llg3d.solvers.experimental.jax¶
LLG3D solver using XLA compilation.
Warning
It is experimental and not maintained.
Functions
|
Compute anisotropy field (JIT compiled). |
|
Compute 3D Laplacian with Neumann boundary conditions (JIT compiled). |
|
JIT-compiled version of compute_slope_jax using modular sub-functions. |
Compute space average using midpoint method on GPU (JIT compiled). |
|
|
Compute cross product \(a \times b\) (JIT compiled). |
|
Compute Laplacian for a single component with Neumann boundary conditions. |
Classes
|
JAX-based LLG3D solver. |
- compute_H_anisotropy(m, coeff_2, anisotropy)[source]¶
Compute anisotropy field (JIT compiled).
- Parameters:
m (Array) – Magnetization array (shape (3, nx, ny, nz))
coeff_2 (float) – Coefficient for anisotropy
anisotropy (int) – Anisotropy type (0: uniaxial, 1: cubic)
- Returns:
Anisotropy field array (shape (3, nx, ny, nz))
- Return type:
Array
- laplacian3D(m_i, dx2_inv, dy2_inv, dz2_inv, center_coeff)[source]¶
Compute Laplacian for a single component with Neumann boundary conditions.
(JIT compiled)
- Parameters:
m_i (Array) – Single component of magnetization (shape (nx, ny, nz))
dx2_inv (float) – Inverse of squared grid spacing in x direction
dy2_inv (float) – Inverse of squared grid spacing in y direction
dz2_inv (float) – Inverse of squared grid spacing in z direction
center_coeff (float) – Coefficient for the center point
- Returns:
Laplacian of m_i (shape (nx, ny, nz))
- Return type:
Array
- compute_laplacian(m, dx, dy, dz)[source]¶
Compute 3D Laplacian with Neumann boundary conditions (JIT compiled).
- Parameters:
m (Array) – Magnetization array (shape (3, nx, ny, nz))
dx (float) – Grid spacing in x direction
dy (float) – Grid spacing in y direction
dz (float) – Grid spacing in z direction
- Returns:
Laplacian of m (shape (3, nx, ny, nz))
- Return type:
Array
- compute_space_average_jax(m1)[source]¶
Compute space average using midpoint method on GPU (JIT compiled).
- Parameters:
m1 (Array) – First component of magnetization (shape (nx, ny, nz))
- Returns:
Space average of m1
- Return type:
float
- cross_product(a, b)[source]¶
Compute cross product \(a \times b\) (JIT compiled).
- Parameters:
a (Array) – First vector (shape (3, nx, ny, nz))
b (Array) – Second vector (shape (3, nx, ny, nz))
- Returns:
Cross product \(a \times b\) (shape (3, nx, ny, nz))
- Return type:
Array
- compute_slope(g_params, e_params, m, R_random)[source]¶
JIT-compiled version of compute_slope_jax using modular sub-functions.
- Parameters:
g_params (dict) – Grid parameters dict (dx, dy, dz)
e_params (dict) – Element parameters dict (coeff_1, coeff_2, coeff_3, lambda_G, anisotropy)
m (Array) – Magnetization array (shape (3, nx, ny, nz))
R_random (Array) – Random field array (shape (3, nx, ny, nz))
- Returns:
Slope array (shape (3, nx, ny, nz))
- Return type:
Array
- class JaxSolver(element='Cobalt', N=5000, dt=1e-14, Jx=300, Jy=21, Jz=21, dx=1e-09, T=0.0, H_ext=0.0, init_type='0', result_file='run.npz', start_averaging=4000, n_mean=1, n_profile=0, solver='numpy', precision='double', blocking=False, seed=12345, device='auto', profiling=False, verbosity='INFO', np=1)[source]¶
Bases:
BaseSolverJAX-based LLG3D solver.
- Parameters:
element (Literal['Cobalt', 'Iron', 'Nickel'])
N (int)
dt (float)
Jx (int)
Jy (int)
Jz (int)
dx (float)
T (float)
H_ext (float)
init_type (Literal['0', 'dw'])
result_file (str)
start_averaging (int)
n_mean (int)
n_profile (int)
solver (Literal['opencl', 'mpi', 'numpy', 'jax'])
precision (Literal['single', 'double'])
blocking (bool)
seed (int)
device (Literal['cpu', 'gpu', 'auto'])
profiling (bool)
verbosity (Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'])
np (int)
- solver_type: ClassVar[str] = 'jax'¶
Solver type name