Skip to content

Commit 3c16c1a

Browse files
authored
Use indices as position_ids in modernebert (#41789)
* Use indices as position_ids in modernebert * Move position_ids init to the branch
1 parent b9f90dc commit 3c16c1a

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

src/transformers/models/modernbert/modeling_modernbert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,8 @@ def forward(
905905
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
906906
inputs=inputs_embeds, attention_mask=attention_mask
907907
)
908+
if position_ids is None:
909+
position_ids = indices.unsqueeze(0)
908910
else:
909911
if position_ids is None:
910912
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)

src/transformers/models/modernbert/modular_modernbert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,8 @@ def forward(
10141014
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
10151015
inputs=inputs_embeds, attention_mask=attention_mask
10161016
)
1017+
if position_ids is None:
1018+
position_ids = indices.unsqueeze(0)
10171019
else:
10181020
if position_ids is None:
10191021
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)

0 commit comments

Comments
 (0)