πͺUnderstand FSDP Sharding Strategies
Walk every FSDP sharding strategy across the same toy transformer until all-gather and reduce-scatter become numbers, not folklore. By the end you can pick FULL_SHARD vs SHARD_GRAD_OP vs HYBRID_SHARD for a 7B model on 16 GPUs and defend it.
Phase 1FSDP as PyTorch's Native ZeRO-3
See FSDP as PyTorch's native ZeRO-3
FSDP is ZeRO-3 wearing a PyTorch hat
6 minFSDP and DeepSpeed ZeRO-3 implement the same idea β shard params, grads, and optimizer state, all-gather on demand β with different APIs and slightly different defaults.
All-gather on the way in, reduce-scatter on the way out
7 minEach FSDP unit costs one all-gather in forward, one all-gather plus one reduce-scatter in backward β that's the whole comm story.
The FlatParameter is FSDP's unit of work β and you choose it
7 minAn auto-wrap policy decides which submodules get fused into one FlatParameter; that decision controls both peak memory and overlap, more than the sharding strategy does.
MixedPrecision is a third axis, orthogonal to sharding
7 minFSDP's MixedPrecision config lets you store params in one dtype, compute in another, and reduce gradients in a third β independent of which sharding strategy you pick.
Phase 2Wrapping a Transformer and Reading FlatParameters
Wrap a transformer and watch FlatParameters form
Wrap a 2-layer transformer in 20 lines and inspect it
7 minYou can see exactly which submodules became FlatParameters by walking `model.named_modules()` after wrapping β no profiler needed for the first sanity check.
Flip FULL_SHARD to SHARD_GRAD_OP and measure peak memory
7 minFULL_SHARD shards params; SHARD_GRAD_OP keeps params replicated β the difference shows up as a roughly 2x params-worth of memory per rank.
Activation checkpointing makes sharding look better than it is
7 minFSDP's memory savings target static state (params + grads + optimizer); activations are an independent axis that often dominates peak memory, and activation checkpointing is the right lever for them.
BackwardPrefetch.BACKWARD_PRE is the throughput knob nobody mentions
6 minSetting `backward_prefetch=BackwardPrefetch.BACKWARD_PRE` overlaps the next unit's all-gather with the current unit's backward compute β often a 5-15% throughput win for free.
Three knobs, in priority order: wrap policy, strategy, prefetch
6 minMost FSDP tuning collapses to three decisions in this order: wrap policy first (overlap), then strategy (memory), then prefetch (throughput) β get them in the right order and you're 90% of the way to a good config.
Phase 3Choosing the Sharding Mode for the Topology
Pick the sharding mode that fits the topology
Single-node, plenty of memory: do you really need FULL_SHARD?
7 minSingle-node, plenty of memory: do you really need FULL_SHARD?
Two nodes, slow inter-node link: when HYBRID_SHARD wins
7 minTwo nodes, slow inter-node link: when HYBRID_SHARD wins
Activations OOM on long context: which knob first?
7 minActivations OOM on long context: which knob first?
The throughput dropped 20% after enabling FSDP β what's the bug?
8 minThe throughput dropped 20% after enabling FSDP β what's the bug?
Phase 4Defend a Strategy for 7B on 16 GPUs
Defend a strategy for 7B on 16 GPUs
7B model, 16 GPUs, 2 nodes: pick a strategy and defend it
8 min7B model, 16 GPUs, 2 nodes: pick a strategy and defend it
Frequently asked questions
- What is FSDP and how is it different from DDP?
- This is covered in the βUnderstand FSDP Sharding Strategiesβ learning path. Start with daily 5-minute micro-lessons that build from fundamentals to hands-on application.
- What's the difference between FULL_SHARD and SHARD_GRAD_OP?
- This is covered in the βUnderstand FSDP Sharding Strategiesβ learning path. Start with daily 5-minute micro-lessons that build from fundamentals to hands-on application.
- When should I use HYBRID_SHARD instead of FULL_SHARD?
- This is covered in the βUnderstand FSDP Sharding Strategiesβ learning path. Start with daily 5-minute micro-lessons that build from fundamentals to hands-on application.
- How do auto-wrap policies actually decide what becomes a FlatParameter?
- This is covered in the βUnderstand FSDP Sharding Strategiesβ learning path. Start with daily 5-minute micro-lessons that build from fundamentals to hands-on application.
- Why does FSDP throughput drop 20% compared to DDP, and how do I fix it?
- This is covered in the βUnderstand FSDP Sharding Strategiesβ learning path. Start with daily 5-minute micro-lessons that build from fundamentals to hands-on application.
Related paths
πPython Decorators Introduction
Build one mental model for Python decorators that covers closures, argument passing, functools.wraps, and stacking β then ship a working caching or logging decorator from scratch in under 30 lines.
π¦Rust Lifetimes Explained
Stop reading `'a` as line noise and start reading it as scope arithmetic β one failing snippet at a time β until you can thread lifetimes through a small parser or iterator adapter without fighting the borrow checker.
βΈοΈKubernetes Core Concepts
Stop drowning in 30+ resource types. Build the mental model one primitive at a time -- pods, deployments, services, ingress, config -- then deploy a real app with rolling updates and health checks.
πBig O Intuition
Stop treating Big O as math you memorized for an interview β build the intuition to spot O(nΒ²) disasters, pick the right data structure without thinking, and rewrite a slow function from O(nΒ²) to O(n) in under five minutes.