Skip to content

Commit

Permalink
Fix for NaNs in Smooth Quant (#1872)
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat committed Dec 1, 2023
1 parent abdded5 commit c722fc3
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 27 deletions.
5 changes: 5 additions & 0 deletions src/sparseml/modifiers/smoothquant/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

_LOGGER = logging.getLogger(__name__)

MINIMUM_SMOOTHING_SCALE = 1e-5

__all__ = ["SmoothQuantModifierPyTorch"]


Expand Down Expand Up @@ -156,6 +158,9 @@ def _apply_smoothing(self):
balance_layers = mapping.balance_layers

scales = self._calculate_smoothing_scales(balance_layers, activation_scales)
scales = torch.maximum(
scales, torch.Tensor([MINIMUM_SMOOTHING_SCALE]).to(scales.device)
)

# invert the smoothing in the following layers
for layer in balance_layers:
Expand Down
20 changes: 14 additions & 6 deletions src/sparseml/transformers/sparsification/obcq/obcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def one_shot(
model_path: str,
dataset_name: str,
num_samples: int = 128,
sequence_length: Optional[int] = None,
device: str = "cuda:0",
deploy_dir: Optional[str] = ".",
recipe_file: Optional[str] = None,
Expand All @@ -59,6 +60,7 @@ def one_shot(
:param model_path: path to Hugging Face stub
:param dataset_name: Dataset to extract calibration data from
:param num_samples: Number of samples to extract from the dataset
:param sequence_length: Maximum input sequence length to the model
:param device: Device (cuda:index or cpu) to use for computation
:param deploy_dir: The output directory to save the model to
:param recipe_file: recipe containing SparseGPT configuration
Expand Down Expand Up @@ -88,16 +90,15 @@ def one_shot(
if "opt" in model_type:
model_loader_fn = SparseCausalLM.opt_model_from_pretrained
forward_fn = opt_forward
elif "llama" in model_type:
model_loader_fn = SparseCausalLM.llama_model_from_pretrained
forward_fn = llama_forward
elif "mistral" in model_type:
elif "llama" in model_type or "mistral" in model_type:
model_loader_fn = SparseCausalLM.auto_model_from_pretrained
forward_fn = llama_forward
else:
raise ValueError(f"model_path={model_path} should be one of {SUPPORTED_MODELS}")
torch_dtype = _parse_dtype(precision)
model = model_loader_fn(model_path, torch_dtype=torch_dtype)
model = model_loader_fn(
model_path, sequence_length=sequence_length, torch_dtype=torch_dtype
)

if dataset_name not in SUPPORTED_DATASETS:
raise ValueError(
Expand All @@ -106,7 +107,7 @@ def one_shot(
dataset = TransformersDataset.load_from_registry(
dataset_name,
model=model_path,
seqlen=model.seqlen,
seqlen=sequence_length,
nsamples=num_samples,
seed=0,
split="train",
Expand Down Expand Up @@ -191,6 +192,12 @@ def _fallback_to_cpu(device):
parser.add_argument(
"--nsamples", type=int, default=512, help="Number of calibration data samples"
)
parser.add_argument(
"--seqlen",
type=int,
default=None,
help="Maximum input sequence length to the model",
)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--deploy-dir", type=str, default=".")
parser.add_argument("--recipe", type=str, default=None)
Expand All @@ -215,6 +222,7 @@ def _fallback_to_cpu(device):
dataset_name=args.dataset,
deploy_dir=args.deploy_dir,
num_samples=args.nsamples,
sequence_length=args.seqlen,
device=args.device,
recipe_file=args.recipe,
precision=args.precision,
Expand Down
35 changes: 14 additions & 21 deletions src/sparseml/transformers/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
LlamaForCausalLM,
OPTForCausalLM,
)
from transformers.file_utils import WEIGHTS_NAME
Expand Down Expand Up @@ -430,12 +429,15 @@ class SparseCausalLM:

@staticmethod
def opt_model_from_pretrained(
model_path: str, torch_dtype: Union[str, torch.dtype] = "auto"
model_path: str,
sequence_length: Optional[int] = None,
torch_dtype: Union[str, torch.dtype] = "auto",
) -> torch.nn.Module:
"""
Load a pretrained OPT model from the specified hugging face path
:param model_path: hugging face or local path to model
:param sequence_length: maximum allowable tokens in input sequence
:param torch_dtype: precision to load model weights in as
:return: loaded pretrained model
"""
Expand All @@ -449,41 +451,32 @@ def skip(*args, **kwargs):

model = OPTForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype)
model.eval()
model.seqlen = model.config.max_position_embeddings
return model

@staticmethod
def llama_model_from_pretrained(
model_path: str, torch_dtype: Union[str, torch.dtype] = "auto"
) -> torch.nn.Module:
"""
Load a pretrained Llama model from the specified hugging face path
:param model_path: hugging face path to model
:param torch_dtype: precision to load model weights in as
:return: loaded pretrained model
"""
model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype)
model.eval()
model.seqlen = model.config.max_position_embeddings
model.seqlen = (
sequence_length if sequence_length else model.config.max_position_embeddings
)
return model

@staticmethod
def auto_model_from_pretrained(
model_path: str, torch_dtype: Union[str, torch.dtype] = "auto"
model_path: str,
sequence_length: Optional[int] = None,
torch_dtype: Union[str, torch.dtype] = "auto",
) -> torch.nn.Module:
"""
Load a pretrained model using auto from the specified hugging face path
:param model_path: hugging face path to model
:param sequence_length: maximum allowable tokens in input sequence
:param torch_dtype: precision to load model weights in as
:return: loaded pretrained model
"""
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch_dtype
)
model.eval()
model.seqlen = model.config.max_position_embeddings
model.seqlen = (
sequence_length if sequence_length else model.config.max_position_embeddings
)
return model


Expand Down

0 comments on commit c722fc3

Please sign in to comment.