name: sharding-stats description: Investigate and explain TorchRec planner sharding statistics output, especially how HBM storage is computed per table and per rank. Use when the user asks about sharding stats, storage breakdown, or memory estimation. allowed-tools: Read, Grep, Bash, Task argument-hint: [table name, sharding type, or question about stats]
Sharding Stats Investigation Guide
Investigate and explain TorchRec planner sharding statistics, especially HBM storage computation, for: $ARGUMENTS
Instructions
You are analyzing the output of EmbeddingStats (in torchrec/distributed/planner/stats.py). This skill covers how to read the stats table and how each number is computed from source code.
Key Source Files
torchrec/distributed/planner/stats.py— Generates the bordered stats table output (EmbeddingStatsclass)torchrec/distributed/planner/shard_estimators.py— Core estimation logic:EmbeddingStorageEstimator— orchestrates storage estimationcalculate_shard_storages()— assembles finalStorageper shard_calculate_shard_io_sizes()— dispatches to sharding-type-specific I/O calculations_calculate_rw_shard_io_sizes(),_calculate_tw_shard_io_sizes(), etc._calculate_storage_specific_sizes()— tensor + optimizer + cache aux_calculate_tensor_sizes()— proportional tensor size per shard_calculate_optimizer_sizes()— optimizer state multipliercalculate_pipeline_io_cost()— pipeline type I/O multipliers
torchrec/distributed/planner/storage_reservations.py— Dense storage and KJT storage reservationtorchrec/distributed/planner/types.py—Perf,Storage,ShardingOptiondataclassestorchrec/distributed/planner/utils.py—bytes_to_gb(1 GB = 2^30 bytes),bytes_to_mbtorchrec/distributed/embedding_types.py— Sharderstorage_usage()implementations
Stats Table Structure
The stats output has these sections:
- Per-Rank Summary: Rank, HBM (GB), DDR (GB), Perf (ms), Input (MB), Output (MB), Shards
- Parameter Info Table: Per-table details (FQN, Sharding, Compute Kernel, Perf, Storage, etc.)
- Batch Size & Compute Kernels: Global batch size, kernel counts and storage
- Imbalance Statistics: Total Variation, Total Distance, Chi Divergence, KL Divergence
- Peak Memory Estimation: Top-tier HBM pressure per rank
- Storage Reservation: Reserved, Planning, Dense, KJT storage
HBM Storage Computation — Step by Step
Per-Rank HBM Formula
used_hbm[rank] = sparse_hbm[rank] + dense_storage.hbm + kjt_storage.hbm
sparse_hbm[rank]= sum ofshard.storage.hbmfor every embedding shard placed on that rankdense_storage= non-embedding model parameters (fromHeuristicalStorageReservation)kjt_storage= KeyedJaggedTensor input buffers
Per-Shard HBM Formula (the core calculation)
Each shard's HBM is computed in calculate_shard_storages():
shard.storage.hbm = hbm_specific_size + pipeline_io_cost
Where:
hbm_specific_size = tensor_size + optimizer_size + cache_aux_size
Step 1: Base Tensor Storage (sharder.storage_usage())
The sharder determines the raw tensor bytes:
- EmbeddingBagCollectionSharder:
hbm_storage = num_embeddings × emb_dim × element_size - EmbeddingCollectionSharder:
hbm_storage = num_embeddings × emb_dim × element_size + num_embeddings × 4- The extra
shape[0] × 4bytes is metadata overhead for sequence embeddings
- The extra
- For UVM caching kernels: tensor goes to DDR, HBM gets
ddr_storage × caching_ratio
Step 2: Per-Shard Tensor Size (_calculate_tensor_sizes())
tensor_size = ceil(hbm_storage × prod(shard_size) / prod(full_shape))
For RW sharding with world_size shards: shard_size = [num_embeddings / world_size, emb_dim]
Step 3: Optimizer Size (_calculate_optimizer_sizes())
optimizer_size = ceil(tensor_size × optimizer_multiplier)
| Optimizer | Multiplier |
|---|---|
| SGD | 0 |
| Adam | 2 |
| RowWiseAdagrad | 1 / emb_dim |
| Default/unknown | 1 |
| None (inference) | 0 |
Step 4: Cache Auxiliary State (_calculate_cache_aux_state_sizes())
Only applies to UVM caching (fused_uvm_caching kernel). For fused kernel: 0.
Step 5: I/O Sizes (sharding-type-specific)
Computed by _calculate_shard_io_sizes() → dispatches to type-specific functions.
Constants:
input_data_type_size = 8(BIGINT_DTYPE, int64 indices)output_data_type_size = tensor.element_size()(oroutput_dtypeif specified)
For RW sharding (_calculate_rw_shard_io_sizes()):
batch_inputs = sum(input_length_i × num_poolings_i × batch_size_i) / world_size
batch_outputs = batch_inputs # if non-pooled (sequence)
= sum(num_poolings_i × batch_size_i) # if pooled
input_size = ceil(batch_inputs × world_size × input_data_type_size) # per shard
output_size = ceil(batch_outputs × world_size × shard_dim × output_data_type_size) # per shard
For TW sharding (_calculate_tw_shard_io_sizes()):
batch_inputs = sum(input_length_i × num_poolings_i × batch_size_i) # no division
input_size = ceil(batch_inputs × world_size × input_data_type_size)
output_size = ceil(batch_outputs × world_size × emb_dim × output_data_type_size)
For CW sharding (_calculate_cw_shard_io_sizes()):
# Same as TW but output uses shard_sizes[i][1] (shard column dim) instead of full emb_dim
output_size = ceil(batch_outputs × world_size × shard_sizes[i][1] × output_data_type_size)
Critical insight for sequence (non-pooled) embeddings:
When is_pooled=False, batch_outputs = batch_inputs, which means output includes one full embedding vector per input index. With large input_lengths, the output buffer can be enormous — often 90%+ of total storage.
Step 6: Pipeline I/O Cost (calculate_pipeline_io_cost())
output_contribution = output_size if count_ephemeral_storage_cost else 0
| Pipeline Type | Formula |
|---|---|
NONE (catch-all) |
input_size + output_size |
TRAIN_SPARSE_DIST |
2 × input_size + output_contribution |
TRAIN_PREFETCH_SPARSE_DIST |
3 × input_size + (1 + 6/max_pass) × prefetch_size + output_contribution |
| Inference | 0 |
prefetch_size = input_sizeif table is cached, else0count_ephemeral_storage_costdefaults toFalse
Step 7: Final Per-Shard Storage
shard.storage.hbm = hbm_specific_size + pipeline_io_cost
Step 8: Total Storage (shown in Parameter Info table)
The "Storage (HBM, DDR)" column in the parameter info table shows:
total_storage = sum(shard.storage for shard in sharding_option.shards)
This is the sum across ALL shards (all ranks), not per-rank.
Dense Storage Reservation
From HeuristicalStorageReservation:
dense_storage.hbm = (total_model_params - embedding_params) × multiplier + buffers
- Training multiplier = 6.0 (1× params + 2× optimizer state + 3× DDP gradient buffers)
- Inference multiplier = 1.0
KJT Storage Reservation
kjt_storage.hbm = total_kjt_size × kjt_multiplier
- Training multiplier = 20 (pipelined batches)
- Inference multiplier = 1
Investigation Checklist
When analyzing a table's storage, determine these parameters:
- Tensor dtype: float32 (4 bytes) or float16/bfloat16 (2 bytes)? Cross-check with per-rank Output (MB) column.
- Sharder type: EmbeddingBagCollectionSharder vs EmbeddingCollectionSharder (EC adds
shape[0] × 4) - is_pooled: pooled (EmbeddingBag) vs sequence (Embedding) — check "Output" column
- Optimizer: check
_optimizer_classesattribute on tensor; RowWiseAdagrad is common for RecSys - Pipeline type: NONE, TRAIN_SPARSE_DIST, or TRAIN_PREFETCH_SPARSE_DIST
- count_ephemeral_storage_cost: defaults to False
- Compute kernel: fused, fused_uvm_caching, quant, etc.
- Caching ratio: only matters for UVM caching kernels
Worked Example: RW Sequence Embedding
Given: hash_size=80M, emb_dim=128, dtype=fp16, 4 features, sum(input_lengths)=6066, batch_size=2560, world_size=96, RW sharding, fused kernel, RowWiseAdagrad, PipelineType.NONE
- Tensor storage:
80M × 128 × 2 = 20,480,000,000 bytes(19.07 GB) - Per-shard tensor:
ceil(20,480,000,000 × 833,333 / 80,000,000) = 213,333,248 bytes - Optimizer:
ceil(213,333,248 / 128) = 1,666,666 bytes(~1.6 MB) - hbm_specific:
213,333,248 + 1,666,666 = 214,999,914 bytes(~205 MB) - I/O (RW, non-pooled):
batch_inputs = 2560 × 6066 / 96 = 161,760input_size = ceil(161,760 × 96 × 8) = 124,231,680(~118 MB)output_size = ceil(161,760 × 96 × 128 × 2) = 3,975,413,760(~3.70 GB)
- Pipeline (NONE):
124,231,680 + 3,975,413,760 = 4,099,645,440(~3.82 GB) - Per-shard total:
214,999,914 + 4,099,645,440 = 4,314,645,354(~4.02 GB) - Total (96 shards):
~385.8 GB
Key insight: Output buffer is ~92% of total — sequence embeddings with large input_lengths dominate storage.
| Component | Per Shard | Total (96 shards) | % |
|---|---|---|---|
| Embedding weights | 203 MB | 19.07 GB | 4.9% |
| Optimizer (RowWiseAdagrad) | 1.6 MB | 0.15 GB | 0.04% |
| Input buffer (int64) | 118 MB | 11.1 GB | 2.9% |
| Output buffer (fp16) | 3.70 GB | 355.5 GB | 92.1% |
| Total | ~4.02 GB | ~385.8 GB |
Example Stats Output
See sharding_stats_example.txt for a full example.