Accelerating LLM Inference: Fast Sampling with Gumbel-Max Trick

Community Article Published October 24, 2024

Introduction

Large Language Model (LLM) inference speed is heavily impacted by the token sampling process. At each generation step, we need to sample the next token from a probability distribution over the entire vocabulary (typically 32K to 100K tokens). The standard approach using torch.multinomial has become a notable bottleneck in the inference pipeline.

The Problem with Traditional LLM Sampling

The traditional sampling process in LLM inference looks like this:

  1. Get logits from the model
  2. Apply softmax to convert logits to probabilities
  3. Use torch.multinomial to sample from the probability distribution

This approach has two main bottlenecks:

  • Computing softmax over large vocabulary sizes is expensive
  • The multinomial sampling operation itself is relatively slow

The Key Insight: Gumbel-Max Sampling

The core innovation in our approach comes from two key observations about the Gumbel-Max trick:

  1. Sampling with Gumbel-Max is mathematically equivalent to categorical sampling:

    # Instead of:
    probs = torch.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    
    # We can do:
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))
    next_token = torch.argmax(logits + gumbel_noise, dim=-1)
    
  2. The Critical Optimization: Gumbel noise can be pre-computed:

    • The noise tensor is independent of the logits
    • We can prepare it before receiving model outputs
    • This removes it from the critical path of token generation
    • We avoid computing softmax entirely

Performance Results on A100

Our benchmarks on A100 80GB show significant speedups across different scales. Complete benchmark code and implementation can be found at: https://github.com/NonvolatileMemory/fast_llm_sampling/tree/main

Small Scale (batch_size=32, vocab_size=32000)

  • Traditional: 0.600 ms ± 0.058 ms
  • Gumbel-Max: 0.214 ms ± 0.004 ms
  • 2.8x speedup

Medium Scale (batch_size=128, vocab_size=50000)

  • Traditional: 4.549 ms ± 2.609 ms
  • Gumbel-Max: 1.294 ms ± 0.009 ms
  • 3.5x speedup

Large Scale (batch_size=512, vocab_size=100000)

  • Traditional: 64.386 ms ± 2.748 ms
  • Gumbel-Max: 30.544 ms ± 1.725 ms
  • 2.1x speedup

Implementation Details

The key to efficient implementation is proper noise pre-computation:

class GumbelSampler:
    def __init__(self, batch_size, vocab_size, device):
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        # Pre-compute noise
        self.noise = self._prepare_gumbel_noise(device)
    
    def _prepare_gumbel_noise(self, device):
        # Generate noise tensor once
        uniform_noise = torch.rand(self.batch_size, self.vocab_size, device=device)
        return -torch.log(-torch.log(uniform_noise))
    
    def sample(self, logits):
        # Direct sampling without softmax
        return torch.argmax(logits + self.noise, dim=-1)