Skip to content

Add intermediate value captures (extends #4925)#5257

Merged
copybara-service[bot] merged 1 commit intogoogle:mainfrom
samanklesaria:capture_intms
Mar 19, 2026
Merged

Add intermediate value captures (extends #4925)#5257
copybara-service[bot] merged 1 commit intogoogle:mainfrom
samanklesaria:capture_intms

Conversation

@samanklesaria
Copy link
Copy Markdown
Collaborator

@samanklesaria samanklesaria commented Feb 16, 2026

This PR adds similar functionality to #4925, but uses flax Variables instead of hijax Boxes.

  • We add a capture decorator that makes the given function returns captured intermediate values along with the ordinary return value.
  • Module methods sow and perturb can be used within capture to store associate values or their gradients with names without adding new attributes on the module. They are stored within variables of the appropriate type in the __captures__ tuple.
  • Perturb just adds values read from the from the __captures__ tuple to the given arguments. This means that if the init argument is passed to capture (used to pre-specify values within __captures__), the values in the init argument will be used for every perturb call in the code, and we can take the gradient with respect to these values to inspect values in the backward pass.
  • We also add a sow_output decorator that adds sow instructions to every method of a module and its submodules. This is useful for debugging.
  • This code is tested under the transformations vmap, jit, and grad.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @samanklesaria, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a simplified and flexible mechanism for capturing intermediate values and their gradients within JAX and Flax models. It provides a context manager and helper functions that enable users to inspect internal states during both forward and backward passes, significantly aiding in debugging, analysis, and understanding model behavior.

Highlights

  • Intermediate Value Capture: Introduced a capture_intms context manager that allows capturing intermediate values in arbitrary JAX code, independent of Flax.
  • Forward and Backward Capturing: Added capture_fwd and capture_bwd functions to store associate values or their gradients with specified names during the forward and backward passes, respectively.
  • Flax Module Integration: Integrated capture_fwd and capture_bwd methods directly into Flax Modules, which automatically prepend their __module_path__ attributes to the provided names for better organization in nested modules.
  • Simplified Implementation: The entire implementation is concise, under 50 lines of code, offering a much simpler approach compared to previous attempts for similar functionality.
  • Comprehensive Testing: Included new tests that cover forward-only, backward-only, combined capturing, and scenarios involving nested modules, ensuring robust functionality.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flax/nnx/init.py
    • Imported capture_intms, capture_fwd, get_intms, and capture_bwd from the new capture module.
  • flax/nnx/capture.py
    • Added a new file containing the core logic for intermediate value capture.
    • Implemented capture_intms as a context manager to manage the scope of intermediate value capturing.
    • Defined capture_fwd to record intermediate values during the forward pass, using jax.lax.stop_gradient.
    • Implemented get_intms to retrieve the currently captured intermediate values.
    • Created capture_bwd using jax.custom_vjp to record gradients during the backward pass.
  • flax/nnx/module.py
    • Imported capture_fwd and capture_bwd from the new capture module.
    • Added capture_fwd method to the Module class, allowing module-aware forward value capturing.
    • Added capture_bwd method to the Module class, enabling module-aware backward gradient capturing.
    • Implemented _save_paths and _del_paths methods to dynamically add and remove __module_path__ attributes for sub-modules within a capture_intms context.
  • tests/nnx/capture_test.py
    • Added a new test file to validate the functionality of the intermediate value capture system.
    • Included tests for capture_fwd in various JIT compilation scenarios.
    • Provided tests for capture_bwd to ensure correct gradient capturing.
    • Verified combined capture_fwd and capture_bwd functionality.
    • Tested the behavior of intermediate value capturing within nested Flax modules, ensuring correct path-based naming.
Activity
  • No human activity (comments, reviews) has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a simplified mechanism for capturing intermediate values in Flax, using hijax.Box. The implementation is concise and the new functionality is well-tested. However, there are a few areas for improvement. The use of a global list for managing contexts is not thread-safe and should be replaced with thread-local storage. There are also some naming inconsistencies (intms instead of intermediates) and potential AttributeErrors that could arise from incorrect usage. My review includes specific suggestions to address these points to improve robustness and clarity.

Comment thread flax/nnx/capture.py Outdated
Comment thread flax/nnx/module.py Outdated
Comment thread flax/nnx/module.py Outdated
Comment thread flax/nnx/capture.py Outdated
Comment thread flax/nnx/capture.py Outdated
Comment thread flax/nnx/capture.py Outdated
Comment thread flax/nnx/__init__.py Outdated
Comment thread tests/nnx/capture_test.py Outdated
@samanklesaria samanklesaria force-pushed the capture_intms branch 2 times, most recently from 195c144 to 86b57f2 Compare February 16, 2026 21:09
@samanklesaria
Copy link
Copy Markdown
Collaborator Author

Note that this approach works with vmap when using a jax branch in which Box supports vmap: jax-ml/jax#35276

@samanklesaria samanklesaria force-pushed the capture_intms branch 3 times, most recently from fe5e00b to 95b4602 Compare February 27, 2026 15:40
@samanklesaria samanklesaria changed the title Add intermediate value captures based on hijax.Box (simplification of #4925) Add intermediate value captures (extends #4925) Feb 27, 2026
@samanklesaria samanklesaria force-pushed the capture_intms branch 2 times, most recently from db32210 to 4a185ab Compare February 27, 2026 15:54
Copy link
Copy Markdown
Collaborator

@cgarciae cgarciae left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Samuel! My main comment is that we should not introduce new APIs, instead we should make sow and perturb work with capture_intermediates (when enabled) so current code can work with tree-mode and hijax.

Comment thread flax/nnx/capture.py Outdated
Comment thread flax/nnx/module.py Outdated
@samanklesaria samanklesaria force-pushed the capture_intms branch 2 times, most recently from f38d427 to a6fb425 Compare February 27, 2026 20:49
Comment thread flax/nnx/transforms/autodiff.py Outdated
Comment thread flax/nnx/capture.py Outdated
@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Comment thread flax/nnx/module.py Outdated
Comment thread flax/nnx/transforms/autodiff.py Outdated
@samanklesaria samanklesaria requested a review from cgarciae March 3, 2026 20:36
@samanklesaria samanklesaria force-pushed the capture_intms branch 4 times, most recently from 9f1f49a to 22bdf6b Compare March 4, 2026 15:39
Comment thread docs_nnx/guides/extracting_intermediates.md Outdated
Comment thread docs_nnx/guides/extracting_intermediates.md Outdated
Comment thread docs_nnx/guides/extracting_intermediates.md Outdated
Comment thread docs_nnx/guides/extracting_intermediates.md Outdated
Comment thread tests/nnx/capture_test.py Outdated
@samanklesaria samanklesaria force-pushed the capture_intms branch 12 times, most recently from 9427203 to 6f63e59 Compare March 17, 2026 20:45
Comment thread docs_nnx/guides/extracting_intermediates.md Outdated
Comment thread docs_nnx/guides/extracting_intermediates.md
Comment thread docs_nnx/guides/extracting_intermediates.md
@samanklesaria samanklesaria force-pushed the capture_intms branch 4 times, most recently from a93d9a3 to e7805e3 Compare March 17, 2026 23:05
@copybara-service copybara-service Bot merged commit e5a6ff6 into google:main Mar 19, 2026
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants