Skip to content

Improve TE Group MLP CPU Overhead #2991

Open
zhongbozhu wants to merge 9 commits into
NVIDIA:mainfrom
zhongbozhu:improve_te_group_mlp_cpu_overhead
Open

Improve TE Group MLP CPU Overhead #2991
zhongbozhu wants to merge 9 commits into
NVIDIA:mainfrom
zhongbozhu:improve_te_group_mlp_cpu_overhead

Conversation

@zhongbozhu
Copy link
Copy Markdown
Collaborator

@zhongbozhu zhongbozhu commented May 14, 2026

Description

Improve TE grouped mlp CPU overhead, suppose cuda graph is not enabled.

This is for issue: #2897

E2E Model: Qwen3.5 35B-A3B, nano scale model more prone to CPU overhead
What we measure: suppose we didn't turn on cuda graph, so that split_sizes live on CPU, we measure the CPU side time between the end of H2D copy of split_sizes , to the launch of grouped quantize kernel.

before: 355us
image

after: 167us
image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

zhongbozhu and others added 2 commits May 13, 2026 23:35
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
zhongbozhu and others added 4 commits May 14, 2026 19:18
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: zhongboz <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu marked this pull request as ready for review May 15, 2026 02:25
@zhongbozhu
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 15, 2026

Greptile Summary

This PR reduces CPU overhead in the TE grouped MLP path by replacing four sequential Python/CUDA operations (type cast, splits_to_offsets, int cast, scalar multiply) with a single fused CUDA kernel prepare_grouped_splits_kernel that computes all split metadata (split_sizes_i64, base_offsets, split_points, tensor_offsets) in one pass.

  • Adds nvte_prepare_grouped_splits as a public C API, the corresponding prepare_grouped_splits C++ extension, and a Python binding; forward_grouped_mlp.py is updated to call the fused path.
  • The kernel handles both int32 and int64 input split sizes and supports num_groups > 256 via a chunked single-block prefix-scan loop, mirroring the existing splits_to_offsets_kernel design.
  • A new benchmark_graph_safe_grouped_linear.py and parametric unit tests covering zero-sized groups, MoE-scale group counts, and 16-byte alignment requirements are added alongside the implementation.

Confidence Score: 4/5

Safe to merge after fixing the missing .contiguous() call in the CUDA input path of prepare_grouped_splits.

The CUDA path in misc.cpp assigns the input tensor directly to split_sizes_for_kernel without calling .contiguous(). makeTransformerEngineTensor then passes tensor.data_ptr() to the kernel which reads elements at consecutive raw memory positions with no stride awareness. A non-contiguous CUDA tensor would silently produce wrong split sizes and corrupt the entire grouped GEMM dispatch. The sibling splits_to_offsets guards against this with an explicit .contiguous() call.

transformer_engine/pytorch/csrc/extensions/misc.cpp - the CUDA input branch of prepare_grouped_splits needs a .contiguous() call before passing the tensor to makeTransformerEngineTensor.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/misc.cpp New prepare_grouped_splits fused metadata function; CUDA input path skips .contiguous() which could silently corrupt outputs for strided/sliced CUDA tensors.
transformer_engine/common/common.cu New prepare_grouped_splits_kernel fuses int32/int64 split sizes into int32 split_points + int64 base_offsets + int64 tensor_offsets in one single-block kernel.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Registers prepare_grouped_splits Python binding without py::call_guard<py::gil_scoped_release>(); validation and bulk_allocate run while holding the GIL.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Replaces four separate Python/CUDA operations with a single fused prepare_grouped_splits call; straightforward and correct.
tests/pytorch/test_grouped_tensor.py Adds thorough parametrised tests; one test case comment overstates coverage of the tensor_offsets > 2^31 edge case.
transformer_engine/common/include/transformer_engine/transformer_engine.h Adds well-documented nvte_prepare_grouped_splits public API declaration.
transformer_engine/pytorch/csrc/extensions.h Adds declaration for prepare_grouped_splits alongside splits_to_offsets; no issues.
benchmarks/linear/benchmark_graph_safe_grouped_linear.py New benchmark script for graph-safe grouped MLP with MXFP8/SwiGLU; well-structured.

Reviews (3): Last reviewed commit: "fix" | Re-trigger Greptile

Comment on lines +500 to +502
m.def("prepare_grouped_splits", &transformer_engine::pytorch::prepare_grouped_splits,
"Prepare grouped split metadata from int32 or int64 split sizes", py::arg("split_sizes"),
py::arg("num_groups"), py::arg("logical_last_dim"));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 prepare_grouped_splits is registered without py::call_guard<py::gil_scoped_release>(), unlike the analogous splits_to_offsets binding above it. The GIL is released only for the CUDA kernel launch inside NVTE_SCOPED_GIL_RELEASE, but the input validation and bulk_allocate call still hold the GIL. Given that this PR's purpose is to reduce CPU overhead, holding the GIL over the bulk_allocate + makeTransformerEngineTensor setup work unnecessarily serialises Python threads. NVTE_SCOPED_GIL_RELEASE handles the "already released" case via PyGILState_Check(), so adding the call guard here is safe and consistent.

Suggested change
m.def("prepare_grouped_splits", &transformer_engine::pytorch::prepare_grouped_splits,
"Prepare grouped split metadata from int32 or int64 split sizes", py::arg("split_sizes"),
py::arg("num_groups"), py::arg("logical_last_dim"));
m.def("prepare_grouped_splits", &transformer_engine::pytorch::prepare_grouped_splits,
"Prepare grouped split metadata from int32 or int64 split sizes", py::arg("split_sizes"),
py::arg("num_groups"), py::arg("logical_last_dim"),
py::call_guard<py::gil_scoped_release>());

base_offsets[idx + 1] = prefix;
// cuDNN grouped GEMM expects padded split end offsets as int32. TE
// GroupedTensor metadata keeps the full int64 base_offsets/tensor_offsets.
split_points[idx] = static_cast<int32_t>(prefix);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Silent int32 truncation for large token counts — prefix is int64_t, and the cast to int32_t silently wraps when the cumulative row count exceeds INT32_MAX (~2.1 B tokens). cuDNN itself consumes int32 padded offsets, so the limit is real, but it would be better to surface it as an explicit error than to pass corrupted offsets to cuDNN and get a hard-to-diagnose runtime failure. A guard of the form NVTE_CHECK(prefix <= INT32_MAX, "split_points overflow int32: total tokens exceed INT32_MAX") placed before the cast would make the failure mode visible.

Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
zhongbozhu and others added 3 commits May 15, 2026 01:43
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Comment on lines +45 to +48
if (split_sizes.is_cuda()) {
NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current CUDA device ",
device.index(), ", but got CUDA device ", split_sizes.device().index(), ".");
split_sizes_for_kernel = split_sizes;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing .contiguous() call for CUDA input in the CUDA path. makeTransformerEngineTensor(at::Tensor) uses tensor.data_ptr() directly without respecting strides — confirmed at common.cpp:140. A non-contiguous 1D CUDA tensor (e.g., a strided view or slice) will pass data_ptr() to the kernel, which reads elements at positions [0, 1, 2, ...] in raw memory, silently producing wrong split_sizes_i64, base_offsets, and all downstream outputs. The sibling splits_to_offsets avoids this with an explicit .contiguous() call.

Suggested change
if (split_sizes.is_cuda()) {
NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current CUDA device ",
device.index(), ", but got CUDA device ", split_sizes.device().index(), ".");
split_sizes_for_kernel = split_sizes;
if (split_sizes.is_cuda()) {
NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current CUDA device ",
device.index(), ", but got CUDA device ", split_sizes.device().index(), ".");
split_sizes_for_kernel = split_sizes.contiguous();

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants