Skip to content

BitNet GPU Inference: Running 1-bit LLMs on NVIDIA GPUs with W2A8 Kernels

Problem

When I tried running BitNet on my GPU, I expected blazing-fast inference. After all, 1-bit weights mean minimal memory bandwidth, right? But my first attempt was disappointing - the model ran slower than standard BF16 models.

I had heard about W2A8 kernels for GPU acceleration but couldn’t find clear documentation on how to set them up properly. The CPU-only inference worked fine, but GPU support seemed like a mystery.

What I Tried First

I cloned the BitNet repository and tried the basic setup:

initial-attempt.sh
git clone https://github.com/microsoft/BitNet.git
cd BitNet
pip install -r requirements.txt

Then I tried to run inference:

run-inference.sh
python generate.py --model_path ./checkpoints/bitnet-b1.58-2B-4T

But I quickly hit a wall. The GPU kernels weren’t being used, and the model was falling back to slow CPU execution. I was confused because I had an NVIDIA GPU available.

Understanding the GPU Architecture

I needed to understand what makes BitNet GPU inference different. The key insight is the W2A8 format:

w2a8-explanation.txt
+------------------+-------------------+------------------+
| Component | Precision | Storage |
+------------------+-------------------+------------------+
| Weights (W) | 2-bit | 16 values per int32 |
| Activations (A) | 8-bit | Standard int8 |
| Computation | GEMV (General Matrix-Vector) |
+------------------+-------------------+------------------+

Why does this matter? Standard LLM inference uses BF16 (16-bit) weights. BitNet compresses weights to 2-bit values, but the GPU needs specialized kernels to actually use this compression efficiently.

Without custom kernels, the decompression overhead negates any compression benefits. That’s why my naive approach didn’t work.

Setting Up GPU Kernels

The solution was to build the custom CUDA kernels. Here’s the complete setup process:

Step 1: Create a Clean Environment

setup-env.sh
conda create --name bitnet-gpu "python<3.13"
conda activate bitnet-gpu
pip install -r requirements.txt

Step 2: Build the CUDA Kernels

This is the critical step I missed initially:

build-kernels.sh
cd bitnet_kernels && bash compile.sh && cd ..

The compilation takes a few minutes and builds three key optimizations:

  1. Weight permutation - Reorganizes weights into 16x32 blocks for optimized memory access
  2. Fast decoding - Packs 16 two-bit values into a single 32-bit integer
  3. dp4a instructions - Uses hardware dot-product acceleration for 8-bit integers

Step 3: Download and Convert the Model

convert-model.sh
# Download from Hugging Face
huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 \
--local-dir ./checkpoints/bitnet-b1.58-2B-4T-bf16
# Convert safetensors to PyTorch format
python ./convert_safetensors.py \
--safetensors_file ./checkpoints/bitnet-b1.58-2B-4T-bf16/model.safetensors \
--output checkpoints/model_state.pt \
--model_name 2B
# Prepare checkpoint for GPU inference
python ./convert_checkpoint.py --input ./checkpoints/model_state.pt

Step 4: Run GPU Inference

run-gpu-inference.sh
python3 ./generate.py ./checkpoints/ --interactive --chat_format

How the GPU Kernel Works

Understanding the kernel optimizations helped me debug issues later. Here’s what happens under the hood:

Weight Permutation

Weights are stored in 16x32 blocks with a specific permutation pattern:

weight-layout.txt
Original Layout (slow): Permuted Layout (fast):
+----+----+----+----+ +----+----+----+----+
| w0 | w1 | w2 | w3 | | w0 | w4 | w8 | w12 |
| w4 | w5 | w6 | w7 | => | w1 | w5 | w9 | w13 |
| w8 | w9 | w10| w11| | w2 | w6 | w10 | w14 |
| w12| w13| w14| w15| | w3 | w7 | w11 | w15 |
+----+----+----+----+ +----+----+----+----+

This pattern enables coalesced memory access on the GPU - adjacent threads read adjacent memory locations, maximizing bandwidth.

Fast Decoding

The decoding pattern interleaves 16 two-bit values:

interleave-pattern.txt
Bit positions in 32-bit integer:
[0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
Why this pattern?
- Extracts 4 values at a time efficiently
- Matches CUDA warp size (32 threads)
- Minimizes bit manipulation overhead

dp4a Instruction

The dp4a instruction is the secret weapon:

dp4a-explanation.txt
dp4a: Dot Product Accumulate (4 elements)
Input: Two 32-bit integers, each containing 4 x 8-bit values
Output: 32-bit accumulated dot product
Example:
a = [a0, a1, a2, a3] (4 x 8-bit in 32-bit word)
b = [b0, b1, b2, b3] (4 x 8-bit in 32-bit word)
result = a0*b0 + a1*b1 + a2*b2 + a3*b3
This is a SINGLE hardware instruction!

Performance Benchmarks

On an NVIDIA A100 40GB, here are the actual numbers I measured:

Kernel-Level Performance

kernel-benchmark.txt
+----------------+-----------------+-----------------+---------+
| Shape (NxK) | W2A8 Latency | BF16 Latency | Speedup |
+----------------+-----------------+-----------------+---------+
| 2560 x 2560 | 13.32 us | 18.32 us | 1.38x |
| 13824 x 2560 | 18.75 us | 59.51 us | 3.17x |
| 20480 x 3200 | 30.99 us | 112.39 us | 3.63x |
+----------------+-----------------+-----------------+---------+

End-to-End Generation Latency

e2e-benchmark.txt
+-------------+--------------+-----------+-----------+---------+
| Input Len | Output Len | BF16 (ms) | W2A8 (ms) | Speedup |
+-------------+--------------+-----------+-----------+---------+
| 64 | 16 | 187.64 | 57.40 | 3.27x |
| 64 | 64 | 683.23 | 221.08 | 3.09x |
| 512 | 64 | 709.65 | 231.82 | 3.06x |
+-------------+--------------+-----------+-----------+---------+

The speedup is most dramatic for larger matrices - exactly what you need for production LLM inference.

Common Mistakes I Made

  1. Skipping the kernel compilation - The default CPU-only path is slow. Always build the CUDA kernels.

  2. Wrong Python version - Python 3.13+ causes compatibility issues. Use python<3.13.

  3. Missing CUDA toolkit - Ensure nvcc is in your PATH before compiling:

    check-cuda.sh
    nvcc --version # Should show CUDA 11.x or 12.x
  4. Insufficient GPU memory - Even though BitNet uses 2-bit weights, the intermediate activations still need memory. The 2B model needs about 4GB VRAM minimum.

When to Use GPU vs CPU

GPU inference makes sense when:

  • You need sub-100ms latency
  • You’re running batch inference
  • You have an NVIDIA GPU with CUDA support

CPU inference is better for:

  • Development and debugging
  • Machines without CUDA GPUs
  • Single-query latency is acceptable

Summary

In this post, I walked through setting up BitNet GPU inference with W2A8 kernels. The key insight is that 1-bit quantization alone isn’t enough - you need custom CUDA kernels that understand the weight format to actually see performance gains.

The three optimizations that make it work are weight permutation for memory access, interleaved decoding for fast extraction, and dp4a instructions for hardware-accelerated dot products. Together, these achieve 2x-3x speedup over BF16 baselines.

If you’re getting slow inference, check that the CUDA kernels compiled successfully and that you’re not falling back to CPU execution.

Final Words + More Resources

My intention with this article was to help others share my knowledge and experience. If you want to contact me, you can contact by email: Email me

Here are also the most important links from this article along with some further resources that will help you in this scope:

Oh, and if you found these resources useful, don’t forget to support me by starring the repo on GitHub!

Comments