"""Define the computational grid for the simulation."""fromtypingimportClassVarfromdataclassesimportdataclass,field,asdictimportnumpyasnpfrom.importsolvers
[docs]@dataclassclassGrid:"""Stores grid data."""# Parameter values correspond to the global gridJx:int#: number of points in x directionJy:int#: number of points in y directionJz:int#: number of points in z directiondx:float#: grid spacing in x directiondy:float=field(init=False)#: grid spacing in y directiondz:float=field(init=False)#: grid spacing in z directiondV:float=field(init=False)#: elemental volumeLx:float=field(init=False)#: physical length in x directionLy:float=field(init=False)#: physical length in y directionLz:float=field(init=False)#: physical length in z directiondims:tuple[int,int,int]=field(init=False)#: local grid dimensionsV:float=field(init=False)#: total volumentot:int=field(init=False)#: total number of pointsncell:int=field(init=False)#: total number of cellsinv_dx2:float=field(init=False)#: :math:`1/dx^2`inv_dy2:float=field(init=False)#: :math:`1/dy^2`inv_dz2:float=field(init=False)#: :math:`1/dz^2`center_coeff:float=field(init=False)#: center coefficient for Laplacianuniform:ClassVar[bool]=True#: whether the grid is uniformdef__post_init__(self)->None:"""Compute grid characteristics."""self.dy=self.dz=self.dx# Enforce dx = dy = dzself.Lx=(self.Jx-1)*self.dxself.Ly=(self.Jy-1)*self.dyself.Lz=(self.Jz-1)*self.dzself.dims=self.Jx//solvers.size,self.Jy,self.Jzself.dV=self.dx*self.dy*self.dzself.V=self.Lx*self.Ly*self.Lzself.ntot=self.Jx*self.Jy*self.Jzself.ncell=(self.Jx-1)*(self.Jy-1)*(self.Jz-1)# precompute the Laplacian coefficients (uniform grid spacing)self.inv_dx2=self.inv_dy2=self.inv_dz2=1/self.dx**2self.center_coeff=-6.0*self.inv_dx2def__str__(self)->str:"""Return grid information."""header="\t\t".join(("x","y","z"))s=f"""\---\t{header}J\t{self.Jx}\t\t{self.Jy}\t\t{self.Jz}L\t{self.Lx:.08e}\t{self.Ly:.08e}\t{self.Lz:.08e}d\t{self.dx:.08e}\t{self.dy:.08e}\t{self.dz:.08e}---dV = {self.dV:.08e}V = {self.V:.08e}ntot = {self.ntot:d}ncell = {self.ncell:d}---"""returns
[docs]defget_x_coords(self,local:bool=True,dtype:np.dtype=np.dtype(np.float64))->np.ndarray:""" Returns the x coordinates. Args: local: if True, returns the local coordinates, otherwise the global coordinates dtype: data type of the coordinates Returns: 1D array with the x coordinates """x_global=np.linspace(0,self.Lx,self.Jx,dtype=dtype)# global coordinates# Split x into local parts if neededreturnx_globalifnotlocalelsenp.split(x_global,solvers.size)[solvers.rank]
[docs]defget_mesh(self,local:bool=True,dtype:np.dtype=np.dtype(np.float64))->tuple[np.ndarray,np.ndarray,np.ndarray]:""" Returns a meshgrid of the coordinates. Use ij indexing. Args: local: if True, returns the local coordinates, otherwise the global coordinates dtype: data type of the coordinates Returns: Tuple of 3D arrays with the coordinates """x=self.get_x_coords(local=local,dtype=dtype)y=np.linspace(0,self.Ly,self.Jy,dtype=dtype)z=np.linspace(0,self.Lz,self.Jz,dtype=dtype)returnnp.meshgrid(x,y,z,indexing="ij")
[docs]defas_dict(self)->dict:""" Export grid parameters to a dictionary for JAX JIT compatibility. Returns: Dictionary containing grid parameters needed for computations """returnasdict(self)