name: test-gen argument-hint: [file-path or "local"] description: Generate tests for TorchRec source files with correct patterns (unit, distributed, hypothesis), proper BUCK targets, and test utilities. Use when asked to generate tests, add test coverage, or write tests for a module. allowed-tools: Read, Write, Edit, Bash(sl:), Bash(buck2:), Grep, Glob, Task
TorchRec Test Generator
Generate idiomatic TorchRec tests by reading source files, detecting the appropriate test type, scaffolding test code with correct patterns, and creating/updating BUCK targets.
Usage Modes
File Path Mode
/test-gen torchrec/distributed/sharding/my_sharder.py
/test-gen torchrec/modules/new_module.py
Generate tests for the specified source file.
Auto-Detect Mode
/test-gen local
/test-gen
Detect changed files via sl status and generate tests for new/modified source files that lack test coverage.
Workflow
Phase 1: Identify Source Files
File path mode: Read the specified file.
Auto-detect mode:
- Run
sl statusto find changed files - Filter to
.pysource files intorchrec/(exclude test files,__init__.py, BUCK files) - For each source file, check if a corresponding test file exists at
$(dirname)/test(s)/test_$(basename) - Present the list of untested files and ask the user which to generate tests for
Phase 2: Analyze Source Code
Read the source file and classify it:
Detection rules (in priority order):
Distributed test if ANY of:
- File is under
torchrec/distributed/ - Imports from
torch.distributed,torchrec.distributed, or usesProcessGroup - Defines sharders, sharded modules, or uses
ShardingType - Uses
LazyAwaitable,all_to_all,all_reduce,all_gather
- File is under
Hypothesis-parameterized test if ANY of:
- Source defines enums, configs, or strategies with multiple variants
- Source handles multiple
ShardingTypeorEmbeddingComputeKernelvalues - Source has branching behavior based on config parameters
Unit test (default) if:
- File is under
torchrec/modules/,torchrec/sparse/,torchrec/optim/,torchrec/metrics/ - No distributed primitives used
- File is under
A file can be both distributed AND hypothesis-parameterized.
Extract from the source file:
- Public classes and their methods
- Public functions
- Constructor signatures and required arguments
- Key data types (KJT, EBC, KeyedTensor, etc.)
- Dependencies and imports needed for tests
Phase 3: Determine Test Location
Follow TorchRec convention:
- Source:
torchrec/foo/bar/my_module.py - Test:
torchrec/foo/bar/tests/test_my_module.py
If a tests/ directory doesn't exist, create it.
If a test file already exists, add new test methods rather than overwriting.
Phase 4: Generate Test Code
Generate tests following the patterns below. See test-patterns.md for complete templates.
For all test types:
- BSD license header +
# pyre-strict - Type hints on all methods (return
-> Nonefor test methods) - Use
self.assertEqual,self.assertTrue,torch.testing.assert_closefor assertions - Cover: happy path, edge cases (empty inputs, single element), error conditions
- Name tests descriptively:
test_<what>_<condition>
For unit tests:
- Inherit from
unittest.TestCase - Test each public method/function independently
- For modules: test
forward()with representative inputs, verify output shapes and types
For distributed tests:
- Inherit from
MultiProcessTestBase - Use
@staticmethodor module-level_test_func(rank, world_size, **kwargs)pattern - Wrap per-rank logic in
with MultiProcessContext(rank, world_size, backend) as ctx: - Default
world_size=2, addworld_size=4for sharding tests - Use
backend="gloo"unless testing GPU-specific behavior - Add
@unittest.skipIf(torch.cuda.device_count() < N, "Not enough GPUs...")for CUDA tests
For hypothesis tests:
- Add
@given(...)withst.sampled_from([...])for enum/config parameters - Add
@settings(verbosity=Verbosity.verbose, max_examples=N, deadline=None) - Use
assume()to filter invalid parameter combinations - Keep
max_examplesreasonable (4-8 for distributed tests, 10-20 for unit tests)
Phase 5: Create/Update BUCK Target
Read the existing BUCK file in the tests/ directory (or create one if it doesn't exist).
For CPU-only unit tests:
python_unittest(
name = "test_my_module",
srcs = ["test_my_module.py"],
deps = [
"//caffe2:_torch",
# ... source deps ...
],
)
For GPU/distributed tests:
python_unittest(
name = "test_my_module",
srcs = ["test_my_module.py"],
remote_execution = re_test_utils.remote_execution(
mig = "false",
platform = "gpu-remote-execution",
resource_units = 2,
),
deps = [
"//caffe2:_torch",
"//torchrec/distributed/test_utils:multi_process",
# ... source deps ...
],
)
If hypothesis is used, add:
supports_static_listing = False,
and add to deps:
"fbsource//third-party/pypi/hypothesis:hypothesis",
BUCK rules:
- Use
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")for standard tests - Add
load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils")for GPU tests - Include
oncall("torchrec")if already present in the BUCK file - Derive deps from the test file's imports — map each
torchrec.*import to its BUCK target by checking the source directory's BUCK file
Phase 6: Verify
- Ask the user to review the generated test file
- Suggest running the test:
buck2 test fbcode//torchrec/path/to/tests:test_my_module - If hypothesis is used, suggest running with more examples:
buck2 test fbcode//torchrec/path/to/tests:test_my_module -- -s
Test Utilities Reference
Use these utilities when generating tests:
| Utility | Import | When to Use |
|---|---|---|
MultiProcessTestBase |
torchrec.distributed.test_utils.multi_process |
All distributed tests |
MultiProcessContext |
torchrec.distributed.test_utils.multi_process |
Per-rank setup/teardown |
ModelInput |
torchrec.distributed.test_utils.test_model |
Generating test inputs for models |
TestSparseNN |
torchrec.distributed.test_utils.test_model |
Test model with embedding tables |
sharding_single_rank_test |
torchrec.distributed.test_utils.test_sharding |
Testing sharders |
create_test_sharder |
torchrec.distributed.test_utils.test_sharding |
Creating test sharder instances |
skip_if_asan_class |
torchrec.test_utils |
Skip entire class under ASAN |
seed_and_log |
torchrec.test_utils |
Deterministic seeding with logging |
get_free_port |
torchrec.test_utils |
Getting available port for dist init |
Constraints
- NEVER overwrite existing test methods. Add new methods to existing test classes or create new classes.
- NEVER add tests for private methods (starting with
_) unless they contain complex logic that's critical to test. - ALWAYS match the import style of the source file (modern
list[str]vsList[str]). - ALWAYS check if similar tests already exist before generating duplicates.
- ALWAYS prefer real implementations over mocks. Use
MultiProcessContext+ real gloo PG for distributed tests,ScopedConfigeratorFake/ JK overrides for config and feature flags, and in-memory fakes where they exist. Reach formock.patch/MagicMockonly when no real fake exists for the dependency, and call out why in a one-line comment. - Keep generated tests focused and minimal — don't test framework behavior or trivial getters/setters.