Source code for mpipartition.exchange

import sys
from typing import Callable, Mapping, Union

import numpy as np

from .partition import Partition

ParticleDataT = Mapping[str, np.ndarray]


[docs] def exchange( partition: Partition, data: dict, key: str, local_keys: np.ndarray, *, verbose: bool = False, filter_key: Union[int, Callable[[np.ndarray], np.ndarray]] = None, do_all2all: bool = False, replace_notfound_key: int = None, ): """Distribute data among neighboring ranks and all2all by a key This function will assign data to the rank that owns the key. The keys that the local rank owns are given by ``local_keys``, which should be unique. The keys of the data that the local rank currently has is in ``data[key]``. Certain values can be ignored by setting filter_key to that value or by setting filter_key to a (vectorized) function that returns ``True`` for keys that should be redistributed and ``False`` for keys that should be ignored. Parameters ---------- Returns ------- """ if not do_all2all and partition.comm_neighbor is None: raise RuntimeError( "Cannot exchange data between neighbors if no neighbor topology was " "created. Either create partition with ``create_neighbor_topo=True`` " "or set ``do_all2all=True``" ) comm = partition.comm rank = partition.rank nranks = partition.nranks if nranks == 1: return data if do_all2all: # exchange particles with all ranks exchange_comm = comm exchange_nranks = nranks exchange_Alltoall = exchange_comm.Alltoall exchange_Alltoallv = exchange_comm.Alltoallv exchange_Allgather = exchange_comm.Allgather exchange_Allgatherv = exchange_comm.Allgatherv else: # exchange particles with the neighboring ranks exchange_comm = partition.comm_neighbor exchange_nranks = len(partition.neighbor_ranks) exchange_Alltoall = exchange_comm.Neighbor_alltoall exchange_Alltoallv = exchange_comm.Neighbor_alltoallv exchange_Allgather = exchange_comm.Neighbor_allgather exchange_Allgatherv = exchange_comm.Neighbor_allgatherv localcount = len(data[key]) data_keys = np.unique(data[key]) if filter_key is not None: if callable(filter_key): data_keys = data_keys[filter_key(data_keys)] else: data_keys = data_keys[data_keys != filter_key] # find local matches islocal = np.isin(data_keys, local_keys, assume_unique=True) nonlocal_data = data_keys[~islocal] # communicate nonlocal descendants local_orphan_count = np.array([len(nonlocal_data)], dtype=np.int32) orphan_counts = np.empty(exchange_nranks, dtype=np.int32) exchange_Allgather(local_orphan_count, orphan_counts) total_orphan_count = np.sum(orphan_counts) orphan_offsets = np.insert(np.cumsum(orphan_counts)[:-1], 0, 0) orphan_data = np.empty(total_orphan_count, dtype=nonlocal_data.dtype) orphan_ranks = np.empty(total_orphan_count, dtype=np.int32) for i in range(exchange_nranks): low = orphan_offsets[i] high = low + orphan_counts[i] orphan_ranks[low:high] = i exchange_Allgatherv( nonlocal_data, [orphan_data, (orphan_counts, orphan_offsets), nonlocal_data.dtype.char], ) if verbose > 1: for i in range(nranks): if rank == i: print(f"Debug Desc Exchange (nonlocal), rank {i}") print(f" - send {local_orphan_count}") print(f" - recv {orphan_counts}") print(f"", flush=True) comm.Barrier() # check if we have any of them orphan_islocal = np.isin(orphan_data, local_keys) # ask orphan_requests_send = orphan_data[orphan_islocal] orphan_requests_send_ranks = orphan_ranks[orphan_islocal] orphan_requests_send_offsets = np.searchsorted( orphan_requests_send_ranks, np.arange(exchange_nranks) ) orphan_requests_send_counts = ( np.append(orphan_requests_send_offsets[1:], len(orphan_requests_send)) - orphan_requests_send_offsets ) orphan_requests_recv_counts = np.empty_like(orphan_requests_send_counts) exchange_Alltoall(orphan_requests_send_counts, orphan_requests_recv_counts) orphan_requests_recv_total = np.sum(orphan_requests_recv_counts) orphan_requests_recv_offsets = np.insert( np.cumsum(orphan_requests_recv_counts)[:-1], 0, 0 ) orphan_requests_recv = np.empty( orphan_requests_recv_total, dtype=orphan_requests_send.dtype ) exchange_Alltoallv( [ orphan_requests_send, (orphan_requests_send_counts, orphan_requests_send_offsets), orphan_requests_send.dtype.char, ], [ orphan_requests_recv, (orphan_requests_recv_counts, orphan_requests_recv_offsets), orphan_requests_recv.dtype.char, ], ) if verbose > 1: for i in range(nranks): if rank == i: print(f"Debug Desc Exchange (request), rank {i}") print(f" - will request {np.sum(orphan_islocal)} in total") print(f" - send req {orphan_requests_send_counts}") print(f" - recv req {orphan_requests_recv_counts}") print(f"", flush=True) comm.Barrier() # verify that we don't aks ourselves for particles if do_all2all and ( orphan_requests_send_counts[rank] != 0 or orphan_requests_recv_counts[rank] != 0 ): print( f"Error in exchange: rank {rank} is asking itself for an orphan halo: " f"{orphan_requests_send_counts[rank]}/{orphan_requests_recv_counts[rank]}", file=sys.stderr, flush=True, ) comm.Abort() # prepare data to send orphan_requests_indices = [] orphan_requests_mask = np.zeros(localcount, dtype=np.bool_) for i in range(exchange_nranks): req = orphan_requests_recv[ orphan_requests_recv_offsets[i] : orphan_requests_recv_offsets[i] + orphan_requests_recv_counts[i] ] mask = np.isin(data[key], req) orphan_requests_indices.append(np.nonzero(mask)[0]) orphan_requests_mask |= mask orphan_requests_send_counts = np.array( [len(i) for i in orphan_requests_indices], dtype=np.int32 ) orphan_requests_recv_counts = np.empty_like(orphan_requests_send_counts) exchange_Alltoall(orphan_requests_send_counts, orphan_requests_recv_counts) orphan_requests_recv_total = np.sum(orphan_requests_recv_counts) orphan_requests_send_offsets = np.insert( np.cumsum(orphan_requests_send_counts)[:-1], 0, 0 ) orphan_requests_recv_offsets = np.insert( np.cumsum(orphan_requests_recv_counts)[:-1], 0, 0 ) orphan_requests_indices = np.concatenate(orphan_requests_indices) if verbose > 1: for i in range(nranks): if rank == i: print(f"Debug Desc Exchange (to exchange), rank {i}") print(f" - send {orphan_requests_send_counts}") print(f" - recv {orphan_requests_recv_counts}") print(f"", flush=True) comm.Barrier() data_new = {} for k in data.keys(): orphan_requests_send = data[k][orphan_requests_indices] orphan_requests_recv = np.empty(orphan_requests_recv_total, dtype=data[k].dtype) exchange_Alltoallv( [ orphan_requests_send, (orphan_requests_send_counts, orphan_requests_send_offsets), orphan_requests_send.dtype.char, ], [ orphan_requests_recv, (orphan_requests_recv_counts, orphan_requests_recv_offsets), orphan_requests_recv.dtype.char, ], ) data_new[k] = np.concatenate( (data[k][~orphan_requests_mask], orphan_requests_recv) ) if verbose > 1 and rank == 0: print("Exchange succeeded, verifying data integrity", flush=True) # comm.Barrier() # Verification localcount_after = len(data_new[key]) localcount_missmatch = local_orphan_count[0] # calculate new missmatch my_data = np.unique(data_new[key]) my_data = my_data[my_data >= 0] islocal = np.isin(my_data, local_keys, assume_unique=True) missing_keys = my_data[~islocal] localcount_missmatch_after = len(missing_keys) localcounts = np.array( [ localcount, localcount_after, localcount_missmatch, localcount_missmatch_after, ], dtype=np.int64, ) totalcounts = np.empty_like(localcounts) comm.Allreduce(localcounts, totalcounts) ( totalcount_before, totalcount_after, totalcount_missmatch, totalcount_missmatch_after, ) = totalcounts if verbose and rank == 0: print(f"exchange summary ({'all2all' if do_all2all else 'neighbors'}):") print( f" Ntot -> Ntot: {totalcount_before:10d} -> {totalcount_after:10d} " "(should remain the same)" ) print( f" Orph -> Orph: {totalcount_missmatch:10d} -> " f"{totalcount_missmatch_after:10d} (should be 0 after)" ) print("", flush=True) # did we conserve number of particles? if rank == 0 and totalcount_before != totalcount_after: print( "Error in exchange: Lost halos during progenitor exchange: " f"{totalcount_before} -> {totalcount_after}", file=sys.stderr, flush=True, ) comm.Abort() # if we were not able to assign all orphans to the neighbors, try all2all if not do_all2all and totalcount_missmatch_after > 0: if verbose and rank == 0: print( "exchange all2all since neighbor exchange was not able to assign all: " f"{totalcount_missmatch} -> {totalcount_missmatch_after}", flush=True, ) return exchange( partition, data_new, key, local_keys, verbose=verbose, filter_key=filter_key, do_all2all=True, replace_notfound_key=replace_notfound_key, ) # if we are still not able to assign all orphans, replace key or abort after # printing some debug messages if replace_notfound_key is not None and localcount_missmatch_after > 0: d = data_new[key] d[np.isin(d, missing_keys)] = replace_notfound_key for i in range(nranks): if rank == i and localcount_missmatch_after != 0 and verbose > 1: print( f"Warning from rank {rank} in exchange: Unable to assign all " f"progenitors to correct ranks (failed for " f"{localcount_missmatch_after} out of {localcount_missmatch})" ) print("Could not assign keys: ", missing_keys) print("", flush=True) comm.Barrier() if rank == 0 and totalcount_missmatch_after != 0: if replace_notfound_key is None: print( f"Error in exchange: Unable to assign all progenitors to correct ranks " f"(tried to reassign {totalcount_missmatch}, failed for " f"{totalcount_missmatch_after})", file=sys.stderr, flush=True, ) comm.Abort() elif verbose: print( f"Warning in exchange: Unable to assign all progenitors to correct " f"ranks (tried to reassign {totalcount_missmatch}, failed for " f"{totalcount_missmatch_after}), replacing missing values with " f"{replace_notfound_key}", flush=True, ) return data_new