Skip to content

[JAX] Improve JAX tutorial documentation#2976

Open
jberchtold-nvidia wants to merge 13 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/improve-jax-tutorial
Open

[JAX] Improve JAX tutorial documentation#2976
jberchtold-nvidia wants to merge 13 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/improve-jax-tutorial

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented May 11, 2026

Description

Reworks tutorial to focus on individual operations and their usage+performance. This will make it clearer to users the impact of each operation and they can focus on trying them out one-at-a-time depending on which are bottlenecks in their models.

Additionally, this switches from notebook .ipynb files to .rst and separate .py files for easier testing in CI to ensure our docs do not become stale and always work with the latest TE version.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Rework existing tutorial and replace with new Dense-specific tutorial
  • Placeholders for Attention and MoE
  • Refactor .ipynb notebooks to .rst and .py files for similar appearance in docs but better testability in CI by running .py files

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR reworks the JAX TransformerEngine tutorial by replacing the old te_jax_integration.ipynb notebook with a new RST+Python structure, adding a focused Dense GEMM tutorial with single-GPU and multi-GPU benchmarking, placeholder stubs for Attention/Collective GEMM/Expert Parallelism, and CI hooks that run the tutorial as pytest tests.

  • New tutorial format: dense.py uses # DENSE_*_START/END markers so code blocks are pulled verbatim into dense.rst via literalinclude, and test_dense.py imports lazily from dense to avoid triggering MXFP8 init before skip marks are applied.
  • CI integration: Both L0 and L1 test scripts gain pytest invocations targeting docs/examples/jax/, keeping tutorial code exercised against the live TE version.
  • quickstart_jax_utils.py is moved into the new jax/ subdirectory and extended with a compare_fwd_bwd helper used in numeric correctness tests.

Confidence Score: 5/5

Documentation-only change with no impact on TE runtime library code; CI additions point to the correct paths.

All changes are documentation, tutorial Python scripts, and CI shell additions. The actual pytest invocations reference the correct docs/examples/jax/ directory and test skip logic is structurally sound. The only findings are a truncated sentence and a missing literalinclude block — neither affects runtime behavior or CI correctness.

docs/examples/jax/dense.rst — the DENSE_IMPORTS literalinclude block is absent, leaving a copy-paste gap for tutorial readers

Important Files Changed

Filename Overview
docs/examples/jax/dense.py New tutorial script demonstrating quantized Dense GEMM; well-structured marker blocks for RST literalinclude; MXFP8 init runs at module scope (handled in test_dense.py by deferring imports)
docs/examples/jax/dense.rst New RST tutorial; missing literalinclude for the DENSE_IMPORTS_START block, leaving readers without the required JAX/Flax import preamble
docs/examples/jax/test_dense.py New pytest tests; defers dense imports into test bodies to avoid MXFP8 init during collection; skip guards and numeric tolerances look correct
docs/examples/te_jax_integration.rst New landing-page RST replacing the deleted notebook; one Benchmarking bullet is incomplete (truncated after "warmup")
docs/examples/jax/quickstart_jax_utils.py Moved from docs/examples/ and extended with compare_fwd_bwd helper; logic looks correct for the MXFP8 use-case
qa/L0_jax_unittest/test.sh Adds docs tutorial tests to CI; actual pytest path is correct (jax/), only the comment text retains the stale "jax_examples" name
qa/L1_jax_distributed_unittest/test.sh Adds multi-GPU tutorial test to distributed CI; same stale "jax_examples" in comment only; functional path is correct
docs/examples/jax/dense.out Captured GB200 benchmark output used by RST literalinclude; regeneration command matches file location

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[docs/index.rst] --> B[docs/examples/te_jax_integration.rst\nLanding page]
    B --> C[docs/examples/jax/dense.rst\nAvailable]
    B --> D[docs/examples/jax/collective_gemm.rst\nComing soon]
    B --> E[docs/examples/jax/attention.rst\nComing soon]
    B --> F[docs/examples/jax/expert_parallelism.rst\nComing soon]
    C -->|literalinclude markers| G[docs/examples/jax/dense.py\nTutorial source]
    C -->|literalinclude output| H[docs/examples/jax/dense.out\nCaptured benchmark output]
    G -->|imported by| I[docs/examples/jax/test_dense.py\nPytest entry points]
    G -->|uses| J[docs/examples/jax/quickstart_jax_utils.py]
    I -->|uses| J
    K[qa/L0_jax_unittest/test.sh] -->|pytest docs/examples/jax/| I
    L[qa/L1_jax_distributed_unittest/test.sh] -->|pytest -k multi_gpu| I
Loading

Reviews (8): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Comment thread docs/examples/jax_examples/attention.ipynb Outdated
Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Comment thread docs/examples/jax_examples/moe.ipynb Outdated
Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Comment thread docs/examples/jax/dense.py
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 L0

Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this skeleton.
I like the modular approach, concise explanation and benchmarking.

In general it looks good there might be some working around needed on item placements but I think that's going to be an evolving process.

Comment on lines +1 to +11
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

JAX: Attention with TransformerEngine
=====================================

**TODO — Coming soon.**

`← Back to the JAX integration overview <../te_jax_integration.html>`_
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to attention but looks like you are renaming the dir to examples/jax_examples whereas I think the pytorch side is examples/pytorch ?
I think we could stick with examples/jax - thoughts ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, updated to examples/jax

`Haiku/Flax interop
<https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html>`_ if you're on
a different stack.)
* **Baseline dtype.** bf16 for inputs and parameters.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add GB200 (arch) details here rather than adding it in the example module or is that by choice ?
I think there's value in having all examples run on the same arch for consistency.

Comment thread docs/examples/jax/attention.rst
Comment thread docs/examples/jax/conftest.py Outdated
#
# See LICENSE for license information.

"""Pytest conftest for docs/examples/jax_examples.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the usage of pytest in general, however I think currently the examples/mnist uses the in built Python UT module for the test example.
@phu0ngng and @tdophung it might be good to standardize and use pytest in there too - thoughts ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, in our main tests in tests/jax/ we use pytest. In our examples/jax we do use unittest instead, but then run those tests in CI with pytest examples/jax/.... because pytest can also run unittest tests.

I'm ok with standardizing and using pytest everywhere. We already have requirements.txt files for running the examples/jax/mnist or encoder tests, so we could add the pytest dependency there too.

Comment thread docs/examples/jax_examples/dense.out Outdated
@@ -0,0 +1,22 @@
# Numbers below are illustrative (captured on a GB200). Regenerate with:
# python3 docs/examples/jax_examples/dense.py > dense.out
# after substantial code changes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"after substantial code changes" ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, thanks! This was an artifact of the agent that helped refactor from .ipynb -> .rst/.py

Comment thread docs/examples/jax_examples/dense.rst Outdated
and your performance comparison will not be accurate.


6. Multi-GPU: DP=2 / TP=2 on a single Dense
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Single GPU performane
    4,5 ?
  2. Multi-GPU: DP=2 / TP=2 on a single Dense

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I had it broken into more sections and forgot to update the latest section numbers. Fixed now

Comment thread qa/L0_jax_unittest/test.sh
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/improve-jax-tutorial branch from 7464325 to 5432ec6 Compare May 15, 2026 16:43
Comment thread qa/L0_jax_unittest/test.sh Outdated
Comment thread qa/L1_jax_distributed_unittest/test.sh Outdated
jberchtold-nvidia and others added 2 commits May 15, 2026 09:48
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Comment thread docs/examples/jax/dense.rst Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Comment thread docs/examples/jax/test_dense.py Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/improve-jax-tutorial branch from 48884cd to 168cc63 Compare May 15, 2026 18:36
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/improve-jax-tutorial branch from 54b1a9c to 4c1fec9 Compare May 15, 2026 19:02
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants