Benchmark of the Laplacian¶
We test 2 methods for calculating the Laplacian of m with Neumann conditions (zero flux on edges):
Slices of
mare extracted and concatenated with slicedm.Extract slices from
m, shift the elements ofmby one index (create a new array) and assign the extracted slice tom.
A random 3D array is initialized.
[1]:
import numpy as np
m = np.random.rand(6000, 20, 20)
Assume \(\Delta x = \Delta y = \Delta z = 1\).
Method 1¶
[2]:
def laplacian_concat(m, dx=1.0, dy=1.0, dz=1.0):
"""Laplacian operator using numpy concatenation."""
dx2_inv, dy2_inv, dz2_inv = 1/dx**2, 1/dy**2, 1/dz**2
center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
m_start_x = m[1:2, :, :]
m_end_x = m[-2:-1, :, :]
m_start_y = m[:, 1:2, :]
m_end_y = m[:, -2:-1, :]
m_start_z = m[:, :, 1:2]
m_end_z = m[:, :, -2:-1]
m_x_plus = np.concatenate((m[1:, :, :], m_end_x), axis=0)
m_x_minus = np.concatenate((m_start_x, m[:-1, :, :]), axis=0)
m_y_plus = np.concatenate((m[:, 1:, :], m_end_y), axis=1)
m_y_minus = np.concatenate((m_start_y, m[:, :-1, :]), axis=1)
m_z_plus = np.concatenate((m[:, :, 1:], m_end_z), axis=2)
m_z_minus = np.concatenate((m_start_z, m[:, :, :-1]), axis=2)
return (dx2_inv * (m_x_plus + m_x_minus) +
dy2_inv * (m_y_plus + m_y_minus) +
dz2_inv * (m_z_plus + m_z_minus) +
center_coeff * m)
Method 2¶
[3]:
def laplacian_roll(m, dx=1.0, dy=1.0, dz=1.0):
"""Laplacian operator using numpy rolling."""
dx2_inv, dy2_inv, dz2_inv = 1/dx**2, 1/dy**2, 1/dz**2
center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
m_start_x = m[1, :, :]
m_end_x = m[-2, :, :]
m_start_y = m[:, 1, :]
m_end_y = m[:, -2, :]
m_start_z = m[:, :, 1]
m_end_z = m[:, :, -2]
m_x_plus = np.roll(m, -1, axis=0)
m_x_plus[-1, :, :] = m_end_x
m_x_minus = np.roll(m, 1, axis=0)
m_x_minus[0, :, :] = m_start_x
m_y_plus = np.roll(m, -1, axis=1)
m_y_plus[:, -1, :] = m_end_y
m_y_minus = np.roll(m, 1, axis=1)
m_y_minus[:, 0, :] = m_start_y
m_z_plus = np.roll(m, -1, axis=2)
m_z_plus[:, :, -1] = m_end_z
m_z_minus = np.roll(m, 1, axis=2)
m_z_minus[:, :, 0] = m_start_z
return (dx2_inv * (m_x_plus + m_x_minus) +
dy2_inv * (m_y_plus + m_y_minus) +
dz2_inv * (m_z_plus + m_z_minus) +
center_coeff * m)
Method 3¶
[4]:
def laplacian_pad(m, dx=1.0, dy=1.0, dz=1.0):
"""Laplacian operator using numpy padding."""
m_padded = np.pad(m, ((1,1), (1,1), (1,1)), mode='reflect')
dx2_inv, dy2_inv, dz2_inv = 1/dx**2, 1/dy**2, 1/dz**2
center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
return (
dx2_inv * (m_padded[2:, 1:-1, 1:-1] + m_padded[:-2, 1:-1, 1:-1]) +
dy2_inv * (m_padded[1:-1, 2:, 1:-1] + m_padded[1:-1, :-2, 1:-1]) +
dz2_inv * (m_padded[1:-1, 1:-1, 2:] + m_padded[1:-1, 1:-1, :-2]) +
center_coeff * m
)
We check that all methods give identical results.
[5]:
assert np.array_equal(laplacian_concat(m), laplacian_roll(m))
assert np.array_equal(laplacian_concat(m), laplacian_pad(m))
We compare execution times.
[6]:
%timeit L1 = laplacian_concat(m)
14 ms ± 445 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[7]:
%timeit L2 = laplacian_roll(m)
15.3 ms ± 408 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[8]:
%timeit L3 = laplacian_pad(m)
15.4 ms ± 488 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Method 3 is therefore slightly faster.
Method 4: JAX with JIT compilation¶
Note: @jax.jit has compilation overhead that can make it slower for small problems or single calls. It’s more efficient for:
Large arrays (>1M elements)
Repeated calls on same-shaped arrays
GPU computation
Complex functions with many operations
[9]:
import jax
import jax.numpy as jnp
# enforce cpu
jax.config.update("jax_platform_name", "cpu")
# Force JAX to use float64 (double precision)
jax.config.update("jax_enable_x64", True)
print("JAX devices:", jax.devices())
print("JAX default backend:", jax.default_backend())
def laplacian_jax(m, dx=1.0, dy=1.0, dz=1.0):
"""Laplacian operator using JAX with proper Neumann boundary conditions."""
dx2_inv, dy2_inv, dz2_inv = 1/dx**2, 1/dy**2, 1/dz**2
center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
m_jax = jnp.array(m)
m_padded = jnp.pad(m_jax, ((1,1), (1,1), (1,1)), mode='reflect')
return (
dx2_inv * (m_padded[2:, 1:-1, 1:-1] + m_padded[:-2, 1:-1, 1:-1]) +
dy2_inv * (m_padded[1:-1, 2:, 1:-1] + m_padded[1:-1, :-2, 1:-1]) +
dz2_inv * (m_padded[1:-1, 1:-1, 2:] + m_padded[1:-1, 1:-1, :-2]) +
center_coeff * m_jax
)
# Pre-compiled JIT version for fair benchmarking
@jax.jit
def _laplacian_jax_compiled(m_jax, dx=1.0, dy=1.0, dz=1.0):
"""Pre-compiled JIT version to exclude compilation time from benchmarks."""
dx2_inv, dy2_inv, dz2_inv = 1/dx**2, 1/dy**2, 1/dz**2
center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
m_padded = jnp.pad(m_jax, ((1,1), (1,1), (1,1)), mode='reflect')
return (
dx2_inv * (m_padded[2:, 1:-1, 1:-1] + m_padded[:-2, 1:-1, 1:-1]) +
dy2_inv * (m_padded[1:-1, 2:, 1:-1] + m_padded[1:-1, :-2, 1:-1]) +
dz2_inv * (m_padded[1:-1, 1:-1, 2:] + m_padded[1:-1, 1:-1, :-2]) +
center_coeff * m_jax
)
def laplacian_jax_jit(m, dx=1.0, dy=1.0, dz=1.0):
"""Wrapper for pre-compiled JIT function."""
return _laplacian_jax_compiled(jnp.array(m), dx, dy, dz)
JAX devices: [CpuDevice(id=0)]
JAX default backend: cpu
[10]:
# JAX comparison (might have tiny numerical differences)
jax_result = laplacian_jax(m)
jax_jit_result = laplacian_jax_jit(m)
diff_fast = np.abs(laplacian_concat(m) - jax_result).max()
diff_jit = np.abs(laplacian_concat(m) - jax_jit_result).max()
print(f"Max difference JAX (no JIT): {diff_fast:.2e}")
print(f"Max difference JAX (JIT): {diff_jit:.2e}")
assert np.allclose(laplacian_concat(m), jax_result, atol=1e-15)
assert np.allclose(laplacian_concat(m), jax_jit_result, atol=1e-15)
Max difference JAX (no JIT): 0.00e+00
Max difference JAX (JIT): 8.88e-16
JAX without JIT (baseline)¶
[11]:
%timeit L4 = laplacian_jax(m)
10.7 ms ± 129 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
JAX with JIT¶
[12]:
# Pre-compile the JIT function (this triggers compilation)
print("Pre-compiling JIT function...")
_ = laplacian_jax_jit(m) # Trigger compilation
print("Compilation done. Now benchmarking JIT performance...")
# JAX with JIT (pre-compiled, fair benchmark)
%timeit L4_jit = laplacian_jax_jit(m)
Pre-compiling JIT function...
Compilation done. Now benchmarking JIT performance...
2.72 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)