@@ -94,7 +94,6 @@ def forward(self, hidden_states):
9494class LlamaRotaryEmbedding(nn.Module):
9595 def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
9696 super().__init__()
97-
9897 self.dim = dim
9998 self.max_position_embeddings = max_position_embeddings
10099 self.base = base
@@ -118,6 +117,9 @@ def cos_cached(self):
118117 return self._cos_cached
119118
120119 def forward(self, x, position_ids, seq_len=None):
120+ if seq_len is not None:
121+ logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.")
122+
121123 # x: [bs, num_attention_heads, seq_len, head_size]
122124 inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
123125 position_ids_expanded = position_ids[:, None, :].float()
@@ -138,16 +140,11 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
138140 self.scaling_factor = scaling_factor
139141 super().__init__(dim, max_position_embeddings, base, device)
140142
141- def _set_cos_sin_cache(self, seq_len, device, dtype):
142- self.max_seq_len_cached = seq_len
143- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
144- t = t / self.scaling_factor
145-
146- freqs = torch.outer(t, self.inv_freq)
147- # Different from paper, but it uses a different permutation in order to obtain the same calculation
148- emb = torch.cat((freqs, freqs), dim=-1)
149- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
150- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
143+ def forward(self, x, position_ids, seq_len=None):
144+ # difference to the original RoPE: a scaling factor is aplied to the position ids
145+ position_ids = position_ids.float() / self.scaling_factor
146+ cos, sin = super().forward(x, position_ids, seq_len)
147+ return cos, sin
151148
152149
153150class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
@@ -157,23 +154,20 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
157154 self.scaling_factor = scaling_factor
158155 super().__init__(dim, max_position_embeddings, base, device)
159156
160- def _set_cos_sin_cache (self, seq_len, device, dtype ):
161- self.max_seq_len_cached = seq_len
162-
157+ def forward (self, x, position_ids, seq_len=None ):
158+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
159+ seq_len = torch.max(position_ids) + 1
163160 if seq_len > self.max_position_embeddings:
164161 base = self.base * (
165162 (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
166163 ) ** (self.dim / (self.dim - 2))
167- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
168- self.register_buffer("inv_freq", inv_freq, persistent=False )
169-
170- t = torch.arange( self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
164+ inv_freq = 1.0 / (
165+ base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim )
166+ )
167+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
171168
172- freqs = torch.outer(t, self.inv_freq)
173- # Different from paper, but it uses a different permutation in order to obtain the same calculation
174- emb = torch.cat((freqs, freqs), dim=-1)
175- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
176- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
169+ cos, sin = super().forward(x, position_ids, seq_len)
170+ return cos, sin
177171
178172
179173def rotate_half(x):
@@ -183,17 +177,16 @@ def rotate_half(x):
183177 return torch.cat((-x2, x1), dim=-1)
184178
185179
186- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
180+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None , unsqueeze_dim=1):
187181 """Applies Rotary Position Embedding to the query and key tensors.
188182
189183 Args:
190184 q (`torch.Tensor`): The query tensor.
191185 k (`torch.Tensor`): The key tensor.
192186 cos (`torch.Tensor`): The cosine part of the rotary embedding.
193187 sin (`torch.Tensor`): The sine part of the rotary embedding.
194- position_ids (`torch.Tensor`):
195- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
196- used to pass offsetted position ids when working with a KV-cache.
188+ position_ids (`torch.Tensor`, *optional*):
189+ Deprecated and unused.
197190 unsqueeze_dim (`int`, *optional*, defaults to 1):
198191 The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
199192 sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -360,8 +353,8 @@ def forward(
360353 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
361354
362355 past_key_value = getattr(self, "past_key_value", past_key_value)
363- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None )
364- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None )
356+ cos, sin = self.rotary_emb(value_states, position_ids)
357+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
365358
366359 if past_key_value is not None:
367360 # sin and cos are specific to RoPE models; position_ids needed for the static cache
@@ -447,8 +440,8 @@ def forward(
447440 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
448441 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
449442
450- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None )
451- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None )
443+ cos, sin = self.rotary_emb(value_states, position_ids)
444+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
452445
453446 past_key_value = getattr(self, "past_key_value", past_key_value)
454447
@@ -645,8 +638,8 @@ def forward(
645638 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
646639 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
647640
648- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None )
649- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None )
641+ cos, sin = self.rotary_emb(value_states, position_ids)
642+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
650643
651644 past_key_value = getattr(self, "past_key_value", past_key_value)
652645
0 commit comments