name: tlx-api-reference description: > TLX DSL API reference for low-level GPU primitives. Use when writing or modifying TLX kernel code that uses barriers (mbarrier, named barriers), memory allocation (local_alloc, SMEM, TMEM), TMA operations, warp specialization (async_tasks, async_task), CLC (cluster launch control), or wgmma instructions. Covers Hopper and Blackwell hardware differences.
TLX API Quick Reference
Warp Specialization
| Function | Description | Arch |
|---|---|---|
tlx.async_tasks() |
Context manager wrapping all async task regions | Both |
tlx.async_task([task_ids]) |
Assign code to specific task IDs (e.g., [0] = producer, [1,2] = consumers) |
Both |
tlx.async_task(num_warps=N, num_regs=R) |
Explicit warp/register allocation for a task | Both |
tlx.async_task("default", num_regs=R) |
Default task for code outside explicit tasks | Both |
tlx.async_task_replica_id() |
Returns replica ID inside an async region | Both |
Warp specialization skeleton
with tlx.async_tasks():
with tlx.async_task([0]): # Producer
# TMA loads
with tlx.async_task([1, 2]): # Consumers
# MMA compute
Memory Barriers
mbarrier (shared-memory allocated)
| Function | Description | Arch |
|---|---|---|
tlx.alloc_barriers(num_barriers, arrive_count=1) |
Allocate SMEM barriers and initialize with arrive count | Both |
tlx.barrier_expect_bytes(bar, bytes, pred=None) |
Set expected transaction byte count on barrier | Both |
tlx.barrier_wait(bar, phase, pred=None) |
Wait until barrier phase flips (LOCAL mbarrier only) | Both |
tlx.barrier_arrive(bar, arrive_count=1, remote_cta_rank=None) |
Signal arrival at barrier. remote_cta_rank signals a barrier in a remote CTA — only valid when ctas_per_cga > 1, causes "Unexpected buffer remote view in 1cta mode" otherwise. Guard with if USE_2CTA: when kernel supports both modes. |
Both |
tlx.cluster_barrier() |
Full cluster-wide synchronization barrier | Both |
arrive_count rules:
arrive_countcontrols how many timesbarrier_expect_bytesmust be called before the barrier can complete a phase. It is NOT a count ofbarrier_arrivecalls.- For TMA barriers where only the leader CTA calls
barrier_expect_bytes: usearrive_count=1(default). - For barriers arrived by software from both CTAs (via
barrier_arrivewithremote_cta_rank), usearrive_count=NUM_CTAS. barrier_arriveinsidetlx.async_task:arrive_count= number of warp groupsbarrier_arriveoutsidetlx.async_task:arrive_count=1(only tid==0 arrives)
Named barriers (hardware-allocated, indices 0–15)
| Function | Description | Arch |
|---|---|---|
tlx.named_barrier_wait(bar_id, num_threads) |
Wait until num_threads arrive at bar_id | NVIDIA |
tlx.named_barrier_arrive(bar_id, num_threads) |
Signal arrival at bar_id | NVIDIA |
num_threads must be a multiple of 32 (warp size). Typically num_warp_groups * warps_per_group * 32.
Used for PingPong scheduling to prevent tensor core contention between consumer warp groups.
Memory Operations
SMEM / TMEM allocation
| Function | Description | Arch |
|---|---|---|
tlx.local_alloc(shape, dtype, num, storage=smem, reuse=None, layout=None) |
Allocate buffered tensor in SMEM or TMEM | Both (TMEM: Blackwell) |
tlx.storage_alias_spec(storage=smem, buffer_size_bytes=None) |
Define shared buffer region for multiple local_alloc calls via reuse |
Both |
tlx.local_view(buf, index) |
Get view of a single buffer from a multi-buffered tensor | Both |
tlx.local_slice(buf, start, end) |
Slice a sub-range of a buffered tensor | Both |
tlx.subslice(tensor, dim, start, size) |
Subslice a tensor along a dimension | Both |
tlx.local_load(buf) |
Load from SMEM/TMEM buffer into registers | Both |
tlx.local_store(val, buf) |
Store from registers into SMEM/TMEM buffer | Both |
tlx.local_trans(buf) |
Transpose a shared memory buffer | Both |
tlx.local_reinterpret(buf, dtype) |
Reinterpret buffer with a different dtype | Both |
tlx.remote_view(buf, remote_cta_rank) |
Get view of buffer in a remote CTA's SMEM | Both |
tlx.remote_shmem_store(val, buf) |
Store to remote CTA's shared memory | Both |
tlx.async_remote_shmem_store(val, buf) |
Async store to remote CTA's shared memory | Both |
tlx.tmem_copy(src, dst) |
Copy between TMEM buffers | Blackwell |
tlx.fence_async_shared() |
Memory fence for async shared memory operations | Both |
Storage kinds: tlx.storage_kind.smem, tlx.storage_kind.tmem (Blackwell), tlx.storage_kind.smemCluster
TMA (Tensor Memory Accelerator)
| Function | Description | Arch |
|---|---|---|
tlx.make_tensor_descriptor(ptr, shape, strides, block_shape) |
Create TMA descriptor from pointer (host-side) | Hopper+ |
tlx.allocate_tensor_descriptor(ptr, shape, strides, block_shape, swizzle_mode) |
Allocate and fill TMA descriptor in SMEM | Hopper+ |
tlx.reinterpret_tensor_descriptor(desc, dtype) |
Reinterpret TMA descriptor with different dtype | Hopper+ |
tlx.async_descriptor_load(desc, indices, barrier=None) |
Async TMA load from global → SMEM, tracked by barrier | Hopper+ |
tlx.async_descriptor_store(desc, val, indices) |
Async TMA store from registers → global | Hopper+ |
tlx.async_descriptor_store_wait() |
Wait for all pending TMA stores to complete | Hopper+ |
tlx.async_load(ptr, buf, barrier) |
Async bulk copy global → SMEM (cp.async) | Hopper+ |
tlx.async_load_commit_group() |
Commit async load group | Hopper+ |
tlx.async_load_wait_group(n) |
Wait for async load groups (n pending allowed) | Hopper+ |
Matrix Multiply (MMA)
| Function | Description | Arch |
|---|---|---|
tlx.async_dot(A, B, acc=None, use_acc=None, mBarriers=[], two_ctas=False) |
Warp-group MMA: D = A @ B + C. Maps to wgmma (Hopper) or tcgen05.mma (Blackwell) | Both |
tlx.async_dot_scaled(A, B, acc, A_scale, A_format, B_scale, B_format, ...) |
Scaled MMA with FP8 inputs: D = (Ascale_A) @ (Bscale_B) + D | Blackwell |
tlx.async_dot_wait(pendings, inp) |
Wait for N pending async dot operations to complete | Both |
tlx.tcgen05_commit(mBarrier, two_ctas=False) |
Make mbarrier track completion of prior tcgen05 ops. Use a SEPARATE mbarrier from async_dot | Blackwell |
Minimum tile sizes for async_dot: M ≥ 64, K ≥ 16, N ≥ 32
Pair-CTA MMA (two_ctas=True): M must be 128 per CTA.
Multi-CTA (Cluster) Kernels
ctas_per_cga=(N,1,1) in triton.Config sets the cluster size. The grid
specifies total CTAs; hardware divides by ctas_per_cga to get the number
of clusters. E.g., grid=(2,1,1) with ctas_per_cga=(2,1,1) = 1 cluster of
2 CTAs.
2-CTA tile scheduling in attention backward
Each CTA in a cluster gets its own program_id and its own tile_id from
CLC. Two CTAs in a cluster naturally get consecutive tiles (pid 0, pid 1).
No special tile scheduling is needed for 2-CTA — start_n = pid works
as-is. Grid size and n_tile_num do NOT change between 1-CTA and 2-CTA.
Think of 2-CTA as two independent 1-CTAs that handle their own K/V tiles and share Q/dO via multicast. For L2 efficiency, they process consecutive N-blocks.
2-CTA MMA semantics (two_ctas=True)
- A operand (TMEM): per-CTA, each CTA has different data
- B operand (SMEM): split across CTAs and combined by hardware via multicast
- Output (TMEM): split across CTAs along the M dimension, written to both CTAs
- Leader MMA writes to both leader TMEM and peer TMEM
2-CTA barrier patterns
- TMA loads with
two_ctas=True: only leader callsbarrier_expect_bytes(guarded byif is_leader:). Usearrive_count=1. - Software arrives (
barrier_arrivewithremote_cta_rank=0): both CTAs arrive on leader's barrier. Usearrive_count=NUM_CTAS. - MMA
mBarrierswithtwo_ctas=True: hardware signals when input reads complete. The TMEM output write may still be in-flight.
input_precision options: tf32, tf32x3, ieee
CLC (Cluster Launch Control) — Blackwell only
| Function | Description |
|---|---|
tlx.clc_create_context(num_consumers, num_stages=1) |
Create CLC pipeline context (allocates barriers + response buffers) |
tlx.clc_producer(context, p_producer, multi_ctas=False, k=0) |
Issue CLC try_cancel request from CTA 0 |
tlx.clc_consumer(context, p_consumer, multi_ctas=False, k=0, return_3d=False) |
Decode tile ID from CLC response, signal completion. Returns tile_id or -1. With return_3d=True, returns (ctaIdX, ctaIdY, ctaIdZ) tuple. |
For 2-CTA mode: set multi_ctas=True (uses "arrive remote, wait local" pattern).
Utility
| Function | Description | Arch |
|---|---|---|
tlx.cluster_cta_rank() |
Unique CTA ID within a cluster (all dims) | Both |
tlx.thread_id(axis) |
Thread ID along axis 0, 1, or 2 | Both |
tlx.dtype_of(tensor_or_desc) |
Get element type of tensor or tensor descriptor | Both |
tlx.size_of(dtype) |
Size of dtype in bytes | Both |
tlx.get_fp8_format_name(dtype) |
Get FP8 format string ("e5m2" or "e4m3") for scaled MMA | Both |
tlx.clock64() |
64-bit hardware clock value (for timing) | Both |
tlx.stoch_round(src, dst_ty, rand_bits) |
Hardware stochastic rounding FP32 → FP8/BF16/F16 | Blackwell |
Common patterns
Producer-consumer with mbarrier (pipelined GEMM)
bars_full = tlx.alloc_barriers(num_stages, arrive_count=1) # TMA arrives implicitly
bars_empty = tlx.alloc_barriers(num_stages, arrive_count=num_consumers)
# Producer: TMA load → signal full
tlx.barrier_expect_bytes(bar_full, nbytes)
tlx.async_descriptor_load(desc, indices, barrier=bar_full)
# Consumer: wait full → MMA → signal empty
tlx.barrier_wait(bar_full, phase)
tlx.async_dot(A, B, acc)
tlx.barrier_arrive(bar_empty)
PingPong with named barriers
# Consumer 0 waits for Consumer 1, then issues MMA
tlx.named_barrier_wait(9, 256) # 256 = 2 warp groups * 4 warps * 32 threads
qk = tlx.async_dot(q, k)
tlx.named_barrier_arrive(10, 256)
# Consumer 1 waits for Consumer 0's MMA to finish
tlx.named_barrier_arrive(9, 256)
tlx.named_barrier_wait(10, 256)
qk = tlx.async_dot(q, k)
Deep-dive docs
- API reference:
third_party/tlx/README.md - Barriers:
third_party/tlx/doc/tlx_barriers.md - Placeholder layouts:
third_party/tlx/doc/PlaceholderLayouts.md - Storage alias design:
third_party/tlx/doc/storage_alias_spec_design.md