Skip to content

Conversation

@dchou1618
Copy link
Owner

@dchou1618 dchou1618 commented Jan 5, 2026

User description

Description

  • practice with pooling and general torch functions

PR Type

Enhancement, Tests


Description

  • Implement VariableSortedHistoryPooling module for jagged event embeddings

  • Pool variable-length user event histories using scatter operations

  • Add comprehensive test suite validating pooling correctness


Diagram Walkthrough

flowchart LR
  A["Event Indices<br/>Offsets"] -->|"Embedding Lookup"| B["Event Embeddings"]
  B -->|"Scatter Add<br/>by User ID"| C["Aggregated User<br/>Embeddings"]
  C -->|"Normalize by<br/>User Length"| D["Pooled Output"]
  E["Reference Loop<br/>Implementation"] -->|"Validation"| F["Test Assertions"]
  D -->|"Compare"| F
Loading

File Walkthrough

Relevant files
Enhancement
pooling.py
Variable-length history pooling module implementation       

LLMs/torch_examples/pooling.py

  • Implement VariableSortedHistoryPooling PyTorch module for aggregating
    variable-length event sequences
  • Use embedding layer to encode event indices
  • Apply scatter_add operation to pool embeddings by user ID based on
    offset boundaries
  • Normalize pooled embeddings by user history length
+18/-0   
Tests
test_pooling.py
Pooling module correctness and shape validation tests       

tests/test_pooling.py

  • Create reference loop-based pooling implementation for validation
  • Test VariableSortedHistoryPooling with multiple users and variable
    history lengths
  • Verify output shape matches expected dimensions (B users × embedding
    dimension)
  • Assert vectorized implementation matches reference loop implementation
    within tolerance
+36/-0   

jagged histories in pooling embeddings

practice with pooling and general torch functions
@dchou1618 dchou1618 marked this pull request as ready for review January 6, 2026 03:34
@qodo-code-review
Copy link

qodo-code-review bot commented Jan 6, 2026

PR Compliance Guide 🔍

Below is a summary of compliance checks for this PR:

Security Compliance
🟢
No security concerns identified No security vulnerabilities detected by AI analysis. Human verification advised for critical code.
Ticket Compliance
🎫 No ticket provided
  • Create ticket/issue
Codebase Duplication Compliance
Codebase context is not defined

Follow the guide to enable codebase context checks.

Custom Compliance
🟢
Generic: Comprehensive Audit Trails

Objective: To create a detailed and reliable record of critical system actions for security analysis
and compliance.

Status: Passed

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Secure Error Handling

Objective: To prevent the leakage of sensitive system information through error messages while
providing sufficient detail for internal debugging.

Status: Passed

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Secure Logging Practices

Objective: To ensure logs are useful for debugging and auditing without exposing sensitive
information like PII, PHI, or cardholder data.

Status: Passed

Learn more about managing compliance generic rules or creating your own custom rules

🔴
Generic: Meaningful Naming and Self-Documenting Code

Objective: Ensure all identifiers clearly express their purpose and intent, making code
self-documenting

Status:
Misleading parameter name: The constructor parameter n_samples is documented as "n events" but is used as
the embedding's vocabulary/num-embeddings size, which is misleading and harms
self-documentation.

Referred Code
def __init__(self, n_samples: int, emb_dim: int):
    super(VariableSortedHistoryPooling, self).__init__()
    # n samples are n events, where it's consecutive events belonging to a given user
    # The n samples can be segmented into B users.
    self.emb = torch.nn.Embedding(n_samples, emb_dim)

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Robust Error Handling and Edge Case Management

Objective: Ensure comprehensive error handling that provides meaningful context and graceful
degradation

Status:
Missing input validation: forward() performs no validation that offsets are sorted, within bounds, and consistent
with event_indices length (e.g., sum of user_lengths), which can cause runtime errors or
incorrect pooling on edge cases.

Referred Code
def forward(self, event_indices: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
    event_embs = self.emb(event_indices)
    # diffs of cumulative offsets gives user lengths (number of events in history per user)
    user_lengths = offsets[1:] - offsets[:-1]
    user_ids = torch.repeat_interleave(torch.arange(len(user_lengths),
                                                    device=offsets.device), 
                                        user_lengths)
    target = torch.zeros(len(user_lengths), event_embs.shape[1], device=event_embs.device)
    target = target.scatter_add(dim=0, index=user_ids.unsqueeze(1).expand_as(event_embs), src=event_embs)

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Security-First Input Validation and Data Handling

Objective: Ensure all data inputs are validated, sanitized, and handled securely to prevent
vulnerabilities

Status:
Unvalidated tensor inputs: The module accepts event_indices and offsets without checking dtype/range constraints
(e.g., event_indices within embedding vocab, offsets monotonic/non-negative), and it is
unclear from the diff whether upstream guarantees make this safe.

Referred Code
def forward(self, event_indices: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
    event_embs = self.emb(event_indices)
    # diffs of cumulative offsets gives user lengths (number of events in history per user)
    user_lengths = offsets[1:] - offsets[:-1]
    user_ids = torch.repeat_interleave(torch.arange(len(user_lengths),
                                                    device=offsets.device), 
                                        user_lengths)
    target = torch.zeros(len(user_lengths), event_embs.shape[1], device=event_embs.device)
    target = target.scatter_add(dim=0, index=user_ids.unsqueeze(1).expand_as(event_embs), src=event_embs)
    return target / user_lengths.clamp(min=1).unsqueeze(1) 

Learn more about managing compliance generic rules or creating your own custom rules

  • Update
Compliance status legend 🟢 - Fully Compliant
🟡 - Partial Compliant
🔴 - Not Compliant
⚪ - Requires Further Human Verification
🏷️ - Compliance label

@qodo-code-review
Copy link

qodo-code-review bot commented Jan 6, 2026

PR Code Suggestions ✨

Explore these optional code suggestions:

CategorySuggestion                                                                                                                                    Impact
High-level
Use torch.nn.EmbeddingBag for pooling

Replace the custom VariableSortedHistoryPooling module with the standard
torch.nn.EmbeddingBag using mode='mean'. This change simplifies the
implementation to a single, optimized layer call.

Examples:

LLMs/torch_examples/pooling.py [3-18]
class VariableSortedHistoryPooling(torch.nn.Module):
    def __init__(self, n_samples: int, emb_dim: int):
        super(VariableSortedHistoryPooling, self).__init__()
        # n samples are n events, where it's consecutive events belonging to a given user
        # The n samples can be segmented into B users.
        self.emb = torch.nn.Embedding(n_samples, emb_dim)
    def forward(self, event_indices: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
        event_embs = self.emb(event_indices)
        # diffs of cumulative offsets gives user lengths (number of events in history per user)
        user_lengths = offsets[1:] - offsets[:-1]

 ... (clipped 6 lines)

Solution Walkthrough:

Before:

class VariableSortedHistoryPooling(torch.nn.Module):
    def __init__(self, n_samples: int, emb_dim: int):
        self.emb = torch.nn.Embedding(n_samples, emb_dim)
    
    def forward(self, event_indices: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
        event_embs = self.emb(event_indices)
        user_lengths = offsets[1:] - offsets[:-1]
        user_ids = torch.repeat_interleave(torch.arange(len(user_lengths)), user_lengths)
        target = torch.zeros(len(user_lengths), event_embs.shape[1], ...)
        target = target.scatter_add(0, user_ids.unsqueeze(1).expand_as(event_embs), event_embs)
        return target / user_lengths.clamp(min=1).unsqueeze(1)

After:

# The entire VariableSortedHistoryPooling class can be removed.
# It can be replaced by a single line in the model definition.

# In tests/test_pooling.py or any usage site:

# Before:
# model = VariableSortedHistoryPooling(vocab_size, emb_dim)

# After:
model = torch.nn.EmbeddingBag(vocab_size, emb_dim, mode='mean')

# The forward pass call remains the same and produces the same result:
out_vec = model(event_indices, offsets)
Suggestion importance[1-10]: 9

__

Why: The suggestion correctly identifies that the custom module reinvents torch.nn.EmbeddingBag, and replacing it greatly simplifies the code, improves performance, and aligns with idiomatic PyTorch practices.

High
Possible issue
Validate offsets tensor

Add validation checks for the offsets tensor to ensure it is a 1D,
non-decreasing tensor with at least two elements.

LLMs/torch_examples/pooling.py [12]

+if offsets.dim() != 1 or offsets.size(0) < 2:
+    raise ValueError("offsets must be a 1D tensor with at least two elements")
+if not torch.all(offsets[1:] >= offsets[:-1]):
+    raise ValueError("offsets must be non-decreasing")
 user_lengths = offsets[1:] - offsets[:-1]
  • Apply / Chat
Suggestion importance[1-10]: 7

__

Why: This is a good suggestion for making the module more robust by adding input validation for the offsets tensor, which can prevent cryptic runtime errors for invalid inputs.

Medium
Enforce long dtype for indices

Ensure the event_indices tensor is of torch.long dtype before passing it to the
embedding layer by casting it if necessary.

LLMs/torch_examples/pooling.py [10]

+if event_indices.dtype != torch.long:
+    event_indices = event_indices.long()
 event_embs = self.emb(event_indices)
  • Apply / Chat
Suggestion importance[1-10]: 7

__

Why: This suggestion improves the robustness of the module by ensuring event_indices has the correct torch.long dtype required by torch.nn.Embedding, preventing potential runtime errors.

Medium
General
Rename a misleading parameter name

Rename the n_samples parameter in init to num_embeddings to accurately
reflect that it represents the vocabulary size, not the number of samples.

LLMs/torch_examples/pooling.py [3-8]

 class VariableSortedHistoryPooling(torch.nn.Module):
-    def __init__(self, n_samples: int, emb_dim: int):
+    def __init__(self, num_embeddings: int, emb_dim: int):
         super(VariableSortedHistoryPooling, self).__init__()
-        # n samples are n events, where it's consecutive events belonging to a given user
-        # The n samples can be segmented into B users.
-        self.emb = torch.nn.Embedding(n_samples, emb_dim)
+        # num_embeddings is the total number of unique events (i.e. vocabulary size).
+        self.emb = torch.nn.Embedding(num_embeddings, emb_dim)
  • Apply / Chat
Suggestion importance[1-10]: 6

__

Why: The suggestion correctly identifies that n_samples is a misleading parameter name, as it represents the vocabulary size, and proposes a clearer name aligning with PyTorch conventions, which significantly improves code readability and maintainability.

Low
  • Update

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants