-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
addressed missed out comments in #1, except checkpointing
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
- Loading branch information
Showing
20 changed files
with
225 additions
and
407 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 0 additions & 8 deletions
8
src/instructlab/dolomite/hf_models/modeling_utils/embedding.py
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 0 additions & 24 deletions
24
src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py
This file was deleted.
Oops, something went wrong.
76 changes: 76 additions & 0 deletions
76
src/instructlab/dolomite/hf_models/modeling_utils/normalization/norms.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# ---------------------------------------------------------------- | ||
# Extracted from https://github.com/ibm-granite/dolomite-engine | ||
# ---------------------------------------------------------------- | ||
|
||
# Third Party | ||
import torch | ||
|
||
# Standard | ||
import numbers | ||
|
||
# ---------------- LayerNorm --------------- | ||
|
||
_LAYERNORM_MODULES = { | ||
"torch": torch.nn.LayerNorm, | ||
} | ||
|
||
def get_layernorm( | ||
normalized_shape: int, | ||
eps: float, | ||
normalization_implementation: str = "torch", | ||
) -> torch.nn.LayerNorm: | ||
if normalization_implementation in _LAYERNORM_MODULES: | ||
return _LAYERNORM_MODULES[normalization_implementation]( | ||
normalized_shape=normalized_shape, eps=eps | ||
) | ||
|
||
raise ValueError( | ||
f"unexpected `normalization_implementation` {normalization_implementation}" | ||
) | ||
|
||
# --------------- RMS Norm --------------- | ||
# ---------------------------------------------------------------- | ||
# Extracted from https://github.com/ibm-granite/dolomite-engine | ||
# ---------------------------------------------------------------- | ||
|
||
class RMSNorm(torch.nn.Module): | ||
def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None: | ||
super().__init__() | ||
|
||
self.weight = torch.nn.Parameter(torch.ones(normalized_shape)) | ||
self.eps = eps | ||
|
||
if isinstance(normalized_shape, numbers.Integral): | ||
normalized_shape = (normalized_shape,) | ||
self.normalized_shape = normalized_shape | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
input_dtype = input.dtype | ||
|
||
input = input.to(torch.float32) | ||
variance = input.pow(2).mean(-1, keepdim=True) | ||
input = input * torch.rsqrt(variance + self.eps) | ||
|
||
return self.weight * input.to(input_dtype) | ||
|
||
def extra_repr(self) -> str: | ||
return f"{self.normalized_shape}, eps={self.eps}" | ||
|
||
def reset_parameters(self) -> None: | ||
torch.nn.init.ones_(self.weight) | ||
|
||
_RMSNORM_MODULES = {"torch": RMSNorm} | ||
|
||
def get_rmsnorm( | ||
normalized_shape: int, | ||
eps: float, | ||
normalization_implementation: str = "torch", | ||
) -> torch.nn.LayerNorm: | ||
if normalization_implementation in _RMSNORM_MODULES: | ||
return _RMSNORM_MODULES[normalization_implementation]( | ||
normalized_shape=normalized_shape, eps=eps | ||
) | ||
|
||
raise ValueError( | ||
f"unexpected `normalization_implementation` {normalization_implementation}" | ||
) |
25 changes: 0 additions & 25 deletions
25
src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py
This file was deleted.
Oops, something went wrong.
35 changes: 0 additions & 35 deletions
35
src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.