A Practical Guide to MoE Kernel Autotuning in vLLM

If you’ve ever started a vLLM server with a Mixture of Experts model and seen this warning:

WARNING: Using default MoE config. Performance might be sub-optimal!
Config file not found at .../E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition.json

You’re leaving performance on the table. I couldn’t find a proper guide on how to fix this, so I wrote one. This is what I learned doing it myself.

What is MoE kernel autotuning?

Mixture of Experts models route each token to a subset of specialised “expert” networks instead of activating all parameters. vLLM runs these expert computations using a fused MoE kernel. It’s a single GPU kernel that handles expert routing, the matrix multiplications for all active experts, and combining their outputs.

This kernel is written in Triton, OpenAI’s GPU compiler framework. Triton kernels have configurable tile shapes. Block sizes for the M, N, and K dimensions of the matrix multiply, the number of warps, number of pipeline stages, that sort of thing. Different combinations perform very differently depending on your GPU’s memory bandwidth, compute throughput, shared memory size, and cache behaviour.

vLLM ships with tuned configs for some GPU and model combinations. They’re mostly contributed by the community for specific datacenter hardware. Coverage is pretty incomplete though. Even common GPUs like the H100 will show the “sub-optimal” warning for plenty of models. If your specific GPU and model combination isn’t covered, vLLM falls back to conservative defaults that work everywhere but are optimal nowhere.

Autotuning is just the process of benchmarking every reasonable tile configuration on your actual hardware, for your model’s exact MoE dimensions, and saving the fastest one per batch size into a JSON config file.

When does this apply?

This only applies to Mixture of Experts models. If you’re running a dense model like Llama, Qwen 27B, or Mistral, there are no MoE layers and none of this matters.

Models where it does matter include Qwen3.5-35B-A3B, Qwen3.5-122B-A10B, Qwen3.5-397B-A17B, DeepSeek-V3, DeepSeek-R1, GPT-OSS, Nemotron 3 Super, Llama 4 Scout/Maverick, Mixtral, and any other MoE architecture.

You need to tune when you see the “sub-optimal” warning in your vLLM startup logs. If vLLM prints Using configuration from ... for MoE layer. instead, a tuned config already exists and you’re good.

What gets tuned?

The tuning script profiles different Triton kernel tile configurations for your model’s specific MoE dimensions:

  • E is the number of experts (e.g., 256 for Qwen3.5-35B-A3B)
  • N is the intermediate size per expert, possibly divided by TP (e.g., 512)
  • hidden_size is the model hidden dimension (e.g., 2048)
  • topk is the number of experts activated per token (e.g., 8)

For each batch size you specify, the script tries every combination of block sizes, warp counts, and pipeline stages. It times each one and picks the winner. The output is a JSON file mapping batch sizes to optimal kernel configs.

The tuning config is specific to

  • Your GPU. An RTX PRO 6000 Blackwell config won’t be optimal on an A100.
  • The MoE dimensions. Different models with different E/N values need separate configs.
  • Your TP size. Tensor parallelism shards N, which changes the GEMM shape.
  • The dtype. FP8, BF16, and INT4 all have different optimal tile shapes.

You do not need to retune when you upgrade vLLM versions, change context length, change batch settings, or modify other serving parameters.

Prerequisites

You’ll need Docker with NVIDIA runtime, your vLLM Docker image, and enough GPU memory to run the benchmark. It allocates tensors matching the MoE dimensions, not the full model, so you don’t need to fit the whole thing in VRAM. The ray package is required but not included in the vLLM image so it gets installed at runtime.

Step 1: Identify your model’s MoE parameters

Check your vLLM startup logs for the warning. It tells you the expected filename:

Config file not found at .../E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition.json

This gives you E and N directly. If you need to dig deeper, inspect the model’s config.json on Hugging Face or locally. The relevant fields vary by model family and this caught me out:

Model family Experts field Intermediate field TopK field
Qwen3.5 MoE (inside text_config) num_experts moe_intermediate_size num_experts_per_tok
DeepSeek-V3 / R1 n_routed_experts moe_intermediate_size num_experts_per_tok
Llama 4 Scout / Maverick num_local_experts intermediate_size num_experts_per_tok
Mixtral num_local_experts intermediate_size num_experts_per_tok
GPT-OSS num_local_experts intermediate_size num_experts_per_tok

Field names are not standardised across model families. Always verify against the actual config.json for your model before tuning. The benchmark_moe.py script expects the HuggingFace standard fields (num_local_experts, intermediate_size) by default, which is why models using non standard names require workarounds.

Also note that N in the config filename equals intermediate_size / tp_size for some models. Check the warning message to be sure.

For nested VLM configs like Qwen3.5 which wraps MoE params inside text_config, you’ll need to look inside that nested config.

Step 2: Find the model’s config path

If you’re using a cached Hugging Face model, find the config:

docker run --rm -it \
  --entrypoint bash \
  -v huggingface:/root/.cache/huggingface \
  vllm/vllm-openai:cu130-nightly \
  -c "find /root/.cache/huggingface -name 'config.json' -path '*YourModel*'"

Note the full path. You’ll need it in the next step.

Step 3: Run the tuning benchmark

For models where benchmark_moe.py can parse the config directly:

docker run --rm -it \
  --runtime nvidia \
  --ipc host \
  --entrypoint bash \
  -e NVIDIA_VISIBLE_DEVICES=0 \
  -v huggingface:/root/.cache/huggingface \
  -v ./moe-configs:/moe-configs \
  vllm/vllm-openai:cu130-nightly \
  -c '
pip install ray &&
python3 benchmarks/kernels/benchmark_moe.py \
  --model YourOrg/YourModel \
  --tp-size 1 \
  --dtype auto \
  --tune \
  --batch-size 1 2 4 8 16 32 64 \
  --save-dir /moe-configs
'

Replace --tp-size with your actual tensor parallel size and --dtype with fp8_w8a8 if you’re running an FP8 model.

When the config can’t be parsed

The benchmark_moe.py script expects specific field names (num_local_experts, intermediate_size) at the top level of the model config. A lot of newer models, especially VLMs with nested configs or models that use non standard field names, will fail with AttributeError.

The fix is to bypass config parsing entirely by hardcoding the parameters. Use sed to replace the parsing call:

docker run --rm -it \
  --runtime nvidia \
  --ipc host \
  --entrypoint bash \
  -e NVIDIA_VISIBLE_DEVICES=0 \
  -v huggingface:/root/.cache/huggingface \
  -v ./moe-configs:/moe-configs \
  vllm/vllm-openai:cu130-nightly \
  -c '
pip install ray &&
sed -i "s/E, topk, intermediate_size, hidden_size = get_model_params(config)/E, topk, intermediate_size, hidden_size = 256, 8, 512, 2048/" benchmarks/kernels/benchmark_moe.py &&
python3 benchmarks/kernels/benchmark_moe.py \
  --model /path/to/your/cached/model/snapshot \
  --model-prefix text_config \
  --tp-size 1 \
  --dtype auto \
  --tune \
  --batch-size 1 2 4 8 16 32 64 \
  --save-dir /moe-configs
'

Replace the four values (256, 8, 512, 2048) with your model’s actual E, topk, intermediate_size, and hidden_size.

The --model-prefix text_config flag tells the script to look inside a nested config key. Use it for VLMs where MoE parameters live under text_config. For models with flat configs, omit it.

Handling Triton compiler crashes

On some GPU architectures, notably SM120 / Blackwell workstation GPUs, certain tile configurations will crash the Triton compiler with RuntimeError: PassManager::run failed. This is a known bug. Specific tile shapes overflow shared memory or hit MLIR pass failures.

If tuning crashes partway through, reduce the batch sizes to only those that completed successfully:

--batch-size 1 2 4 8

For single user workloads, batch sizes 1 to 8 cover the vast majority of real usage anyway.

Step 4: Deploy the tuned config

The tuning script writes a JSON file like:

E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition.json

To make vLLM use it, set the VLLM_TUNED_CONFIG_FOLDER environment variable to point to the directory containing the file. In Docker Compose:

volumes:
  - ./moe-configs:/moe-configs:ro
environment:
  - VLLM_TUNED_CONFIG_FOLDER=/moe-configs

On restart, your logs should show:

Using configuration from /moe-configs/E=256,N=512,device_name=... for MoE layer.

The “sub-optimal” warning should be gone.

What does the JSON config look like?

The file maps batch sizes to optimal Triton kernel parameters. This is an illustrative example, your actual values will differ:

{
  "1": {
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 64,
    "BLOCK_SIZE_K": 128,
    "GROUP_SIZE_M": 1,
    "num_warps": 4,
    "num_stages": 2
  },
  "2": {
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 64,
    "BLOCK_SIZE_K": 128,
    "GROUP_SIZE_M": 1,
    "num_warps": 4,
    "num_stages": 3
  }
}

At runtime vLLM looks up the current batch size and uses the corresponding tile shape. For batch sizes not in the config it falls back to defaults.

How long does tuning take?

Depends on the number of experts, batch sizes, and GPU. The search space is typically around 1920 configurations. That number comes from the range of BLOCK_M, BLOCK_N, BLOCK_K, num_warps, and num_stages combinations. With 256 experts:

  • Batch sizes 1 to 8 on an RTX PRO 6000 Blackwell took about 52 minutes total
  • The first batch size or two take the longest, maybe 20 to 25 minutes each, because Triton has to JIT compile each kernel variant from scratch
  • Subsequent batch sizes reuse cached compilations and complete much faster, like 2 to 5 minutes each
  • Most of the wall clock time is CPU bound (Triton JIT compilation), not GPU bound
  • GPU power draw stays low during tuning since the GPU is idle most of the time waiting for the CPU to compile each kernel variant

How much does it help?

Reported improvements vary. A community contributed PR tuning DeepSeek-R1 on L20 GPUs measured 5 to 8% decode throughput improvement. The actual gain depends on how far the defaults were from optimal for your specific GPU and model dimensions. The improvement is most noticeable during decode when batch sizes are small and the MoE GEMM is a larger fraction of total latency.

Worth noting: for NVFP4 models on Blackwell GPUs, the main model’s MoE layers typically use CUTLASS kernels, not Triton. So the tuning primarily benefits the MTP speculative decoding drafter or any other component using the Triton MoE backend. Check your logs. If you see Using FLASHINFER_CUTLASS NvFp4 MoE backend, the main model’s MoE isn’t using Triton.

When to retune

You need to regenerate the config when:

  • You move to a different GPU
  • You change tensor parallel size
  • You switch to a model with different E or N dimensions
  • You change quantisation dtype (e.g., BF16 to FP8)

You do not need to retune when:

  • Upgrading vLLM versions
  • Changing context length, batch settings, or other serving parameters
  • Running a different model that happens to share the same E/N dimensions

Contributing configs upstream

If you generate a tuned config for a GPU that vLLM doesn’t ship configs for, consider submitting a PR. The config goes in vllm/model_executor/layers/fused_moe/configs/ and helps everyone running that model on your GPU type. Several community members have contributed configs for H20, L20, B200, and B300 GPUs this way.

Summary

MoE kernel autotuning is basically a one time process that generates optimised Triton kernel configs for your specific GPU and model combination. The workflow is: identify your MoE dimensions from the warning log, run benchmark_moe.py with --tune, save the JSON output, and point vLLM at it with VLLM_TUNED_CONFIG_FOLDER. Takes under an hour and the config lasts until you change GPU or model architecture.

Versions tested

This guide was written and tested on 15 March 2026 with the following versions. Commands and script behaviour may change in future vLLM releases.

  • vLLM: 0.17.1rc1.dev156+g74fe80ee9 (cu130-nightly Docker image)
  • Docker image: vllm/vllm-openai:cu130-nightly
  • CUDA: 13.0
  • GPU: NVIDIA RTX PRO 6000 Blackwell Workstation Edition (96GB GDDR7, SM 12.0)
  • Model: Sehyo/Qwen3.5-35B-A3B-NVFP4 (E=256, N=512, topk=8, hidden_size=2048)
  • Ray: 2.54.0 (installed at runtime via pip)
  • Triton: bundled with vLLM nightly image
  • Python: 3.12

Sources

Disclaimer

The information provided on this website does not constitute professional advice, and should not be relied upon as such. No client relationship is formed by accessing or using this website. Users are advised to seek their own professional advice before acting on any information provided or generated herein. datacraftsman.com.au and its contributors accept no liability for any loss, injury or damage caused by reliance on the information provided or generated.