Source code for haccytrees.mergertrees.assemble.catalogs2trees

from typing import Callable, Mapping, Union
import numpy as np
import pygio
import numba
import gc
import sys
from mpi4py import MPI

from ...simulations import Simulation
from mpipartition import Partition
from mpipartition import distribute, exchange
from ...utils.timer import Timer, time
from ...utils.datastores import GenericIOStore, NumpyStore
from ...utils.memory import debug_memory
from .catalogs2trees_hdf5_output import write2multiple, write2single
from .fieldsconfig import FieldsConfig


@numba.jit(nopython=True)
def _descendant_index(ids, desc_ids):
    desc_idx = np.empty_like(desc_ids)
    ids2idx = numba.typed.Dict.empty(numba.int64, numba.int64)
    for i in range(len(ids)):
        ids2idx[ids[i]] = i

    progenitor_sizes = np.zeros_like(ids, dtype=np.uint32)

    for i in range(len(desc_ids)):
        if desc_ids[i] >= 0:
            dix = ids2idx[desc_ids[i]]
            desc_idx[i] = dix
            progenitor_sizes[dix] += 1
        else:
            desc_idx[i] = -1

    progenitor_offsets = np.empty_like(progenitor_sizes)

    progenitor_array_size = 0
    for i in range(len(ids)):
        progenitor_offsets[i] = progenitor_array_size
        progenitor_array_size += progenitor_sizes[i]

    progenitor_internal_offsets = np.zeros_like(progenitor_offsets)
    progenitor_array = np.empty(progenitor_array_size, dtype=np.uint32)

    for i in range(len(desc_ids)):
        if desc_ids[i] >= 0:
            dix = ids2idx[desc_ids[i]]
            progenitor_array[
                progenitor_offsets[dix] + progenitor_internal_offsets[dix]
            ] = i
            progenitor_internal_offsets[dix] += 1

    return desc_idx, progenitor_array, progenitor_offsets


@numba.jit(nopython=True)
def _fill_branch_pos_desc(
    branch_position,
    parent_branch_position,
    branch_size,
    progenitor_array,
    progenitor_offsets,
):
    for j in range(len(parent_branch_position)):
        current_pos = parent_branch_position[j] + 1
        low = progenitor_offsets[j]
        up = (
            progenitor_offsets[j + 1]
            if j + 1 < len(progenitor_offsets)
            else len(progenitor_array)
        )
        for pj in range(low, up):
            p = progenitor_array[pj]
            branch_position[p] = current_pos
            current_pos += branch_size[p]


@numba.jit(nopython=True)
def _fill_branch_pos_new(branch_position, branch_size, current_pos, desc_index):
    for j in range(len(desc_index)):
        if desc_index[j] == -1:
            branch_position[j] = current_pos
            current_pos += branch_size[j]
        else:
            assert branch_position[j] != -1
    return current_pos


@numba.jit(nopython=True)
def _calculate_branch_size(branch_size, prev_branch_size, desc_index):
    for j in range(len(desc_index)):
        if desc_index[j] != -1:
            branch_size[desc_index[j]] += prev_branch_size[j]


def _fill_branch_pos(
    branch_position,
    parent_branch_position,
    branch_size,
    desc_index,
    progenitor_array,
    progenitor_offsets,
    new_branch_insert_pos,
):
    # loop over all progenitors
    if len(parent_branch_position) > 0:
        _fill_branch_pos_desc(
            branch_position,
            parent_branch_position,
            branch_size,
            progenitor_array,
            progenitor_offsets,
        )

    # loop over all with desc == -1:
    current_pos = _fill_branch_pos_new(
        branch_position, branch_size, new_branch_insert_pos, desc_index
    )
    return current_pos


def _read_data(
    partition: Partition,
    simulation: Simulation,
    filename: str,
    logger: Callable,
    fields_config: FieldsConfig,
    verbose: Union[bool, int],
):
    fields_xyz = [
        fields_config.node_position_x,
        fields_config.node_position_y,
        fields_config.node_position_z,
    ]

    with Timer(name="read treenodes: read genericio", logger=logger):
        data = pygio.read_genericio(
            filename,
            fields_config.read_fields,
            method=pygio.PyGenericIO.FileIO.FileIOMPI,
        )

    # mask invalid data
    with Timer(name="read treenodes: validate data", logger=logger):
        mask = np.ones_like(data[fields_xyz[0]], dtype=np.bool_)
        for k in data.keys():
            mask &= np.isfinite(data[k])
        n_invalid = len(mask) - np.sum(mask)
        if n_invalid > 0:
            if verbose > 1:
                print(
                    f"WARNING rank {partition.rank}: {n_invalid} invalid halos",
                    flush=True,
                )
            for k in data.keys():
                data[k] = data[k][mask]
        n_invalid_global = partition.comm.reduce(n_invalid, root=0, op=MPI.SUM)
        if partition.rank == 0 and n_invalid_global > 0:
            print(f"WARNING: {n_invalid_global} invalid halos (nan/inf)", flush=True)

    # normalize data
    for k in fields_xyz:
        data[k] = np.remainder(data[k], simulation.rl)

    # assign to rank by position
    with Timer(name="read treenodes: distribute", logger=logger):
        data = distribute(
            partition,
            simulation.rl,
            data,
            fields_xyz,
            verbose=verbose,
            verify_count=True,
        )

    # calculate derived fields
    with Timer(name="read treenodes: calculate derived fields", logger=logger):
        for k, f in fields_config.derived_fields.items():
            data[k] = f(data)

    # only keep necessary fields
    data_new = {}
    for k in fields_config.keep_fields:
        data_new[k] = data[k]

    return data_new


def _catalog2tree_step(
    step: int,
    data_store: Union[GenericIOStore, Mapping],
    index_store: Union[NumpyStore, Mapping],
    size_store: Mapping,
    local_desc: np.ndarray,
    partition: Partition,
    simulation: Simulation,
    treenode_base: str,
    fields_config: FieldsConfig,
    do_all2all_exchange: bool,
    fail_on_desc_not_found: bool,
    verbose: Union[bool, int],
    logger: Callable[[str], None],
):
    logger(f"\nSTEP {step:3d}\n--------\n")

    # read data and calculate derived fields (contains its own timers)
    with Timer(name="read treenodes", logger=logger):
        if "#" in treenode_base:
            treenode_filename = treenode_base.replace("#", str(step))
        else:
            treenode_filename = f"{treenode_base}{step}.treenodes"
        data = _read_data(
            partition=partition,
            simulation=simulation,
            filename=treenode_filename,
            logger=logger,
            fields_config=fields_config,
            verbose=verbose,
        )

    # exchange progenitors that belong to neighboring ranks
    with Timer(name="exchange descendants", logger=logger):
        if fail_on_desc_not_found:
            na_desc_key = None
        else:
            na_desc_key = -1
        data = exchange(
            partition,
            data,
            fields_config.desc_node_index,
            local_desc,
            filter_key=lambda idx: idx >= 0,
            verbose=verbose,
            do_all2all=do_all2all_exchange,
            replace_notfound_key=na_desc_key,
        )

    # sort array by descendant index, most massive first
    with Timer(name="sort arrays", logger=logger):
        s = np.lexsort(
            (data[fields_config.desc_node_index], data[fields_config.tree_node_mass])
        )[::-1]
        for k in data.keys():
            data[k] = data[k][s]

    # get index to descendant for each progenitor, and an inverse list of progenitors for each descendant
    with Timer(name="assign progenitors", logger=logger):
        desc_idx, progenitor_array, progenitor_offsets = _descendant_index(
            local_desc, data[fields_config.desc_node_index]
        )
        if len(progenitor_array) and progenitor_array.min() < 0:
            print(
                f"invalid progenitor array on rank {partition.rank}",
                file=sys.stderr,
                flush=True,
            )
            partition.comm.Abort()
    # prepare next step:
    local_ids = data["tree_node_index"]

    # store data
    data_store[step] = data
    index_store[step] = {
        "desc_idx": desc_idx,
        "progenitor_array": progenitor_array,
        "progenitor_offsets": progenitor_offsets,
    }
    size_store[step] = len(local_ids)

    return local_ids


[docs] def catalog2tree( simulation: Simulation, treenode_base: str, fields_config: FieldsConfig, output_file: str, *, # The following arguments are keyword-only temporary_path: str = None, do_all2all_exchange: bool = False, split_output: bool = False, fail_on_desc_not_found: bool = True, mpi_waittime: float = 0, logger: Callable[[str], None] = None, verbose: Union[bool, int] = False, ) -> None: """The main routine that converts treenode-catalogs to HDF5 treenode forests [add basic outline of algorithm] Parameters ---------- simulation a :class:`Simulation` instance containing the cosmotools steps treenode_base the base path for the treenode files. - if ``treenode_base`` contains ``#``, ``#`` will be replace by the current step number - otherwise, the path will be completed by appending ``[step].treenodes``. fields_config a :class:`FieldsConfig` instance, containing the treenodes filed names output_file base path for the output file(s) temporary_path base path for temporary files. Note: folders must exist. do_all2all_exchange if ``False``, will exchange among neighboring ranks first and then all2all. If ``True``, will do all2all directly split_output if ``True``, forests will be stored in multiple HDF5 files (one per rank). If ``False``, all data will be combined in a single file (might not be feasible for large simulations) fail_on_desc_not_found if ``True``, will abort if a descendant halo cannot be found among all ranks. If ``False``, the orphaned halo will become the root of the subtree. mpi_waittime time in seconds for which the code will wait for the MPI to be initialized. Can help with some MPI errors (on cooley) logger a logging function, e.g. ``print`` verbose verbosity level, either 0, 1, or 2 """ # Set a timer for the full run total_timer = Timer("total time", logger=None) total_timer.start() # Initialize MPI partition time.sleep(mpi_waittime) partition = Partition(create_neighbor_topo=not do_all2all_exchange) # Wait a sec... (maybe solves MPI issues?) time.sleep(mpi_waittime) # Used to print status messages, only from rank 0 def logger(x, **kwargs): kwargs.pop("flush", None) partition.comm.Barrier() if partition.rank == 0: print(x, flush=True, **kwargs) # Print MPI configuration if partition.rank == 0: print(f"Running catalog2tree with {partition.nranks} ranks") if verbose: print("Partition topology:") print(" Decomposition: ", partition.decomposition) print(" Ranklist: \n", partition.ranklist) print("", flush=True) if verbose > 1: for i in range(partition.nranks): if partition.rank == i: print(f"Rank {i}", flush=True) print(" Coordinates: ", partition.coordinates) print(" Extent: ", partition.extent) print(" Origin: ", partition.origin) print(" Neighbors: ", partition.neighbors) if not do_all2all_exchange: n26 = partition.neighbor_ranks print(" 26-neighbors N : ", len(n26)) print(" 26-neighbors : ", n26) print("", flush=True) partition.comm.Barrier() # cosmotools steps steps = simulation.cosmotools_steps # read final snapshot (tree roots) with Timer(name="read treenodes", logger=logger): if "#" in treenode_base: treenode_filename = treenode_base.replace("#", str(steps[-1])) else: treenode_filename = f"{treenode_base}{steps[-1]}.treenodes" data = _read_data( partition, simulation, treenode_filename, logger, fields_config, verbose=verbose, ) # sort by tree_node_index with Timer(name="sort arrays", logger=logger): s = np.argsort(data[fields_config.tree_node_index]) for k in data.keys(): data[k] = data[k][s] local_ids = data[fields_config.tree_node_index] dtypes = {k: i.dtype for k, i in data.items()} # Containers to store local data from each snapshot data_by_step = GenericIOStore(partition, simulation.rl, temporary_path) data_by_step[steps[-1]] = data index_by_step = NumpyStore(partition, temporary_path) index_by_step[steps[-1]] = { "desc_idx": -1.0 * np.ones(len(s), dtype=np.int64), "progenitor_array": np.empty(0, dtype=np.uint32), "progenitor_offsets": np.empty(0, dtype=np.uint32), } size_by_step = {steps[-1]: len(local_ids)} # iterate over previous snapshots: read, assign, prepare ordering for step in steps[-2::-1]: local_ids = _catalog2tree_step( step, data_by_step, index_by_step, size_by_step, local_ids, partition, simulation, treenode_base, fields_config, do_all2all_exchange, fail_on_desc_not_found, verbose, logger, ) gc.collect() # explicitly free memory if verbose: debug_memory(partition, "AFTER GC") # local and total size of forest local_size = sum([l for s, l in size_by_step.items()]) logger("\nBUILDING TREE\n-------------\n") # size of branches, iteratively with Timer("calculate branch sizes", logger=logger): branch_sizes = { step: np.ones(l, dtype=np.uint64) for step, l in size_by_step.items() } desc_index = np.empty(0, dtype=np.int64) prev_branch_size = np.empty(0, dtype=np.uint64) for i, step in enumerate(steps): branch_size = branch_sizes[step] _calculate_branch_size(branch_size, prev_branch_size, desc_index) prev_branch_size = branch_size next_indices = index_by_step[step] desc_index = next_indices["desc_idx"] if verbose: debug_memory(partition, "AFTER BRANCH SIZE") # index to where in the forest each halo goes with Timer("calculate array positions", logger=logger): branch_positions = {} parent_branch_position = np.empty(0, dtype=np.int64) new_branch_insert_pos = 0 for i, step in enumerate(steps[::-1]): branch_size = branch_sizes[step] next_indices = index_by_step[step] desc_index = next_indices["desc_idx"] # convert array and offsets to list of lists # progenitor_idx = np.array_split(next_indices['progenitor_array'], next_indices['progenitor_offsets'][1:]) branch_position = -np.ones(len(branch_size), dtype=np.int64) current_pos = _fill_branch_pos( branch_position, parent_branch_position, branch_size, desc_index, next_indices["progenitor_array"], next_indices["progenitor_offsets"], new_branch_insert_pos, ) # mark end-of-table new_branch_insert_pos = current_pos # update for next iteration parent_branch_position = branch_position branch_positions[step] = branch_position if verbose: debug_memory(partition, "AFTER ARRAY POS") # Assert that we got all halos if new_branch_insert_pos != local_size: print( f"Error (rank {partition.rank}): forest size is {new_branch_insert_pos} " f"instead of {local_size}", flush=True, ) partition.comm.Abort() # write data with Timer("output HDF5 forest", logger=logger): if split_output: write2multiple( partition, steps, local_size, branch_sizes, branch_positions, data_by_step, fields_config, dtypes, output_file, logger, ) else: write2single( partition, steps, local_size, branch_sizes, branch_positions, data_by_step, fields_config, dtypes, output_file, logger, ) logger("cleanup...") for step in simulation.cosmotools_steps: data_by_step.remove(step) index_by_step.remove(step) total_timer.stop() if partition.rank == 0: print("\n----------------------------------------\n") print("Forest creation finished!\n") print("Final timers:") timer_names = list(Timer.timers.keys()) timer_names.sort() maxtimerlen = max([len(n) for n in timer_names]) + 1 for n, t in Timer.timers.items(): print(f"{n:{maxtimerlen}s}: {t:4e}s") partition.comm.Barrier()