@@ -315,7 +315,7 @@ def output_bias(self):
315315 def bert_config (self ):
316316 return self .bert .config
317317
318- def exec_bert (self , tokens , mask , slices_begin = None , slices_end = None ):
318+ def exec_bert (self , tokens , mask , slices_begin , slices_end ):
319319 res = self .bert .forward (tokens , mask , output_hidden_states = True )
320320 if self .output_lm_embeds :
321321 lm_embeds = list (res .logits )
@@ -329,8 +329,7 @@ def exec_bert(self, tokens, mask, slices_begin=None, slices_end=None):
329329 token_features = torch .stack (token_features [- self .n_layers :], dim = 2 )
330330
331331 results = (token_features , * lm_embeds )
332- if slices_begin is not None :
333- results = multi_dim_slice (results , slices_begin , slices_end )
332+ results = multi_dim_slice (results , slices_begin , slices_end )
334333 return results
335334
336335 def forward (self , batch ):
@@ -350,12 +349,18 @@ def forward(self, batch):
350349 flat_mask = mask [mask .any (- 1 )]
351350 flat_slices_begin = batch ['slice_begin' ][mask .any (- 1 )] if 'slice_begin' in batch else torch .zeros (len (flat_mask ), device = device , dtype = torch .long )
352351 flat_slices_end = batch ['slice_end' ][mask .any (- 1 )] if 'slice_end' in batch else flat_mask .long ().sum (1 )
352+
353+ if 'slice_end' in batch :
354+ slices_lengths = (batch ['slice_end' ] - batch ['slice_begin' ])[..., None ]
355+ mask = torch .arange (slices_lengths .max () if 0 not in slices_lengths .shape else 0 , device = device ) < slices_lengths
353356 else :
354357 needs_concat = False
355358 flat_tokens = tokens
356359 flat_mask = mask
357360 flat_slices_begin = batch ['slice_begin' ] if 'slice_begin' in batch else torch .zeros (len (flat_mask ), device = device , dtype = torch .long )
358361 flat_slices_end = batch ['slice_end' ] if 'slice_end' in batch else flat_mask .long ().sum (1 )
362+
363+
359364 if self .do_cache :
360365 keys = [hash ((tuple (row [:length ]), begin , end )) for row , length , begin , end in zip (flat_tokens .tolist (), flat_mask .sum (1 ).tolist (), flat_slices_begin .tolist (), flat_slices_end .tolist ())]
361366
0 commit comments