sglang-diffusion-add-model

star 29.1k

Use when adding a new diffusion model or Diffusers pipeline to SGLang.

sgl-project By sgl-project schedule Updated 6/8/2026

name: sglang-diffusion-add-model description: Use when adding a new diffusion model or Diffusers pipeline to SGLang.

Add a Diffusion Model to SGLang

Use this skill when adding a new diffusion model or pipeline variant to sglang.multimodal_gen.

Two Pipeline Styles

Style A: Hybrid Monolithic Pipeline (Recommended)

The recommended default for most new models. Uses a three-stage structure:

BeforeDenoisingStage (model-specific)  -->  DenoisingStage (standard)  -->  DecodingStage (standard)
  • BeforeDenoisingStage: A single, model-specific stage that consolidates all pre-processing logic: input validation, text encoding, image encoding, latent preparation, timestep setup. This stage is unique per model.
  • DenoisingStage: Framework-standard stage for the denoising loop (DiT/UNet forward passes). Shared across models.
  • DecodingStage: Framework-standard stage for VAE decoding. Shared across models.

Why recommended? Modern diffusion models have highly heterogeneous pre-processing requirements (different text encoders, different latent formats, different conditioning mechanisms). The Hybrid approach keeps pre-processing isolated per model, avoids fragile shared stages with excessive conditional logic, and lets developers port Diffusers reference code quickly.

Style B: Modular Composition Style

Uses the framework's fine-grained standard stages (TextEncodingStage, LatentPreparationStage, TimestepPreparationStage, etc.) to build the pipeline by composition.

This style is appropriate when:

  • The new model's pre-processing can largely reuse existing stages — e.g., a model that uses standard CLIP/T5 text encoding + standard latent preparation with minimal customization. In this case, add_standard_t2i_stages() or add_standard_ti2i_stages() may be all you need.
  • A model-specific optimization needs to be extracted as a standalone stage — e.g., a specialized encoding or conditioning step that benefits from being a separate stage for profiling, parallelism control, or reuse across multiple pipeline variants.

See existing Modular examples: QwenImagePipeline (uses add_standard_t2i_stages), FluxPipeline, WanPipeline.

How to Choose

Situation Recommended Style
Model has unique/complex pre-processing (VLM captioning, AR token generation, custom latent packing, etc.) Hybrid — consolidate into a BeforeDenoisingStage
Model fits neatly into standard text-to-image or text+image-to-image pattern Modular — use add_standard_t2i_stages() / add_standard_ti2i_stages()
Porting a Diffusers pipeline with many custom steps Hybrid — copy the __call__ logic into a single stage
Adding a variant of an existing model that shares most logic Modular — reuse existing stages, customize via PipelineConfig callbacks
A specific pre-processing step needs special parallelism or profiling isolation Modular — extract that step as a dedicated stage

Key principle (both styles): The stage(s) before DenoisingStage must produce a Req batch object with all the standard tensor fields that DenoisingStage expects (latents, timesteps, prompt_embeds, etc.). As long as this contract is met, the pipeline remains composable regardless of which style you use.


Key Files and Directories

Purpose Path
Pipeline classes python/sglang/multimodal_gen/runtime/pipelines/
Model-specific stages python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/
PipelineStage base class python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py
Pipeline base class python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py
Standard stages (Denoising, Decoding) python/sglang/multimodal_gen/runtime/pipelines_core/stages/
Pipeline configs python/sglang/multimodal_gen/configs/pipeline_configs/
Sampling params python/sglang/multimodal_gen/configs/sample/
DiT model implementations python/sglang/multimodal_gen/runtime/models/dits/
VAE implementations python/sglang/multimodal_gen/runtime/models/vaes/
Encoder implementations python/sglang/multimodal_gen/runtime/models/encoders/
Scheduler implementations python/sglang/multimodal_gen/runtime/models/schedulers/
Model/VAE/DiT configs python/sglang/multimodal_gen/configs/models/dits/, vaes/, encoders/
Central registry python/sglang/multimodal_gen/registry.py
Model component registry python/sglang/multimodal_gen/runtime/models/registry.py
Current support list docs/diffusion/compatibility_matrix.md

Step-by-Step Implementation

Step 1: Obtain and Study the Reference Implementation

Before writing any code, obtain the model's reference implementation or Diffusers pipeline code. You need the actual source code to work from — do not guess or assume the model's architecture. If the user already gave a HuggingFace model ID or repo, inspect that yourself first. Ask the user only when the reference implementation is private, ambiguous, or otherwise unavailable. Typical sources are:

  • The model's Diffusers pipeline source (e.g., the pipeline_*.py file from the diffusers library or HuggingFace repo)
  • Or the model's official reference implementation (e.g., from the model author's GitHub repo)
  • Or the HuggingFace model ID so you can look up model_index.json and the associated pipeline class

Once you have the reference code, study it thoroughly:

  1. Find the model's model_index.json to identify required modules (text_encoder, vae, transformer, scheduler, etc.)
  2. Read the Diffusers pipeline's __call__ method end-to-end. Identify:
    • How text prompts are encoded
    • How latents are prepared (shape, dtype, scaling)
    • How timesteps/sigmas are computed
    • What conditioning kwargs the DiT/UNet expects
    • How the denoising loop works (classifier-free guidance, etc.)
    • How VAE decoding is done (scaling factors, tiling, etc.)

Step 2: Evaluate Reuse of Existing Pipelines and Stages

Before creating any new files, check whether an existing pipeline or stage can be reused or extended. Only create new pipelines/stages when the existing ones would require extensive modifications or when no similar implementation exists.

Specifically:

  1. Compare the new model's architecture against existing pipelines before creating files. Current native families include LTX-2/2.3, HunyuanVideo/FastHunyuan, Wan/FastWan/TurboWan/LingBot World, MOVA, FLUX/FLUX.2/Klein, Z-Image, Qwen-Image/edit/layered, GLM-Image, SD3, Hunyuan3D, Helios, Cosmos3, SANA, FireRed, ERNIE-Image, JoyAI, and Ideogram4. If the new model shares most of its structure with an existing one (e.g., same text encoders, similar latent format, compatible denoising loop), prefer:
    • Adding a new config variant to the existing pipeline rather than creating a new pipeline class
    • Reusing the existing BeforeDenoisingStage with minor parameter differences
    • Using add_standard_t2i_stages() / add_standard_ti2i_stages() / add_standard_ti2v_stages() if the model fits standard patterns
  2. Check existing stages in runtime/pipelines_core/stages/ and stages/model_specific_stages/. If an existing stage handles 80%+ of what the new model needs, extend it rather than duplicating it.
  3. Check existing model components — many models share VAEs (e.g., AutoencoderKL), text encoders (CLIP, T5), and schedulers. Reuse these directly instead of re-implementing.

Rule of thumb: Only create a new file when the existing implementation would need substantial structural changes to accommodate the new model, or when no architecturally similar implementation exists.

Step 3: Implement Model Components

Adapt or implement the model's core components in the appropriate directories.

DiT/Transformer (runtime/models/dits/{model_name}.py):

# python/sglang/multimodal_gen/runtime/models/dits/my_model.py

import torch
import torch.nn as nn

from sglang.multimodal_gen.runtime.layers.layernorm import (
    LayerNormScaleShift,
    RMSNormScaleShift,
)
from sglang.multimodal_gen.runtime.layers.attention.selector import (
    get_attn_backend,
)


class MyModelTransformer2DModel(nn.Module):
    """DiT model for MyModel.

    Adapt from the Diffusers/reference implementation. Key points:
    - Use SGLang's fused LayerNorm/RMSNorm ops (see `existing-fast-paths.md` under the benchmark/profile skill)
    - Use SGLang's attention backend selector
    - Keep the same parameter naming as Diffusers for weight loading compatibility
    """

    def __init__(self, config):
        super().__init__()
        # ... model layers ...

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        timestep: torch.Tensor,
        # ... model-specific kwargs ...
    ) -> torch.Tensor:
        # ... forward pass ...
        return output

Tensor Parallel (TP) and Sequence Parallel (SP): For multi-GPU deployment, it is recommended to add TP/SP support to the DiT model. This can be done incrementally after the single-GPU implementation is verified. Reference existing implementations and adapt to your model's architecture:

  • Wan model (runtime/models/dits/wanvideo.py) — Full TP + SP reference:
    • TP: Uses ColumnParallelLinear for Q/K/V projections, RowParallelLinear for output projections, attention heads divided by tp_size
    • SP: Sequence dimension sharding via get_sp_world_size(), padding for alignment, sequence_model_parallel_all_gather for aggregation
    • Cross-attention skips SP (skip_sequence_parallel=is_cross_attention)
  • Qwen-Image model (runtime/models/dits/qwen_image.py) — SP + USPAttention reference:
    • SP: Uses USPAttention (Ulysses + Ring Attention), configured via --ulysses-degree / --ring-degree
    • TP: Uses MergedColumnParallelLinear for QKV (with Nunchaku quantization), ReplicatedLinear otherwise

Important: These are references only — each model has its own architecture and parallelism requirements. Consider:

  • How attention heads can be divided across TP ranks
  • Whether the model's sequence dimension is naturally shardable for SP
  • Which linear layers benefit from column/row parallel sharding vs. replication
  • Whether cross-attention or other special modules need SP exclusion

Key imports for distributed support:

from sglang.multimodal_gen.runtime.distributed import (
    divide,
    get_sp_group,
    get_sp_world_size,
    get_tp_world_size,
    sequence_model_parallel_all_gather,
)
from sglang.multimodal_gen.runtime.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
    ReplicatedLinear,
)

VAE (runtime/models/vaes/{model_name}.py): Implement if the model uses a non-standard VAE. Many models reuse existing VAEs.

Encoders (runtime/models/encoders/{model_name}.py): Implement if the model uses custom text/image encoders.

Schedulers (runtime/models/schedulers/{scheduler_name}.py): Implement if the model requires a custom scheduler not available in Diffusers.

Step 4: Create Model Configs

DiT Config (configs/models/dits/{model_name}.py):

# python/sglang/multimodal_gen/configs/models/dits/mymodel.py

from dataclasses import dataclass, field

from sglang.multimodal_gen.configs.models.dits.base import DiTConfig


@dataclass
class MyModelDitConfig(DiTConfig):
    arch_config: dict = field(default_factory=lambda: {
        "in_channels": 16,
        "num_layers": 24,
        "patch_size": 2,
        # ... model-specific architecture params ...
    })

VAE Config (configs/models/vaes/{model_name}.py):

from dataclasses import dataclass, field

from sglang.multimodal_gen.configs.models.vaes.base import VAEConfig


@dataclass
class MyModelVAEConfig(VAEConfig):
    vae_scale_factor: int = 8
    # ... VAE-specific params ...

Sampling Params (configs/sample/{model_name}.py):

from dataclasses import dataclass

from sglang.multimodal_gen.configs.sample.base import SamplingParams


@dataclass
class MyModelSamplingParams(SamplingParams):
    num_inference_steps: int = 50
    guidance_scale: float = 7.5
    height: int = 1024
    width: int = 1024
    # ... model-specific defaults ...

Step 5: Create PipelineConfig

The PipelineConfig holds static model configuration and defines callback methods used by the standard DenoisingStage and DecodingStage.

# python/sglang/multimodal_gen/configs/pipeline_configs/my_model.py

from dataclasses import dataclass, field

import torch

from sglang.multimodal_gen.configs.models import DiTConfig, VAEConfig
from sglang.multimodal_gen.configs.pipeline_configs.base import (
    ImagePipelineConfig,
    ModelTaskType,
    # PipelineConfig,              # common base for many video pipelines
    # SpatialImagePipelineConfig,  # alternative base for spatial image models
)
from sglang.multimodal_gen.configs.models.dits.mymodel import MyModelDitConfig
from sglang.multimodal_gen.configs.models.vaes.mymodel import MyModelVAEConfig


@dataclass
class MyModelPipelineConfig(ImagePipelineConfig):
    """Pipeline config for MyModel.

    This config provides callbacks that the standard DenoisingStage and
    DecodingStage use during execution. The BeforeDenoisingStage handles
    all model-specific pre-processing independently.
    """

    task_type: ModelTaskType = ModelTaskType.T2I
    vae_precision: str = "bf16"
    should_use_guidance: bool = True
    vae_tiling: bool = False
    enable_autocast: bool = False

    dit_config: DiTConfig = field(default_factory=MyModelDitConfig)
    vae_config: VAEConfig = field(default_factory=MyModelVAEConfig)

    # --- Callbacks used by DenoisingStage ---

    def get_freqs_cis(self, batch, device, rotary_emb, dtype):
        """Prepare rotary position embeddings for the DiT."""
        # Model-specific RoPE computation
        ...
        return freqs_cis

    def prepare_pos_cond_kwargs(self, batch, latent_model_input, t, **kwargs):
        """Build positive conditioning kwargs for each denoising step."""
        return {
            "hidden_states": latent_model_input,
            "encoder_hidden_states": batch.prompt_embeds[0],
            "timestep": t,
            # ... model-specific kwargs ...
        }

    def prepare_neg_cond_kwargs(self, batch, latent_model_input, t, **kwargs):
        """Build negative conditioning kwargs for CFG."""
        return {
            "hidden_states": latent_model_input,
            "encoder_hidden_states": batch.negative_prompt_embeds[0],
            "timestep": t,
            # ... model-specific kwargs ...
        }

    # --- Callbacks used by DecodingStage ---

    def get_decode_scale_and_shift(self):
        """Return (scale, shift) for latent denormalization before VAE decode."""
        return self.vae_config.latents_std, self.vae_config.latents_mean

    def post_denoising_loop(self, latents, batch):
        """Optional post-processing after the denoising loop finishes."""
        return latents.to(torch.bfloat16)

    def post_decoding(self, frames, server_args):
        """Optional post-processing after VAE decoding."""
        return frames

There is no separate VideoPipelineConfig base class. For video models, choose ModelTaskType.T2V, ModelTaskType.I2V, or ModelTaskType.TI2V, and follow existing video configs such as Wan, LTX, Hunyuan, Helios, or MOVA when deciding whether to subclass PipelineConfig directly or use a model-specific base.

Important: The prepare_pos_cond_kwargs / prepare_neg_cond_kwargs methods define what the DiT receives at each denoising step. These must match the DiT's forward() signature.

Step 6: Implement the BeforeDenoisingStage (Core Step)

This is the heart of the Hybrid pattern. Create a single stage that handles ALL pre-processing.

# python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/my_model.py

import torch
from typing import List, Optional, Union

from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)


class MyModelBeforeDenoisingStage(PipelineStage):
    """Monolithic pre-processing stage for MyModel.

    Consolidates all logic before the denoising loop:
    - Input validation
    - Text/image encoding
    - Latent preparation
    - Timestep/sigma computation

    This stage produces a Req batch with all fields required by
    the standard DenoisingStage.
    """

    def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
        super().__init__()
        self.vae = vae
        self.text_encoder = text_encoder
        self.tokenizer = tokenizer
        self.transformer = transformer
        self.scheduler = scheduler
        # ... other initialization (image processors, scale factors, etc.) ...

    # --- Internal helper methods ---
    # Copy/adapt directly from the Diffusers reference pipeline.
    # These are private to this stage; no need to make them reusable.

    def _encode_prompt(self, prompt, device, dtype):
        """Encode text prompt into embeddings."""
        # ... model-specific text encoding logic ...
        return prompt_embeds, negative_prompt_embeds

    def _prepare_latents(self, batch_size, height, width, dtype, device, generator):
        """Create initial noisy latents."""
        # ... model-specific latent preparation ...
        return latents

    def _prepare_timesteps(self, num_inference_steps, device):
        """Compute the timestep/sigma schedule."""
        # ... model-specific timestep computation ...
        return timesteps, sigmas

    # --- Main forward method ---

    @torch.no_grad()
    def forward(self, batch: Req, server_args: ServerArgs) -> Req:
        """Execute all pre-processing and populate batch for DenoisingStage.

        This method mirrors the first half of a Diffusers pipeline __call__,
        up to (but not including) the denoising loop.
        """
        device = get_local_torch_device()
        dtype = torch.bfloat16
        generator = torch.Generator(device=device).manual_seed(batch.seed)

        # 1. Encode prompt
        prompt_embeds, negative_prompt_embeds = self._encode_prompt(
            batch.prompt, device, dtype
        )

        # 2. Prepare latents
        latents = self._prepare_latents(
            batch_size=1,
            height=batch.height,
            width=batch.width,
            dtype=dtype,
            device=device,
            generator=generator,
        )

        # 3. Prepare timesteps
        timesteps, sigmas = self._prepare_timesteps(
            batch.num_inference_steps, device
        )

        # 4. Populate batch with everything DenoisingStage needs
        batch.prompt_embeds = [prompt_embeds]
        batch.negative_prompt_embeds = [negative_prompt_embeds]
        batch.latents = latents
        batch.timesteps = timesteps
        batch.num_inference_steps = len(timesteps)
        batch.sigmas = sigmas
        batch.generator = generator
        batch.raw_latent_shape = latents.shape
        batch.height = batch.height
        batch.width = batch.width

        return batch

Key fields that DenoisingStage expects on the batch (set these in your forward):

