name: fla-kda description: > FLA KDA kernel workflow and public technical notes. Use when modifying or reviewing fla/ops/kda/**, KDA gate modes, chunk intra/inter kernels, safe_gate behavior, KDA backends, or KDA-specific tests and benchmarks.
FLA KDA Skill
Use this skill for KDA-specific work under fla/ops/kda/** and tests that
exercise KDA behavior.
Public code map
- Public API:
fla.ops.kda.chunk_kda,fla.ops.kda.fused_recurrent_kda. - Gate helpers:
naive_kda_gate,naive_kda_lowerbound_gate,kda_gate_fwd,kda_gate_bwd,fused_kda_gate,kda_gate_chunk_cumsuminfla/ops/kda/gate.py. - Chunk forward:
chunk_kda_fwdinchunk_fwd.py. - Intra/inter forward:
chunk_kda_fwd_intra,chunk_kda_fwd_kernel_intra_sub_chunk,chunk_kda_fwd_kernel_inter_solve_fusedinchunk_intra.py. - Token-parallel non-safe path:
chunk_kda_fwd_intra_token_parallelinchunk_intra_token_parallel.py. - WY recompute:
recompute_w_u_fwdandrecompute_w_u_fwd_kda_kernelinwy_fast.py. - Backward:
chunk_kda_bwd,chunk_kda_bwd_intra,chunk_kda_bwd_wy_dqkg_fused. - Backends:
FlashKDABackend,KDATileLangBackend.
Gate modes
chunk_kda has two gate input contracts:
- Pre-gated mode:
use_gate_in_kernel=False.gis already the log-space decay tensor.A_log,dt_bias, andlower_boundare not part of the gate activation.
- In-kernel mode:
use_gate_in_kernel=True.gis raw gate input.A_logis required anddt_biasis optional.- Without
safe_gate, activation is-exp(A_log) * softplus(g + dt_bias). - With
safe_gate, activation islower_bound * sigmoid(exp(A_log) * (g + dt_bias)).
safe_gate=True requires use_gate_in_kernel=True, lower_bound is not None,
and -5 <= lower_bound < 0.
Safe gate numerical note
With lower_bound=-5, every per-token gate value is in [-5, 0) before the
RCP_LN2 conversion used by chunk_kda_fwd. A 16-token sub-chunk can therefore
accumulate -80 in natural-log units. Directly feeding the full span to exp2
would be larger in base-2 units, so the safe intra path relies on offsetting.
chunk_kda_fwd_kernel_intra_sub_chunk uses a midpoint offset before
exponentiation:
b_gm = b_g - b_gn;exp2(b_gm)andexp2(-b_gm).
With the midpoint offset, each exponent operand covers at most about half of the
16-token sub-chunk. Under lower_bound=-5, this is about 40 / ln(2), which is
below the kernel's exp2 safety comment threshold. The important invariant is
not the raw cumulative value alone; it is that each exponentiation uses a local
offset rather than the full chunk cumsum directly.
For inter-subchunk work, chunk_kda_fwd_kernel_inter_solve_fused computes decay
ratios with paired offsets such as:
exp2(b_g1 - b_gn1)andexp2(b_gn1 - b_g0);exp2(b_g2 - b_gn2)andexp2(b_gn2 - b_g1).
Both terms are non-positive under monotonic accumulated decay, so the off-diagonal inter path avoids positive exponent growth. The triangular solve operates on masked lower-triangular blocks, so it does not introduce an unbounded exponent path.
Safe vs non-safe intra path
- Safe path:
chunk_kda_fwd_intra(..., safe_gate=True)callschunk_kda_fwd_kernel_intra_sub_chunkfor 16-token diagonal blocks, then callschunk_kda_fwd_kernel_inter_solve_fusedwithUSE_SAFE_GATE=True. - Non-safe path:
safe_gate=Falsecallschunk_kda_fwd_intra_token_parallelfor diagonal blocks, then calls the same inter/solve kernel withUSE_SAFE_GATE=False. - Do not change one path without checking the other path unless the contract is explicitly safe-only or non-safe-only.
Correctness checklist
Before finishing a KDA behavior change, use fla-correctness-coverage and cover
only axes affected by the change:
- dense and varlen sequence layout;
- forward and backward if training path is touched;
- pre-gated, non-safe in-kernel, and safe in-kernel gate modes where supported;
- raw beta logits and post-sigmoid beta where supported;
use_qk_l2norm_in_kernel=True/Falsewhere relevant;- MHA and GVA (
HV > H); D != Dvwhen value dimension is involved;- initial/final state,
return_intermediate_states, and CP paths when touched; - backend verifier behavior for FlashKDA / TileLang changes.
- gate numerical extremes when gate math or intra/inter decay is touched:
lower_bound=-5, a lower bound close to0, large positive and negativeg + dt_bias, extremeA_log, long-sequence cumulative decay, chunk boundaries, and ragged varlen boundaries.
Style constraints
- Use platform helpers from
fla.utils(device,device_platform,IS_NVIDIA,IS_NVIDIA_HOPPER,IS_NVIDIA_BLACKWELL,IS_AMD,IS_INTEL) instead of adding new directtorch.cudaplatform checks in tests or public code. If no helper covers the condition, add one infla.utilsfirst. - Keep math derivations in operator docs or PR text; in Triton kernels, prefer compact shape comments and one-line rationale comments.
- Do not include internal-only paths, private model names, local machine paths, or private workload identifiers in public tests or skills.