add-archon-model

star 5.3k

Guide for adding a new model to the Archon engine. Use when user wants to add support for a new HuggingFace model architecture in ArchonEngine.

areal-project By areal-project schedule Updated 3/24/2026

name: add-archon-model description: Guide for adding a new model to the Archon engine. Use when user wants to add support for a new HuggingFace model architecture in ArchonEngine.

Add Archon Model

Add support for a new HuggingFace model architecture in the Archon training engine.

When to Use

This skill is triggered when:

  • User asks "how do I add a model to Archon?"
  • User wants to support a new model family (e.g., Llama, Mistral, DeepSeek) in ArchonEngine
  • User mentions adding a new ModelSpec or model type for Archon

Prerequisites

Before starting, ensure:

  • The target model is available on HuggingFace (has config.json with model_type)
  • You know the HuggingFace model ID (e.g., meta-llama/Llama-3-8B)
  • The model uses a standard transformer architecture (decoder-only)

Step-by-Step Guide

Step 1: Analyze the Target Model Architecture

Read the HuggingFace model's source code to extract key architecture information.

Action: Fetch and analyze the model's HuggingFace configuration and modeling files.

  1. Read the model's config.json (via AutoConfig.from_pretrained) to identify:

    • model_type string (this is the key used for registry lookup)
    • All architecture hyperparameters (hidden_size, num_layers, etc.)
    • Any model-specific fields (e.g., qk_norm, attention_bias, MoE fields)
  2. Read the HuggingFace modeling_*.py source to identify:

    • Attention variant: Does it have Q/K norm? Attention bias? Sliding window? Multi-latent attention?
    • FFN variant: SwiGLU (gate_proj + up_proj + down_proj)? GeGLU? Standard MLP?
    • MoE support: Does it have MoE layers? What router type? Shared experts?
    • RoPE variant: Standard RoPE? YaRN? NTK-aware scaling? What is the inv_freq formula?
    • Normalization: RMSNorm or LayerNorm? Pre-norm or post-norm? Elementwise affine?
    • Weight tying: Does tie_word_embeddings appear in config?
    • State dict key names: What are the HF weight key naming conventions?
  3. Summarize findings in a checklist like:

Target model: <name>
HF model_type: "<model_type>" (and variants like "<model_type>_moe" if applicable)
Attention: [standard GQA / with QK norm / with bias / sliding window / ...]
FFN: [SwiGLU / GeGLU / standard MLP / ...]
MoE: [no / yes - num_experts, top_k, shared_experts]
RoPE: [standard / YaRN / NTK-aware / ...]
Norm: [RMSNorm / LayerNorm] with [pre-norm / post-norm]
Weight tying: [yes / no]

Step 2: Select the Reference Model

Choose the closest existing implementation as a starting point:

Target characteristics Reference Why
Dense-only, standard GQA, no QK norm qwen2 Simplest baseline, pure dense
Has QK norm, or has MoE support qwen3 Supports QK norm + MoE + shared experts

Action: Copy the reference model directory as the starting point:

areal/experimental/models/archon/<model>/
  __init__.py
  spec.py
  model/
    args.py
    model.py
    rope.py
    state_dict_adapter.py
  infra/
    parallelize.py

Step 3: Implement args.py

Adapt <Model>ModelArgs to match the target model's HuggingFace config fields.

Key changes from reference:

  1. Update the @dataclass fields to match the target model's hyperparameters:

    • Field names should use Archon conventions (dim, n_layers, n_heads, n_kv_heads, vocab_size, head_dim, hidden_dim, norm_eps, rope_theta, etc.)
    • Default values should match the smallest variant of the target model
    • Add model-specific fields (e.g., attention_bias, qk_norm, sliding_window)
  2. Update from_hf_config() to correctly map HuggingFace config attributes:

    • Use getattr(hf_config, "field_name", default) for optional fields
    • Handle variant-specific fields (e.g., MoE fields only present in MoE variants)
    • The method must return an instance of the model args class

Critical: Verify every field mapping against the HF model's config.json. Incorrect mappings here cause silent errors downstream.

Base class contract (BaseModelArgs):

