llg3d.solvers.experimental.jax

LLG3D solver using XLA compilation.

Warning

It is experimental and not maintained.

Functions

compute_H_anisotropy(m, coeff_2, anisotropy)

Compute anisotropy field (JIT compiled).

compute_laplacian(m, dx, dy, dz)

Compute 3D Laplacian with Neumann boundary conditions (JIT compiled).

compute_slope(g_params, e_params, m, R_random)

JIT-compiled version of compute_slope_jax using modular sub-functions.

compute_space_average_jax(m1)

Compute space average using midpoint method on GPU (JIT compiled).

cross_product(a, b)

Compute cross product \(a \times b\) (JIT compiled).

laplacian3D(m_i, dx2_inv, dy2_inv, dz2_inv, ...)

Compute Laplacian for a single component with Neumann boundary conditions.

Classes

JaxSolver([element, N, dt, Jx, Jy, Jz, dx, ...])

JAX-based LLG3D solver.

compute_H_anisotropy(m, coeff_2, anisotropy)[source]

Compute anisotropy field (JIT compiled).

Parameters:
  • m (Array) – Magnetization array (shape (3, nx, ny, nz))

  • coeff_2 (float) – Coefficient for anisotropy

  • anisotropy (int) – Anisotropy type (0: uniaxial, 1: cubic)

Returns:

Anisotropy field array (shape (3, nx, ny, nz))

Return type:

Array

laplacian3D(m_i, dx2_inv, dy2_inv, dz2_inv, center_coeff)[source]

Compute Laplacian for a single component with Neumann boundary conditions.

(JIT compiled)

Parameters:
  • m_i (Array) – Single component of magnetization (shape (nx, ny, nz))

  • dx2_inv (float) – Inverse of squared grid spacing in x direction

  • dy2_inv (float) – Inverse of squared grid spacing in y direction

  • dz2_inv (float) – Inverse of squared grid spacing in z direction

  • center_coeff (float) – Coefficient for the center point

Returns:

Laplacian of m_i (shape (nx, ny, nz))

Return type:

Array

compute_laplacian(m, dx, dy, dz)[source]

Compute 3D Laplacian with Neumann boundary conditions (JIT compiled).

Parameters:
  • m (Array) – Magnetization array (shape (3, nx, ny, nz))

  • dx (float) – Grid spacing in x direction

  • dy (float) – Grid spacing in y direction

  • dz (float) – Grid spacing in z direction

Returns:

Laplacian of m (shape (3, nx, ny, nz))

Return type:

Array

compute_space_average_jax(m1)[source]

Compute space average using midpoint method on GPU (JIT compiled).

Parameters:

m1 (Array) – First component of magnetization (shape (nx, ny, nz))

Returns:

Space average of m1

Return type:

float

cross_product(a, b)[source]

Compute cross product \(a \times b\) (JIT compiled).

Parameters:
  • a (Array) – First vector (shape (3, nx, ny, nz))

  • b (Array) – Second vector (shape (3, nx, ny, nz))

Returns:

Cross product \(a \times b\) (shape (3, nx, ny, nz))

Return type:

Array

compute_slope(g_params, e_params, m, R_random)[source]

JIT-compiled version of compute_slope_jax using modular sub-functions.

Parameters:
  • g_params (dict) – Grid parameters dict (dx, dy, dz)

  • e_params (dict) – Element parameters dict (coeff_1, coeff_2, coeff_3, lambda_G, anisotropy)

  • m (Array) – Magnetization array (shape (3, nx, ny, nz))

  • R_random (Array) – Random field array (shape (3, nx, ny, nz))

Returns:

Slope array (shape (3, nx, ny, nz))

Return type:

Array

class JaxSolver(element='Cobalt', N=5000, dt=1e-14, Jx=300, Jy=21, Jz=21, dx=1e-09, T=0.0, H_ext=0.0, init_type='0', result_file='run.npz', start_averaging=4000, n_mean=1, n_profile=0, solver='numpy', precision='double', blocking=False, seed=12345, device='auto', profiling=False, verbosity='INFO', np=1)[source]

Bases: BaseSolver

JAX-based LLG3D solver.

Parameters:
  • element (Literal['Cobalt', 'Iron', 'Nickel'])

  • N (int)

  • dt (float)

  • Jx (int)

  • Jy (int)

  • Jz (int)

  • dx (float)

  • T (float)

  • H_ext (float)

  • init_type (Literal['0', 'dw'])

  • result_file (str)

  • start_averaging (int)

  • n_mean (int)

  • n_profile (int)

  • solver (Literal['opencl', 'mpi', 'numpy', 'jax'])

  • precision (Literal['single', 'double'])

  • blocking (bool)

  • seed (int)

  • device (Literal['cpu', 'gpu', 'auto'])

  • profiling (bool)

  • verbosity (Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'])

  • np (int)

solver_type: ClassVar[str] = 'jax'

Solver type name

_xyz_average(m1)[source]

Compute the space average of m1 using JAX.

Parameters:

m1 (Array)

Return type:

float

_simulate()[source]

Simulates the system for N iterations using JAX.

device

Device to use (‘cpu’, ‘gpu’, ‘gpu:0’, ‘gpu:1’, etc., or ‘auto’)

Returns:

The time taken for the simulation

Raises:

NotImplementedError – If n_profile is not zero

Return type:

float