Skip to content

Commit f99bc20

Browse files
committed
fix: some edge cases with bert sentence slicing
1 parent edb0d62 commit f99bc20

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

nlstruct/models/common.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def output_bias(self):
315315
def bert_config(self):
316316
return self.bert.config
317317

318-
def exec_bert(self, tokens, mask, slices_begin=None, slices_end=None):
318+
def exec_bert(self, tokens, mask, slices_begin, slices_end):
319319
res = self.bert.forward(tokens, mask, output_hidden_states=True)
320320
if self.output_lm_embeds:
321321
lm_embeds = list(res.logits)
@@ -329,8 +329,7 @@ def exec_bert(self, tokens, mask, slices_begin=None, slices_end=None):
329329
token_features = torch.stack(token_features[-self.n_layers:], dim=2)
330330

331331
results = (token_features, *lm_embeds)
332-
if slices_begin is not None:
333-
results = multi_dim_slice(results, slices_begin, slices_end)
332+
results = multi_dim_slice(results, slices_begin, slices_end)
334333
return results
335334

336335
def forward(self, batch):
@@ -350,12 +349,18 @@ def forward(self, batch):
350349
flat_mask = mask[mask.any(-1)]
351350
flat_slices_begin = batch['slice_begin'][mask.any(-1)] if 'slice_begin' in batch else torch.zeros(len(flat_mask), device=device, dtype=torch.long)
352351
flat_slices_end = batch['slice_end'][mask.any(-1)] if 'slice_end' in batch else flat_mask.long().sum(1)
352+
353+
if 'slice_end' in batch:
354+
slices_lengths = (batch['slice_end'] - batch['slice_begin'])[..., None]
355+
mask = torch.arange(slices_lengths.max() if 0 not in slices_lengths.shape else 0, device=device) < slices_lengths
353356
else:
354357
needs_concat = False
355358
flat_tokens = tokens
356359
flat_mask = mask
357360
flat_slices_begin = batch['slice_begin'] if 'slice_begin' in batch else torch.zeros(len(flat_mask), device=device, dtype=torch.long)
358361
flat_slices_end = batch['slice_end'] if 'slice_end' in batch else flat_mask.long().sum(1)
362+
363+
359364
if self.do_cache:
360365
keys = [hash((tuple(row[:length]), begin, end)) for row, length, begin, end in zip(flat_tokens.tolist(), flat_mask.sum(1).tolist(), flat_slices_begin.tolist(), flat_slices_end.tolist())]
361366

0 commit comments

Comments
 (0)