Source code for haccytrees.mergertrees.forest_matrix

import numpy as np
import numba
from typing import Mapping, Union

from ..simulations import Simulation

def _count_rows(snapnum, desc_index, mass, row_idx, mass_threshold=0):
    nhalos = len(snapnum)
    lastsnap = snapnum[0]
    current_row = 0
    row_above_threshold = False
    for i in range(nhalos):
        if desc_index[i] == -1 or snapnum[i] >= lastsnap:
            row_above_threshold = mass[i] > mass_threshold
            if row_above_threshold:
                current_row += 1
        # elif snapnum[i] >= lastsnap:
        #     current_row += 1
        lastsnap = snapnum[i]
        row_idx[i] = current_row - 1 if row_above_threshold else -1
    return current_row

@numba.jit(nopython=True, parallel=True)
def _fill_matrix(mat, data, row_idx, col_idx):
    for i in numba.prange(len(data)):
        if row_idx[i] > -1:
            mat[row_idx[i], col_idx[i]] = data[i]

def _fill_hostidx(
    tree_node_index_mat, fof_halo_tag_mat, top_host_row_mat, direct_host_row_mat, parent_tni_mat, parent_fof_mat,
    nrows = tree_node_index_mat.shape[0]
    ncols = tree_node_index_mat.shape[1]
    host_rows_per_snap = np.empty(ncols, dtype=np.int64)
    host_rows_per_snap[:] = -1
    for i in range(nrows):
        infall_col = 0
        for j in range(ncols):
            if tree_node_index_mat[i, j] == 0:
                if infall_col == 0:
                    # halo doesn't exist yet at this snapshot
                    # this is a subhalo
                    host_row = host_rows_per_snap[j]
                    top_host_row = top_host_row_mat[host_row, j]
                    direct_host_row_mat[i, j] = host_row
                    top_host_row_mat[i, j] = (
                        top_host_row if top_host_row != -1 else host_row
                    parent_tni_mat[i, j] = tree_node_index_mat[top_host_row, j]
                    parent_fof_mat[i, j] = fof_halo_tag_mat[top_host_row, j]
                # a top halo
                infall_col = j
                host_rows_per_snap[j] = i
                parent_tni_mat[i, j] = tree_node_index_mat[i, j]
                parent_fof_mat[i, j] = fof_halo_tag_mat[i, j]

[docs] def forest2matrix( forest: Mapping[str, np.ndarray], simulation: Union[str, Simulation], target_index: int = None, *, subhalo_data: Mapping[str, np.ndarray] = None, branchmass_threshold: float = None, ) -> Mapping[str, np.ndarray]: """Convert a haccytree forest to a matrix, where each row is a branch Parameters ---------- forest: the full treenode forest returned by :func:`read_forest` simulation: the full progenitor array created by :func:`read_forest` target_index: (optional) if set, a matrix of only the tree starting at `target_index` will be calculated subhalo_data: currently does nothing branchmass_threshold: if set, removes all branches that have a mass below the threshold at infall Returns ------- matrices: Mapping[str, np.ndarray] arrays of shape `(nbranches, nsnapshots)` for each of the forest properties. Additionally, `matrices["top_host_row"]` contains the row number of the main host, `matrices["direct_host_row"]` contains the row of the direct host in the hierarchy, and `matrices["parent_tni"] contains the `tree_node_index` of the main (top) host halo. `top_host_row` and `direct_host_row` are `-1` for host halos. `parent_tni` of a host halo is its own `tree_node_index`. Notes ----- All the properties (except `top_host_row`, `direct_host_row`, `parent_tni`) are 0 for entries where the halo does not exist or when the halo is a subhalo. For `top_host_row`, `direct_host_row`, and `parent_tni`, the values are `-1` prior to the nucleation of the branch. """ if isinstance(simulation, str): simulation = Simulation.simulations[simulation] if target_index is not None: start = target_index end = target_index + forest["branch_size"][target_index] forest = {k: forest[k][start:end] for k in forest.keys()} # make sure we don't use invalid indices... forest.pop("halo_index", None) forest.pop("descendant_idx", None) ncols = len(simulation.cosmotools_steps) nhalos = len(forest["snapnum"]) col_idx = forest["snapnum"] row_idx = np.empty(nhalos, dtype=np.int64) mass_threshold = 0 if branchmass_threshold is None else branchmass_threshold nrows = _count_rows( forest["snapnum"], forest["desc_node_index"], forest["tree_node_mass"], row_idx, mass_threshold=mass_threshold, ) # remove forest specific keys data_keys = set(forest.keys()) discard_keys = [ "branch_size", "descendant_idx", "progenitor_count", "progenitor_offset", "halo_index", "scale_factor", "snapnum", ] for k in discard_keys: data_keys.discard(k) matrices = {} for k in data_keys: data = forest[k] matrices[k] = np.zeros((nrows, ncols), dtype=data.dtype) _fill_matrix(matrices[k], data, row_idx, col_idx) # subhalo data: contains 'mass', 'hostidx', 'direct_hostidx', 'infallidx', 'snapnum' if subhalo_data is not None: # copy mass pass # Additional indices matrices["top_host_row"] = np.empty((nrows, ncols), dtype=np.int64) matrices["top_host_row"][:] = -1 matrices["direct_host_row"] = np.empty((nrows, ncols), dtype=np.int64) matrices["direct_host_row"][:] = -1 matrices["parent_tni"] = np.empty((nrows, ncols), dtype=np.int64) matrices["parent_tni"][:] = -1 matrices["parent_fof"] = np.empty((nrows, ncols), dtype=np.int64) _fill_hostidx( matrices["tree_node_index"], matrices["fof_halo_tag"], matrices["top_host_row"], matrices["direct_host_row"], matrices["parent_tni"], matrices["parent_fof"] ) return matrices