The JAX Jacobian factory pattern
Aragog's production integrator path is SUNDIALS CVODE with an analytic Jacobian built by JAX. The design choice that makes this practical is the factory pattern: the solver does not own the JAX trace itself but accepts a callable that constructs the (RHS, Jacobian) pair from the bound mesh, EOS, phase parameters, and boundary conditions at solve time.
This page explains why the factory is the unit of separation and what its consequences are.
What the factory returns
build_jax_rhs_and_jacobian (in aragog.solver.cvode_jax) takes the runtime state of an EntropySolver:
- the staggered + basic mesh arrays,
- the JAX-friendly EOS (
EntropyEOS_JAX), - the bound
PhaseParams, - the bound
BoundaryParams, - the radiogenic + tidal heating closures,
and returns a tuple (rhs, jac):
rhs(t, y)evaluates \(dS/dt\) for a given state,jac(t, y)evaluates \(\partial(dS/dt)/\partial y\) viajax.jacrev.
Both callables are JAX pure_callback-wrapped JITted functions. CVODE calls them like any other RHS / Jacobian; it never sees the JAX layer directly.
Why a factory and not a method
The Jacobian depends on more than just (t, y). It depends on the EOS tables, the mesh geometry, the radiogenic heating closure, the boundary-condition mode, and the phase parameters. All of these are constants over a single solve() call but vary across calls (PROTEUS rebuilds the mesh on every Zalmoxis re-solve).
Three competing designs were considered:
-
Make
EntropySolverown the JAX trace. The solver would need to import JAX at module load time, which would break installs that don't have[jax]. JAX is an optional extra precisely because we want users on platforms where JAX cannot be installed (rare but real) to still run the scipy paths. -
Build the trace lazily inside
solve(). The trace then has to walk intoself.parameters, which is aParametersdataclass with numpy arrays and Python floats. JAX cannot trace throughParametersdirectly; it would have to be unpacked at every call. The unpacking is non-trivial (e.g.mesh.basic_radii,eos._tables['temperature_solid']) and would couple the solver class to the JAX representation. -
Factory pattern (the chosen design). The solver exposes a
set_jax_cvode_factoryhook. PROTEUS (or any caller that has JAX installed) importsbuild_jax_rhs_and_jacobianand registers it once. Insidesolve(), the solver calls the factory with the already-unpacked mesh / EOS / params and gets back the two callables. The solver class never imports JAX. The factory function lives in a module that is only imported by callers who explicitly opt in.
The third option keeps the solver class JAX-agnostic and lets the factory be the only place that knows the trace shape. Refactors of the JAX layer (e.g. splitting compute_fluxes into smaller pieces, adding a new flux contribution) only touch cvode_jax.py and aragog.jax.phase; the solver is unaffected.
What "registered before solve()" means
EntropySolver.set_jax_cvode_factory(build_jax_rhs_and_jacobian) stores the factory on the instance. When solve() runs with solver_method = 'cvode' and use_jax_jacobian = true, it:
- Calls the factory with the bound runtime objects,
- Hands the resulting
(rhs, jac)to CVODE, - Lets CVODE drive the integration.
If the factory is not registered, the solver silently falls back to CVODE's built-in finite-difference Jacobian. This is intentional: it lets the standalone unit-test suite run without JAX installed, and it gives users on JAX-less platforms a working (slower) path.
The fallback is deliberate; an exception would make the JAX extra a hard dependency. The cost is that a misconfigured environment (JAX installed, factory not registered) silently runs at FD-Jacobian speed instead of failing loudly. Mitigations:
- The CLI emits a one-line
INFOat solve time naming the active Jacobian path. aragog --versionslists JAX, so a bug report on slow integration can be diagnosed remotely.- The PROTEUS wrapper always registers the factory; standalone users who care about speed read the docs.
Bit-parity with the numpy path
Every flux contribution in aragog.jax.phase is constructed to match the numpy EntropyState.update to floating-point precision. The verification suite asserts equality on a per-component basis (conduction, convection, gravitational separation, mixing) for representative (P, S, T, mesh) tuples; see first-principles verification.
Bit-parity is the contract that lets the JAX path be the production default without breaking a regression that was tuned against numpy. If the numpy path is being used as a "ground truth" and the JAX trace diverges, the parity tests catch the drift before it lands.
When the factory itself is the problem
A run that hits unexpected divergence under JAX but converges under FD points to a problem in the trace, not in CVODE. Diagnostic loop:
- Set
use_jax_jacobian = false, rerun. If it converges, the bug is in the trace. - Add a print of
jac(t, y)at the failing step to the factory; compare against a numerical Jacobian computed by FD. The trace and FD must match to the tolerance ofjax.jacrevrounding. - If they do not match, suspect a non-differentiable branch (a tanh switch with too-narrow width, a
whereover a NaN, a JAX-unfriendly clip).
For broader CVODE+JAX rationale and the full integrator dispatch decision tree see CVODE and JAX. For the actual flux assembly under each path see Heat transport.