Field Type Description
batch.latents torch.Tensor Initial noisy latent tensor
batch.timesteps torch.Tensor Timestep schedule
batch.num_inference_steps int Number of denoising steps
batch.sigmas list[float] Sigma schedule (as a list, not numpy)
batch.prompt_embeds list[torch.Tensor] Positive prompt embeddings (wrapped in list)
batch.negative_prompt_embeds list[torch.Tensor] Negative prompt embeddings (wrapped in list)
batch.generator torch.Generator RNG generator for reproducibility
batch.raw_latent_shape tuple Original latent shape before any packing
batch.height / batch.width int Output dimensions

Step 7: Define the Pipeline Class

The pipeline class is minimal -- it just wires the stages together.

# python/sglang/multimodal_gen/runtime/pipelines/my_model.py

from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline
from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (
    ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines_core.stages import DenoisingStage
from sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.my_model import (
    MyModelBeforeDenoisingStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs


class MyModelPipeline(LoRAPipeline, ComposedPipelineBase):
    pipeline_name = "MyModelPipeline"  # Must match model_index.json _class_name

    _required_config_modules = [
        "text_encoder",
        "tokenizer",
        "vae",
        "transformer",
        "scheduler",
        # ... list all modules from model_index.json ...
    ]

    def create_pipeline_stages(self, server_args: ServerArgs):
        # 1. Monolithic pre-processing (model-specific)
        self.add_stage(
            MyModelBeforeDenoisingStage(
                vae=self.get_module("vae"),
                text_encoder=self.get_module("text_encoder"),
                tokenizer=self.get_module("tokenizer"),
                transformer=self.get_module("transformer"),
                scheduler=self.get_module("scheduler"),
            ),
        )

        # 2. Standard denoising loop (framework-provided)
        self.add_stage(
            DenoisingStage(
                transformer=self.get_module("transformer"),
                scheduler=self.get_module("scheduler"),
            ),
        )

        # 3. Standard VAE decoding (framework-provided)
        self.add_standard_decoding_stage()


# REQUIRED: This is how the registry discovers the pipeline
EntryClass = [MyModelPipeline]

Step 8: Register the Model

In python/sglang/multimodal_gen/registry.py, register your configs:

register_configs(
    sampling_param_cls=MyModelSamplingParams,
    pipeline_config_cls=MyModelPipelineConfig,
    hf_model_paths=[
        "org/my-model-name",  # HuggingFace model ID(s)
    ],
    model_detectors=[
        lambda path: "my-model" in path.lower(),
    ],
)

register_configs() does not take a model_family argument. It registers the sampling and pipeline config classes, then resolves models by exact hf_model_paths or optional detector predicates. Prefer exact hf_model_paths for public checkpoints used in docs or tests; use detector predicates only for families where local mirrors, renamed repos, or generated paths are common.

The EntryClass in your pipeline file is automatically discovered by the registry's _discover_and_register_pipelines() function -- no additional registration needed for the pipeline class itself.

Step 9: Verify Output Quality

After implementation, you must verify that the generated output is not noise. A noisy or garbled output image/video is the most common sign of an incorrect implementation. Common causes include:

  • Incorrect latent scale/shift factors (get_decode_scale_and_shift returning wrong values)
  • Wrong timestep/sigma schedule (order, dtype, or value range)
  • Mismatched conditioning kwargs (fields not matching the DiT's forward() signature)
  • Incorrect VAE decoder configuration (wrong vae_scale_factor, missing denormalization)
  • Rotary embedding style mismatch (is_neox_style set incorrectly)
  • Wrong prompt embedding format (missing list wrapping, wrong encoder output selection)

If the output is noise, the implementation is incorrect — do not ship it. Debug by:

  1. Comparing intermediate tensor values (latents, prompt_embeds, timesteps) against the Diffusers reference pipeline
  2. Running the Diffusers pipeline and SGLang pipeline side-by-side with the same seed
  3. Checking each stage's output shape and value range independently

Reference Implementations

Hybrid Style (recommended for most new models)

Model Pipeline BeforeDenoisingStage PipelineConfig
GLM-Image runtime/pipelines/glm_image.py stages/model_specific_stages/glm_image.py configs/pipeline_configs/glm_image.py
Qwen-Image-Layered runtime/pipelines/qwen_image.py (QwenImageLayeredPipeline) stages/model_specific_stages/qwen_image_layered.py configs/pipeline_configs/qwen_image.py (QwenImageLayeredPipelineConfig)
Cosmos3 runtime/pipelines/cosmos3_pipeline.py stages/model_specific_stages/cosmos3.py configs/pipeline_configs/cosmos3.py
ErnieImage runtime/pipelines/ernie_image.py runtime/pipelines/ernie_image.py configs/pipeline_configs/ernie_image.py
Hunyuan3D runtime/pipelines/hunyuan3d_pipeline.py stages/model_specific_stages/hunyuan3d.py configs/pipeline_configs/hunyuan3d.py
LingBot World realtime runtime/pipelines/lingbot_world_causal_dmd_pipeline.py stages/model_specific_stages/lingbot_world/ configs/pipeline_configs/lingbot_world.py

Modular Style (when standard stages fit well)

Model Pipeline Notes
Qwen-Image (T2I) runtime/pipelines/qwen_image.py Uses add_standard_t2i_stages() — standard text encoding + latent prep fits this model
Qwen-Image-Edit runtime/pipelines/qwen_image.py Uses add_standard_ti2i_stages() — standard image-to-image flow
Flux runtime/pipelines/flux.py Uses add_standard_t2i_stages() with custom prepare_mu
FLUX.2 / FLUX.2 Klein runtime/pipelines/flux_2.py, flux_2_klein.py Reuses FLUX.2 stages; Klein differences live in config and sampling params
Z-Image runtime/pipelines/zimage_pipeline.py Uses standard image pipeline stages plus Z-Image-specific config/model code
Ideogram4 runtime/pipelines/ideogram.py Uses dedicated text encoding and denoising stages while keeping standard latent prep
SANA runtime/pipelines/sana.py Spatial image pipeline; reuse the spatial image config pattern
Wan runtime/pipelines/wan_pipeline.py Uses add_standard_ti2v_stages()

Checklist

Before submitting, verify:

Common (both styles):

  • Pipeline file exists at runtime/pipelines/{model_name}.py with EntryClass
  • PipelineConfig at configs/pipeline_configs/{model_name}.py
  • SamplingParams at configs/sample/{model_name}.py
  • DiT model at runtime/models/dits/{model_name}.py
  • DiT config at configs/models/dits/{model_name}.py
  • VAE — reuse existing (e.g., AutoencoderKL) or create new at runtime/models/vaes/
  • VAE config — reuse existing or create new at configs/models/vaes/{model_name}.py
  • Registry entry in registry.py via register_configs()
  • pipeline_name matches Diffusers model_index.json _class_name
  • _required_config_modules lists all modules from model_index.json
  • PipelineConfig callbacks (prepare_pos_cond_kwargs, get_freqs_cis, etc.) match DiT's forward() signature
  • Latent scale/shift factors are correctly configured
  • Use fused kernels where possible (see existing-fast-paths.md under the benchmark/profile skill)
  • Weight names match Diffusers for automatic loading
  • TP/SP support considered for DiT model (recommended; reference wanvideo.py for TP+SP, qwen_image.py for USPAttention)
  • Output quality verified — generated images/videos are not noise; compared against Diffusers reference output

Hybrid style only:

  • BeforeDenoisingStage at stages/model_specific_stages/{model_name}.py
  • BeforeDenoisingStage.forward() populates all fields needed by DenoisingStage

Common Pitfalls

  1. batch.sigmas must be a Python list, not a numpy array. Use .tolist() to convert.
  2. batch.prompt_embeds is a list of tensors (one per encoder), not a single tensor. Wrap with [tensor].
  3. Don't forget batch.raw_latent_shape -- DecodingStage uses it to unpack latents.
  4. Rotary embedding style matters: is_neox_style=True = split-half rotation, is_neox_style=False = interleaved. Check the reference model carefully.
  5. VAE precision: Many VAEs need fp32 or bf16 for numerical stability. Set vae_precision in the PipelineConfig accordingly.
  6. Avoid forcing model-specific logic into shared stages: If your model's pre-processing doesn't naturally fit the existing standard stages, prefer the Hybrid pattern with a dedicated BeforeDenoisingStage rather than adding conditional branches to shared stages.

After Implementation: Tests and Performance Data

After the model produces non-noise output, read references/testing-and-accuracy.md before adding GPU cases, component-accuracy skips/hooks, suite entries, or benchmark claims. That reference tracks the current gpu_cases.py / testcase_configs.py / accuracy_testcase_configs.py / run_suite.py split and the component-accuracy decision rules.

Install via CLI
npx skills add https://github.com/sgl-project/sglang --skill sglang-diffusion-add-model
Repository Details
star Stars 29,123
call_split Forks 6,576
navigation Branch main
article Path SKILL.md
More from Creator