Choosing Fast: From Softmax to FlashSampling
Series parts
On this page
By the time an LLM chooses its next token, it is tempting to think the hard part is over.
The model has already built contextual representations and produced logits for the whole vocabulary. But drawing one exact sample from that distribution can still be expensive.
At the scale of a real LLM, next-token sampling happens over a very large vocabulary at every decode step. The math is simple. The memory traffic is not.
Three terms
HBM means high-bandwidth memory: the large, off-core GPU memory system. A tile is just a chunk of the vocabulary; instead of processing the whole vocabulary at once, a GPU kernel can process it in smaller pieces that fit in fast on-chip memory. And when people talk about kernel-level decode workloads, they mean benchmarking the low-level GPU routine itself, not only whole-application latency.
The standard pipeline
The final projection of a language model, often called the LM head, is just a matrix multiply:
where is the batch of hidden states and is the vocabulary-sized output matrix. The result is one logit per vocabulary item.
A conventional exact-sampling pipeline then does roughly this:
- Compute logits.
- Write them to HBM.
- Read them back.
- Apply temperature, masking, or other transforms.
- Compute softmax.
- Build a cumulative distribution.
- Draw a uniform random number.
- Search for the first cumulative mass that crosses it.
Mathematically, nothing is wrong with that. Systems-wise, it is wasteful. If you only need one sampled index, materializing the full [B, V] logits tensor means a lot of memory traffic for something that will immediately be thrown away.
Why HBM matters so much
On-chip memory is tiny but extremely fast. HBM is much larger, but much slower.
If you only need one sampled index, why keep hauling the whole logits vector back and forth through slow memory?
Probabilities are cheap on paper and expensive in bandwidth.
The mathematical escape hatch
The key move is to stop thinking “softmax, then sample” and start thinking “argmax after the right random perturbation.”
This is the Gumbel-Max trick. If are transformed logits and are independent Gumbel random variables, then
is an exact sample from the categorical distribution with probabilities proportional to .
Why does this work? A Gumbel distribution is the extreme-value distribution for maxima of exponentials. When you add Gumbel noise to each logit and take the argmax, the probability of each index winning is exactly proportional to , which is the same distribution softmax would produce. The engineering consequence is the important part: sampling becomes an argmax.
With that change, sampling is no longer a normalization problem. It is a reduction problem: find the maximum perturbed score and return its index.
Why tiling works
Once sampling becomes “find the max,” tiling becomes natural.
FlashSampling processes one vocabulary tile at a time. For each tile, it computes logits on chip, adds Gumbel noise on chip, and keeps only the best candidate from that tile. Then it performs a small second-stage reduction over tile winners.
This is exact, not approximate. Maxima decompose over partitions: if a global winner exists, it must also be the winner of its own tile. So a “max of tile-maxima” gives the same answer as a max over the whole vocabulary.
The tile is what lets the math and the memory hierarchy line up.
A toy version
import numpy as np
def gumbel(shape): u = np.random.rand(*shape) return -np.log(-np.log(u))
def gumbel_max_sample(logits): scores = np.array(logits, dtype=np.float64) scores += gumbel(scores.shape) return int(np.argmax(scores))
def tiled_gumbel_max_sample(logits, tile_size=4): logits = np.array(logits, dtype=np.float64) best_score = -np.inf best_index = -1
for start in range(0, len(logits), tile_size): end = min(start + tile_size, len(logits)) tile = logits[start:end] scores = tile + gumbel(tile.shape)
local_index = int(np.argmax(scores)) local_score = float(scores[local_index])
if local_score > best_score: best_score = local_score best_index = start + local_index
return best_indexNot fast. Not fused. Not a GPU kernel. But it preserves the conceptual structure of FlashSampling: perturb, keep one winner per tile, reduce the winners. That is enough to understand the idea without going into GPU kernels.
Why this matters
Sampling started in this series as a fair choice from a finite set. Then it turned into streams, approximate sets, search, classifier outputs, and latent variables. Here it becomes a systems problem: choose one next token from a huge discrete distribution while respecting GPU memory hierarchy.
That arc is useful because it shows one continuous line from “pick a fair random element” to “fuse exact sampling into the LM-head epilogue.”
Across all of these examples, randomness is not ornamental. It is one of the ways software handles scale, uncertainty, and cost.
This is the final post in the series Why Randomness Keeps Sneaking Into Software. If you enjoyed it, the best starting point for a friend is The Fairest Possible Choice.