Source code for mpipartition.spherical_partition.s2_distribute

from typing import Union
import numpy as np

from .s2_partition import S2Partition
from .._send_home import distribute_dataset_by_home

ParticleDataT = dict[str, np.ndarray]


[docs] def s2_distribute( partition: S2Partition, data: ParticleDataT, *, theta_key: str = "theta", phi_key: str = "phi", verbose: Union[bool, int] = False, verify_count: bool = True, validate_home: bool = False, all2all_iterations: int = 1, ) -> ParticleDataT: """Distribute particles among MPI ranks according to the S2 partition. Parameters ---------- partition: The S2 partition to use for the distribution. data: The particle data to distribute, as a collection of 1-dimensional arrays. Each array must have the same length (number of particles) and the map needs to contain at least the keys `theta_key` and `phi_key`. theta_key: The key in `data` that contains the particle theta coordinates (latitude), in the range [0, pi]. phi_key: The key in `data` that contains the particle phi coordinates (longitude), in the range [0, 2*pi]. verbose: If True, print summary statistics of the distribute. If > 1, print statistics of each rank (i.e. how much data each rank sends to every other rank). verify_count: If True, make sure that total number of objects is conserved. validate_home: If True, validate that each rank indeed owns the particles that it was sent. all2all_iterations: The number of iterations to use for the all-to-all communication. This is useful for large datasets, where MPI_Alltoallv may fail. Returns ------- data: ParticleDataT The distributed particle data (i.e. the data that this rank owns) """ # verify data is normalized assert np.all(data[theta_key] >= 0) assert np.all(data[theta_key] <= np.pi) assert np.all(data[phi_key] >= 0) assert np.all(data[phi_key] < 2 * np.pi) # ring idx: 0=cap, 1=first ring, 2=second ring, etc. if partition.equal_area: ring_idx = np.digitize(data[theta_key], partition.ring_thetas) else: # equal theta if partition.nranks == 2: ring_idx = (data[theta_key] > partition.theta_cap).astype(np.int32) else: assert partition.ring_dtheta is not None ring_idx = ( (data[theta_key] - partition.theta_cap) // partition.ring_dtheta ).astype(np.int32) + 1 ring_idx = np.clip(ring_idx, 0, len(partition.ring_segments) + 1) ring_idx[data[theta_key] == np.pi] -= 1 # handle cases where theta == pi phi_idx = np.zeros_like(ring_idx, dtype=np.int32) mask_is_on_ring = (ring_idx > 0) & (ring_idx <= len(partition.ring_segments)) phi_idx[mask_is_on_ring] = ( data[phi_key][mask_is_on_ring] / (2 * np.pi) * partition.ring_segments[ring_idx[mask_is_on_ring] - 1] ).astype(np.int32) # rank index where each ring starts ring_start_idx = np.zeros(len(partition.ring_segments) + 2, dtype=np.int32) ring_start_idx[1] = 1 ring_start_idx[2:] = np.cumsum(partition.ring_segments) + 1 # rank index of each particle home_idx = ring_start_idx[ring_idx] + phi_idx assert np.all(home_idx >= 0) assert np.all(home_idx < partition.nranks) data_new = distribute_dataset_by_home( partition, data, home_idx=home_idx, verbose=verbose, verify_count=verify_count, all2all_iterations=all2all_iterations, ) if validate_home: assert np.all(data_new[theta_key] >= partition.theta_extent[0]) if partition.theta_extent[1] < np.pi: assert np.all(data_new[theta_key] < partition.theta_extent[1]) else: # bottom cap, we allow theta == pi assert np.all(data_new[theta_key] <= partition.theta_extent[1]) assert np.all(data_new[phi_key] >= partition.phi_extent[0]) assert np.all(data_new[phi_key] < partition.phi_extent[1]) return data_new