"""
LLG3D solver using OpenCL.
Uses :doc:`../opencl_kernels` for the OpenCL kernel implementations.
"""
from __future__ import annotations
import logging
import os
import time
from pathlib import Path
from typing import Any, ClassVar
import numpy as np
import pyopencl as cl
from pyopencl import array as clarray
from pyopencl import clrandom
from pyopencl import mem_flags as mf
from ..base import BaseSolver
from ..profiling import timeit
[docs]
class TimedKernel:
"""
Wrapper for OpenCL kernels with automatic timing.
Args:
kernel: The underlying OpenCL kernel
name: Name of the kernel for logging
solver: Reference to the solver for profiling storage
"""
def __init__(self, kernel: cl.Kernel, name: str, solver: "OpenCLSolver"):
self.kernel = kernel
self.name = name
self.solver = solver
def __call__(self, *args: Any, **kwargs: Any) -> None:
"""
Execute the kernel and record the event for profiling (optional).
Args:
*args: Positional arguments for the kernel
**kwargs: Keyword arguments for the kernel
"""
# Execute kernel
event: cl.Event = self.kernel(*args, **kwargs)
# Record event if profiling enabled
if self.solver.profiling:
self.solver._add_profiling_event(self.name, event)
[docs]
def get_context_and_device(
device_selection: str = "auto",
) -> tuple[cl.Context, cl.Device]:
"""
Get the OpenCL context and device.
Args:
device_selection:
- ``"auto"``: Let OpenCL choose automatically
- ``"cpu"``: Select CPU device
- ``"gpu"``: Select first available GPU
- ``"gpu:N"``: Select specific GPU by index (e.g., ``"gpu:0"``, ``"gpu:1"``)
Returns:
- The OpenCL context
- The OpenCL device
Raises:
RuntimeError: If no suitable device is found
ValueError: If the device selection string is invalid
"""
if device_selection == "auto":
context = cl.create_some_context(interactive=False)
device = context.devices[0]
return context, device
# Get all platforms and devices
platforms = cl.get_platforms()
all_devices = []
for platform in platforms:
all_devices.extend(platform.get_devices())
if not all_devices:
raise RuntimeError("No OpenCL devices found")
# Filter devices based on selection
if device_selection == "cpu":
cpu_devices = [d for d in all_devices if d.type & cl.device_type.CPU]
if not cpu_devices:
raise RuntimeError("No CPU devices found")
selected_device = cpu_devices[0]
elif device_selection == "gpu":
gpu_devices = [d for d in all_devices if d.type & cl.device_type.GPU]
if not gpu_devices:
raise RuntimeError("No GPU devices found")
selected_device = gpu_devices[0]
elif device_selection.startswith("gpu:"):
gpu_devices = [d for d in all_devices if d.type & cl.device_type.GPU]
if not gpu_devices:
raise RuntimeError("No GPU devices found")
gpu_index = int(device_selection.split(":")[1])
if gpu_index >= len(gpu_devices):
raise RuntimeError(
f"GPU index {gpu_index} not available. Found {len(gpu_devices)} GPU(s)"
)
selected_device = gpu_devices[gpu_index]
else:
raise ValueError(f"Invalid device selection: {device_selection}")
# Create context with selected device
context = cl.Context([selected_device])
logging.info(
f"Selected OpenCL device: {selected_device.name} ({selected_device.type})"
)
return context, selected_device
[docs]
def get_precision(device: cl.Device, precision: str) -> np.dtype:
"""
Get the numpy float type based on the precision.
Args:
device: OpenCL device
precision: Precision of the simulation (single or double)
Returns:
The numpy float type (float32 or float64)
Raises:
RuntimeError: If double precision is asked while the device does not support it
"""
# Check that cl device supports double precision
if precision == "double" and not device.double_fp_config:
raise RuntimeError("The selected device does not support double precision.")
return np.dtype(np.float64 if precision == "double" else np.float32)
[docs]
class Program:
"""Class to manage the OpenCL kernels for the LLG3D simulation."""
def __init__(self, solver: OpenCLSolver):
self.solver = solver #: The OpenCLSolver instance
self.cl_program: cl.Program = self._get_built_program() #: The OpenCL program
[docs]
def _get_built_program(self) -> cl.Program:
"""
Return the OpenCL program built from the source code.
Returns:
The OpenCL program object
"""
grid = self.solver.grid
opencl_code: str = (Path(__file__).parent / "kernels.cl").read_text()
build_options = (
"-D USE_DOUBLE_PRECISION" if self.solver.np_float == np.float64 else ""
)
build_options += f" -D NX={grid.Jx} -D NY={grid.Jy} -D NZ={grid.Jz}"
build_options += " -cl-fp32-correctly-rounded-divide-sqrt"
# Optional error detection in kernels, opt-in via env var
if self.solver.error_check_enabled:
build_options += " -D ENABLE_ERROR_CHECK"
# Add anisotropy type directive
if self.solver.elem.anisotropy == "uniaxial":
build_options += " -D ANISOTROPY_UNIAXIAL"
# else: cubic is the default
pbc_x = self.solver.geometry in ("periodic_wire", "periodic_layer", "periodic_box")
pbc_y = self.solver.geometry in ("periodic_layer", "periodic_box")
pbc_z = self.solver.geometry == "periodic_box"
if pbc_x:
build_options += " -D PBC_X"
if pbc_y:
build_options += " -D PBC_Y"
if pbc_z:
build_options += " -D PBC_Z"
return cl.Program(self.solver.context, opencl_code).build(options=build_options)
[docs]
def get_kernel(self, kernel_name: str, arg_types: list = [None]) -> TimedKernel:
"""
Returns the specified kernel by name, wrapped with timing.
Args:
kernel_name: Name of the kernel to retrieve
arg_types: List of argument types for the kernel
Returns:
The TimedKernel wrapper around the OpenCL kernel
"""
kernel: cl.Kernel = getattr(self.cl_program, kernel_name)
kernel.set_arg_types(arg_types)
return TimedKernel(kernel, kernel_name, self.solver)
[docs]
class OpenCLSolver(BaseSolver):
"""OpenCL-based LLG3D solver."""
solver_type: ClassVar[str] = "opencl" #: Solver type name
def __post_init__(self) -> None:
"""Initialize OpenCL solver and check parameters."""
logging.info("Initializing OpenCL solver...")
logging.info("Initializing context...")
self.context, opencl_device = get_context_and_device(self.device)
# Enable profiling on the command queue
properties = cl.command_queue_properties.PROFILING_ENABLE
self.queue = cl.CommandQueue(self.context, properties=properties)
self.profiling_events: list[tuple[str, cl.Event]] = []
#: Numpy float type (np.dtype(np.float32) or np.dtype(np.float64))
self.np_float: np.dtype = get_precision(opencl_device, self.precision)
# Optional error detection in kernels, opt-in via env var
error_check_env = os.getenv("LLG3D_ENABLE_ERROR_CHECK", "0").lower()
#: Whether error checking is enabled in kernels
self.error_check_enabled: bool = error_check_env in {"1", "true", "on", "yes"}
super().__post_init__()
# Check that the grid is uniform
if not self.grid.uniform:
raise ValueError("OpenCLSolver only supports uniform grids.")
# Create OpenCL kernels
program = Program(self)
kernel_arg_types = [None] * 4 + [self.np_float] * 7
# Fused kernels combining slope computation and updates
self.predict_kernel = program.get_kernel(
"predict", kernel_arg_types
) #: Prediction kernel with slope storage
# Correction kernel with error codes
correct_arg_types = kernel_arg_types + [None]
self.correct_and_normalize_kernel = program.get_kernel(
"correct_and_normalize", correct_arg_types
) #: Correction kernel (slope + update_2 + normalize) with error codes
#: Weighted reduction kernel for m1 averaging
self.sum_m1_weighted_kernel = program.get_kernel(
"sum_m1_weighted", [None, None, None]
)
#: Second-level reduction to finalize sum on GPU
self.reduce_partial_sums_kernel = program.get_kernel(
"reduce_partial_sums", [None, None, None, None]
)
# Buffer for error codes (one int per work-item, thread-safe).
# Always present to simplify call.
if self.error_check_enabled:
err_buf_bytes = self.grid.ntot * 4
host_err_shape = self.grid.ntot
else:
err_buf_bytes = 4 # minimal placeholder when disabled
host_err_shape = 1
self.d_error_codes = cl.Buffer(
self.context, cl.mem_flags.READ_WRITE, size=err_buf_bytes
)
self.h_error_codes = np.zeros(host_err_shape, dtype=np.int32)
# Pre-allocate buffers for two-level reduction (used in _xyz_average)
self.wgroup_size = 256
self.num_groups = (self.grid.ntot + self.wgroup_size - 1) // self.wgroup_size
self.d_partial_sums = cl.Buffer(
self.context,
cl.mem_flags.READ_WRITE,
size=self.num_groups * self.np_float.itemsize,
)
# Buffer for final sum (single value)
self.d_final_sum = cl.Buffer(
self.context,
cl.mem_flags.READ_WRITE,
size=self.np_float.itemsize,
)
self.h_final_sum = np.empty(1, dtype=self.np_float)
# Buffer to accumulate all xyz_average results during simulation
# Will be sized dynamically based on N and n_mean
self.d_all_averages: cl.Buffer | None = None
self.h_all_averages: np.ndarray | None = None
self.average_counter = 0
# Compute work-group size for second reduction (next power of 2)
self.reduce_wgroup_size = 1
while self.reduce_wgroup_size < self.num_groups:
self.reduce_wgroup_size *= 2
self.reduce_wgroup_size = min(self.reduce_wgroup_size, 256)
[docs]
def _init_rng(self) -> Any:
"""
Initialize a random number generator for temperature fluctuations.
Returns:
An OpenCL random number generator
"""
return clrandom.PhiloxGenerator(self.context, seed=self.seed)
[docs]
def _compute_R_random(self, d_R_random: clarray.Array):
"""
Compute random number array for thermal fluctuations.
Args:
d_R_random: Device array to fill with random numbers
"""
event = self.rng.fill_normal(d_R_random) # type: ignore[attr-defined]
# Record profiling event if enabled
if self.profiling:
self._add_profiling_event("_compute_R_random", event)
[docs]
@timeit
def _xyz_average(self, m: Any, immediate: bool = True) -> float | None:
"""
Compute the space average of m_n using a two-level GPU reduction.
Args:
m: Current magnetization device array
immediate: If True (default), read result immediately and return it.
If False, accumulate on GPU without synchronization.
Returns:
The average value if immediate=True, None otherwise.
"""
# First reduction: work-groups produce partial sums
local_mem_size = self.wgroup_size * self.np_float.itemsize
self.sum_m1_weighted_kernel(
self.queue,
(self.num_groups * self.wgroup_size,), # Global work size
(self.wgroup_size,), # Local work size
m.data,
self.d_partial_sums,
cl.LocalMemory(local_mem_size),
)
# Second reduction: reduce partial sums to single value on GPU
reduce_local_mem_size = self.reduce_wgroup_size * self.np_float.itemsize
self.reduce_partial_sums_kernel(
self.queue,
(self.reduce_wgroup_size,), # Global work size
(self.reduce_wgroup_size,), # Local work size (single work-group)
self.d_partial_sums,
self.d_final_sum,
cl.LocalMemory(reduce_local_mem_size),
np.int32(self.num_groups),
)
if immediate:
# Read result immediately for unit tests or debugging
cl.enqueue_copy(
self.queue,
self.h_final_sum,
self.d_final_sum,
).wait()
return float(self.h_final_sum[0]) / self.grid.ncell
else:
# Copy result to accumulated buffer (no synchronization)
offset = self.average_counter * self.np_float.itemsize
assert self.d_all_averages is not None
cl.enqueue_copy(
self.queue,
self.d_all_averages,
self.d_final_sum,
dst_offset=offset,
byte_count=self.np_float.itemsize,
)
self.average_counter += 1
return None
[docs]
def _record_xyz_average(self, m_n: Any, t: float, n: int) -> None:
"""Override to defer reading results until simulation end."""
# If buffer is not allocated (e.g. in tests calling this directly),
# fall back to synchronous execution
if self.d_all_averages is None:
# Compute immediately
val = self._xyz_average(m_n[0], immediate=True)
# Record in records
if "xyz_average" not in self.records:
self.records["xyz_average"] = []
self.records["xyz_average"].append((t, val))
# Accumulate for m1_mean
if self.start_averaging is None or n >= self.start_averaging:
if "m1_mean" not in self.observables:
self.observables["m1_mean"] = 0.0
self.observables["m1_mean"] += val
return
# Deferred mode (during simulation)
# Just compute and store on GPU, don't read back yet
self._xyz_average(m_n[0], immediate=False)
# Store metadata for later association with results
if not hasattr(self, "_xyz_average_meta"):
self._xyz_average_meta = []
self._xyz_average_meta.append((t, n))
[docs]
def _check_errors(self, iteration: int) -> None:
"""
Check for non-finite values detected by the kernel.
Args:
iteration: Current iteration number (for error message)
Raises:
RuntimeError: If non-finite values are detected
"""
if not self.error_check_enabled:
return
cl.enqueue_copy(self.queue, self.h_error_codes, self.d_error_codes).wait()
if np.any(self.h_error_codes):
error_gids = np.where(self.h_error_codes)[0]
first_gid = error_gids[0]
# Convert gid to (i, j, k) coordinates
k = first_gid % self.grid.Jz
j = (first_gid // self.grid.Jz) % self.grid.Jy
i = first_gid // (self.grid.Jy * self.grid.Jz)
raise RuntimeError(
f"Non-finite value detected at iteration n={iteration}\n"
f"Location: gid={first_gid}, i={i}, j={j}, k={k}"
)
[docs]
@timeit
def _update_x_profiles(self, m_n: Any, t: float):
"""
Update the x profiles of m_n at time t.
Reads the device buffer back to host and computes the y-z averaged
profiles for each magnetization component, then stores them in
`self.x_profiles`.
"""
# Initialize x_profiles on first use
if "x_profiles" not in self.records:
self.records["x_profiles"] = {
"t": [],
"m1": [],
"m2": [],
"m3": [],
}
# allocate host array and copy from device
h_m_n = m_n.map_to_host(is_blocking=True)
# save profiles
self.records["x_profiles"]["t"].append(t)
self.records["x_profiles"]["m1"].append(self._yz_average(h_m_n[0]))
self.records["x_profiles"]["m2"].append(self._yz_average(h_m_n[1]))
self.records["x_profiles"]["m3"].append(self._yz_average(h_m_n[2]))
[docs]
def _simulate(self) -> float:
"""
Simulates the system over N iterations.
Returns:
The time taken for the simulation
"""
elem = self.elem
grid = self.grid
queue = self.queue
h_m_n = self._init_m_n() # Create initial magnetization array on host
# Create device buffers
# Use clarray
d_m_n = clarray.to_device(queue, h_m_n)
d_R_random = clarray.Array(queue, (3, *grid.dims), self.np_float)
d_m_np1 = clarray.empty_like(d_m_n)
d_s_pre = cl.Buffer(self.context, mf.READ_WRITE, size=h_m_n.nbytes)
# Pre-allocate buffer for all xyz_average results
max_averages = (self.N // self.n_mean) + 2 if self.n_mean > 0 else 0
if max_averages > 0:
self.d_all_averages = cl.Buffer(
self.context,
cl.mem_flags.READ_WRITE,
size=max_averages * self.np_float.itemsize,
)
self.h_all_averages = np.empty(max_averages, dtype=self.np_float)
self.average_counter = 0
self._xyz_average_meta = []
t = 0.0
self._record(d_m_n, t, 0) # Record the initial solution
start_time = time.perf_counter()
for n in self._progress_bar():
t += self.dt
self._compute_R_random(d_R_random)
# Prediction phase: compute slope, store it, and apply first Euler update
self.predict_kernel(
queue,
(grid.ntot,),
None,
d_m_n.data,
d_m_np1.data,
d_R_random.data,
d_s_pre,
grid.inv_dx2,
elem.coeff_1,
elem.coeff_2,
elem.coeff_3,
elem.coeff_4,
elem.lambda_G,
self.dt,
)
# Correction phase: compute slope, apply midpoint update and normalize
self.correct_and_normalize_kernel(
queue,
(grid.ntot,),
None,
d_m_n.data,
d_m_np1.data,
d_R_random.data,
d_s_pre,
grid.inv_dx2,
elem.coeff_1,
elem.coeff_2,
elem.coeff_3,
elem.coeff_4,
elem.lambda_G,
self.dt,
self.d_error_codes,
)
# Check for errors if enabled
self._check_errors(n)
# Swap the buffers for the next iteration
d_m_n, d_m_np1 = d_m_np1, d_m_n
self._record(d_m_n, t, n)
self.queue.finish() # Ensure all operations are complete before timing
total_time = time.perf_counter() - start_time
# Now read back all xyz_average results in one go
self._process_deferred_xyz_averages()
if self.profiling:
self._process_profiling_events()
self._finalize(d_m_n, t)
return total_time
[docs]
def _add_profiling_event(self, name: str, event: cl.Event) -> None:
"""Add a profiling event to the list."""
self.profiling_events.append((name, event))
[docs]
def _process_profiling_events(self) -> None:
"""
Process pending profiling events and update stats.
Fill the :attr:`~llg3d.solvers.base.BaseSolver.profiling_stats` dict with
cumulative time and call counts for each kernel and function.
"""
if not self.profiling_events:
return
# Wait for all events to complete
cl.wait_for_events([e for _, e in self.profiling_events])
for name, event in self.profiling_events:
# Time in nanoseconds
elapsed_ns = event.profile.end - event.profile.start
elapsed_sec = elapsed_ns * 1e-9
# Use name as-is if it starts with underscore (function), else prefix with
# kernel_
key = name if name.startswith("_") else f"kernel_{name}"
stats = self.profiling_stats[key]
stats["time"] += elapsed_sec
stats["calls"] += 1
self.profiling_events.clear()
[docs]
def _process_deferred_xyz_averages(self) -> None:
"""
Read back and process deferred xyz_average results.
- Fill the ``xyz_average`` entry of the
:attr:`~llg3d.solvers.base.BaseSolver.records` dict with (time, value) pairs
- Add a ``m1_mean`` entry to the
:attr:`~llg3d.solvers.base.BaseSolver.observables` dict.
"""
if self.average_counter <= 0:
return
assert self.d_all_averages is not None
assert self.h_all_averages is not None
# Read back all results in one go
cl.enqueue_copy(
self.queue,
self.h_all_averages[: self.average_counter],
self.d_all_averages,
).wait()
# Convert to physical values
averaged_values = self.h_all_averages[: self.average_counter] / self.grid.ncell
# Store individual (time, value) pairs in records
if "xyz_average" not in self.records:
self.records["xyz_average"] = []
# Retrieve metadata (defaults to empty list if not set)
meta = getattr(self, "_xyz_average_meta", [])
m1_sum = 0.0
# If no samples fall in the averaging window, avoid division by zero later
found_sample_in_window = False
for (t, n), val in zip(meta, averaged_values):
val_float = float(val)
self.records["xyz_average"].append((t, val_float))
if self.start_averaging is None or n >= self.start_averaging:
m1_sum += val_float
found_sample_in_window = True
# Store sum in observables for averaging in _finalize
# Only set if we actually accumulated something
if found_sample_in_window:
self.observables["m1_mean"] = m1_sum