name: ff-new-model description: "Complete workflow for adding a new model adapter. Covers analysis, sample dataclass, adapter implementation (4 abstract methods + per-modality encoder overrides), registry, example YAML, and verification. Trigger: 'add model', 'support new model', 'integrate model', 'new adapter'."
New Model Adapter Integration
Authoritative reference:
guidance/new_model.md— read it first.
Prerequisites
Before starting, ensure you understand:
- The target model's diffusers pipeline (or that you'll need a pseudo-pipeline)
- The task type: Text-to-Image, Image-to-Image, Text-to-Video, Image-to-Video
- Which Sample dataclass to extend
Phase 1: Analysis
- Identify the diffusers pipeline for the target model
- Check if it exists in
diffusers:from diffusers import <Pipeline> - If not, you'll need a pseudo-pipeline (see
guidance/new_model.mdadvanced section)
- Check if it exists in
- Study an existing adapter of the same task type:
- T2I:
models/flux/flux1.pyormodels/stable_diffusion/sd3_5.py - I2I:
models/flux/flux1_kontext.pyormodels/qwen_image/qwen_image_edit_plus.py - T2V:
models/wan/wan2_t2v.py - I2V:
models/wan/wan2_i2v.py
- T2I:
- Map pipeline components to adapter responsibilities:
- Text encoders →
encode_prompt(),preprocessing_modules - VAE →
encode_image()/decode_latents(),preprocessing_modules - Audio encoder/VAE (if any) →
encode_audio(),preprocessing_modules - Transformer/UNet →
forward(),default_target_modules(LoRA target layer names),inference_modules
- Text encoders →
- Also read:
topics/adapter_conventions.mdfor upstream alignment rules;topics/dtype_precision.mdfor precision handling incast_latents().
Phase 2: Implementation
Step 1 — Define Sample Dataclass
# src/flow_factory/models/<family>/<model>.py
@dataclass
class MyModelSample(T2ISample): # or appropriate base
_shared_fields: ClassVar[frozenset[str]] = frozenset({})
# Add model-specific fields if needed
Step 2 — Create Adapter Class
class MyModelAdapter(BaseAdapter):
@property
def preprocessing_modules(self) -> List[str]:
return ["text_encoder", "vae"] # Components for Stage 1
@property
def inference_modules(self) -> List[str]:
return ["vae"] # Components needed at inference time
@property
def default_target_modules(self) -> List[str]:
# LoRA target module names used when YAML sets `target_modules: default`.
# Override only if your transformer uses non-standard attention layer names.
return ["to_q", "to_k", "to_v", "to_out.0"]
Which components are trainable is config-driven: the YAML
target_components/target_modulesfields are resolved byBaseAdapter._parse_target_modules()intoself.target_module_map(set in__init__). Adapters do not overridetarget_module_map.
Step 3 — Implement Required Methods
| Method | Purpose | Stage | Abstract? |
|---|---|---|---|
load_pipeline() |
Load diffusers pipeline | Init | Yes |
decode_latents() |
Latents → pixels | 3 | Yes |
inference() |
Full multi-step denoising | 3 | Yes |
forward() |
Single-step denoising loss | 6 | Yes |
encode_prompt() |
Text → embeddings | 1 | No (no-op default; override if your model consumes text) |
encode_image() |
Image → latents | 1 | No (no-op default; override if your model consumes images) |
encode_video() |
Video frames → latents | 1 | No (no-op default; override if your model consumes videos) |
encode_audio() |
Audio → embeddings/features | 1 | No (no-op default; override if your model consumes audio) |
preprocess_func() |
Raw inputs → cached tensors (dispatches to the 4 encoders) | 1 | No (concrete, override only for cross-modal preprocessing) |
Step 4 — Register
Add to _MODEL_ADAPTER_REGISTRY in src/flow_factory/models/registry.py:
'my-model': 'flow_factory.models.<family>.<model>.MyModelAdapter',
Phase 3: Configuration
Create example YAML config in examples/grpo/lora/<model>/default.yaml:
model:
model_type: "my-model"
model_name_or_path: "org/model-name"
finetune_type: "lora"
target_components: ["transformer"]
Phase 4: Verification
Also read: topics/parity_testing.md for the 4-layer verification protocol.
-
load_pipeline()successfully loads the model -
preprocess_func()produces correct cached tensors -
inference()generates valid images/videos -
forward()computes loss without errors - Training runs end-to-end with GRPO for ≥2 steps
- LoRA weights save and reload correctly
- Registry entry resolves correctly:
get_model_adapter_class('my-model') - Example YAML config is valid and complete
Common Pitfalls
- Forgetting to set
preprocessing_modules— causes text encoder to stay on GPU, OOM during training - Wrong
target_components/target_modules(ordefault_target_modules) — LoRA applied to wrong components/layers, no training effect - Mismatched
_shared_fields— data corruption during batch collation - Not handling
enable_preprocess=False— encoding components not loaded at inference time - Inconsistent custom field types across samples — if a custom sample field is
Tensoron some samples andList[Tensor]on others,gather_sampleswill fall back to slow pickle-basedgather_object. Always canonicalize to a single type in__post_init__; preferList[Tensor]for variable-length data. - Wrong
images/condition_images/audiosconvention —preprocess_func(),encode_image(),encode_video(),encode_audio(), andinference()all operate at batch level:imagesisList[List[Image.Image]](MultiImageBatch),condition_imagesisList[List[Tensor(C,H,W)]](orList[List[PIL.Image]]for adapters that declarepython_format_columns, e.g. Bagel), andaudiosisList[List[Tensor]](MultiAudioBatch), where the outer list indexes samples in the batch and the inner list holds each sample's items. Empty samples contribute[](neverNone); single-item samples contribute[item](never a bare element). Never pass a flatList[Image]/List[Tensor]or unwrap single-element lists — that breaks Arrow's homogeneous-column requirement and forces every downstream consumer to handle three input shapes. For single-condition models,_standardize_image_input/_standardize_video_inputmust detect the nested format withis_multi_image_batch/is_multi_video_batch, extract the first element per sample ([batch[0] for batch in images]), and warn if extra conditions are discarded (e.g.Wan2_I2V._standardize_image_input,LTX2_I2AV._standardize_image_input). Seetopics/adapter_conventions.mdGotcha #5 and #6.