Skip to content
17 changes: 17 additions & 0 deletions fast_llm/layers/language_model/loss/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import typing
import warnings

Expand Down Expand Up @@ -193,6 +194,12 @@ def loss_class(self) -> "type[LanguageModelZLoss]":
return LanguageModelZLoss


class GRPOMetricsLevel(enum.StrEnum):
none = "none"
basic = "basic"
with_entropy = "with_entropy"


@config_class(dynamic_type={LanguageModelLossConfig: "grpo"})
class LanguageModelGRPOLossConfig(LanguageModelLossConfig):

Expand All @@ -205,6 +212,16 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig):
desc="Enable triton implementation. Default: use if available.",
hint=FieldHint.expert,
)
metrics: GRPOMetricsLevel = Field(
default=GRPOMetricsLevel.none,
desc=(
"Additional GRPO metrics to log. "
"`basic`: per-token ratio, KL, and advantage statistics. "
"`with_entropy`: also log per-token entropy. "
"Not supported with pipeline_parallel > 1."
),
hint=FieldHint.feature,
)

@property
def loss_class(self) -> "type[LanguageModelGRPOLoss]":
Expand Down
192 changes: 189 additions & 3 deletions fast_llm/layers/language_model/loss/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,69 @@

import torch

from fast_llm.engine.base_model.config import LossDef
from fast_llm.core.distributed import ReduceOp, all_reduce
from fast_llm.engine.base_model.config import LossDef, ReductionType
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.functional.config import TritonConfig
from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base
from fast_llm.functional.utils import reduce_losses
from fast_llm.layers.language_model.config import LanguageModelKwargs
from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs
from fast_llm.layers.language_model.loss.config import (
GRPOMetricsLevel,
LanguageModelGRPOLossConfig,
LanguageModelLossKwargs,
)
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss
from fast_llm.utils import Assert


class GRPOMetrics(typing.NamedTuple):
old_logprobs: torch.Tensor
ratio_new_old: torch.Tensor
ratio_new_old_sum: torch.Tensor
ratio_new_old_squared_sum: torch.Tensor
kl_new_old: torch.Tensor
clipped_ratio_fraction: torch.Tensor
advantage: torch.Tensor
max_advantage: torch.Tensor
min_advantage: torch.Tensor
num_tokens: torch.Tensor
entropy: torch.Tensor | None


class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]):
def __init__(
self,
config: ConfigType,
distributed_config: DistributedConfig,
*,
name: str,
prediction_distance: int = 1,
prediction_heads: int = 1,
vocab_parallel: bool = False,
num_splits: int = 1,
logits_scale_factor: float = 1.0,
weight: float = 1.0,
register_loss: bool = False,
):
super().__init__(
config,
distributed_config,
name=name,
prediction_distance=prediction_distance,
prediction_heads=prediction_heads,
vocab_parallel=vocab_parallel,
num_splits=num_splits,
logits_scale_factor=logits_scale_factor,
weight=weight,
register_loss=register_loss,
)
Assert.custom(
lambda metrics, pipeline_parallel: metrics == GRPOMetricsLevel.none or pipeline_parallel == 1,
config.metrics,
distributed_config.pipeline_parallel,
)

def _forward_backward(
self,
logits: "torch.Tensor",
Expand Down Expand Up @@ -51,10 +104,88 @@ def _forward_backward(
self._register_loss(
self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM
)

# Skip the extra softmax pass when there is nothing to register.
if losses is not None and self._config.metrics != GRPOMetricsLevel.none:
self._register_extra_metrics(logits, kwargs, losses, split_index)

return loss, grad

def _register_extra_metrics(
self,
logits: torch.Tensor,
kwargs: dict[str, typing.Any],
losses: dict | None,
split_index: int,
) -> None:
metrics = compute_grpo_metrics(
logits,
self._get_labels(kwargs, split_index),
self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index),
self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index),
self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index),
self._config.epsilon_low,
self._config.epsilon_high,
self._logits_scale_factor,
group=self._parallel_dim.group if self._vocab_parallel else None,
compute_entropy=self._config.metrics == GRPOMetricsLevel.with_entropy,
)

num_documents = kwargs[LanguageModelKwargs.num_documents_in_batch]

for attr in (
"old_logprobs",
"ratio_new_old",
"kl_new_old",
"clipped_ratio_fraction",
"advantage",
):
self._register_loss(f"{self._name}_{attr}", getattr(metrics, attr) / num_documents, losses)

for attr in (
"ratio_new_old_sum",
"ratio_new_old_squared_sum",
"num_tokens",
):
self._register_loss(f"{self._name}_{attr}", getattr(metrics, attr), losses)

self._register_loss(
f"{self._name}_max_advantage",
metrics.max_advantage,
losses,
reduce_op=torch.distributed.ReduceOp.MAX,
)
self._register_loss(
f"{self._name}_min_advantage",
metrics.min_advantage,
losses,
reduce_op=torch.distributed.ReduceOp.MIN,
)

if metrics.entropy is not None:
self._register_loss(f"{self._name}_entropy", metrics.entropy / num_documents, losses)

def get_loss_definitions(self) -> list[LossDef]:
return super().get_loss_definitions() + [LossDef(self._logprob_metric_name)]
defs = super().get_loss_definitions()
defs.append(LossDef(self._logprob_metric_name))
if self._config.metrics != GRPOMetricsLevel.none:
defs.extend(
[
LossDef(f"{self._name}_old_logprobs"),
LossDef(f"{self._name}_ratio_new_old"),
LossDef(f"{self._name}_ratio_new_old_sum"),
LossDef(f"{self._name}_ratio_new_old_squared_sum"),
LossDef(f"{self._name}_kl_new_old"),
LossDef(f"{self._name}_clipped_ratio_fraction"),
LossDef(f"{self._name}_advantage"),
LossDef(f"{self._name}_max_advantage", reduction=ReductionType.maximum),
LossDef(f"{self._name}_min_advantage", reduction=ReductionType.minimum),
LossDef(f"{self._name}_num_tokens"),
]
)
if self._config.metrics == GRPOMetricsLevel.with_entropy:
defs.append(LossDef(f"{self._name}_entropy"))
return defs

def get_preprocessing_config(
self,
Expand All @@ -66,6 +197,61 @@ def _logprob_metric_name(self) -> str:
return f"{self._name}_new_logprobs"


@torch.compile
def compute_grpo_metrics(
logits: torch.Tensor, # (*batch, vocab_local)
target: torch.Tensor, # (*batch,)
advantages: torch.Tensor, # (*batch,)
old_log_probabilities: torch.Tensor, # (*batch,)
label_counts: torch.Tensor, # (*batch,) — global per-sequence count broadcast per token
epsilon_low: float = 0.2,
epsilon_high: float = 0.2,
logits_scale_factor: float = 1.0,
group: torch.distributed.ProcessGroup | None = None,
compute_entropy: bool = False,
) -> GRPOMetrics:
loss_mask = target >= 0
mask = loss_mask.float()
masked = mask / label_counts.float().clamp(min=1)

logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group)
predicted_logits, _, _ = fused_predicted_logits_from_labels(logits_norm, target, loss_mask, group)
new_log_probs = predicted_logits - sum_exp_logits.log()

log_ratio = new_log_probs - old_log_probabilities
ratio = log_ratio.exp()
clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high)
kl = ratio - log_ratio - 1.0

neg_inf = advantages.new_full((), float("-inf"))
pos_inf = advantages.new_full((), float("inf"))

entropy: torch.Tensor | None = None
if compute_entropy:
# exp_logits and logits_norm are local vocab slices — sum over the local slice, then all-reduce
# across the tensor-parallel group to recover the global E_p[logit_norm] before dividing by the
# already-global sum_exp_logits.
weighted_logits_sum = (exp_logits * logits_norm).sum(-1)
if group is not None:
all_reduce(weighted_logits_sum, op=ReduceOp.SUM, group=group)
entropy_per_token = sum_exp_logits.log() - weighted_logits_sum / sum_exp_logits
entropy = (entropy_per_token * masked).sum()

return GRPOMetrics(
old_logprobs=(old_log_probabilities * masked).sum(),
ratio_new_old=(ratio * masked).sum(),
ratio_new_old_sum=(ratio * mask).sum(),
ratio_new_old_squared_sum=(ratio * ratio * mask).sum(),
kl_new_old=(kl * masked).sum(),
clipped_ratio_fraction=(clipped.float() * masked).sum(),
advantage=(advantages * masked).sum(),
max_advantage=torch.where(loss_mask, advantages, neg_inf).max(),
min_advantage=torch.where(loss_mask, advantages, pos_inf).min(),
num_tokens=mask.sum(),
entropy=entropy,
)


@torch.compile
def fused_grpo_loss_forward_backward(
logits: torch.Tensor, # (*batch, vocab)
Expand Down
Loading
Loading