neurojax-dev

star 4

Guidelines for developing NeuroJAX (OSL-JAX) components.

m9h By m9h schedule Updated 1/29/2026

name: neurojax_dev description: Guidelines for developing NeuroJAX (OSL-JAX) components.

NeuroJAX (OSL-JAX) Development

Philosophy

NeuroJAX follows the "Kidger Stack" philosophy:

  • State: All models must be equinox.Module. State is explicit and immutable.
  • Solvers: Use lineax for linear solves and optimistix for non-linear optimization.
  • Differentiation: Everything must be differentiable. Avoid numpy (except for I/O); use jax.numpy.
  • Typing: Use jaxtyping to enforce shapes, e.g., Float[Array, "time sensors"].

Directory Structure

  • src/neurojax/glm.py: Mass-univariate statistics.
  • src/neurojax/inverse/: Beamformers and source reconstruction.
  • src/neurojax/models/: Biophysical and generative models (Diffrax, Equinox).
  • src/neurojax/utils/: Bridges and helpers.

Common Patterns

GLM / Linear Solvers

When solving $Ax=b$, prefer lineax over jnp.linalg.solve:

operator = lx.MatrixLinearOperator(A)
solution = lx.linear_solve(operator, b, solver=lx.QR())

Random Keys

Passing key is mandatory for stochastic operations. Split keys early:

keys = jax.random.split(key, num=100)
jax.vmap(func)(keys)
Install via CLI
npx skills add https://github.com/m9h/neurojax --skill neurojax-dev
Repository Details
star Stars 4
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator