Benchmark of the Laplacian

We test 2 methods for calculating the Laplacian of m with Neumann conditions (zero flux on edges):

  1. Slices of m are extracted and concatenated with sliced m.

  2. Extract slices from m, shift the elements of m by one index (create a new array) and assign the extracted slice to m.

A random 3D array is initialized.

[1]:
import numpy as np

m = np.random.rand(6000, 20, 20)

Assume \(\Delta x = \Delta y = \Delta z = 1\).

Method 1

[2]:
def laplacian_concat(m, dx=1.0, dy=1.0, dz=1.0):
    """Laplacian operator using numpy concatenation."""
    dx2_inv, dy2_inv, dz2_inv = 1/dx**2, 1/dy**2, 1/dz**2
    center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)

    m_start_x = m[1:2, :, :]
    m_end_x = m[-2:-1, :, :]

    m_start_y = m[:, 1:2, :]
    m_end_y = m[:, -2:-1, :]

    m_start_z = m[:, :, 1:2]
    m_end_z = m[:, :, -2:-1]

    m_x_plus = np.concatenate((m[1:, :, :], m_end_x), axis=0)
    m_x_minus = np.concatenate((m_start_x, m[:-1, :, :]), axis=0)
    m_y_plus = np.concatenate((m[:, 1:, :], m_end_y), axis=1)
    m_y_minus = np.concatenate((m_start_y, m[:, :-1, :]), axis=1)
    m_z_plus = np.concatenate((m[:, :, 1:], m_end_z), axis=2)
    m_z_minus = np.concatenate((m_start_z, m[:, :, :-1]), axis=2)

    return (dx2_inv * (m_x_plus + m_x_minus) +
            dy2_inv * (m_y_plus + m_y_minus) +
            dz2_inv * (m_z_plus + m_z_minus) +
            center_coeff * m)

Method 2

[3]:
def laplacian_roll(m, dx=1.0, dy=1.0, dz=1.0):
    """Laplacian operator using numpy rolling."""
    dx2_inv, dy2_inv, dz2_inv = 1/dx**2, 1/dy**2, 1/dz**2
    center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)

    m_start_x = m[1, :, :]
    m_end_x = m[-2, :, :]

    m_start_y = m[:, 1, :]
    m_end_y = m[:, -2, :]

    m_start_z = m[:, :, 1]
    m_end_z = m[:, :, -2]

    m_x_plus = np.roll(m, -1, axis=0)
    m_x_plus[-1, :, :] = m_end_x
    m_x_minus = np.roll(m, 1, axis=0)
    m_x_minus[0, :, :] = m_start_x

    m_y_plus = np.roll(m, -1, axis=1)
    m_y_plus[:, -1, :] = m_end_y
    m_y_minus = np.roll(m, 1, axis=1)
    m_y_minus[:, 0, :] = m_start_y

    m_z_plus = np.roll(m, -1, axis=2)
    m_z_plus[:, :, -1] = m_end_z
    m_z_minus = np.roll(m, 1, axis=2)
    m_z_minus[:, :, 0] = m_start_z

    return (dx2_inv * (m_x_plus + m_x_minus) +
            dy2_inv * (m_y_plus + m_y_minus) +
            dz2_inv * (m_z_plus + m_z_minus) +
            center_coeff * m)

Method 3

[4]:
def laplacian_pad(m, dx=1.0, dy=1.0, dz=1.0):
    """Laplacian operator using numpy padding."""
    m_padded = np.pad(m, ((1,1), (1,1), (1,1)), mode='reflect')

    dx2_inv, dy2_inv, dz2_inv = 1/dx**2, 1/dy**2, 1/dz**2
    center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)

    return (
        dx2_inv * (m_padded[2:, 1:-1, 1:-1] + m_padded[:-2, 1:-1, 1:-1]) +
        dy2_inv * (m_padded[1:-1, 2:, 1:-1] + m_padded[1:-1, :-2, 1:-1]) +
        dz2_inv * (m_padded[1:-1, 1:-1, 2:] + m_padded[1:-1, 1:-1, :-2]) +
        center_coeff * m
    )

We check that all methods give identical results.

[5]:
assert np.array_equal(laplacian_concat(m), laplacian_roll(m))
assert np.array_equal(laplacian_concat(m), laplacian_pad(m))

We compare execution times.

[6]:
%timeit L1 = laplacian_concat(m)
14 ms ± 445 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[7]:
%timeit L2 = laplacian_roll(m)
15.3 ms ± 408 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[8]:
%timeit L3 = laplacian_pad(m)
15.4 ms ± 488 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Method 3 is therefore slightly faster.

Method 4: JAX with JIT compilation

Note: @jax.jit has compilation overhead that can make it slower for small problems or single calls. It’s more efficient for:

  • Large arrays (>1M elements)

  • Repeated calls on same-shaped arrays

  • GPU computation

  • Complex functions with many operations

[9]:
import jax
import jax.numpy as jnp
# enforce cpu
jax.config.update("jax_platform_name", "cpu")
# Force JAX to use float64 (double precision)
jax.config.update("jax_enable_x64", True)

print("JAX devices:", jax.devices())
print("JAX default backend:", jax.default_backend())

def laplacian_jax(m, dx=1.0, dy=1.0, dz=1.0):
    """Laplacian operator using JAX with proper Neumann boundary conditions."""
    dx2_inv, dy2_inv, dz2_inv = 1/dx**2, 1/dy**2, 1/dz**2
    center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)

    m_jax = jnp.array(m)
    m_padded = jnp.pad(m_jax, ((1,1), (1,1), (1,1)), mode='reflect')

    return (
        dx2_inv * (m_padded[2:, 1:-1, 1:-1] + m_padded[:-2, 1:-1, 1:-1]) +
        dy2_inv * (m_padded[1:-1, 2:, 1:-1] + m_padded[1:-1, :-2, 1:-1]) +
        dz2_inv * (m_padded[1:-1, 1:-1, 2:] + m_padded[1:-1, 1:-1, :-2]) +
        center_coeff * m_jax
    )

# Pre-compiled JIT version for fair benchmarking
@jax.jit
def _laplacian_jax_compiled(m_jax, dx=1.0, dy=1.0, dz=1.0):
    """Pre-compiled JIT version to exclude compilation time from benchmarks."""
    dx2_inv, dy2_inv, dz2_inv = 1/dx**2, 1/dy**2, 1/dz**2
    center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)

    m_padded = jnp.pad(m_jax, ((1,1), (1,1), (1,1)), mode='reflect')

    return (
        dx2_inv * (m_padded[2:, 1:-1, 1:-1] + m_padded[:-2, 1:-1, 1:-1]) +
        dy2_inv * (m_padded[1:-1, 2:, 1:-1] + m_padded[1:-1, :-2, 1:-1]) +
        dz2_inv * (m_padded[1:-1, 1:-1, 2:] + m_padded[1:-1, 1:-1, :-2]) +
        center_coeff * m_jax
    )

def laplacian_jax_jit(m, dx=1.0, dy=1.0, dz=1.0):
    """Wrapper for pre-compiled JIT function."""
    return _laplacian_jax_compiled(jnp.array(m), dx, dy, dz)
JAX devices: [CpuDevice(id=0)]
JAX default backend: cpu
[10]:
# JAX comparison (might have tiny numerical differences)
jax_result = laplacian_jax(m)
jax_jit_result = laplacian_jax_jit(m)

diff_fast = np.abs(laplacian_concat(m) - jax_result).max()
diff_jit = np.abs(laplacian_concat(m) - jax_jit_result).max()

print(f"Max difference JAX (no JIT): {diff_fast:.2e}")
print(f"Max difference JAX (JIT):    {diff_jit:.2e}")

assert np.allclose(laplacian_concat(m), jax_result, atol=1e-15)
assert np.allclose(laplacian_concat(m), jax_jit_result, atol=1e-15)
Max difference JAX (no JIT): 0.00e+00
Max difference JAX (JIT):    8.88e-16

JAX without JIT (baseline)

[11]:
%timeit L4 = laplacian_jax(m)
10.7 ms ± 129 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

JAX with JIT

[12]:
# Pre-compile the JIT function (this triggers compilation)
print("Pre-compiling JIT function...")
_ = laplacian_jax_jit(m)  # Trigger compilation
print("Compilation done. Now benchmarking JIT performance...")

# JAX with JIT (pre-compiled, fair benchmark)
%timeit L4_jit = laplacian_jax_jit(m)
Pre-compiling JIT function...
Compilation done. Now benchmarking JIT performance...
2.72 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)