TPUs and Sharding
While GPUs have long been popular for training and deploying AI models, Google's TPUs offer a compelling alternative, especially for large-scale applications. I recently had the chance to explore Google's TPUs on a fairly large scale. My impression is quite positive. They feel solid, fast, reliable, and cost-efficient. This post explores the basics of how they work.
A Tensor Processing Unit (TPU) is a type of application-specific integrated circuit (ASIC) specialized in performing the common operations needed for modern deep learning - matrix multiplication, elementwise vector computations, and embedding lookups. Unlike GPUs which can be considered general purpose processors for tasks that benefit from massive parallelism, TPUs are very restricted in their applications. Their main use is in matrix multiplication. This limitation has made it easier to design TPUs to be extremely efficient in what they do, and to be easily composable into larger systems that scale almost effortlessly.
We start by looking at a basic TPU core. It has the following components:
- A MXU (matrix multiply unit) - the main part of the TPU, responsible for performing matrix multiplication. It operates using a systolic array architecture.
- A VPU (vector processing unit). It performs general operations such as computing activation functions, elementise multiplications and additions, aggregations, and broadcasting. Essentially anything that is not suitable for matrix multiplication.
- A VMEM (vector memory) - a type of on-chip memory used as scratchpad for the VPU computations. It is located very close to the compute units and has a similar role as the L1/L2 cache on CPUs. It is programmer-controlled.
- A SPU (scalar processing unit) which processes all intructions, executes transfers from HBM to VMEM, executes scalar control logic, loops, and indexing calculations.
- A SMEM (scalar memory) is used to store loop indices, offsets, and small constants.
- A HBM (high bandwidth memory) - a large chunk (e.g. 16GB for
v4
TPUs) of memory for storing tensors and variables. It directly determines how big a model can be loaded onto the TPU and how big the batch size could be.
The way the TPU works is as follows. Suppose we want to perform elementwise multiplication of two matrices, stored in the HBM. Their bytes are streamed in chunks to the VMEM and from there to a number of VREGs (vector registers), which are closest to the VPU and the MXU. The VPU has many ALUs which perform the computations in chunks, of fixed tiled shapes like (8, 128)
, read from and written to the VREGs. All operations are pipelined and overlapped so that the VPU processes the current chunk while the next one is loaded into the VREGs. As the indivudal chunks become ready, they are streamed to the HBM, where the result is stored.
Briefly about the MXU, it uses a systolic array which can process (for v4
TPUs) one bfloat16[8,128] @ bf16[128x128] -> f32[8,128]
per 8 clock cycles. It is a grid of 128x128
ALUs, each performing one multiply and add operation. Data flows through the array in waves rhytmically (hence the name), not all at once. Weights \(\mathbf{W}\) of shape [128x128]
approach from above, while inputs \(\mathbf{X}\) of shape [8x128]
approach from the left. At each clock cycle (when fully pipelined) each ALU multiplies the corresponding weight with the input element and adds this to the previous result from above. Irregular inputs are usually padded to the desired fixed tiles. Overall, matrix multiplications are sped up dramatically. A simple schematic is shown in Fig. 1.
Depending on the TPU versions, each core could have multiple MXUs of different configurations. The HBM sizes and the HBM \(\leftrightarrow\) VMEM bandwidth are also quite variable from version to version. In any case, cores exist within chips. With v4
, a TPU chip contains two cores, and they share their HBM. These chips are then attached to a TPU board called a tray four by four. Each tray holds four chips or eight cores. Naturally, these chips need to be reached by the instructions and commands given by the user, therefore the tray is attached to the host system, controlled by a CPU. The connection between the 4 chips and the CPU host is achieved by PCI Express (PCIe) network.
Next, we need to have virtual machines running on the host. Google Cloud’s TPU-VMs are built using KVM (Kernel-based Virtual Machine) though it's not traditional virtualization of the TPU chips themselves — instead, it's virtualization of the host CPU resources co-located with dedicated, physically attached TPUs. So, it's KVM-based virtualization of the CPU host, with direct PCIe or similar pass-through access to physical TPU chips.
Next, the TPU chips are also connected to other TPU chips, forming a TPU pod. The chips are stacked in a 3D grid and each one is connected to its six nearest neighbors. These connections do not go through their hosts, they are direct links between chips. This is done through ICI (inter-chip interconnects). There are wraparound edge links that make the chip connections form a 3D torus, as in Fig. 2. This shape keeps bandwidth uniform everywhere in the chip mesh, while reducing the maximum distance between any two chips from \(N\) to \(N/2\).
Users can request slices from the TPU pod, such as 2x2x1
or 4x4x8
. These are contiguous blocks of physically-adjacent chips which are idle and reservable. With 4x4x8
one would get 128 chips, each with its own HBM worth of 16GB and two cores each. A scheduler would then map those chips to their respective CPU host machines - 32 of them. Each host runs a TPU-VM (or multiple VMs depending on slicing granularity). Since this is a 3D topology (X=4, Y=4, Z=8), the TPU pod may need to enable wraparound links (e.g., between chip 0 and chip 7 on the Z-axis). These are handled via optical links that "close the torus" in each axis. Routing tables for the ICI network are configured. At the end, each VM is assigned a slice of the TPU hardware and has XLA/TPU runtime and drivers preinstalled. And then we're finally ready to launch a scary big Jax script.
Now, consider that the time for for training and inference is either spent on compute, or on communications. Often these are overlapped, which means we can lower bound training and inference time by the maximum of the compute and communication time, and upper bound it by their sum. We can compute an algorithm's arithmetic intensity as the number of FLOPs needed divided by the number of bytes that need to be communicated, either within or between chips.
For example, the dot product between two bf16 vectors of size \(N\) has to read \(2N + 2N\) bytes, write \(2\) bytes and compute \(N\) multiplications and \(N-1\) additions. Therefore, its arithmetic intensity is \((N + N-1)/(2N + 2N + 2) \rightarrow 1/2\), as \(N\rightarrow\infty\). For each byte read/written, it has to perform 0.5 FLOPs. Consider at what time will the dot product be bounded by the math vs by communication of the data:
The intensity of the v5 VPU is around 3, higher than the dot product's 0.5, which means that the dot product will never be compute-bound. It will always be the case that some clock cycles are wasted when the VPU is waiting for data to be transferred from HBM to VMEM. For accelerators, it is considered better to be compute-bound, which occurs when \(T_\text{math} > T_\text{comms}\), because at least you are not wasting any clock cycles waiting for data transfers.
Overall, there are four bandwidths that could bottleneck the training/inference process: HBM bandwidth between a core and its associated HBM, ICI bandwidth between a TPU chip and its nearest 4 or 6 neighbors, PCIe bandwidth between a CPU host and its associated tray(s) of chips, and DCN (data center networking) bandwidth between multiple CPU hosts, typically hosts not connected by ICI. An example of the last one are multislices - connected slices from different pods.
Оnce we have a few TPU chips or GPUs available, Jax would view them simply as different devices. Yet, bigger models or bigger data batches may not fit on one device, in which case they have to be sharded. Suppose the devices are organized in a 2D mesh with named axes \(x\) and \(y\). А tensor (representing either data or model weights) of shape \((N, D)\) is sharded over the device mesh like so \((N_x, D_y)\). This means that the first axis of the tensor is sharded over the \(x\) axis of the devices, and the second axis over the \(y\) axis of the devices. All in all, each device will hold \(1/(|x|\dot|y|)\) of the data. Likewise, \((N_x, D)\) means that the first axis of the tensor is sharded over the \(x\)-axis of the devices, while the second axis is replicated on all devices. \((N, D_{xy})\) means that the second axis is sharded over all \(|x||y|\) devices.
In Jax, such sharding is very intuitive. We define the layout of the devices, with the axis names \(x\) and \(y\). Then we define the sharding, which indicates how the tensor is sharded over the device mesh. The code P(None, ('x', 'y'))
resembles the sharding \((N, D_{xy})\).
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P(None, ('x', 'y')))
X = jnp.arange(0, 32, 1).reshape(8, 4)
X = jax.device_put(x, device=sharding)
Naturally, this will work only if the axis sizes of X
are divisible by \(|x|\), \(|y|\), and \(|x||y|\). Now that the input is sharded across the devices, how do perform computations over it? Any elementwise computation is easy, because it does not involve communication between different devices. There is no overhead in this case. However, if we're doing aggregations or pretty much anything non-elementwise, we often need to copy data from one device to another. The case of doing distributed matrix multiplication has been studied in details. It is also the most important operation for ML.
With a non-distributed matmul we have \(\mathbf{A}[I, J] \cdot \mathbf{B}[J, K] \rightarrow \mathbf{C}[I, K]\). This would be the case if we have a single device or we have \(\mathbf{A}\) and \(\mathbf{B}\) replicated on all devices. In any other situation, there are four main cases:
- Neither multiplicand has a sharded contracting dimension. This would be for example \(\mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K_Y] \rightarrow \mathbf{C}[I_X, K_Y]\). Since the contracting dimension \(J\) is not sharded, each device simply performs its own block matrix multiplication without any communication in terms of other blocks.
- One multiplicand has a sharded contracting dimension - e.g. \(\mathbf{A}[I, J_X] \cdot \mathbf{B}[J, K] \rightarrow \mathbf{C}[I, K]\). Here each device has only part of the \(J\) axis. Therefore we need to \(\textbf{AllGather}\) the shards across all devices in the \(X\) mesh axis. The \(\textbf{AllGather}\) operation communicates all shards until each device has a full copy of the data, after which the matmul can be performed.
- Both multiplicands have sharded contracting dimensions - \(\mathbf{A}[I, J_X] \cdot \mathbf{B}[J_X, K] \rightarrow \mathbf{C}[I, K]\). Here each device has some columns from \(\mathbf{A}\) and some rows from \(\mathbf{A}\), so it can perform a local matrix multiplication on those indices that it has. After this is done, each device ends up with a partial result. We perform a \(\textbf{AllReduce}\) to communicate all partial results to all devices and to sum-reduce them, so that each device can have the final full matrix \(\mathbf{C}\).
- Both multiplicands have a non-contracting dimension sharded along the same axis - \(\mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K_X] \rightarrow \mathbf{C}[I_X, K_X]\). In this case \(\mathbf{C}[I_X, K_X]\) is nonsensical. We need to perform an \(\textbf{AllGather}\) and do either \(\mathbf{A}[I, J] \cdot \mathbf{B}[J, K_X] \rightarrow \mathbf{C}[I, K_X]\) or \(\mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K] \rightarrow \mathbf{C}[I_X, K]\).
Communication primitives. TPUs rely on a small number of communication primitives, which we saw above. In more details they are:
- AllGather - each device sends its part of the data to all other devices. Thus, all devices end up with the full data. The runtime is simply the bytes sent divided by the bandwidth.
- ReduceScatter - each device chunks its data and sends chunk \(i\) to device \(i\), \(\forall i\). Device \(i\) receives all chunks from the other devices and reduces them, ending up with a full result for only its chunk. The runtime is similar to that of AllGather.
- AllReduce - each device sends the full local data to all other devices and reduces all incoming data with its own. All devices end up with the same reduced data. The runtime is roughly twice that of AllGather.
- AllToAll - each device sends a unique chunk of local data to every other device, and also receives a unique chunk from every other device. It essentially changes the sharding index \(\mathbf{A}[I_X, J] \rightarrow \mathbf{A}[I, J_X]\). Runtime is about a quarter of that of AllGather. AllToAll is very common in mixture-of-expert models where each token has to be routed to its chosen expert, which may live on a different device.
Overall, TPUs are not magic, they're actually pretty simple. But they’re engineered for scale. If you're serious about training massive models efficiently, it’s worth investing time to understand how XLA, partitioning strategies, and collective ops like AllToAll and ReduceScatter interact. The tooling is becoming more mature, and the gains are real. Exciting times indeed.