[FEAT][kernels]: implement fused GRPO loss with in-place group reward normalization#93
Conversation
Implements the group-relative policy optimization loss as a dispatchable
"grpo_loss" op, replacing the excessive broadcasting/allocation of a naive
PyTorch implementation.
NativeGRPOLossOp (pure PyTorch, the correctness oracle):
- group-wise reward normalization (mean / population std, clamp_min),
accepting uniform groups (samples_per_prompt) or CSR-style group_boundaries;
- clipped surrogate objective + k3 reference-KL penalty over active tokens;
- gradient to the policy logps via autograd.
TritonGRPOLossOp (CUDA, fused):
- _group_norm_kernel: one program per group, reward mean/std in registers;
- _grpo_fwd_kernel / _grpo_bwd_kernel: token-parallel surrogate + KL forward
and analytic backward w.r.t. per-token logps, gathering each token's
advantage on the fly (seq_id = idx // completion_len) so no broadcasted
[B, T] advantage tensor is ever materialized;
- wrapped in a torch.autograd.Function for transparent backward.
It composes with the existing logp ops (logits -> [logp op] -> logp ->
[grpo_loss op] -> loss); the loss kernel does not re-implement log-softmax.
Register grpo_loss in the kernel registry: Triton backend first on CUDA/ROCm,
PyTorch native everywhere (and as the CPU/Triton-less fallback).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
tests/test_grpo_loss.py (19 tests) built on rl_engine.testing fixtures
(make_synthetic_rl_kernel_batch, selected_logprobs_reference, reference helpers):
- native op vs the reference group-norm + surrogate/KL recipe;
- Triton forward/backward/grad-scaling/per-sequence-apply vs the native op;
- logp -> grpo_loss pipeline composition (NativeLogpOp and the dispatched
CUDA fused logp), incl. differentiability to logits;
- loss-step checks: masked-token invariance and an SGD step that lowers the
loss (native + Triton);
- registry dispatch.
Triton tests skip without CUDA + Triton.
benchmarks/benchmark_grpo_loss.py: forward / forward+backward latency and peak
VRAM, native vs Triton, across group x sample x length shapes.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds GRPO loss operator implementations (native PyTorch and Triton GPU), registry entries, extensive tests (CUDA-gated for Triton), a CUDA benchmark script measuring latency and VRAM, and operator documentation. ChangesGRPO Loss Implementation & Validation
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant TritonGRPOLossOp
participant GroupNormKernel
participant GRPOLossFunction
participant GRPOFwdKernel
participant GRPOBwdKernel
Trainer->>TritonGRPOLossOp: forward(rewards, current_logps, old_logps, ref_logps, mask)
TritonGRPOLossOp->>GroupNormKernel: compute sample_advantages (bounds)
GroupNormKernel-->>TritonGRPOLossOp: sample_advantages
TritonGRPOLossOp->>GRPOLossFunction: apply(cur, old, ref, adv, mask)
GRPOLossFunction->>GRPOFwdKernel: launch forward kernel (partial accum)
GRPOFwdKernel-->>GRPOLossFunction: policy/kl partials
GRPOLossFunction-->>Trainer: loss, policy_loss, kl
Trainer->>GRPOLossFunction: backward(grad_loss)
GRPOLossFunction->>GRPOBwdKernel: launch backward kernel (grad buffer)
GRPOBwdKernel-->>Trainer: grad_current_logps
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@benchmarks/benchmark_grpo_loss.py`:
- Around line 147-164: The CLI parsing needs value and shape validation: after
parsing in your arg parsing function (where parser.add_argument sets --iters,
--warmup, --configs and args.configs is transformed), enforce that args.iters is
an int > 0 and args.warmup is an int >= 0 (and optionally args.warmup <
args.iters), and validate that when args.configs is provided each
semicolon-separated triple yields exactly three positive integers (e.g.
prompt,samples,length) before assigning to args.configs; if any check fails use
parser.error(...) or raise argparse.ArgumentTypeError with a clear message so
invalid inputs fail fast and prevent downstream divide-by-zero or
malformed-shape errors (reference: parser.add_argument, args.iters, args.warmup,
args.configs, DEFAULT_CONFIGS, and the timing computation that uses _time_ms).
In `@rl_engine/kernels/ops/pytorch/loss/grpo_loss.py`:
- Around line 179-193: Validate that group_boundaries follow CSR semantics
before computing sizes: ensure group_boundaries is monotonic non-decreasing and
that boundaries[0] == 0 and boundaries[-1] == num_sequences; if any of these
fail raise a ValueError. Concretely, in the native fallback in grpo_loss.py (the
block that currently computes boundaries, sizes and returns
torch.repeat_interleave), add checks using the existing tensor
(group_boundaries/boundaries) such as checking torch.any(boundaries[1:] <
boundaries[:-1]) and comparing boundaries[0] and boundaries[-1] to 0 and
num_sequences respectively, and raise clear errors so the CPU/native path
matches TritonGRPOLossOp._build_bounds() behavior.
In `@rl_engine/kernels/ops/triton/triton_grpo_loss.py`:
- Around line 62-68: Validate tensor shapes and devices in
TritonGRPOLossOp.apply before launching kernels: ensure current_logps,
old_logps, ref_logps are CUDA tensors and have the same shape equal to
completion_mask.shape (i.e. [num_sequences, completion_len]), and ensure
sample_advantages is 1D with numel() == num_sequences; raise a clear ValueError
if any check fails. Also check dtype/contiguity if required by
_grpo_fwd_kernel/_grpo_bwd_kernel (they rely on flattened indexing using
n_elements and T) so that seq_id = offs // T and tl.load(adv_seq_ptr + seq_id,
...) cannot read out of bounds. Use the unique symbols current_logps, old_logps,
ref_logps, completion_mask, sample_advantages, TritonGRPOLossOp.apply,
_grpo_fwd_kernel and _grpo_bwd_kernel to locate where to add these guards and
error messages.
In `@tests/test_grpo_loss.py`:
- Around line 7-10: The test imports TritonGRPOLossOp at module scope causing
import-time triton import errors; change to a conditional import so environments
without Triton don't fail test collection: move the import of TritonGRPOLossOp
out of module scope and either wrap it in a try/except ImportError (setting a
sentinel like has_triton) or use pytest.importorskip inside the test(s) that
reference TritonGRPOLossOp, and ensure the tests use the sentinel or the
requires_triton_cuda decorator to skip when Triton is unavailable; update
references to TritonGRPOLossOp in the test functions to use the locally imported
symbol or skip accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: fd5edb10-3987-4fc7-8cbe-c6426334b1bd
📒 Files selected for processing (5)
benchmarks/benchmark_grpo_loss.pyrl_engine/kernels/ops/pytorch/loss/grpo_loss.pyrl_engine/kernels/ops/triton/triton_grpo_loss.pyrl_engine/kernels/registry.pytests/test_grpo_loss.py
| parser.add_argument("--iters", type=int, default=30) | ||
| parser.add_argument("--warmup", type=int, default=10) | ||
| parser.add_argument("--clip-eps", type=float, default=0.2) | ||
| parser.add_argument("--beta", type=float, default=0.04) | ||
| parser.add_argument( | ||
| "--configs", | ||
| type=str, | ||
| default=None, | ||
| help="Semicolon-separated 'prompts,samples,len' triples, e.g. '64,8,512;128,8,1024'.", | ||
| ) | ||
| args = parser.parse_args() | ||
| if args.configs: | ||
| args.configs = [ | ||
| tuple(int(x) for x in triple.split(",")) for triple in args.configs.split(";") | ||
| ] | ||
| else: | ||
| args.configs = DEFAULT_CONFIGS | ||
| return args |
There was a problem hiding this comment.
Validate CLI numeric bounds to prevent runtime crashes.
--iters can be 0 or negative, which makes _time_ms divide by zero (Line 56) or produce invalid timing behavior. --warmup and --configs also lack basic bounds/shape validation, so malformed input fails with unclear errors.
Suggested fix
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--iters", type=int, default=30)
parser.add_argument("--warmup", type=int, default=10)
@@
args = parser.parse_args()
+ if args.iters <= 0:
+ parser.error("--iters must be > 0")
+ if args.warmup < 0:
+ parser.error("--warmup must be >= 0")
+
if args.configs:
- args.configs = [
- tuple(int(x) for x in triple.split(",")) for triple in args.configs.split(";")
- ]
+ parsed = []
+ for triple in args.configs.split(";"):
+ parts = [int(x.strip()) for x in triple.split(",")]
+ if len(parts) != 3:
+ parser.error(f"Invalid --configs entry '{triple}': expected prompts,samples,len")
+ if any(v <= 0 for v in parts):
+ parser.error(f"Invalid --configs entry '{triple}': all values must be > 0")
+ parsed.append(tuple(parts))
+ args.configs = parsed
else:
args.configs = DEFAULT_CONFIGS
return args📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| parser.add_argument("--iters", type=int, default=30) | |
| parser.add_argument("--warmup", type=int, default=10) | |
| parser.add_argument("--clip-eps", type=float, default=0.2) | |
| parser.add_argument("--beta", type=float, default=0.04) | |
| parser.add_argument( | |
| "--configs", | |
| type=str, | |
| default=None, | |
| help="Semicolon-separated 'prompts,samples,len' triples, e.g. '64,8,512;128,8,1024'.", | |
| ) | |
| args = parser.parse_args() | |
| if args.configs: | |
| args.configs = [ | |
| tuple(int(x) for x in triple.split(",")) for triple in args.configs.split(";") | |
| ] | |
| else: | |
| args.configs = DEFAULT_CONFIGS | |
| return args | |
| parser.add_argument("--iters", type=int, default=30) | |
| parser.add_argument("--warmup", type=int, default=10) | |
| parser.add_argument("--clip-eps", type=float, default=0.2) | |
| parser.add_argument("--beta", type=float, default=0.04) | |
| parser.add_argument( | |
| "--configs", | |
| type=str, | |
| default=None, | |
| help="Semicolon-separated 'prompts,samples,len' triples, e.g. '64,8,512;128,8,1024'.", | |
| ) | |
| args = parser.parse_args() | |
| if args.iters <= 0: | |
| parser.error("--iters must be > 0") | |
| if args.warmup < 0: | |
| parser.error("--warmup must be >= 0") | |
| if args.configs: | |
| parsed = [] | |
| for triple in args.configs.split(";"): | |
| parts = [int(x.strip()) for x in triple.split(",")] | |
| if len(parts) != 3: | |
| parser.error(f"Invalid --configs entry '{triple}': expected prompts,samples,len") | |
| if any(v <= 0 for v in parts): | |
| parser.error(f"Invalid --configs entry '{triple}': all values must be > 0") | |
| parsed.append(tuple(parts)) | |
| args.configs = parsed | |
| else: | |
| args.configs = DEFAULT_CONFIGS | |
| return args |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@benchmarks/benchmark_grpo_loss.py` around lines 147 - 164, The CLI parsing
needs value and shape validation: after parsing in your arg parsing function
(where parser.add_argument sets --iters, --warmup, --configs and args.configs is
transformed), enforce that args.iters is an int > 0 and args.warmup is an int >=
0 (and optionally args.warmup < args.iters), and validate that when args.configs
is provided each semicolon-separated triple yields exactly three positive
integers (e.g. prompt,samples,length) before assigning to args.configs; if any
check fails use parser.error(...) or raise argparse.ArgumentTypeError with a
clear message so invalid inputs fail fast and prevent downstream divide-by-zero
or malformed-shape errors (reference: parser.add_argument, args.iters,
args.warmup, args.configs, DEFAULT_CONFIGS, and the timing computation that uses
_time_ms).
| boundaries = torch.as_tensor(group_boundaries, device=device, dtype=torch.long) | ||
| if boundaries.ndim != 1 or boundaries.numel() < 2: | ||
| raise ValueError("group_boundaries must be a 1D tensor of length num_groups + 1.") | ||
| sizes = boundaries[1:] - boundaries[:-1] | ||
|
|
||
| if int(sizes.sum().item()) != num_sequences: | ||
| raise ValueError( | ||
| f"group sizes sum to {int(sizes.sum().item())} but there are " | ||
| f"{num_sequences} sequences." | ||
| ) | ||
| if bool((sizes < 1).any().item()): | ||
| raise ValueError("each group must contain at least one sequence.") | ||
|
|
||
| group_index = torch.arange(sizes.numel(), device=device) | ||
| return torch.repeat_interleave(group_index, sizes) |
There was a problem hiding this comment.
Reject non-CSR group_boundaries in the native fallback.
Line 184 only validates the summed span, so inputs like [1, 3, 6] or [-1, 1, 5] still pass for num_sequences == 5. This path then normalizes the first sizes[0], sizes[1], ... rewards instead of honoring the actual offsets, while TritonGRPOLossOp._build_bounds() rejects the same input. That makes CPU/native and CUDA/Triton disagree on the same API input.
Suggested fix
boundaries = torch.as_tensor(group_boundaries, device=device, dtype=torch.long)
if boundaries.ndim != 1 or boundaries.numel() < 2:
raise ValueError("group_boundaries must be a 1D tensor of length num_groups + 1.")
sizes = boundaries[1:] - boundaries[:-1]
+ if int(boundaries[0].item()) != 0 or int(boundaries[-1].item()) != num_sequences:
+ raise ValueError("group_boundaries must start at 0 and end at num_sequences.")
if int(sizes.sum().item()) != num_sequences:
raise ValueError(
f"group sizes sum to {int(sizes.sum().item())} but there are "
f"{num_sequences} sequences."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/pytorch/loss/grpo_loss.py` around lines 179 - 193,
Validate that group_boundaries follow CSR semantics before computing sizes:
ensure group_boundaries is monotonic non-decreasing and that boundaries[0] == 0
and boundaries[-1] == num_sequences; if any of these fail raise a ValueError.
Concretely, in the native fallback in grpo_loss.py (the block that currently
computes boundaries, sizes and returns torch.repeat_interleave), add checks
using the existing tensor (group_boundaries/boundaries) such as checking
torch.any(boundaries[1:] < boundaries[:-1]) and comparing boundaries[0] and
boundaries[-1] to 0 and num_sequences respectively, and raise clear errors so
the CPU/native path matches TritonGRPOLossOp._build_bounds() behavior.
| seq_id = offs // T | ||
|
|
||
| cur = tl.load(cur_ptr + offs, mask=bound, other=0.0).to(tl.float32) | ||
| old = tl.load(old_ptr + offs, mask=bound, other=0.0).to(tl.float32) | ||
| ref = tl.load(ref_ptr + offs, mask=bound, other=0.0).to(tl.float32) | ||
| adv = tl.load(adv_seq_ptr + seq_id, mask=bound, other=0.0).to(tl.float32) | ||
| active = tl.load(mask_ptr + offs, mask=bound, other=0).to(tl.float32) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="rl_engine/kernels/ops/triton/triton_grpo_loss.py"
# Show the relevant regions: apply(), kernels, and any validation/helpers.
sed -n '1,220p' "$FILE" | nl -ba | sed -n '1,220p'
echo "----"
sed -n '220,420p' "$FILE" | nl -ba | sed -n '220,420p'
echo "---- Kernels: search for seq_id / bounds / tl.load masks ----"
rg -n "seq_id\s*=|tl\.load\((adv_seq_ptr|mask_ptr)" "$FILE" || true
echo "---- Search for shape/device validation already present ----"
rg -n "ValueError|RuntimeError|ndim|shape|numel|device" "$FILE" || true
echo "---- Find where TritonGRPOLossOp.apply is called (optional context) ----"
rg -n "TritonGRPOLossOp\.apply\(" -S . || trueRepository: RL-Align/RL-Kernel
Length of output: 105
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="rl_engine/kernels/ops/triton/triton_grpo_loss.py"
echo "== apply()/wrapper overview =="
awk '{print NR "\t" $0}' "$FILE" | sed -n '1,260p'
echo "----"
awk '{print NR "\t" $0}' "$FILE" | sed -n '260,520p'
echo "== Kernels: search for seq_id / adv_seq_ptr / mask_ptr usage =="
rg -n "seq_id\s*=|tl\.load\(\s*(adv_seq_ptr|old_ptr|ref_ptr|mask_ptr)" "$FILE" || true
rg -n "mask\s*=bound|mask\s*=" "$FILE" || true
echo "== Shape/device validation in this file =="
rg -n "ValueError|RuntimeError|ndim|shape|numel|device|is_cuda" "$FILE" || true
echo "== Find where TritonGRPOLossOp.apply is called =="
rg -n "TritonGRPOLossOp\.apply\(" -S . || trueRepository: RL-Align/RL-Kernel
Length of output: 17340
Validate tensor shapes/devices before launching Triton GRPO loss kernels to avoid out-of-bounds device reads.
TritonGRPOLossOp.apply() only checks current_logps.is_cuda and completion_mask.ndim == 2, but does not validate that:
current_logps,old_logps,ref_logpsmatchcompletion_mask.shape([num_sequences, completion_len]), andsample_advantagesis 1D withnumel() == num_sequences.
In _grpo_fwd_kernel/_grpo_bwd_kernel, n_elements = current_logps.numel() and T = completion_len; they compute seq_id = offs // T and then do tl.load(adv_seq_ptr + seq_id, mask=bound, ...) / tl.load(mask_ptr + offs, mask=bound, ...) where mask=bound only guards offs < n_elements—so shape mismatches can lead to out-of-range reads instead of a Python-side ValueError.
Suggested fix
def apply(
self,
current_logps: torch.Tensor,
old_logps: torch.Tensor,
ref_logps: torch.Tensor,
@@
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Evaluate the loss from per-sequence advantages (gathered per token)."""
if not current_logps.is_cuda:
raise RuntimeError("TritonGRPOLossOp requires CUDA tensors.")
if completion_mask.ndim != 2:
raise ValueError("completion_mask must be 2D [num_sequences, completion_len].")
+ expected_shape = completion_mask.shape
+ if current_logps.shape != expected_shape:
+ raise ValueError("current_logps must have shape [num_sequences, completion_len].")
+ if old_logps.shape != expected_shape or ref_logps.shape != expected_shape:
+ raise ValueError("old_logps and ref_logps must match current_logps.shape.")
+ if sample_advantages.ndim != 1 or sample_advantages.numel() != expected_shape[0]:
+ raise ValueError("sample_advantages must be 1D with one value per sequence.")
+ devices = {
+ current_logps.device,
+ old_logps.device,
+ ref_logps.device,
+ sample_advantages.device,
+ completion_mask.device,
+ }
+ if len(devices) != 1:
+ raise ValueError("All inputs must be on the same device.")
completion_len = completion_mask.shape[1]
return _GRPOLossFunction.apply(
current_logps,
old_logps,
ref_logps,
sample_advantages,
completion_mask,
completion_len,
clip_eps,
beta,
)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/triton/triton_grpo_loss.py` around lines 62 - 68,
Validate tensor shapes and devices in TritonGRPOLossOp.apply before launching
kernels: ensure current_logps, old_logps, ref_logps are CUDA tensors and have
the same shape equal to completion_mask.shape (i.e. [num_sequences,
completion_len]), and ensure sample_advantages is 1D with numel() ==
num_sequences; raise a clear ValueError if any check fails. Also check
dtype/contiguity if required by _grpo_fwd_kernel/_grpo_bwd_kernel (they rely on
flattened indexing using n_elements and T) so that seq_id = offs // T and
tl.load(adv_seq_ptr + seq_id, ...) cannot read out of bounds. Use the unique
symbols current_logps, old_logps, ref_logps, completion_mask, sample_advantages,
TritonGRPOLossOp.apply, _grpo_fwd_kernel and _grpo_bwd_kernel to locate where to
add these guards and error messages.
| from rl_engine.kernels.ops.pytorch.loss.grpo_loss import NativeGRPOLossOp | ||
| from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp | ||
| from rl_engine.kernels.ops.triton.triton_grpo_loss import TritonGRPOLossOp | ||
| from rl_engine.testing import ( |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="tests/test_grpo_loss.py"
# Show imports and the try/except block region (around where comment claims line ~7-28)
nl -ba "$FILE" | sed -n '1,120p'
# Focus on any triton/skip logic and TritonGRPOLossOp references
echo "---- occurrences ----"
rg -n "TritonGRPOLossOp|requires_triton|skipif|_HAS_TRITON|import triton" "$FILE" || trueRepository: RL-Align/RL-Kernel
Length of output: 105
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="tests/test_grpo_loss.py"
echo "---- first 160 lines (numbered) ----"
python3 - <<'PY'
from pathlib import Path
p = Path("tests/test_grpo_loss.py")
for i, line in enumerate(p.read_text().splitlines(), 1):
if i > 160: break
print(f"{i:4d}\t{line}")
PY
echo "---- occurrences ----"
rg -n "TritonGRPOLossOp|requires_triton|skipif|_HAS_TRITON|import triton" "$FILE" || trueRepository: RL-Align/RL-Kernel
Length of output: 7092
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Locate and print the triton GRPO loss module header/imports
FILE="$(fd -a "triton_grpo_loss\.py$" rl_engine | head -n 1)"
echo "Found: $FILE"
python3 - <<'PY'
from pathlib import Path
p = Path("$(fd -a "triton_grpo_loss\.py$" rl_engine | head -n 1)")
text = p.read_text().splitlines()
for i,line in enumerate(text[:120], 1):
print(f"{i:4d}\t{line}")
PY
# Also search inside for direct triton imports/guards
rg -n "import triton|from triton|Triton|_HAS_TRITON|requires_triton|cuda" "$FILE" || trueRepository: RL-Align/RL-Kernel
Length of output: 361
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="$(fd -a "triton_grpo_loss\.py$" rl_engine | head -n 1)"
echo "Found: $FILE"
python3 - <<PY
import os
from pathlib import Path
p = Path(os.environ["FILE"])
text = p.read_text().splitlines()
for i, line in enumerate(text[:160], 1):
print(f"{i:4d}\t{line}")
PY
echo "---- triton-related occurrences in module ----"
rg -n "import triton|from triton|triton\.|Triton|_HAS_TRITON|requires_triton|cuda" "$FILE" || trueRepository: RL-Align/RL-Kernel
Length of output: 292
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="$(fd -a "triton_grpo_loss\.py$" rl_engine | head -n 1)"
echo "Found: $FILE"
python3 - <<PY
from pathlib import Path
p = Path(r"$FILE")
text = p.read_text().splitlines()
for i, line in enumerate(text[:200], 1):
print(f"{i:4d}\t{line}")
PY
echo "---- triton-related occurrences in module ----"
rg -n "import triton|from triton|triton\.|Triton|_HAS_TRITON|requires_triton|cuda" "$FILE" || trueRepository: RL-Align/RL-Kernel
Length of output: 7996
Make the Triton import conditional in this test module.
tests/test_grpo_loss.py imports TritonGRPOLossOp unconditionally at module scope, but rl_engine/kernels/ops/triton/triton_grpo_loss.py imports triton at import time—so environments without Triton fail during test collection before requires_triton_cuda can skip the Triton tests.
Suggested fix
from rl_engine.kernels.ops.pytorch.loss.grpo_loss import NativeGRPOLossOp
from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp
-from rl_engine.kernels.ops.triton.triton_grpo_loss import TritonGRPOLossOp
from rl_engine.testing import (
compute_policy_ratio,
compute_reference_kl,
make_synthetic_rl_kernel_batch,
@@
try:
import triton # noqa: F401
+ from rl_engine.kernels.ops.triton.triton_grpo_loss import TritonGRPOLossOp
_HAS_TRITON = True
except ImportError: # pragma: no cover
+ TritonGRPOLossOp = None # type: ignore[assignment]
_HAS_TRITON = False📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from rl_engine.kernels.ops.pytorch.loss.grpo_loss import NativeGRPOLossOp | |
| from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp | |
| from rl_engine.kernels.ops.triton.triton_grpo_loss import TritonGRPOLossOp | |
| from rl_engine.testing import ( | |
| from rl_engine.kernels.ops.pytorch.loss.grpo_loss import NativeGRPOLossOp | |
| from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp | |
| from rl_engine.testing import ( | |
| compute_policy_ratio, | |
| compute_reference_kl, | |
| make_synthetic_rl_kernel_batch, | |
| ) | |
| try: | |
| import triton # noqa: F401 | |
| from rl_engine.kernels.ops.triton.triton_grpo_loss import TritonGRPOLossOp | |
| _HAS_TRITON = True | |
| except ImportError: # pragma: no cover | |
| TritonGRPOLossOp = None # type: ignore[assignment] | |
| _HAS_TRITON = False |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_grpo_loss.py` around lines 7 - 10, The test imports
TritonGRPOLossOp at module scope causing import-time triton import errors;
change to a conditional import so environments without Triton don't fail test
collection: move the import of TritonGRPOLossOp out of module scope and either
wrap it in a try/except ImportError (setting a sentinel like has_triton) or use
pytest.importorskip inside the test(s) that reference TritonGRPOLossOp, and
ensure the tests use the sentinel or the requires_triton_cuda decorator to skip
when Triton is unavailable; update references to TritonGRPOLossOp in the test
functions to use the locally imported symbol or skip accordingly.
|
Thank you for your contribution! Could you add a readme for this kernel to explain how to use it? You can create your operator in the docs/operators directory. |
Add docs/operators/grpo-loss.md following the operator doc template: purpose, entry point, group specification, backend table, tensor contract, dispatch behavior, reference semantics, benchmark numbers, tests, and limitations. Link it from the operators index. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
@Flink-ddd Documents are added. |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/operators/grpo-loss.md`:
- Around line 12-14: The fenced code block containing the pipeline diagram
"logits --[logp op]--> logps --[grpo_loss op]--> loss" is missing a language
identifier; update the fence to include a language tag (e.g., use ```text) so
the block follows markdownlint MD040; locate the fence around that pipeline
diagram in docs/operators/grpo-loss.md and add the language identifier to the
opening backticks.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ae723212-9f13-4f1b-bd50-a7a521b7588e
📒 Files selected for processing (2)
docs/operators/README.mddocs/operators/grpo-loss.md
✅ Files skipped from review due to trivial changes (1)
- docs/operators/README.md
| ``` | ||
| logits --[logp op]--> logps --[grpo_loss op]--> loss | ||
| ``` |
There was a problem hiding this comment.
Add a language identifier to the fenced code block.
This fence is missing a language tag (markdownlint MD040). Use something like ```text for the pipeline diagram.
Suggested doc fix
-```
+```text
logits --[logp op]--> logps --[grpo_loss op]--> loss</details>
<details>
<summary>🧰 Tools</summary>
<details>
<summary>🪛 markdownlint-cli2 (0.22.1)</summary>
[warning] 12-12: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
</details>
</details>
<details>
<summary>🤖 Prompt for AI Agents</summary>
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In @docs/operators/grpo-loss.md around lines 12 - 14, The fenced code block
containing the pipeline diagram "logits --[logp op]--> logps --[grpo_loss op]-->
loss" is missing a language identifier; update the fence to include a language
tag (e.g., use ```text) so the block follows markdownlint MD040; locate the
fence around that pipeline diagram in docs/operators/grpo-loss.md and add the
language identifier to the opening backticks.
</details>
<!-- fingerprinting:phantom:triton:hawk -->
<!-- cr-comment:v1:13c3393cb81ae72a6c408c62 -->
_Source: Linters/SAST tools_
<!-- This is an auto-generated comment by CodeRabbit -->
| | `rewards` | `[B]` | float | One scalar per sequence. | | ||
| | `completion_mask` | `[B, T]` | bool / {0,1} | 2-D; `True` marks active tokens. | | ||
| | `loss` (output) | scalar | float32 | `policy_loss + beta * kl`. | | ||
| | `policy_loss`, `kl` (output) | scalar | float32 | Detached reporting values. | |
There was a problem hiding this comment.
Fix incorrect gradient contract for policy_loss and kl outputs.
The docs say these are “detached reporting values,” but the implementation returns them directly (no .detach()), so they remain connected to autograd. Please update the wording (or detach in code if that behavior is intentional).
|
Thanks for the contribution, this is a useful GRPO loss addition. One concern: the native path masks inactive tokens only after |
|
Thanks for remind! I will sanitize the exponents before exp, masked_fill on delta and diff, so masked positions get ratio=1 / kl=0 (finite). |
LGTM, Thank you for your contribution! |
Flink-ddd
left a comment
There was a problem hiding this comment.
Awesome work and great memory optimizations! Just a few numerical and performance edge cases to address before merging.
| bounds_ptr, # int32[num_groups + 1], CSR-style group offsets | ||
| adv_ptr, # float32[N], per-sequence advantages (output) | ||
| eps, | ||
| GROUP_BLOCK: tl.constexpr, |
There was a problem hiding this comment.
in _group_norm_kernel launch config
I noticed you are launching _group_norm_kernel with GROUP_BLOCK = _next_pow2(max_group).
While this works perfectly for standard PPO/GRPO setups where samples_per_prompt is small (e.g., 8, 16, or 64), it will crash or fail to compile if a user provides highly skewed group_boundaries resulting in a max_group > 2048. Triton has hardware limits on block sizes.
Suggestion:
For this PR, we can keep the current design but please add an explicit check and throw a ValueError in TritonGRPOLossOp._build_bounds if max_group > 1024.
For future scaling, we might need a tiled reduction kernel here, but a safety assertion is enough for now to prevent cryptic CUDA errors.
There was a problem hiding this comment.
Ok, I will use try-except to catch it
| policy_term = tl.where(keep, policy_term, 0.0) | ||
| kl_term = tl.where(keep, kl_term, 0.0) | ||
|
|
||
| tl.atomic_add(partials_ptr + 0, tl.sum(policy_term, axis=0)) |
There was a problem hiding this comment.
Atomic Contention in _grpo_fwd_kernel
tl.atomic_add(partials_ptr + 0, tl.sum(policy_term, axis=0))
tl.atomic_add(partials_ptr + 1, tl.sum(kl_term, axis=0))
Having all SM blocks hammer the exact same 2 global memory addresses with atomic adds will cause severe memory contention on high-end GPUs like H100s.
Suggestion:
Instead of using tl.atomic_add, it is generally much faster to have the kernel output a [grid_size] tensor of block-level sums, and then perform a simple torch.sum() over it in PyTorch (in _GRPOLossFunction.forward).
Since this only happens once per forward pass, it’s not a blocker for merging, but consider changing this if you plan to submit further optimizations! I'm happy to approve this PR once the sqrt numerical trap in the native op is addressed.
|
@Flink-ddd Thx for suggestion! The requested changes are done. |
Flink-ddd
left a comment
There was a problem hiding this comment.
LGTM now, Thank you for your contribution.
#46
Summary
Implements the GRPO loss as a dispatchable
grpo_lossop, eliminating the broadcasting/allocation overhead of a naive PyTorch implementation.NativeGRPOLossOp: group-wise reward normalization (mean / population std,clamp_min), clipped surrogate objective + k3 reference-KL over active tokens.TritonGRPOLossOp:_group_norm_kernel(per-group reward stats in registers) + token-parallel_grpo_fwd_kernel/_grpo_bwd_kernelwith an analytic backward w.r.t. per-token logps. Each token gathers its advantage on the fly (seq_id = idx // completion_len), so no broadcasted[B, T]advantage tensor is ever materialized. Wrapped in atorch.autograd.Function.grpo_lossresolves to the Triton backend on CUDA/ROCm, PyTorch native on CPU / when Triton is unavailable.Tests (
tests/test_grpo_loss.py)Built on
rl_engine.testingfixtures (make_synthetic_rl_kernel_batch,selected_logprobs_reference, reference helpers):logp → grpo_losspipeline (NativeLogpOp + dispatched CUDA fused logp), incl.differentiability to logits;
Benchmark (
benchmarks/benchmark_grpo_loss.py)~3–4.7× faster forward, ~3× forward+backward, ~10× less peak VRAM.
Notes
old_logps/ref_logpsare precomputed per-token constants.Summary by CodeRabbit
New Features
Benchmarks
Tests
Documentation