Skip to content

aragog.jax

The aragog.jax package contains JAX-traceable replicas of the EOS, phase evaluator, and dSdt right-hand side. They are used to build an analytic Jacobian via jax.jacrev and feed it to SUNDIALS CVODE through the registered factory in aragog.solver.cvode_jax.

This module is loaded only when solver.use_jax_jacobian = true (the production default). The numpy path in aragog.solver.entropy_state remains the reference implementation and is exercised by the standalone tests; both must agree to numerical precision (see first-principles verification).

For when and why to opt into the JAX path, see CVODE and JAX-derived Jacobians.

Submodule Role
aragog.jax.eos EntropyEOS_JAX, the JAX-traceable P-S table loader, and PhaseState, the per-cell phase cache. Mirrors the public surface of aragog.eos.EntropyEOS.
aragog.jax.phase PhaseParams, MeshArrays, PhaseProperties, FluxOutput, compute_fluxes, compute_mlt, evaluate_phase. SPIDER-parity two-stage blend, mixing-length convection, and pure-functional flux assembly.
aragog.jax.solver BoundaryParams, SolveResult. The standalone JAX solve path used by the verification suite; in production the JAX path supplies only the Jacobian and CVODE drives the integration.
aragog.jax.nondim NonDimScales. Reference scales used to non-dimensionalise state and time before passing them to the integrator.

aragog.jax.eos

JAX-traceable P-S equation of state. Loads SPIDER-format two-phase tables and provides temperature, density, melt_fraction, and the latent-heat / phase-boundary derivatives via jax.jit-compatible bilinear interpolation.

eos

JAX-based entropy EOS layer for PALEOS P-S tables.

Drop-in replacement for aragog.eos.entropy.EntropyEOS using JAX arrays and jax.scipy.interpolate.RegularGridInterpolator. All methods are JIT-compilable and differentiable via jax.grad.

Table loading uses the existing numpy loader (disk I/O is not JIT-compiled). The loaded grids are converted to JAX arrays and stored as equinox Module fields so the entire EOS object is a valid JAX pytree.

Dependencies: jax, equinox (both already in PROTEUS ecosystem via atmodeller).

EntropyEOS_JAX(eos_dir)

Bases: Module

JAX-based entropy EOS from PALEOS P-S tables.

Drop-in replacement for aragog.eos.entropy.EntropyEOS with all lookups JIT-compilable and differentiable. Constructed from the same SPIDER-format table files.

Parameters:

Name Type Description Default
eos_dir Path or str

Directory containing the SPIDER-format P-S table files.

required
Source code in src/aragog/jax/eos.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
def __init__(self, eos_dir: Path | str):
    eos_dir = Path(eos_dir)
    if not eos_dir.is_dir():
        raise FileNotFoundError(f'EOS directory not found: {eos_dir}')

    logger.info('Loading JAX entropy EOS from %s', eos_dir)

    # Load tables from disk (numpy) and convert to JAX interpolators
    def _make_table(name: str, phase: str) -> _Table2D:
        if name == 'dTdPs':
            fname = f'adiabat_temp_grad_{phase}.dat'
        else:
            fname = f'{name}_{phase}.dat'
        t = _load_spider_ps_table(eos_dir / fname)
        return _Table2D(t['P'], t['S'], t['values'])

    self._temperature_solid = _make_table('temperature', 'solid')
    self._temperature_melt = _make_table('temperature', 'melt')
    self._density_solid = _make_table('density', 'solid')
    self._density_melt = _make_table('density', 'melt')
    self._heat_capacity_solid = _make_table('heat_capacity', 'solid')
    self._heat_capacity_melt = _make_table('heat_capacity', 'melt')
    self._dTdPs_solid = _make_table('dTdPs', 'solid')
    self._dTdPs_melt = _make_table('dTdPs', 'melt')

    # Phase boundaries
    sol = _load_spider_phase_boundary(eos_dir / 'solidus_P-S.dat')
    liq = _load_spider_phase_boundary(eos_dir / 'liquidus_P-S.dat')
    self._solidus = _PhaseBoundary1D(sol['P'], sol['S'])
    self._liquidus = _PhaseBoundary1D(liq['P'], liq['S'])

    # Domain bounds
    self.P_min = min(self._temperature_solid.P_min, self._temperature_melt.P_min)
    self.P_max = max(self._temperature_solid.P_max, self._temperature_melt.P_max)
    self.S_min = min(self._temperature_solid.S_min, self._temperature_melt.S_min)
    self.S_max = max(self._temperature_solid.S_max, self._temperature_melt.S_max)

    logger.info(
        'JAX EOS loaded: P=[%.2e, %.2e] Pa, S=[%.0f, %.0f] J/kg/K',
        self.P_min,
        self.P_max,
        self.S_min,
        self.S_max,
    )

compute_phase_state(P, S, k_solid, k_liquid, matprop_smooth_width=0.0)

Single-pass SPIDER-parity phase evaluation (cp_blend='latent').

Bit-for-bit mirror of numpy EntropyPhaseEvaluator._update_eos with cp_blend='latent'. All properties share the same intermediates, and each is the smth-blend smth * mixed + (1 - smth) * single matching SPIDER combine_matprop (eos_composite.c:278-285).

Parameters:

Name Type Description Default
P Array

Pressure [Pa] and entropy [J/kg/K], same shape.

required
S Array

Pressure [Pa] and entropy [J/kg/K], same shape.

required
k_solid float

Single-phase thermal conductivities [W/m/K].

required
k_liquid float

Single-phase thermal conductivities [W/m/K].

required
matprop_smooth_width float

SPIDER's -matprop_smooth_width. 0.0 reproduces the sharp smth=1 inside [0,1] convention; 0.01 is the production setting.

0.0

Returns:

Type Description
PhaseState

All blended properties and shared intermediates.

