Skip to content

Commit

Permalink
addressed missed out comments in #1, except checkpointing
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Jun 22, 2024
1 parent 6d0760e commit a2718fa
Show file tree
Hide file tree
Showing 20 changed files with 225 additions and 407 deletions.
127 changes: 0 additions & 127 deletions src/instructlab/dolomite/hf_models/config.py

This file was deleted.

4 changes: 0 additions & 4 deletions src/instructlab/dolomite/hf_models/defaults.py

This file was deleted.

5 changes: 2 additions & 3 deletions src/instructlab/dolomite/hf_models/modeling_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
repeat_key_value,
split_query_key_value_tensor_for_attention,
)
from .embedding import Embedding
from .linear import Linear
from .normalization import RMSNorm, get_normalization_function
from .position_embedding import Alibi, RoPE, YaRNScaledRoPE, apply_rotary_pos_emb
from .position_embedding import Alibi, RoPE, apply_rotary_pos_emb

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

# Local
from ...config import CommonConfig
from ...models.gpt_dolomite.config import GPTDolomiteConfig
from ...enums import AttentionHeadType
from .base import Attention
from .flash import FlashAttention2
Expand Down Expand Up @@ -48,7 +48,7 @@


def get_attention_module(
config: CommonConfig,
config: GPTDolomiteConfig,
causal: bool,
attention_implementation: str,
use_padding_free_transformer: bool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
from transformers import DynamicCache
import torch
import torch.nn.functional as F
from torch.nn import Linear # replaces ParameterizedLinear

# Local
from ...config import CommonConfig
from ...models.gpt_dolomite.config import GPTDolomiteConfig
from ...enums import AttentionHeadType, PositionEmbeddingType
from ..linear import Linear
from ..position_embedding import apply_rotary_pos_emb
from .utils import repeat_key_value


class Attention(torch.nn.Module):
def __init__(
self, config: CommonConfig, causal: bool, layer_idx: int = None
self, config: GPTDolomiteConfig, causal: bool, layer_idx: int = None
) -> None:
super().__init__()

Expand Down

This file was deleted.

8 changes: 0 additions & 8 deletions src/instructlab/dolomite/hf_models/modeling_utils/linear.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
import torch

# Local
from .layernorm import get_layernorm
from .rmsnorm import RMSNorm, get_rmsnorm
from .norms import RMSNorm, get_layernorm, get_rmsnorm

_NORMALIZATION_FUNCTIONS = {
"layernorm": get_layernorm,
"rmsnorm": get_rmsnorm,
}


def get_normalization_function(
name: str,
normalized_shape: int,
Expand All @@ -30,3 +28,4 @@ def get_normalization_function(
raise ValueError(
f"unexpected `normalization_implementation` {normalization_implementation}"
)

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# ----------------------------------------------------------------
# Extracted from https://github.com/ibm-granite/dolomite-engine
# ----------------------------------------------------------------

# Third Party
import torch

# Standard
import numbers

# ---------------- LayerNorm ---------------

_LAYERNORM_MODULES = {
"torch": torch.nn.LayerNorm,
}

def get_layernorm(
normalized_shape: int,
eps: float,
normalization_implementation: str = "torch",
) -> torch.nn.LayerNorm:
if normalization_implementation in _LAYERNORM_MODULES:
return _LAYERNORM_MODULES[normalization_implementation](
normalized_shape=normalized_shape, eps=eps
)

raise ValueError(
f"unexpected `normalization_implementation` {normalization_implementation}"
)

# --------------- RMS Norm ---------------
# ----------------------------------------------------------------
# Extracted from https://github.com/ibm-granite/dolomite-engine
# ----------------------------------------------------------------

class RMSNorm(torch.nn.Module):
def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None:
super().__init__()

self.weight = torch.nn.Parameter(torch.ones(normalized_shape))
self.eps = eps

if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = normalized_shape

def forward(self, input: torch.Tensor) -> torch.Tensor:
input_dtype = input.dtype

input = input.to(torch.float32)
variance = input.pow(2).mean(-1, keepdim=True)
input = input * torch.rsqrt(variance + self.eps)

return self.weight * input.to(input_dtype)

def extra_repr(self) -> str:
return f"{self.normalized_shape}, eps={self.eps}"

def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)

_RMSNORM_MODULES = {"torch": RMSNorm}

def get_rmsnorm(
normalized_shape: int,
eps: float,
normalization_implementation: str = "torch",
) -> torch.nn.LayerNorm:
if normalization_implementation in _RMSNORM_MODULES:
return _RMSNORM_MODULES[normalization_implementation](
normalized_shape=normalized_shape, eps=eps
)

raise ValueError(
f"unexpected `normalization_implementation` {normalization_implementation}"
)

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# ----------------------------------------------------------------
# Local
from .alibi import Alibi
from .rope import RoPE, YaRNScaledRoPE, apply_rotary_pos_emb
from .rope import RoPE, apply_rotary_pos_emb
Loading

0 comments on commit a2718fa

Please sign in to comment.