name: add-archon-model description: Guide for adding a new model to the Archon engine. Use when user wants to add support for a new HuggingFace model architecture in ArchonEngine.
Add Archon Model
Add support for a new HuggingFace model architecture in the Archon training engine.
When to Use
This skill is triggered when:
- User asks "how do I add a model to Archon?"
- User wants to support a new model family (e.g., Llama, Mistral, DeepSeek) in ArchonEngine
- User mentions adding a new
ModelSpecor model type for Archon
Prerequisites
Before starting, ensure:
- The target model is available on HuggingFace (has
config.jsonwithmodel_type) - You know the HuggingFace model ID (e.g.,
meta-llama/Llama-3-8B) - The model uses a standard transformer architecture (decoder-only)
Step-by-Step Guide
Step 1: Analyze the Target Model Architecture
Read the HuggingFace model's source code to extract key architecture information.
Action: Fetch and analyze the model's HuggingFace configuration and modeling files.
Read the model's
config.json(viaAutoConfig.from_pretrained) to identify:model_typestring (this is the key used for registry lookup)- All architecture hyperparameters (hidden_size, num_layers, etc.)
- Any model-specific fields (e.g.,
qk_norm,attention_bias, MoE fields)
Read the HuggingFace
modeling_*.pysource to identify:- Attention variant: Does it have Q/K norm? Attention bias? Sliding window? Multi-latent attention?
- FFN variant: SwiGLU (gate_proj + up_proj + down_proj)? GeGLU? Standard MLP?
- MoE support: Does it have MoE layers? What router type? Shared experts?
- RoPE variant: Standard RoPE? YaRN? NTK-aware scaling? What is the inv_freq formula?
- Normalization: RMSNorm or LayerNorm? Pre-norm or post-norm? Elementwise affine?
- Weight tying: Does
tie_word_embeddingsappear in config? - State dict key names: What are the HF weight key naming conventions?
Summarize findings in a checklist like:
Target model: <name>
HF model_type: "<model_type>" (and variants like "<model_type>_moe" if applicable)
Attention: [standard GQA / with QK norm / with bias / sliding window / ...]
FFN: [SwiGLU / GeGLU / standard MLP / ...]
MoE: [no / yes - num_experts, top_k, shared_experts]
RoPE: [standard / YaRN / NTK-aware / ...]
Norm: [RMSNorm / LayerNorm] with [pre-norm / post-norm]
Weight tying: [yes / no]
Step 2: Select the Reference Model
Choose the closest existing implementation as a starting point:
| Target characteristics | Reference | Why |
|---|---|---|
| Dense-only, standard GQA, no QK norm | qwen2 |
Simplest baseline, pure dense |
| Has QK norm, or has MoE support | qwen3 |
Supports QK norm + MoE + shared experts |
Action: Copy the reference model directory as the starting point:
areal/experimental/models/archon/<model>/
__init__.py
spec.py
model/
args.py
model.py
rope.py
state_dict_adapter.py
infra/
parallelize.py
Step 3: Implement args.py
Adapt <Model>ModelArgs to match the target model's HuggingFace config fields.
Key changes from reference:
Update the
@dataclassfields to match the target model's hyperparameters:- Field names should use Archon conventions (
dim,n_layers,n_heads,n_kv_heads,vocab_size,head_dim,hidden_dim,norm_eps,rope_theta, etc.) - Default values should match the smallest variant of the target model
- Add model-specific fields (e.g.,
attention_bias,qk_norm,sliding_window)
- Field names should use Archon conventions (
Update
from_hf_config()to correctly map HuggingFace config attributes:- Use
getattr(hf_config, "field_name", default)for optional fields - Handle variant-specific fields (e.g., MoE fields only present in MoE variants)
- The method must return an instance of the model args class
- Use
Critical: Verify every field mapping against the HF model's config.json. Incorrect
mappings here cause silent errors downstream.
Base class contract (BaseModelArgs):
@dataclass
class <Model>ModelArgs(BaseModelArgs):
# ... model-specific fields ...
@classmethod
def from_hf_config(
cls,
hf_config: PretrainedConfig,
is_critic: bool = False,
**kwargs,
) -> <Model>ModelArgs:
# Map HF config fields to Archon model args
...
Step 4: Implement model.py
Adapt the model architecture to match the target model.
Key components to adapt:
Normalization (
RMSNormor similar):- Check if
elementwise_affineis configurable - Check the epsilon default value
- If the model uses
LayerNorm, implement accordingly
- Check if
Attention module:
- Q/K/V projection: Check bias presence (
nn.Linear(..., bias=True/False)) - QK norm: Add
q_norm/k_normif the model has them, remove if it doesn't - GQA:
n_kv_heads<n_headsfor grouped-query attention - Ulysses SP: Keep the
set_cp_group/_sp_enabledpattern from the reference - Output projection: Check bias presence
- Q/K/V projection: Check bias presence (
FeedForward module:
- SwiGLU:
w2(silu(w1(x)) * w3(x))-- most common for modern LLMs - Check bias in linear layers
- For MoE models:
MoEmodule replacesFeedForwardon designated layers
- SwiGLU:
TransformerBlock: Pre-norm (most modern LLMs) vs post-norm
- MoE layer detection via
_is_moe_layer()if applicable
- MoE layer detection via
Top-level Model (
<Model>Model(BaseArchonModel)):tok_embeddings,layers(asModuleDict),norm,output/scoreinit_weights(): Match initialization scheme from HFinit_buffers(): RoPE cache + MoE buffersforward(): Must followBaseArchonModelsignature:(tokens, positions, cu_seqlens, max_seqlen, tree_attn_meta=None) -> Tensor
Base class contract (BaseArchonModel):
class <Model>Model(BaseArchonModel):
def forward(self, tokens, positions, cu_seqlens, max_seqlen, tree_attn_meta=None) -> torch.Tensor: ...
def init_weights(self) -> None: ...
def init_buffers(self, buffer_device) -> None: ...
Step 5: Implement rope.py
Handle the rotary position embedding variant.
Options:
Standard RoPE (same as qwen2/qwen3): Re-export from qwen2:
from areal.experimental.models.archon.qwen2.model.rope import ( apply_rotary_emb, precompute_rope_cache, repeat_kv, reshape_for_broadcast, rotate_half, )Custom RoPE (YaRN, NTK-aware, etc.): Implement custom
precompute_rope_cache()andapply_rotary_emb()functions. The key difference is usually in howinv_freqis computed (scaling factors, interpolation, etc.).
Step 6: Implement state_dict_adapter.py
Map between HuggingFace and Archon weight key names.
This is the most error-prone step. The adapter must correctly handle:
Key name mapping (
from_hf_mapdict):- Embedding:
model.embed_tokens.weight->tok_embeddings.weight - Attention:
model.layers.{}.self_attn.q_proj.weight->layers.{}.attention.wq.weight - FFN:
model.layers.{}.mlp.gate_proj.weight->layers.{}.feed_forward.w1.weight - Norms:
model.layers.{}.input_layernorm.weight->layers.{}.attention_norm.weight - Output:
lm_head.weight->output.weight - Skip keys (set to
None):rotary_emb.inv_freq(computed at runtime) - Model-specific keys: bias terms, QK norm weights, etc.
- Embedding:
Reverse mapping (
to_hf_map): Auto-generated fromfrom_hf_mapMoE expert weights (if applicable): 3D<->2D conversion for expert weights. Copy the MoE handling from qwen3 if the model has MoE.
Weight tying: Skip
output.weightduringto_hf()iftie_word_embeddings=True
Verification approach: After implementation, the adapter should satisfy:
# Roundtrip: archon -> hf -> archon preserves all keys
hf_sd = adapter.to_hf(archon_sd)
roundtrip_sd = adapter.from_hf(hf_sd)
assert set(roundtrip_sd.keys()) == set(archon_sd.keys())
Base class contract (BaseStateDictAdapter):
class <Model>StateDictAdapter(BaseStateDictAdapter):
def from_hf(self, hf_state_dict) -> dict[str, Any]: ...
def to_hf(self, archon_state_dict) -> dict[str, Any]: ...
def convert_single_to_hf(self, name, tensor) -> list[tuple[str, torch.Tensor]]: ...
Step 7: Implement parallelize.py
Define the parallelization strategy for the model.
The parallelize function applies parallelism in this order:
- TP (Tensor Parallelism) -- shard attention/FFN across devices
- EP (Expert Parallelism) -- for MoE models only
- CP (Context Parallelism / Ulysses SP) -- sequence parallelism
- AC (Activation Checkpointing) -- memory optimization
- torch.compile -- compilation optimization
- FSDP (Fully Sharded Data Parallelism) -- data parallelism
Key adaptations by model architecture:
- Attention with QK norm: wq/wk use
use_local_output=False(DTensor output for norm), addSequenceParallel(sequence_dim=2)for q_norm/k_norm - Attention without QK norm: wq/wk/wv all use
use_local_output=True - Attention with bias: Bias terms follow the same parallel plan as their weights
- MoE layers: Separate TP plan for MoE input/output, router gate, and expert
weights. Copy from qwen3's
apply_moe_ep_tp()andapply_non_moe_tp() - Dense-only models: Simpler plan without MoE handling. Copy from qwen2
Function signature (must match ParallelizeFn protocol):
def parallelize_<model>(
model: nn.Module,
parallel_dims: ArchonParallelDims,
param_dtype: torch.dtype = torch.bfloat16,
reduce_dtype: torch.dtype = torch.float32,
loss_parallel: bool = True,
cpu_offload: bool = False,
reshard_after_forward_policy: str = "default",
ac_config: ActivationCheckpointConfig | None = None,
enable_compile: bool = True,
) -> nn.Module:
Step 8: Create spec.py and Register
Assemble the ModelSpec and register it.
from areal.experimental.models.archon.model_spec import ModelSpec, register_model_spec
from areal.experimental.models.archon.pipeline_parallel import pipeline_llm
from areal.experimental.models.archon.<model>.infra.parallelize import parallelize_<model>
from areal.experimental.models.archon.<model>.model.args import <Model>ModelArgs
from areal.experimental.models.archon.<model>.model.model import <Model>Model
from areal.experimental.models.archon.<model>.model.state_dict_adapter import (
<Model>StateDictAdapter,
)
<MODEL>_SPEC = ModelSpec(
name="<Model>",
model_class=<Model>Model,
model_args_class=<Model>ModelArgs,
state_dict_adapter_class=<Model>StateDictAdapter,
parallelize_fn=parallelize_<model>,
supported_model_types=frozenset({"<model_type>"}), # From HF config.json
pipelining_fn=pipeline_llm,
)
# Auto-register when module is imported
register_model_spec(<MODEL>_SPEC)
__all__ = ["<MODEL>_SPEC"]
Note: supported_model_types should include all HF model_type strings that this
implementation handles (e.g., {"qwen3", "qwen3_moe"} for Qwen3).
Step 9: Register in __init__.py
Add the import to areal/experimental/models/archon/__init__.py:
from areal.experimental.models.archon.<model> import spec as <model>_spec # noqa: F401
This triggers auto-registration when the module is imported.
Step 10: Verify and Test
Verification should be done in stages, adapting based on available hardware and the test
patterns in tests/experimental/archon/.
Before writing tests, examine the existing test files to understand current patterns:
tests/experimental/archon/
conftest.py -- Pytest configuration (version checks)
utils.py -- Shared utilities (model loading, comparison)
test_qwen3_args.py -- Args unit tests (CPU-only)
test_state_dict_adapter.py -- State dict roundtrip tests
test_weight_sync.py -- Weight completeness tests (meta device)
test_forward.py -- Forward precision comparison (single GPU)
...
Test stages (write tests appropriate for the model's complexity):
Stage 1: Args Tests (CPU-only, always write these)
Test from_hf_config() with mock HuggingFace configs:
# Pattern: Create mock PretrainedConfig, verify args mapping
from unittest.mock import MagicMock
def test_args_from_hf_config():
hf_config = MagicMock()
hf_config.hidden_size = 4096
hf_config.num_hidden_layers = 32
# ... set all required fields
args = <Model>ModelArgs.from_hf_config(hf_config)
assert args.dim == 4096
assert args.n_layers == 32
Stage 2: State Dict Adapter Tests (CPU-only)
Test key mapping roundtrip:
def test_state_dict_roundtrip():
# Create adapter with mock config
adapter = <Model>StateDictAdapter(mock_config)
# Create fake archon state dict with expected keys
archon_sd = {"tok_embeddings.weight": torch.randn(vocab, dim), ...}
# Roundtrip
hf_sd = adapter.to_hf(archon_sd)
roundtrip = adapter.from_hf(hf_sd)
assert set(roundtrip.keys()) == set(archon_sd.keys())
Stage 3: Weight Completeness (meta device, CPU-only)
Verify all model parameters have HF mappings:
def test_weight_completeness():
# Create model on meta device
with torch.device("meta"):
model = <Model>Model(args)
adapter = <Model>StateDictAdapter(hf_config)
# Check every archon param has a HF mapping
for name, _ in model.named_parameters():
hf_pairs = adapter.convert_single_to_hf(name, torch.empty(0))
assert len(hf_pairs) > 0, f"No HF mapping for {name}"
Stage 4: Forward Precision (single GPU, if available)
Compare Archon model output against HuggingFace reference:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
def test_forward_matches_hf():
# Load both HF and Archon models
# Run forward on same input
# Compare logits within tolerance
Important: Do NOT hardcode the test categories. Inspect the existing test files in
tests/experimental/archon/ and follow the same patterns, fixtures, and markers. Adapt
test scope to the model's specific features (e.g., add MoE-specific tests only if the
model has MoE).
Reference Implementations
| Model | Directory | Features |
|---|---|---|
| Qwen2 | areal/experimental/models/archon/qwen2/ |
Dense, attention bias, no QK norm |
| Qwen3 | areal/experimental/models/archon/qwen3/ |
Dense + MoE, QK norm, no attention bias, shared experts |
Architecture Decision Map
| Feature | qwen2 | qwen3 | What to check in target model |
|---|---|---|---|
| Attention bias | Yes | No | attention_bias in HF config |
| QK norm | No | Yes | qk_norm in HF config or QKNorm module in modeling file |
| MoE | No | Yes | num_experts/num_local_experts in HF config |
| Shared experts | No | Yes | num_shared_experts in HF config |
| Decoder sparse step | No | Yes | decoder_sparse_step in HF config |
| Weight tying | Both | Both | tie_word_embeddings in HF config |
| RoPE | Standard | Standard (re-export qwen2) | Check inv_freq formula in HF modeling code |
Common Mistakes
- Not mapping all HF keys in
state_dict_adapter.py(causes silent weight drops) - Wrong
from_hf_config()field mapping (uses wrong HF config attribute name) - Forgetting to handle
Nonekeys infrom_hf_map(keys to skip likerotary_emb.inv_freq) - Missing MoE expert weight 3D<->2D conversion when model has MoE
- Wrong TP plan for attention with/without QK norm (
use_local_outputmust match) - Forgetting to add import line in
areal/experimental/models/archon/__init__.py - Not including all
model_typevariants insupported_model_typesfrozenset - Using
printinstead ofareal.utils.logging.getLogger()
File Checklist
After completion, verify all files exist and are consistent:
-
areal/experimental/models/archon/<model>/__init__.py -
areal/experimental/models/archon/<model>/spec.py-- ModelSpec + register -
areal/experimental/models/archon/<model>/model/args.py-- ModelArgs + from_hf_config -
areal/experimental/models/archon/<model>/model/model.py-- Model + Attention + FFN -
areal/experimental/models/archon/<model>/model/rope.py-- RoPE (or re-export) -
areal/experimental/models/archon/<model>/model/state_dict_adapter.py-- Key mapping -
areal/experimental/models/archon/<model>/infra/parallelize.py-- Parallel strategy -
areal/experimental/models/archon/__init__.py-- Import line added -
tests/experimental/archon/test_<model>_*.py-- Tests