name: flydsl-kernel-authoring description: > Comprehensive reference for authoring FlyDSL GPU kernels on AMD GPUs. Covers the layout algebra, tiled copy/MMA, buffer ops, loop-carried range loops, SmemAllocator, autotuning, and common patterns. Use when writing, reviewing, or understanding FlyDSL kernel code. allowed-tools: Read Edit Bash Grep Glob Agent
FlyDSL Kernel Authoring Skill
Overview
FlyDSL is a Python DSL and MLIR-based compiler for writing high-performance GPU kernels on AMD GPUs (MI300X/MI350). It provides explicit layout algebra for controlling data movement, tiling, and memory access patterns. The layout system is the core abstraction that distinguishes FlyDSL from Triton/Gluon.
Repository: /FlyDSL/ (installed in editable mode)
Target GPU: gfx942 (MI300X, CDNA3), gfx950 (MI350, CDNA4)
Python: 3.12, ROCm 7.2
Scope (read this first): This skill is the reference — the full layout-algebra API surface, per-op tables, MFMA/copy-atom catalogs, environment variables, and an exhaustive troubleshooting list. Reach for it to look something up while writing or reviewing kernel code. If instead you want a guided, step-by-step procedure that turns a kernel requirement into a finished, tested kernel (classify -> skeleton -> compute -> control flow -> test), use the flydsl-tile-programming skill, which is the wizard companion to this reference. For diagnosing a kernel that already compiles but produces NaN/inf/wrong results, use the debug-flydsl-kernel skill.
1. Architecture and Compilation
Pipeline
Python (@flyc.kernel/@flyc.jit)
-> AST Rewriting (for/if -> scf.for/scf.if)
-> MLIR Tracing (generates Fly dialect + gpu/arith/scf/memref/vector ops)
-> MlirCompiler.compile() (Fly -> ROCDL -> LLVM -> HSACO binary)
-> JITCFunction (ExecutionEngine wrapper)
Key Passes
Pipeline is built by RocmBackend._pipeline_parts() and split into three stages — see docs/architecture_guide.md §3 for the per-pass table. Highlights:
fly-rewrite-func-signature- Rewrite DSL types at function / SCF boundaries to packed LLVM structsfly-layout-lowering- Lower layout algebra (fly.crd2idx, partitions, divides) to arithmeticfly-convert-atom-call-to-ssa-form+fly-promote-regmem-to-vectorssa- Lift copy/MMA atom calls and register memory to vector SSAconvert-fly-to-rocdl- Fly ops -> ROCDL intrinsicsgpu-module-to-binary{format=fatbin}- Emit HSACO binary via LLVM AMDGPU backend
Key Source Paths
python/flydsl/compiler/- JIT compilation (jit_function.py, kernel_function.py)python/flydsl/expr/- DSL expression API (primitive.py, derived.py, typing.py)python/flydsl/expr/primitive.py- All layout algebra functionspython/flydsl/expr/derived.py- CopyAtom, MmaAtom, TiledCopy, TiledMma wrapperspython/flydsl/expr/gpu.py- GPU operations (thread_idx, block_idx, barrier)python/flydsl/expr/buffer_ops.py- AMD buffer load/store intrinsicspython/flydsl/expr/rocdl/- MFMA/WMMA and other ROCm intrinsics (package: cdna4, cluster, inline_asm, tdm_ops, universal)python/flydsl/utils/smem_allocator.py- LDS (shared memory) managementkernels/- Pre-built kernels (preshuffle_gemm.py, layernorm, softmax, rmsnorm)
2. Layout System (Core Abstraction)
Core Types
| Type | Description | Example |
|---|---|---|
!fly.int_tuple |
Integer tuple (can be nested) | (8, 16), (8, (4, 2)) |
!fly.layout |
(Shape, Stride) pair | (8, 16):(1, 8) (col-major) |
!fly.memref |
Memory reference with layout | Typed pointer + layout info |
Construction
import flydsl.expr as fx
shape = fx.make_shape(8, 16) # IntTuple (8, 16)
stride = fx.make_stride(1, 8) # IntTuple (1, 8)
layout = fx.make_layout(shape, stride) # Layout (8,16):(1,8)
# Shorthand with Python tuples
layout = fx.make_layout((8, 16), (1, 8))
# Coordinates
coord = fx.make_coord(i, j)
# Nested shapes for hierarchical tiling
shape_nested = fx.make_shape(9, (4, 8)) # (9, (4, 8))
# Identity layout
identity = fx.make_identity_layout((M, N))
Coordinate Mapping
The fundamental operation maps logical coordinates to physical memory indices.
Formula: Index = sum(coord_i * stride_i)
idx = fx.crd2idx(coord, layout) # Coordinate -> linear index
coord = fx.idx2crd(idx, layout) # Linear index -> coordinate
s = fx.size(layout) # Total element count (product of shape)
Example: For layout (8, 16):(1, 8) (8x16, column-major):
crd2idx((3, 5), layout)=3*1 + 5*8= 43idx2crd(43, layout)=(43 % 8, 43 / 8)=(3, 5)
Query Operations
fx.size(layout) # Total element count
fx.get_shape(layout) # Extract shape IntTuple
fx.get_stride(layout) # Extract stride IntTuple
fx.get(int_tuple, i) # Get i-th element
fx.rank(int_tuple) # Number of top-level modes
Layout Algebra Operations
Composition: fx.composition(A, B)
Compose two layouts: result(x) = A(B(x)). Used to apply permutations or tile coordinate mappings.
Complement: fx.complement(tiler, target_size)
Compute remaining modes not covered by tiler, up to target_size. Internal building block for divides.
Coalesce: fx.coalesce(layout)
Simplify layout by merging adjacent modes. Preserves mapping but flattens structure.
Right Inverse: fx.right_inverse(layout)
Compute right inverse of layout mapping.
Recast: fx.recast_layout(layout, old_bits, new_bits)
Adjust layout for type width change (e.g., FP16->FP8).
Product Operations (Combine Layouts)
Products combine two layouts to create a larger layout:
fx.logical_product(layout, tiler) # Basic mode-wise concatenation
fx.raked_product(thr, val) # Interleaved access pattern (common for TiledCopy)
fx.blocked_product(layout, tiler) # Blocked access pattern
fx.zipped_product(layout, tiler) # Zipped modes
fx.tiled_product(layout, tiler) # Hierarchical tiled structure
fx.flat_product(layout, tiler) # Flattened result
Divide Operations (Partition Layouts)
Divides split a layout by a divisor, creating tile + rest dimensions:
fx.logical_divide(layout, divisor) # Basic partitioning (uses complement internally)
fx.zipped_divide(layout, divisor) # Zipped division
fx.tiled_divide(layout, divisor) # Hierarchical tiled division
fx.flat_divide(layout, divisor) # Flattened division
Structural Operations
fx.select(int_tuple, indices=[0, 2]) # Pick specific modes
fx.group(int_tuple, begin=1, end=3) # Group modes into nested tuple
fx.append(base, elem) # Append mode
fx.prepend(base, elem) # Prepend mode
fx.zip(lhs, rhs) # Zip two IntTuples
fx.slice(src, coord) # Slice at coordinate (None = keep mode)
3. Writing Kernels
Basic Pattern
import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.expr import buffer_ops, const_expr, gpu, range_constexpr, rocdl
@flyc.kernel
def my_kernel(
A: fx.Tensor, # GPU tensor (memref via DLPack)
B: fx.Tensor,
N: fx.Constexpr[int], # Compile-time constant
):
tid = gpu.thread_id("x")
bid = gpu.block_id("x")
# ... kernel body ...
@flyc.jit
def launch(
A: fx.Tensor,
B: fx.Tensor,
N: fx.Constexpr[int],
stream: fx.Stream = fx.Stream(None),
):
my_kernel(A, B, N).launch(
grid=(N // 256,), block=(256,), stream=stream
)
# Usage:
import torch
A = torch.randn(1024, device="cuda", dtype=torch.float32)
B = torch.empty(1024, device="cuda", dtype=torch.float32)
launch(A, B, 1024)
Current Syntax Quick Reference
Use the current public FlyDSL surface from kernels/preshuffle_gemm.py when writing new kernels:
Vec = fx.Vector
tx = gpu.thread_id("x")
bx = gpu.block_id("x")
by = gpu.block_id("y")
i32_m: fx.Int32
c_m = fx.Index(i32_m)
c4 = fx.Index(4)
zero_f = fx.Float32(0.0)
layout = fx.make_layout((4, 64), (64, 1))
coord = fx.idx2crd(tx, layout)
wave_id = fx.get(coord, 0)
lane_id = fx.get(coord, 1)
acc = Vec.filled(4, 0.0, fx.Float32)
v_i64 = Vec(raw_vec).bitcast(fx.Int64)
elem0 = v_i64[0]
rsrc = buffer_ops.create_buffer_resource(tensor, max_size=True)
word = buffer_ops.buffer_load(rsrc, fx.Int32(offset), vec_width=4, dtype=fx.Int32)
Older code may use gpu.thread_idx.x, gpu.block_idx.x, arith.constant(...), T.i32, and raw vector.* helpers. Keep those when editing existing code that already uses them heavily, but prefer gpu.thread_id/block_id, fx.Index/fx.Int32/fx.Float32, and fx.Vector for new code.
Parameter Types
| Type | Description | At host boundary |
|---|---|---|
fx.Tensor |
GPU tensor (memref) | Auto-converted from torch.Tensor via DLPack |
fx.Constexpr[int] |
Compile-time constant | Different values -> different compiled kernels |
fx.Int32 |
Runtime i32 | Auto-converted from Python int |
fx.Stream |
CUDA/HIP stream | fx.Stream(None) for default stream |
Thread/Block Hierarchy
from flydsl.expr import gpu
tid_x = gpu.thread_id("x") # Preferred current spelling
bid_x = gpu.block_id("x")
bid_y = gpu.block_id("y")
# Legacy spelling still appears in older kernels:
tid_x = gpu.thread_idx.x
bid_x = gpu.block_idx.x
gpu.barrier() # Workgroup synchronization
Control Flow
from flydsl.expr import range_constexpr
# Compile-time unrolled loop (emitted inline in IR)
for i in range_constexpr(N):
...
# Runtime loop (lowered by AST rewriting)
for i in range(runtime_value):
...
Runtime vs Compile-Time Conditions (Current Style)
Use Python/DSL operators for runtime SSA comparisons. The AST rewriter lowers dynamic if conditions to scf.IfOp, and comparison operators like ==, <, >= generate the needed MLIR predicates.
tid = gpu.thread_id("x")
lane = tid % fx.Index(64)
c_zero = fx.Index(0)
c_limit = fx.Index(8)
# Preferred: readable DSL comparisons
if lane == c_zero:
...
in_range = lane < c_limit
val = fx.arith.select(in_range, good_val, zero_val)
# Avoid for simple integer comparisons
in_range = arith.cmpi(arith.CmpIPredicate.slt, lane, c_limit)
Use const_expr(...) only for values known at trace/compile time, such as Python booleans, constexpr arguments, loop-unroll choices, or type/layout branches:
if const_expr(trans_v):
...
if const_expr(max_context_partition_num <= WARP_SIZE):
...
Do not wrap GPU runtime values in const_expr. Even with @flyc.kernel(known_block_size=(256, 1, 1)), gpu.thread_id("x"), lane, and warp_id are runtime SSA values; the compiler knows their range, not the current lane instance.
# Wrong: lane depends on gpu.thread_id("x")
if const_expr(lane == c_zero):
...
# Correct
if lane == c_zero:
...
Keep explicit arith.cmpi(...) / arith.unwrap(...) for low-level manual MLIR construction, such as passing a raw condition to scf.IfOp directly:
cond = arith.unwrap(partition_idx >= visible_tile_count)
if_op = scf.IfOp(cond, has_else=False)
Frontend Semantic Restrictions
When writing or reviewing @flyc.kernel / @flyc.jit code, proactively avoid these patterns because they can conflict with MLIR construction even if they look valid in plain Python.
Do not define values inside
if/elseand use them later outside the branch. Keep a single explicit definition path.if cond: dst = a else: dst = b use(dst) # avoid this patternDo not mutate captured outer variables inside nested helper functions. Read-only closure capture is acceptable, but writes should go through explicit parameters and return values.
def kernel(): acc = fx.Float32(0.0) def helper(acc): acc = acc + fx.Float32(1.0) return acc acc = helper(acc)Avoid early
return, and do not placereturn/yieldinsideif/elsebranches. Prefer a single explicit exit so the frontend can determine result types.if cond: out = v0 else: out = v1 return outCompile-time conditions must use
const_expr(...). Useif const_expr(flag): ...for constexpr flags and other static decisions. A plain Pythonifis only safe when the condition is already a Pythonbool.Runtime branches inside helper functions should be dispatched via local
@flyc.jit. When a branch body has side effects, loop-carried values, or branch-local definitions, split the branch bodies into local helpers and wrap theifin a local JIT helper:def then_path(): ... def else_path(): ... @flyc.jit def dispatch(): if runtime_cond: then_path() else: else_path() dispatch()
Runtime Loops with Loop-Carried Values (Software Pipelining)
Use init= on range() to create a runtime loop with explicit SSA phi nodes for loop-carried state. This is required for software pipelining (prefetch patterns) where data must flow across iterations.
Pattern (from preshuffle_gemm.py):
# Prologue: load first tile
tile_0 = prefetch(0)
init_state = [acc_init, tile_0_flat_val1, tile_0_flat_val2, ...]
# Runtime loop with loop-carried state
# Use fx.Index(...) bounds so the AST rewriter does not treat this as a Python unrolled range.
_start = fx.Index(0)
_stop = fx.Index(N - 1)
_step = fx.Index(1)
for iv, state in range(_start, _stop, _step, init=init_state):
acc_in = state[0]
tile_in = state[1:]
next_tile = prefetch(iv + 1) # load NEXT data
acc_in = compute(acc_in, tile_in) # compute CURRENT
results = yield [acc_in] + next_tile # carry to next iter
# Epilogue: process last tile from results
acc_final = results[0]
tile_final = results[1:]
compute(acc_final, tile_final)
How it works in MLIR:
| Element | Meaning |
|---|---|
init=init_state |
List of SSA values that seed the runtime loop block arguments for iteration 0 |
state |
The loop-carried block arguments (phi nodes) for THIS iteration |
yield [...] |
Feeds values back as next iteration's state |
results |
After loop exits, holds the last yielded values |
Three critical pitfalls (all verified by debugging):
Loop bounds must be DSL index values, NOT Python ints. If you write
range(0, 15, 1, init=...), the AST rewriter treats constant bounds as a Pythonrangeand unrolls the loop — silently ignoringinit=. Usefx.Index(0),fx.Index(15),fx.Index(1)instead.Prefer internal types, but unwrap at hard boundaries. Most
range(..., init=...)uses accept DSL numeric/vector values. If a lower-level helper explicitly expects rawir.Value, unwrap withv.ir_value()/_raw(v)at that boundary only.Clear
SmemPtr._view_cachebefore epilogue.SmemPtr.get()caches the view it creates. If called inside the runtime loop body, the cached view is defined in the loop scope. Using it in the epilogue (outside the loop) causes an SSA dominance error. Fix:# After the runtime loop, before epilogue compute: my_smem_ptr._view_cache = None
Arithmetic Operations
c42 = fx.Index(42) # index type constant (preferred)
c3_14 = fx.Float32(3.14) # f32 constant (preferred)
mask = fx.Int32(0xFF) # i32 constant (preferred)
# Prefer operators / Numeric methods
result = a + b
result = a * scale
result = cond.select(true_val, false_val)
# Keep direct arith.*FOp only when explicit fastmath flags are required.
Internal Types: Vector and Numeric (PREFERRED)
Use FlyDSL's internal typed system instead of raw MLIR ops. The Vector class wraps vector<NxTy> with operator overloading and type-safe methods.
Vec = fx.Vector
# Wrap raw vector values
acc = Vec(frag_C.load()) # vector<Nxf32> → Vector with * / + operators
# Indexing (replaces vector.extract)
val = acc[idx] # returns Float32 scalar
# Bitcast (replaces vector.bitcast)
v_f32 = Vec(raw_vec).bitcast(fx.Float32) # vector<Nxi32> → vector<Nxf32>
# Type conversion (replaces arith.trunc_f / arith.ext_f)
bf16_val = f32_val.to(fx.BFloat16) # f32 → bf16
# Arithmetic — use Python operators, not arith.mulf/addf
result = (val * scale_a) * scale_b
# Splat constant vector
zeros = Vec.filled(N, 0.0, fx.Float32)
# Index cast — use fx.Int32 instead of arith.index_cast
idx = fx.Int32(gpu.block_id("x") * tile_m)
Prefer internal types over raw ops:
| Raw MLIR op | Internal type equivalent |
|---|---|
vector.extract(v, static_position=[i], ...) |
Vec(v)[i] |
vector.bitcast(target_ty, v) |
Vec(v).bitcast(Float32) |
arith.trunc_f(ty, v) |
v.to(BFloat16) |
arith.mulf(a, b) |
a * b |
arith.addf(a, b) |
a + b |
arith.index_cast(T.i32, v) |
fx.Int32(v) |
Use Vec.filled(...) for splats and Vec.from_elements(...) for vectors from scalars.
Arith Ops Availability Table
| Operation | Function | Works on Vectors | Notes |
|---|---|---|---|
| Add | a + b |
Yes | Use direct FOp only for explicit fastmath |
| Multiply | a * b |
Yes | Use direct FOp only for explicit fastmath |
| Negate | -a |
Yes | |
| Max | a.maximumf(b) |
Yes | Good for ReLU |
| Compare | arith.cmpf(a, b, pred) |
Yes | Returns i1/vec |
| Select | cond.select(t, f) |
Yes | |
| Abs | no direct helper | Use -v, comparison, and cond.select(...) |
|
| FMA | a * b + c |
Yes | Use direct FOp only when explicit fastmath is needed |
| Splat const | Vec.filled(width, val, dtype) |
Creates vector | For scalar broadcast |
Printf Debugging
fx.printf("tid={} bid={} val={}", tid, bid, value)
4. Data Movement Patterns
Layout-Based Copy (Preferred for Element-wise Kernels)
The standard pattern: divide tensor by tile size, slice by block/thread, copy via atoms.
@flyc.kernel
def my_kernel(A: fx.Tensor, B: fx.Tensor, BLOCK_DIM: fx.Constexpr[int]):
bid = fx.block_idx.x
tid = fx.thread_idx.x
# 1. Divide tensor into blocks
tA = fx.logical_divide(A, fx.make_layout(BLOCK_DIM, 1))
tB = fx.logical_divide(B, fx.make_layout(BLOCK_DIM, 1))
# 2. Select this block's tile
tA = fx.slice(tA, (None, bid))
tB = fx.slice(tB, (None, bid))
# 3. Further divide for per-thread access
tA = fx.logical_divide(tA, fx.make_layout(1, 1)) # 1 element per thread
tB = fx.logical_divide(tB, fx.make_layout(1, 1))
# 4. Allocate registers
copyAtom = fx.make_copy_atom(fx.UniversalCopy32b(), fx.Float32)
rA = fx.make_rmem_tensor(1, fx.Float32)
# 5. Copy: global -> register -> compute -> global
fx.copy_atom_call(copyAtom, fx.slice(tA, (None, tid)), rA)
# ... compute on register values ...
fx.copy_atom_call(copyAtom, rA, fx.slice(tB, (None, tid)))
Vectorized Loads (Wide Copies)
VEC_WIDTH = 4
copy_bits = VEC_WIDTH * 32 # 128 bits
copyAtom = fx.make_copy_atom(fx.UniversalCopy(copy_bits), fx.Float32)
rA = fx.make_rmem_tensor(VEC_WIDTH, fx.Float32)
# Divide for VEC_WIDTH elements per thread
tA = fx.logical_divide(tA, fx.make_layout(VEC_WIDTH, 1))
fx.copy_atom_call(copyAtom, fx.slice(tA, (None, tid)), rA)
# Load/store as vectors
vec = fx.memref_load_vec(rA) # Load vector from register memref
fx.memref_store_vec(vec, rA) # Store vector to register memref
TiledCopy Abstraction (for 2D Copies)
# Define thread and value layouts
thr_layout = fx.make_layout((4, 1), (1, 1)) # 4 threads
val_layout = fx.make_layout((1, 8), (1, 1)) # 8 values per thread
# Create copy atom
copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), fx.Float32)
# Build tiled copy with raked product layout
layout_thr_val = fx.raked_product(thr_layout, val_layout)
tile_mn = fx.make_tile(4, 8)
tiled_copy = fx.make_tiled_copy(copy_atom, layout_thr_val, tile_mn)
# Get this thread's slice and partition
thr_copy = tiled_copy.get_slice(tid)
partition_src = thr_copy.partition_S(src_tensor)
partition_dst = thr_copy.partition_D(dst_tensor)
frag = fx.make_fragment_like(partition_src)
# Execute copy: src -> fragment -> dst
fx.copy(copy_atom, partition_src, frag)
fx.copy(copy_atom, frag, partition_dst)
Buffer Load/Store (AMD Intrinsics)
from flydsl.expr import buffer_ops
rsrc = buffer_ops.create_buffer_resource(tensor)
# offset is in ELEMENTS (not bytes)
data = buffer_ops.buffer_load(rsrc, offset, vec_width=4)
buffer_ops.buffer_store(data, rsrc, offset)
Copy Atom Types
| Type | Bits | Usage |
|---|---|---|
fx.UniversalCopy32b() |
32 | 1x f32 element copy |
fx.UniversalCopy(64) |
64 | 2x f32 elements |
fx.UniversalCopy(128) |
128 | 4x f32 elements |
fx.rocdl.BufferCopy128b() |
128 | AMD buffer load 4xf32 |
5. Shared Memory (LDS)
SmemAllocator Pattern
from flydsl.utils.smem_allocator import SmemAllocator
from flydsl.expr.typing import T
from flydsl.compiler.kernel_function import CompilationContext
allocator = SmemAllocator(None, arch="gfx942", global_sym_name="smem0")
lds_a = allocator.allocate_array(T.f16, 8192) # Allocate typed arrays
lds_b = allocator.allocate_array(T.f16, 8192)
@flyc.kernel
def my_kernel(A: fx.Tensor, ...):
lds_base = allocator.get_base() # Get base ptr inside kernel
lds_a_ptr = lds_a(lds_base) # SmemPtr for typed access
val = lds_a_ptr.load([idx])
lds_a_ptr.store(val, [idx])
# Finalize in GPU module body (before launch)
comp_ctx = CompilationContext.get_current()
with ir.InsertionPoint(comp_ctx.gpu_module_body):
allocator.finalize()
LDS Capacity
| Architecture | GPU | LDS per CU |
|---|---|---|
| gfx942 | MI300X | 64 KB |
| gfx950 | MI350 | 160 KB |
6. MFMA Integration (Matrix Math)
Available MFMA Instructions
from flydsl.expr import rocdl
# FP16/BF16 MFMA
result = rocdl.mfma_f32_16x16x16_f16(a, b, acc)
# FP8 MFMA
result = rocdl.mfma_f32_16x16x32_fp8(a, b, acc)
# INT8 MFMA
result = rocdl.mfma_i32_16x16x32i8(a, b, acc)
GEMM Pattern (Preshuffle)
The preshuffle GEMM pattern in kernels/preshuffle_gemm.py:
- B matrix is pre-shuffled to layout: (N/16, K/64, 4, 16, kpack_bytes)
- A tiles loaded from global to LDS with XOR16 swizzle for bank-conflict avoidance
- K64-byte micro-steps: each step issues 2x K32 MFMA operations
- Ping-pong LDS (lds_stage=2) for overlapping loads with compute
- Epilogue: either direct row-major store or CShuffle via LDS for packing
7. Reduction Patterns
Warp Reduction (AMD wave64)
XOR-shuffle-based intra-wave reduction:
width_i32 = fx.Int32(64)
for sh in [32, 16, 8, 4, 2, 1]:
off = fx.Int32(sh)
peer = gpu.ShuffleOp(val, off, width_i32, mode="xor").shuffleResult
val = ArithValue(val) + peer # use explicit FOp only if fastmath flags are needed
Block Reduction
- Intra-wave XOR shuffle (shifts: 32, 16, 8, 4, 2, 1)
- Lane 0 writes per-wave partial to LDS
gpu.barrier()- Wave 0 reads and reduces NUM_WAVES partials from LDS
See kernels/reduce.py for reusable implementations.
8. Common Patterns and Recipes
Element-wise Kernel Template
@flyc.kernel
def elementwise_kernel(In: fx.Tensor, Out: fx.Tensor, BLOCK: fx.Constexpr[int], VEC: fx.Constexpr[int]):
bid, tid = fx.block_idx.x, fx.thread_idx.x
tile = BLOCK * VEC
tIn = fx.logical_divide(In, fx.make_layout(tile, 1))
tOut = fx.logical_divide(Out, fx.make_layout(tile, 1))
tIn = fx.slice(tIn, (None, bid))
tOut = fx.slice(tOut, (None, bid))
tIn = fx.logical_divide(tIn, fx.make_layout(VEC, 1))
tOut = fx.logical_divide(tOut, fx.make_layout(VEC, 1))
copy = fx.make_copy_atom(fx.UniversalCopy(VEC * 32), fx.Float32)
rIn = fx.make_rmem_tensor(VEC, fx.Float32)
rOut = fx.make_rmem_tensor(VEC, fx.Float32)
fx.copy_atom_call(copy, fx.slice(tIn, (None, tid)), rIn)
# Transform
v = Vec(fx.memref_load_vec(rIn))
v = v * v # example: square
fx.memref_store_vec(v, rOut)
fx.copy_atom_call(copy, rOut, fx.slice(tOut, (None, tid)))
Element-wise Kernel Cookbook (GPU-Verified)
All recipes below follow the same vectorized copy_atom pattern (256 threads, vec_width=4, 128-bit loads).
Only the compute section between memref_load_vec and memref_store_vec differs.
# --- Scale: C = A * scalar ---
vA = Vec(fx.memref_load_vec(rA))
scale = Vec.filled(vec_width, 2.0, fx.Float32)
vC = vA * scale
# --- Multiply: C = A * B ---
vC = Vec(fx.memref_load_vec(rA)) * Vec(fx.memref_load_vec(rB))
# --- FMA: D = A * B + C ---
vAB = Vec(fx.memref_load_vec(rA)) * Vec(fx.memref_load_vec(rB))
vD = vAB + Vec(fx.memref_load_vec(rC))
# --- ReLU: C = max(A, 0) ---
vA = Vec(fx.memref_load_vec(rA))
zero_vec = Vec.filled(vec_width, 0.0, fx.Float32)
vC = vA.maximumf(zero_vec)
# --- Abs: C = |A| (arith.absf does NOT exist) ---
vA = fx.memref_load_vec(rA)
zero_vec = Vec.filled(vec_width, 0.0, fx.Float32)
neg_vA = -vA
is_neg = vA < zero_vec
vC = is_neg.select(neg_vA, vA)
Naive GEMM Template (for understanding, not performance)
@flyc.kernel
def naive_gemm(A: fx.Tensor, B: fx.Tensor, C: fx.Tensor,
M: fx.Constexpr[int], N: fx.Constexpr[int], K: fx.Constexpr[int],
BM: fx.Constexpr[int], BN: fx.Constexpr[int]):
tid, bid = gpu.thread_id("x"), gpu.block_id("x")
bm, bn = bid // (N // BN), bid % (N // BN)
tm, tn = tid // BN, tid % BN
row, col = bm * BM + tm, bn * BN + tn
rsrc_a = buffer_ops.create_buffer_resource(A)
rsrc_b = buffer_ops.create_buffer_resource(B)
rsrc_c = buffer_ops.create_buffer_resource(C)
acc = fx.Float32(0.0)
for k in range_constexpr(K):
a = buffer_ops.buffer_load(rsrc_a, row * K + k, vec_width=1)
b = buffer_ops.buffer_load(rsrc_b, k * N + col, vec_width=1)
acc = acc + a * b
buffer_ops.buffer_store(acc, rsrc_c, row * N + col)
9. Environment and Debugging
IR Dump
FLYDSL_DUMP_IR=1 FLYDSL_DUMP_DIR=./dumps python my_kernel.py
Produces numbered .mlir files per pipeline stage plus final_isa.s.
Key Environment Variables
| Variable | Default | Description |
|---|---|---|
FLYDSL_DUMP_IR |
false | Dump IR at each stage |
FLYDSL_DEBUG_ENABLE_DEBUG_INFO |
false | Emit DWARF debug info (source-to-asm mapping) |
FLYDSL_RUNTIME_ENABLE_CACHE |
true | Enable kernel disk caching (in-memory cache is always active) |
FLYDSL_RUNTIME_CACHE_DIR |
~/.flydsl/cache | Cache directory |
FLYDSL_COMPILE_OPT_LEVEL |
2 | Optimization level (0-3) |
ARCH |
auto-detect | Override GPU architecture |
Disk Cache Invalidation
The JIT disk cache auto-invalidates when kernel source or closure values change. Set FLYDSL_RUNTIME_ENABLE_CACHE=0 only when modifying C++ passes or non-closure helper functions:
FLYDSL_RUNTIME_ENABLE_CACHE=0 python my_kernel.py # or: rm -rf ~/.flydsl/cache
Source-to-Assembly Debug Info
FlyDSL supports source-to-assembly mapping for rocprofv3 ATT traces via the MLIR
ensure-debug-info-scope-on-llvm-func pass (equivalent to Triton's add_di_scope).
How it works:
- FlyDSL's
FuncLocationTrackergenerates MLIRloc()metadata pointing to Python source lines - The
ensure-debug-info-scope-on-llvm-func{emission-kind=LineTablesOnly}pass converts MLIR locations into LLVMDISubprogramAttr/DICompileUnitAttrmetadata - The
-gflag ingpu-module-to-binarypreserves this metadata as.debug_linein the HSACO binary - rocprofv3 ATT reads
.debug_lineto producecode.jsonwith"source_file:line"entries
Pipeline position: After reconcile-unrealized-casts, before gpu-module-to-binary:
... -> reconcile-unrealized-casts
-> ensure-debug-info-scope-on-llvm-func{emission-kind=LineTablesOnly} (conditional on enable_debug_info)
-> gpu-module-to-binary{format=fatbin opts=-g}
Verification: With FLYDSL_DUMP_IR=1, check final_isa.s for .file and .loc directives.
The PA decode kernel achieves 99.9% coverage (1109/1110 ISA instructions mapped to source).
Key insight: Without this pass, MLIR loc() metadata is silently dropped during MLIR-to-LLVM-IR
translation. The -g flag alone is useless — it preserves debug info, but there's none to preserve
without the DI scope pass.
Autotune Module
FlyDSL includes a Triton-style autotune module at /FlyDSL/python/flydsl/autotune.py:
from flydsl.autotune import autotune, Config, do_bench
@autotune(
configs=[
Config(block_dim=64, vec_width=4),
Config(block_dim=128, vec_width=4),
Config(block_dim=256, vec_width=4),
],
key=['const_n'], # re-tune when these arg values change
warmup=5, rep=25, # benchmark timing params
)
@flyc.jit
def myKernel(A, C, n: fx.Int32, const_n: fx.Constexpr[int],
block_dim: fx.Constexpr[int], vec_width: fx.Constexpr[int],
stream: fx.Stream = fx.Stream(None)):
...
Configkwargs becomeConstexprargs injected into@jitcallConfig.num_warps,waves_per_eu,maxnregare special compiler-level options- First call benchmarks all configs; subsequent calls use cached best
- Disk cache at
~/.flydsl/autotune/{func_name}.json do_bench(fn, warmup=5, rep=25)benchmarks using CUDA/HIP events, returns median ms
IMPORTANT: waves_per_eu does NOT work via gpu-module-to-binary opts=. It needs to be
set as an LLVM function attribute or through rocdl-attach-target. This is a known limitation.
DLTensorAdaptor bug: Do NOT use flyc.from_dlpack() with pre-wrapped tensors when calling
a @jit function with varying Constexpr values. The DLTensorAdaptor caches MLIR types from
the first ir.Context, which become invalid when a new context is created (causes segfault).
Pass raw torch.Tensor objects instead.
10. Troubleshooting
Common Issues
Constants/casts: Prefer
fx.Int32(...),fx.Int64(...),fx.Index(...), andfx.Float32(...). Usearith.constant(...)only at low-level boundaries.buffer_ops.buffer_loadoffset: Theoffsetparameter is in ELEMENTS, not bytes.Cache stale after code changes: The disk cache auto-invalidates on source/closure changes. Only set
FLYDSL_RUNTIME_ENABLE_CACHE=0or clear~/.flydsl/cache/if you changed C++ passes or non-closure helpers.LDS overflow: Check capacity (64KB on gfx942, 160KB on gfx950). Use
SmemAllocatorwhich tracks allocations.Dynamic vs Constexpr:
Constexpr[int]values are baked into IR -- different values produce different compiled kernels. UseInt32for truly dynamic values.Tensor layout marking: For dynamic shapes or alignment, use
flyc.from_dlpack(tensor).mark_layout_dynamic(leading_dim=0, divisibility=4).SmemAllocator finalize: Must call
allocator.finalize()inside the GPU module body (useCompilationContext.get_current().gpu_module_body).AMD wavefront size: Always 64 on gfx9xx. Use shifts [32, 16, 8, 4, 2, 1] for full-wave reduction.
tile_k alignment for GEMM:
tile_k * elem_bytesmust be divisible by 64 (K64-byte micro-step).INT4 (W4A8): A matrix is int8, B matrix is packed int4 (2 values/byte), unpacked to int8 in-kernel.
arith.absfdoes not exist: PreferVector/ArithValueoperators:neg = -v,is_neg = v < zero,out = is_neg.select(neg, v).Scalar broadcast to vector: Use
Vec.filled(width, value, fx.Float32)to create a splat constant vector. Do NOT use raw vector ops for ordinary arithmetic.
11. Comparison with Triton/Gluon
| Aspect | FlyDSL | Triton | Gluon |
|---|---|---|---|
| Layout control | Explicit layout algebra (Shape, Stride, Layout) | Implicit via block pointers | Implicit |
| Tiling | Manual via divide/product operations | Auto-tiling with tl.program_id |
Auto-tiling |
| Memory access | Copy atoms, buffer load/store, TiledCopy | tl.load/tl.store |
gluon.load/gluon.store |
| MFMA | Direct rocdl.mfma_* intrinsics |
tl.dot |
gluon.dot |
| Shared memory | SmemAllocator with explicit management | Implicit scratchpad | Implicit |
| Abstraction level | Low (near hardware) | Medium | Medium-High |
| Compilation | MLIR (Fly dialect -> LLVM -> HSACO) | MLIR (Triton dialect -> LLVM) | MLIR |
| Control | Maximum control over data layout and movement | Less control, more automation | Least control |
FlyDSL gives maximum control at the cost of verbosity. The layout algebra is the key differentiator -- it enables precise control over how data is arranged in registers, shared memory, and global memory, and how threads map to data.
12. Running Kernels
SSH to Remote Host
# Run a kernel
ssh -o LogLevel=ERROR hjbog-srdc-39.amd.com 'docker exec hungry_dijkstra bash -c "cd /FlyDSL && python3 my_kernel.py"'
# Run existing tests
ssh -o LogLevel=ERROR hjbog-srdc-39.amd.com 'docker exec hungry_dijkstra bash -c "cd /FlyDSL && python3 tests/kernels/test_vec_add.py"'
# Run benchmarks
ssh -o LogLevel=ERROR hjbog-srdc-39.amd.com 'docker exec hungry_dijkstra bash -c "cd /FlyDSL && bash scripts/run_benchmark.sh"'