@dataclass
class <Model>ModelArgs(BaseModelArgs):
    # ... model-specific fields ...

    @classmethod
    def from_hf_config(
        cls,
        hf_config: PretrainedConfig,
        is_critic: bool = False,
        **kwargs,
    ) -> <Model>ModelArgs:
        # Map HF config fields to Archon model args
        ...

Step 4: Implement model.py

Adapt the model architecture to match the target model.

Key components to adapt:

  1. Normalization (RMSNorm or similar):

    • Check if elementwise_affine is configurable
    • Check the epsilon default value
    • If the model uses LayerNorm, implement accordingly
  2. Attention module:

    • Q/K/V projection: Check bias presence (nn.Linear(..., bias=True/False))
    • QK norm: Add q_norm/k_norm if the model has them, remove if it doesn't
    • GQA: n_kv_heads < n_heads for grouped-query attention
    • Ulysses SP: Keep the set_cp_group / _sp_enabled pattern from the reference
    • Output projection: Check bias presence
  3. FeedForward module:

    • SwiGLU: w2(silu(w1(x)) * w3(x)) -- most common for modern LLMs
    • Check bias in linear layers
    • For MoE models: MoE module replaces FeedForward on designated layers
  4. TransformerBlock: Pre-norm (most modern LLMs) vs post-norm

    • MoE layer detection via _is_moe_layer() if applicable
  5. Top-level Model (<Model>Model(BaseArchonModel)):

    • tok_embeddings, layers (as ModuleDict), norm, output/score
    • init_weights(): Match initialization scheme from HF
    • init_buffers(): RoPE cache + MoE buffers
    • forward(): Must follow BaseArchonModel signature: (tokens, positions, cu_seqlens, max_seqlen, tree_attn_meta=None) -> Tensor

Base class contract (BaseArchonModel):

class <Model>Model(BaseArchonModel):
    def forward(self, tokens, positions, cu_seqlens, max_seqlen, tree_attn_meta=None) -> torch.Tensor: ...
    def init_weights(self) -> None: ...
    def init_buffers(self, buffer_device) -> None: ...

Step 5: Implement rope.py

Handle the rotary position embedding variant.

Options:

  1. Standard RoPE (same as qwen2/qwen3): Re-export from qwen2:

    from areal.experimental.models.archon.qwen2.model.rope import (
        apply_rotary_emb,
        precompute_rope_cache,
        repeat_kv,
        reshape_for_broadcast,
        rotate_half,
    )
    
  2. Custom RoPE (YaRN, NTK-aware, etc.): Implement custom precompute_rope_cache() and apply_rotary_emb() functions. The key difference is usually in how inv_freq is computed (scaling factors, interpolation, etc.).

Step 6: Implement state_dict_adapter.py

Map between HuggingFace and Archon weight key names.

This is the most error-prone step. The adapter must correctly handle:

  1. Key name mapping (from_hf_map dict):

    • Embedding: model.embed_tokens.weight -> tok_embeddings.weight
    • Attention: model.layers.{}.self_attn.q_proj.weight -> layers.{}.attention.wq.weight
    • FFN: model.layers.{}.mlp.gate_proj.weight -> layers.{}.feed_forward.w1.weight
    • Norms: model.layers.{}.input_layernorm.weight -> layers.{}.attention_norm.weight
    • Output: lm_head.weight -> output.weight
    • Skip keys (set to None): rotary_emb.inv_freq (computed at runtime)
    • Model-specific keys: bias terms, QK norm weights, etc.
  2. Reverse mapping (to_hf_map): Auto-generated from from_hf_map

  3. MoE expert weights (if applicable): 3D<->2D conversion for expert weights. Copy the MoE handling from qwen3 if the model has MoE.

  4. Weight tying: Skip output.weight during to_hf() if tie_word_embeddings=True

Verification approach: After implementation, the adapter should satisfy:

# Roundtrip: archon -> hf -> archon preserves all keys
hf_sd = adapter.to_hf(archon_sd)
roundtrip_sd = adapter.from_hf(hf_sd)
assert set(roundtrip_sd.keys()) == set(archon_sd.keys())

Base class contract (BaseStateDictAdapter):

class <Model>StateDictAdapter(BaseStateDictAdapter):
    def from_hf(self, hf_state_dict) -> dict[str, Any]: ...
    def to_hf(self, archon_state_dict) -> dict[str, Any]: ...
    def convert_single_to_hf(self, name, tensor) -> list[tuple[str, torch.Tensor]]: ...

Step 7: Implement parallelize.py

Define the parallelization strategy for the model.

The parallelize function applies parallelism in this order:

  1. TP (Tensor Parallelism) -- shard attention/FFN across devices
  2. EP (Expert Parallelism) -- for MoE models only
  3. CP (Context Parallelism / Ulysses SP) -- sequence parallelism
  4. AC (Activation Checkpointing) -- memory optimization
  5. torch.compile -- compilation optimization
  6. FSDP (Fully Sharded Data Parallelism) -- data parallelism

Key adaptations by model architecture:

  • Attention with QK norm: wq/wk use use_local_output=False (DTensor output for norm), add SequenceParallel(sequence_dim=2) for q_norm/k_norm
  • Attention without QK norm: wq/wk/wv all use use_local_output=True
  • Attention with bias: Bias terms follow the same parallel plan as their weights
  • MoE layers: Separate TP plan for MoE input/output, router gate, and expert weights. Copy from qwen3's apply_moe_ep_tp() and apply_non_moe_tp()
  • Dense-only models: Simpler plan without MoE handling. Copy from qwen2

Function signature (must match ParallelizeFn protocol):

def parallelize_<model>(
    model: nn.Module,
    parallel_dims: ArchonParallelDims,
    param_dtype: torch.dtype = torch.bfloat16,
    reduce_dtype: torch.dtype = torch.float32,
    loss_parallel: bool = True,
    cpu_offload: bool = False,
    reshard_after_forward_policy: str = "default",
    ac_config: ActivationCheckpointConfig | None = None,
    enable_compile: bool = True,
) -> nn.Module:

Step 8: Create spec.py and Register

Assemble the ModelSpec and register it.

from areal.experimental.models.archon.model_spec import ModelSpec, register_model_spec
from areal.experimental.models.archon.pipeline_parallel import pipeline_llm
from areal.experimental.models.archon.<model>.infra.parallelize import parallelize_<model>
from areal.experimental.models.archon.<model>.model.args import <Model>ModelArgs
from areal.experimental.models.archon.<model>.model.model import <Model>Model
from areal.experimental.models.archon.<model>.model.state_dict_adapter import (
    <Model>StateDictAdapter,
)

<MODEL>_SPEC = ModelSpec(
    name="<Model>",
    model_class=<Model>Model,
    model_args_class=<Model>ModelArgs,
    state_dict_adapter_class=<Model>StateDictAdapter,
    parallelize_fn=parallelize_<model>,
    supported_model_types=frozenset({"<model_type>"}),  # From HF config.json
    pipelining_fn=pipeline_llm,
)

# Auto-register when module is imported
register_model_spec(<MODEL>_SPEC)

__all__ = ["<MODEL>_SPEC"]

Note: supported_model_types should include all HF model_type strings that this implementation handles (e.g., {"qwen3", "qwen3_moe"} for Qwen3).

Step 9: Register in __init__.py

Add the import to areal/experimental/models/archon/__init__.py:

from areal.experimental.models.archon.<model> import spec as <model>_spec  # noqa: F401

This triggers auto-registration when the module is imported.

Step 10: Verify and Test

Verification should be done in stages, adapting based on available hardware and the test patterns in tests/experimental/archon/.

Before writing tests, examine the existing test files to understand current patterns:

tests/experimental/archon/
  conftest.py             -- Pytest configuration (version checks)
  utils.py                -- Shared utilities (model loading, comparison)
  test_qwen3_args.py      -- Args unit tests (CPU-only)
  test_state_dict_adapter.py  -- State dict roundtrip tests
  test_weight_sync.py     -- Weight completeness tests (meta device)
  test_forward.py         -- Forward precision comparison (single GPU)
  ...

