Backends (JAX and PyTorch)¶
stella uses Keras 3, which can run on multiple numerical backends. You can use either the JAX or the PyTorch backend with the same stella code.
Quick facts¶
- Select backend via the environment variable
KERAS_BACKENDbefore importingkeras. - The stella code is backend‑agnostic; no changes are required besides selecting the backend and installing the backend packages.
Setup: JAX¶
pip install -U keras jax jaxlib
export KERAS_BACKEND=jax
Setup: PyTorch¶
pip install -U keras torch
export KERAS_BACKEND=torch
Usage is identical¶
import os
os.environ.setdefault("KERAS_BACKEND", "jax") # or "torch"
import keras
m = keras.models.load_model("/path/to/model.keras", compile=False)
y = m.predict(x)
Troubleshooting¶
- If you see a backend mismatch, ensure
KERAS_BACKENDis set before any import ofkerasoccurs in your Python process. - Some ops or layers may have different performance characteristics across backends.
Inspect, swap, and benchmark¶
import stella
stella.check_backend()
stella.swap_backend('torch', accelerator='mps') # prepare env for PyTorch Metal