Source code for mpipartition.spherical_partition.s2_partition

import numpy as np
import numpy.typing as npt
from typing import Optional, Tuple, List
from dataclasses import dataclass

from mpi4py import MPI

def _cap_area(theta):
    h = 1 - np.cos(theta)
    return 2 * np.pi * h

def _cap_angle(area):
    h = area / (2 * np.pi)
    return np.arccos(1 - h)

def _s2_partition(n, adjust_theta=True):
    # Following the algorithm by Paul Leopardi (
    # TODO: generalize to arbitrary number of dimensions
    assert n > 1
    if n == 2:
        return (
            np.pi / 2,
            np.array([np.pi / 2], dtype=np.float64),
            np.array([], dtype=np.int32),

    target_area = 4 * np.pi / n

    # 1: polar caps
    theta_c = _cap_angle(target_area)

    # 2-4: rings
    delta_i = np.sqrt(target_area)  # ideal spacing
    n_rings_i = (np.pi - 2 * theta_c) / delta_i
    n_rings = max(int(n_rings_i + 0.5), 1)
    # 5-6: ring segments
    delta_f = (np.pi - 2 * theta_c) / n_rings
    theta_f = theta_c + np.arange(n_rings) * delta_f
    theta_f = np.append(theta_f, np.pi - theta_c)

    n_segments_i = (_cap_area(theta_f[1:]) - _cap_area(theta_f[:-1])) / target_area
    n_segments = np.zeros(n_rings, dtype=np.int32)
    remainder = 0
    for i in range(n_rings):
        ni = int(n_segments_i[i] + remainder + 0.5)
        remainder += n_segments_i[i] - ni
        n_segments[i] = ni

    assert abs(remainder) < 1e-6
    assert np.sum(n_segments) + 2 == n

    # 7: adjust theta_f for equal area
    if adjust_theta:
        areas = target_area * np.array(n_segments)
        cum_areas = target_area + np.cumsum(areas)
        theta_f[1:] = _cap_angle(cum_areas)

    return theta_c, theta_f, n_segments

class S2Segment:
    theta_range: Tuple[float]
    phi_range: Tuple[float]
    area: float
    edge_length: float

def _build_s2_segment_list(theta_cap, ring_thetas, ring_segments):
    segments = []

    # cap
    cap_area = _cap_area(theta_cap)
    cap_edge_length = 2 * np.pi * np.sin(theta_cap)
        S2Segment((0, theta_cap), (0, 2 * np.pi), cap_area, cap_edge_length)

    # rings
    for i in range(len(ring_segments)):
        theta_start = ring_thetas[i]
        theta_end = ring_thetas[i + 1]
        area = (_cap_area(theta_end) - _cap_area(theta_start)) / ring_segments[i]
        phi_edges = np.linspace(0, 2 * np.pi, ring_segments[i] + 1, endpoint=True)
        edge_length_top = 2 * np.pi * np.sin(theta_start) / ring_segments[i]
        edge_length_bottom = 2 * np.pi * np.sin(theta_end) / ring_segments[i]
        edge_length = (
            edge_length_top + edge_length_bottom + 2 * (theta_end - theta_start)
        for j in range(ring_segments[i]):
                    (theta_start, theta_end),
                    (phi_edges[j], phi_edges[j + 1]),

    # cap
        S2Segment((np.pi - theta_cap, np.pi), (0, 2 * np.pi), cap_area, cap_edge_length)

    return segments

def _print_segmentation_info(
    nranks, theta_cap, ring_thetas, ring_segments, precision=3
    print(f"Segmentation statistics for {nranks} ranks:")
    print(f"  polar cap angle: {theta_cap:.{precision}f}")
    print(f"  number of rings: {len(ring_segments)}")
    for i in range(len(ring_segments)):
            f"    ring {i:3d}: {ring_segments[i]:3d} segments between "
            f"theta=[{ring_thetas[i]:.{precision}f}, {ring_thetas[i+1]:.{precision}f}]]"

# area imbalance
def _print_area_imabalance(segments: List[S2Segment], precision=3):
    areas = np.array([r.area for r in segments])
    assert np.isclose(np.sum(areas), 4 * np.pi)
    print("  Segment area imbalance:")
    print(f"    max/min: {np.max(areas) / np.min(areas):.{precision}f}")
    print(f"    max/avg: {np.max(areas) / np.mean(areas):.{precision}f}")

def _print_edge_to_area_ratio(segments: List[S2Segment], precision=3):
    areas = np.array([r.area for r in segments])
    edge_lengths = np.array([r.edge_length for r in segments])
    total_ratio = np.sum(edge_lengths) / np.sum(areas)
    print(f"  Total edge/area ratio: {total_ratio:.{precision}f}")

[docs] class S2Partition: """An MPI decomposition of the spherical shell into equal-area segments Parameters ---------- equal_area : bool If True, the spherical shell is divided into equal-area segments. If False, use equally spaced rings (in theta) instead. comm : MPI.Comm The MPI communicator to use for the decomposition (default: COMM_WORLD) verbose : bool If True, rank 0 will print information about the segmentation. """ # parition properties _comm: MPI.Comm _nranks: int _theta_cap: float _ring_thetas: npt.NDArray[np.float64] _ring_segments: npt.NDArray[np.int64] _equal_area: bool _ring_dtheta: Optional[float] _all_s2_segments: List[S2Segment] # rank properties _rank: int _s2_segment: S2Segment @property def comm(self): """MPI Communicator""" return self._comm @property def rank(self) -> int: """the MPI rank of this processor""" return self._rank @property def nranks(self) -> int: """the total number of processors""" return self._nranks @property def equal_area(self) -> bool: """whether the partition is equal-area""" return self._equal_area @property def theta_cap(self) -> float: """the polar cap angle""" return self._theta_cap @property def ring_thetas(self) -> npt.NDArray[np.float64]: """the theta boundaries of all rings""" return self._ring_thetas @property def ring_segments(self) -> npt.NDArray[np.int64]: """the number of segments in each ring""" return self._ring_segments @property def ring_dtheta(self) -> Optional[float]: """the theta spacing between rings (only for non-equal-area partitions)""" return self._ring_dtheta @property def phi_extent(self) -> Tuple[float, float]: """the phi extent of the segment assigned to this rank""" return self._s2_segment.phi_range @property def theta_extent(self) -> Tuple[float, float]: """the theta extent of the segment assigned to this rank""" return self._s2_segment.theta_range @property def area(self) -> float: """the area of the segment assigned to this rank""" return self._s2_segment.area @property def all_phi_extents(self) -> npt.NDArray[np.float64]: """the phi extent of all segments, shape (nranks, 2)""" return np.array([s.phi_range for s in self._all_s2_segments]) @property def all_theta_extents(self) -> npt.NDArray[np.float64]: """the theta extent of all segments, shape (nranks, 2)""" return np.array([s.theta_range for s in self._all_s2_segments]) def __init__( self, *, equal_area: bool = True, comm: MPI.Comm = None, verbose: bool = False, ): self._mpi_init = False if comm is None: if not MPI.Is_initialized(): MPI.Init() self._mpi_init = True comm = MPI.COMM_WORLD self._comm = comm self._rank = self._comm.Get_rank() self._nranks = self._comm.Get_size() self._equal_area = equal_area self._theta_cap, self._ring_thetas, self._ring_segments = _s2_partition( self._nranks, equal_area ) if not equal_area: if len(self._ring_thetas) > 1: self._ring_dtheta = self._ring_thetas[1] - self._ring_thetas[0] else: self._ring_dtheta = 0.0 else: self._ring_dtheta = None self._all_s2_segments = _build_s2_segment_list( self._theta_cap, self._ring_thetas, self._ring_segments ) assert len(self._all_s2_segments) == self._nranks self._s2_segment = self._all_s2_segments[self._rank] if verbose and self._rank == 0: _print_segmentation_info( self._nranks, self._theta_cap, self._ring_thetas, self._ring_segments, ) _print_area_imabalance(self._all_s2_segments) _print_edge_to_area_ratio(self._all_s2_segments) print() def __del__(self): if self._mpi_init: MPI.Finalize()
def visualize_s2_partition( nranks: int, equal_area: bool = True, use_mollweide: bool = True, fig=None ): """Visualize the S2 partitioning of the sphere. Parameters ---------- nranks : int Number of ranks to partition the sphere into. equal_area : bool, optional If True, partition the sphere into equal area regions by adjusting theta. Otherwise, keep delta_theta of rings constant. use_mollweide : bool, optional If True, use the Mollweide projection. Otherwise, use a regular plot. fig : matplotlib.figure.Figure, optional Figure to plot on. If None, create a new figure. Returns ------- matplotlib.figure.Figure Figure containing the plot. matplotlib.axes.Axes Axes containing the plot. """ try: import matplotlib.pyplot as plt except ImportError: raise ImportError("matplotlib is required for visualization") if fig is None: fig = plt.figure() if use_mollweide: ax = fig.add_subplot(1, 1, 1, projection="mollweide") else: ax = fig.add_subplot(1, 1, 1) theta_c, theta_f, n_segments = _s2_partition(nranks, equal_area) for i, nsegments in enumerate(n_segments): theta_lo = theta_f[i] theta_hi = theta_f[i + 1] if use_mollweide: theta_lo -= np.pi / 2 theta_hi -= np.pi / 2 # draw vertical bars for j in range(nsegments): ax.plot( (2 * np.pi / nsegments * j - np.pi) * np.ones(2), [theta_lo, theta_hi], color="black", linewidth=0.5, ) # draw upper ring ax.axhline(theta_lo, color="black", linewidth=0.5) # draw final ring (only if we have rings!) if n_segments.size: ax.axhline(theta_hi, color="black", linewidth=0.5) else: ax.axhline(theta_c, color="black", linewidth=0.5) if not use_mollweide: ax.set( xlabel=r"$\phi$", ylabel=r"$\theta$", xlim=(-np.pi, np.pi), ylim=(0, np.pi), ) return fig, ax