Bug Report: Potential Out-of-Bounds Accesses In MoE Kernels
Hey everyone,
I've been diving deep into the CUDA kernels within the vllm-project, and I've stumbled upon some potential out-of-bounds access issues that I wanted to bring to your attention. Specifically, these issues seem to be lurking in the moe_wna16_gemm and moe_wna16_marlin_gemm functions. Let's break it down so we can squash these bugs together!
Current Environment
Before we get into the nitty-gritty, hereβs the output of python collect_env.py to give you a snapshot of my setup:
Your output of `python collect_env.py` here
This should give you some context about the environment where these potential bugs were identified.
π Describe the bug
During static analysis on the CUDA kernels, I've flagged several potential out-of-bounds accesses in both moe_wna16_gemm and moe_wna16_marlin_gemm. Let's dive into the specifics:
1. moe_wna16_gemm
This function seems to have a few spots where we might be stepping outside the bounds of our memory. Let's take a closer look at each one.
(1) expert_ids[blockIdx.x]
- Location: https://github.com/vllm-project/vllm/blob/a00d6254e998be472d8df9dc590784d6facf8d85/csrc/moe/moe_wna16.cu#L39-L42
- Issue: The access
expert_ids[blockIdx.x]could lead to an out-of-bounds read. This happens whenblockIdx.xexceeds the size ofexpert_ids. To avoid memory access violations, it's important to ensure that the block index stays within the bounds of theexpert_idsarray. When the block index goes beyond the size ofexpert_ids, the program attempts to read from a memory location that it is not authorized to access. This results in undefined behavior and can cause the application to crash or produce incorrect results. - Example Scenario:
In this scenario, ifblockIdx.x: 4 expert_ids.shape: [4] BLOCK_SIZE_M: 16 top_k: 4 batch_size: 2 seq_len: 1 sorted_token_ids.shape: [128]blockIdx.xis 4 whileexpert_idsonly has a shape of[4], we're in trouble because we're trying to access an element that doesn't exist. This is a classic case of an out-of-bounds access, which can lead to crashes or unpredictable behavior. We need to ensure thatblockIdx.xis always less than the size ofexpert_idsto prevent this issue.
(2) sorted_token_ids[offset_m]
- Location: https://github.com/vllm-project/vllm/blob/a00d6254e998be472d8df9dc590784d6facf8d85/csrc/moe/moe_wna16.cu#L50-L52
- Issue: Here,
sorted_token_ids[offset_m]may also cause an out-of-bounds access. The index is calculated asblockIdx.x * BLOCK_SIZE_M + m. The risk is that this computed index can exceed the bounds ofsorted_token_ids. This can occur if the calculated index, derived from multiplyingblockIdx.xbyBLOCK_SIZE_Mand addingm, surpasses the maximum allowable index forsorted_token_ids. When the index is out of bounds, the program might attempt to access memory it does not own, leading to a crash or undefined behavior. Careful bounds checking and index validation are essential to mitigate this risk. - Example Scenario:
Withsorted_token_ids.shape: [17] blockIdx.x: 1 BLOCK_SIZE_M: 16 m: 1sorted_token_idshaving a shape of[17],blockIdx.xbeing1,BLOCK_SIZE_Mas16, andmas1, the index becomes1 * 16 + 1 = 17. Boom! Out-of-bounds. Accessingsorted_token_ids[17]when the valid range is0-16will cause issues. It's super important to keep those indices in check, ensuring they don't stray beyond the array's boundaries.
(3) reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp]
- Location: https://github.com/vllm-project/vllm/blob/a00d6254e998be472d8df9dc590784d6facf8d85/csrc/moe/moe_wna16.cu#L112-L116
- Issue: The expression
reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp]might result in an out-of-bounds access. The risk arises ifscales_offset_tmpexceeds the valid range within the reinterpretedexpert_scalesarray. This can happen due to complex calculations involvingoffset_n,size_k,group_size, andGROUPS. Whenscales_offset_tmpis too large, the program attempts to read memory outside the allocated bounds ofexpert_scales, leading to potential crashes or incorrect results. Proper validation ofscales_offset_tmpagainst the size ofexpert_scalesis crucial to prevent this issue.expert_scales = scales + expert_offset / group_size; scales_offset_tmp = (offset_n * size_k + offset_k) / group_size / GROUPS; scales.shape: [60, 2816, 16] GROUPS=2 group_size=128 size_n=2816 size_k=2048 - Example Scenario:
The computedblockIdx.x: 0 blockIdx.y: 21 blockIdx.z: 0 threadIdx.x: 0 BLOCK_SIZE_N: 128 BLOCK_SIZE_K: 256 expert_ids[0]: 60scales_offset_tmpcan exceed the valid index range ofscales, which leads to an out-of-bounds access. Ensuringscales_offset_tmpis within bounds is key to preventing memory errors.
(4) topk_weights[token_index]
- Location: https://github.com/vllm-project/vllm/blob/a00d6254e998be472d8df9dc590784d6facf8d85/csrc/moe/moe_wna16.cu#L210-L214
- Issue: There is a potential out-of-bounds access with
topk_weights[token_index]. Iftoken_indexgoes beyond the bounds oftopk_weights, we've got a problem. The shape oftopk_weightsis[seq_len*batch_size, 4], andtop_kis8. Ensuringtoken_indexstays within these bounds is crucial for stable operation. - Example Scenario:
In this example, ifblockIdx.x: 0 BLOCK_SIZE_M: 16 m: 1 sorted_token_ids[1]: 16 batch_size: 2 seq_len: 2token_indexexceedsseq_len * batch_size * 4 - 1, we will have an out-of-bounds access. Proper index validation is vital.
(5) output[token_index * size_n + offset_n]
- Location: https://github.com/vllm-project/vllm/blob/a00d6254e998be472d8df9dc590784d6facf8d85/csrc/moe/moe_wna16.cu#L216-L217
- Issue: Accessing
output[token_index * size_n + offset_n]might lead to memory access violations. The index calculation here is complex, and it's essential to ensure it remains within the bounds of theoutputarray. Specifically, if the calculated index, based ontoken_index,size_n, andoffset_n, exceeds the maximum allowable index for theoutputarray, it results in accessing memory that the program does not own. This can lead to crashes or unexpected behavior. Validating the calculated index against the size of theoutputarray is crucial for preventing these issues. - Example Scenario:
Here,blockIdx.x: 0 blockIdx.y: 0 threadIdx.x: 0 m: 0 batch_size: 2 seq_len: 9 sorted_token_ids[0]: 73outputhas a shape of[seq_len * batch_size, 4, 2816], andsize_nis2816. The index calculation istoken_index * size_n + offset_n, wheretoken_indexissorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m]. If this computed index exceeds the valid range foroutput, we'll have an out-of-bounds write, which is not good.
2. moe_wna16_marlin_gemm
Now, let's shift our focus to moe_wna16_marlin_gemm and see what potential issues we can uncover.
- Location: https://github.com/vllm-project/vllm/blob/a00d6254e998be472d8df9dc590784d6facf8d85/csrc/moe/marlin_moe_wna16/marlin_template.h#L526-L528
- Issue: The access
expert_ids_ptr[block_id]may cause an out-of-bounds access. The indexblock_idis calculated based onslice_col_parandn_tiles. Ifblock_idexceeds the size ofexpert_ids, we'll have a problem. The shape ofexpert_idsis[seq_len*batch_size+112]. Ensuringblock_idremains within these bounds is crucial for preventing memory errors. Whenblock_idgoes beyond the bounds of theexpert_idsarray, the program attempts to read from an unauthorized memory location. This can lead to a variety of issues, including application crashes and incorrect results. Therefore, it is crucial to validate thatblock_idremains within the size ofexpert_idsto ensure program stability and correctness. - Example Scenario:
With these parameters,seq_len: 2 batch_size: 1 blockIdx.x: 313 sms: 80 num_tokens_past_padded_ptr[0]: 928block_idis computed, and it might exceed the bounds ofexpert_ids, leading to a crash. Double-checking index calculations is super important.
Before submitting a new issue...
- [x] I made sure I searched for relevant issues, and I even chatted with the chatbot on the documentation page. It's always good to cover your bases!
I hope this detailed breakdown helps in addressing these potential issues. Let me know your thoughts, and let's work together to make vllm even more robust! π