Fix GAE masked critic values bootstrapping valid advantages#447
Open
haoyang9804 wants to merge 1 commit into
Open
Fix GAE masked critic values bootstrapping valid advantages#447haoyang9804 wants to merge 1 commit into
haoyang9804 wants to merge 1 commit into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
ROLL's
roll.utils.functionals.compute_advantage()can let critic values from masked response positions change valid-token GAE advantages. The trigger isadv_estimator="gae"with aresponse_mask=0padding or filtered token whose criticvaluesentry is non-zero or invalid. The function storedvalues * response_maskback into the batch, but still passed the original unmaskedvaluestensor tocompute_gae_advantage_return(), so a masked slot could bootstrap earlier valid tokens without a crash.This patch applies
response_masktovaluesbefore calling the GAE helper. The new test covers the concrete boundary where a masked padding value of100.0used to change both valid-token advantages.Concrete triggering example
Buggy output from upstream
main:{ "stored_values_after_compute_advantage": [[0.0, 0.0, 0.0]], "advantages": [[5.699999809265137, 6.0, 0.0]], "returns": [[5.699999809265137, 6.0, 0.0]] }Wrong intermediate value: the masked critic value
100.0at the finalresponse_mask=0slot is still used asnextvalueswhile computing the previous valid token, so the valid advantages become5.699999809265137and6.0.Fixed output:
{ "stored_values_after_compute_advantage": [[0.0, 0.0, 0.0]], "advantages": [[0.949999988079071, 1.0, 0.0]], "returns": [[0.949999988079071, 1.0, 0.0]] }Fixed value: after masking critic values before GAE, the helper sees
[[0.0, 0.0, 0.0]]forvalues, so the valid advantages are0.949999988079071and1.0.The shared invariant is that masked values must be selected out before they enter reward, advantage, return, or loss arithmetic. Multiplying only for storage is not enough when a later helper still receives the unmasked tensor.
Real rollout reproduction
This was also validated with a local real model rollout. The runner uses
Qwen/Qwen2.5-0.5B-Instructfrom the local Hugging Face cache, performs a realAutoModelForCausalLM.generate(), then sends the generated tokens through ROLL's realpostprocess_generate(),get_sample_level_mask(),reward_postprocess(),compute_token_reward(), andcompute_advantage()path. The only hook is the fault construction step after rollout: it injects a finite critic value into a rollout padding slot whereresponse_mask=0.Recipe:
{ "kind": "rl_sentinel_validation_recipe", "schema_version": 1, "bug_id": "ROLL-GAE-MASKED-VALUE-BOOTSTRAP", "target": "roll", "validation_mode": "real_hf_model_rollout_plus_roll_training_signal_hook", "model": "${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}", "requirements": { "target_repo": "${TARGET_REPO}", "output_dir": "${OUTPUT_DIR}", "required_modules": [ "roll.utils.functionals", "roll.distributed.scheduler.protocol", "transformers", "torch", "tensordict" ] }, "preserved_infrastructure": [ "real local HF model generation", "roll.utils.functionals.postprocess_generate", "roll.utils.functionals.get_sample_level_mask", "roll.utils.functionals.reward_postprocess", "roll.utils.functionals.compute_token_reward", "roll.utils.functionals.compute_advantage", "roll.distributed.scheduler.protocol.DataProto", "tensordict.TensorDict" ], "hooked_boundary": "after real model rollout, before critic values enter compute_advantage", "constructed_scenario": { "prompt": "Answer with one short word: ok", "max_new_tokens": 2, "extra_response_padding_slots": 3, "response_level_reward": 1.0, "masked_critic_value": 100.0, "gamma": 1.0, "lambd": 0.95, "adv_estimator": "gae" }, "replaced_component": null }Runner:
Hook:
Real rollout output on unpatched
alibaba/ROLLmain:{ "target_commit": "c09bc8bc9f43", "status": "reproduced", "candidate_bug_reproduced": true, "real_rollout": { "backend": "transformers.AutoModelForCausalLM.generate", "model_id": "Qwen/Qwen2.5-0.5B-Instruct", "response": "\n\nSure", "final_response_mask": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0]] }, "observed": { "advantages": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.699999809265137, 6.0, -0.0, 0.0, 0.0]] }, "expected_sanitized_values": { "advantages": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.949999988079071, 1.0, 0.0, 0.0, 0.0]] }, "attack_effect": { "max_abs_valid_advantage_delta": 5.0, "quiet_non_crash": true, "finite_signal_corruption": true } }Real rollout output on this branch:
{ "target_commit": "090ac94658ad", "status": "fixed", "candidate_bug_reproduced": false, "observed": { "advantages": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.949999988079071, 1.0, 0.0, 0.0, 0.0]] }, "expected_sanitized_values": { "advantages": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.949999988079071, 1.0, 0.0, 0.0, 0.0]] }, "attack_effect": { "max_abs_valid_advantage_delta": 0.0, "quiet_non_crash": true, "finite_signal_corruption": false } }Reproduction recipe
{ "kind": "rl_sentinel_validation_recipe", "schema_version": 1, "target": "roll", "validation_mode": "real_functionals_boundary_hook", "hooked_boundary": "roll.utils.functionals.compute_advantage", "requirements": { "roll_repo": "${ROLL_REPO}", "output_dir": "${OUTPUT_DIR}", "required_modules": [ "roll.utils.functionals", "roll.distributed.scheduler.protocol", "torch", "tensordict" ] }, "constructed_scenario": { "response_mask": [[1.0, 1.0, 0.0]], "token_level_rewards": [[0.0, 1.0, 0.0]], "values": [[0.0, 0.0, 100.0]], "gamma": 1.0, "lambd": 0.95, "adv_estimator": "gae" }, "expected_unpatched": { "advantages": [[5.699999809265137, 6.0, 0.0]] }, "expected_fixed": { "advantages": [[0.949999988079071, 1.0, 0.0]] }, "replaced_component": null }Validation runner
Save this as
run_gae_masked_value_validation.shand run it withROLL_REPOpointing at either an unpatched or patched ROLL checkout.Observed output
On unpatched
alibaba/ROLLmain, the same boundary reproduced the bug:{ "status": "reproduced", "observed_unpatched": { "stored_values_after_compute_advantage": [[0.0, 0.0, 0.0]], "advantages": [[5.699999809265137, 6.0, -0.0]], "returns": [[5.699999809265137, 6.0, 0.0]] }, "expected_fixed": { "advantages": [[0.949999988079071, 1.0, 0.0]], "returns": [[0.949999988079071, 1.0, 0.0]] }, "attack_effect": { "max_abs_valid_advantage_delta": 5.0, "quiet_non_crash": true, "finite_signal_corruption": true } }On this branch:
{ "status": "fixed", "observed_fixed": { "stored_values_after_compute_advantage": [[0.0, 0.0, 0.0]], "advantages": [[0.949999988079071, 1.0, 0.0]], "raw_advantages_before_final_mask": [[0.949999988079071, 1.0, 0.0]], "returns": [[0.949999988079071, 1.0, 0.0]], "all_finite": true }, "attack_effect": { "candidate_bug_reproduced": false, "candidate_bug_fixed": true, "max_abs_valid_advantage_delta": 0.0, "max_abs_valid_return_delta": 0.0, "stored_values_match_response_mask": true } }Root cause
compute_advantage()did this in the GAE branch:The batch stored masked values, but
compute_gae_advantage_return()still received the rawvalues. Because GAE usesnextvalues = values[:, t + 1], a masked value at a later padding or filtered token can bootstrap into earlier valid tokens.Fix
The GAE branch now masks
valuesbefore both storage and helper invocation:The regression test constructs the exact tensor boundary above and asserts
values,advantages, andreturnsall match the masked-value oracle.Tests and checks
PYTHONPATH="${REPAIR_REPO}:${PYTHONPATH}" python3 -m pytest -q tests/utils/test_functionals.pyOutput:
Output:
The commit hook also passed the same pre-commit hooks during
git commit.Contribution and duplicate checks
Target upstream repo:
alibaba/ROLL.Contribution files checked locally:
README.md.pre-commit-config.yamlpyproject.tomlNo root
CONTRIBUTING.mdwas present in the checkout. The relevant local tooling is black, isort, autoflake, and flake8 through pre-commit, with 119-character line length.Duplicate checks performed:
BUG_FINDINGS.mdledger for GAE, masked critic values, bootstrap,response_mask,compute_advantage, and related terms.gaeor masked-value fixes.pr_drafts/for related ROLL masked-value PR drafts.ROLL-GAE-MASKED-VALUE-BOOTSTRAPand related boundary names.alibaba/ROLLPRs and issues for the boundary and symptom terms.valuesintocompute_gae_advantage_return()after assigning masked values back to the batch.Result: no exact upstream issue, PR, branch, PR draft, or ledger duplicate was found for ROLL's finite masked critic-value GAE bootstrap bug.