Skip to content

[FEAT][kernels]: implement fused GRPO loss with in-place group reward normalization#93

Merged
Flink-ddd merged 6 commits into
RL-Align:mainfrom
KJLdefeated:feat/grpo-loss-pytorch-op
Jun 10, 2026
Merged

[FEAT][kernels]: implement fused GRPO loss with in-place group reward normalization#93
Flink-ddd merged 6 commits into
RL-Align:mainfrom
KJLdefeated:feat/grpo-loss-pytorch-op

Conversation

@KJLdefeated

@KJLdefeated KJLdefeated commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

#46

Summary

Implements the GRPO loss as a dispatchable grpo_loss op, eliminating the broadcasting/allocation overhead of a naive PyTorch implementation.

  1. NativeGRPOLossOp: group-wise reward normalization (mean / population std, clamp_min), clipped surrogate objective + k3 reference-KL over active tokens.
  2. TritonGRPOLossOp: _group_norm_kernel (per-group reward stats in registers) + token-parallel _grpo_fwd_kernel/_grpo_bwd_kernel with an analytic backward w.r.t. per-token logps. Each token gathers its advantage on the fly (seq_id = idx // completion_len), so no broadcasted [B, T] advantage tensor is ever materialized. Wrapped in a torch.autograd.Function.
  3. Registry dispatchgrpo_loss resolves to the Triton backend on CUDA/ROCm, PyTorch native on CPU / when Triton is unavailable.

Tests (tests/test_grpo_loss.py)

Built on rl_engine.testing fixtures (make_synthetic_rl_kernel_batch, selected_logprobs_reference, reference helpers):

  • native op vs the reference group-norm + surrogate/KL recipe;
  • Triton forward / backward / grad-scaling / per-sequence-apply vs native;
  • logp → grpo_loss pipeline (NativeLogpOp + dispatched CUDA fused logp), incl.
    differentiability to logits;
  • loss-step: masked-token invariance + an SGD step that lowers the loss;
  • registry dispatch.

Benchmark (benchmarks/benchmark_grpo_loss.py)

shape (P×S×L) tokens fwd speedup fwd+bwd speedup VRAM (native→triton)
64×8×512 0.26M 4.70× 2.98× 10 MB → 1 MB
128×8×1024 1.05M 3.16× 2.66× 40 MB → 4 MB
256×16×1024 4.19M 2.97× 2.99× 160 MB → 16 MB

~3–4.7× faster forward, ~3× forward+backward, ~10× less peak VRAM.

Notes

  • old_logps/ref_logps are precomputed per-token constants.
  • Group normalization stays a separate lightweight kernel by design (scalar per sequence; fusing it into the token kernel would trade away token-parallelism).

Summary by CodeRabbit

  • New Features

    • Added GRPO (Group Relative Policy Optimization) loss with both high-performance CUDA (Triton) and native PyTorch backends; runtime selects the best available implementation.
  • Benchmarks

    • Added a CLI benchmark tool measuring forward and forward+backward latency and peak VRAM across configurable shapes, outputting markdown tables.
  • Tests

    • Added comprehensive tests for loss values, gradients, grouping modes, backend parity, masked tokens, and integration.
  • Documentation

    • Added operator docs and index entry describing usage, tensor contracts, and performance notes.

KJLdefeated and others added 2 commits June 9, 2026 13:02
Implements the group-relative policy optimization loss as a dispatchable
"grpo_loss" op, replacing the excessive broadcasting/allocation of a naive
PyTorch implementation.

NativeGRPOLossOp (pure PyTorch, the correctness oracle):
  - group-wise reward normalization (mean / population std, clamp_min),
    accepting uniform groups (samples_per_prompt) or CSR-style group_boundaries;
  - clipped surrogate objective + k3 reference-KL penalty over active tokens;
  - gradient to the policy logps via autograd.

TritonGRPOLossOp (CUDA, fused):
  - _group_norm_kernel: one program per group, reward mean/std in registers;
  - _grpo_fwd_kernel / _grpo_bwd_kernel: token-parallel surrogate + KL forward
    and analytic backward w.r.t. per-token logps, gathering each token's
    advantage on the fly (seq_id = idx // completion_len) so no broadcasted
    [B, T] advantage tensor is ever materialized;
  - wrapped in a torch.autograd.Function for transparent backward.

It composes with the existing logp ops (logits -> [logp op] -> logp ->
[grpo_loss op] -> loss); the loss kernel does not re-implement log-softmax.

Register grpo_loss in the kernel registry: Triton backend first on CUDA/ROCm,
PyTorch native everywhere (and as the CPU/Triton-less fallback).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
tests/test_grpo_loss.py (19 tests) built on rl_engine.testing fixtures
(make_synthetic_rl_kernel_batch, selected_logprobs_reference, reference helpers):
  - native op vs the reference group-norm + surrogate/KL recipe;
  - Triton forward/backward/grad-scaling/per-sequence-apply vs the native op;
  - logp -> grpo_loss pipeline composition (NativeLogpOp and the dispatched
    CUDA fused logp), incl. differentiability to logits;
  - loss-step checks: masked-token invariance and an SGD step that lowers the
    loss (native + Triton);
  - registry dispatch.
Triton tests skip without CUDA + Triton.

benchmarks/benchmark_grpo_loss.py: forward / forward+backward latency and peak
VRAM, native vs Triton, across group x sample x length shapes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@coderabbitai

coderabbitai Bot commented Jun 9, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 827c74a4-b469-44c7-8614-e1a114d8a4b4

📥 Commits

Reviewing files that changed from the base of the PR and between a4ca453 and f1bc93a.

📒 Files selected for processing (2)
  • rl_engine/kernels/ops/pytorch/loss/grpo_loss.py
  • rl_engine/kernels/ops/triton/triton_grpo_loss.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • rl_engine/kernels/ops/pytorch/loss/grpo_loss.py
  • rl_engine/kernels/ops/triton/triton_grpo_loss.py

📝 Walkthrough

Walkthrough

Adds GRPO loss operator implementations (native PyTorch and Triton GPU), registry entries, extensive tests (CUDA-gated for Triton), a CUDA benchmark script measuring latency and VRAM, and operator documentation.

Changes

GRPO Loss Implementation & Validation

Layer / File(s) Summary
Native PyTorch GRPO Loss Implementation
rl_engine/kernels/ops/pytorch/loss/grpo_loss.py
NativeGRPOLossOp computes per-sequence normalized advantages from rewards (grouped by samples_per_prompt or group_boundaries), expands them to per-token advantages, applies PPO-style clipped surrogate policy loss, computes a masked k3 reference-KL term, and returns (loss_with_beta_kl, policy_loss, kl). Includes input validation and masked-mean helpers.
Triton GPU GRPO Loss Implementation
rl_engine/kernels/ops/triton/triton_grpo_loss.py
Three Triton kernels (group-normalization, fused forward accumulation, backward gradient) plus _GRPOLossFunction implement normalized advantages, fused forward reduction with atomic partials, and backward gradient computation. TritonGRPOLossOp exposes group_advantages, apply, and forward with CSR-style group bounds.
Registry Integration & Comprehensive Test Suite
rl_engine/kernels/registry.py, tests/test_grpo_loss.py
Registry gains grpo_loss op type with TRITON/PYTORCH backends and platform priority mapping. Tests cover native correctness (advantage normalization, loss terms, gradients, validation), Triton vs native forward/backward equivalence (CUDA-gated), integration composition with logp ops, masked-token invariance, SGD-like loss reduction, and registry dispatch behavior.
Performance Benchmark Suite
benchmarks/benchmark_grpo_loss.py
CUDA-only benchmark comparing Native vs Triton across configurable shapes; measures forward latency (CUDA events), forward+backward latency (torch.autograd.grad), and peak extra VRAM (CUDA memory peak deltas). CLI args support iterations, warmup, clip_eps, beta, and custom shape lists; outputs a GitHub-formatted table.
Docs
docs/operators/README.md, docs/operators/grpo-loss.md
Adds operator index entry and a detailed page documenting the grpo_loss tensor contract, backends, semantics, matching tolerances, tests, and benchmark usage.

Sequence Diagram(s)

sequenceDiagram
  participant Trainer
  participant TritonGRPOLossOp
  participant GroupNormKernel
  participant GRPOLossFunction
  participant GRPOFwdKernel
  participant GRPOBwdKernel
  Trainer->>TritonGRPOLossOp: forward(rewards, current_logps, old_logps, ref_logps, mask)
  TritonGRPOLossOp->>GroupNormKernel: compute sample_advantages (bounds)
  GroupNormKernel-->>TritonGRPOLossOp: sample_advantages
  TritonGRPOLossOp->>GRPOLossFunction: apply(cur, old, ref, adv, mask)
  GRPOLossFunction->>GRPOFwdKernel: launch forward kernel (partial accum)
  GRPOFwdKernel-->>GRPOLossFunction: policy/kl partials
  GRPOLossFunction-->>Trainer: loss, policy_loss, kl
  Trainer->>GRPOLossFunction: backward(grad_loss)
  GRPOLossFunction->>GRPOBwdKernel: launch backward kernel (grad buffer)
  GRPOBwdKernel-->>Trainer: grad_current_logps
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Poem

🐇 I hopped through tensors, masks, and clips,

grouped rewards into tiny sips,
Triton raced while PyTorch kept pace,
gradients marched and tests found their place,
two kernels now share the same loss trace.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and accurately describes the main change: implementing a fused GRPO loss operator with in-place group reward normalization, which is the primary feature across all modified files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@benchmarks/benchmark_grpo_loss.py`:
- Around line 147-164: The CLI parsing needs value and shape validation: after
parsing in your arg parsing function (where parser.add_argument sets --iters,
--warmup, --configs and args.configs is transformed), enforce that args.iters is
an int > 0 and args.warmup is an int >= 0 (and optionally args.warmup <
args.iters), and validate that when args.configs is provided each
semicolon-separated triple yields exactly three positive integers (e.g.
prompt,samples,length) before assigning to args.configs; if any check fails use
parser.error(...) or raise argparse.ArgumentTypeError with a clear message so
invalid inputs fail fast and prevent downstream divide-by-zero or
malformed-shape errors (reference: parser.add_argument, args.iters, args.warmup,
args.configs, DEFAULT_CONFIGS, and the timing computation that uses _time_ms).

In `@rl_engine/kernels/ops/pytorch/loss/grpo_loss.py`:
- Around line 179-193: Validate that group_boundaries follow CSR semantics
before computing sizes: ensure group_boundaries is monotonic non-decreasing and
that boundaries[0] == 0 and boundaries[-1] == num_sequences; if any of these
fail raise a ValueError. Concretely, in the native fallback in grpo_loss.py (the
block that currently computes boundaries, sizes and returns
torch.repeat_interleave), add checks using the existing tensor
(group_boundaries/boundaries) such as checking torch.any(boundaries[1:] <
boundaries[:-1]) and comparing boundaries[0] and boundaries[-1] to 0 and
num_sequences respectively, and raise clear errors so the CPU/native path
matches TritonGRPOLossOp._build_bounds() behavior.

In `@rl_engine/kernels/ops/triton/triton_grpo_loss.py`:
- Around line 62-68: Validate tensor shapes and devices in
TritonGRPOLossOp.apply before launching kernels: ensure current_logps,
old_logps, ref_logps are CUDA tensors and have the same shape equal to
completion_mask.shape (i.e. [num_sequences, completion_len]), and ensure
sample_advantages is 1D with numel() == num_sequences; raise a clear ValueError
if any check fails. Also check dtype/contiguity if required by
_grpo_fwd_kernel/_grpo_bwd_kernel (they rely on flattened indexing using
n_elements and T) so that seq_id = offs // T and tl.load(adv_seq_ptr + seq_id,
...) cannot read out of bounds. Use the unique symbols current_logps, old_logps,
ref_logps, completion_mask, sample_advantages, TritonGRPOLossOp.apply,
_grpo_fwd_kernel and _grpo_bwd_kernel to locate where to add these guards and
error messages.

In `@tests/test_grpo_loss.py`:
- Around line 7-10: The test imports TritonGRPOLossOp at module scope causing
import-time triton import errors; change to a conditional import so environments
without Triton don't fail test collection: move the import of TritonGRPOLossOp
out of module scope and either wrap it in a try/except ImportError (setting a
sentinel like has_triton) or use pytest.importorskip inside the test(s) that
reference TritonGRPOLossOp, and ensure the tests use the sentinel or the
requires_triton_cuda decorator to skip when Triton is unavailable; update
references to TritonGRPOLossOp in the test functions to use the locally imported
symbol or skip accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: fd5edb10-3987-4fc7-8cbe-c6426334b1bd

📥 Commits

Reviewing files that changed from the base of the PR and between 12fc220 and 25e19bf.

📒 Files selected for processing (5)
  • benchmarks/benchmark_grpo_loss.py
  • rl_engine/kernels/ops/pytorch/loss/grpo_loss.py
  • rl_engine/kernels/ops/triton/triton_grpo_loss.py
  • rl_engine/kernels/registry.py
  • tests/test_grpo_loss.py

Comment on lines +147 to +164
parser.add_argument("--iters", type=int, default=30)
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--clip-eps", type=float, default=0.2)
parser.add_argument("--beta", type=float, default=0.04)
parser.add_argument(
"--configs",
type=str,
default=None,
help="Semicolon-separated 'prompts,samples,len' triples, e.g. '64,8,512;128,8,1024'.",
)
args = parser.parse_args()
if args.configs:
args.configs = [
tuple(int(x) for x in triple.split(",")) for triple in args.configs.split(";")
]
else:
args.configs = DEFAULT_CONFIGS
return args

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Validate CLI numeric bounds to prevent runtime crashes.

--iters can be 0 or negative, which makes _time_ms divide by zero (Line 56) or produce invalid timing behavior. --warmup and --configs also lack basic bounds/shape validation, so malformed input fails with unclear errors.

Suggested fix
 def parse_args():
     parser = argparse.ArgumentParser(description=__doc__)
     parser.add_argument("--iters", type=int, default=30)
     parser.add_argument("--warmup", type=int, default=10)
@@
     args = parser.parse_args()
+    if args.iters <= 0:
+        parser.error("--iters must be > 0")
+    if args.warmup < 0:
+        parser.error("--warmup must be >= 0")
+
     if args.configs:
-        args.configs = [
-            tuple(int(x) for x in triple.split(",")) for triple in args.configs.split(";")
-        ]
+        parsed = []
+        for triple in args.configs.split(";"):
+            parts = [int(x.strip()) for x in triple.split(",")]
+            if len(parts) != 3:
+                parser.error(f"Invalid --configs entry '{triple}': expected prompts,samples,len")
+            if any(v <= 0 for v in parts):
+                parser.error(f"Invalid --configs entry '{triple}': all values must be > 0")
+            parsed.append(tuple(parts))
+        args.configs = parsed
     else:
         args.configs = DEFAULT_CONFIGS
     return args
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument("--iters", type=int, default=30)
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--clip-eps", type=float, default=0.2)
parser.add_argument("--beta", type=float, default=0.04)
parser.add_argument(
"--configs",
type=str,
default=None,
help="Semicolon-separated 'prompts,samples,len' triples, e.g. '64,8,512;128,8,1024'.",
)
args = parser.parse_args()
if args.configs:
args.configs = [
tuple(int(x) for x in triple.split(",")) for triple in args.configs.split(";")
]
else:
args.configs = DEFAULT_CONFIGS
return args
parser.add_argument("--iters", type=int, default=30)
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--clip-eps", type=float, default=0.2)
parser.add_argument("--beta", type=float, default=0.04)
parser.add_argument(
"--configs",
type=str,
default=None,
help="Semicolon-separated 'prompts,samples,len' triples, e.g. '64,8,512;128,8,1024'.",
)
args = parser.parse_args()
if args.iters <= 0:
parser.error("--iters must be > 0")
if args.warmup < 0:
parser.error("--warmup must be >= 0")
if args.configs:
parsed = []
for triple in args.configs.split(";"):
parts = [int(x.strip()) for x in triple.split(",")]
if len(parts) != 3:
parser.error(f"Invalid --configs entry '{triple}': expected prompts,samples,len")
if any(v <= 0 for v in parts):
parser.error(f"Invalid --configs entry '{triple}': all values must be > 0")
parsed.append(tuple(parts))
args.configs = parsed
else:
args.configs = DEFAULT_CONFIGS
return args
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_grpo_loss.py` around lines 147 - 164, The CLI parsing
needs value and shape validation: after parsing in your arg parsing function
(where parser.add_argument sets --iters, --warmup, --configs and args.configs is
transformed), enforce that args.iters is an int > 0 and args.warmup is an int >=
0 (and optionally args.warmup < args.iters), and validate that when args.configs
is provided each semicolon-separated triple yields exactly three positive
integers (e.g. prompt,samples,length) before assigning to args.configs; if any
check fails use parser.error(...) or raise argparse.ArgumentTypeError with a
clear message so invalid inputs fail fast and prevent downstream divide-by-zero
or malformed-shape errors (reference: parser.add_argument, args.iters,
args.warmup, args.configs, DEFAULT_CONFIGS, and the timing computation that uses
_time_ms).

Comment on lines +179 to +193
boundaries = torch.as_tensor(group_boundaries, device=device, dtype=torch.long)
if boundaries.ndim != 1 or boundaries.numel() < 2:
raise ValueError("group_boundaries must be a 1D tensor of length num_groups + 1.")
sizes = boundaries[1:] - boundaries[:-1]

if int(sizes.sum().item()) != num_sequences:
raise ValueError(
f"group sizes sum to {int(sizes.sum().item())} but there are "
f"{num_sequences} sequences."
)
if bool((sizes < 1).any().item()):
raise ValueError("each group must contain at least one sequence.")

group_index = torch.arange(sizes.numel(), device=device)
return torch.repeat_interleave(group_index, sizes)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject non-CSR group_boundaries in the native fallback.

Line 184 only validates the summed span, so inputs like [1, 3, 6] or [-1, 1, 5] still pass for num_sequences == 5. This path then normalizes the first sizes[0], sizes[1], ... rewards instead of honoring the actual offsets, while TritonGRPOLossOp._build_bounds() rejects the same input. That makes CPU/native and CUDA/Triton disagree on the same API input.

Suggested fix
         boundaries = torch.as_tensor(group_boundaries, device=device, dtype=torch.long)
         if boundaries.ndim != 1 or boundaries.numel() < 2:
             raise ValueError("group_boundaries must be a 1D tensor of length num_groups + 1.")
         sizes = boundaries[1:] - boundaries[:-1]
+        if int(boundaries[0].item()) != 0 or int(boundaries[-1].item()) != num_sequences:
+            raise ValueError("group_boundaries must start at 0 and end at num_sequences.")
 
         if int(sizes.sum().item()) != num_sequences:
             raise ValueError(
                 f"group sizes sum to {int(sizes.sum().item())} but there are "
                 f"{num_sequences} sequences."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/loss/grpo_loss.py` around lines 179 - 193,
Validate that group_boundaries follow CSR semantics before computing sizes:
ensure group_boundaries is monotonic non-decreasing and that boundaries[0] == 0
and boundaries[-1] == num_sequences; if any of these fail raise a ValueError.
Concretely, in the native fallback in grpo_loss.py (the block that currently
computes boundaries, sizes and returns torch.repeat_interleave), add checks
using the existing tensor (group_boundaries/boundaries) such as checking
torch.any(boundaries[1:] < boundaries[:-1]) and comparing boundaries[0] and
boundaries[-1] to 0 and num_sequences respectively, and raise clear errors so
the CPU/native path matches TritonGRPOLossOp._build_bounds() behavior.

Comment on lines +62 to +68
seq_id = offs // T

cur = tl.load(cur_ptr + offs, mask=bound, other=0.0).to(tl.float32)
old = tl.load(old_ptr + offs, mask=bound, other=0.0).to(tl.float32)
ref = tl.load(ref_ptr + offs, mask=bound, other=0.0).to(tl.float32)
adv = tl.load(adv_seq_ptr + seq_id, mask=bound, other=0.0).to(tl.float32)
active = tl.load(mask_ptr + offs, mask=bound, other=0).to(tl.float32)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="rl_engine/kernels/ops/triton/triton_grpo_loss.py"

# Show the relevant regions: apply(), kernels, and any validation/helpers.
sed -n '1,220p' "$FILE" | nl -ba | sed -n '1,220p'
echo "----"
sed -n '220,420p' "$FILE" | nl -ba | sed -n '220,420p'

echo "---- Kernels: search for seq_id / bounds / tl.load masks ----"
rg -n "seq_id\s*=|tl\.load\((adv_seq_ptr|mask_ptr)" "$FILE" || true

echo "---- Search for shape/device validation already present ----"
rg -n "ValueError|RuntimeError|ndim|shape|numel|device" "$FILE" || true

echo "---- Find where TritonGRPOLossOp.apply is called (optional context) ----"
rg -n "TritonGRPOLossOp\.apply\(" -S . || true

Repository: RL-Align/RL-Kernel

Length of output: 105


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="rl_engine/kernels/ops/triton/triton_grpo_loss.py"

echo "== apply()/wrapper overview =="
awk '{print NR "\t" $0}' "$FILE" | sed -n '1,260p'
echo "----"
awk '{print NR "\t" $0}' "$FILE" | sed -n '260,520p'

echo "== Kernels: search for seq_id / adv_seq_ptr / mask_ptr usage =="
rg -n "seq_id\s*=|tl\.load\(\s*(adv_seq_ptr|old_ptr|ref_ptr|mask_ptr)" "$FILE" || true
rg -n "mask\s*=bound|mask\s*=" "$FILE" || true

echo "== Shape/device validation in this file =="
rg -n "ValueError|RuntimeError|ndim|shape|numel|device|is_cuda" "$FILE" || true

echo "== Find where TritonGRPOLossOp.apply is called =="
rg -n "TritonGRPOLossOp\.apply\(" -S . || true

Repository: RL-Align/RL-Kernel

Length of output: 17340


Validate tensor shapes/devices before launching Triton GRPO loss kernels to avoid out-of-bounds device reads.

TritonGRPOLossOp.apply() only checks current_logps.is_cuda and completion_mask.ndim == 2, but does not validate that:

  • current_logps, old_logps, ref_logps match completion_mask.shape ([num_sequences, completion_len]), and
  • sample_advantages is 1D with numel() == num_sequences.

In _grpo_fwd_kernel/_grpo_bwd_kernel, n_elements = current_logps.numel() and T = completion_len; they compute seq_id = offs // T and then do tl.load(adv_seq_ptr + seq_id, mask=bound, ...) / tl.load(mask_ptr + offs, mask=bound, ...) where mask=bound only guards offs < n_elements—so shape mismatches can lead to out-of-range reads instead of a Python-side ValueError.

Suggested fix
     def apply(
         self,
         current_logps: torch.Tensor,
         old_logps: torch.Tensor,
         ref_logps: torch.Tensor,
@@
     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         """Evaluate the loss from per-sequence advantages (gathered per token)."""
         if not current_logps.is_cuda:
             raise RuntimeError("TritonGRPOLossOp requires CUDA tensors.")
         if completion_mask.ndim != 2:
             raise ValueError("completion_mask must be 2D [num_sequences, completion_len].")
+        expected_shape = completion_mask.shape
+        if current_logps.shape != expected_shape:
+            raise ValueError("current_logps must have shape [num_sequences, completion_len].")
+        if old_logps.shape != expected_shape or ref_logps.shape != expected_shape:
+            raise ValueError("old_logps and ref_logps must match current_logps.shape.")
+        if sample_advantages.ndim != 1 or sample_advantages.numel() != expected_shape[0]:
+            raise ValueError("sample_advantages must be 1D with one value per sequence.")
+        devices = {
+            current_logps.device,
+            old_logps.device,
+            ref_logps.device,
+            sample_advantages.device,
+            completion_mask.device,
+        }
+        if len(devices) != 1:
+            raise ValueError("All inputs must be on the same device.")
         completion_len = completion_mask.shape[1]
         return _GRPOLossFunction.apply(
             current_logps,
             old_logps,
             ref_logps,
             sample_advantages,
             completion_mask,
             completion_len,
             clip_eps,
             beta,
         )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/triton/triton_grpo_loss.py` around lines 62 - 68,
Validate tensor shapes and devices in TritonGRPOLossOp.apply before launching
kernels: ensure current_logps, old_logps, ref_logps are CUDA tensors and have
the same shape equal to completion_mask.shape (i.e. [num_sequences,
completion_len]), and ensure sample_advantages is 1D with numel() ==
num_sequences; raise a clear ValueError if any check fails. Also check
dtype/contiguity if required by _grpo_fwd_kernel/_grpo_bwd_kernel (they rely on
flattened indexing using n_elements and T) so that seq_id = offs // T and
tl.load(adv_seq_ptr + seq_id, ...) cannot read out of bounds. Use the unique
symbols current_logps, old_logps, ref_logps, completion_mask, sample_advantages,
TritonGRPOLossOp.apply, _grpo_fwd_kernel and _grpo_bwd_kernel to locate where to
add these guards and error messages.

Comment thread tests/test_grpo_loss.py
Comment on lines +7 to +10
from rl_engine.kernels.ops.pytorch.loss.grpo_loss import NativeGRPOLossOp
from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp
from rl_engine.kernels.ops.triton.triton_grpo_loss import TritonGRPOLossOp
from rl_engine.testing import (

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="tests/test_grpo_loss.py"

# Show imports and the try/except block region (around where comment claims line ~7-28)
nl -ba "$FILE" | sed -n '1,120p'

# Focus on any triton/skip logic and TritonGRPOLossOp references
echo "---- occurrences ----"
rg -n "TritonGRPOLossOp|requires_triton|skipif|_HAS_TRITON|import triton" "$FILE" || true

Repository: RL-Align/RL-Kernel

Length of output: 105


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="tests/test_grpo_loss.py"

echo "---- first 160 lines (numbered) ----"
python3 - <<'PY'
from pathlib import Path
p = Path("tests/test_grpo_loss.py")
for i, line in enumerate(p.read_text().splitlines(), 1):
    if i > 160: break
    print(f"{i:4d}\t{line}")
PY

echo "---- occurrences ----"
rg -n "TritonGRPOLossOp|requires_triton|skipif|_HAS_TRITON|import triton" "$FILE" || true

Repository: RL-Align/RL-Kernel

Length of output: 7092


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Locate and print the triton GRPO loss module header/imports
FILE="$(fd -a "triton_grpo_loss\.py$" rl_engine | head -n 1)"
echo "Found: $FILE"

python3 - <<'PY'
from pathlib import Path
p = Path("$(fd -a "triton_grpo_loss\.py$" rl_engine | head -n 1)")
text = p.read_text().splitlines()
for i,line in enumerate(text[:120], 1):
    print(f"{i:4d}\t{line}")
PY

# Also search inside for direct triton imports/guards
rg -n "import triton|from triton|Triton|_HAS_TRITON|requires_triton|cuda" "$FILE" || true

Repository: RL-Align/RL-Kernel

Length of output: 361


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="$(fd -a "triton_grpo_loss\.py$" rl_engine | head -n 1)"
echo "Found: $FILE"

python3 - <<PY
import os
from pathlib import Path
p = Path(os.environ["FILE"])
text = p.read_text().splitlines()
for i, line in enumerate(text[:160], 1):
    print(f"{i:4d}\t{line}")
PY

echo "---- triton-related occurrences in module ----"
rg -n "import triton|from triton|triton\.|Triton|_HAS_TRITON|requires_triton|cuda" "$FILE" || true

Repository: RL-Align/RL-Kernel

Length of output: 292


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="$(fd -a "triton_grpo_loss\.py$" rl_engine | head -n 1)"
echo "Found: $FILE"

python3 - <<PY
from pathlib import Path
p = Path(r"$FILE")
text = p.read_text().splitlines()
for i, line in enumerate(text[:200], 1):
    print(f"{i:4d}\t{line}")
PY

echo "---- triton-related occurrences in module ----"
rg -n "import triton|from triton|triton\.|Triton|_HAS_TRITON|requires_triton|cuda" "$FILE" || true

Repository: RL-Align/RL-Kernel

Length of output: 7996


Make the Triton import conditional in this test module.

tests/test_grpo_loss.py imports TritonGRPOLossOp unconditionally at module scope, but rl_engine/kernels/ops/triton/triton_grpo_loss.py imports triton at import time—so environments without Triton fail during test collection before requires_triton_cuda can skip the Triton tests.

Suggested fix
 from rl_engine.kernels.ops.pytorch.loss.grpo_loss import NativeGRPOLossOp
 from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp
-from rl_engine.kernels.ops.triton.triton_grpo_loss import TritonGRPOLossOp
 from rl_engine.testing import (
     compute_policy_ratio,
     compute_reference_kl,
     make_synthetic_rl_kernel_batch,
@@
 try:
     import triton  # noqa: F401
+    from rl_engine.kernels.ops.triton.triton_grpo_loss import TritonGRPOLossOp
 
     _HAS_TRITON = True
 except ImportError:  # pragma: no cover
+    TritonGRPOLossOp = None  # type: ignore[assignment]
     _HAS_TRITON = False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from rl_engine.kernels.ops.pytorch.loss.grpo_loss import NativeGRPOLossOp
from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp
from rl_engine.kernels.ops.triton.triton_grpo_loss import TritonGRPOLossOp
from rl_engine.testing import (
from rl_engine.kernels.ops.pytorch.loss.grpo_loss import NativeGRPOLossOp
from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp
from rl_engine.testing import (
compute_policy_ratio,
compute_reference_kl,
make_synthetic_rl_kernel_batch,
)
try:
import triton # noqa: F401
from rl_engine.kernels.ops.triton.triton_grpo_loss import TritonGRPOLossOp
_HAS_TRITON = True
except ImportError: # pragma: no cover
TritonGRPOLossOp = None # type: ignore[assignment]
_HAS_TRITON = False
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_grpo_loss.py` around lines 7 - 10, The test imports
TritonGRPOLossOp at module scope causing import-time triton import errors;
change to a conditional import so environments without Triton don't fail test
collection: move the import of TritonGRPOLossOp out of module scope and either
wrap it in a try/except ImportError (setting a sentinel like has_triton) or use
pytest.importorskip inside the test(s) that reference TritonGRPOLossOp, and
ensure the tests use the sentinel or the requires_triton_cuda decorator to skip
when Triton is unavailable; update references to TritonGRPOLossOp in the test
functions to use the locally imported symbol or skip accordingly.

@Flink-ddd

Copy link
Copy Markdown
Collaborator

Thank you for your contribution! Could you add a readme for this kernel to explain how to use it? You can create your operator in the docs/operators directory.

KJLdefeated and others added 2 commits June 9, 2026 16:40
Add docs/operators/grpo-loss.md following the operator doc template: purpose,
entry point, group specification, backend table, tensor contract, dispatch
behavior, reference semantics, benchmark numbers, tests, and limitations.
Link it from the operators index.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@KJLdefeated

Copy link
Copy Markdown
Contributor Author

@Flink-ddd Documents are added.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/operators/grpo-loss.md`:
- Around line 12-14: The fenced code block containing the pipeline diagram
"logits --[logp op]--> logps --[grpo_loss op]--> loss" is missing a language
identifier; update the fence to include a language tag (e.g., use ```text) so
the block follows markdownlint MD040; locate the fence around that pipeline
diagram in docs/operators/grpo-loss.md and add the language identifier to the
opening backticks.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ae723212-9f13-4f1b-bd50-a7a521b7588e

📥 Commits

Reviewing files that changed from the base of the PR and between 25e19bf and e851c48.

📒 Files selected for processing (2)
  • docs/operators/README.md
  • docs/operators/grpo-loss.md
✅ Files skipped from review due to trivial changes (1)
  • docs/operators/README.md

Comment on lines +12 to +14
```
logits --[logp op]--> logps --[grpo_loss op]--> loss
```

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add a language identifier to the fenced code block.

This fence is missing a language tag (markdownlint MD040). Use something like ```text for the pipeline diagram.

Suggested doc fix
-```
+```text
 logits --[logp op]--> logps --[grpo_loss op]--> loss
</details>

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 markdownlint-cli2 (0.22.1)</summary>

[warning] 12-12: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

</details>

</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @docs/operators/grpo-loss.md around lines 12 - 14, The fenced code block
containing the pipeline diagram "logits --[logp op]--> logps --[grpo_loss op]-->
loss" is missing a language identifier; update the fence to include a language
tag (e.g., use ```text) so the block follows markdownlint MD040; locate the
fence around that pipeline diagram in docs/operators/grpo-loss.md and add the
language identifier to the opening backticks.


</details>

<!-- fingerprinting:phantom:triton:hawk -->

<!-- cr-comment:v1:13c3393cb81ae72a6c408c62 -->

_Source: Linters/SAST tools_

<!-- This is an auto-generated comment by CodeRabbit -->

| `rewards` | `[B]` | float | One scalar per sequence. |
| `completion_mask` | `[B, T]` | bool / {0,1} | 2-D; `True` marks active tokens. |
| `loss` (output) | scalar | float32 | `policy_loss + beta * kl`. |
| `policy_loss`, `kl` (output) | scalar | float32 | Detached reporting values. |

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fix incorrect gradient contract for policy_loss and kl outputs.

The docs say these are “detached reporting values,” but the implementation returns them directly (no .detach()), so they remain connected to autograd. Please update the wording (or detach in code if that behavior is intentional).

@inaniloquentee

Copy link
Copy Markdown
Collaborator

Thanks for the contribution, this is a useful GRPO loss addition. One concern: the native path masks inactive tokens only after exp/surrogate/KL computation, so extreme padding values can create inf/nan in the autograd graph and leak nan gradients. Could we sanitize masked positions or operate only on active tokens before those numerically sensitive ops?

@KJLdefeated

Copy link
Copy Markdown
Contributor Author

Thanks for remind! I will sanitize the exponents before exp, masked_fill on delta and diff, so masked positions get ratio=1 / kl=0 (finite).

@inaniloquentee

Copy link
Copy Markdown
Collaborator

Thanks for remind! I will sanitize the exponents before exp, masked_fill on delta and diff, so masked positions get ratio=1 / kl=0 (finite).

LGTM, Thank you for your contribution!

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work and great memory optimizations! Just a few numerical and performance edge cases to address before merging.

Comment thread rl_engine/kernels/ops/pytorch/loss/grpo_loss.py Outdated
bounds_ptr, # int32[num_groups + 1], CSR-style group offsets
adv_ptr, # float32[N], per-sequence advantages (output)
eps,
GROUP_BLOCK: tl.constexpr,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in _group_norm_kernel launch config

I noticed you are launching _group_norm_kernel with GROUP_BLOCK = _next_pow2(max_group).

While this works perfectly for standard PPO/GRPO setups where samples_per_prompt is small (e.g., 8, 16, or 64), it will crash or fail to compile if a user provides highly skewed group_boundaries resulting in a max_group > 2048. Triton has hardware limits on block sizes.

Suggestion:
For this PR, we can keep the current design but please add an explicit check and throw a ValueError in TritonGRPOLossOp._build_bounds if max_group > 1024.

For future scaling, we might need a tiled reduction kernel here, but a safety assertion is enough for now to prevent cryptic CUDA errors.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will use try-except to catch it

policy_term = tl.where(keep, policy_term, 0.0)
kl_term = tl.where(keep, kl_term, 0.0)

tl.atomic_add(partials_ptr + 0, tl.sum(policy_term, axis=0))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Atomic Contention in _grpo_fwd_kernel

tl.atomic_add(partials_ptr + 0, tl.sum(policy_term, axis=0))
tl.atomic_add(partials_ptr + 1, tl.sum(kl_term, axis=0))

Having all SM blocks hammer the exact same 2 global memory addresses with atomic adds will cause severe memory contention on high-end GPUs like H100s.

Suggestion:
Instead of using tl.atomic_add, it is generally much faster to have the kernel output a [grid_size] tensor of block-level sums, and then perform a simple torch.sum() over it in PyTorch (in _GRPOLossFunction.forward).

Since this only happens once per forward pass, it’s not a blocker for merging, but consider changing this if you plan to submit further optimizations! I'm happy to approve this PR once the sqrt numerical trap in the native op is addressed.

@KJLdefeated

Copy link
Copy Markdown
Contributor Author

@Flink-ddd Thx for suggestion! The requested changes are done.

@KJLdefeated KJLdefeated requested a review from Flink-ddd June 10, 2026 03:58

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, Thank you for your contribution.

@Flink-ddd Flink-ddd merged commit ac357b7 into RL-Align:main Jun 10, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants