[JAX] Improve JAX tutorial documentation#2976
Conversation
Greptile SummaryThis PR reworks the JAX TransformerEngine tutorial by replacing the old
Confidence Score: 5/5Documentation-only change with no impact on TE runtime library code; CI additions point to the correct paths. All changes are documentation, tutorial Python scripts, and CI shell additions. The actual pytest invocations reference the correct docs/examples/jax/ directory and test skip logic is structurally sound. The only findings are a truncated sentence and a missing literalinclude block — neither affects runtime behavior or CI correctness. docs/examples/jax/dense.rst — the DENSE_IMPORTS literalinclude block is absent, leaving a copy-paste gap for tutorial readers Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[docs/index.rst] --> B[docs/examples/te_jax_integration.rst\nLanding page]
B --> C[docs/examples/jax/dense.rst\nAvailable]
B --> D[docs/examples/jax/collective_gemm.rst\nComing soon]
B --> E[docs/examples/jax/attention.rst\nComing soon]
B --> F[docs/examples/jax/expert_parallelism.rst\nComing soon]
C -->|literalinclude markers| G[docs/examples/jax/dense.py\nTutorial source]
C -->|literalinclude output| H[docs/examples/jax/dense.out\nCaptured benchmark output]
G -->|imported by| I[docs/examples/jax/test_dense.py\nPytest entry points]
G -->|uses| J[docs/examples/jax/quickstart_jax_utils.py]
I -->|uses| J
K[qa/L0_jax_unittest/test.sh] -->|pytest docs/examples/jax/| I
L[qa/L1_jax_distributed_unittest/test.sh] -->|pytest -k multi_gpu| I
Reviews (8): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L1 L0 |
KshitijLakhani
left a comment
There was a problem hiding this comment.
Thanks for adding this skeleton.
I like the modular approach, concise explanation and benchmarking.
In general it looks good there might be some working around needed on item placements but I think that's going to be an evolving process.
| .. | ||
| Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
|
||
| See LICENSE for license information. | ||
|
|
||
| JAX: Attention with TransformerEngine | ||
| ===================================== | ||
|
|
||
| **TODO — Coming soon.** | ||
|
|
||
| `← Back to the JAX integration overview <../te_jax_integration.html>`_ |
There was a problem hiding this comment.
Unrelated to attention but looks like you are renaming the dir to examples/jax_examples whereas I think the pytorch side is examples/pytorch ?
I think we could stick with examples/jax - thoughts ?
There was a problem hiding this comment.
Good point, updated to examples/jax
| `Haiku/Flax interop | ||
| <https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html>`_ if you're on | ||
| a different stack.) | ||
| * **Baseline dtype.** bf16 for inputs and parameters. |
There was a problem hiding this comment.
Should we add GB200 (arch) details here rather than adding it in the example module or is that by choice ?
I think there's value in having all examples run on the same arch for consistency.
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Pytest conftest for docs/examples/jax_examples. |
There was a problem hiding this comment.
I see, in our main tests in tests/jax/ we use pytest. In our examples/jax we do use unittest instead, but then run those tests in CI with pytest examples/jax/.... because pytest can also run unittest tests.
I'm ok with standardizing and using pytest everywhere. We already have requirements.txt files for running the examples/jax/mnist or encoder tests, so we could add the pytest dependency there too.
| @@ -0,0 +1,22 @@ | |||
| # Numbers below are illustrative (captured on a GB200). Regenerate with: | |||
| # python3 docs/examples/jax_examples/dense.py > dense.out | |||
| # after substantial code changes. | |||
There was a problem hiding this comment.
"after substantial code changes" ?
There was a problem hiding this comment.
Removed, thanks! This was an artifact of the agent that helped refactor from .ipynb -> .rst/.py
| and your performance comparison will not be accurate. | ||
|
|
||
|
|
||
| 6. Multi-GPU: DP=2 / TP=2 on a single Dense |
There was a problem hiding this comment.
- Single GPU performane
4,5 ? - Multi-GPU: DP=2 / TP=2 on a single Dense
There was a problem hiding this comment.
Good catch, I had it broken into more sections and forgot to update the latest section numbers. Fixed now
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
7464325 to
5432ec6
Compare
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
48884cd to
168cc63
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
54b1a9c to
4c1fec9
Compare
for more information, see https://pre-commit.ci
|
/te-ci |
Description
Reworks tutorial to focus on individual operations and their usage+performance. This will make it clearer to users the impact of each operation and they can focus on trying them out one-at-a-time depending on which are bottlenecks in their models.
Additionally, this switches from notebook
.ipynbfiles to.rstand separate.pyfiles for easier testing in CI to ensure our docs do not become stale and always work with the latest TE version.Type of change
Changes
Checklist: