How To Scale Your Model
August 20, 2025.
This is an extremely long winded, scattered, and at many times incoherent post; a result of my notetaking while reading the entire "How To Scale Your Model" series. I highly recommend the series, was a great read. I'd love to get more hands on experience with large scale training, would be a dream come true.
Again, don't try and read this. This is for my own reference. I've answered most quizzes, might be the only thing worth looking at here if your not me.
Rooflines
Typically computation within a single chip can be overlapped with communication within a chip and between chips. This means that we can lower bound training and inference time by using the maximum of computation and communication time . The upper bound, when no there is no comp/comms overlap, is therefor naturally given by the sum of and . We also know that the upper bound and lower bound differ by at most a factor of 2.
If comms and math is perfectly overlapped, and T_math > T_comms, then we are compute-bound. If T_comms > T_math we tend to be communication-bound
The Arithmetic Intensity measures of an algorithm measures of many "FLOPs per byte" the algorithm dictates. This is given by the ratio of total FLOPs it performs to the number of bytes it needs to communicate (either intra or inter chips). An accelerator has a peak arithmetic intensity given by its FLOPS divided by its bandwidth. For a TPU v5e this is 240 FLOPs/byte, and for a H100 its 226. This means that if an algorithm has a lower arithmetic intensity than 240/226, it will be bound by byte loading and thus we won't make good use of our hardware. An example of this is a dot product. Take two vectors in bf16 of shape[N]. To perform a dot product we need to load both vectors into memory, each of which has 2N bytes, perform N multiplications followed by N-1 additions and then write back 2 bytes into HBM.
So the dot product intensity is 1/2, meaning the algorithm performs 0.5 floating point operations per byte loaded, an extremely low value that will be communication bounded on any modern hardware.
Matrix Multiplication A matrix multiplication of two shapes [B, D] x [D, F] will need to load 2BD + 2DF bytes, perform 2BDF FLOPs (think of it as a dot product along the row B and column F between D elements), and write 2BF bytes back. Hence
Intensity(matmul) = 2BDF / 2BD + 2DF + 2BF
If batch size is small relative to D and F, we get that the intensity depends mostly on B. For a bf16 matmul to be compute bound on most TPUs, the local token batch size should be greater than 240.
Network communication rooflines Now, let's instead imagine that we perform a matmul across 2 TPU/GPU. Take the above example again and imagine we split along the D dimension. This matmul is now performed by performing half of each matrix on each accelerator, then copying the partial sums to the other accelerator.
Intensity (matmul 2 chips) = BDF/2BF = D/2
The algorithmic intensity depends entirely on the dimension D which we've split along. This is because our comms bytes BF stay fixed while math FLOPs scale with D. So the crossover to compute bound occurs when D/2 exceeds the hardwards intensity.
Question 1
- This is just half of the bytes from before, (2BD + 2DF + 2BF) / 2.
- The number of operations are unchanged, its still 2BDF. However, T_math changes because theoretically we should have higher accelerator FLOPs under int8.
- The intensity is 2BFD / (BD + DF + BF). Making the same assumptions as before we get 2BFD/DF = 2B. To calculate the threshold we get 2B = 3.94e14/8.1e11 = 486 <=> B = 243. Which is basically the same as last time.
- With the given params we have T_math = 2BFD / 3.94e14 and T_comms = (BD + DF + BF)/8.1e11. Lower bound assumes we have perfect overlap of T_math and T_comms, and then our lower bound is the maximum of the two, and upper bound is T_math + T_comms.
Question 2
To determine the batch size we need to first determine the intensity. First, let's figure out the amount of FLOPs necessary. The FLOPs are not influenced by the dtype, hence, we still have 2BFD FLOPs. Now, onto the communication bytes. We load the activations in bf16 and the weight in int8? That would mean load is 2BD + DF. Then we perform the matmul in bf16, and write back the result in bf16 which is 2BF. So that gives us
Intensity(matmul int8 + bf16) = 2BFD / (2BD + DF + 2BF). Under the same conditions as before, we still get 2BFD / DF = 2B. The accelerator intensity is now 1.94e14 / 8.1e11 / 2 = 120. That means if we can do int8 weight quantization, but still do BF16 FLOPs we become compute bound at much lower local token batch sizes, which is great! That is what we want remember.
Question 4
Let's look at the FLOPS and bytes. First, we load BD + BDF from HBM, then we perform the matmul which is still 2BDF. Then we store BF to HBM. That gives us an intensity of
2BDF / (BD + BDF + BF)
Since BDF will dominate the denominator this is essentially = 2. Which means our intensity is constant, which is bad because we will be comms bound no matter what.
Question 5 For BFLOAT16 the bf16 FLOPs is 1.979e15 with sparisty and half that without. The HBM is 3.35e12 resulting in a critical batch size of 295.
TPUs and GPUs
TPUs are fairly simple architectures. They have a high bandwidth memory, HBM which is typically on the order of tens of gigabytes. The HBM transfers data into the TensorCore which performs computation either in the MXU or VPU. Data is transfered through the Vmem.
![[Screenshot 2025-08-06 at 12.53.37.png]]
The reported HBM bandwidth represents the bandwidth between HBM and the TensorCore (through VMEM). The VMEM is a lot smaller than the HBM, on the order of megabytes, which means data is loaded in chunks from HBM into Vmem and through the tensorcore. This means that the bandwidth between Vmem and MXU is much higher than the HBM, otherwise we would be communication bound a lot. VMEM bandwidth is around 22x higher than HBM bandwidth which means an MXU operation reading from/writing to VMEM requires an arithmetic intensity of only 10-20 to achieve peak FLOPs utilization. That means if we can fit our weights into VMEM instead of HBM, our matrix multiplications can be FLOPs bound at much smaller batch sizes.
![[Screenshot 2025-08-06 at 12.59.30.png]] ![[Screenshot 2025-08-06 at 13.18.34.png]]
Question 1 Model is 2e11 parameters. Split across 32 TPUv4p. With bf16 that is (400e9 / 32) per chip, and each chip has a HBM bandwidth of 1.23e12. That gives us 10ms. Loading this model into systolic array takes 10ms. This is a lower bound on the latency of sampling from the model, because sampling requires we load all parameters from HBM so it can not take less than 10ms.
Question 2 A v5e pod is 16x16. With a host size being 4x2 that means we have 32 CPU hosts. We have 256 chips and the v5e only had 1 core per chip so that is 256 cores. FLOPS for the pod is 5.0432e16 and the HBM is 4TB.
Question 3 So we have to first move A [D, F] and x [B, D] from DRAM into MXU. We know from before that the arithmetic intensity of this matmul is
2BFD / (2BD + 2DF + 2BF) = (assuming B << D) = B.
Our T_math is given by 2BFD / 9.2e14. Our communication time is bottlenecked by PCIe so we need (2BD + 2DF + 2BF) / 1.5e10 to transfer data to and from the TPU. Since we want computation to take longer than weight loading, assuming we can completely overlap comms and math, we need 2BFD / 9.2e14 > (2BD + 2DF + 2BF) / 1.5e10. Simplifying this with our assumption of B << D and F = 4D we get 8BD^2 / 9.2e14 > 8D^2 /1.5e10 or B > 61,000
Question 4 Weight matrix W int8[16384, 4096], and activation matrix x int8[B, 4096]. On TPU v5e. First, we have to move both from HBM into MUX. We have 67108864 + 4096B bytes to move and a bandwidth of 8.1e11 giving us T_comms = 16384 + B / 49438476.5625 = . Then, we have to perform the matmul, which is 2 * 16384 * 4096 * B FLOPs, giving us T_math = 2 * 16384 * 4096 * B / 3.94e14 = 3.4e-7B.
Partitioning Notation and Collective Operations
Pop quiz: A's first dimension is partioned across 2x8=16 devices, and the second dimension is not sharded. This means across device meshes X,Y each device holds a int8[8, 2048] array. These are replicated across the Z device mesh. This means each device holds 16384 bytes. In total this uses 524288 bytes in total.
Case 1 - No Communication Required Lemma: when multiplying partitioned tensors, the computation is valid and the output follows the sharding of the inputs unless the contracting dimension is sharded or both tensors have a non-contracting dimension sharded along the same axis. For example, this works great
![[Screenshot 2025-08-08 at 11.19.07.png]]
with no communication whatsoever. Why? Because every device has a local block of A, B that is enough to perform a batch of the computation. This computation is independent of the sharding. Because we have the complete contracting dimension we can perform a complete dot product for a row / col and hold the temporary result.
Case 2 - One multiplicand has a sharded contracting dimension
We need to perform an AllGather. An AllGather removes sharding along an axis and reassembles the shards spread across devices onto each device along that axis.
![[Screenshot 2025-08-08 at 11.56.07.png]]
Then we perform the matmul on each device in full.
When performing an AllGather (or ReduceScatter or AllReduce) in a throughput bound regime. The time, latency of this communication only depends on the size (bytes) of the full array, and the bandwidth (think pcie bandwidth or whatever bandwidth the devices are connected by), NOT THE NUMBER OF DEVICES.
However, there is a caveat. The time for a hop, that is the time it takes to send a shard of the array to two neighbour devices (bidirectional communication), is given by:
![[Screenshot 2025-08-08 at 12.01.01.png]]
because V is the total size of the shard, X is the number of devices meaning the local shard is V/X bytes, and we need to send it in two directions, so we have bytes/bandwidth. This is all good, but if the ratio of V/X becomes very small, T_hop may be bound by the overhead rather than this formula: ![[Screenshot 2025-08-08 at 12.03.06.png]]
where T_min is the intrinsic overhead of our connection. In this case, the total time does indeed depend on X: ![[Screenshot 2025-08-08 at 12.03.44.png]]
Pop Quiz: The per device array is [512, 8192]. The formula says V / W_ICI where V represents the total array size which is [2048, 8192], which should mean it is sufficient to compute the bytes of that array and divide by W_ICI. However, our device mesh is 32, larger than a single host, meaning we have ICI latency. So this gives us 0.00037 seconds. Let's check if we are latency bound, with a mesh of 4 we are doing at most 3 hops which is around 3us which we are long from.
**Case 3 - both multiplicands have sharded contracting dimensions
How expensive is an AllReduce? Well our intuition says that we need to pass the result at each rank to every other device, similar to AllGather. But, we are passing a significantly larger tensor around, because it has the same shape as the full tensor, just with partial results.
Question 1 The array is sharded across axis X, meaning that across this axis each A is 1/4 the size of A. Along each X axis there are 16 devices. Meaning that each axis along X holds an array of size 16 * A/4 = 4A. That means that total size of A across the device mesh is 4A * 4 = 16A. The ratio is 16. If the array was not sharded it would of taken up a total of 64A.
Question 2 We are gathering across axis X, which contains 4 devices. We know from before that T_total = V / W_ICI. However, we have to note here that we are gathering across X but our D is sharded along Y so we only have gather in total 2BD/Y.
For the AllGather_XY, we now gather across both dimensions XY, meaning that our total bytes gathered is 2BD. We are also working across two axis meaning we have double the bandwidth giving us T_total = 2BD / 2W_ICI = 46us. This is far from the latency bound of 4us (1us per hop) so we are fine.
Finally, an AllReduce. We know that AllReduce=2AllGather. Each shard has size 2BD / (YX) = 2BD/16=BD/8.
Question 3 Given that we are latency bound, we can immediately say that the time is given by the latency overhead of one hop (1us) by the number of hops which is 2 so 2us.
Question 4 AllGather Comms: 2DF/W_ICI Gather across X means sending 2DF bytes. FLOPS: 2BDF After communicating we perform a full matmul on X[B,D] * Y[D,F] which is B2DF FLOPs.
Assuming we can overlap comms/math we get: T_allgather = max( 2BDF / C, 2DF/W_ICI)
AllReduce First, we perform the local matmul X[B,Dx] * Y[Dx,F] which gives T_math = 2BDF/XC and we end up with Z[B, F]{Ux}. Then we perform an AllReduce which is 2AllGathers, giving us T_comms = 4BF/W_ICI
T_allreduce = max( 2BDF/CX, 4BF/W_ICI )
The AllReduce is compute bound when D/CX > 2/W_ICI or when D/2X > C/W_ici. Taking v5p as an example we have C/W_ici = 2550. Which means D/5100 > X, so we are compute bound when D is 5100 larger than X which is generally not the case. So with AllReduce we are generally comms bound. Under strategy 1 we are
comms bound when B < C/W_ICI = 2550 which is often true. So if B < 2550 we are comms bound and we have
T_comms_allreduce < T_comms_allgather <=> 4BF/W_ICI < 2DF / W_ICI <=> 2B < D. So if D is twice the size of B then AllReduce is faster than AllGather. That is to say the AllReduce strategy is typically faster if our batch size is small. As the models grow, the AllGather is preferred.
Question 5 This is the case where input is sharded across non contracting dimensions
A[Bx, D] x B[D, Fyz] -> C[Bx, Fyz]
because this requires no communication, and the matmul flops are reduced by 64 because the each device only performs matmul on its own block
T_math = 2BDF / 64C
Question 6 We are multiplying matrices which are sharded across the contracting dimensions. This is case 3, where we would perform a local matmul giving partial sums
C[Ix, K]{Uy}
followed by an AllReduce_Y. Our computation is
T_math = 2IJK/XYC
and then the AllReduce requires the time of 2AllGathers, we are communicating 2IK/X bytes per hop, giving us
T_comms = 2IK/XW_IKI.
B
Okay, let's analyze. Both of our output dimensions keep their sharding, and the concatenating dimension is sharded on one of the input matrices. This is the same setting we analyzed earlier in question 4. One option we have is to perform an allgather over J, turning B[Jx, Ky] -> B[J, Ky], allowing us to just perform local matmuls on each rank. To calculat the comms cost, we divide the numer of bytes to communicate which is 2JK/Y
T_comms = 2JK / YW_ICI
Then, performing local matmuls on each rank, where I is sharded over X devices, and K sharded over Y, we get
T_math = 2IJK / XYC
where C is our peak accelerator FLOPs.
C
In the final case we have a simpler variant of the aforementioned question, because this time our contracting dimension J is already replicated, so we don't need to communicate anything. This is Case 1. We can just perform the matmul on each rank:
T_math = 2IJK/XYC
Question 7 To reduce latency, we would shard
B[Dx, F] x D[F, Dy] -> O[Dx, Dy]
Transformers
x: [P] y: [P] A: [N, P] B: [P, M]
Overall parameter and FLOPs of a Transformer are fairly easy to calculate and are sumarized below using MHA (same num query, key, value heads)
B: batch size T: sequence length (query) D: d_model hidden dimension F: d_ff hidden mlp dimension N: Numer of query heads H: Attention head dimension
A matmul requires 2NPM FLOPs. Training FLOPs is 6NPM, because we need 2NPM for the forward pass, and 4NPM for the backward pass.
| Component | Params per layer | Training FLOPs per layer |
|---|---|---|
| MLP | 3DF | 18BTFD |
| Attention | 4DNH | 24BTDNH + 12BT^2NH |
| Other | D | BTD |
| Vocab | DV | 12BTDV |
- D = NH (Typically)
- Typically we see F = 4D, which means the parameter count of the MLP block dominates the total parameter count.
- The total FLOPs budget during training is well approximated as 6 x num_params x num_tokens for reasonable context lengths. (This is not accounting for attention, which starts to dominate at T > 8D)
- During inference, our KV caches are roughly 2SLNH=2TLNH per cache. Meaning it scales linearly with sequence length, number of layers, number of heads, and attention head dimension.
Question 1 Total parameters is 17.3B. The attention parameters make up about 25% of the total parameter count. Finally, the KV cache per token is 2TLNH/T = 2LD = 524288, which in int8 is 512KB.
Question 2
A[Bx, Dy] x W[Dy, F]. The total FLOPs if nothing was sharded would be 2BDF. Now, since our matrices are sharded each device will perform a matmul of 2BDF/XY. Given the device mesh that means the total FLOPs performed is BDF/16. Since the computation is not sharded across Z we do Z extra flops meaning 2BDFZ total FLOPs.
Question 3
A[I,J,K,L] x B[I,J,M,N,O] -> C[K,L,M,N,O]
Here we assume that I,J are the contracting dimensions. That means we are doing 2KLMNOIJ
Question 4
Reshape Q[BTNH] -> Q[BTKGH] Q[BTKGH] x K[BSKH] -> O[BTSKG] U = softmax_S(O[BTSKG]) U[BTSKG] x V[BSKH] -> X[BTKGH]
Question 5 24BTDNH == 12BT^2NH <=> 2D == T
Question 7 The number of training FLOPs divided by the total training seconds gives us the achieved FLOPs / second which is 327 teraflops. The actual accelerator peak fp8 is 1513 teraflops. Giving a hardware utilization of 21.7%.
Question 8
We need to load DF weights from memory normally, but E copies in a MoE setup so EDF.
per token FLOPs is 2kBDF. To be compute bound we need an arithmetic intensity over 240, because that is the arithmetic intensity of our accelerator TPUv5e. Arithmetic intensity is given by
FLOPs / bytes moved
which is 2kBDF / EDF > 240 or kB/E > 120 or B > 120E/k. For DeepSeek this means B > 3840
Scaling Training
We approximate a Transformer as a stack of MLP blocks, because for large models attention makes up a comparatively small fraction of the FLOPs. As such, we can look at each transformer layer as: ![[Screenshot 2025-08-11 at 14.43.35.png]] Using our established notation this looks like
In[B, D] x W_in[D, F] x W_out[F, D] -> Out[B, D].
The 4 parallelism schemes can be though of as uniquely defined by a sharding for In, W_in, W_out and Out in the above diagram. Let's go through them.
Data Parallelism
In[Bx, D] x W_in[D, F] x W_out[F, D] -> Out[Bx, D].
Naturally, the activations are sharded across the rank X, and communication is not required until the backward pass. Activations, parameters and optimizer states are replicated on each device.
Fully Sharded Data Parallelism The activations are still sharded across the axis X (just like DP), but now the parameters are also sharded along the same mesh Axis.
In[Bx, D] x W_in[Dx, F] x W_out[F, Dx] -> Out[Bx, D].
Note that in this setup, due to the parameter sharding W_in[Dx, F] we now a sharded contracting dimension, meaning we require some kind of collective. In FSDP that means we AllGather just-in-time before use in the forward pass. Optimizer states are also sharded.
Tensor Parallelism Activations are sharded across D (d_model), parameters sharded along F. This requires an AllGather and ReduceScatter activations before and after each block. Compatible with FSDP.
In[B, Dy] x W_in[D, Fy] x W_out[Fy, D] -> Out[B, Dy].
Order of operations during forward pass
AllGather_y In[B, Dy] -> In[B, D] In[B, D] x W_in[D, Fy] -> H[B, Fy] H[B, Fy] x W_out[Fy, D] -> Out[B, D]{Ux} ReduceScatter_y Out[B, D]{Ux} -> Out[B, Dy]
Pipeline Parallelism Weights sharded along the layer dimension, activations microbatchd and rolled along the layer dimension. Communication between the pipeline stages is minimal (just moving activations over a single hop).
Data Parallelism
When your model fits on a single chip with even a tiny batch size (>240 tokens, so as to be compute bound), you should always use simple data parallelism. Remember, assuming we can overlap math and comms, when T_math > T_comms we are compute bound. For a matrix multiplication
Intensity(matmul) = 2BDF / (2BD + 2DF + 2BF) ≈ B (for small B relative to D and F)
meaning that when B > Accelerator intensity, which is 240 for a TPUv5e, we are compute bound.
Anyway, like we said, if we can fit the model on a single chip, and have a batch size > 240, we should always use data parallelism. Pure DP splits the activations across the number of accelerators so long as the number of accelerators is smaller than our batch size. The forward pass requires 0 communication, but at the end of every step we need an AllReduce of the gradients in order to sync before updating. ![[Screenshot 2025-08-11 at 15.39.20.png]]
All communication happens in the backward pass! There is also a really neat property of the backward pass that the AllReduces are not in the "critical path". Meaning that I don't need dW_out to compute dW_in, and hence we can overlap comms/compute really nicely. The overall communication cost can still bottleneck us if it exceeds our total compute cost, but it is much more forgiving from an implementation standpoint.
Why do DP? Pure DP reduces activation memory pressure by splitting our activations over the batch dimension, allowing us to almost arbitrarily increase batch size as long as we have more chips to split over. Especially during training when our activations often dominate our memory usage this is very helpful.
However, this does nothing to reduce the memory pressure from model params or optimizer states. Typically, with Adam optimizer, we require 10 * num_params memory. Meaning if we have a 96GB TPUv5p pod, the max we can train is 9B model.
When do we become bottlenecked by comms? We only need to perform comms in the backward pass where we require an AllReduce on a weight matrix W[D, F]. From previous sections we know that an AllReduce is 2x the time of an AllGather, and AllGather comms time is given by the total size of array we are communicating divided by our bandwidth. This is 2 * total bytes / W_ICI. We also require two of these per layer. This gives
T_comms = V / W_ICI = 2 * 2 * 2 * DF / W_ICI = 8DF/W_ICI
For the matmuls, each layer comprises two matmuls in the forward pass and four matmuls in the backward pass, each of which requires 2(B/X)DF FLOPs. Thus, for a single layer in the backward pass we have:
T_math = 8BDF/XC
Since we overlap
T = max(8DF/W_ICI, 8BDF/XC) = max(1/W_ICI, B/XC)
We become compute bound when B/X > C/W_ICI. This means, for data parallelism to remain compute-bound, we need the per device batch size to be larger than C / W_ICI. This naturally follows from the fact that per device computation time scales linearly with batch size, while communication is independent of this quantity (since we are communicating the weights not the activations). When per device computation is large enough, we are compute bound. Note the striking resemblance of the B > C/W_ICI condition to the single device compute-bound rule B > 240; in that case as well the rule came from the fact that computation time scaled with batch size while data-transfer size was independent of batch size.
Putting some real numbers on this thing, for TPUv5p we need our batch size to be at least 2550 to avoid being communication bound. Since we can DP over multiple axes, if we dedicate all three axes of a pod to DP, we 3x our bandwidth and can scale down to only BS=850! Remember, we use B to refer to the total batch size in tokens. Clearly, however, a batch is made of K sequences of T tokens each, so how can we do this? As far as the MLP goes, this does not matter, tokens are tokens, they are processed independently. You are free to do parallelism over both batch and sequence dimension: this is called context or sequence parallelism, but it is data parallelism.
FSDP
FSDP is a variant of data parallelism that splits the model parameters and optimizer states across the data parallel shards and efficiently gathers and scatters them as needed. Compared to pure DP, FSDP drastically reduces per device memory usage and saves on backward pass FLOPs, with very minimal overhead.
![[Screenshot 2025-08-11 at 20.38.43.png]]
Syntax: In[Bx, D] x W_in[Dx, F] x W_out[F, Dx] -> Out[Bx, D].
Remember that an AllReduce can be decomposed into an AllGather and a ReduceScatter. FSDP applies this to DP. Instead of doing full gradient AllReduce, FSDP shards the weights and optimizer states across chips, AllGather them at each layer during the forward pass and ReduceScatter across the weights during the backward pass at no extra cost. FSDP has basically no overhead compared to DP.
Standard DP involves a lot of duplicated work. Each TU AllReduces the full gradient, then updates the full optimizer state (identical work on all TPUs), updates params (again duplicated). For ZeRO sharding (sharding gradients/optimizer state), instead of the AllReduce, we ReduceScatter the gradients, update only your shard of the optimizer state, update a shard of the parameters.
FSDP has the same relative FLOPs and comms cost as pure DP
T = max(B/XC, 1/W_ICI)
This is great because it means if our per device batch size is big enough to be compute-bound for pure DP, we can - without worrying about leaving the compute-bound regime - simply upgrade to FSDP, saving a massive amount of parameter and optimizer state memory. Although we do add comms during the forward pass, the cost is immaterial since it overlaps with forward pass FLOPs.
As a concrete example, DeepSeek-V2 used a batch size of 40M tokens. That would allow you to scale to 47,000 chips before hitting the bandwidth limit.
Takeaway: FSDP and pure DP become bandwidth bound on a TPUv5 when B < 2550/n_axes
Tensor Parallelism
In a FSDP AllGather we move the weights across chips. We can also shard the feedforward dimension of the model and move the activations during the layer - this is called "1D model parallelism" or megatron sharding. This can unlock a smaller efficient batch size per pod.
Syntax: In[B, Dy] x W_in[D, Fy] x W_out[Fy, D] -> Out[B, Dy]
Gather Weights vs Activations In FSDP, looking at the flow, we have an input which is sharded along the batch dimension and the weights are sharded across D. In[Bx, D] x W_in[Dx, F] -> Tmp[Bx, F]
This means, to be able to perform a matmul, we will have to perform an AllGather_X on W_in AllGather_X W_in[Dx, F] -> W_in[D, F] before we can compute the matmul. This means we gather the weights across our network.
In TP, our syntax is In[B, Dy] x W_in[D, Fy] -> Tmp[B, Fy] Which means we now have to gather the activations across Y to be able to perform the matmul.
This is cheaper than ZeRO sharding when the activations are smaller than the weights! This is typically true only with some amount of ZeRO sharding added (which reduces the size of the gather). This is one of the reasons we tend to mix ZeRO sharding and model parallelism. TP comms cost is BD, and FSDP is DF. Meaning TP is cheaper than FSDP when
BD < DF <=> B < F
the batch dimension is smaller than the feed forward.
TP Algorithm In[B, D] = In[B, Dy] AllGather In[B, D] x W_in[D, Fy] = Tmp[B, Fy] Tmp[B, Fy] x W_out[Fy, D] = Out[B, D]{Uy} Out[B, Dy] = ReduceScatter(Out[B, D]{Uy})
One nice thing about TP is that it interacts nicely with our matrices in the Transformer forward pass, because we can do a AllGather to start, and only have to do a ReduceScatter Out at the end. What is the cost of this (modelling only the forward pass)
T_comms = (2BD + 2BD) / W_ICI = 4BD/W_ICI T_math = 2BDF/YC + 2BFD/YC = 4BDF/YC T = max(4BDF/YC, 4BD/W_ICI) = max(F/YC, 1/W_ICI)
Noting that we want compute cost to be greater than comms cost we get
F/YC > 1/W_ICI <->. F > YC/W_ICI
On a TPUv5p we have C/W_ICI 2500 in bf16 meaning we are compute bound when F/2500 > Y. So, for a given feed forward dimension F, we can only increase our mesh axis Y up to a certain point before we become comms bound.
Takeaway: Model Parallelism becomes comms bound when Y > n_axes * F/2500. For most models this is between 8 and 16 way model parallelism.
- On a TPU4vp with Llama 3-70B that has D=8192, F=30,000 we can comfortable do 8-way model parallelism but will be communication bound on 16 way model parallelism.
- For Gemma 7B we become comms bound at 19-way MP.
Mixed FSDP and TP
Syntax: In[Bx, Dy] x W_in[Dx, Fy] x W_out[Fy, Dx] -> Out[Bx, Dy]
FSDP and TP are easily combined, by sharding W_in and W_out along both axes we both save memory and compute. Because we shard B along X, we reduce the size of the model-parallel AllGathers and because we shard F along Y, we reduce the communication overhead of FSDP.
Algorithm In[Bx, D] = AllGather_Y(In[Bx, Dy]) W_in[Dx, Fy] = AllGather_X(W_in[Dx, Fy]) In[Bx, D] x W_in[D, Fy] = Tmp[Bx, Fy] W_out[Fy, D] = AllGather_X(W_out[Fy, Dx]) Tmp[Bx, Fy] x W_out[Fy, D] = Out[Bx, D]{Uy} Out[Bx, Dy] = ReduceScatter(Out[Bx, D]{Uy})
There are more communications, but note that they are smaller. For example, the first step of the TP algorithm was to AllGather(In[B, Dy]), this collective is now reduced by a factor X because we only move AllGather(In[Bx, Dy]).
A simple but key maxim is that FSDP moves weights and TP/MP moves activations. That means as our batch size (especially as we do more data parallel, meaning higher X), model parallelism becomes cheaper, as we noted above, because the per shard activations In[Bx, D] that we move are smaller.
- Model parallelism performs AllGather_Y([Bx, Dy]), which shrinks as X grows
- FSDP performs AllGather_X([Dx, Fy]) which shrinks as Y grows
Thus by combining both we can push our minimum batch size replica down even more. Let X be the number of chips used to FSDP, and Y the number of chips dedicated to TP. Let N be the total number of chips N=XY. Mx is the number of mesh axes over which we do FSDP and My for TP.
Total bytes communicated by each collective: AllGather_Y(In[Bx, Dy]) = BD/X AllGather_X(W_in[Dx, Fy]) = DF/Y AllGather_X(W_out[Fy, Dx]) = FD/Y ReduceScatter(Out[Bx, D]{Uy}) = BD/X
T_FSDPcomms = 4DF/YW_ICI T_TPcomms = 4BD/XW_ICI
Our total flops is
T_math = 2BDF/XC + 2BFD/YC = 4BDF/NC
Under the assumption that we do not overlap comms on the X and Y axis, the total comms time is
T_comms = T_FSDPcomms + T_TPcomms
First, let's identify the optimal values for X and Y to minimize total communication. Since FLOPs are independent of X and Y, the optimal settings are those that simply minimize comms. We find that: ![[Screenshot 2025-08-13 at 10.16.39.png]]
meaning that for a given B,F,N we know what amount of FSDP is optimal. Plugging in real values, N=64 (corresponds to a 4x4x4 array of chips), B=48000, F=32768 gives X ≈ 13.9. So we would choose X to be 16 and Y to be 4.
Takeaway: Combining TP and FSDP allows us to drop to a per device batch size B/N of 2*2550^2/F.
![[Screenshot 2025-08-13 at 10.29.43.png]]
Summary
- If we can fit the whole model in memory, and our batch size is > 240, we should always do pure DP.
- For a TPUv5p C/W_ICI ≈ 2,550.
- Pure data parallelism is compute-bound when B/X > C/W_ICI, meaning when the per-device batch size is large enough.
- FSDP is compute-bound when B/X > C/W_ICI, same as for DP. Meaning that if our per device batch size B/X is large enough, we can simply upgrade to FSDP and save a lot of memory.
- Tensor Parallelism is compute bound when F > Y2500. This is independent of batch size, and is around (Y=) 8-16 way for most models.
- Mixed FSDP + TP allows us to drop the batch size to 2 * 2550^2/F ≈ 400.
If we have LlaMA-3 70B with F≈30,000 on TPUv5p. We will be comms bound under model parallelism at Y > n_axes * F/2550 ≈ n_axes * 11. So anything over 11 way per axis.
Pure FSDP becomes ICI bound when the per device batch size is < 2550/n_axes. That means, if we want to use a total batch size of 2M, the most amount of chips we can use is:
B/X < 2550/n_axes <-> 2e6/X < 2550/3 <-> X = 2400
Mixed FSDP + TP means you are ICI bound at batch size < 432. That means, with 2 axes, we get
2e6/X < 432/2 <-> X = 9259 chips.
Takeaways
- Generally, increasing parallelism, or reducing the batch size, both tend to make us more communication bound, because they decrease the amount of per device computation.
- Up to a reasonable context length (~32k) we can get away with modeling a transformer as a stack of MLP blocks and define each of several parallelism schemes by how they shard the two/three main matmuls per layer.
- During training there are 4 main parallelism schemes we consider, each of which has its own bandwidth and compute requirements.
| Strategy | Description |
|---|---|
| Data Parallelism | Activations are batch sharded, everything else is fully replicated, we all reduce gradients during the backward pass |
| FSDP | Activations, weights and optimizer are sharded, weights are gathered just before use, gradients are reduce scattered. |
| Model Parallelism (aka megatron / tensor) | Activations are sharded along d_model, weights are sharded along d_ff, activations are gathered before W_in, the result reduce scattered after W_out |
| Mixed FSDP + Model Parallelism | Both of the above |
| Syntax: |
![[Screenshot 2025-08-13 at 13.11.06.png]]
Question 1 The MLP layers has 3 matrices, each of which with DF parameters FFN: 3DFL
The attention is 4 up projections W_QKVO matrices with shapes DNH and DKH, because H=N we get Attention: 4DNHL
Vocabulary params = 2VD
Question 2 BS=16M tokens, Adam optimizer.
Parameters in bf16: 13B * 2 = 26GB
Optimizer state: First moment estimate and second moment estimate per parameter stored in fp32. That is 8 bytes per param: 104GB
The activations after each matmul are shape BF, BF, BD. At bf16, these take up: 4BFL + 2BDL = 2LB(2F + D) = 4.19e13 = 42TB
Question 3 Some numbers first. 32k sequence length and a 3M batch size gives us a sequence batch size of 96. On a TPU v5p 16x16x16 we have 393TB of HBM
- Using pure data parallelism, the model is replicated across the devices, and because the param + optimizer states take up 130GB we need that to fit on device. TPUv5p have 96HBM, so they dont fit.
- Pure FSDP means we shard parameters, optimizer states and activations across the batch. Replacing 16M with 3M in Question 2, we get activations of size 7.86e12. We also have 1.3e11 in optimizer states, bringing us to almost exactly 8e12 = 8TB. This means we are well under the limit of 393TB of HMB. Remember the activations and optimizer states are sharded across our slice. The per device memory usage is 1.95GB. Let's not analyze if we are compute or communication bound. We know from earlier that under FSDP we are comms bound if B/X < C/W_ICI * n_axes. For a TPUv5p with 3 axes our arithmetic intensity is 850 -> B/X < 850. Our per batch size is 3e6/16^3 = 732. So yes we are comms bound. Or, another way of looking at it is that with 4096 chips and a minimum intensity of 850 we need a minimum batch size of 3.48M.
- Mixed FSDP + TP. This means we are comms bound if B/X < 2 * 2550^2 / F = 940. That means its actually worse than FSDP.
Question 4 Dropping the batch size to 1M means we have a 244 per device batch size. That is barely enough to be above our on device computation bound threshold.
Training Llama 3 on TPUs
A practical example of training the LLama 3 family herd of models. The herd includes 8B, 70B and 405B models.
Starting off with the 70B model, let's have a closer look at it:
![[Screenshot 2025-08-13 at 14.48.06.png]]
Params
Per layer params FFN: 3DF Attention: W_Q and W_O have DNH, W_K and W_V have DKH. In total we have 2DNH + 2DKH. Vocab: 2VD
In total that gives L(3DF + 2DNH + 2DKH) + 2VD = 70.52B
The FFW make up 80% of the total parameter count, Attention is 17% and output embeddings are 3%.
FLOPS
The flops are typically 6 * num_params, meaning 420B FLOPs per token. That means we are doing roughly half a TFLOP per token per training step. Assuming we are compute-bound, this should take 0.9ms on a TPUv5p.
To calculate exactly: The FFN layer has 3 big matmuls, each of which takes 2BTDF FLOPs. That is in total 3 * 2BTDF * 3 = 18BTDF, where the second factor 3 comes form the fact that we do 1 of these in the forward pass and 2 for the backward pass. The attention consists of 6 large matmuls, and the dot product attention. Ignoring the dot product flops, and just taking the 6 matmuls, they require
2BTDNH + 2BSDKH + 2BSDKH + 2BTKGS + 2BTKGH + 2BTNHD = 4BTNHD + 4BSDKH + 2BTKGS + 2BTKGH.
Given 15T training tokens, that gives us a total FLOPs of 6.3e24. If we are compute-bound this would take 158k days on a single TPU v5, roughly 435 years.
If we assume a full TPU v5p pod with 8960 chips and a MFU of 40% this would take around 1000 hours ≈ 44 days. Which is fairly reasonable, assuming we can achieve 40% MFU.
Llama 3 70B was pretrained with a batch size of 4M tokens. The optimizer state will take up 70 * (2 + 4 + 4) = 700GB. Then for the activations, we are checkpointing 4 times per layer, it depends on which shape we are checkpointing, but let's assume they are all B,D matrices, that means
2 * 4e6 * 8192 * 4 * 80 ≈ 21TB
That means we need atleast 21.6TB of HMB. Since each TPU has 96GB of HBM we need 225 TPUs at least. But, as we've already established, the total FLOPs to train is 6.3e24, which if our 225 TPUs were running with 100% MFU, would take 706 days to complete. It becomes quite evident that the reason our clusters are growing isn't due to memory, but rather FLOPs.
If we assume we are using 8960 chips, that means we have about 2.4GB per chip which is basically nothing.
How to shard for training
Assuming the same setting from before, we want to train LLaMa 3 70B with a 4M token batch size (1024 sequences of length 4096 per batch) on a TPU v5p pod of 8960 chips. Let's try to identify the best sharding strategy.
Question: Under the assumptions above, can we train our model with FSDP alone? To start, let’s say we can’t do any sequence/context parallelism. This should be the first idea you have, since it’s simple and will introduce no extra communication if it works
FSDP shards parameters, activations and optimizer states, meaning that memory wise, we will fit on 8960 no problem. The question however, says that we can't do sequence/context parallelism, that means that because we only have 1024 sequences, we can at most split across 1024 chips.
Question: Let’s relax the requirement of not doing any sequence sharding. If we allow ourselves to do FSDP over both the batch and sequence axes, can we train LLaMA 3-70B with only FSDP on 8960 chips?
Okay, now this is a normal FSDP setting as discussed before. Remember, FSDP is communication bound when B/X < C/(W_ICI * n_axes). We're using TPU v5p so C/W_ICI is 2550, and in this case n_axes is 3 i think. So we are comms bound when the per device batch size is < 850. With the given setup, our per device batch size is 4e6 / 8960 = 446. So we are comms bound.
Question: Now let’s look at mixed tensor parallelism and FSDP. Does there exist some combination that lets us remain compute-bound? What amount of FSDP and tensor parallelism should we do if so?
Let's see, for mixed FSDP + TP, we have established that our condition for comms bound is B/N < 2550^2 / (MxMyF) = {where MxMy must be 2 in a 3D mesh} = 2550^2/2F = 113. That means we can be compute bound! To compute the optimal amount of FSDP/TP we go to our trusty derivation
X_opt = sqrt(B * Mx * N / F My) ≈ 1618. If we round to the nearest multiple of 2 we get 2048. That means we should do 2048
Inference
Naive sampling from a transformer. Put prompt in, get log p(next token | previous tokens). Sample from distribution, put prompt + next token in. Repeat.
This works, but we never do this in practice. Due to the causal dependency of the transformer decoder, token only depends on , so, at the second step in the image above we are recomputing the same thing for all previous tokens that we already processed in step 1. The forward pass is (n²) on the FFW and O(n³) on the attention mechanism to generate n tokens, that is expensive!!
Instead of doing the full forward pass every time, we can save some intermediate activations from each forward pass that allows us to avoid re-processing previous tokens. Specifically, since a given token only attends to the previous tokens during dot product attention, we can simply write each token's key and value projections into a new data structure called the kv cache. Once we've saved these key/value projections from past tokens, future tokens can simply compute their products without performing any new FLOPs on earlier tokens. Amazing! This naturally divides inference into two separate stages
Prefill: This is the first step in the image above, where we have yet to process the prompt. At this step we process all the tokens in the prompt at the same time, saving resulting activations (specifically key-value projections) in a KV cache. We also save the logits for the last token. Generation: Given a KV cache and the previous logit, we sample a new token and feed that token into the Transformer and produce a new set of logits. We also append the new KV activations to the KV cache.
Here's a new visualization with a KV cache
By sampling with a KV cache we reduced our time complexity to generate n tokens to O(n) in the FFW and O(n²) on the attention, since we never reprocess a previous token. We will see that prefill and generate are two very different tasks, with the KV cache being a novel and significant source of complexity.
What do we want to optimize? A part of inference that's totally new compared to training: latency. During training we focus on throughput, the total tokens processed per seconds, during inference we have to worry about how fast we're producing tokens, measured as both Time to First Token (TTFT) and the per token latency. This is different for different use cases:
- Chat interfaces / streaming tasks need to run cheaply at while while having a low TTFT, generating tokens fast enough to exceed human speed
- Offline batch inference for evals and data generation only care about the bulk cost of inference and is blind to the latency of individual samples
- Edge inference only needs to service one user at a time at the lowest possible latency.
Maximizing hardware utilization is still critical and helps with cost and TTFT, but unlike training it does not necessarily translate to better experience for individual users in all contexts. Many optimizations at the accelerator, systems and model arch level make tradeoffs between latency, throughput, context length and model quality.
A granular view of the Transformer
Before, when we were looking at the training perspective, we treated Transformers as a stack of MLP layers. While this is often reasonable from a FLOPs and memory standpoint, it is not sufficient to properly model inference. The major components of the Transformer forward pass are:
- a bunch of linear operations, including the MLP: W_in and W_out; the attention QKVO projections: W_Q, W_K, W_V, W_O. These all involve reading parameters and a batch of activations from HBM, doing some flops and then writing the result back to HBM.
- dot product attention We need to read a batch of key-value projections and a batch of query activations from HBM, do a few inner products and some softmax operations and write back to HBM.
- everything else including layer norms, activation functions, token sampling, updating kv cache and pos embeddings. These take some FLOPs but are dominated by, or fused into the above
Linear operations: what bottlenecks us?
Let's look at one of the linear operations, which take the form of a bf16[B, D] batch by a bf16[D, F] weight matrix. This could be either one of the big W_in/out in the MLP block or one of the smaller attention projections. To perform this matmul we need to load into HBM, perform the matmul, and store back into HBM. That means we have to move 2BD + 2DF weights into HBM, perform matmul, and then store back 2BF. Let's assume a TPU v5e, the time this takes is given by
T_comms = bytes moved / bandwidth = 2(BD+DF+BF) / W_HBM
Then the matmul we are performing is obviously 2BDF FLOPs, and the time it takes it
T_math = computation FLOPs / accelerator FLOPs = 2BDF / C
We are compute bound if
T_math > T_comms = computation FLOPs / accelerator FLOPs > bytes moved / bandwidth = computation FLOPs / bytes moved > accelerator FLOPs / bandwidth = intensity(algorithm) > intensity(TPU v5e)
where intensity(TPU v5e BF16) = 1.97e14 / 8.1e11 = 243
With this we get that
2BDF / 2(BD+DF+BF)> 243
which can have different characteristics depending on the size relationbetween B,D,F. Typically F>D>>B which gives
BDF / DF(B/F + 1 + B/D) <-> B -> B > 243 = B_crit
If we quantize our weights or use lower precision FLOPs for the matrix multiplication this critical batch size can change. For instance if our weights are quantized in int8 the bytes we get
2BDF / (2BD + DF + 2BF) <-> 2BDF / DF(2B/F + 1 + 2B/D) <-> 2B -> B_crit = 243/2
or if we do our FLOPs int8 / fp8 we now load everything in int8 meaning
2BDF / BD+DF+BF <-> 2B -> 2B > HMB int8 intensity = 3.94e14 -> B > 243
so basically nothing changes if we do things in int8. We are moving 2x less data which reduces communication load, but our accelerator is 2x faster so it evens out.
We can draw some general conclusions from this; if we let = bits per param / bits per activation, and alpha_hbm = intensity(accelerator) = C/W_hbm, then our critical batch size is B_crit = .
Takeaway: Transforme matmuls are compute bound iff the per replica token batch size is greater than B_crit = C/W_hbm * (bits per param / bits per activation) = beta*alpha. For bf16 activationson a TPU v5e this is 240 tokens, for an h100 this is about 280 tokens.
Remember that batch size here refers to the token batch size. During training, we'll have a very high algorithmic intensity because we reuse the same weights over a very large batch. This high intensity carries over to prefill since user prompts are typically hundreds if not thousands of tokens long. If a sequence is longer than 240 tokens and fed into a dense model we expect it to be compute-bound and all is well. Prompts shorter than this can technically be batched together to achieve higher utilization but this is typically not necessary.
Takeaway: During prefill, all matrix multiplications are basically always compute-bound. Therefore simply maximizing hardward utilization or MFU is enough to maximize throughput per chip (cost) and latency (in the form of TTFT). Unless prompts are extremely short, batching at a per-prompt level only adds latency for a small improvements in prefill throughput.
However, when we move to the decoding/generation stage we can only do our forward passes one token at a time. Thus we can only (easily) achieve good utilization by batching multiple requests together, parallelizing over the batch dimension. Apparently, batching over concurrent requests is hard without affecting latency, for that reason it is much harder to saturate the hardware FLOPs with generation.
Takeaway: During generation, the total token batch size must be greater than B_crit to be compute bound on the linear/feed forward operations. Because generation is only done on one token this requires batching multiple requests, which is hard
You have to realize that handling 240 concurrent requests means handling 240 separate KV caches. That means this is difficult to achieve in practice. In contrast, pushing more than 240 tokens through during the prefill is pretty routine.
Attention!
Things get more complicated as we turn to Attention :) Looking at pure multi head scaled dot product attention. In a single Flash Attention fusion we, ignoring softmax, masks etc, we:
- Read Q activations of shape bf16[B, T, D] (assuming D=NH) from HBM
- Read the KV cache which is a pair of bf16[B, S, D] tensors from HBM
- Perform 2BTSD FLOPs in the QK matmul, with flash attention we dont need to write bf16[B,S,T] attention matrix mack into HBM
- Perform AV matmul taking 2BTSD FLOPs
- Write the resulting bf16[B,T,D] tensor back into HBM.
Putting this together we get
Multihead attention intensity = FLOPs / bytes moved = 4BTSD / 2BTD + 2BSD + 2BTD = TS / T + S
During prefill S=T giving us T/2. This is great becuse it means the arithmetic intensity of attention during prefill is O(T). That means it is quite easy to be compute-bound ofr attention, as long as our sequence length is fairly large. But, during generation S>>T = 1 giving us
ST/(T+S) = S/(S+1) = 1 as S grows.
This is bad, since we cannot do anything to improve the arithmetic intensity of attention during generation. We're doing a tiny amount of FLOPs while loading a massive KV cache. So we are basically always memory bandwidth bound during attention.
Takeaway: During prefill, attention is typically comput bound for any reasonable sequence length (roughly > 480 on a v5e), while during generation our arithmetic intensity is roughly 1 and constant, so we are ALWAYS memory bandwidth bound.
Let's think about this. During the linear portions of the model we are compute bound because the parameters (the memory bandwidth heavy components) are reused over many batch items. However, every batch item has its own KV cache, so a bigger batch size means more kv caches. We will almost always be memory bound here unless the architecture is adjusted.
Theoretical estimates for LLM latency and throughput
\begin{align} \text{Theoretical Step Time (General)} = \underbrace{\frac{\text{Batch Size} \times \text{KV Cache Size}}{\text{Total Memory Bandwidth}}}_{\text{Attention (always bandwidth-bound)}} + \underbrace{\max\left(\frac{2 \times \text{Batch Size} \times \text{Parameter Count}}{\text{Total FLOPs/s}}, \frac{\text{Parameter Size}}{\text{Total Memory Bandwidth}}\right)}_{\text{MLP (can be compute-bound)}} \end{align}
Throughput / latency pareto charts from PaLM. We can trade throughput for latency up to a certain point. As the batch size goes beyond the 240 mark, our MLP FLOPs begin to dominate over communication time, and as such the throughput then starts to depend on the batch size, meaning tha throughput is flat as we increase batch size beyond that point. Before that, communication time dominates, which depends only on parameter size and bandwidth. Once MLP becomes compute bound the throughput is given by
Throughput = Batch Size / (Attention Time + MLP Time)
= Batch Size / (Batch Size × KV_factor + Batch Size × MLP_factor)
= 1 / (KV_factor + MLP_factor)
which scales linearly with batch size.
Takeaway: If we care about generation throughput, use the largest per-chip batch size possible. Any per-chip batch size above the arithmetic intensity (B_crit which is typically 120 or
240 depending on quantization) will maximize throughput. You may need to increase topology to achieve this. Smaller batch sizes will allow you to improve latency at the cost of throughput.
What about memory?
We've looked at bandwidth and FLOPs of attention and linear operations during inference, but not the memory. Memory looks quite different during inference thanks to the kv cache. During training, we have parameters, activations, optimizer states. Where activations typically dominate the memory requirements. During inference, many of the things we have to store during training disappear. We don't have a optimizer, and we don't perform a backward pass so we dont need to save activations. It's actually just the parameters, with the addition of the kv cache. For the coming section, let's look at a real model to demonstrate how different things are in inference.
LLAMA 2 13B
L = 40
D = 5120
F = 2.7*D
N = 40
K = 40
H = 128
V = 32000
As we said, during inference our parameters require memory. Counting these we have, per layer:
FFW: 3DF Attention: DNH + DKH + DKH + NHD = {N=K} = 4DNH Vocab = 2DV
Adding these up gives 13e9 parameters, as expected. As we saw in the last section, storing parameters in bf16, with optimizer state in float32 may use around 100GB (2 bytes for params, 4 for m_t and 4 for v_t). This pales in comparison to the gradient checkpoints (activations) which can take several TBs.
During inference we only store one copy of params, in something like bf16 using at most 26GB. But we can often do even better with quantization. Activations are negligable. There are no optimizer states. The main difference is the kv cache. The total size of the kv cache for T tokens is
KV cache size = 2 * bytes per float * L * T * K * H
where H is the dimension of each head, K the number of KV heads, L layers and 2 comes from storing both K and V. This can get big fast. For our model at hand, a 8192 sequence at bf16
(2 * 2 * L * T * K * H) / 1e9 = 6.7GB
At just a batch size 4 we've exceeded the memory usage of our parameters.
Modelling throughput and latency for Llama 2 13B
What happens when we want to perform inference at different batch sizes on 8x TPU v5e
| Batch Size | 1 | 8 | 16 | 32 | 64 | 240 |
|---|---|---|---|---|---|---|
| KV Cache Memory (GiB) | 6.7 | 53.6 | 107.2 | 214.4 | 428.8 | 1608 |
| Total Memory (GiB) | 32.7 | 79.6 | 133.2 | 240.4 | 454.8 | 1634 |
| Theoretical Step Time (ms) | 4.98 | 12.13 | 20.30 | 36.65 | 69.33 | 249.09 |
| Theoretical Throughput (tokens/s) | 200.61 | 659.30 | 787.99 | 873.21 | 923.13 | 963.53 |
| We note that the KV Cache dominates our memory footprint, ammortizing the parameter cost. 8x TPU v5e give us 128GB of HBM, 6.5 TiB/s of HBM bandwidth and 1600TF/s of compute. Increasing the batch size increases our throughput, as we expect, but at diminishing returns. We will OOM at batch sizes > 16. If we keep the number of params the same, but are able to magically make our KV cache 5x smaller: |
| Batch Size | 1 | 8 | 16 | 32 | 64 | 240 |
|---|---|---|---|---|---|---|
| KV Cache Memory (GiB) | 1.34 | 10.72 | 21.44 | 42.88 | 85.76 | 321.6 |
| Total Memory (GiB) | 27.34 | 36.72 | 47.44 | 68.88 | 111.76 | 347.6 |
| Theoretical Step Time (ms) | 4.17 | 5.60 | 7.23 | 10.50 | 17.04 | 52.99 |
| Theoretical Throughput (tokens/s) | 239.94 | 1,429.19 | 2,212.48 | 3,047.62 | 3,756.62 | 4,529.34 |
| Now, we will OOM at a batch size of > 64. We still see diminishing returns but throughput scales beter up to 240. |
Takeaway: The size of the KV cache has a lot of bearing over the ultimate inference performance of the model. At longer sequence lengths the attention time dominates MLP time, which means that reducing the KV cache size by a factor 1/X will roughly reduce the step time by the same factor 1/X (and increase throughput by X).
Tricks for improving generation throughput and latency
Many techniques have been developed targeting the KV cache specifically.
MQA, GQA, MLA
Mixing in local attention: Local attention caps the context to a small-moderate size max length. At training time and prefil time, this involves masking the attention matrix to a diagonal strip instead of triangle.
Sharing KV cache across layers: The model can learn to share the same KV cache across layers in some pattern. While this benefits KV cache size and provides benefits to increasing batch size, caching, shared KV caches may need to read from HBM multiple times so it does not necessarily improve step time.
Quantization Inference is less sensitive to the precision of parameters and the KVs. By quantizing the parameters and KV cache (eg to int8, int 4, fp8 etc) we can save on memory bandwidth on both, decrease the batch size required to reach the compute roofline and save memory to run at bigger batch sizes. Quantization has the added advantage that even if the model was not trained with quantization it can be applied post training.
Using ragged HBM reads and Paged Attention We allocate 8k of context for each KV cache but it is often not necessary to read the entire KV cache from memory - requests have a wide range of length distributions and dont use the max context of the model.
Paged Attention is a refinement upon this that stores KV caches in OS-style page tables and mostly avoids padding the KV caches altogether. This adds a lot of complexity but means every batch only uses as much memory as it needs. Instead of allocating a standard 8k to every batch request we instead only use the necessary amount for each request.
Big Picture: All told, these KV cache optimizations can reduce KV cache sizes by over an order of magnitude compared to a standard MHA Transformer. This can lead to an order-of-magnitude improvement in the overall cost of the Transformer.
Distributing Inference Over Multiple Accelerators
We've mostly handwaved how we are scaling beyond a single chip. Let's look at this now, prefill and generation separately.
Prefill The roofline calculations are almost identical to training and almost all the same techniques apply - model (megatron) parallelism, sequence sharding (for sufficient long context), pipelining, even FSDP are all viable! You just have to keep the KVs kicking around so you can do generation later. As in training, increasing the number of chips gives us access to more FLOPs but adds communication overhead. General rule for sharding prefill: Assuming we're doing prefill on a single sequence ( no batch dim):
-
Model sharding: we typically do some amount of model parallelism first, up to the point we become ICI-bound. From section 5, this is around F/2550 for 1 axis.
-
Sequence parallelism: Beyond this we do sequence parallelism (like data parallelism but sharding across the sequence dimension). While SP introduces some extra communication in attention, it is typically fairly small at longer contexts. As with training we can overlap comms and math.
Takeaway: during prefill, almost any sharding that can work during training can work fine. Do model parallelism up to ICI bound, then sequence parallelism
Generation is a different beast. For one thing, it is harder to get a large batch size since we need to batch many requests together. Latency targets are lower. Together, this means we are typically more memory bound and more sensitive to communication overhead.
- FSDP is impossible. We are memory bound in loading our parameters and KV caches from HBM to MXU, we do not want to move them via ICI which is orders of magnitude slower than HBM. If anything we want to move activations rather than weights. Activations are considerably smaller.
- There is no reason to do data parallelism. Pure data parallelism is unhelpful, we are already memory bound on a single chip and DP replicates parameters, which doesn't make parameter loading faster. You're better off spinning up multiple copies of the model instead.
- No sequence = no sequence sharding
This mostly leaves us with model sharding for dense model generation.
Note on ICI bounds for generation. During training we want to be compute bound, hence we try and identify at what point our ICI comms take longer than our FLOPS. However, during generation, if we're memory bound (HBM to MXU) by parameter loading, we can increase model sharding beyond the aforementioned point and improve latency at a minimal throughput cost. More model sharding gives us more HBM to load our weights over, and our FLOPs dont matter (in the sense that FLOPs time isnt the bottleneck, so the thing we need to worry about is ICI time exceeding parameter loading time).
T_HBM_comms = 2DF / YW_hbm T_ICI_comms = 2BD / W_ICI
T_ICI_comms > T_HBM_comms <-> 2BD/W_ICI > 2DF / YW_hbm <-> W_hbm / W_ICI > F/BY <-> Y > F/B * beta
Beta is the ratio between hbm and ici speed, which is usually around 8. That means, for the llama model above we have Y = 54 without a meaningful hit to throughput. This assumes we can fully shard our KV caches 54 ways which is difficult.
Takeaway: our only option during generation are variants of model parallelism. We aim to move activations instead of KV caches or parameters because we are memory bound and we want to limit data transfer over ICI. When our batch size is large, we do MP up to the FLOPs-ICI bound (F/alpha). When our batch size is smaller we can improve latency by model sharding more. When we want to model shard more ways than we have KV heads we can shard our KVs along the batch dimension as well.
Sharding the KV cache
We almost always prefer to avoid replicating the cache, since it is the primary source of attention latency. To do this, we megatron shard across the head dimension, which limits us to K way sharding, so for models with a small number of heads we shard the head dimension as much as possible and then shard the batch dimension. Given a KV cache [2, B, S, K, H] we shard it as [2, Bz, S, Ky, H]. This means the KV cache is completely distributed.
X[B, D] = (existing activations, unsharded from previous layer) K[Bz, S, Ky, H], V[Bz, S, Ky, H] = ... (existing KV cache, batch sharded)
Q[B, Nyz, H] = X[B, D] x W_Q[D, Nyz, H] Q[Bz, Ny, H] = AllToAll_z->b(Q[B, Nyz, H]) Q[Bz, Ky, M, H] = Q[Bz, Ny, H] (split N -> K, M) O[Bz, S, Ky, M] = Q[Bz, Ky, M, H] x K[Bz, S, Ky, H] O[Bz, S, Ky, M] = softmax(O[Bz, S, Ky, M]) O[Bz, Ky, M, H] = O[Bz, S, Ky, M] x V[Bz, S, Ky, H] O[B, Ky, Mz, H] = AllToAll_z->M (O[Bz, Ky, M, H]) O[B, Nyz, H] = Reshape(O[B, Ky, Mz, H]) X[B, D]{Uyz} = O[B, Nyz, H] x W_O[Nyz, H, D] X[B, D] = AllReduce(X[B,D] {U_yz})
This is kind of complicated, byt we can see that sharding over the batch dimension like this requires us to perform2 AllToAll collectives, one to shift the Q activations to the batch sharding so we can compute attention with batch sharding, and one to shift sharded attention output back to pure model sharded. The new comms are modestly expensive since they operate on our small activations, while in return we save a huge amount of memory bandwidth loading the KVs.
Designing an effective inference engine
So far we've looked at how to optimize and shard individual prefill and generate operations in isolation. How do we combine these? The simplest method is simply run a batch of prefill, then a batch of generations This is easy to imlpement and is the inference setup in most codebases, but it has multiple drawbacks:
- Latency is terribl. We couple the prefill and generate batch size. Time to first token is terrible at big prefill batch sizes - you need to finish all prefills before any user can see any tokens. Generate throughput is terrible at small batch sizes
- We block shorter generations on longer ones. Many sequences will finish before others, leaving empty batch slots during generation.
- Prefills are padded to the longest sequence and we waste a lot of compute.
A slightly better approach involves performing prefill at batch size 1 (where it is compute bound but has reasonable latency), but batch multiple requests during generation. This will avoid wasted TTFT from batched prefill while keeping generation throughput high. We call this interleaved configuration since we interleave prefill and generation steps. This is very powerful for bulk generation applications like evaluations where throughput is the main goal. We want to batch generation to improve throughput, and because prefill is already compute bound at batch size 1, this combination achieves this well. However, if we are serving a user, this configuration can lead to jittery and slow response on average. Other user prefills are placed on the critical path of the overall latency of a request.
To get around this, we separate decode and prefill.
Serving LLama
Question: Now let’s dig into the question of sharding. Let’s say we wanted to serve in bfloat16 on a TPU v5e 4x8. What sharding would we use for our model on a TPU v5e 4x8 during generation? Can we avoid being communication bound?
The only parallelism we can apply is model parallelism. TP becomes ICI bound when Y > n_axes * F / 2200 = 26. That means, with a 4x8 configuration we can not apply TP. The most we can do is 4x4, and even that might be pushing it considering we can not always perfectly overlap comms and math.
But, remember that during generation, we are in the memory bandwidth bound regime due to parameter loading, which means we can increase model sharding beyond the traditional point at a minimal throughput cost. More model sharding means more HBM to load our weights over and our FLOPs done "matter" in the sense that FLOP time isnt bottlenecking us. All we need to worry about is ICI time exceeding parameter loading time.
T_comms_hbm = 2DF/YW_hbm T_comms_ici = 2BD/W_ICI
T_comms_ici > T_comms_hbm <-> 2BD/W_ICI > 2DF/YW_hbm <-> F/BY < q <-> Y > F/Bq = 3185/B
We know that we have 32 GPUs which means we can our batch size can at most be 99. As long as our batch size is less than this we will be HBM bound on 32 GPUs. We can sanity check this further by looking at the raw values for a 4x8 and a 64 batch size
T_comms_hbm = 2DF/YW_hbm = 0.018ms T_comms_ici = 2BD/W_ICI = 0.011ms T_math = 2BDF/YC = 0.0047ms
Takeaway: the maximum amount of useful model parallelism depends on d_ff and the number of axes over which you're sharding the model. This value typically ranges between 8 and 32 depending on model size, the larger the model, the larger you can parallelize before being ICI bound. You can scale beyond this point this limit to improve latency at some throughput cost.
Prefill We've mostly ignored prefill because it is much simpler to deal with. Let's put a couple of concepts together and think about the end-to-end picture.
Question: Assume we achieve a 40% FLOPs utilization during prefill. How long will a prefill of length 8192 take on 16 TPU v5e chips?
A 40% FLOPs utilization means we are achieving 16 * 1.97e14 * 0.4 = 1.2608e15 FLOPs. How many FLOPs are required for a 8192 prefill? At 8192 tokens we are solidly in the compute bound regime. The forward pass uses 2 * num params * num tokens FLOPS. Which means this takes 0.9 seconds.
Question: Assume we have a median prefill length of 8192 tokens and a median decode length of 4096 tokens. Say we have a generate batch size of 32. On average how many sequences finish decoding per step? On average how many tokens are evicted from our KV cache each step?
We decode one token per step, and each sequence needs to decode 4096 tokens. At a batch size of 32 that means 32/4096 sequences finishing every step.
Our KV cache length is 8192 + 4096 (assuming a fixed size). This KV cache is dropped when we finish a sequence. So that means we are dropping 8192 + 4096 * 32/4096 = 96 tokens every step.
Question: Assume we do disaggregated serving with a median prefill length of 8192 and a median decode length of 512. Assume the prefill and generate latencies calculated above in bfloat16. What ratio of prefill:generate servers will you need to keep both fully saturated.
Prefill latency for 8192 was 0.91 seconds Generation latency for decode length 512 is 0.019s at batch size 32 (was 43 but lets say 32)
Let P be the number of prefill servers and G the number of generation servers. We will feed sequences into generation at a rate of P / prefill latency and consume them at a rate of B * G / (latency * decode steps). This gives
P/0.91 = 32G/(512 * 0.019) <-> P = 3G
so we need 3 times more generation servers than prefill servers.
As we've established before, during generation, the time is dominated by parameter loading at small batch sizes. As we cross a batch size of ~120 (this it int8) we become compute bound in the MLPs and the FLOPs start to dominate our time share. However, as we increase context, the only parameter that increases is the KV comms, we have to move a lot more bytes around, which means at increasing context lengths, in just case at just 4096, the KV comms are larger than the FLOPs. Naturally the KV cache size grows with batch size so at significant batch sizes the KV cache is what dominates the total time share. At context lengths of 16384 we cant even increase our batch size enough to reach the MLP compute bound regime anymore, as the context grows the total memory usage of the KV cache means our maximum batch size shrinks.
Takeaway: for LLaMA 3-70B, we are strongly KV cache memory bandwidth-bound (and HBM-bound) in almost all of these configurations, highlighting just how important reducing KV cache size is for generation throughput. Also note just how dramatic the latency/throughput tradeoff remains here.
Question 1: How many FLOPs does each forward pass for LLaMA 3-405B use per-token? Assuming we’re FLOPs bound, what is a lower bound on a single forward pass on N chips on TPU v5e? What if we’re comms bound? Ignore the fact that the model does not fit on a single chip.
The FLOPs in a forward pass are 2 * num params per token which gives 810e9 FLOPs. Assuming we are FLOPs bound,, that means our lower bound is just the time it takes for our N cards to perform the necessary FLOPs of the forward pass. The TPUv5e has 1.97e14 FLOPs so the answer is:
810e9 / (N * 1.97e14)
If we are comms bound, that means the lower bound is given by the time it takes to move our parameters into MXU. Not sure if they are assuming N cards here, but in the single chip case the lower bound is
2 * 405e9 / 8.1e11
because the TPUv5e HBM BW is 8.1e11.
If we are comms bound as in ICI bound, then we assume the model is sharded over N chips and we get
2 * 405e9 / (N * W_ICI) = 2 * 405e9 / (N * 9e10)
Question 2: Assume we want to serve LLaMA 3-8B with BS240 using int8 weights and int8 KV caches. How many bytes are used by (a) model parameters b) KV caches and (c) peak working activations (roughly)? What’s the smallest topology we can run this on?
Bytes a) 8GB b) KV cache is [2, bytes per param, L, K, N] per token where L the layers, K the number of KV heads and N the head dimension. The config for LLama 3 8B gives us 2 * 1 * 56 * 8 * 128 = 114KB per token. With batch size 240 we know that the total kv cache bytes is 114e3 * 240 * S where S is the context length. c) ignoring actuvations because they are roughly negligible.
To determine the smallest topology we can run this on, lets assume our context length is 2048, that means the KV cache requires 56GB. We therefore require 64GB, meaning a 4x2 is sufficient (8 * 16GB). If we want to increase to 4096 context length we need 120 GB which will barely fit on 4x2 given the overhead.
Question 3: How would you serve LLaMA 3-405B on TPU v5e? Assume int8 weights and bfloat16 FLOPs. Let’s say we have a firm limit of 15ms / token, what’s the highest throughput configuration we could achieve? What is the theoretical minimum step time?
Let's see. Under this configuration our highest throughput is when we become compute bound by our MLPs at B_crit = 120.
The question is what throughput are we achieving at 15ms / token.
We have the general step time formula as \begin{align} \text{Theoretical Step Time (General)} = \underbrace{\frac{\text{Batch Size} \times \text{KV Cache Size}}{\text{Total Memory Bandwidth}}}_{\text{Attention (always bandwidth-bound)}} + \underbrace{\max\left(\frac{2 \times \text{Batch Size} \times \text{Parameter Count}}{\text{Total FLOPs/s}}, \frac{\text{Parameter Size}}{\text{Total Memory Bandwidth}}\right)}_{\text{MLP (can be compute-bound)}} <= 15ms\end{align} For int8 parameters and bf16 FLOPs the MLP turns compute bound when per-chip batch size is > B_crit = 120.
Let's perform some sanity checks, starting with seeing if the MLP can even be compute bound with the latency limit.
When the MLP becomes compute bound we have
2BP / (NC) < 15ms <-> B/N < 3.6 token/ second / chip
which is too large a latency meaning we will not be able to reach the B=120 compute roofline under the 15ms cap, the MLP will be comms bound not compute bound. Looking at the parameter loading
bytes / (N * W_hbm) = 0.5/N <0.015 <-> N > 33.33
means we need more than 34 chips to hide the parameter loading in the 15ms limit. A 4x8 topology is too small, we need to move to an 8x8 with 64 chips. At this size, the MLP step time is
0.5/N = 7.8ms
which leaves 7.2ms for the attention computation. The KV cache is [2, bytes per param, L, K, H] per token, which for our given 405 config is 2 * 1 * 126 * 8 * 128 = 258KB per token.
Batch size * sequence length * kv cache per token / (N * W_HBM) < 7.2ms <-> BS < 1446697
Serving LLaMA
Looking at the 70B version again
| hyperparam | value |
|---|---|
| nlayers (L) | 80 |
| dmodel (D) | 8,192 |
| dff (F) | 28,672 |
| nheads (N) | 64 |
| nkv heads (K) | 8 |
| dqkv (H) | 128 |
| nembeddings (V) | 128,256 |
| This model has a KV cache of [2, L, K, H] per token, assuming int8, that is 163KB. That means just a single sequence of 8192 length requires 1.34GB of memory. |
Serving this model with a batch size 32 and 8192 sequence length in int8 would require:
Params: 70GB KV Cache: 2 * L * K * H * 32 * 8192 = 43GB
That is a total of 113GB, which would require a 4x2 TPU v5e, or perhaps even a 4x4 to account for overhead.
To calculate the decode latency we look at
\begin{align} \text{Theoretical Step Time (General)} = \underbrace{\frac{\text{Batch Size} \times \text{KV Cache Size}}{\text{Total Memory Bandwidth}}}_{\text{Attention (always bandwidth-bound)}} + \underbrace{\max\left(\frac{2 \times \text{Batch Size} \times \text{Parameter Count}}{\text{Total FLOPs/s}}, \frac{\text{Parameter Size}}{\text{Total Memory Bandwidth}}\right)}_{\text{MLP (can be compute-bound)}} \end{align}
First, calculating attention time is straight forward
T_attn = B x kv cache / 8 * W_HBM = 32 * 8192 * 160e3 / (8 * 8.1e11) = 6.47ms
Remember that this is decoding / generation, so T = 1 and S=8192. For the MLP to be compute bound we need a per device batch size > 120 (params int8). In our config, the per device batch size is 32/8=4, so we are well into the memory bound regime. That gives:
T_mlp_comms = 70e9 / 8 * W_HBM = 10ms
Meaning our per step latency is 16.47ms and our latency is 32 / 17e-3 = 1882 tokens/second or 235 tokens / sec / chip.
The one caveat to check here is if we are ICI bound on our matmuls. In the above equations we have assumed that we are HBM bound, so we need to make sure this is true. In theory we are ICI bound if Y > n_axis * F/2200 = 26. Let's remind ourselves where this comes from.
TP: In[B, Dy] * D Win[D, Fy] * Wout[Fy, D] -> Out[B, Dy]
Where we AllGather the activations before the first matmul, then ReduceScatter them after the second. There are two matmuls, requiring 2BDF/Y + 2BFD/Y FLOPs
T_comms = 2 * 2 * B * D / W_ICI T_math = 4BDF / YC
T_comms < T_math <-> 4BD/W_ICI < 4BDF/YC <-> YC/W_ICI < F <-> {C/W_ICI = 2200 on a TPUv5e} <-> Y < F/2200.
That means we are compute bound as long as Y < F/2200. Note that this is independant of the precision of the computation. As an example, under int8 flops C_int8/W_ICI will double, but at the same time our communication volume is halved, so the two factors cancel.
If we were to run on a 4x4 we would still be fine ICI-wise and our latency would drop to by the same factor as we increase our number of TPUs, so down to 8.6ms
Throughput When we want to optimize for throughput, we ideally want to be compute bound, meaning we want to come close to utilizing all the TPU MXU capacity. Typically, that means increasing the batch size to be as large as possible so we are doing as much work as possible.
Repeating earlier sections, let's determine when a TPUv5e matmul becomes compute bound. This happens when time spent doing math (FLOPs), exceeds the time moving data from HBM into MXU.
Typically, we denote a matmul in this context as bf16[B, D] x bf[D, F]. Hence we get
T_math = 2BDF / C T_comms = 2BD + 2DF + 2BF / (W_hbm)
Compute bound iff: T_math > T_comms <-> 2BDF/C > (2BD + 2DF + 2BF) / W_hbm <-> 2BDF/(2BD + 2DF + 2BF) > C/W_hbm <-> {assuming B << D < F} B > C/W_hbm = 243
In BF16 the batch size needs to be larger than 243 to be compute bound. If our weights are int8 but FLOPs in bf16 then we will communicate half the bytes which means the necessary batch size is halved to 120. If we also perform our FLOPs in int8 these two factors cancel and we are back to 243. If our FLOPs precision is p_flops and our weight precision is p_w we can generalize the formula to be
B > p_w/p_flops * C/W_hbm
The case of int8 weights and bf16 FLOPs is fairly common, since quantizing parameters losslessly is often easier than doing low precision arithmetic.
Question: What is the smallest TPU v5e topology we could serve LLaMA 3-70B on using bfloat16, int8, and int4 (both KVs and parameters) with 8k context?
| dtype | param size | KV size / token (bytes) | min TPU v5es | actual min slice | remaining HBM for KV caches | num KV caches @ 8k |
|---|---|---|---|---|---|---|
| bf16 | 140GB | 324kB | 8.75 | 4x4 = 16 chips | 116 | 43 |
| int8 | 70GB | 162kB | 4.38 | 4x2 = 8 chips | 68 | 52 |
| int4 | 45GB | 81kB | 2.81 | 2x2 = 4 chips | 19 | 67 |
| So in theory we can fit LLama 70B on just 4 chips, but with only 67 KV caches. Note that this is our batch size! That means we will have very poor utilization. Ideally we want to use a larger topology to push our batch size up to 240. |
Question: Assume we use the largest batch size that fits on these topologies, what latency we could expect for each generate step?
Again we return to the latency equation from earlier
\begin{align} \text{Theoretical Step Time (General)} = \underbrace{\frac{\text{Batch Size} \times \text{KV Cache Size}}{\text{Total Memory Bandwidth}}}_{\text{Attention (always bandwidth-bound)}} + \underbrace{\max\left(\frac{2 \times \text{Batch Size} \times \text{Parameter Count}}{\text{Total FLOPs/s}}, \frac{\text{Parameter Size}}{\text{Total Memory Bandwidth}}\right)}_{\text{MLP (can be compute-bound)}} \end{align} Where the MLP is compute bound if the token batch size is greater than B_crit.
At BF16, B_crit is 240, and we are only memory bound. Which means the latency is
step latency = (Batch_size * KV cache + Parameter size) / Memory bandwidth = 19.2ms
We can alternatively realize that what we're doing here is just taking the total bytes that fit into TPU v5e HBM and moving those into MXU. That takes
step latency = 16GB / 8.2e11 = 19ms
Takeaway: we can always lower bound decode lateny by asking how long it takes to load all the model parameters from HBM into MXU. When KV caches are small, you can think about each layer as just loading the weights chunk-by-chunk and then discarding them. Unless we're using large batch sizes or lots of inter-device comms, this is often a reasonable bound. When our batch size is bigger, we need to model the KV cache as well, since that dominates the parameters.
Likewise, in the FLOPs bound regime (e.g training or big-batch inference) the lower bound is determined by our FLOPs:
Total FLOPs / (N * C) = 2 * param count * B / (N * C)
where N is the number of accelerators sharded over, C the per accelerator flop. This is a lower bound, assuming no communication.
Question: For each of these, what throughput per chip does this give us (in terms of queries / chip)? You can assume our median decode length is 512 tokens.
At BF16, our per step latency is 19ms. That means our throughput is given by
B / (per step latency * median steps * N) = 43/(0.019 * 512 * N) = 4.42/N. Plugging in N we get
| dtype | QPS / chip |
|---|---|
| bfloat16 | 0.27 |
| int8 | 0.66 |
| int4 | 1.72 |
| Doubling our topology would mean we can increase our batch size. With a 4x8 slice we now have 372GB over for KV caches, which means we can fit a batch size of 137. This means we get a throughput of 14/N. Giving |
| dtype | QPS / chip |
|---|---|
| bfloat16 | 0.43 |
| int8 | 0.87 |
| int4 | 1.75 |
| Question: Now let’s dig into the question of sharding. Let’s say we wanted to serve in bfloat16 on a TPU v5e 4x8. What sharding would we use for our model on a TPU v5e 4x8 during generation? Can we avoid being communication bound? |
Right, so I think the first thing we want to look at is what happens if we apply TP. For a TPU v5e at bf16, TP is ici bound when Y > n_axis * F/2200 (the notes use 2550 which is for TPU v5p). If we shard across 2 axis, with our model config this is Y > 26. That means we can not TP shard across our entire slice without getting ICI bound. This means we can TP across a 4x4 slice but not a 4x8. And even this is generally optimistic since we rarely perfectly overlap communication. Takeaway: we cannot actually serve on a 4x8 with pure model parallelism. The best we can do is 4x2 or maybe a 4x4.
GPUs
Question 1 [CUDA cores]: How many fp32 CUDA cores (ALUs) does an H100 have? B200? How does this compare to the number of independent ALUs in a TPU v5p?
A H100 SM has 4 subpartitions, called SM subpartitions. Each subpartion contains a SIMD/SIMT vector architecture called a Warp Scheduler whose lanes (ALUs) calls CUDA cores. The CUDA Cores perform vector arithmetic, similarly to the TPU VPU, each ALU can generally do 1 arithmetic op each cycle, e.g a fp32 add. Each subpartition contains 32 fp32 cores (and a smaller number of int32 and fp64 cores). That means each SM has 128 fp32 CUDA cores. The H100 has 132 SMs, that means it has 16896 fp32 CUDA cores. The B200 has 148 SMs / chip, totalling 19536 CUDA cores. The TPU v5p has 2 TensorCores, each wth a VPU with (8,128) lanes and 4 independent ALUs per lane so 2 * 4 * 8 * 128 = 8192 ALUs.
Question 2 [Vector FLOPs calculation]: A single H100 has 132 SMs and runs at a clock speed of 1.59GHz (up to 1.98GHz boost). Assume it can do one vector op per cycle per ALU. How many vector fp32 FLOPs can be done per second? With boost? How does this compare to matmul FLOPs?
Like we establishes, a H100 has 16896 fp32 CUDA cores (ALUs). It does one vector op per ALU per cycle. That means it does 16896 * 1.59e9 = 2.68e13 FLOPs per second. 3.34e13 in boost. In comparison, we can do 990 TFLOPs of matmuls in bf16. That is about 30x more FLOPs/s.
Question 3 [GPU matmul intensity]: What is the peak fp16 matmul intensity on an H100? A B200? What about fp8? By intensity we mean the ratio of matmul FLOPs/s to memory bandwidth.
The intensity of a h100 is 1979e12/13.35e12 = 590. That is with sparsity. If we assume no sparsity it is 295. For b200 the fp16 is 2250e12 / 8e12 = 281. This means, in similar fashion to the TPU we need around 280 batch size to be compute bound in a matmul. For both the h100 and the b200 we have exactly 2x FLOPs in int8 compare to fp16. That means our peak intensity doubles to 590 and 562.
Question 4 [Matmul runtime]: Using the answer to Question 3, how long would you expect an fp16[64, 4096] * fp16[4096, 8192] matmul to take on a single B200? How about fp16[512, 4096] * fp16[4096, 8192]?
We expect fp16[64, 4096] * fp16[4096, 8192] to be comms bound. We can double check this to be exact.
T_comms = (240968192 + 2644096 + 2648192) / 8e12 = 8.58e-6 T_math = (2 * 64 * 4096 * 8192) / 2250e12 = 1.90e-6
We see that comms take longer than math, hence we are comms bound. Increasing the batch size to 512, we are now FLOPs bound, and the time is 15us. In both these cases we are calculating the theoretical LOWER BOUND. In reality we can expect to get a fraction of the maximum FLOPs and BW meaning the time is slower than calculated.
Question 5 [L1 cache capacity]: What is the total L1/SMEM capacity for an H100? What about register memory? How does this compare to TPU VMEM capacity?
The L1/SMEM is a small, very fast memory located inside each SM. This memory is the on-chip cache. It is programmer controlled. This is used for storing activations and inputs to the TensorCore matmuls. This is 256kb.
Each SM subpartition has its own register file containing 16384, 32-bit words, totalling 256 kib per SM. Basically, we have as much register memory as we have SMEM. In total this is 33MB of each per card. This is about half of modern TPU's VMEM.
Question 6 [Calculating B200 clock frequency]: NVIDIA reports here that a B200 can perform 80TFLOPs/s of vector fp32 compute. Given that each CUDA core can perform 2 FLOPs/cycle in a FMA (fused multiply add) op, estimate the peak clock cycle.
B200 has 19536 CUDA cores, each performs 2 flops / cycle. We are performing 90TFLOPs per second, that means our cycle is 2.3GHz.
Question 7 [Estimating H100 add runtime]: Using the figures above, calculate how long it ought to take to add two fp32[N] vectors together on a single H100. Calculate both Tmath and Tcomms. What is the arithmetic intensity of this operation? If you can get access, try running this operation in PyTorch or JAX as well for N = 1024 and N=1024 * 1024 * 1024. How does this compare?
Adding two vectors fp32[N] means we have to load the two vectors into L1 cache, perform the operation in our cuda cores and then move the resulting fp32[N] vector back into DRAM.
T_comms = 12N / 3.35e12 T_math = N / 33.5e12
The arithmetic ratio of this operation is total FLOPs / total bytes = N/12N = 1/12 which is abysmal.
The peak hardware intensity is 10, meaning we are going to be horribly comms bound. That means the time of this operation is
T = max(T_comms, T_math) = T_comms = 12N / 3.35e12 = N/2.8e11.
At N=1024 we expect
Networking
Node level
Performing an AllGather of bf[Dx, F] in a H100 node. Let's assume N=X, we have to communicate 2DF/N bytes from each GPU, where each GPU has a unidirectional bandwidth of 450GB/s, which gives 2DF/(NW_uni). This is performed N-1 times in a ring style AllGather implementation. Giving a total communication time of
(N-1)* 2DF / (N * W_uni)
Beyond the node level
TPU v5p has about 90GB/s egress bandwidth per link, 540GB/s along all axes of the 3D torus. Within the H100 node, we have 450GB/s from each GPU, while beyond the node, this drops to 400GB/s node-to-node.
The newly released GB200 NVL72 SuperPods drastically change our rooflines. This is the first time we are moving beyond 8 GPUs in our NVLink domain, which is now increased to 72! These NVLinks now do full 900GB/s of GPU to GPU bandwidth.
Question 1 [Fat tree topology]: Using the DGX H100 diagram above, calculate the bisection bandwidth of the entire 1024 GPU pod at the node level. Show that the bandwidth of each link is chosen to ensure full bisection bandwidth. Hint: make sure to calculate both the link bandwidth and switch bandwidth.
Any even partition of this will include 2 SUs. Let's first look at the node->leaf connections. There are 8 leafs, each leaf is a 64 port NDR IB switch with 50GB/s per port, but we can only use 32 ports to ingress. That means the total switch bandwidth of the SU is 32 * 50 * 8 = 12.8TB/s. That means per node our bandwidth is 12.8/32 = 400GB/s. Let's now look at the link bandwidth, each node is connected through 8 links to the switches, this gives 3.2TB/s of total egress from the node. Per GPU this means 400GB/s.
At the spine level each leaf is connected to the 16 spines via 2x400GB/s. There are 8 leafs in the SU. That means the SU is connected to the spline with 8 * 16 * 2 * 400 / 8 = 12.8TB/s per leaf. That means, per node we get 400GB/s to the spline.
The splines are 16 switches with 64 ports. The total switch bandwidth is 51.2TB/s which, at 128 nodes is 400GB/s per node.
Thus if we bisect our nodes in any way, we will have 400GB/s per GPU between them. Every component has exactly the requisite bandwidth to ensure the fat tree.
Question 2 [Scaling to a larger DGX pod]: Say we wanted to train on 2048 GPUs instead of 1024. What would be the simplest/best way to modify the above DGX topology to handle this? What about 4096? Hint: there’s no single correct answer, but try to keep costs down. Keep link capacity in mind. This documentation may be helpful.
We can't increase the number of GPUs in a SU because our IB switches are at their max capacity of 32 ingress wires. We would have to double the amount of scalable units to 8. Our spine switches still have 32 ports available. This means we would have to increase to 32 spines.
Let's do the math. We have 8 SU, each with 8 leafs, and 32 GPUs.
The leaf is connected to 32 spines, with 1 * 50 GB/s, that means we have 400GB/s of per node link BW. The switch bandwidth is 64 * 32 * 50 = 102.4 TB/s which is 400GB/s per node given that we have 256 nodes. This would work. We can use the same setup as before within SUs and just double the SUs and splines. The only difference is that we connect our leaf-splines with 1x NDR instead of 2.
For 4096 GPUs we run out of ports, so we would need to add another level of indirection, that is to say another level of hierarchy. NVIDIA calls this level core switches. One easily notes how much more complexity is added in the system with the tree like structure compare to TPU pods.
How do collectives work on GPUs?
GPUs can perform all the same collectives as TPUs: AllReduce, AllGather, ReduceScatters, and AllToAlls. Unlike TPUs however, the way these work depends on whether they are performed at the node level (over NVLink) or above (over InfiniBand). The collectives are imlpemented by NVIDIA in the NVSHMEM and NCCL libraries. NCCL uses a variety of implementations depending on latency requirements/topology, but in this section we discuss a theoretically optimal model over a switched tree fabric.
Intra-node collectives
AllGather or ReduceScatter: For an AllGather or a ReduceScatter at the node level, we can perform them around a ring just like a TPU, using the full GPU-to-GPU BW at each hop. We can imagine this as each GPU sending B bytes over the network using the egress or unidirectional bandwidth. The cost of this hop is T_hop = bytes/(N * GPU egress bandwidth), where bytes is the total bytes across all devices. The overall cost is therefore
T = (N-1) * bytes / (N * W_uni) -> bytes/w_uni = V/W_uni
Note that this is exactly the same as on a TPU. For an AllReduce we combine AllGather + ReduceScatter as usual for twice the cost.
In cases where the array is very small, we can do a tree reduction where we allreduce within pairs of 2, then 4 then 8 for a total of log(N) hops instead of N-1. Obviously the total cost is still the same.
Takeaway: the cost to AllGather or ReduceScatter an array of B bytes within a single node is T_comms = B * (8 - 1) / (8*W_uni) = B/W_uni. This is theoretically around B/450e9 on a H100 and B/900e9 on a B200. An AllReduce is 2x this cost.
Pop Quiz: T_comms = 2BF * 7 / (8 * 450e9) = 65us
AllToAll As opposed to TPUs, GPUs within a node have all-to-all connectivity, making AllToAlls simple. Each GPU sends directly to destination node. Within a node, each GPU has B/N bytes and sends (B/N²) bytes to N-1 targets for a total of
T_comms = B * (N - 1) / (W * N²) ≈ B/(WN)
Compare this to a TPU where the cost is B/4W. Within a single node we get a 2x theoretical speedup in time B/8W
Pop Quiz Under non sparse conditions this will take
T_comms = B/WN.
If we know that 4 out of 8 entries will be non zero, we get
T_comms = B / (2WN)
Takeaway: A AllToAll collective performed on an array of B bytes, on a single node, is T = B/(8W_uni), meaning 1/8th the cost of an AllGather. In comparison, it is B/4W on TPUs. For a ragged tensor (top-k), this is decreased further to B*k/(64W_uni).
Cross-node collectives
As repetition, the cost to AllGather or Reduce scatter at the intra-node level of NVIDIA GPUS is given by the following. At the intra node level we have N gpus, B bytes. Each device wants to communicate B/N bytes. Due to the node setup we have direct connectivity between ALL devices in the node. That means each device wants to egress B/N bytes, to N-1 GPUs, and it can do that at the available GPU agress bandwidth. That means the cost of each hop is so the overall cost is
which you will note is the same as for TPUs. Similarly, the cost for an AllReduce is the combination of RS+AG, at twice the cost
AllGather and Reduce Scatter
Now, on-to cross-node collectives. When doing a reduction over a tree you can think of reducing bottom up, first within the node, then at the leaf level and then at the spine level. This has the nice effect that for an AllReduce, we communicate less data overall because we will reduce at the node level and we only have to egress bytes up to the leaf instead of . Because we have full bisection bandwidth (the smallest bandwidth between any even partition of the network is equal to our full bandwidth) the cost of an AllGather or ReduceScatter is roughly the buffer size in bytes divided by the node egress bandwidth:
You can imagine this as performing a ring reduction over every node in the cluster. Now you may be wondering, do we not have to perform the intranode reduction first, before we can do the cross-node reduction? Like often is the case, these two collectives are overlapped, and the intra node reduction will (almost) never be the bottleneck so we don't need to calculate it. But, the general cost is:
Precise calculation
Let's be even more precise in this calculation. As we've established, we're effectively doing a ring reduction at each layer in the tree (network) which we can mostly overlap. That means, the cost, is whichever reduction takes the longest. A general way to write this is
where is the degree at depth , that is the number of children at depth . To determine which level of the tree determines our time / BW, we just have to solve the max() part of the formula.
Node: There are 8 GPUs with egress BW of 450GB/s, this will take 7 / (8 * 450e9) = 0.0019us Leaf: There are 32 nodes in an SU with egress BW of 400GB/s. This gives 31/(32 * 400e9) = 0.002us Spine: There are 4 SUs in total with egress BW of 12.8TB/s. This gives 4/(3 * 12.8e12) = 0.05ps
As we can see, the bottleneck is at the leaf level.
Other collectives
AllReduces are still 2x the above cost unless SHARP is enabled.
AllToAlls change a bit in the cross-node because they are not hierarchical in the way AllReduces are. If we want to send data from every GPU to every other GPU we can't take advantage of the full bisection BW at the node level. That means if we have an N-way AllToAll that spans M = N/8 nodes, each node holds B/M bytes, it keeps 1/M and sends the rest to the other nodes (M-1). Giving
That means, when moving from a single node to two nodes, our AllToAll collectives go from to . A general formulation of this is:
which for our full fat tree is
Takeaway: beyond the node level, the cost of an AllGather or ReduceScatter on B bytes is roughly B/W_node egress, which is B/400e9 on a H100 DGX SuperPod.
Reductions when array is sharded over a separate axis
In TPU-land, performing reductions such as
where we reduce over an array that has a dimension sharded over a separate axis reduced the cost by a factor 1/Y. This makes sense because we are moving 1/Y less data in each hop. Unfortunately, in GPU-land, this is not as straight forward. On GPUs, the cost depends on which axis is the "inner" one (intra-node vs inter-node) and whether each shard spans more than a single node. Going back to the general formulation
First, look at the intra node setting, where N-1 is replaced with N for simplicity.
Where D is the degree of the node (8). Then the scale out
Quiz 4: Collectives
Question 1 [SU AllGather]: Consider only a single SU with M nodes and N GPUs per node. Precisely how many bytes are ingressed and egressed by the node level switch during an AllGather? What about the top-level switch?
Let's work through the components of the reduction.
Each GPU holds B/NM bytes of data. Within each node, each GPU sends B/NM to the switch, for a total ingress of BN/NM = B/M bytes ingressed.
The switch egresses B/M bytes to the spine switch.
The spine switch ingresses B * M / M bytes. At this point, the spine switch holds the entire B bytes.
Now we need to send the data back down in the tree. Every node already holds B/M of the data so each node only needs what its missing: B - B/M = B(M-1)/M. That means the spine switch will egress a total of M * B(M-1)/M = B(M-1) bytes. Each node ingresses the B(M-1)/M.
Now, the last step is to egress downwards to each GPU. Remember, our GPUs already hold B/NM of the total data which means each GPU needs B - B/NM. Distribute that to all N GPUs in the node and the per node egress is N(B - B/NM) = NB - B/M.
Lets now look at the totals:
GPU Egress: B/NM Ingress: B - B/NM
Node switch Egress: B/M + NB - B/M = BN Ingress: B/M + B - B/M = B
Spine switch Egress: B(M-1) Ingress: B
Question 2 [Single-node SHARP AR]: Consider a single node with N GPUs per node. Precisely how many bytes are ingressed and egressed by the switch during an AllReduce using SHARP (in-network reductions)?
Each GPU sends B(N-1)/N bytes to the node switch for a total of N * B(N-1)/N = B(N-1) bytes. Normally, at this point we would want to communicate the rest of the missing B bytes to each GPU such that they can perform the reduction B - B/N. In total the switch egress would be N(B-B/N) = BN - B. But, with SHARP we can perform partial reduction at the switch level meaning we only have to communicate the resulting B/N bytes to every GPU for a total of N * B/N = B bytes. Then, we do partial sum of residuals locally on the GPU and send this back to the switch, N * B/N = B bytes ingressed. We then capture all the shard and multicast them, sending B(N-1)/N to N destinations for a total of B(N-1) /N * N = B(N-1) egressed.
Therefore the totals are
Node Ingress: B(N-1) + B = BN bytes Egress: B + B(N-1) = BN bytes
This supports the overall throughput being exactly B/W_egress
Question 3 [Cross-node SHARP AR]: Consider an array bf16[DX, FY] sharded over a single node of N GPUs. How long does AllReduce(bf16[D, FY] { UX }) take? You can assume we do in-network reductions. Explain how this differs if we have more than a single node?
We can try to modify the previous answer assuming sharding XY. Each GPU sends B(X-1)/XY bytes, then send back B/XY to each GPU, then send the same amountback to the switch, then send B(X-1)/XY back to each GPU. The total is BN/Y ingress and egress which means the total time is
BN/(Y * N * W_link) = N 2DF / (Y N W_link) = 2DF/(YW_link)
Question 5 [2-way AllGather cost]: Calculate the precide cost of an AllGather of B bytes over exactly 2 nodes. Make sure to calculate the precise cost and not the approximation, and consider both the intra-node and cross-node cost.
First lets look at the intra node cost:
now the cross node
which means that we are bottlenecked by the intra node reduction not the leaf level. This motivates 2-way DP.
Rooflines for LLM Scaling on GPUs
The idea of this chapter is to compare and for different parallelism strategies and understand at what point . This tells us when a certain parallelism strategy has run its course, and we've become bottlenecked by our communication collectives as opposed to our compute FLOPs. As before, we consider only the MLP block with operations
MLP = x[B, D] * W_in[D, F] * W_out[F, D]
where B is the global batch size in tokens (i.e B = batch size * sequence length).
| Node Type | GPUs per node | GPU egress bandwidth | Node egress bandwidth |
|---|---|---|---|
| H100 | 8 | 450e9 | 400e9 |
| B200 | 8 | 900e9 | 400e9 |
| GB200 NVL72 | 72 | 900e9 | 3600e9 |
Both GPU and node egress bandwidth determine rooflines for our LLMs. We use to describe either the GPU or node bandwidths depending on whether we are operating within or above the node level.
Data Parallelism
The cost of pure DP or FSDP without network reductions, per layer, in the backward pass with an axis size of X. In the backward pass we have four matmuls, each of which requires 2BDF FLOPs. Thus for a single layer:
T_math = 4 * 2BDF / XC T_comms = 2 * 2 * 2 DF / W_collective
Here we assume that our batch is sharded across X. Remember, the cost of an allreduce is the number of bytes of the array being allreduced and the bandwidth, specifically 2 * bytes / bandwidth. In the backward pass we have to perform 2 of these all reduces. This is in BF16 so we have 2 bytes per param moved hence 2DF.
For math > comms time we need B/X > C/W_collective where W_collective is either the GPU or node egress bandwidth depending on whether we are sharding within a node or across nodes. That is, we need the per-GPU token batch size to be larger than the intensity of our GPU.
- Within a node, we just need the per GPU token bath size > 9.9e14/450e9 = 2200
- Within a SU, or at the spine level, BS > 990e12/400e9 = 2475
This is quite a bit higher than on a TPU where the number is 850 with all three axes. For instance, LLaMaA-3 which trained on 16000 H100s would need a batch size of at least 40M tokens (for reference they used 16M). DeepSeek v3 trained on 2048 H800 GPUs with a lower 300GB/s bandwidth which would need 3300 tokens per GPU, or about 6.7M batch size (they used 4M).
In theory, because these are AllReduces we are taking abuot, enabling SHARP would 2x the AllReduce bandwidth which would half all of these numbers, but in practice this benefit is closer to 30%.
MoE models: For Mixture of Expert models where we have E experts and k experts per token, our costs change to
T_math = 4 * 2kBDF / XC T_comms = 2 * 2 * 2 EDF / W_collective
because we have k experts performing compute and E tensors to move. This inflates the pre-GPU token batch size by a factor E/k:
B/X > E/k * C/W_collective
For example, with stats from the new OAI OSS model we get BS > 79200 which is a kind of a ridiculously high number.
takeaway: DP and ZeRO sharding require a per GPU batch size of about 2500 tokens to be compute bound on a h100 or b200, assuming perfect overlap and FLOPs utilization. For MoE models this increases by a factor of E/K, the ratio of total activated parameters, this is because we are only doing FLOPs on a small ratio of the total parameters. When doing a small amount of DP, such as 2-way DP, the critical batch size decreases.
Tensor Parallelism
Syntax In[B, Dy] * W_in[D, Fy] * W_out[Fy, D] -> Out [B, Dy]
One way of implementing TP is by performing an AllReduce after each matmul:
Tmp[B, Fy] = In[B, Dy] * W_in[D, Fy] (we calculate a partial sum of the final desired product)
but we can be smarter about this, because remember we are performing two matmuls. So, instead, we can do an AllGather in the start which allows us to perform the matmul
In[B, D] = AllGather(In[B, DY]) (on critical path) Tmp[B, FY] = In[B, D] * W_in[D, FY] (not sharded along contracting, so no comms)
then we can perform the next matmul without and collectives as well, ending up with a partial result which we then reducescatter
Out[B, D] {Uy} = Tmp[B, FY] * W_out[FY, D] Out[B, Dy] = ReduceScatter(Out[B, D] {UY}) (on critical path)
This saves us a decent amount on comms costs. The forward pass costs are
T_math = (2BDF/Y + 2BDF/Y) / C = 4BDF/YC T_comms = (2BD + 2BD) / W_collective = 4BD/YW_collective
which is compute bound when
Y < FW/C.
Within a node this gives about F/2200 or F/2475 beyond a node, this is very close to TPUs. For F=28000 like LLaMA 3 this is at about 11-way TP (or rounding down, about 8 way whichis how large a node is). That means we can shard across up to 11 GPUs and remain compute bound, above that we are communication bound.
Takeaway: parallelism over an axis of size Y with feed-forward dimension F becomes communication bound when Y > F/2475, which generally constrains us to only intranode TP or at mode 2-node TP
Expert Parallelism
Mixture of Experts models introduce problems because the model comes with E times more model weights with only k times more FLOPs (k << E), making DP significantly harder. This can be somewhat mitigated by sharding our weights along the expert dimension i.e W_in[Ez, D, F]. To perform the MLP blocks this requires introducing 2x AllToAll collectives to dispurse our activations to the corresponding experts.
As noted above the cost of AllToAll_z->k([B, D, k]) if it spans multiple nodes is T_alltoall = 2BD (Z-8) / Z * min(8k/z, 1). This means the overall costs are
T_math = 4BkDF/ZC T_comms = 4BD(Z-8)/WZ * min( 8k/Z, 1)
To be compute bound we need either
- k > Z/8 with F > a (Z-8)/k
- Z >> k and F > 8a
where a = C/W. This gives two domains in which EP is possible, one with a small amount of expert parallelism (roughly 2 nodes) and a small F, or one with a large F and Z arbitrarily large. You'll see both cases in practice, ether small amount of EP (like DS v3 witch has a very small F and relatively small, restricted cross node EP) or models with large F, in which case we can do significant EP alongside TP.
Takeaway if F < 8C/W, EP can span 1-2 nodes with similar cost to TP. If F > 8C/W we can do significant amount of EP, up to E nodes with relatively low cost.
Pipeline Parallelism
PP splits layers across nodes with an extremely low communication cost, since we are just sending the small microbatches of activations (between layers) every couple layers. Historically PP has suffered from pipeline bubbles, but with new zero-bubble pipelining approaches it is typically possible to do without.
The overall communication cost of pipelining is tiny. With N_mb microbatches and N_stages we have
T_pp = 2BD/W_Nmb * (Nmb + N_stages - 2) T_per_layer_comms = 1.5 * 2BD / W N_layers
Since we divide by N_layers the comms cost are a lot smaller than other collectives. So from a communication standpoint pipelineing is basically free. But why don't we just do it then
- Code compelxity. Pipelining does not fit nicely with automatic parallelism frameworks as other approaches. Microbatching changes the structure of the program
- Pipelining makes DP and FSDP hard: This is probably the biggest reason. Zero 3 sharding in particular works badly since it requires us to AllGather the weights on every microbatch which doesnt work when we only have B/N tokens to amortize the AllGather cost.
Quiz 5: LLM rooflines
question 1 This means when performing DP intra node we get a roofline of B/X > 2555 to be compute bound, and for the inter node setting we get B/X > 5750. Making it harder for us to be compute bound in in the multi node regime. For model parallelism within the node we get Y < F/2555 which is basically the same as before.
Question 2 [How to shard LLaMA-3 70B]: Consider LLaMA-3 70B, training in bfloat16 with fp32 optimizer state with Adam.
- Per parameter we need 2 bytes for weights, 8 bytes for optimizer, totalling 700GB. H100s have 80GB of DRAM which means we need at least 9 GPUs at a minimum, which is at least 2x 8xH100 nodes.
- This is a simple calculation. Just the total amount of required FLOPs divided by our FLOPs / second. The number of FLOPs to train a model is 6 * num params * num tokens. The available flops are 4096 * 9.9e14 * 0.45. This would take 959 hours, or 40 days.
- The most amount of TP we can do is given by Y < FW/C which at the node level gives Y < 11. So we can not shard across more than 11 GPUs, which essentially means we can only do 8 way model parallelism without being comms bound.
- This essentially means that we will have to do 512 way pure DP. First, let us check if this is even possible because this implies that we need to be able to fit our model on a single node. Since our model, at 700GB, is sharded across 8 GPUs, our per GPU memory is 87.5GB so it wont fit! We already established this in question 1, we need to shard across at least 2 nodes to fit.
- With ZeRO-3 and 8-way TP we'll be doing 512-way ZeRO-3. This won't be an issue with memory because we are sharding everything aggressively across the nodes as oppposed to DP where the model and weights need to fit into each node. Our per-GPU batch size of 4e6 / 4096 = 976. This is quite low, even below our pure DP limit, and this is twice that limit because we have to more our weights. So no this is not possible to remain compute bound
- with PP, each model parallel shard now spans 8 nodes. As we've seen, this reduced the cost of our leaf level allgathers by 8, so the overall AllReduce/AllGather bandwidth goes from 400GB/s to 3200GB/s. The roofline is then 990e12 / 3200e9 / 309 so we should be good!