name: sglang-diffusion-add-model description: Use when adding a new diffusion model or Diffusers pipeline to SGLang.
Add a Diffusion Model to SGLang
Use this skill when adding a new diffusion model or pipeline variant to sglang.multimodal_gen.
Two Pipeline Styles
Style A: Hybrid Monolithic Pipeline (Recommended)
The recommended default for most new models. Uses a three-stage structure:
BeforeDenoisingStage (model-specific) --> DenoisingStage (standard) --> DecodingStage (standard)
- BeforeDenoisingStage: A single, model-specific stage that consolidates all pre-processing logic: input validation, text encoding, image encoding, latent preparation, timestep setup. This stage is unique per model.
- DenoisingStage: Framework-standard stage for the denoising loop (DiT/UNet forward passes). Shared across models.
- DecodingStage: Framework-standard stage for VAE decoding. Shared across models.
Why recommended? Modern diffusion models have highly heterogeneous pre-processing requirements (different text encoders, different latent formats, different conditioning mechanisms). The Hybrid approach keeps pre-processing isolated per model, avoids fragile shared stages with excessive conditional logic, and lets developers port Diffusers reference code quickly.
Style B: Modular Composition Style
Uses the framework's fine-grained standard stages (TextEncodingStage, LatentPreparationStage, TimestepPreparationStage, etc.) to build the pipeline by composition.
This style is appropriate when:
- The new model's pre-processing can largely reuse existing stages — e.g., a model that uses standard CLIP/T5 text encoding + standard latent preparation with minimal customization. In this case,
add_standard_t2i_stages()oradd_standard_ti2i_stages()may be all you need. - A model-specific optimization needs to be extracted as a standalone stage — e.g., a specialized encoding or conditioning step that benefits from being a separate stage for profiling, parallelism control, or reuse across multiple pipeline variants.
See existing Modular examples: QwenImagePipeline (uses add_standard_t2i_stages), FluxPipeline, WanPipeline.
How to Choose
| Situation | Recommended Style |
|---|---|
| Model has unique/complex pre-processing (VLM captioning, AR token generation, custom latent packing, etc.) | Hybrid — consolidate into a BeforeDenoisingStage |
| Model fits neatly into standard text-to-image or text+image-to-image pattern | Modular — use add_standard_t2i_stages() / add_standard_ti2i_stages() |
| Porting a Diffusers pipeline with many custom steps | Hybrid — copy the __call__ logic into a single stage |
| Adding a variant of an existing model that shares most logic | Modular — reuse existing stages, customize via PipelineConfig callbacks |
| A specific pre-processing step needs special parallelism or profiling isolation | Modular — extract that step as a dedicated stage |
Key principle (both styles): The stage(s) before DenoisingStage must produce a Req batch object with all the standard tensor fields that DenoisingStage expects (latents, timesteps, prompt_embeds, etc.). As long as this contract is met, the pipeline remains composable regardless of which style you use.
Key Files and Directories
| Purpose | Path |
|---|---|
| Pipeline classes | python/sglang/multimodal_gen/runtime/pipelines/ |
| Model-specific stages | python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/ |
| PipelineStage base class | python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py |
| Pipeline base class | python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py |
| Standard stages (Denoising, Decoding) | python/sglang/multimodal_gen/runtime/pipelines_core/stages/ |
| Pipeline configs | python/sglang/multimodal_gen/configs/pipeline_configs/ |
| Sampling params | python/sglang/multimodal_gen/configs/sample/ |
| DiT model implementations | python/sglang/multimodal_gen/runtime/models/dits/ |
| VAE implementations | python/sglang/multimodal_gen/runtime/models/vaes/ |
| Encoder implementations | python/sglang/multimodal_gen/runtime/models/encoders/ |
| Scheduler implementations | python/sglang/multimodal_gen/runtime/models/schedulers/ |
| Model/VAE/DiT configs | python/sglang/multimodal_gen/configs/models/dits/, vaes/, encoders/ |
| Central registry | python/sglang/multimodal_gen/registry.py |
| Model component registry | python/sglang/multimodal_gen/runtime/models/registry.py |
| Current support list | docs/diffusion/compatibility_matrix.md |
Step-by-Step Implementation
Step 1: Obtain and Study the Reference Implementation
Before writing any code, obtain the model's reference implementation or Diffusers pipeline code. You need the actual source code to work from — do not guess or assume the model's architecture. If the user already gave a HuggingFace model ID or repo, inspect that yourself first. Ask the user only when the reference implementation is private, ambiguous, or otherwise unavailable. Typical sources are:
- The model's Diffusers pipeline source (e.g., the
pipeline_*.pyfile from thediffuserslibrary or HuggingFace repo) - Or the model's official reference implementation (e.g., from the model author's GitHub repo)
- Or the HuggingFace model ID so you can look up
model_index.jsonand the associated pipeline class
Once you have the reference code, study it thoroughly:
- Find the model's
model_index.jsonto identify required modules (text_encoder, vae, transformer, scheduler, etc.) - Read the Diffusers pipeline's
__call__method end-to-end. Identify:- How text prompts are encoded
- How latents are prepared (shape, dtype, scaling)
- How timesteps/sigmas are computed
- What conditioning kwargs the DiT/UNet expects
- How the denoising loop works (classifier-free guidance, etc.)
- How VAE decoding is done (scaling factors, tiling, etc.)
Step 2: Evaluate Reuse of Existing Pipelines and Stages
Before creating any new files, check whether an existing pipeline or stage can be reused or extended. Only create new pipelines/stages when the existing ones would require extensive modifications or when no similar implementation exists.
Specifically:
- Compare the new model's architecture against existing pipelines before creating files. Current native families include LTX-2/2.3, HunyuanVideo/FastHunyuan, Wan/FastWan/TurboWan/LingBot World, MOVA, FLUX/FLUX.2/Klein, Z-Image, Qwen-Image/edit/layered, GLM-Image, SD3, Hunyuan3D, Helios, Cosmos3, SANA, FireRed, ERNIE-Image, JoyAI, and Ideogram4. If the new model shares most of its structure with an existing one (e.g., same text encoders, similar latent format, compatible denoising loop), prefer:
- Adding a new config variant to the existing pipeline rather than creating a new pipeline class
- Reusing the existing
BeforeDenoisingStagewith minor parameter differences - Using
add_standard_t2i_stages()/add_standard_ti2i_stages()/add_standard_ti2v_stages()if the model fits standard patterns
- Check existing stages in
runtime/pipelines_core/stages/andstages/model_specific_stages/. If an existing stage handles 80%+ of what the new model needs, extend it rather than duplicating it. - Check existing model components — many models share VAEs (e.g.,
AutoencoderKL), text encoders (CLIP, T5), and schedulers. Reuse these directly instead of re-implementing.
Rule of thumb: Only create a new file when the existing implementation would need substantial structural changes to accommodate the new model, or when no architecturally similar implementation exists.
Step 3: Implement Model Components
Adapt or implement the model's core components in the appropriate directories.
DiT/Transformer (runtime/models/dits/{model_name}.py):
# python/sglang/multimodal_gen/runtime/models/dits/my_model.py
import torch
import torch.nn as nn
from sglang.multimodal_gen.runtime.layers.layernorm import (
LayerNormScaleShift,
RMSNormScaleShift,
)
from sglang.multimodal_gen.runtime.layers.attention.selector import (
get_attn_backend,
)
class MyModelTransformer2DModel(nn.Module):
"""DiT model for MyModel.
Adapt from the Diffusers/reference implementation. Key points:
- Use SGLang's fused LayerNorm/RMSNorm ops (see `existing-fast-paths.md` under the benchmark/profile skill)
- Use SGLang's attention backend selector
- Keep the same parameter naming as Diffusers for weight loading compatibility
"""
def __init__(self, config):
super().__init__()
# ... model layers ...
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
# ... model-specific kwargs ...
) -> torch.Tensor:
# ... forward pass ...
return output
Tensor Parallel (TP) and Sequence Parallel (SP): For multi-GPU deployment, it is recommended to add TP/SP support to the DiT model. This can be done incrementally after the single-GPU implementation is verified. Reference existing implementations and adapt to your model's architecture:
- Wan model (
runtime/models/dits/wanvideo.py) — Full TP + SP reference:- TP: Uses
ColumnParallelLinearfor Q/K/V projections,RowParallelLinearfor output projections, attention heads divided bytp_size - SP: Sequence dimension sharding via
get_sp_world_size(), padding for alignment,sequence_model_parallel_all_gatherfor aggregation - Cross-attention skips SP (
skip_sequence_parallel=is_cross_attention)
- TP: Uses
- Qwen-Image model (
runtime/models/dits/qwen_image.py) — SP + USPAttention reference:- SP: Uses
USPAttention(Ulysses + Ring Attention), configured via--ulysses-degree/--ring-degree - TP: Uses
MergedColumnParallelLinearfor QKV (with Nunchaku quantization),ReplicatedLinearotherwise
- SP: Uses
Important: These are references only — each model has its own architecture and parallelism requirements. Consider:
- How attention heads can be divided across TP ranks
- Whether the model's sequence dimension is naturally shardable for SP
- Which linear layers benefit from column/row parallel sharding vs. replication
- Whether cross-attention or other special modules need SP exclusion
Key imports for distributed support:
from sglang.multimodal_gen.runtime.distributed import (
divide,
get_sp_group,
get_sp_world_size,
get_tp_world_size,
sequence_model_parallel_all_gather,
)
from sglang.multimodal_gen.runtime.layers.linear import (
ColumnParallelLinear,
RowParallelLinear,
ReplicatedLinear,
)
VAE (runtime/models/vaes/{model_name}.py): Implement if the model uses a non-standard VAE. Many models reuse existing VAEs.
Encoders (runtime/models/encoders/{model_name}.py): Implement if the model uses custom text/image encoders.
Schedulers (runtime/models/schedulers/{scheduler_name}.py): Implement if the model requires a custom scheduler not available in Diffusers.
Step 4: Create Model Configs
DiT Config (configs/models/dits/{model_name}.py):
# python/sglang/multimodal_gen/configs/models/dits/mymodel.py
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.dits.base import DiTConfig
@dataclass
class MyModelDitConfig(DiTConfig):
arch_config: dict = field(default_factory=lambda: {
"in_channels": 16,
"num_layers": 24,
"patch_size": 2,
# ... model-specific architecture params ...
})
VAE Config (configs/models/vaes/{model_name}.py):
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.vaes.base import VAEConfig
@dataclass
class MyModelVAEConfig(VAEConfig):
vae_scale_factor: int = 8
# ... VAE-specific params ...
Sampling Params (configs/sample/{model_name}.py):
from dataclasses import dataclass
from sglang.multimodal_gen.configs.sample.base import SamplingParams
@dataclass
class MyModelSamplingParams(SamplingParams):
num_inference_steps: int = 50
guidance_scale: float = 7.5
height: int = 1024
width: int = 1024
# ... model-specific defaults ...
Step 5: Create PipelineConfig
The PipelineConfig holds static model configuration and defines callback methods used by the standard DenoisingStage and DecodingStage.
# python/sglang/multimodal_gen/configs/pipeline_configs/my_model.py
from dataclasses import dataclass, field
import torch
from sglang.multimodal_gen.configs.models import DiTConfig, VAEConfig
from sglang.multimodal_gen.configs.pipeline_configs.base import (
ImagePipelineConfig,
ModelTaskType,
# PipelineConfig, # common base for many video pipelines
# SpatialImagePipelineConfig, # alternative base for spatial image models
)
from sglang.multimodal_gen.configs.models.dits.mymodel import MyModelDitConfig
from sglang.multimodal_gen.configs.models.vaes.mymodel import MyModelVAEConfig
@dataclass
class MyModelPipelineConfig(ImagePipelineConfig):
"""Pipeline config for MyModel.
This config provides callbacks that the standard DenoisingStage and
DecodingStage use during execution. The BeforeDenoisingStage handles
all model-specific pre-processing independently.
"""
task_type: ModelTaskType = ModelTaskType.T2I
vae_precision: str = "bf16"
should_use_guidance: bool = True
vae_tiling: bool = False
enable_autocast: bool = False
dit_config: DiTConfig = field(default_factory=MyModelDitConfig)
vae_config: VAEConfig = field(default_factory=MyModelVAEConfig)
# --- Callbacks used by DenoisingStage ---
def get_freqs_cis(self, batch, device, rotary_emb, dtype):
"""Prepare rotary position embeddings for the DiT."""
# Model-specific RoPE computation
...
return freqs_cis
def prepare_pos_cond_kwargs(self, batch, latent_model_input, t, **kwargs):
"""Build positive conditioning kwargs for each denoising step."""
return {
"hidden_states": latent_model_input,
"encoder_hidden_states": batch.prompt_embeds[0],
"timestep": t,
# ... model-specific kwargs ...
}
def prepare_neg_cond_kwargs(self, batch, latent_model_input, t, **kwargs):
"""Build negative conditioning kwargs for CFG."""
return {
"hidden_states": latent_model_input,
"encoder_hidden_states": batch.negative_prompt_embeds[0],
"timestep": t,
# ... model-specific kwargs ...
}
# --- Callbacks used by DecodingStage ---
def get_decode_scale_and_shift(self):
"""Return (scale, shift) for latent denormalization before VAE decode."""
return self.vae_config.latents_std, self.vae_config.latents_mean
def post_denoising_loop(self, latents, batch):
"""Optional post-processing after the denoising loop finishes."""
return latents.to(torch.bfloat16)
def post_decoding(self, frames, server_args):
"""Optional post-processing after VAE decoding."""
return frames
There is no separate VideoPipelineConfig base class. For video models, choose
ModelTaskType.T2V, ModelTaskType.I2V, or ModelTaskType.TI2V, and follow
existing video configs such as Wan, LTX, Hunyuan, Helios, or MOVA when deciding
whether to subclass PipelineConfig directly or use a model-specific base.
Important: The prepare_pos_cond_kwargs / prepare_neg_cond_kwargs methods define what the DiT receives at each denoising step. These must match the DiT's forward() signature.
Step 6: Implement the BeforeDenoisingStage (Core Step)
This is the heart of the Hybrid pattern. Create a single stage that handles ALL pre-processing.
# python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/my_model.py
import torch
from typing import List, Optional, Union
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class MyModelBeforeDenoisingStage(PipelineStage):
"""Monolithic pre-processing stage for MyModel.
Consolidates all logic before the denoising loop:
- Input validation
- Text/image encoding
- Latent preparation
- Timestep/sigma computation
This stage produces a Req batch with all fields required by
the standard DenoisingStage.
"""
def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
super().__init__()
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.transformer = transformer
self.scheduler = scheduler
# ... other initialization (image processors, scale factors, etc.) ...
# --- Internal helper methods ---
# Copy/adapt directly from the Diffusers reference pipeline.
# These are private to this stage; no need to make them reusable.
def _encode_prompt(self, prompt, device, dtype):
"""Encode text prompt into embeddings."""
# ... model-specific text encoding logic ...
return prompt_embeds, negative_prompt_embeds
def _prepare_latents(self, batch_size, height, width, dtype, device, generator):
"""Create initial noisy latents."""
# ... model-specific latent preparation ...
return latents
def _prepare_timesteps(self, num_inference_steps, device):
"""Compute the timestep/sigma schedule."""
# ... model-specific timestep computation ...
return timesteps, sigmas
# --- Main forward method ---
@torch.no_grad()
def forward(self, batch: Req, server_args: ServerArgs) -> Req:
"""Execute all pre-processing and populate batch for DenoisingStage.
This method mirrors the first half of a Diffusers pipeline __call__,
up to (but not including) the denoising loop.
"""
device = get_local_torch_device()
dtype = torch.bfloat16
generator = torch.Generator(device=device).manual_seed(batch.seed)
# 1. Encode prompt
prompt_embeds, negative_prompt_embeds = self._encode_prompt(
batch.prompt, device, dtype
)
# 2. Prepare latents
latents = self._prepare_latents(
batch_size=1,
height=batch.height,
width=batch.width,
dtype=dtype,
device=device,
generator=generator,
)
# 3. Prepare timesteps
timesteps, sigmas = self._prepare_timesteps(
batch.num_inference_steps, device
)
# 4. Populate batch with everything DenoisingStage needs
batch.prompt_embeds = [prompt_embeds]
batch.negative_prompt_embeds = [negative_prompt_embeds]
batch.latents = latents
batch.timesteps = timesteps
batch.num_inference_steps = len(timesteps)
batch.sigmas = sigmas
batch.generator = generator
batch.raw_latent_shape = latents.shape
batch.height = batch.height
batch.width = batch.width
return batch
Key fields that DenoisingStage expects on the batch (set these in your forward):
| Field | Type | Description |
|---|---|---|
batch.latents |
torch.Tensor |
Initial noisy latent tensor |
batch.timesteps |
torch.Tensor |
Timestep schedule |
batch.num_inference_steps |
int |
Number of denoising steps |
batch.sigmas |
list[float] |
Sigma schedule (as a list, not numpy) |
batch.prompt_embeds |
list[torch.Tensor] |
Positive prompt embeddings (wrapped in list) |
batch.negative_prompt_embeds |
list[torch.Tensor] |
Negative prompt embeddings (wrapped in list) |
batch.generator |
torch.Generator |
RNG generator for reproducibility |
batch.raw_latent_shape |
tuple |
Original latent shape before any packing |
batch.height / batch.width |
int |
Output dimensions |
Step 7: Define the Pipeline Class
The pipeline class is minimal -- it just wires the stages together.
# python/sglang/multimodal_gen/runtime/pipelines/my_model.py
from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline
from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines_core.stages import DenoisingStage
from sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.my_model import (
MyModelBeforeDenoisingStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
class MyModelPipeline(LoRAPipeline, ComposedPipelineBase):
pipeline_name = "MyModelPipeline" # Must match model_index.json _class_name
_required_config_modules = [
"text_encoder",
"tokenizer",
"vae",
"transformer",
"scheduler",
# ... list all modules from model_index.json ...
]
def create_pipeline_stages(self, server_args: ServerArgs):
# 1. Monolithic pre-processing (model-specific)
self.add_stage(
MyModelBeforeDenoisingStage(
vae=self.get_module("vae"),
text_encoder=self.get_module("text_encoder"),
tokenizer=self.get_module("tokenizer"),
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
# 2. Standard denoising loop (framework-provided)
self.add_stage(
DenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
# 3. Standard VAE decoding (framework-provided)
self.add_standard_decoding_stage()
# REQUIRED: This is how the registry discovers the pipeline
EntryClass = [MyModelPipeline]
Step 8: Register the Model
In python/sglang/multimodal_gen/registry.py, register your configs:
register_configs(
sampling_param_cls=MyModelSamplingParams,
pipeline_config_cls=MyModelPipelineConfig,
hf_model_paths=[
"org/my-model-name", # HuggingFace model ID(s)
],
model_detectors=[
lambda path: "my-model" in path.lower(),
],
)
register_configs() does not take a model_family argument. It registers the
sampling and pipeline config classes, then resolves models by exact
hf_model_paths or optional detector predicates. Prefer exact hf_model_paths
for public checkpoints used in docs or tests; use detector predicates only for
families where local mirrors, renamed repos, or generated paths are common.
The EntryClass in your pipeline file is automatically discovered by the registry's _discover_and_register_pipelines() function -- no additional registration needed for the pipeline class itself.
Step 9: Verify Output Quality
After implementation, you must verify that the generated output is not noise. A noisy or garbled output image/video is the most common sign of an incorrect implementation. Common causes include:
- Incorrect latent scale/shift factors (
get_decode_scale_and_shiftreturning wrong values) - Wrong timestep/sigma schedule (order, dtype, or value range)
- Mismatched conditioning kwargs (fields not matching the DiT's
forward()signature) - Incorrect VAE decoder configuration (wrong
vae_scale_factor, missing denormalization) - Rotary embedding style mismatch (
is_neox_styleset incorrectly) - Wrong prompt embedding format (missing list wrapping, wrong encoder output selection)
If the output is noise, the implementation is incorrect — do not ship it. Debug by:
- Comparing intermediate tensor values (latents, prompt_embeds, timesteps) against the Diffusers reference pipeline
- Running the Diffusers pipeline and SGLang pipeline side-by-side with the same seed
- Checking each stage's output shape and value range independently
Reference Implementations
Hybrid Style (recommended for most new models)
| Model | Pipeline | BeforeDenoisingStage | PipelineConfig |
|---|---|---|---|
| GLM-Image | runtime/pipelines/glm_image.py |
stages/model_specific_stages/glm_image.py |
configs/pipeline_configs/glm_image.py |
| Qwen-Image-Layered | runtime/pipelines/qwen_image.py (QwenImageLayeredPipeline) |
stages/model_specific_stages/qwen_image_layered.py |
configs/pipeline_configs/qwen_image.py (QwenImageLayeredPipelineConfig) |
| Cosmos3 | runtime/pipelines/cosmos3_pipeline.py |
stages/model_specific_stages/cosmos3.py |
configs/pipeline_configs/cosmos3.py |
| ErnieImage | runtime/pipelines/ernie_image.py |
runtime/pipelines/ernie_image.py |
configs/pipeline_configs/ernie_image.py |
| Hunyuan3D | runtime/pipelines/hunyuan3d_pipeline.py |
stages/model_specific_stages/hunyuan3d.py |
configs/pipeline_configs/hunyuan3d.py |
| LingBot World realtime | runtime/pipelines/lingbot_world_causal_dmd_pipeline.py |
stages/model_specific_stages/lingbot_world/ |
configs/pipeline_configs/lingbot_world.py |
Modular Style (when standard stages fit well)
| Model | Pipeline | Notes |
|---|---|---|
| Qwen-Image (T2I) | runtime/pipelines/qwen_image.py |
Uses add_standard_t2i_stages() — standard text encoding + latent prep fits this model |
| Qwen-Image-Edit | runtime/pipelines/qwen_image.py |
Uses add_standard_ti2i_stages() — standard image-to-image flow |
| Flux | runtime/pipelines/flux.py |
Uses add_standard_t2i_stages() with custom prepare_mu |
| FLUX.2 / FLUX.2 Klein | runtime/pipelines/flux_2.py, flux_2_klein.py |
Reuses FLUX.2 stages; Klein differences live in config and sampling params |
| Z-Image | runtime/pipelines/zimage_pipeline.py |
Uses standard image pipeline stages plus Z-Image-specific config/model code |
| Ideogram4 | runtime/pipelines/ideogram.py |
Uses dedicated text encoding and denoising stages while keeping standard latent prep |
| SANA | runtime/pipelines/sana.py |
Spatial image pipeline; reuse the spatial image config pattern |
| Wan | runtime/pipelines/wan_pipeline.py |
Uses add_standard_ti2v_stages() |
Checklist
Before submitting, verify:
Common (both styles):
- Pipeline file exists at
runtime/pipelines/{model_name}.pywithEntryClass - PipelineConfig at
configs/pipeline_configs/{model_name}.py - SamplingParams at
configs/sample/{model_name}.py - DiT model at
runtime/models/dits/{model_name}.py - DiT config at
configs/models/dits/{model_name}.py - VAE — reuse existing (e.g.,
AutoencoderKL) or create new atruntime/models/vaes/ - VAE config — reuse existing or create new at
configs/models/vaes/{model_name}.py - Registry entry in
registry.pyviaregister_configs() -
pipeline_namematches Diffusersmodel_index.json_class_name -
_required_config_moduleslists all modules frommodel_index.json -
PipelineConfigcallbacks (prepare_pos_cond_kwargs,get_freqs_cis, etc.) match DiT'sforward()signature - Latent scale/shift factors are correctly configured
- Use fused kernels where possible (see
existing-fast-paths.mdunder the benchmark/profile skill) - Weight names match Diffusers for automatic loading
- TP/SP support considered for DiT model (recommended; reference
wanvideo.pyfor TP+SP,qwen_image.pyfor USPAttention) - Output quality verified — generated images/videos are not noise; compared against Diffusers reference output
Hybrid style only:
- BeforeDenoisingStage at
stages/model_specific_stages/{model_name}.py -
BeforeDenoisingStage.forward()populates all fields needed byDenoisingStage
Common Pitfalls
batch.sigmasmust be a Python list, not a numpy array. Use.tolist()to convert.batch.prompt_embedsis a list of tensors (one per encoder), not a single tensor. Wrap with[tensor].- Don't forget
batch.raw_latent_shape--DecodingStageuses it to unpack latents. - Rotary embedding style matters:
is_neox_style=True= split-half rotation,is_neox_style=False= interleaved. Check the reference model carefully. - VAE precision: Many VAEs need fp32 or bf16 for numerical stability. Set
vae_precisionin the PipelineConfig accordingly. - Avoid forcing model-specific logic into shared stages: If your model's pre-processing doesn't naturally fit the existing standard stages, prefer the Hybrid pattern with a dedicated BeforeDenoisingStage rather than adding conditional branches to shared stages.
After Implementation: Tests and Performance Data
After the model produces non-noise output, read
references/testing-and-accuracy.md before
adding GPU cases, component-accuracy skips/hooks, suite entries, or benchmark
claims. That reference tracks the current gpu_cases.py / testcase_configs.py
/ accuracy_testcase_configs.py / run_suite.py split and the component-accuracy
decision rules.