Skip to content

JAX EOS subpackage

The zalmoxis.jax_eos subpackage holds line-by-line JAX ports of the performance-critical EOS kernels and the structure ODE driver. The numpy implementations in zalmoxis.eos remain the reference; kernels here must match them within solver tolerance (rtol \(\leq 10^{-12}\) on bilinear, rtol \(\leq 10^{-6}\) on the full RHS, rtol \(\leq 10^{-5}\) end-to-end at default integrator tolerance).

The path is gated by config_params['use_jax']; when False (the default) none of these modules are imported. Float64 is enforced at import via jax.config.update('jax_enable_x64', True) because JAX defaults to float32 and would otherwise lose ~\(10^{-7}\) precision relative to the numpy reference. Scope today is the Stage-1b two-layer config (PALEOS:iron core + PALEOS-2phase:MgSiO\(_3\) mantle); other configurations transparently fall back to the numpy path at the caller.

Submodule Purpose
bilinear fast_bilinear_jax, paleos_clamp_temperature_jax
paleos get_paleos_unified_density_jax, mushy-zone branches
tdep get_Tdep_density_jax for the PALEOS-2phase mantle
rhs coupled_odes_jax (jax-traceable structure RHS)
solver solve_structure_jax via diffrax.diffeqsolve(Tsit5) with event-based pressure-zero termination
wrapper solve_structure_via_jax (numpy-signature entry point used by structure_model.solve_structure); accepts both temperature_function(r, P) and an explicit temperature_arrays = (r_arr, T_arr)

jax_eos

JAX ports of Zalmoxis EOS inner kernels.

This subpackage holds jax.numpy / jax.jit reimplementations of the performance-critical EOS kernels. The numpy implementations in zalmoxis.eos remain the reference; kernels here must match them within solver-tolerance (rtol <= 1e-4 on all physics fields).

Module layout mirrors the numpy side: bilinear — fast_bilinear_jax, paleos_clamp_temperature_jax paleos — get_paleos_unified_density_jax rhs — coupled_odes_jax RHS solver — diffrax-based solve_structure replacement

The JAX path is gated behind the use_jax entry in config_params; when use_jax is False (the default) the numpy path is used unchanged.

JAX x64 mode is enabled at import. Without this, JAX defaults to float32 and all downstream density/pressure/temperature calculations lose ~1e-7 relative precision versus the numpy reference, far above the parity tolerance (rtol <= 1e-12) we require for kernel-level parity.