llg3d.solvers.experimental

Experimental solvers live here.

Currently contains JAX-based solver implementations.

class JaxSolver(element='Cobalt', N=5000, dt=1e-14, Jx=300, Jy=21, Jz=21, dx=1e-09, T=0.0, H_ext=0.0, init_type='0', result_file='run.npz', start_averaging=4000, n_mean=1, n_profile=0, solver='numpy', precision='double', blocking=False, seed=12345, device='auto', profiling=False, verbosity='INFO', np=1)[source]

Bases: BaseSolver

JAX-based LLG3D solver.

Parameters:
  • element (Literal['Cobalt', 'Iron', 'Nickel'])

  • N (int)

  • dt (float)

  • Jx (int)

  • Jy (int)

  • Jz (int)

  • dx (float)

  • T (float)

  • H_ext (float)

  • init_type (Literal['0', 'dw'])

  • result_file (str)

  • start_averaging (int)

  • n_mean (int)

  • n_profile (int)

  • solver (Literal['opencl', 'mpi', 'numpy', 'jax'])

  • precision (Literal['single', 'double'])

  • blocking (bool)

  • seed (int)

  • device (Literal['cpu', 'gpu', 'auto'])

  • profiling (bool)

  • verbosity (Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'])

  • np (int)

solver_type: ClassVar[str] = 'jax'

Solver type name

_xyz_average(m1)[source]

Compute the space average of m1 using JAX.

Parameters:

m1 (Array)

Return type:

float

_simulate()[source]

Simulates the system for N iterations using JAX.

device

Device to use (‘cpu’, ‘gpu’, ‘gpu:0’, ‘gpu:1’, etc., or ‘auto’)

Returns:

The time taken for the simulation

Raises:

NotImplementedError – If n_profile is not zero

Return type:

float

Modules

jax

LLG3D solver using XLA compilation.