Test stages (write tests appropriate for the model's complexity):

Stage 1: Args Tests (CPU-only, always write these)

Test from_hf_config() with mock HuggingFace configs:

# Pattern: Create mock PretrainedConfig, verify args mapping
from unittest.mock import MagicMock

def test_args_from_hf_config():
    hf_config = MagicMock()
    hf_config.hidden_size = 4096
    hf_config.num_hidden_layers = 32
    # ... set all required fields
    args = <Model>ModelArgs.from_hf_config(hf_config)
    assert args.dim == 4096
    assert args.n_layers == 32

Stage 2: State Dict Adapter Tests (CPU-only)

Test key mapping roundtrip:

def test_state_dict_roundtrip():
    # Create adapter with mock config
    adapter = <Model>StateDictAdapter(mock_config)
    # Create fake archon state dict with expected keys
    archon_sd = {"tok_embeddings.weight": torch.randn(vocab, dim), ...}
    # Roundtrip
    hf_sd = adapter.to_hf(archon_sd)
    roundtrip = adapter.from_hf(hf_sd)
    assert set(roundtrip.keys()) == set(archon_sd.keys())

Stage 3: Weight Completeness (meta device, CPU-only)

Verify all model parameters have HF mappings:

def test_weight_completeness():
    # Create model on meta device
    with torch.device("meta"):
        model = <Model>Model(args)
    adapter = <Model>StateDictAdapter(hf_config)
    # Check every archon param has a HF mapping
    for name, _ in model.named_parameters():
        hf_pairs = adapter.convert_single_to_hf(name, torch.empty(0))
        assert len(hf_pairs) > 0, f"No HF mapping for {name}"

Stage 4: Forward Precision (single GPU, if available)

Compare Archon model output against HuggingFace reference:

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
def test_forward_matches_hf():
    # Load both HF and Archon models
    # Run forward on same input
    # Compare logits within tolerance

Important: Do NOT hardcode the test categories. Inspect the existing test files in tests/experimental/archon/ and follow the same patterns, fixtures, and markers. Adapt test scope to the model's specific features (e.g., add MoE-specific tests only if the model has MoE).

Reference Implementations

Model Directory Features
Qwen2 areal/experimental/models/archon/qwen2/ Dense, attention bias, no QK norm
Qwen3 areal/experimental/models/archon/qwen3/ Dense + MoE, QK norm, no attention bias, shared experts

Architecture Decision Map

Feature qwen2 qwen3 What to check in target model
Attention bias Yes No attention_bias in HF config
QK norm No Yes qk_norm in HF config or QKNorm module in modeling file
MoE No Yes num_experts/num_local_experts in HF config
Shared experts No Yes num_shared_experts in HF config
Decoder sparse step No Yes decoder_sparse_step in HF config
Weight tying Both Both tie_word_embeddings in HF config
RoPE Standard Standard (re-export qwen2) Check inv_freq formula in HF modeling code

Common Mistakes

  • Not mapping all HF keys in state_dict_adapter.py (causes silent weight drops)
  • Wrong from_hf_config() field mapping (uses wrong HF config attribute name)
  • Forgetting to handle None keys in from_hf_map (keys to skip like rotary_emb.inv_freq)
  • Missing MoE expert weight 3D<->2D conversion when model has MoE
  • Wrong TP plan for attention with/without QK norm (use_local_output must match)
  • Forgetting to add import line in areal/experimental/models/archon/__init__.py
  • Not including all model_type variants in supported_model_types frozenset
  • Using print instead of areal.utils.logging.getLogger()

File Checklist

After completion, verify all files exist and are consistent:

  • areal/experimental/models/archon/<model>/__init__.py
  • areal/experimental/models/archon/<model>/spec.py -- ModelSpec + register
  • areal/experimental/models/archon/<model>/model/args.py -- ModelArgs + from_hf_config
  • areal/experimental/models/archon/<model>/model/model.py -- Model + Attention + FFN
  • areal/experimental/models/archon/<model>/model/rope.py -- RoPE (or re-export)
  • areal/experimental/models/archon/<model>/model/state_dict_adapter.py -- Key mapping
  • areal/experimental/models/archon/<model>/infra/parallelize.py -- Parallel strategy
  • areal/experimental/models/archon/__init__.py -- Import line added
  • tests/experimental/archon/test_<model>_*.py -- Tests

Install via CLI
npx skills add https://github.com/areal-project/AReaL --skill add-archon-model
Repository Details
star Stars 5,311
call_split Forks 522
navigation Branch main
article Path SKILL.md
More from Creator
areal-project
areal-project Explore all skills →