Skip to content

Commit

Permalink
Rename Phi-3 rope scaling type (huggingface#31436)
Browse files Browse the repository at this point in the history
* renamed phi3 rope_scaling type

* fixed trailing whitespaces

* fixed test

* added warning

* fixed format
  • Loading branch information
garg-amit authored and zucchini-nlp committed Jul 24, 2024
1 parent a3fe765 commit fd8d051
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 11 deletions.
20 changes: 17 additions & 3 deletions src/transformers/models/phi3/configuration_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Phi3Config(PretrainedConfig):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
divided by the number of attention heads divided by 2.
bos_token_id (`int`, *optional*, defaults to 1):
Expand Down Expand Up @@ -155,6 +155,7 @@ def __init__(
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_adjustment()
self._rope_scaling_validation()
self.sliding_window = sliding_window

Expand All @@ -166,6 +167,19 @@ def __init__(
**kwargs,
)

def _rope_scaling_adjustment(self):
"""
Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
"""
if self.rope_scaling is None:
return

rope_scaling_type = self.rope_scaling.get("type", None)

# For backward compatibility if previous version used "su" or "yarn"
if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
self.rope_scaling["type"] = "longrope"

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand All @@ -181,8 +195,8 @@ def _rope_scaling_validation(self):
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]:
raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
if not (
isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
Expand Down
58 changes: 51 additions & 7 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""PyTorch Phi-3 model."""

import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -106,6 +107,51 @@ def forward(self, x, position_ids, seq_len=None):

class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(self, dim, config, device=None):
warnings.warn(
"The class Phi3SuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please"
" use Phi3LongRoPEScaledRotaryEmbedding instead.",
FutureWarning,
)
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)

self.short_factor = config.rope_scaling["short_factor"]
self.long_factor = config.rope_scaling["long_factor"]
self.original_max_position_embeddings = config.original_max_position_embeddings

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(self, dim, config, device=None):
warnings.warn(
"The class Phi3YarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers",
FutureWarning,
)
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)

self.short_factor = config.rope_scaling["short_factor"]
Expand Down Expand Up @@ -138,14 +184,14 @@ def forward(self, x, position_ids, seq_len=None):
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
scaling_factor = 0.1 * math.log(scale) + 1.0

cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(self, dim, config, device=None):
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)

Expand Down Expand Up @@ -179,7 +225,7 @@ def forward(self, x, position_ids, seq_len=None):
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = 0.1 * math.log(scale) + 1.0
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))

cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
Expand Down Expand Up @@ -300,10 +346,8 @@ def _init_rope(self):
)
else:
scaling_type = self.config.rope_scaling["type"]
if scaling_type == "su":
self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
elif scaling_type == "yarn":
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
if scaling_type == "longrope":
self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

Expand Down
2 changes: 1 addition & 1 deletion tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_phi3_sequence_classification_model_for_multi_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))

@parameterized.expand([("su",), ("yarn",)])
@parameterized.expand([("longrope",)])
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
Expand Down

0 comments on commit fd8d051

Please sign in to comment.