Skip to content

GGEMM+srelu kernels for MxFP8 Nemotron#2981

Open
sraman-rgb wants to merge 6 commits into
NVIDIA:mainfrom
sraman-rgb:fc1-srelu-main
Open

GGEMM+srelu kernels for MxFP8 Nemotron#2981
sraman-rgb wants to merge 6 commits into
NVIDIA:mainfrom
sraman-rgb:fc1-srelu-main

Conversation

@sraman-rgb
Copy link
Copy Markdown

@sraman-rgb sraman-rgb commented May 12, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ksivaman
Copy link
Copy Markdown
Member

/te-ci pytorch

@ksivaman
Copy link
Copy Markdown
Member

Please sign-off your commits @sraman-rgb

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 12, 2026

Greptile Summary

This PR refactors the fused GroupedMLP kernel hierarchy into a shared base class and adds ScaledSReLU (squared-ReLU with per-row post-scaling) as a second supported activation alongside the existing GLU variants, wiring up new cuDNN FE grouped_gemm_srelu_wrapper_sm100 / grouped_gemm_dsrelu_wrapper_sm100 kernels.

  • New ScaledSReLU op (activation.py): standard BasicOperation with num_extra_inputs=1, implements both unfused and fused forward/backward paths.
  • Refactored fused forward/backward: common logic moved to abstract base classes; GLU and Unary concrete subclasses wire their respective cuDNN FE kernels.
  • Fusion plumbing (_common.py): fuse_grouped_mlp_ops parameterised with activation_op_types; validate_grouped_mlp_dims extended for unary activations; separate forward/backward fusion functions registered for each activation family.

Confidence Score: 5/5

The refactor is well-structured and the SReLU kernel wiring follows the established GLU pattern closely; the two flagged items are clarifying questions rather than confirmed failures.

The class hierarchy generalisation is clean, dscales_tensor is always an allocated tensor, the recompute-FC2-input path is guarded by multiple independent checks, and test coverage spans both unit-level ScaledSReLU and the full grouped-MLP integration.

forward_grouped_mlp.py (prob_tensor dtype) and _common.py (_nvidia_cudnn_frontend_supports_wgrad guard)

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/activation.py Adds ScaledSReLU with correct unfused fuser_forward/fuser_backward; dtype handling and grad accumulation look sound.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Base class refactor is clean; prob_tensor dtype (BF16/FP16 vs float32 fallback and backward) is an inconsistency worth confirming against the SReLU kernel spec.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py dSReLU backward kernel wiring, recompute path, and grad_scales handling are logically correct; dscales_tensor is always allocated.
transformer_engine/pytorch/ops/_common.py validate_grouped_mlp_dims and fuse_grouped_mlp_ops generalised cleanly; _nvidia_cudnn_frontend_supports_wgrad is a thin alias with no distinct version check.
transformer_engine/pytorch/ops/fused/init.py Export list updated to expose the four new concrete fused-op classes; no issues.
transformer_engine/pytorch/ops/basic/init.py Adds ScaledSReLU to the public API; straightforward.
tests/pytorch/test_fusible_ops.py New test_scaled_srelu unit test and scaled_srelu parametrize for test_grouped_mlp look correct; reference implementation matches expected SReLU*scales semantics.

Sequence Diagram

sequenceDiagram
    participant Fuser
    participant GLUFwd as ForwardGroupedMLP_CuTeGEMMGLU_MXFP8
    participant SReLUFwd as ForwardGroupedMLP_CuTeGEMMUnary_MXFP8
    participant SReLUBwd as BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8
    participant cuDNN as cuDNN FE Kernels

    Fuser->>GLUFwd: fuse_forward_ops GLU pattern
    GLUFwd->>cuDNN: grouped_gemm_glu_wrapper_sm100
    cuDNN-->>GLUFwd: fc2_in scales and activation_in

    Fuser->>SReLUFwd: fuse_forward_srelu_ops SReLU pattern
    SReLUFwd->>cuDNN: grouped_gemm_srelu_wrapper_sm100
    cuDNN-->>SReLUFwd: fc2_in scales and activation_in
    Note over SReLUFwd: Save activation_in and scales
    Note over SReLUFwd: optionally skip saving fc2_x

    Fuser->>SReLUBwd: fuse_backward_srelu_ops
    SReLUBwd->>cuDNN: grouped_gemm_dsrelu_wrapper_sm100
    cuDNN-->>SReLUBwd: FC1 dy tensors and grad_scales
    cuDNN-->>SReLUBwd: optional recomputed FC2 input
    SReLUBwd->>cuDNN: grouped_gemm_wgrad for FC1 and FC2
Loading

Reviews (8): Last reviewed commit: "Address grouped MLP ScaledSReLU review c..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/basic/activation.py
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Signed-off-by: sraman-rgb <sraman@nvidia.com>
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, but we've gotten to the point where we need to start thinking about how to gracefully handle adding new activations. It seems that every model has a different activation function.

Comment on lines +309 to +310
swiglu: Optional[ScaledSwiGLU | ScaledClampedQGeGLU] = None,
srelu: Optional[ScaledSReLU] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not have a single arg?

Suggested change
swiglu: Optional[ScaledSwiGLU | ScaledClampedQGeGLU] = None,
srelu: Optional[ScaledSReLU] = None,
activation: Optional[FusibleOperation] = None,

It seems like we're adding one activation function after another, so we want interfaces that scale gracefully. Also, fused ops are basically internal to TE and these ops in particular are experimental, so backward compatibility is not a major concern.

The forward fused op should have a similar design. Changing to a consistent arg name would also let us get rid of the kwarg name messiness in the op fusion function.

return fc2_out, [(), (), ()]


class ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8(ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an awkward class hierarchy. It would be better to have a virtual base class that both the GLU and non-GLU functions inherit from. The backward fused ops should have a similar design.

While we're messing with the existing classes, we should reconsider the names. The "SwiGLU" op is actually used for both SwiGLU and ClampedQGeGLU, so a name like "GLU" would be better. And there's no reason to expect "SReLU" won't be applied to other activations later, so maybe "Unary" would be more general.

Comment thread tests/pytorch/test_fusible_ops.py Outdated
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
if activation == "scaled_srelu" and quantization != "mxfp8":
pytest.skip("ScaledSReLU grouped MLP fusion is only supported with MXFP8")
if activation == "scaled_srelu" and glu_interleave_size is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This is assuming that activations are GLUs by default, and SReLU is weird. Isn't that kind of backward? In any case, it would be more logical to have a single point where we check is_glu_activation, and then use that everywhere.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 18, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.

*,
fc1: GroupedLinear,
swiglu: ScaledSwiGLU | ScaledClampedQGeGLU,
activation: Optional[FusibleOperation] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Python supports kwargs without defaults.

Suggested change
activation: Optional[FusibleOperation] = None,
activation: Optional[FusibleOperation],

fc2_ctx.dtype = dtype
fc2_ctx.input_requires_grad = input_requires_grad
fc2_ctx.weight_requires_grad = weight_requires_grad
fc2_ctx.recompute_input_from_dsrelu = recompute_srelu_fc2_x
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option isn't supported in the unfused GroupedLinear op. This is a problem because the forward and backward fusions are performed indendently, so everything needs to be compatible with the unfused op interfaces in case there are different forward and backward fusions. However, I also don't want to include this in the unfused op because this is so hyper-specific to this particular fusion.

The requirement that the fused and unfused ops are interchangeable has causing some trouble with the grouped MLP block. It may be worth relaxing, but we would need to have some guarantee that the forward and backward fusions match exactly. I propose we change the op fuser to operate in three stages: fuse forward-backward ops together, fuse forward ops, fuse backward ops. For fused ops with matching forward and backward, we can tolerate tighter forward-backward integration.

Siddhartha Raman S and others added 5 commits May 18, 2026 14:46
Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants