-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Description
Describe the bug
I think there might be an issue with calculating the sequence length of cap_feat (which is the text encoder output), and masking it accordingly.
I'm going to use code links from before the omni commit, because it's easier to read - but the issue seems exist in both, before and after the omni commit.
diffusers/src/diffusers/models/transformers/transformer_z_image.py
Lines 593 to 611 in 52766e6
| cap_item_seqlens = [len(_) for _ in cap_feats] | |
| cap_max_item_seqlen = max(cap_item_seqlens) | |
| cap_feats = torch.cat(cap_feats, dim=0) | |
| cap_feats = self.cap_embedder(cap_feats) | |
| cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token | |
| cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) | |
| cap_freqs_cis = list( | |
| self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) | |
| ) | |
| cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) | |
| cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) | |
| # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors | |
| cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] | |
| cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) | |
| for i, seq_len in enumerate(cap_item_seqlens): | |
| cap_attn_mask[i, :seq_len] = 1 |
In these lines, the sequence length for each cap_feats sample is put into cap_item_seqlens and is then used to create an attention mask in line 611.
The problem with this is that cap_feats has been overwritten before in these lines:
diffusers/src/diffusers/models/transformers/transformer_z_image.py
Lines 552 to 560 in 52766e6
| ( | |
| x, | |
| cap_feats, | |
| x_size, | |
| x_pos_ids, | |
| cap_pos_ids, | |
| x_inner_pad_mask, | |
| cap_inner_pad_mask, | |
| ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) |
cap_feats therefore isn't anymore what the caller has passed. instead, cap_feats has already been padded by patchify_and_embed. Therefore, all cap_item_seqlens are identical to cap_max_item_seqlen.
This leads to unmasked text tokens during attention, which probably wasn't the intention here.
Reproduction
put a breakpoint in this line:
all cap_items_seqlens are identical, cap_attn_mask is all True - even if multiple different text lengths were passed to the transformer.