Add LogitProcessor interface for pre-sampling logit transforms (#19517)#19517
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19517
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 6 Pending, 1 Unclassified FailureAs of commit 7d57b0a with merge base 4c474af ( UNCLASSIFIED FAILURE - DrCI could not classify the following job because the workflow did not run on the merge base. The failure may be pre-existing on trunk or introduced by this PR:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104767967. |
This PR needs a
|
|
@claude review |
|
Claude finished @kirklandsign's task in 1m 31s —— View job Code Review: Add LogitProcessor interface for pre-sampling logit transforms
Overall this is a clean, well-scoped addition. The interface is minimal, the integration point is correct, and tests cover the key behaviors. A few observations: Design — looks good
Potential issues
Buck targets
Nits
SummaryClean PR, well-designed interface. The main suggestion is adding an integration test for the 3D tensor offset logic in |
There was a problem hiding this comment.
Pull request overview
This PR introduces a LogitProcessor interface to enable in-place mutation of model logits between the decoder forward pass and sampling, allowing features like constrained decoding and logit biasing without changing the core generation loop.
Changes:
- Adds
LogitProcessorinterface underextension/llm/sampler/. - Extends
TextTokenGeneratorwith a configurable processor chain applied pre-sampling. - Updates Buck targets to export the new header and adds a unit test for the interface.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| extension/llm/sampler/test/test_logit_processor.cpp | Adds unit tests validating basic LogitProcessor behavior and ordering semantics. |
| extension/llm/sampler/test/targets.bzl | Adds a Buck test target for the new logit processor tests. |
| extension/llm/sampler/targets.bzl | Exports logit_processor.h from the sampler library target. |
| extension/llm/sampler/logit_processor.h | Introduces the LogitProcessor pure virtual interface. |
| extension/llm/runner/text_token_generator.h | Adds processor registration APIs and applies processor chain to logits before sampling. |
| extension/llm/runner/targets.bzl | Adds runner dependency on the sampler target (for LogitProcessor). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| const auto vocab_size = logits_tensor.size(logits_tensor.dim() - 1); | ||
| if (logits_tensor.dim() == 3) { | ||
| const auto num_tokens = logits_tensor.size(1); |
| auto* logits = logits_tensor.mutable_data_ptr<float>(); | ||
| const auto vocab_size = logits_tensor.size(logits_tensor.dim() - 1); | ||
| if (logits_tensor.dim() == 3) { | ||
| const auto num_tokens = logits_tensor.size(1); | ||
| logits += (num_tokens - 1) * vocab_size; | ||
| } | ||
| for (auto& processor : logit_processors_) { | ||
| processor->process(logits, static_cast<int32_t>(vocab_size)); | ||
| } |
| ET_CHECK_OR_RETURN_ERROR( | ||
| logits_tensor.scalar_type() == ::executorch::aten::ScalarType::Float, | ||
| InvalidArgument, | ||
| "LogitProcessor chain only supports Float logits; got dtype %d", | ||
| static_cast<int>(logits_tensor.scalar_type())); |
| if (!logit_processors_.empty()) { | ||
| ET_CHECK_OK_OR_RETURN_ERROR(apply_logit_processors_(logits_tensor)); | ||
| } |
| * @param vocab_size Number of logits in the buffer (size of the model's | ||
| * output vocabulary for the current step). | ||
| */ | ||
| virtual void process(float* logits, int32_t vocab_size) = 0; |
Summary: Introduces a `LogitProcessor` abstract interface that allows callers to mutate logits in place between the model forward pass and the sampler. This enables grammar-constrained decoding, logit biasing, repetition penalties, and similar pre-sampling transforms without modifying the core generation loop. Changes: - `LogitProcessor` (new): pure virtual interface with a single `process(float*, int32_t)` method, placed in `extension/llm/sampler/`. - `TextTokenGenerator`: gains `add_logit_processor()`, `clear_logit_processors()`, and `num_logit_processors()`. The processor chain runs after the model step and before `logits_to_token()`. When no processors are registered, behavior is identical to before. - `apply_logit_processors_()`: private helper that validates Float dtype, advances to the last-position logits for 3D tensors (mirroring `logits_to_token`), and invokes each processor in order. - Buck: `logit_processor.h` exported from the sampler target; `text_token_generator` gains a direct dep on sampler; test target added. Processors must be configured before calling `generate()` — concurrent modification during generation is not safe. Differential Revision: D104767967
3b3862f to
6ebfdf6
Compare
Summary: Introduces a `LogitProcessor` abstract interface that allows callers to mutate logits in place between the model forward pass and the sampler. This enables grammar-constrained decoding, logit biasing, repetition penalties, and similar pre-sampling transforms without modifying the core generation loop. Changes: - `LogitProcessor` (new): pure virtual interface with a single `process(float*, int32_t)` method, placed in `extension/llm/sampler/`. - `TextTokenGenerator`: gains `add_logit_processor()`, `clear_logit_processors()`, and `num_logit_processors()`. The processor chain runs after the model step and before `logits_to_token()`. When no processors are registered, behavior is identical to before. - `apply_logit_processors_()`: private helper that validates Float dtype, advances to the last-position logits for 3D tensors (mirroring `logits_to_token`), and invokes each processor in order. - Buck: `logit_processor.h` exported from the sampler target; `text_token_generator` gains a direct dep on sampler; test target added. Processors must be configured before calling `generate()` — concurrent modification during generation is not safe. Differential Revision: D104767967
6ebfdf6 to
d03e9db
Compare
Summary: Introduces a `LogitProcessor` abstract interface that allows callers to mutate logits in place between the model forward pass and the sampler. This enables grammar-constrained decoding, logit biasing, repetition penalties, and similar pre-sampling transforms without modifying the core generation loop. Changes: - `LogitProcessor` (new): pure virtual interface with a single `process(float*, int32_t)` method, placed in `extension/llm/sampler/`. - `TextTokenGenerator`: gains `add_logit_processor()`, `clear_logit_processors()`, and `num_logit_processors()`. The processor chain runs after the model step and before `logits_to_token()`. When no processors are registered, behavior is identical to before. - `apply_logit_processors_()`: private helper that validates Float dtype, advances to the last-position logits for 3D tensors (mirroring `logits_to_token`), and invokes each processor in order. - Buck: `logit_processor.h` exported from the sampler target; `text_token_generator` gains a direct dep on sampler; test target added. Processors must be configured before calling `generate()` — concurrent modification during generation is not safe. Differential Revision: D104767967
d03e9db to
cac76f9
Compare
Summary: Introduces a `LogitProcessor` abstract interface that allows callers to mutate logits in place between the model forward pass and the sampler. This enables grammar-constrained decoding, logit biasing, repetition penalties, and similar pre-sampling transforms without modifying the core generation loop. Changes: - `LogitProcessor` (new): abstract class with a constructor that takes `vocab_size` and a pure virtual `process(float*)` method, placed in `extension/llm/sampler/`. The `vocab_size` is fixed per model and stored as a member, avoiding redundant per-call arguments. - `TextTokenGenerator`: gains `add_logit_processor()`, `clear_logit_processors()`, and `num_logit_processors()`. The processor chain runs after the model step and before `logits_to_token()`. When no processors are registered, behavior is identical to before. - `apply_logit_processors_()`: private helper that advances to the last-position logits for 3D tensors (mirroring `logits_to_token`), and invokes each processor in order. Supports Float, Half, BFloat16, and UInt16 dtypes: Float logits are processed in place (zero-copy); for other dtypes, logits are cast to a temporary float buffer, processed, then cast back to the original dtype. - Buck: `logit_processor.h` exported from the sampler target; `text_token_generator` gains a direct dep on sampler; test target added. Processors must be configured before calling `generate()` -- concurrent modification during generation is not safe. Differential Revision: D104767967
cac76f9 to
7676c50
Compare
|
@claude review |
|
Claude finished @kirklandsign's task in 1m 5s —— View job Code Review: Add LogitProcessor interface for pre-sampling logit transforms
Clean, well-scoped PR. The interface is minimal and the integration point is correct. A few observations: Strengths
Issues
Nits
Buck targets — correct
SummaryThe main actionable item is the |
Summary: Introduces a `LogitProcessor` abstract interface that allows callers to mutate logits in place between the model forward pass and the sampler. This enables grammar-constrained decoding, logit biasing, repetition penalties, and similar pre-sampling transforms without modifying the core generation loop. Changes: - `LogitProcessor` (new): abstract class with a constructor that takes `vocab_size` and a pure virtual `process(float*)` method, placed in `extension/llm/sampler/`. The `vocab_size` is fixed per model and stored as a member, avoiding redundant per-call arguments. - `TextTokenGenerator`: gains `add_logit_processor()`, `clear_logit_processors()`, and `num_logit_processors()`. The processor chain runs after the model step and before `logits_to_token()`. When no processors are registered, behavior is identical to before. - `apply_logit_processors_()`: private helper that advances to the last-position logits for 3D tensors (mirroring `logits_to_token`), and invokes each processor in order. Supports Float, Half, BFloat16, and UInt16 dtypes: Float logits are processed in place (zero-copy); for other dtypes, logits are cast to a temporary float buffer, processed, then cast back to the original dtype. - Buck: `logit_processor.h` exported from the sampler target; `text_token_generator` gains a direct dep on sampler; test target added. Processors must be configured before calling `generate()` -- concurrent modification during generation is not safe. Differential Revision: D104767967
a7cdd2e to
a9add25
Compare
Summary: Introduces a `LogitProcessor` abstract interface that allows callers to mutate logits in place between the model forward pass and the sampler. This enables grammar-constrained decoding, logit biasing, repetition penalties, and similar pre-sampling transforms without modifying the core generation loop. Changes: - `LogitProcessor` (new): abstract class with a constructor that takes `vocab_size` and a pure virtual `process(float*)` method, placed in `extension/llm/sampler/`. The `vocab_size` is fixed per model and stored as a member, avoiding redundant per-call arguments. - `TextTokenGenerator`: gains `add_logit_processor()`, `clear_logit_processors()`, and `num_logit_processors()`. The processor chain runs after the model step and before `logits_to_token()`. When no processors are registered, behavior is identical to before. - `apply_logit_processors_()`: private helper that advances to the last-position logits for 3D tensors (mirroring `logits_to_token`), and invokes each processor in order. Supports Float, Half, BFloat16, and UInt16 dtypes: Float logits are processed in place (zero-copy); for other dtypes, logits are cast to a temporary float buffer, processed, then cast back to the original dtype. - Buck: `logit_processor.h` exported from the sampler target; `text_token_generator` gains a direct dep on sampler; test target added. Processors must be configured before calling `generate()` -- concurrent modification during generation is not safe. Differential Revision: D104767967
a9add25 to
b25bbf5
Compare
Summary: Introduces a `LogitProcessor` abstract interface that allows callers to mutate logits in place between the model forward pass and the sampler. This enables grammar-constrained decoding, logit biasing, repetition penalties, and similar pre-sampling transforms without modifying the core generation loop. Changes: - `LogitProcessor` (new): abstract class with a constructor that takes `vocab_size` and a pure virtual `process(float*)` method, placed in `extension/llm/sampler/`. The `vocab_size` is fixed per model and stored as a member, avoiding redundant per-call arguments. - `TextTokenGenerator`: gains `add_logit_processor()`, `clear_logit_processors()`, and `num_logit_processors()`. The processor chain runs after the model step and before `logits_to_token()`. When no processors are registered, behavior is identical to before. - `apply_logit_processors_()`: private helper that advances to the last-position logits for 3D tensors (mirroring `logits_to_token`), and invokes each processor in order. Supports Float, Half, BFloat16, and UInt16 dtypes: Float logits are processed in place (zero-copy); for other dtypes, logits are cast to a temporary float buffer, processed, then cast back to the original dtype. - Buck: `logit_processor.h` exported from the sampler target; `text_token_generator` gains a direct dep on sampler; test target added. Processors must be configured before calling `generate()` -- concurrent modification during generation is not safe. Differential Revision: D104767967
b25bbf5 to
fa26f94
Compare
Summary:
Introduces a `LogitProcessor` abstract interface that lets callers mutate
logits in place between the model forward pass and the sampler. Enables
grammar-constrained decoding, logit biasing, repetition penalties, and
similar transforms without touching the core generation loop.
Interface (`extension/llm/sampler/logit_processor.h`, ~15 lines):
- Single virtual method `process(::executorch::aten::Tensor logits)` that
returns `Error::Ok` or aborts the chain on a non-Ok return.
- Tensor passed by value (handle-typed ATen idiom; mutations propagate
through the shared underlying buffer).
- Each implementation declares its own dtype expectations -- the chain
runner does not cast or copy the tensor. Typical implementations check
`logits.scalar_type()` and either dispatch to a kernel or return
InvalidArgument.
- Tensor shape contract (rank 2 = `[batch, vocab]`, rank 3 =
`[batch, seq, vocab]` advanced to last sequence position) mirrors
`sample_from_logits`.
Wiring (`extension/llm/runner/text_token_generator.h`):
- New public methods `add_logit_processor`, `clear_logit_processors`,
`num_logit_processors`.
- Inside `generate()`, between `step()` and `logits_to_token()`, the loop
invokes each registered processor:
for (auto& processor : logit_processors_) {
ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits_tensor));
}
- Empty chain is the existing fast path; no behavior change for callers
that don't register processors.
Configure processors before calling `generate()` -- concurrent
modification during generation is not safe.
Differential Revision: D104767967
fa26f94 to
b416b30
Compare
Summary:
Introduces a `LogitProcessor` abstract interface that lets callers mutate
logits in place between the model forward pass and the sampler. Enables
grammar-constrained decoding, logit biasing, repetition penalties, and
similar transforms without touching the core generation loop.
Interface (`extension/llm/sampler/logit_processor.h`, ~15 lines):
- Single virtual method `process(::executorch::aten::Tensor logits)` that
returns `Error::Ok` or aborts the chain on a non-Ok return.
- Tensor passed by value (handle-typed ATen idiom; mutations propagate
through the shared underlying buffer).
- Each implementation declares its own dtype expectations -- the chain
runner does not cast or copy the tensor. Typical implementations check
`logits.scalar_type()` and either dispatch to a kernel or return
InvalidArgument.
- Tensor shape contract (rank 2 = `[batch, vocab]`, rank 3 =
`[batch, seq, vocab]` advanced to last sequence position) mirrors
`sample_from_logits`.
Wiring (`extension/llm/runner/text_token_generator.h`):
- New public methods `add_logit_processor`, `clear_logit_processors`,
`num_logit_processors`.
- Inside `generate()`, between `step()` and `logits_to_token()`, the loop
invokes each registered processor:
for (auto& processor : logit_processors_) {
ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits_tensor));
}
- Empty chain is the existing fast path; no behavior change for callers
that don't register processors.
Configure processors before calling `generate()` -- concurrent
modification during generation is not safe.
Differential Revision: D104767967
b416b30 to
c08e1e3
Compare
Summary:
Introduces a `LogitProcessor` abstract interface that lets callers mutate
logits in place between the model forward pass and the sampler. Enables
grammar-constrained decoding, logit biasing, repetition penalties, and
similar transforms without touching the core generation loop.
Interface (`extension/llm/sampler/logit_processor.h`, ~15 lines):
- Single virtual method `process(::executorch::aten::Tensor logits)` that
returns `Error::Ok` or aborts the chain on a non-Ok return.
- Tensor passed by value (handle-typed ATen idiom; mutations propagate
through the shared underlying buffer).
- Each implementation declares its own dtype expectations -- the chain
runner does not cast or copy the tensor. Typical implementations check
`logits.scalar_type()` and either dispatch to a kernel or return
InvalidArgument.
- Tensor shape contract (rank 2 = `[batch, vocab]`, rank 3 =
`[batch, seq, vocab]` advanced to last sequence position) mirrors
`sample_from_logits`.
Wiring (`extension/llm/runner/text_token_generator.h`):
- New public methods `add_logit_processor`, `clear_logit_processors`,
`num_logit_processors`.
- Inside `generate()`, between `step()` and `logits_to_token()`, the loop
invokes each registered processor:
for (auto& processor : logit_processors_) {
ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits_tensor));
}
- Empty chain is the existing fast path; no behavior change for callers
that don't register processors.
Configure processors before calling `generate()` -- concurrent
modification during generation is not safe.
Reviewed By: Gasoonjia
Differential Revision: D104767967
c08e1e3 to
0e37b09
Compare
Summary:
Introduces a `LogitProcessor` abstract interface that lets callers mutate
logits in place between the model forward pass and the sampler. Enables
grammar-constrained decoding, logit biasing, repetition penalties, and
similar transforms without touching the core generation loop.
Interface (`extension/llm/sampler/logit_processor.h`, ~15 lines):
- Single virtual method `process(::executorch::aten::Tensor logits)` that
returns `Error::Ok` or aborts the chain on a non-Ok return.
- Tensor passed by value (handle-typed ATen idiom; mutations propagate
through the shared underlying buffer).
- Each implementation declares its own dtype expectations -- the chain
runner does not cast or copy the tensor. Typical implementations check
`logits.scalar_type()` and either dispatch to a kernel or return
InvalidArgument.
- Tensor shape contract (rank 2 = `[batch, vocab]`, rank 3 =
`[batch, seq, vocab]` advanced to last sequence position) mirrors
`sample_from_logits`.
Wiring (`extension/llm/runner/text_token_generator.h`):
- New public methods `add_logit_processor`, `clear_logit_processors`,
`num_logit_processors`.
- Inside `generate()`, between `step()` and `logits_to_token()`, the loop
invokes each registered processor:
for (auto& processor : logit_processors_) {
ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits_tensor));
}
- Empty chain is the existing fast path; no behavior change for callers
that don't register processors.
Configure processors before calling `generate()` -- concurrent
modification during generation is not safe.
Reviewed By: Gasoonjia
Differential Revision: D104767967
0e37b09 to
7d57b0a
Compare
Summary:
Introduces a
LogitProcessorabstract interface that lets callers mutatelogits in place between the model forward pass and the sampler. Enables
grammar-constrained decoding, logit biasing, repetition penalties, and
similar transforms without touching the core generation loop.
Interface (
extension/llm/sampler/logit_processor.h, ~15 lines):process(::executorch::aten::Tensor logits)thatreturns
Error::Okor aborts the chain on a non-Ok return.through the shared underlying buffer).
runner does not cast or copy the tensor. Typical implementations check
logits.scalar_type()and either dispatch to a kernel or returnInvalidArgument.
[batch, vocab], rank 3 =[batch, seq, vocab]advanced to last sequence position) mirrorssample_from_logits.Wiring (
extension/llm/runner/text_token_generator.h):add_logit_processor,clear_logit_processors,num_logit_processors.generate(), betweenstep()andlogits_to_token(), the loopinvokes each registered processor:
for (auto& processor : logit_processors_) {
ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits_tensor));
}
that don't register processors.
Configure processors before calling
generate()-- concurrentmodification during generation is not safe.
Reviewed By: Gasoonjia
Differential Revision: D104767967
cc @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng