AIWiki
Malaysia

Flash Attention

FlashAttention is an IO-aware exact attention algorithm that restructures the standard attention computation into memory-efficient tiled blocks, dramatically reducing GPU memory usage and wall-clock time for transformer models on long sequences.

6 min readLast updated June 2026Foundations

FlashAttention is a fast, memory-efficient algorithm for computing exact self-attention in transformer models. Standard attention implementations materialise a full N-by-N matrix (where N is the sequence length) in GPU high-bandwidth memory (HBM) to perform the softmax and value-weighting steps. For long sequences, this becomes both slow and prohibitively memory-intensive, as HBM reads and writes scale quadratically. FlashAttention, introduced in 2022 by researchers at Stanford, reorders these operations using a technique called tiling to keep intermediate computations in the faster on-chip SRAM, minimising expensive HBM traffic and yielding significant speedups and memory savings without approximating the attention computation.

The IO Bottleneck in Standard Attention

Modern GPU architectures have two distinct levels of memory: a large but relatively slow HBM (typically tens of gigabytes) and a small but fast on-chip SRAM (typically tens of megabytes). Standard attention implementations perform multiple separate kernel passes over HBM: they write the N-by-N score matrix to HBM after the query-key product, read it back to apply the softmax, then write and read again to weight the value vectors. For long contexts, this round-trip cost dominates total runtime. Standard attention is said to be memory-bandwidth-bound because the bottleneck is not the number of floating-point operations but the number of HBM reads and writes.

FlashAttention addresses this by recognising that the entire attention operation can be reformulated as a single pass over the input matrices if the computation is tiled into blocks small enough to fit in SRAM. By processing one tile of queries against one tile of keys and values at a time, FlashAttention never writes the full N-by-N matrix to HBM. The running softmax normalisation — which normally requires seeing all scores before normalisation — is handled using a numerically stable online softmax algorithm that updates a running maximum and a running normaliser as each tile is processed.

Algorithm Overview

The FlashAttention forward pass proceeds as follows. The query matrix Q, key matrix K, and value matrix V are split into blocks. For each block of queries, the algorithm iterates over all blocks of keys and values, computing partial attention scores, updating the running softmax statistics, and accumulating a partial output. At the end of all iterations, the accumulated output is divided by the normalisation factor to yield the exact softmax-weighted value sum. The backward pass (gradient computation) similarly avoids storing the full attention matrix by recomputing it on-the-fly from the saved inputs, trading additional compute for a dramatic reduction in memory.

Versions and Evolution

FlashAttention-2 (2023) refined the original implementation with better parallelism, improved work partitioning across GPU thread blocks, and fewer non-matrix-multiplication operations, achieving roughly 2x the throughput of FlashAttention-1 and approaching the theoretical maximum of A100 GPU utilisation. FlashAttention-3 (2024) extended the technique to NVIDIA Hopper architecture GPUs (H100, H200), exploiting new hardware features such as asynchronous data movement and low-precision tensor core operations to push throughput further. FlashAttention has been adopted as the default attention implementation in major open-source transformer libraries including PyTorch (via torch.nn.functional.scaled_dot_product_attention), Hugging Face Transformers, and the training codebases for GPT-4, Llama, Mistral, and many other frontier models.

Impact on Context Length

One of the most significant practical consequences of FlashAttention is the dramatic expansion of feasible context lengths. Before FlashAttention, training transformer models on sequences longer than a few thousand tokens was impractical due to memory constraints. FlashAttention's sublinear memory usage (O(N) rather than O(N^2) in sequence length for intermediate storage) enabled researchers to train and serve models with context windows of 32,000, 100,000, and eventually millions of tokens. Long-context models used for legal document analysis, scientific literature review, and code understanding are made possible in large part by FlashAttention.

Relation to Approximate Attention

FlashAttention is an exact algorithm: it computes precisely the same result as standard attention, just with better IO efficiency. It should not be confused with approximate attention methods such as linear attention, Longformer sparse attention, or Big Bird, which alter the mathematical definition of attention to reduce asymptotic complexity at the cost of some approximation error. FlashAttention achieves its efficiency through algorithmic and hardware-level optimisation of the exact computation.

See Also

References

References

  1. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Re, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Advances in Neural Information Processing Systems (NeurIPS) 2022. arXiv:2205.14135.
  2. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. Proceedings of ICLR 2024. arXiv:2307.08691.
  3. Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. arXiv:2407.08608.
  4. Raschka, S. (2024). Understanding and Coding the KV Cache in LLMs from Scratch. Ahead of AI newsletter. https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms