Source code for llg3d.grid

"""Define the computational grid for the simulation."""

from typing import ClassVar
from dataclasses import dataclass, field, asdict
import numpy as np

from . import solvers


[docs] @dataclass class Grid: """Stores grid data.""" # Parameter values correspond to the global grid Jx: int #: number of points in x direction Jy: int #: number of points in y direction Jz: int #: number of points in z direction dx: float #: grid spacing in x direction dy: float = field(init=False) #: grid spacing in y direction dz: float = field(init=False) #: grid spacing in z direction dV: float = field(init=False) #: elemental volume Lx: float = field(init=False) #: physical length in x direction Ly: float = field(init=False) #: physical length in y direction Lz: float = field(init=False) #: physical length in z direction dims: tuple[int, int, int] = field(init=False) #: local grid dimensions V: float = field(init=False) #: total volume ntot: int = field(init=False) #: total number of points ncell: int = field(init=False) #: total number of cells inv_dx2: float = field(init=False) #: :math:`1/dx^2` inv_dy2: float = field(init=False) #: :math:`1/dy^2` inv_dz2: float = field(init=False) #: :math:`1/dz^2` center_coeff: float = field(init=False) #: center coefficient for Laplacian uniform: ClassVar[bool] = True #: whether the grid is uniform def __post_init__(self) -> None: """Compute grid characteristics.""" self.dy = self.dz = self.dx # Enforce dx = dy = dz self.Lx = (self.Jx - 1) * self.dx self.Ly = (self.Jy - 1) * self.dy self.Lz = (self.Jz - 1) * self.dz self.dims = self.Jx // solvers.size, self.Jy, self.Jz self.dV = self.dx * self.dy * self.dz self.V = self.Lx * self.Ly * self.Lz self.ntot = self.Jx * self.Jy * self.Jz self.ncell = (self.Jx - 1) * (self.Jy - 1) * (self.Jz - 1) # precompute the Laplacian coefficients (uniform grid spacing) self.inv_dx2 = self.inv_dy2 = self.inv_dz2 = 1 / self.dx**2 self.center_coeff = -6.0 * self.inv_dx2 def __str__(self) -> str: """Return grid information.""" header = "\t\t".join(("x", "y", "z")) s = f"""\ --- \t{header} J\t{self.Jx}\t\t{self.Jy}\t\t{self.Jz} L\t{self.Lx:.08e}\t{self.Ly:.08e}\t{self.Lz:.08e} d\t{self.dx:.08e}\t{self.dy:.08e}\t{self.dz:.08e} --- dV = {self.dV:.08e} V = {self.V:.08e} ntot = {self.ntot:d} ncell = {self.ncell:d} ---""" return s
[docs] def get_x_coords( self, local: bool = True, dtype: np.dtype = np.dtype(np.float64) ) -> np.ndarray: """ Returns the x coordinates. Args: local: if True, returns the local coordinates, otherwise the global coordinates dtype: data type of the coordinates Returns: 1D array with the x coordinates """ x_global = np.linspace(0, self.Lx, self.Jx, dtype=dtype) # global coordinates # Split x into local parts if needed return x_global if not local else np.split(x_global, solvers.size)[solvers.rank]
[docs] def get_mesh( self, local: bool = True, dtype: np.dtype = np.dtype(np.float64) ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Returns a meshgrid of the coordinates. Use ij indexing. Args: local: if True, returns the local coordinates, otherwise the global coordinates dtype: data type of the coordinates Returns: Tuple of 3D arrays with the coordinates """ x = self.get_x_coords(local=local, dtype=dtype) y = np.linspace(0, self.Ly, self.Jy, dtype=dtype) z = np.linspace(0, self.Lz, self.Jz, dtype=dtype) return np.meshgrid(x, y, z, indexing="ij")
[docs] def as_dict(self) -> dict: """ Export grid parameters to a dictionary for JAX JIT compatibility. Returns: Dictionary containing grid parameters needed for computations """ return asdict(self)