Skip to content

Installation

Choose a backend and install dependencies:

# JAX (CPU)
Install stella and choose a backend.

Via pip extras
--------------

```bash
# JAX (CPU)
pip install "stella[jax]"

# JAX on macOS with Metal (Apple Silicon)
pip install "stella[jax-mps]"

# PyTorch
pip install "stella[torch]"

From source (development)

git clone https://github.com/benjaminpope/stella
cd stella
pip install -e .[dev]

Selecting a backend

Set the backend before importing keras:

export KERAS_BACKEND=jax   # or torch

Quick sanity check

import os
os.environ.setdefault("KERAS_BACKEND", "jax")  # or "torch"
import keras
print("Backend:", keras.backend.backend())
m = keras.Sequential([
    keras.layers.Input((8,)),
    keras.layers.Dense(4, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid'),
])
print(m([[0]*8]).shape)

Backend validation at runtime

import stella
stella.require_backend()  # raises with install hint if backend missing

```