Source code in src/aragog/jax/eos.py
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
def compute_phase_state(
    self,
    P: jax.Array,
    S: jax.Array,
    k_solid: float,
    k_liquid: float,
    matprop_smooth_width: float = 0.0,
) -> PhaseState:
    """Single-pass SPIDER-parity phase evaluation (cp_blend='latent').

    Bit-for-bit mirror of numpy ``EntropyPhaseEvaluator._update_eos``
    with ``cp_blend='latent'``. All properties share the same
    intermediates, and each is the smth-blend
    ``smth * mixed + (1 - smth) * single`` matching SPIDER
    ``combine_matprop`` (eos_composite.c:278-285).

    Parameters
    ----------
    P, S : jax.Array
        Pressure [Pa] and entropy [J/kg/K], same shape.
    k_solid, k_liquid : float
        Single-phase thermal conductivities [W/m/K].
    matprop_smooth_width : float, default 0.0
        SPIDER's ``-matprop_smooth_width``. ``0.0`` reproduces the
        sharp ``smth=1`` inside [0,1] convention; ``0.01`` is the
        production setting.

    Returns
    -------
    PhaseState
        All blended properties and shared intermediates.
    """
    # ── Step 1: phase boundaries (computed ONCE) ────────────────
    S_sol = self.solidus_entropy(P)
    S_liq = self.liquidus_entropy(P)
    dS_phase = jnp.maximum(S_liq - S_sol, 1e-10)
    gphi = (S - S_sol) / dS_phase
    phi_arr = jnp.clip(gphi, 0.0, 1.0)

    # smth: matprop_smooth_width blend factor
    # (SPIDER util.c:get_smoothing). matprop_smooth_width is a static
    # Python float here, so the if-branch is resolved at trace time.
    if matprop_smooth_width > 0:
        smth = jnp.where(
            gphi > 0.5,
            1.0 - _tanh_weight_jax(gphi, 1.0, matprop_smooth_width),
            _tanh_weight_jax(gphi, 0.0, matprop_smooth_width),
        )
    else:
        smth = jnp.where((gphi >= 0.0) & (gphi <= 1.0), 1.0, 0.0)

    # ── Step 2: phase-boundary table evaluations (ONCE each) ────
    T_sol = self._lookup_at_phase_boundary('temperature', P, 'solid')
    T_liq = self._lookup_at_phase_boundary('temperature', P, 'melt')
    rho_sol = self._lookup_at_phase_boundary('density', P, 'solid')
    rho_liq = self._lookup_at_phase_boundary('density', P, 'melt')

    # ── Step 3: intermediate two-phase ('mixed') properties ─────
    dT_phase = jnp.maximum(T_liq - T_sol, 1e-10)
    T_avg = T_sol + 0.5 * dT_phase

    # T: linear blend
    T_mixed = phi_arr * T_liq + (1.0 - phi_arr) * T_sol

    # rho: harmonic mean
    inv_rho_mixed = phi_arr / jnp.maximum(rho_liq, 1.0) + (1.0 - phi_arr) / jnp.maximum(
        rho_sol, 1.0
    )
    rho_mixed = 1.0 / jnp.maximum(inv_rho_mixed, 1e-30)

    # alpha and Cp: latent-heat-augmented (SPIDER eos_composite.c:227-246).
    # The 100 J/kg/K floor on Cp_mixed is a defensive guard against
    # division-by-near-zero when the latent budget collapses
    # (dT_phase very large, or S_liq -> S_sol, or T_avg small near
    # the eutectic). MgSiO3 production runs are always well above
    # this floor; a triggering EOS is the signal that the upstream
    # property tables need clipping, not Aragog's runtime.
    alpha_mixed = (rho_sol - rho_liq) / dT_phase / jnp.maximum(rho_mixed, 1.0)
    Cp_mixed = jnp.maximum((S_liq - S_sol) / dT_phase * T_avg, 100.0)

    # dTdPs: analytical from intermediates
    dTdPs_mixed = (
        alpha_mixed * T_mixed / (jnp.maximum(rho_mixed, 1.0) * jnp.maximum(Cp_mixed, 100.0))
    )

    # cond: linear blend
    cond_mixed = phi_arr * k_liquid + (1.0 - phi_arr) * k_solid

    # ── Step 4: single-phase table evaluations (SPIDER 269-276) ──
    # Evaluate at S_sol/S_liq when mushy, at actual S otherwise.
    mushy = (phi_arr > 0) & (phi_arr < 1)
    S_for_solid = jnp.where(mushy, S_sol, S)
    S_for_melt = jnp.where(mushy, S_liq, S)

    # The ``gphi > 0.5`` switches below pick which pure-phase table to
    # evaluate when the cell is *outside* the mushy mask. This 0.5 is
    # a binary discriminator (gphi outside [0,1] is by definition
    # super-liquidus or sub-solidus), not the rheological critical
    # melt fraction; see ``PhaseParams.phi_rheo`` for the RCMF.
    def _table_lookup_blend(prop_name: str) -> jax.Array:
        solid_tbl, melt_tbl = self._get_tables(prop_name)
        v_sol = solid_tbl(P, S_for_solid)
        v_mel = melt_tbl(P, S_for_melt)
        return jnp.where(gphi > 0.5, v_mel, v_sol)

    T_single = _table_lookup_blend('temperature')
    rho_single = _table_lookup_blend('density')
    Cp_single = _table_lookup_blend('heat_capacity')
    dTdPs_single = _table_lookup_blend('dTdPs')
    # alpha derived from thermodynamic identity (no thermal_exp tables yet)
    alpha_single = dTdPs_single * rho_single * Cp_single / jnp.maximum(T_single, 1.0)
    cond_single = jnp.where(gphi > 0.5, k_liquid, k_solid)

    # ── Step 5: combine_matprop blend (SPIDER 278-285) ──────────
    def _blend(mixed, single):
        return smth * mixed + (1.0 - smth) * single

    temperature = _blend(T_mixed, T_single)
    density = _blend(rho_mixed, rho_single)
    heat_capacity = _blend(Cp_mixed, Cp_single)
    alpha_raw = _blend(alpha_mixed, alpha_single)
    dTdPs_val = _blend(dTdPs_mixed, dTdPs_single)
    thermal_conductivity = _blend(cond_mixed, cond_single)

    # Guard: clamp negative alpha (matches numpy line 309)
    eps_a = 1.0e-8
    thermal_expansivity = 0.5 * (
        alpha_raw + jnp.sqrt(alpha_raw * alpha_raw + eps_a * eps_a)
    )

    latent_heat = self.latent_heat(P)

    return PhaseState(
        temperature=temperature,
        density=density,
        heat_capacity=heat_capacity,
        thermal_expansivity=thermal_expansivity,
        dTdPs=dTdPs_val,
        thermal_conductivity=thermal_conductivity,
        melt_fraction=phi_arr,
        gphi=gphi,
        smth=smth,
        latent_heat=latent_heat,
    )

dTdPs(P, S)

Adiabatic temperature gradient dT/dP|_S (P, S) [K/Pa].

Source code in src/aragog/jax/eos.py
539
540
541
def dTdPs(self, P: jax.Array, S: jax.Array) -> jax.Array:
    """Adiabatic temperature gradient dT/dP|_S (P, S) [K/Pa]."""
    return self._lookup_phase_weighted('dTdPs', P, S)

density(P, S)

Density rho(P, S) [kg/m^3], matching numpy EntropyEOS.density.

  • Mushy zone (0 < phi < 1): harmonic mean of end-member densities evaluated at phase-boundary entropies (Lever Rule, SPIDER eos_composite.c:236-237).
  • Pure phase (phi = 0 or phi = 1): evaluate the active single- phase table at the actual S (clamped by the table itself). SPIDER combine_matprop(smth=0, mixed, single) selects the single-phase branch in this regime. Using the harmonic-mean form unconditionally biases the fully-molten density relative to the numpy EntropyEOS.
Source code in src/aragog/jax/eos.py
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
def density(self, P: jax.Array, S: jax.Array) -> jax.Array:
    """Density rho(P, S) [kg/m^3], matching numpy EntropyEOS.density.

    - Mushy zone (0 < phi < 1): harmonic mean of end-member
      densities evaluated at phase-boundary entropies (Lever Rule,
      SPIDER ``eos_composite.c:236-237``).
    - Pure phase (phi = 0 or phi = 1): evaluate the active single-
      phase table at the actual S (clamped by the table itself).
      SPIDER ``combine_matprop(smth=0, mixed, single)`` selects the
      single-phase branch in this regime. Using the harmonic-mean
      form unconditionally biases the fully-molten density
      relative to the numpy EntropyEOS.
    """
    phi = self.melt_fraction(P, S)
    mushy = (phi > 0) & (phi < 1)

    solid_table, melt_table = self._get_tables('density')

    # Mushy zone: harmonic mean at phase boundaries
    rho_sol_boundary = self._lookup_at_phase_boundary('density', P, 'solid')
    rho_liq_boundary = self._lookup_at_phase_boundary('density', P, 'melt')
    inv_rho_mushy = phi / jnp.maximum(rho_liq_boundary, 1.0) + (1.0 - phi) / jnp.maximum(
        rho_sol_boundary, 1.0
    )
    rho_mushy = 1.0 / jnp.maximum(inv_rho_mushy, 1e-30)

    # Single-phase: evaluate at actual S (clamped by the table). Pick
    # the melt table for phi >= 0.5, solid otherwise (matches numpy).
    # NOTE: this 0.5 is a binary table-selector for the non-mushy
    # fallback only — outside the mushy band phi is essentially 0 or
    # 1 by construction. It is NOT the rheological critical melt
    # fraction (RCMF). The RCMF lives in
    # ``EntropyPhaseEvaluator._phi_rheo`` / ``PhaseParams.phi_rheo``
    # and drives the viscosity tanh blend separately.
    rho_solid_single = solid_table(P, S)
    rho_melt_single = melt_table(P, S)
    rho_single = jnp.where(phi >= 0.5, rho_melt_single, rho_solid_single)

    return jnp.where(mushy, rho_mushy, rho_single)

heat_capacity(P, S)

Specific heat capacity Cp(P, S) [J/kg/K].

Source code in src/aragog/jax/eos.py
535
536
537
def heat_capacity(self, P: jax.Array, S: jax.Array) -> jax.Array:
    """Specific heat capacity Cp(P, S) [J/kg/K]."""
    return self._lookup_phase_weighted('heat_capacity', P, S)

latent_heat(P)

Latent heat L(P) = T_fus x (S_liq - S_sol) [J/kg].

Source code in src/aragog/jax/eos.py
543
544
545
546
547
548
549
550
def latent_heat(self, P: jax.Array) -> jax.Array:
    """Latent heat L(P) = T_fus x (S_liq - S_sol) [J/kg]."""
    S_sol = self.solidus_entropy(P)
    S_liq = self.liquidus_entropy(P)
    T_sol = self._lookup_at_phase_boundary('temperature', P, 'solid')
    T_liq = self._lookup_at_phase_boundary('temperature', P, 'melt')
    T_fus = 0.5 * (T_sol + T_liq)
    return T_fus * jnp.maximum(S_liq - S_sol, 1.0)

liquidus_entropy(P)

Liquidus entropy S_liq(P) [J/kg/K].

Source code in src/aragog/jax/eos.py
391
392
393
def liquidus_entropy(self, P: jax.Array) -> jax.Array:
    """Liquidus entropy S_liq(P) [J/kg/K]."""
    return self._liquidus(P)

liquidus_entropy_dP(P)

dS_liq/dP at the given pressure(s), in J/(kg·K·Pa).

Source code in src/aragog/jax/eos.py
404
405
406
def liquidus_entropy_dP(self, P: jax.Array) -> jax.Array:
    """dS_liq/dP at the given pressure(s), in J/(kg·K·Pa)."""
    return self._liquidus.dSdP(P)

melt_fraction(P, S)

Melt fraction phi from position between solidus and liquidus.

phi = 0 for S <= S_sol, phi = 1 for S >= S_liq, linear between.

Source code in src/aragog/jax/eos.py
408
409
410
411
412
413
414
415
416
def melt_fraction(self, P: jax.Array, S: jax.Array) -> jax.Array:
    """Melt fraction phi from position between solidus and liquidus.

    phi = 0 for S <= S_sol, phi = 1 for S >= S_liq, linear between.
    """
    S_sol = self.solidus_entropy(P)
    S_liq = self.liquidus_entropy(P)
    dS = jnp.maximum(S_liq - S_sol, 1e-10)
    return jnp.clip((S - S_sol) / dS, 0.0, 1.0)

solidus_entropy(P)

Solidus entropy S_sol(P) [J/kg/K].

Source code in src/aragog/jax/eos.py
387
388
389
def solidus_entropy(self, P: jax.Array) -> jax.Array:
    """Solidus entropy S_sol(P) [J/kg/K]."""
    return self._solidus(P)

solidus_entropy_dP(P)

dS_sol/dP at the given pressure(s), in J/(kg·K·Pa).

Needed by the SPIDER-parity bracket Jmix in aragog.jax.phase.compute_fluxes. Mirrors numpy EntropyEOS.solidus_entropy_dP (entropy.py:377-388).

Source code in src/aragog/jax/eos.py
395
396
397
398
399
400
401
402
def solidus_entropy_dP(self, P: jax.Array) -> jax.Array:
    """dS_sol/dP at the given pressure(s), in J/(kg·K·Pa).

    Needed by the SPIDER-parity bracket Jmix in
    ``aragog.jax.phase.compute_fluxes``. Mirrors numpy
    ``EntropyEOS.solidus_entropy_dP`` (entropy.py:377-388).
    """
    return self._solidus.dSdP(P)

temperature(P, S)

Temperature T(P, S) [K].

Source code in src/aragog/jax/eos.py
491
492
493
def temperature(self, P: jax.Array, S: jax.Array) -> jax.Array:
    """Temperature T(P, S) [K]."""
    return self._lookup_phase_weighted('temperature', P, S)

thermal_expansivity(P, S)

Thermal expansivity alpha(P, S) [1/K].

Derived: alpha = rho * Cp * |dTdPs| / T.

Source code in src/aragog/jax/eos.py
552
553
554
555
556
557
558
559
560
561
def thermal_expansivity(self, P: jax.Array, S: jax.Array) -> jax.Array:
    """Thermal expansivity alpha(P, S) [1/K].

    Derived: alpha = rho * Cp * |dTdPs| / T.
    """
    T = self.temperature(P, S)
    rho = self.density(P, S)
    Cp = self.heat_capacity(P, S)
    dTdPs_val = self.dTdPs(P, S)
    return rho * Cp * jnp.abs(dTdPs_val) / jnp.maximum(T, 1.0)

PhaseState

Bases: NamedTuple

Material properties at (P, S) following SPIDER eos_composite.c convention.

All scalar properties (T, rho, Cp, alpha, dTdPs, k) result from a smth-blend between two-phase mixed values (analytical formulas at the phase boundaries) and single table values (looked up at the actual P and at S_sol or S_liq when mushy). This matches numpy EntropyPhaseEvaluator._update_eos step-for-step.

aragog.jax.phase

Phase-aware property evaluation and flux assembly. evaluate_phase performs the SPIDER two-stage tanh blend at one cell; compute_fluxes is the pure-functional RHS that returns heat flux at basic nodes, internal heating at staggered nodes, eddy diffusivity, and capacitance.

phase

JAX phase evaluator and flux computation for the entropy solver.

Pure-functional replacements for entropy_phase.py (phase properties) and entropy_state.py (MLT convection, heat/mass fluxes). All functions are JIT-compilable and differentiable.

The numpy versions mutate state arrays in-place. The JAX versions take arrays in and return NamedTuples out, with no side effects.

Dependencies: jax, equinox (already in PROTEUS ecosystem via atmodeller).

PhaseParams(phi_rheo=0.4, phi_width=0.15, viscosity_solid=1e+21, viscosity_liquid=0.1, grain_size=0.001, k_solid=4.0, k_liquid=2.0, matprop_smooth_width=0.0, conduction=True, convection=True, grav_sep=False, mixing=False, eddy_diff_thermal=1.0, eddy_diff_chemical=1.0, kappah_floor=0.0, bottom_up_grav_sep=True, phase_smoothing='tanh', phase_smoothing_width=0.01)

Bases: Module

Static parameters for phase evaluation and flux computation.

Constructed once from config, passed as args to JIT-compiled functions. All fields are scalars or 1D JAX arrays.

Source code in src/aragog/jax/phase.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def __init__(
    self,
    phi_rheo: float = 0.4,
    phi_width: float = 0.15,
    viscosity_solid: float = 1e21,
    viscosity_liquid: float = 1e-1,
    grain_size: float = 1e-3,
    k_solid: float = 4.0,
    k_liquid: float = 2.0,
    matprop_smooth_width: float = 0.0,
    conduction: bool = True,
    convection: bool = True,
    grav_sep: bool = False,
    mixing: bool = False,
    eddy_diff_thermal: float = 1.0,
    eddy_diff_chemical: float = 1.0,
    kappah_floor: float = 0.0,
    bottom_up_grav_sep: bool = True,
    phase_smoothing: str = 'tanh',
    phase_smoothing_width: float = 0.01,
):
    self.phi_rheo = phi_rheo
    self.phi_width = phi_width
    self.log10_visc_solid = jnp.log10(viscosity_solid)
    self.log10_visc_liquid = jnp.log10(viscosity_liquid)
    self.visc_liquid = viscosity_liquid
    self.grain_size = grain_size
    self.k_solid = k_solid
    self.k_liquid = k_liquid
    self.matprop_smooth_width = matprop_smooth_width
    self.conduction = float(conduction)
    self.convection = float(convection)
    self.grav_sep = float(grav_sep)
    self.mixing = float(mixing)
    self.eddy_diff_thermal = eddy_diff_thermal
    self.eddy_diff_chemical = eddy_diff_chemical
    self.kappah_floor = kappah_floor
    self.bottom_up_grav_sep = float(bottom_up_grav_sep)
    if phase_smoothing not in ('cubic_hermite', 'tanh'):
        raise ValueError(
            f"phase_smoothing must be 'cubic_hermite' or 'tanh', got {phase_smoothing!r}"
        )
    self.phase_smoothing_tanh = 1.0 if phase_smoothing == 'tanh' else 0.0
    self.phase_smoothing_width = float(phase_smoothing_width)

MeshArrays(d_dr_matrix, quantity_matrix, area, volume, radii_basic, mixing_length, mixing_length_sq, mixing_length_cu, radii_stag, P_stag, P_basic, gravity, dP_dr_basic=None, gravity_stag=None)

Bases: Module

Static mesh geometry arrays, converted from numpy Mesh once.

All arrays are 1D JAX arrays. The transform matrices are 2D.

from P_basic and radii_basic via numpy gradient (matches entropy_state._dP_dr_basic = np.gradient(P_basic, r_basic)).

Source code in src/aragog/jax/phase.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def __init__(
    self,
    d_dr_matrix,
    quantity_matrix,
    area,
    volume,
    radii_basic,
    mixing_length,
    mixing_length_sq,
    mixing_length_cu,
    radii_stag,
    P_stag,
    P_basic,
    gravity,
    dP_dr_basic=None,
    gravity_stag=None,
):
    """Custom init so callers that predate the dP_dr_basic field
    keep working. When ``dP_dr_basic`` is not supplied we derive it
    from ``P_basic`` and ``radii_basic`` via numpy gradient (matches
    ``entropy_state._dP_dr_basic = np.gradient(P_basic, r_basic)``).
    """
    self.d_dr_matrix = d_dr_matrix
    self.quantity_matrix = quantity_matrix
    self.area = area
    self.volume = volume
    self.radii_basic = radii_basic
    self.mixing_length = mixing_length
    self.mixing_length_sq = mixing_length_sq
    self.mixing_length_cu = mixing_length_cu
    self.radii_stag = radii_stag
    self.P_stag = P_stag
    self.P_basic = P_basic
    self.gravity = gravity
    if dP_dr_basic is None:
        import numpy as _np

        dP_dr_basic = jnp.asarray(
            _np.gradient(_np.asarray(P_basic), _np.asarray(radii_basic))
        )
    self.dP_dr_basic = dP_dr_basic
    # Default ``gravity_stag`` to a midpoint average of ``gravity``
    # so callers that don't supply a per-staggered profile still
    # see a reasonable approximation. ``from_numpy_mesh`` overrides
    # this with the EOS-interpolated value at staggered radii
    # whenever the column is available.
    if gravity_stag is None:
        g_arr = jnp.asarray(gravity)
        gravity_stag = 0.5 * (g_arr[:-1] + g_arr[1:])
    self.gravity_stag = jnp.asarray(gravity_stag)

from_numpy_mesh(mesh) staticmethod

Build from a numpy Aragog Mesh object.

Source code in src/aragog/jax/phase.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
@staticmethod
def from_numpy_mesh(mesh) -> 'MeshArrays':
    """Build from a numpy Aragog Mesh object."""
    P_basic_arr = np.asarray(mesh.basic_pressure).ravel()
    r_basic_arr = np.asarray(mesh.basic.radii).ravel()
    r_stag_arr = np.asarray(mesh.staggered.radii).ravel()
    return MeshArrays(
        d_dr_matrix=jnp.asarray(mesh._d_dr_transform),
        quantity_matrix=jnp.asarray(mesh._quantity_transform),
        area=jnp.asarray(np.asarray(mesh.basic.area).ravel()),
        volume=jnp.asarray(np.asarray(mesh.basic.volume).ravel()),
        radii_basic=jnp.asarray(r_basic_arr),
        radii_stag=jnp.asarray(r_stag_arr),
        mixing_length=jnp.asarray(np.asarray(mesh.basic.mixing_length).ravel()),
        mixing_length_sq=jnp.asarray(np.asarray(mesh.basic.mixing_length_squared).ravel()),
        mixing_length_cu=jnp.asarray(np.asarray(mesh.basic.mixing_length_cubed).ravel()),
        P_stag=jnp.asarray(np.asarray(mesh.staggered_pressure).ravel()),
        P_basic=jnp.asarray(P_basic_arr),
        # SPIDER-parity dP/dr at basic nodes via numpy gradient (matches
        # entropy_state._dP_dr_basic = np.gradient(P_basic, r_basic)).
        dP_dr_basic=jnp.asarray(np.gradient(P_basic_arr, r_basic_arr)),
        # Per-node gravity profile when the external mesh file
        # carries eos_gravity (UserDefinedEOS / Zalmoxis path),
        # else scalar broadcast. The per-node array is aligned to
        # the same basic-node grid as area / volume / mixing_length,
        # so the MLT buoyancy cascade in compute_mlt picks up the
        # radial dependence without any downstream broadcasting
        # change. Scalar fallback chain:
        # mesh.eos._gravitational_acceleration (AdamsWilliamsonEOS),
        # mesh.settings.gravitational_acceleration (external EOS
        # without the private attribute), then 9.81 m/s^2.
        gravity=_build_gravity_array(mesh, r_stag=False),
        # Same construction at the staggered radii. Mirrors numpy's
        # entropy_solver.py ``g_stag = np.interp(r_stag, eos_radius,
        # eos_gravity)`` (with scalar fallback). Aligned to the
        # staggered grid.
        gravity_stag=_build_gravity_array(mesh, r_stag=True),
    )

PhaseProperties

Bases: NamedTuple

Material properties at a set of mesh nodes.

FluxOutput

Bases: NamedTuple

Output from the flux computation at basic nodes.

compute_fluxes(S_stag, time, eos, params, mesh, heating_rate, S_basic_cmb_override=None, dSdr_cmb_override=None)

Compute all heat and mass fluxes from the entropy profile.

This is the physics kernel called by the ODE RHS. It is a pure function: no mutation, no side effects, fully JIT-compilable.

Parameters:

Name Type Description Default
S_stag Array

Entropy at staggered nodes [J/kg/K].

required
time float

Current time [yr] (used for radionuclide heating).

required
eos EntropyEOS_JAX

JAX EOS tables.

required
params PhaseParams

Static material parameters and transport flags.

required
mesh MeshArrays

Mesh geometry and transform matrices.

required
heating_rate Array

Internal heating rate at staggered nodes [W/kg] (radionuclide + tidal, pre-computed by caller).

required
S_basic_cmb_override float or None

Optional override for the entropy at the CMB basic node (basic-node index 0). Used by the energy_balance core BC which reconstructs S_basic[0] from the state-tracked dSdr_cmb via S[0] + dSdr_cmb * (r_basic[0] - r_stag[0]). When None, the standard quantity_matrix mapping is used.

None
dSdr_cmb_override float or None

Optional override for the entropy gradient at the CMB basic node (dSdr[0]). Used by the energy_balance core BC where dSdr_cmb is a state-tracked variable. When None, the standard d_dr_matrix mapping is used.

None

Returns:

Type Description
FluxOutput

Heat flux, mass flux, eddy diffusivity, heating.

Source code in src/aragog/jax/phase.py
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
def compute_fluxes(
    S_stag: jax.Array,
    time: float,
    eos: EntropyEOS_JAX,
    params: PhaseParams,
    mesh: MeshArrays,
    heating_rate: jax.Array,
    S_basic_cmb_override=None,
    dSdr_cmb_override=None,
) -> FluxOutput:
    """Compute all heat and mass fluxes from the entropy profile.

    This is the physics kernel called by the ODE RHS. It is a pure
    function: no mutation, no side effects, fully JIT-compilable.

    Parameters
    ----------
    S_stag : jax.Array
        Entropy at staggered nodes [J/kg/K].
    time : float
        Current time [yr] (used for radionuclide heating).
    eos : EntropyEOS_JAX
        JAX EOS tables.
    params : PhaseParams
        Static material parameters and transport flags.
    mesh : MeshArrays
        Mesh geometry and transform matrices.
    heating_rate : jax.Array
        Internal heating rate at staggered nodes [W/kg]
        (radionuclide + tidal, pre-computed by caller).
    S_basic_cmb_override : float or None
        Optional override for the entropy at the CMB basic node
        (basic-node index 0). Used by the energy_balance core BC
        which reconstructs S_basic[0] from the state-tracked
        dSdr_cmb via S[0] + dSdr_cmb * (r_basic[0] - r_stag[0]).
        When None, the standard quantity_matrix mapping is used.
    dSdr_cmb_override : float or None
        Optional override for the entropy gradient at the CMB
        basic node (dSdr[0]). Used by the energy_balance core BC
        where dSdr_cmb is a state-tracked variable. When None,
        the standard d_dr_matrix mapping is used.

    Returns
    -------
    FluxOutput
        Heat flux, mass flux, eddy diffusivity, heating.
    """
    # Interpolate entropy to basic nodes and compute gradient
    S_basic = mesh.quantity_matrix @ S_stag
    dSdr = mesh.d_dr_matrix @ S_stag

    # SPIDER ic.c:450 boundary-copy convention for dSdr at the surface
    # (mirrors numpy entropy_state.py:390 ``dSdxi[-1] = dSdxi[-2]``).
    # The d_dr_matrix gives a linear-extrapolation gradient at the top
    # basic node, but the SPIDER convention is to copy the adjacent
    # interior value. Without this, the surface dSdr can flip sign
    # relative to numpy, blowing up kappa_h and dS/dt at the second-
    # to-last staggered cell. The CMB side has its own override path
    # via ``dSdr_cmb_override``, so the surface copy here is the only
    # one needed at this layer.
    dSdr = dSdr.at[-1].set(dSdr[-2])

    # Apply energy_balance overrides at the CMB basic node.
    # Done as separate jax.lax.cond branches so the function remains
    # JIT-compatible regardless of whether the overrides are None or
    # JAX scalars. We use jnp.where with a static bool flag passed
    # through from the caller via the optional argument.
    if S_basic_cmb_override is not None:
        S_basic = S_basic.at[0].set(S_basic_cmb_override)
    if dSdr_cmb_override is not None:
        dSdr = dSdr.at[0].set(dSdr_cmb_override)
    else:
        # Mirror numpy entropy_state.py:389 ``dSdxi[0] = dSdxi[1]`` for
        # the non-energy_balance modes (quasi_steady etc.) where there
        # is no boundary-state override. In energy_balance mode the
        # explicit override above wins and this branch is skipped.
        dSdr = dSdr.at[0].set(dSdr[1])

    # Phase properties at basic nodes only. The SPIDER-bracket Jmix
    # is built entirely from basic-node quantities, so staggered-node
    # phase properties are not materialised here.
    phase_basic = evaluate_phase(eos, params, mesh.P_basic, S_basic)

    # MLT eddy diffusivity
    kappa_h, kappa_c = compute_mlt(dSdr, phase_basic, mesh, params)

    # Basic node properties
    rho = phase_basic.density
    T = phase_basic.temperature
    Cp = phase_basic.heat_capacity
    k = phase_basic.thermal_conductivity
    dTdPs_basic = phase_basic.dTdPs

    # Heat flux components (multiply by flag to enable/disable)
    heat_flux = jnp.zeros_like(S_basic)

    # Conduction (SPIDER decomposition matching numpy entropy_state):
    #   F_cond = -k * [(T/Cp) * dS/dr + dT/dr|_adiabat]
    # where the adiabatic gradient is dT/dr|_ad = dTdPs * dPdr_basic
    # (EOS-table dT/dP|_S times the structural Adams-Williamson dP/dr).
    # This avoids the noise from finite-differencing T_stag and matches
    # SPIDER's eos-consistent conduction at phase boundaries.
    # Cp floor (100 J/kg/K) parity with numpy entropy_state.update. A
    # logger.warning inside this jit-compiled RHS would only fire
    # during tracing, not on actual data; the equivalent diagnostic
    # is emitted at load time by EntropySolver._check_eos_floors,
    # plus a once-per-instance runtime warning on the numpy path.
    Cp_safe = jnp.maximum(Cp, 100.0)
    superadiabatic = (T / Cp_safe) * dSdr
    dT_dr_adiabat = dTdPs_basic * mesh.dP_dr_basic
    heat_flux = heat_flux + params.conduction * (-k * (superadiabatic + dT_dr_adiabat))

    # Convection: F_conv = rho * T * kappa_h * (-dS/dr)
    heat_flux = heat_flux + params.convection * (rho * T * kappa_h * (-dSdr))

    # Mass flux for gravitational separation and mixing
    mass_flux = jnp.zeros_like(S_basic)

    # Gravitational separation.
    #
    # Raw Stokes/permeability-driven mass flux:
    #     jgrav_raw = rho * phi * (1 - phi) * v_rel
    # SPIDER analogue smoothing (SPIDER/energy.c:523-533,
    # JGRAV_BOTTOM_UP + get_smoothing): multiply jgrav_raw by a bounded
    # polynomial of an UN-truncated two-phase fraction
    #     gphi = (S - S_sol(P)) / (S_liq(P) - S_sol(P))
    # evaluated at the staggered cell immediately BELOW the interface.
    # The polynomial `16 * gphi^2 * (1 - gphi)^2` (clipped to [0, 1])
    # vanishes cleanly at both pure phases and has bounded derivatives
    # everywhere, unlike SPIDER's tanh smoothing. This is the scipy-
    # path fix from entropy_state.py mirrored here so the JAX backend
    # doesn't reproduce the pre-fix CMB drain at first crystallisation.
    phi_b = phase_basic.melt_fraction
    v_rel = relative_velocity(
        eos,
        params,
        mesh.P_basic,
        rho,
        phi_b,
        mesh.gravity,
    )
    jgrav_raw = rho * phi_b * (1.0 - phi_b) * v_rel

    # gphi at STAGGERED nodes (cell below each basic interface)
    S_sol_stag = eos.solidus_entropy(mesh.P_stag)
    S_liq_stag = eos.liquidus_entropy(mesh.P_stag)
    dS_stag = jnp.maximum(S_liq_stag - S_sol_stag, 1.0)
    gphi_stag = (S_stag - S_sol_stag) / dS_stag

    smth_stag = phase_boundary_smoothing(gphi_stag, params)

    # Map staggered smoothing to basic-node interfaces: interior basic
    # node i (1..N-2) sees the smoothing of staggered node i-1 (the
    # cell BELOW). Boundary basic nodes (0 and -1) use smth = 1 as a
    # placeholder because the mass flux at those indices is zeroed a
    # few lines below anyway. Lengths: staggered has N entries, basic
    # has N+1; smth_stag[:-1] supplies the N-1 interior interfaces
    # plus the two boundaries, totalling N+1.
    smth_basic = jnp.concatenate(
        [
            jnp.array([1.0]),
            smth_stag[:-1],
            jnp.array([1.0]),
        ]
    )

    # `bottom_up_grav_sep = 1.0` selects the smoothed flux, 0.0 selects
    # the raw flux (for reproducing the pre-fix drain in regression
    # tests).
    jgrav_smoothed = jgrav_raw * smth_basic
    jgrav = (
        params.bottom_up_grav_sep * jgrav_smoothed
        + (1.0 - params.bottom_up_grav_sep) * jgrav_raw
    )
    mass_flux = mass_flux + params.grav_sep * jgrav

    # Zero mass fluxes at boundaries (SPIDER convention)
    mass_flux = mass_flux.at[0].set(0.0)
    mass_flux = mass_flux.at[-1].set(0.0)

    # Add latent heat transport from mass flux
    heat_flux = heat_flux + mass_flux * phase_basic.latent_heat

    # Mixing flux (SPIDER-parity bracket form, heat flux only).
    #
    # Mirrors numpy ``entropy_state.update`` (entropy_state.py:656-692):
    #     Jmix_heat = -kappa_c * rho * T_fus * bracket * smth_basic_mix
    #     bracket = dS/dr - [phi·dS_liq/dP + (1-phi)·dS_sol/dP] · dP/dr
    # evaluated at basic nodes. The mass-flux form
    # ``mass_flux += mixing * rho * kappa_c * (-dphi/dr)`` (delivered
    # via the latent-heat term) is NOT equivalent and diverges from
    # the numpy production path.
    #
    # The smth polynomial ``16·gphi²·(1-gphi)²`` at basic nodes zeroes
    # the flux outside the mushy band; the bracket itself is finite in
    # pure phases because dS_sol/dP and dS_liq/dP are bounded.
    S_sol_basic = eos.solidus_entropy(mesh.P_basic)
    S_liq_basic = eos.liquidus_entropy(mesh.P_basic)
    dS_phase_basic = jnp.maximum(S_liq_basic - S_sol_basic, 1.0)
    dS_sol_dP_basic = eos.solidus_entropy_dP(mesh.P_basic)
    dS_liq_dP_basic = eos.liquidus_entropy_dP(mesh.P_basic)
    phi_clipped = phase_basic.melt_fraction
    bracket = (
        dSdr
        - (phi_clipped * dS_liq_dP_basic + (1.0 - phi_clipped) * dS_sol_dP_basic)
        * mesh.dP_dr_basic
    )

    # T_fus = latent_heat / dS_phase (mirrors numpy
    # _ensure_basic_phase_boundary_cache: T_fus_basic = L_basic / dS_phase).
    T_fus_basic = phase_basic.latent_heat / dS_phase_basic

    # Smoothing: gphi at basic nodes, same smoothing family as Jgrav
    # (cubic Hermite or SPIDER tanh, selected by params.phase_smoothing_tanh).
    gphi_basic = (S_basic - S_sol_basic) / dS_phase_basic
    smth_basic_mix = phase_boundary_smoothing(gphi_basic, params)

    jmix_heat = -kappa_c * rho * T_fus_basic * bracket * smth_basic_mix
    # Zero at CMB and surface (no mass/heat transfer across those boundaries)
    jmix_heat = jmix_heat.at[0].set(0.0)
    jmix_heat = jmix_heat.at[-1].set(0.0)
    heat_flux = heat_flux + params.mixing * jmix_heat

    return FluxOutput(
        heat_flux=heat_flux,
        mass_flux=mass_flux,
        eddy_diffusivity=kappa_h,
        heating=heating_rate,
        jmix_heat=jmix_heat,
    )

compute_mlt(dSdr, phase_basic, mesh, params)

Compute MLT eddy diffusivity from the entropy gradient.

Parameters:

Name Type Description Default
dSdr Array

Entropy gradient at basic nodes [J/kg/K/m].

required
phase_basic PhaseProperties

Material properties at basic nodes.

required
mesh MeshArrays

Mesh geometry.

required
params PhaseParams

Static parameters.

required

Returns:

Name Type Description
kappa_h Array

Thermal eddy diffusivity at basic nodes [m^2/s].

kappa_c Array

Chemical eddy diffusivity at basic nodes [m^2/s].

Source code in src/aragog/jax/phase.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
def compute_mlt(
    dSdr: jax.Array,
    phase_basic: PhaseProperties,
    mesh: MeshArrays,
    params: PhaseParams,
) -> tuple[jax.Array, jax.Array]:
    """Compute MLT eddy diffusivity from the entropy gradient.

    Parameters
    ----------
    dSdr : jax.Array
        Entropy gradient at basic nodes [J/kg/K/m].
    phase_basic : PhaseProperties
        Material properties at basic nodes.
    mesh : MeshArrays
        Mesh geometry.
    params : PhaseParams
        Static parameters.

    Returns
    -------
    kappa_h : jax.Array
        Thermal eddy diffusivity at basic nodes [m^2/s].
    kappa_c : jax.Array
        Chemical eddy diffusivity at basic nodes [m^2/s].
    """
    alpha = phase_basic.thermal_expansivity
    T = phase_basic.temperature
    Cp = phase_basic.heat_capacity
    nu = phase_basic.kinematic_viscosity

    # Buoyancy from entropy gradient. ``jnp.abs`` has a non-differentiable
    # kink at zero (subgradient anywhere in [-1, +1] is mathematically
    # valid, but JAX returns sign(0) = 0 in fwd and the backward pass
    # through the abs and the downstream multiplications can produce
    # NaN at exact-zero entropy gradients. Numpy uses
    # ``_smooth_abs_neg(dSdr, eps=1e-30)`` which is the smoothed
    # ``max(-x, 0)``; we use the differentiable
    # ``0.5*(|x| + sqrt(x^2 + eps^2)) ~ |x|`` instead, which keeps the
    # forward equal to ``abs(x)`` to ULP precision while delivering a
    # finite analytic gradient everywhere.
    eps_abs = 1.0e-30
    abs_dSdr_safe = 0.5 * (jnp.abs(dSdr) + jnp.sqrt(dSdr * dSdr + eps_abs * eps_abs))
    effective_superadiabatic = alpha * T * abs_dSdr_safe / jnp.maximum(Cp, 1.0)
    velocity_prefactor = mesh.gravity * effective_superadiabatic

    # Convective mask: unstable when dS/dr < 0.
    # Hard mask matching the numpy reference. Using jnp.where (not boolean
    # indexing) for JAX traceability. Produces exactly 0 at stable nodes,
    # avoiding spurious convection from a soft sigmoid.
    conv_mask = jnp.where(dSdr < 0.0, 1.0, 0.0)

    # Viscous velocity (Re <= Re_crit)
    viscous_velocity = (velocity_prefactor * mesh.mixing_length_cu / (18.0 * nu)) * conv_mask

    # Inviscid velocity (Re > Re_crit). ``jnp.sqrt(jnp.maximum(x, 0))`` is
    # the textbook NaN-safe forward, but its backward gradient at x=0 is
    # ``0.5 * 0 / sqrt(0) = 0/0 = NaN`` (the maximum's stop-at-zero
    # subgradient kills the divisor protection). Numpy avoids this with
    # ``np.sqrt(x + 1e-20)`` (always-positive argument). Mirror it here:
    # eps**2 = 1e-40 is far below any physical inviscid_velocity_sq, so
    # the forward is bit-equivalent for non-trivial x.
    eps_sqrt = 1.0e-20
    inviscid_velocity_sq = (velocity_prefactor * mesh.mixing_length_sq / 16.0) * conv_mask
    inviscid_velocity = jnp.sqrt(inviscid_velocity_sq + eps_sqrt)

    # Reynolds number
    reynolds = viscous_velocity * mesh.mixing_length / nu

    # Smooth blend between viscous and inviscid regimes. The narrow
    # blend_width (0.01 * RE_CRIT) keeps inviscid k_h confined to the
    # convecting regime; widening the blend leaks inviscid mixing
    # into the solid regime and induces T_core bistability.
    blend_width = 0.01 * RE_CRIT
    inviscid_weight = 0.5 * (
        1.0 + jnp.tanh((reynolds - RE_CRIT) / jnp.maximum(blend_width, 1e-30))
    )

    # Raw eddy diffusivity
    kh_raw = (
        (1.0 - inviscid_weight) * viscous_velocity + inviscid_weight * inviscid_velocity
    ) * mesh.mixing_length

    # Thermal eddy diffusivity (SPIDER convention: positive=scale, negative=constant)
    kappa_h = jnp.where(
        params.eddy_diff_thermal > 0,
        params.eddy_diff_thermal * kh_raw,
        jnp.full_like(kh_raw, -params.eddy_diff_thermal),
    )

    # Chemical eddy diffusivity (from raw kh, before floor)
    kappa_c = jnp.where(
        params.eddy_diff_chemical > 0,
        params.eddy_diff_chemical * kh_raw,
        jnp.full_like(kh_raw, -params.eddy_diff_chemical),
    )

    # kappa_h floor (phase-dependent, modulated by melt fraction).
    # Production PROTEUS runs use kappah_floor = 10 m^2/s; the
    # phi-modulated f_floor ramps from 0 in solid layers (no spurious
    # convective flux) to ~1 in mushy/liquid layers, where physical
    # convection is expected and MLT can otherwise numerically freeze.
    # See solver/entropy_state.py for the same comment block on the
    # numpy path. The transition is anchored on the rheological critical
    # melt fraction (``params.phi_rheo``, default 0.4) with width
    # ``params.phi_width`` (default 0.15) so the floor turns on exactly
    # where Costa-blended viscosity drops, consistent across config knobs.
    phi_basic = phase_basic.melt_fraction
    f_floor = tanh_weight(phi_basic, params.phi_rheo, params.phi_width)
    kh_floor = params.kappah_floor * f_floor
    kappa_h = jnp.maximum(kappa_h, kh_floor)

    # SPIDER energy.c:220-223 CMB fix: use kappa_h from the first interior
    # node at the CMB basic node, since kappa_h is a nonlinear function of
    # dSdr and the boundary extrapolation can over- or under-estimate it
    # relative to the interior value. Mirrors numpy entropy_state.py:533.
    kappa_h = kappa_h.at[0].set(kappa_h[1])

    return kappa_h, kappa_c

evaluate_phase(eos, params, P, S)

Compute all material properties at (P, S) nodes.

Parameters:

Name Type Description Default
eos EntropyEOS_JAX

JAX EOS tables.

required
params PhaseParams

Static material parameters.

required
P Array

Pressure [Pa], 1D.

required
S Array

Entropy [J/kg/K], 1D.

required

Returns:

Type Description
PhaseProperties

All material properties at the given nodes.

Source code in src/aragog/jax/phase.py
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
def evaluate_phase(
    eos: EntropyEOS_JAX,
    params: PhaseParams,
    P: jax.Array,
    S: jax.Array,
) -> PhaseProperties:
    """Compute all material properties at (P, S) nodes.

    Parameters
    ----------
    eos : EntropyEOS_JAX
        JAX EOS tables.
    params : PhaseParams
        Static material parameters.
    P : jax.Array
        Pressure [Pa], 1D.
    S : jax.Array
        Entropy [J/kg/K], 1D.

    Returns
    -------
    PhaseProperties
        All material properties at the given nodes.
    """
    # SPIDER-parity single-pass evaluation: T, rho, Cp, alpha, dTdPs, k
    # all derived from one shared (S_sol, S_liq, gphi, smth, phase-boundary)
    # cache, with the smth-blend mixed<->single matching numpy
    # EntropyPhaseEvaluator._update_eos.
    state = eos.compute_phase_state(
        P,
        S,
        k_solid=params.k_solid,
        k_liquid=params.k_liquid,
        matprop_smooth_width=params.matprop_smooth_width,
    )
    phi = state.melt_fraction

    # Viscosity: two-stage blend mirroring numpy entropy_phase.py:311-327
    # (used downstream by MLT -> kappa_c -> Jmix, which is why both stages
    # matter). Stage 1: tanh blend at phi_rheo (SPIDER util.c:255-259).
    # Stage 2: combine_matprop with the cached matprop_smooth_width smth
    # that compute_phase_state also uses for T/rho/Cp/alpha/k.
    w = tanh_weight(phi, params.phi_rheo, params.phi_width)
    log_visc_mixed = (1.0 - w) * params.log10_visc_solid + w * params.log10_visc_liquid
    log_visc_single = jnp.where(
        phi > 0.5,
        params.log10_visc_liquid,
        params.log10_visc_solid,
    )
    log_visc = state.smth * log_visc_mixed + (1.0 - state.smth) * log_visc_single
    viscosity = 10.0**log_visc
    kinematic_viscosity = viscosity / state.density

    # Capacitance for entropy equation
    capacitance = state.density * state.temperature

    return PhaseProperties(
        temperature=state.temperature,
        density=state.density,
        heat_capacity=state.heat_capacity,
        thermal_expansivity=state.thermal_expansivity,
        dTdPs=state.dTdPs,
        melt_fraction=phi,
        viscosity=viscosity,
        kinematic_viscosity=kinematic_viscosity,
        thermal_conductivity=state.thermal_conductivity,
        latent_heat=state.latent_heat,
        capacitance=capacitance,
    )

aragog.jax.solver

Standalone JAX solve path used by the verification suite. The production path uses CVODE driven by aragog.solver.entropy_solver and only borrows the Jacobian from jax.jacrev; this module's solve_entropy is kept for parity tests and gradient-based sensitivity studies.

solver

JAX ODE solver for the entropy equation (research-only).

Direct-JAX integration of the entropy equation via diffrax Kvaerno5 (5th-order ESDIRK, A-L stable). The RHS applies boundary conditions, computes flux divergence, and adds internal heating, all in pure JAX.

Not production-ready: the diffrax ESDIRK solvers (kvaerno3, kvaerno5) stall at the first crystallisation step on the cubic-Hermite J_grav smoothing in coupled Earth-mantle runs. Kept for autodiff development. The production JAX integration path is the CVODE Option-Z path in aragog/solver/cvode_jax.py, selected by setting EnergyParams.use_jax_jacobian = True (PROTEUS-side backend="jax"), which uses CVODE for time stepping and JAX only for the analytic Jacobian.

Dependencies: jax, equinox, diffrax, lineax (transitive via diffrax).

BoundaryParams(*, outer_bc_type, outer_bc_value, emissivity, T_eq, inner_bc_type, inner_bc_value, core_density, core_heat_capacity, tfac_core_avg, cmb_area=0.0, core_M=0.0, cmb_dr_cmb=0.0, param_utbl=False, param_utbl_const=0.0)

Bases: Module

Boundary condition configuration as a JAX pytree.

Surface BC types: 1 = grey-body (F = emissivity * sigma * (T^4 - T_eq^4)) 4 = prescribed flux (from atmosphere module)

CMB BC types: 0 = insulating (F = 0) 1 = core cooling (Bower+2018 Eq. 37) 2 = prescribed flux 3 = prescribed temperature (preserve conduction-derived flux) 5 = energy_balance (SPIDER bit-parity): F_cmb derived from the boundary entropy gradient (state-tracked dSdr_cmb). Used by dSdt_energy_balance.

All float fields are stored as JAX arrays (not Python floats) to avoid JIT recompilation when values change between coupling steps.

Energy-balance constants are optional (default 0); only used when inner_bc_type == 5 via dSdt_energy_balance.

Source code in src/aragog/jax/solver.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def __init__(
    self,
    *,
    outer_bc_type,
    outer_bc_value,
    emissivity,
    T_eq,
    inner_bc_type,
    inner_bc_value,
    core_density,
    core_heat_capacity,
    tfac_core_avg,
    cmb_area=0.0,
    core_M=0.0,
    cmb_dr_cmb=0.0,
    param_utbl=False,
    param_utbl_const=0.0,
):
    self.outer_bc_type = outer_bc_type
    self.outer_bc_value = jnp.asarray(outer_bc_value, dtype=jnp.float64)
    self.emissivity = jnp.asarray(emissivity, dtype=jnp.float64)
    self.T_eq = jnp.asarray(T_eq, dtype=jnp.float64)
    self.inner_bc_type = inner_bc_type
    self.inner_bc_value = jnp.asarray(inner_bc_value, dtype=jnp.float64)
    self.core_density = jnp.asarray(core_density, dtype=jnp.float64)
    self.core_heat_capacity = jnp.asarray(core_heat_capacity, dtype=jnp.float64)
    self.tfac_core_avg = jnp.asarray(tfac_core_avg, dtype=jnp.float64)
    self.cmb_area = jnp.asarray(cmb_area, dtype=jnp.float64)
    self.core_M = jnp.asarray(core_M, dtype=jnp.float64)
    self.cmb_dr_cmb = jnp.asarray(cmb_dr_cmb, dtype=jnp.float64)
    self.param_utbl = bool(param_utbl)
    self.param_utbl_const = jnp.asarray(param_utbl_const, dtype=jnp.float64)

SolveResult

Bases: NamedTuple

Output from solve_entropy.

aragog.jax.nondim

Reference scales used internally to non-dimensionalise the state vector and time before they enter the integrator. The triplet (state_scale, rhs_scale, t_ref) is built from EntropySolver._S_ref, _dSdr_ref, and _t_ref_yr; it is not user-configurable.

nondim

Nondimensional scaling spec shared by the numpy and JAX RHS paths.

Single source of truth for the (state_scale, rhs_scale, t_ref) triplet that scales physical-units state into the BDF integrator's O(1) work space. Built once by the EntropySolver, consumed by both the scipy/CVODE wrapper (entropy_solver.py) and the JAX CVODE factory (cvode_jax.py).

Internal contract enforced by __post_init__:

rhs_scale = t_ref / state_scale          (per component)
state_scale > 0, rhs_scale > 0, t_ref > 0
state_scale.shape == rhs_scale.shape

By construction, every NonDimScales instance enforces the internal contract (rhs_scale = t_ref / state_scale, all positive, shapes matching) inside its __post_init__ before any caller can use it. The two callers that consume an instance --- the scipy/CVODE wrapper in entropy_solver.py and the JAX CVODE factory in cvode_jax.py --- therefore do NOT need to re-check those invariants on entry.

NonDimScales(state_scale, t_ref, rhs_scale=None) dataclass

Per-component nondim scales for the entropy solver state vector.

Parameters:

Name Type Description Default
state_scale (ndarray, shape(n))

Physical-units scale per state component: y_phys = y_nd * state_scale. Different components can have different scales (entropy in J/kg/K, dSdr in J/kg/K/m, T_core in K), so the scale is per-element.

required
t_ref float

Time scale: t_phys = t_nd * t_ref. Same value for every state component.

required
rhs_scale (ndarray, shape(n) or None)

Optional precomputed RHS scale dydt_nd = dydt_phys * rhs_scale. If None, derived as t_ref / state_scale automatically. When provided, must satisfy the contract or __post_init__ raises.

None

n property

Number of state-vector components.