1. Introduction
have a model. You have a single GPU. Training takes 72 hours. You requisition a second machine with four more GPUs — and now you need your code to actually use them. This is the exact moment where most practitioners hit a wall. Not because distributed training is conceptually hard, but because the engineering required to do it correctly — process groups, rank-aware logging, sampler seeding, checkpoint barriers — is scattered across dozens of tutorials that each cover one piece of the puzzle.
This article is the guide I wish I had when I first scaled training beyond a single node. We will build a complete, production-grade multi-node training pipeline from scratch using PyTorch’s DistributedDataParallel (DDP). Every file is modular, every value is configurable, and every distributed concept is made explicit. By the end, you will have a codebase you can drop into any cluster and start training immediately.
What we will cover: the mental model behind DDP, a clean modular project structure, distributed lifecycle management, efficient data loading across ranks, a training loop with mixed precision and gradient accumulation, rank-aware logging and checkpointing, multi-node launch scripts, and the performance pitfalls that trip up even experienced engineers.
The full codebase is available on GitHub. Every code block in this article is pulled directly from that repository.
2. How DDP Works — The Mental Model
Before writing any code, we need a clear mental model. DistributedDataParallel (DDP) is not magic — it is a well-defined communication pattern built on top of collective operations.
The setup is straightforward. You launch N processes (one per GPU, potentially across multiple machines). Each process initialises a process group — a communication channel backed by NCCL (NVIDIA Collective Communications Library) for GPU-to-GPU transfers. Every process gets three identity numbers: its global rank (unique across all machines), its local rank (unique within its machine), and the total world size.
Each process holds an identical copy of the model. Data is partitioned across processes using a DistributedSampler — every rank sees a different slice of the dataset, but the model weights start (and stay) identical.
The critical mechanism is what happens during backward(). DDP registers hooks on every parameter. When a gradient is computed for a parameter, DDP buckets it with nearby gradients and fires an all-reduce operation across the process group. This all-reduce computes the mean gradient across all ranks. Because every rank now has the same averaged gradient, the subsequent optimizer step produces identical weight updates, keeping all replicas in sync — without any explicit synchronisation code from us.
This is why DDP is strictly superior to the older DataParallel: there is no single “master” GPU bottleneck, no redundant forward passes, and gradient communication overlaps with backward computation.
Figure 1: DDP gradient synchronization flow. All-reduce happens automatically via hooks registered during backward().
Key terminology
TermMeaningRankGlobally unique process ID (0 to world_size – 1)Local RankGPU index within a single machine (0 to nproc_per_node – 1)World SizeTotal number of processes across all nodesProcess GroupCommunication channel (NCCL) connecting all ranks
3. Architecture Overview
A production training pipeline should never be a single monolithic script. Ours is split into six focused modules, each with a single responsibility. The dependency graph below shows how they connect — note that config.py sits at the bottom, acting as the single source of truth for every hyperparameter.
Figure 2: Module dependency graph. train.py orchestrates all other modules. config.py is imported by everyone
Here is the project structure:
pytorch-multinode-ddp/
├── train.py # Entry point — training loop
├── config.py # Dataclass configuration + argparse
├── ddp_utils.py # Distributed setup, teardown, checkpointing
├── model.py # MiniResNet (lightweight ResNet variant)
├── dataset.py # Synthetic dataset + DistributedSampler loader
├── utils/
│ ├── logger.py # Rank-aware structured logging
│ └── metrics.py # Running averages + distributed all-reduce
├── scripts/
│ └── launch.sh # Multi-node torchrun wrapper
└── requirements.txt
This separation means you can swap in a real dataset by editing only dataset.py, or replace the model by editing only model.py. The training loop never needs to change.
4. Centralized Configuration
Hard-coded hyperparameters are the enemy of reproducibility. We use a Python dataclass as our single source of configuration. Every other module imports TrainingConfig and reads from it — nothing is hard-coded.
The dataclass doubles as our CLI parser: the from_args() classmethod introspects the field names and types, automatically building argparse flags with defaults. This means you get –batch_size 128 and –no-use_amp for free, without writing a single parser line by hand.
@dataclass
class TrainingConfig:
“””Immutable bag of every parameter the training pipeline needs.”””
# Model
num_classes: int = 10
in_channels: int = 3
image_size: int = 32
# Data
batch_size: int = 64 # per-GPU
num_workers: int = 4
# Optimizer / Scheduler
epochs: int = 10
lr: float = 0.01
momentum: float = 0.9
weight_decay: float = 1e-4
# Distributed
backend: str = “nccl”
# Mixed Precision
use_amp: bool = True
# Gradient Accumulation
grad_accum_steps: int = 1
# Checkpointing
checkpoint_dir: str = “./checkpoints”
save_every: int = 1
resume_from: Optional[str] = None
# Logging & Profiling
log_interval: int = 10
enable_profiling: bool = False
seed: int = 42
@classmethod
def from_args(cls) -> “TrainingConfig”:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
defaults = cls()
for name, val in vars(defaults).items():
arg_type = type(val) if val is not None else str
if isinstance(val, bool):
parser.add_argument(f”–{name}”, default=val,
action=argparse.BooleanOptionalAction)
else:
parser.add_argument(f”–{name}”, type=arg_type, default=val)
return cls(**vars(parser.parse_args()))
Why a dataclass instead of YAML or JSON? Three reasons: (1) type hints are enforced by the IDE and mypy, (2) there is zero dependency on third-party config libraries, and (3) every parameter has a visible default right next to its declaration. For production systems that need hierarchical configs, you can always layer Hydra or OmegaConf on top of this pattern.
5. Distributed Lifecycle Management
The distributed lifecycle has three phases: initialise, run, and tear down. Getting any of these wrong can produce silent hangs, so we wrap everything in explicit error handling.
Process Group Initialization
The setup_distributed() function reads the three environment variables that torchrun sets automatically (RANK, LOCAL_RANK, WORLD_SIZE), pins the correct GPU with torch.cuda.set_device(), and initialises the NCCL process group. It returns a frozen dataclass — DistributedContext — that the rest of the codebase passes around instead of re-reading os.environ.
@dataclass(frozen=True)
class DistributedContext:
“””Immutable snapshot of the current process’s distributed identity.”””
rank: int
local_rank: int
world_size: int
device: torch.device
def setup_distributed(config: TrainingConfig) -> DistributedContext:
required_vars = (“RANK”, “LOCAL_RANK”, “WORLD_SIZE”)
missing = [v for v in required_vars if v not in os.environ]
if missing:
raise RuntimeError(
f”Missing environment variables: {missing}. ”
“Launch with torchrun or set them manually.”)
if not torch.cuda.is_available():
raise RuntimeError(“CUDA is required for NCCL distributed training.”)
rank = int(os.environ[“RANK”])
local_rank = int(os.environ[“LOCAL_RANK”])
world_size = int(os.environ[“WORLD_SIZE”])
torch.cuda.set_device(local_rank)
device = torch.device(“cuda”, local_rank)
dist.init_process_group(backend=config.backend)
return DistributedContext(
rank=rank, local_rank=local_rank,
world_size=world_size, device=device)
Checkpointing with Rank Guards
The most common distributed checkpointing bug is all ranks writing to the same file simultaneously. We guard saving behind is_main_process(), and loading behind dist.barrier() — this ensures rank 0 finishes writing before other ranks attempt to read.
def save_checkpoint(path, epoch, model, optimizer, scaler=None, rank=0):
“””Persist training state to disk (rank-0 only).”””
if not is_main_process(rank):
return
Path(path).parent.mkdir(parents=True, exist_ok=True)
state = {
“epoch”: epoch,
“model_state_dict”: model.module.state_dict(),
“optimizer_state_dict”: optimizer.state_dict(),
}
if scaler is not None:
state[“scaler_state_dict”] = scaler.state_dict()
torch.save(state, path)
def load_checkpoint(path, model, optimizer=None, scaler=None, device=”cpu”):
“””Restore training state. All ranks load after barrier.”””
dist.barrier() # wait for rank 0 to finish writing
ckpt = torch.load(path, map_location=device, weights_only=False)
model.load_state_dict(ckpt[“model_state_dict”])
if optimizer and “optimizer_state_dict” in ckpt:
optimizer.load_state_dict(ckpt[“optimizer_state_dict”])
if scaler and “scaler_state_dict” in ckpt:
scaler.load_state_dict(ckpt[“scaler_state_dict”])
return ckpt.get(“epoch”, 0)
6. Model Design for DDP
We use a lightweight ResNet variant called MiniResNet — three residual stages with increasing channels (64, 128, 256), two blocks per stage, global average pooling, and a fully-connected head. It is complex enough to be realistic but light enough to run on any hardware.
The critical DDP requirement: the model must be moved to the correct GPU before wrapping. DDP does not move models for you.
def create_model(config: TrainingConfig, device: torch.device) -> nn.Module:
“””Instantiate a MiniResNet and move it to device.”””
model = MiniResNet(
in_channels=config.in_channels,
num_classes=config.num_classes,
)
return model.to(device)
def wrap_ddp(model: nn.Module, local_rank: int) -> DDP:
“””Wrap model with DistributedDataParallel.”””
return DDP(model, device_ids=[local_rank])
Note the two-step pattern: create_model() → wrap_ddp(). This separation is intentional. When loading a checkpoint, you need the unwrapped model (model.module) to load state dicts, then re-wrap. If you fuse creation and wrapping, checkpoint loading becomes awkward.
7. Distributed Data Loading
DistributedSampler is what ensures each GPU sees a unique slice of data. It partitions indices across world_size ranks and returns a non-overlapping subset for each. Without it, every GPU would train on identical batches — burning compute for zero benefit.
There are three details that trip people up:
First, sampler.set_epoch(epoch) must be called at the start of every epoch. The sampler uses the epoch number as a random seed for shuffling. If you forget this, every epoch will iterate over data in the same order, which degrades generalisation.
Second, pin_memory=True in the DataLoader pre-allocates page-locked host memory, enabling asynchronous CPU-to-GPU transfers when you call tensor.to(device, non_blocking=True). This overlap is where real throughput gains come from.
Third, persistent_workers=True avoids respawning worker processes every epoch — a significant overhead reduction when num_workers > 0.
def create_distributed_dataloader(dataset, config, ctx):
sampler = DistributedSampler(
dataset,
num_replicas=ctx.world_size,
rank=ctx.rank,
shuffle=True,
)
loader = DataLoader(
dataset,
batch_size=config.batch_size,
sampler=sampler,
num_workers=config.num_workers,
pin_memory=True,
drop_last=True,
persistent_workers=config.num_workers > 0,
)
return loader, sampler
8. The Training Loop — Where It All Comes Together
This is the heart of the pipeline. The loop below integrates every component we have built so far: DDP-wrapped model, distributed data loader, mixed precision, gradient accumulation, rank-aware logging, learning rate scheduling, and checkpointing.
Figure 3: Training loop state machine. The inner step loop handles gradient accumulation; the outer epoch loop handles scheduler stepping and checkpointing.
Mixed Precision (AMP)
Automatic Mixed Precision (AMP) keeps master weights in FP32 but runs the forward pass and loss computation in FP16. This halves memory bandwidth requirements and enables Tensor Core acceleration on modern NVIDIA GPUs, often yielding a 1.5–2x throughput improvement with negligible accuracy impact.
We use torch.autocast for the forward pass and torch.amp.GradScaler for loss scaling. A subtlety: we create the GradScaler with enabled=config.use_amp. When disabled, the scaler becomes a no-op — same code path, zero overhead, no branching.
Gradient Accumulation
Sometimes you need a larger effective batch size than your GPU memory allows. Gradient accumulation simulates this by running multiple forward-backward passes before stepping the optimizer. The key is to divide the loss by grad_accum_steps before backward(), so the accumulated gradient is correctly averaged.
def train_one_epoch(model, loader, criterion, optimizer, scaler, ctx, config, epoch, logger):
model.train()
tracker = MetricTracker()
total_steps = len(loader)
use_amp = config.use_amp and ctx.device.type == “cuda”
autocast_ctx = torch.autocast(“cuda”, dtype=torch.float16) if use_amp else nullcontext()
optimizer.zero_grad(set_to_none=True)
for step, (images, labels) in enumerate(loader):
images = images.to(ctx.device, non_blocking=True)
labels = labels.to(ctx.device, non_blocking=True)
with autocast_ctx:
outputs = model(images)
loss = criterion(outputs, labels)
loss = loss / config.grad_accum_steps # scale for accumulation
scaler.scale(loss).backward()
if (step + 1) % config.grad_accum_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True) # memory-efficient reset
# Track raw (unscaled) loss for logging
raw_loss = loss.item() * config.grad_accum_steps
acc = compute_accuracy(outputs, labels)
tracker.update(“loss”, raw_loss, n=images.size(0))
tracker.update(“accuracy”, acc, n=images.size(0))
if is_main_process(ctx.rank) and (step + 1) % config.log_interval == 0:
log_training_step(logger, epoch, step + 1, total_steps,
raw_loss, optimizer.param_groups[0][“lr”])
return tracker
Two details worth highlighting. First, zero_grad(set_to_none=True) deallocates gradient tensors instead of filling them with zeros, saving memory proportional to the model size. Second, data is moved to the GPU with non_blocking=True — this allows the CPU to continue filling the next batch while the current one transfers, exploiting the pin_memory overlap.
The Main Function
The main() function orchestrates the full pipeline. Note the try/finally pattern guaranteeing that the process group is torn down even if an exception occurs — without this, a crash on one rank can leave other ranks hanging indefinitely.
def main():
config = TrainingConfig.from_args()
ctx = setup_distributed(config)
logger = setup_logger(ctx.rank)
torch.manual_seed(config.seed + ctx.rank)
model = create_model(config, ctx.device)
model = wrap_ddp(model, ctx.local_rank)
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr,
momentum=config.momentum,
weight_decay=config.weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=config.epochs)
scaler = torch.amp.GradScaler(enabled=config.use_amp)
start_epoch = 1
if config.resume_from:
start_epoch = load_checkpoint(config.resume_from, model.module,
optimizer, scaler, ctx.device) + 1
dataset = SyntheticImageDataset(size=50000, image_size=config.image_size,
num_classes=config.num_classes)
loader, sampler = create_distributed_dataloader(dataset, config, ctx)
criterion = nn.CrossEntropyLoss()
try:
for epoch in range(start_epoch, config.epochs + 1):
sampler.set_epoch(epoch)
tracker = train_one_epoch(model, loader, criterion, optimizer,
scaler, ctx, config, epoch, logger)
scheduler.step()
avg_loss = all_reduce_scalar(tracker.average(“loss”),
ctx.world_size, ctx.device)
if is_main_process(ctx.rank):
log_epoch_summary(logger, epoch, {“loss”: avg_loss})
if epoch % config.save_every == 0:
save_checkpoint(f”checkpoints/epoch_{epoch}.pt”,
epoch, model, optimizer, scaler, ctx.rank)
finally:
cleanup_distributed()
9. Launching Across Nodes
PyTorch’s torchrun (introduced in v1.10 as a replacement for torch.distributed.launch) handles spawning one process per GPU and setting the RANK, LOCAL_RANK, and WORLD_SIZE environment variables. For multi-node training, every node must specify the master node’s address so that all processes can establish the NCCL connection.
Here is our launch script, which reads all tunables from environment variables:
#!/usr/bin/env bash
set -euo pipefail
NNODES=”${NNODES:-2}”
NPROC_PER_NODE=”${NPROC_PER_NODE:-4}”
NODE_RANK=”${NODE_RANK:-0}”
MASTER_ADDR=”${MASTER_ADDR:-127.0.0.1}”
MASTER_PORT=”${MASTER_PORT:-12355}”
torchrun \
–nnodes=”${NNODES}” \
–nproc_per_node=”${NPROC_PER_NODE}” \
–node_rank=”${NODE_RANK}” \
–master_addr=”${MASTER_ADDR}” \
–master_port=”${MASTER_PORT}” \
train.py “$@”
For a quick single-node test on one GPU:
torchrun –standalone –nproc_per_node=1 train.py –epochs 2
For two-node training with four GPUs each, run on Node 0:
MASTER_ADDR=10.0.0.1 NODE_RANK=0 NNODES=2 NPROC_PER_NODE=4 bash scripts/launch.sh
And on Node 1:
MASTER_ADDR=10.0.0.1 NODE_RANK=1 NNODES=2 NPROC_PER_NODE=4 bash scripts/launch.sh
Figure 4: Multi-node architecture. Each node runs 4 GPU processes; NCCL all-reduce synchronizes gradients across the ring.
10. Performance Pitfalls and Tips
After building hundreds of distributed training jobs, these are the mistakes I see most often:
Forgetting sampler.set_epoch(). Without it, data order is identical every epoch. This is the single most common DDP bug and it silently hurts convergence.
CPU-GPU transfer bottleneck. Always use pin_memory=True in your DataLoader and non_blocking=True in your .to() calls. Without these, the CPU blocks on every batch transfer.
Logging from all ranks. If every rank prints, output is interleaved garbage. Guard all logging behind rank == 0 checks.
zero_grad() without set_to_none=True. The default zero_grad() fills gradient tensors with zeros. set_to_none=True deallocates them instead, reducing peak memory.
Saving checkpoints from all ranks. Multiple ranks writing the same file causes corruption. Only rank 0 should save, and all ranks should barrier before loading.
Not seeding with rank offset. torch.manual_seed(seed + rank) ensures each rank’s data augmentation is different. Without the offset, augmentations are identical across GPUs.
When NOT to use DDP
DDP replicates the entire model on every GPU. If your model does not fit in a single GPU’s memory, DDP alone will not help. For such cases, look into Fully Sharded Data Parallel (FSDP), which shards parameters, gradients, and optimizer states across ranks, or frameworks like DeepSpeed ZeRO.
11. Conclusion
We’ve gone from a single-GPU training mindset to a fully distributed, production-grade pipeline capable of scaling across machines — without sacrificing clarity or maintainability.
But more importantly, this wasn’t just about making DDP work. It was about building it correctly.
Let’s distill the most important takeaways:
Key Takeaways
- DDP is deterministic engineering, not magic
Once you understand process groups, ranks, and all-reduce, distributed training becomes predictable and debuggable. - Structure matters more than scale
A clean, modular codebase (config → data → model → training → utils) is what makes scaling from 1 GPU to 100 GPUs feasible. - Correct data sharding is non-negotiable
DistributedSampler + set_epoch() is the difference between true scaling and wasted compute. - Performance comes from small details
pin_memory, non_blocking, set_to_none=True, and AMP collectively deliver massive throughput gains. - Rank-awareness is essential
Logging, checkpointing, and randomness must all respect rank — otherwise you get chaos. - DDP scales compute, not memory
If your model doesn’t fit on one GPU, you need FSDP or ZeRO — not more GPUs.
The Bigger Picture
What you’ve built here is not just a training script — it’s a template for real-world ML systems.
This exact pattern is used in:
- Production ML pipelines
- Research labs training large models
- Startups scaling from prototype to infrastructure
And the best part?
You can now:
- Plug in a real dataset
- Swap in a Transformer or custom architecture
- Scale across nodes with zero code changes
What to Explore Next
Once you’re comfortable with this setup, the next frontier is memory-efficient and large-scale training:
- Fully Sharded Data Parallel (FSDP) → shard model + gradients
- DeepSpeed ZeRO → shard optimizer states
- Pipeline Parallelism → split models across GPUs
- Tensor Parallelism → split layers themselves
These techniques power today’s largest models — but they all build on the exact DDP foundation you now understand.
Distributed training often feels intimidating — not because it’s inherently complex, but because it’s rarely presented as a complete system.
Now you’ve seen the full picture.
And once you see it end-to-end…
Scaling becomes an engineering decision, not a research problem.
What’s Next
This pipeline handles data-parallel training — the most common distributed pattern. When your models outgrow single-GPU memory, explore Fully Sharded Data Parallel (FSDP) for parameter sharding, or DeepSpeed ZeRO for optimizer-state partitioning. For truly massive models, pipeline parallelism (splitting the model across GPUs layer by layer) and tensor parallelism (splitting individual layers) become necessary.
But for the vast majority of training workloads — from ResNets to medium-scale Transformers — the DDP pipeline we built here is exactly what production teams use. Scale it by adding nodes and GPUs; the code handles the rest.
The complete, production-ready codebase for this project is available here: pytorch-multinode-ddp
References
[1] PyTorch Distributed Overview, PyTorch Documentation (2024), https://pytorch.org/tutorials/beginner/dist_overview.html
[2] S. Li et al., PyTorch Distributed: Experiences on Accelerating Data Parallel Training (2020), VLDB Endowment
[3] PyTorch DistributedDataParallel API, https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
[4] NCCL: Optimized primitives for collective multi-GPU communication, NVIDIA, https://developer.nvidia.com/nccl
[5] PyTorch AMP: Automatic Mixed Precision, https://pytorch.org/docs/stable/amp.html

