The problem that made me look for an alternative
. My work involves taking models of the Universe – dark energy equations of state, modified gravity, tachyonic fields – and asking: what do the data actually say about the parameters? The tool for that question is Bayesian inference. I usually run dynesty nested sampling for a few thousand to a few hundred thousand likelihood evaluations depending upon the complexity of the model.
For most of my PhD, I did not think much about the ODE solver inside the likelihood as solve_ivp worked. It was reliable. Hence I used it and moved on.
Then I started working on a tachyonic DBI dark energy model where the dark energy field is governed by a non-standard kinetic term, and the background and perturbation equations are a coupled stiff-ish system. Each likelihood call solved those ODEs, computed the comoving distance, and evaluated the distance modulus at the redshifts of 30 supernovae.
I profiled it. The ODE solve alone was taking 0.4 ms per call. In a nested sampling run with 10⁵ evaluations, that is 40 seconds — just in ODE calls, before you count any bookkeeping. And for a 10-parameter model, getting a gradient via central finite differences costs 20 extra forward solves, turning those 0.4 ms into 8 ms per gradient. That is 300 seconds, or about 5 minutes, just for the gradients. For a single nested sampling run.
Something had to change.
Figure 1: Where time goes in a dynesty nested-sampling run on a flat ΛCDM model against 30 mock supernovae. Left: scipy pipeline — ODE solve 40 s, FD gradient 98 s, overhead 30 s. Right: diffrax pipeline — total ODE + gradient cost: 24.8 s. (Image created by author)
What I found: diffrax
After a day of searching, I landed on diffrax [1], a library of numerical ODE solvers written entirely in JAX. Not a neural surrogate. Not an approximation. The same embedded Runge–Kutta algorithms I already use in scipy — Tsit5 instead of RK45, but the same family of methods — just compiled, differentiable, and vectorisable.
Three properties come from the “written entirely in JAX” design:
JIT compilation – The entire adaptive-stepping loop compiles to a single XLA kernel. Zero Python overhead after the first call.
Autodiff – Because every operation inside the solver is a JAX primitive, jax.grad propagates gradients through the solve. Exact gradients. One backward pass. Regardless of how many parameters.
vmap – An entire batch of parameter vectors can be solved in parallel with jax.vmap. Critical for nested sampling.
Installing it takes 10 seconds:
pip install jax diffrax
The test problem: flat ΛCDM from supernovae
To make the comparison concrete, let me show the exact problem I was working with. In a flat ΛCDM universe, the comoving distance satisfies:
dχdz=cH(z),H(z)=H0Ωm(1+z)3+(1−Ωm),χ(0)=0\frac{d\chi}{dz} = \frac{c}{H(z)}, \quad H(z) = H_0\sqrt{\Omega_m(1+z)^3 + (1-\Omega_m)}, \quad \chi(0)=0
The distance modulus follows: μ(z) = 5 log₁₀[(1+z)χ(z) / 10 pc]. I want to infer (Ωₘ, H₀) from 30 mock SNIa distance-modulus observations.
from scipy.integrate import solve_ivp
import numpy as np
C_KMS = 299792.458 # speed of light [km/s]
def rhs(z, chi, Om, H0):
return C_KMS / (H0 * np.sqrt(Om*(1+z)**3 + (1-Om)))
def forward_scipy(Om, H0, z_obs):
sol = solve_ivp(rhs, t_span=(0, z_obs[-1]),
y0=[0.0], t_eval=z_obs,
args=(Om, H0), method=”RK45″,
rtol=1e-8, atol=1e-10)
chi = sol.y[0]
return 5 * np.log10((1 + z_obs) * chi * 1e5) # distance modulus
The old way: SciPy
from scipy.integrate import solve_ivp
import numpy as np
C_KMS = 299792.458 # speed of light [km/s]
def rhs(z, chi, Om, H0):
return C_KMS / (H0 * np.sqrt(Om*(1+z)**3 + (1-Om)))
def forward_scipy(Om, H0, z_obs):
sol = solve_ivp(rhs, t_span=(0, z_obs[-1]),
y0=[0.0], t_eval=z_obs,
args=(Om, H0), method=”RK45″,
rtol=1e-8, atol=1e-10)
chi = sol.y[0]
return 5 * np.log10((1 + z_obs) * chi * 1e5) # distance modulus
The new way: Diffrax
import jax, jax.numpy as jnp
import diffrax as dfx
# Non-negotiable: enable 64-bit (more on this below)
jax.config.update(“jax_enable_x64”, True)
def H_jax(z, Om, H0):
return H0 * jnp.sqrt(Om*(1+z)**3 + (1-Om))
@jax.jit # compile once, call fast forever
def forward_diffrax(theta, z_obs):
Om, H0 = theta[0], theta[1]
sol = dfx.diffeqsolve(
dfx.ODETerm(lambda z, chi, a: C_KMS / H_jax(z, a[0], a[1])),
dfx.Tsit5(),
t0=0.0, t1=float(z_obs[-1]), # initial and final value
dt0=1e-3, # initial step-size
y0=jnp.array(0.0), # initial condition
args=(Om, H0),
saveat=dfx.SaveAt(ts=z_obs),
stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-10),
max_steps=10_000,
)
chi = sol.ys
return 5 * jnp.log10((1 + z_obs) * chi * 1e5)
The physics is identical. The solver algorithm is nearly identical (Tsit5 is very similar to RK45). The only structural differences are @jax.jit and the diffrax API. Let us look at what those two changes buy.
Surprise 1: the speed
solve_ivp: 404 μs per call. diffrax post-JIT: 59 μs per call. That is 07× faster.
I stared at this number for a few seconds the first time I saw it. Let me be honest about where the speedup actually comes from, because it is not magic.
In solve_ivp, Python re-enters the C/Cython backend on every call. Memory is allocated fresh. The adaptive while-loop goes through the Python interpreter asking: “is the local error too large? reject; else grow the step; repeat.” For a 12-step solve, that is 12 rounds of Python dispatch, 12 allocations, 12 error-estimate computations sitting behind the interpreter lock.
In diffrax, the first @jax.jit call traces the entire computation – including the adaptive while-loop, which is lowered to a lax.while_loop and hands it to XLA to compile into a machine-code kernel. Every subsequent call executes that kernel directly. Therefore no Python, no need for allocation and no dispatch.
Figure 2: Left: single-call timing for the comoving-distance ODE at rtol = 10⁻⁸. Right: the inference problem — 30 mock supernovae, 0.1-mag noise. Both solvers produce identical μ(z) curves; only the speed differs. (Image created by author)
For 100,000 likelihood evaluations, 404 μs vs 59 μs translates to 40.4 seconds vs 5.9 seconds. That is the difference that get enhanced when the model complexity increases.
Surprise 2: gradients become free
This was the part that changed not just my workflow but how I think about inference. With scipy, getting one gradient of the log-likelihood with respect to 2 parameters (Ωₘ, H₀) costs 4 forward solves (central finite differences). Once you start turning the dial up, it gets expensive fast: 10 parameters means 20 forward solves, 50 parameters means 100. The bill grows linearly with the number of parameters.
∂ℱ∂Ωm≈ℱ(Ωm+h,H0)−ℱ(Ωm−h,H0)2h,∂ℱ∂H0≈ℱ(Ωm,H0+h)−ℱ(Ωm,H0−h)2h\frac{\partial\mathcal{F}}{\partial\Omega_m} \approx \frac{\mathcal{F}(\Omega_m+h,H_0) – \mathcal{F}(\Omega_m-h,H_0)}{2h}, \qquad \frac{\partial\mathcal{F}}{\partial H_0} \approx \frac{\mathcal{F}(\Omega_m,H_0+h) – \mathcal{F}(\Omega_m,H_0-h)}{2h}
With diffrax, I write:
def loss(theta):
mu_pred = forward_diffrax(theta, z_obs)
return 0.5 * jnp.sum(((mu_pred – mu_obs) / sigma_mu)**2)
grad_fn = jax.jit(jax.grad(loss)) # that is the entire change
g = grad_fn(jnp.array([0.3, 70.0])) # exact gradient
Under the hood, JAX’s reverse-mode autodiff integrates the adjoint equations [2] backward through the ODE solve – but I never have to write those equations. The result is an exact gradient in time comparable to one forward pass, independent of the number of parameters.
Figure 3: Left: cost of one gradient on the 2-parameter likelihood. Scipy with central finite differences costs 1.62 ms (4 ODE calls). Diffrax with autodiff costs 195 μs — a 8× saving. Right: the log-likelihood surface −logℒ(Ωₘ, H₀) with an autodiff gradient arrow pointing correctly toward lower loss. (Image created by author)
How to choose a solver
When it comes to picking a solver, you have to be a little careful. I defaulted to Tsit5 for almost everything, and it handled about 95% of my problems without complaint. If you want the whole decision process, here it is:
- Non-stiff ODE (most cosmological problems) → dfx.Tsit5() ← start here
- Very tight tolerances (< 10⁻⁹) → dfx.Dopri8()
- Stiff ODE (many steps, solver seems slow) → dfx.Kvaerno5()
- Stiff + non-stiff terms (IMEX) → dfx.KenCarp4()
- SDE → dfx.EulerHeun() or dfx.SPaRK()
A quick way to tell if your problem is stiff: print sol.stats[“num_steps”]. If it is 10–100× more than you expect, the problem is stiff and you need an implicit solver.
The payoff: cosmological inference end-to-end
Now, let me show the full inference comparison. I start both pipelines from the same bad initial guess (Ωₘ, H₀) = (0.10, 60), well away from the truth (0.30, 70), and run 350 gradient steps.
- scipy pipeline: gradient from central finite differences, simple gradient descent, fixed learning rate.
- diffrax pipeline: gradient from autodiff, Adam optimiser with a cosine-decay learning-rate schedule.
import optax # optimisers for JAX
# Scale parameters so Adam can handle them equally
# Om ~ 0.3, h = H0/100 ~ 0.7 — both O(1) now
def loss_scaled(theta_s):
theta = jnp.array([theta_s[0], 100.0 * theta_s[1]])
return loss(theta)
grad_scaled = jax.jit(jax.grad(loss_scaled))
schedule = optax.cosine_decay_schedule(
init_value=0.05, decay_steps=350, alpha=0.04)
opt = optax.adam(schedule)
theta = jnp.array([0.10, 0.60]) # start far from truth
state = opt.init(theta)
for step in range(350):
g = grad_scaled(theta)
updates, state = opt.update(g, state)
theta = optax.apply_updates(theta, updates)
if (step + 1) % 50 == 0:
print(f”Step {step+1}: Om={theta[0]:.3f} H0={100*theta[1]:.2f}”)
Figure 4: MAP inference on flat ΛCDM from 30 mock SNIa. Diffrax (green) with Adam + autodiff: converges to Ωₘ = 0.270, H₀ = 70.94. Scipy (red) with simple gradient descent + finite differences: gets stuck at Ωₘ = 0.65, H₀ = 60 – a completely wrong region. (Image created by author)
While the diffrax pipeline recovers physically sensible parameters, the scipy pipeline cannot simultaneously move both parameters – a textbook failure of gradient descent on poorly-scaled problems. Adam handles this automatically through its per-parameter adaptive learning rates, but Adam is only available because autodiff gives me exact gradients to feed it.
Three things I got wrong (so you do not have to)
Figure 5: Left: 32-bit precision causes the same ODE to take 5.75× more steps. Centre: first JIT call pays 93 ms compilation, subsequent calls are ~1550× faster. Right: odeint reverses the argument order to f(y, t) — a completely silent error. (Image created by author)
Caveat 1: forgetting 64-bit precision. JAX defaults to 32-bit floats. If you push the tolerances (rtol < 10⁻⁷), that can lead to some very odd results: on my ODE the solver needs 69 steps in 32-bit, but only 12 in 64-bit. Tighten the tolerances further and it can fail outright. The fix is simple — enable 64-bit before you do anything else:
jax.config.update(“jax_enable_x64”, True) # must be first
Caveat 2: benchmarking without warming up. The first call to any @jax.jit-decorated function includes a one-off compilation hit of about 90–100 ms. If you include that in your timings, diffrax will look slower than scipy for the wrong reason. The fix is to warm up once and throw away that first run:
_ = forward_diffrax(theta, z_obs).block_until_ready() # compile
# NOW benchmark — this is the real speed
Also: JAX dispatches asynchronously. Always call .block_until_ready() in timing loops or you measure the time to submit work, not finish it.
Caveat 3: the argument-order trap. scipy.odeint expects f(y, t) (state first, time second). Almost everything else (solve_ivp, diffrax) expects f(t, y). If you port old odeint code to diffrax without swapping the arguments, you end up solving a different ODE and you usually won’t get an error. You’ll just get the wrong answer.
Should you make the switch?
The honest answer is this: if you’re solving a one-off ODE and you don’t need gradients, solve_ivp is perfectly fine — there’s no need to learn a new API. But if you’re doing inference (repeated likelihood evaluations, parameter gradients, or batched solves), the switch is worth the effort.
Situationsolve_ivpodeintdiffraxOne-off solve, no inference✓✓fine tooNested sampling / MCMCslowslowYESNeed gradientsFD onlyFD onlyexact, freeBatch over parameter gridfor-loopfor-loopvmapStiff systemRadauauto (LSODA)Kvaerno5SDE or Neural ODEnonoYESGPU/TPUnonoYES
The migration itself is small. The forward model changes by about six lines. The gradient appears by adding one more line. The rest of the inference code stays identical.
One thing we must mention here, diffrax is not “ML-based” in the sense of using a neural network. It is the same classical Runge–Kutta mathematics, written in JAX. The “ML acceleration” comes from JIT compilation and autodiff – both infrastructure tools from the ML world applied to a classical numerical solver. The only genuinely ML-based approach would be a neural surrogate that learns θ → μ(z) from training data – a separate and more advanced topic.
The complete working code
Everything above in one self-contained script (pip install jax diffrax optax):
“””
flat_lcdm_inference.py
Infer (Omega_m, H0) from 30 mock supernovae using diffrax + Adam.
pip install jax diffrax optax
“””
import jax, jax.numpy as jnp, numpy as np
import diffrax as dfx, optax
from scipy.integrate import solve_ivp # only for generating mock data
jax.config.update(“jax_enable_x64″, True)
# — Constants and data ———————————————–
C_KMS = 299792.458
z_obs = jnp.linspace(0.05, 1.5, 30)
SIGMA = 0.10
# Mock data at truth (Om=0.30, H0=70)
def chi_np(Om, H0):
sol = solve_ivp(lambda z, y: C_KMS/(H0*np.sqrt(Om*(1+z)**3+(1-Om))),
(0, 1.5), [0.], t_eval=np.array(z_obs), rtol=1e-10)
return sol.y[0]
mu_true = 5*np.log10((1+np.array(z_obs))*chi_np(0.3, 70.)*1e5)
mu_obs = jnp.array(mu_true + 0.10*np.random.default_rng(42).standard_normal(30))
# — diffrax forward model ——————————————–
@jax.jit
def forward(theta):
Om, H0 = theta[0], theta[1]
sol = dfx.diffeqsolve(
dfx.ODETerm(lambda z, chi, a:
C_KMS/(a[1]*jnp.sqrt(a[0]*(1+z)**3+(1-a[0])))),
dfx.Tsit5(),
t0=0., t1=1.5, dt0=1e-3, y0=jnp.array(0.),
args=(Om, H0),
saveat=dfx.SaveAt(ts=z_obs),
stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-10),
max_steps=10_000,
).ys
return 5*jnp.log10((1+z_obs)*sol*1e5)
# — Loss and gradient ————————————————
def loss(th_s): # optimise in scaled coords (Om, h=H0/100)
mu = forward(jnp.array([th_s[0], 100.*th_s[1]]))
return 0.5*jnp.sum(((mu – mu_obs)/SIGMA)**2)
grad_fn = jax.jit(jax.grad(loss))
# Warm up the JIT compiler
theta_init = jnp.array([0.10, 0.60])
_ = forward(jnp.array([0.3, 0.7])).block_until_ready()
_ = grad_fn(theta_init).block_until_ready()
# — Adam optimiser with cosine LR schedule —————————
sched = optax.cosine_decay_schedule(init_value=0.05, decay_steps=350, alpha=0.04)
opt = optax.adam(sched)
theta = theta_init
state = opt.init(theta)
print(f”{‘Step’:>5} {‘Om’:>7} {‘H0’:>7} {‘Loss’:>8}”)
for step in range(350):
g = grad_fn(theta)
upd, state = opt.update(g, state)
theta = optax.apply_updates(theta, upd)
if (step + 1) % 70 == 0 or step == 0:
L = float(loss(theta))
print(f”{step+1:5d} {float(theta[0]):7.4f} {100*float(theta[1]):7.3f} {L:8.2f}”)
Om_fit, H0_fit = float(theta[0]), 100*float(theta[1])
print(f”\nFinal: Om = {Om_fit:.3f} H0 = {H0_fit:.2f}”)
print(f”Truth: Om = 0.300 H0 = 70.00″)
Numbers at a glance
MeasurementscipydiffraxSpeedupSingle forward call 0.4 ms 57 μs ~07×Gradient (2 params)1.62 ms195 μs~08×10⁵ forward calls40 s5.9 s ~07×10⁵ gradient calls~98 s~19.6 s~05×Final Ωₘ (350 steps)0.652 (wrong) 0.270 —Final H₀ (350 steps)60.10 (stuck) 70.94 —
The “wrong” scipy result is not a solver failure – it reflects that simple gradient descent with finite-difference gradients cannot handle the 200× scale mismatch between Ωₘ and H₀.
Final thought
Switching my forward model to diffrax did not change the physics or the inference method. It changed the practical feasibility of doing that inference at all. A nested-sampling run that was heading toward a large time forward-model budget became a less than a minutes one. The gradients that were going to cost 20 extra solves per step became essentially free.
The learning curve was about one afternoon. The debugging was mostly the 64-bit caveat and the JIT warmup confusion. The payoff has been real and immediate.
If you are a physicist using scipy for repeated likelihood evaluations and you have not looked at diffrax yet, I hope this gives you a reason to.
A note on reproducibility: the exact timings you see will differ on your machine and even between runs on the same machine. On my Mac (Macbook Air M3 Base Model), the diffrax forward call varied between 55 µs and 62 µs across sessions, and scipy varied between 400 µs and 407 µs. This is normal – CPU thermal state, OS scheduling, and memory cache conditions all shift the absolute numbers by 10–15%. What stays stable is the ratio: diffrax is consistently 07–08× faster than scipy on this problem. The ratio, not the absolute time, is the number to take away.
The Python code that generated every figure in this article is available at: github.com/Samit1424/ODE_solver_comparison
Note : Excluding the featured image, which was produced using AI tool, all illustrations are of author’s original work.
References
[1] P. Kidger, On Neural Differential Equations, DPhil thesis, University of Oxford, 2021. docs.kidger.site/diffrax/
[2] R. T. Q. Chen, Y. Rubanova, J. Bettencourt, D. Duvenaud, Neural Ordinary Differential Equations, NeurIPS 2018.
