How You Can Achieve 2.5x Speedups via Mamba-2 Kernel Fusion
Optimize your Mamba-2 SSD modules with a fused Triton kernel for 1.50x-2.51x speedups on NVIDIA A100/H100 GPUs. Reduce memory bandwidth bottlenecks.
Editorial Note
Reviewed and analysis by ScoRpii Tech Editorial Team.
In this article
State-Space Models and Kernel Fusion: A Performance Leap
You can now achieve up to a 2.51x speedup in Mamba-2 inference performance through a newly developed fused Triton kernel. This optimization directly addresses the computational bottlenecks inherent in State-Space Models (SSMs) when processing extremely long sequences – exceeding 128K tokens – and offers a significant reduction in infrastructure costs for long-context applications. Traditional Transformer self-attention scales quadratically with sequence length, making it impractical for these workloads. Mamba-2, leveraging SSMs, scales linearly, but still benefits from low-level kernel optimization.
The Mechanics of SSD Prefill Optimization
Previously, performance with the original Mamba-2 implementation was constrained by the sequential execution of five distinct kernels: Chunk Cumsum, BMM, Chunk State, State Passing, and Chunk Scan. This sequential approach forced constant data transfer between the GPU’s fast SRAM and slower DRAM, creating a major performance bottleneck. The new fused Triton kernel consolidates the entire SSD prefill computation for a layer into a single GPU launch.
This represents the first end-to-end Triton fusion of all five SSD kernels, allowing the GPU to treat these steps as a unified operation. A key challenge was managing the dependencies introduced by the State Passing step. The unfused version handled these dependencies with loops and threadblock splitting.
The fused kernel addresses this by splitting the State Passing loop iterations into separate threadblocks, utilizing atomics to ensure correct ordering. While this introduces a degree of serialization – a local slowdown for that specific step – Amdahl’s Law demonstrates the overall benefit. If State Passing represented 1/7th of the total runtime, a 7x slowdown in that step would only result in a 1.14x increase in overall runtime due to compute overlap.
Hardware Utilization and Precision Economics
Benchmarking with the fused kernel on NVIDIA H100 and A100 GPUs reveals that you are now effectively L2-bound rather than DRAM-bound. The fused SSD module achieves compute utilization between 40-50% and memory utilization between 65-75%. This contrasts with standard matrix multiplications (matmuls) in Mamba-2, which typically reach 85-96% compute utilization.
Current performance is limited by register pressure and shared memory, resulting in approximately 25% occupancy. Profiling with Nsight Compute indicates that fewer than 3% of warp stalls are caused by synchronization for State Passing; the majority originate from data loading and general computation. You can optimize throughput by carefully considering precision choices.
Using fp16 for states yields roughly a 16% performance gain compared to fp32. While the fused kernel supports both, the fp16 configuration delivers the reported 1.50x-2.51x speedup. Numerical accuracy with “relaxed dtypes” (fp16) shows approximately 1/3 of output elements differing from the fp32 baseline due to the non-associative nature of floating-point addition.
However, over 99.7% of elements match within a 1e-3 tolerance, and 99.9995% achieve accuracy at a standard 1% tolerance level. For applications demanding absolute precision, both the fused and original kernels perform best with a chunk_size of 256 when using fp32 states.
Infrastructure Impact: What This Means For You
The efficiency of Mamba-2 becomes particularly apparent at context lengths of 128K tokens. Because compute and memory requirements only double with sequence length – unlike the quadratic scaling of self-attention – Mamba-2 offers substantial cost savings for scaling. This is already being demonstrated in production environments. For example, IBM’s Granite 4.0 model family employs a hybrid architecture, deploying nine Mamba-2 layers for every one attention layer.
The current fused SSD kernel is also applicable to linear attention, as the SSD formula simplifies to a linear attention update when the variable A equals 1. Further optimizations are possible. The current kernel does not yet leverage the Tensor Memory Accelerator (TMA) or thread block clusters available on Hopper GPUs, nor does it utilize Tensor Memory in Blackwell GPUs.
While fusing depth-wise convolutions and layernorms into the SSD operation has yielded limited benefits thus far, the current 8-13% end-to-end speedup for a 2.7B parameter model at 128K context represents a significant reduction in infrastructure costs for both long-context inference and training.
The Bottom Line for Developers
This Triton fusion represents a practical, immediately deployable optimization for Mamba-2. You should prioritize evaluating this fused kernel if you are working with long-context applications or seeking to reduce inference costs. The performance gains are substantial, and the accuracy trade-offs with fp16 are acceptable for many use cases. Keep an eye on future developments leveraging TMA and thread block clusters for even greater performance improvements as newer GPU architectures become more prevalent.
Originally reported by
PyTorch BlogStay Updated
Get the latest tech news delivered to your reader.