Development History#
Origin#
Sparse Attention is an optimized attention mechanism that maps a query vector and a set of key-value pairs to an output vector. Unlike single-head and multi-head attention, it does not compute the similarity between the query vector and all key vectors, but only calculates the similarity between the query vector and a subset of key vectors, thereby reducing computational load and memory consumption. The concept of sparse attention first appeared in the 2018 paper "Generating Long Sequences with Sparse Transformers," which proposed a Transformer-based long sequence generation model that used sparse attention to handle text sequences exceeding 8000 words.
In fact, in trained Transformer models, the attention matrix is often sparse, meaning not every token needs to attend to all other tokens. Some interactions between tokens may contribute little to the final output and can be ignored.
The sparse approach can be fixed patterns (such as local windows), content-based selections (such as the most relevant other positions to the current position), or patterns learned through training.
Based on the criteria for determining sparse connections, methods are divided into two categories: position-based sparse attention and content-based sparse attention.
Position-Based Sparse Attention#
-
Global Attention
Introduces the concept of global nodes, where a node can attend to all other nodes. In other words, these nodes act as intermediaries for communication among all nodes. The essence of sparse attention is that it is unnecessary for every node to communicate point-to-point, somewhat like the concept of P2P to P2S.
-
Band Attention
Considers the local nature of data distribution, similar to the concept of a sliding window, where a node only attends to its surrounding nodes, limiting attention interactions to local attention (the long-range receptive field that attention has is disrupted by this truncation).
-
Dilated Attention
Based on Band Attention, it sets a more distant interaction node by leaving gaps for itself, equivalent to setting a step size interval on the sliding window (after truncating a segment and finding the effect unsatisfactory, one thinks of extending it, but it still feels clumsy).
-
Random Attention
Implements random sampling of some edges for each query, feeling like pure randomness in winning, the effect should not be great.
-
Block Attention
Divides the input sequence into several non-overlapping query blocks and assigns a local memory block to each query block for efficient processing of long sequences. (Not quite understanding this, isn't this just Attention reducing both length and width by one size? 🤓)
Content-Based Sparse Attention#
-
Maximum Inner Product Search (MIPS)
To efficiently construct content-based sparse graphs, solutions to the Maximum Inner Product Search (MIPS) problem can be utilized. The goal of MIPS is to find the key with the maximum dot product with the query without calculating the dot product between the query and all keys.
NSA (Inference Efficiency and Training Feasibility)#
Origin#
Modeling long texts is extremely important for large models, but the traditional attention mechanism has a quadratic computational complexity. Increasing context length adds a significant amount of extra computation; doubling the length increases the computation by three times .
The state-of-the-art sparse attention is divided into two categories: KV-cache eviction and block KV-cache selection, sampling, and clustering methods. However, neither of these categories performs as well as claimed.
The main issues are: local sparsity, incompatibility with Attention, end-to-end training.
- Local Sparsity: Methods like H2O only apply sparse matrices during the autoregressive decoding phase, but prefill requires computation-intensive preprocessing. MInference only uses sparse attention during prefill. These methods do not implement sparse attention in all phases, leading to poor performance in prefill-dominant tasks like book summarization, code completion, or decode-dominant tasks like reasoning chains, meaning there is no unified architecture for end-to-end training for downstream tasks.
- Incompatibility with Attention: Most sparse matrices consider sparsity for MHA, but for structures like MQA and GQA, there may be incompatibilities. For example, in the Quest method, each attention head has its independent kv-cache, while MQA and GQA share.
- End-to-End Training: Most sparse matrices are currently aimed at inference tasks, requiring a sparse matrix tailored for training tasks. However, models trained on dense matrices perform poorly under sparse inference because 20% of attention can only cover 70% of the Attention Score. Moreover, works like ClusterKV and MagicPIG introduce discontinuous computation graphs, preventing proper backpropagation. Non-continuous memory access hinders effective adaptation to fast attention techniques like FlashAttention, which rely on continuous memory access and block computation for high throughput.
The NSA report mentions that the main issues to be addressed are twofold:
- One is inference optimization in conjunction with hardware; in the prefill and decode phases, turning theoretical optimizations into actual acceleration requires hardware-friendly algorithms, mainly focusing on memory access and hardware bottleneck scheduling.
- The second is training-aware algorithm design, supporting end-to-end learning of sparse patterns, avoiding performance loss from traditional methods of "train first, prune later."
The proposed solutions are mainly divided into three steps:
-
Compressed Coarse-Grained Tokens (cmp)
Aggregates continuous key/value blocks into block-level representations, capturing coarse-grained semantic information and reducing computational burden.
In simple terms, it merges multiple dimensions of kv into one dimension. For example, a 1024-dimensional kv becomes 64 dimensions.
-
Selectively Retained Fine-Grained Tokens (slc)
Selectively retains important tokens to compensate for potential information loss from compression.
In simple terms, it is similar to MIPS in finding the most relevant token attention, while the remaining tokens are not worth attending to.
-
Sliding Windows (win)
A sliding window branch specifically designed to handle local context, addressing the issue that local patterns may dominate the learning process.
In simple terms, this is the aforementioned Band Attention (it turns out it is indeed useful 🤯).
Demo#
We can illustrate this process with a simple example:
Suppose our input is , and assume . Then, if we compress with a length of 8, since are symmetric, we can take as an example, dividing into 8 blocks . After compression, we transform into a vector block of the same size as . In simple terms, we reduce many blocks of into one to save 's memory and accelerate computation. At this point, we calculate attention scores using the original and the compressed to obtain compressed attention .
The middle part is called . During compression, we obtained the compressed KV blocks , and now we calculate the largest attention scores. We choose , assuming they are , meaning the third and seventh blocks. We then restore the corresponding selected compressed blocks, expanding back to to process and obtain the required block, then calculate to get the selected attention .
On the right is the sliding window, where we select the most recent 8 from the original to obtain sliding window attention .
Finally, we use a gating function to control:
Analyzing the saved , originally there were 64 , our compressed attention used 8 , the selected attention used 16 , and the sliding window attention used 8 , meaning we now only used 32 in total, saving half of the memory.
Background#
Attention#
For a new incoming query , it needs to query all previous pairs:
Arithmetic Intensity#
: The time to access memory equals the number of bytes accessed in memory divided by the memory bandwidth of the processor.
: The mathematical time equals the number of operations divided by the mathematical bandwidth of the processor.
If , then the algorithm is mathematically constrained.
The above can be replaced with , where the left side is the ratio of algorithm implementation operations to accessed bytes, known as the algorithm's arithmetic intensity, and the right side is the ratio of the processor's mathematical bandwidth to memory bandwidth, known as the byte ratio.
- If the algorithm's arithmetic intensity exceeds the GPU's byte ratio, then the algorithm is constrained by computational power, also known as
math bound
, meaning performance is limited by computational powerFLOPS
(computationally constrained/computation-intensive operators). - If the algorithm's arithmetic intensity is lower than the GPU's byte ratio, then the algorithm is constrained by memory, also known as
memory bound
, meaning performance is limited by memory bandwidth (memory constrained/memory access-intensive operators).
The arithmetic intensity of algorithms/network layers should be as high as possible compared to the GPU's byte ratio to fully utilize the
gpu
computational power.
In the prefill phase, the batch matrix multiplication exhibited by a large amount of causal self-attention shows high arithmetic intensity, meaning performance is constrained by computational power. In the autoregressive decode phase, since each generated token requires access to all previous kv-cache, it becomes constrained by memory bandwidth. This difference can lead to inconsistent optimization directions, reducing computational complexity during prefill and training, while reducing memory access during the decode phase.
Method#
Two directions: algorithm design and kernel optimization.
Overall Overview#
Optimizing the original into more compact and information-dense , where this transformation dynamically changes based on , can be expressed with the formula:
For function mapping, there are three methods, namely the aforementioned cmp, slc, and win methods, controlled by a gating factor, as shown below:
Compressed Coarse-Grained Tokens (Compression)#
Where is the length of the block, is the sliding step between blocks, and is a learnable MLP that maps keys within blocks to a compressed key. The text mentions that generally, this should be less than to alleviate information fragmentation.
The original text states that compressed representations can capture coarser-grained high-level semantic information and reduce the computational burden of attention, which is just what it is; if experiments yield good results, that’s what matters (🤓).
Selectively Retained Fine-Grained Tokens (Selection)#
Using only the coarse-grained tokens mentioned above is certainly insufficient, as it loses a lot of fine-grained information. We also need fine-grained blocks to help the model understand better.
Block Selection:
Based on hardware-friendly considerations and the fixed distribution of attention scores. This step is crucial for achieving efficient computation on modern GPUs. Modern GPUs have far superior throughput for accessing contiguous blocks compared to random index-based reads, while block computation can maximize the efficiency of the GPU's tensor core. Attention scores typically exhibit spatial continuity, indicating that adjacent keys often have similar levels of importance, as discovered in experiments conducted later in DS. The lighter areas indicate higher attention values, as shown in the figure.
Important Attention Score Calculation:
Calculating all attention scores is clearly a costly task, but we can reduce this overhead by calculating the attention based on the previously compressed attention.
However, the above is only based on the compressed attention scores. Generally, we need to define the selection block length as . When , . For inconsistent block situations, given , then
However, and are not the same; in the NSA scheme, they are made consistent. In GQA and MQA, for different Q heads that share the same KV values, their important attention scores are the same, meaning all attention scores are summed as the attention score for this KV, which can directly save a lot of memory.
Selecting the Largest k Attention Scores:
Selecting the largest k attention scores, noting that here we are selecting compressed blocks, i.e., , where indicates ranking.
Based on the compressed blocks, we restore the original blocks.
Sliding Windows#
In the attention mechanism, local patterns often adapt more quickly and can dominate the learning process, which may prevent the model from effectively learning from the first two kvs. To address this issue, a dedicated sliding window branch is introduced, explicitly handling the original context, allowing other branches (compression and selection) to focus on learning their respective functions.
The three branches provide independent keys and values. This architectural design achieves stable learning by preventing gradient interference between local and global while introducing minimal overhead. After obtaining the six KV values , results are obtained using a gate-controlled method.
Kernel Design#
To achieve FlashAttention-level acceleration in the train and prefill phases, a sparse matrix aligned with hardware is implemented using Triton.
Both the compression phase and sliding window can effectively utilize FlashAttention for optimization, so the kernel optimization mentioned here mainly targets the computation of discrete attention sequences generated during the selection phase.
Optimizing for GQA and MQA, if we follow FlashAttention's strategy and load temporally contiguous query blocks into SRAM, it will lead to inefficient memory access because queries within blocks may require non-overlapping KV blocks. To solve this problem, we load all query heads sharing the same kv blocks in GQA together into SRAM.
Group Center Data Loading:
In the inner loop, load queries with heads , then find the compressed index it belongs to.
Shared KV:
In the inner loop, based on , select the required , where is the smallest kernel block size satisfying .
The above green block represents a q and a segment of kv computation. Here we note that as increases, since the selected KV blocks are always less than or equal to 3, the longer the NSA, the more significant the acceleration.
MoBA (Simplicity is Key, Plug and Play)#
Background#
Regarding sparsity, not only is the sparsity of attention scores mentioned, but also the sparse connectivity characteristics observed in brain regions related to memory storage (feels like it could be submitted to ACL 🤓).
Traditional sparse attention has two major flaws: one is the use of predefined structures based on specific tasks, leading to poor generalization; the second is the dynamic selection of tokens for sparse attention training, which generally does not help during the training phase.
Similar to the two issues NSA aims to solve regarding inference speed and trainability, MoBA also addresses the problems of accelerating inference and trainability, using a block attention mixture mechanism (MoBA, Mixture of Block Attention) that transfers expert mixture (MoE) from MLP to Attention.
The core method is to dynamically select relevant historical blocks for each .
Method#
The core methods of MoBA are block partitioning and selection strategy (sounds similar to NSA's compressed blocks and selection 🤔).
Overall#
The methodology of MoBA is straightforward: , assuming a length of , then dividing the input of length into small blocks, where the block size is defined as . At this point, an index is defined, primarily for subsequent block selection.
Then calculate the of on , selecting the largest blocks. Note that the calculation method for is , where the outer angle brackets indicate the inner product, and mean_pool
indicates the average value, equivalent to calculating the average value over the block.
MoBA also emphasizes the importance of maintaining causality in autoregressive language models, ensuring that cannot route to any future blocks. One particularly special case is defining the "current block" as the block containing the query token itself; routing to the current block may also violate causality, as the average pool of the entire block may inadvertently include information from future blocks. To address this issue, we enforce that each token must route to its respective current block and apply a causal mask during the current block attention operation.
Final Thoughts#
The core differences between MoBA and NSA are:
- MoBA focuses on partitioning and selecting smaller blocks for computation, while NSA compresses and selects small blocks for computation and adds a sliding window. The core logic of these computations is different; MoBA's selection is based on top-k via inner product, which does not require gradient involvement, while NSA's selection does involve gradient backpropagation for correction.
- NSA focuses on fine-grained KV blocks, while MoBA allows different query heads to access different blocks, with different emphases, and neither can do what the other does.
Author's Questions:
- In MoBA, the attention scores can be calculated using FlashAttention after partitioning into small blocks; why can't NSA use this after selecting small blocks?