Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for NaNs in Smooth Quant #1872

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/sparseml/modifiers/smoothquant/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

_LOGGER = logging.getLogger(__name__)

MINIMUM_SMOOTHING_SCALE = 1e-5

__all__ = ["SmoothQuantModifierPyTorch"]


Expand Down Expand Up @@ -156,6 +158,9 @@ def _apply_smoothing(self):
balance_layers = mapping.balance_layers

scales = self._calculate_smoothing_scales(balance_layers, activation_scales)
scales = torch.maximum(
scales, torch.Tensor([MINIMUM_SMOOTHING_SCALE]).to(scales.device)
)

# invert the smoothing in the following layers
for layer in balance_layers:
Expand Down
20 changes: 14 additions & 6 deletions src/sparseml/transformers/sparsification/obcq/obcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def one_shot(
model_path: str,
dataset_name: str,
num_samples: int = 128,
sequence_length: Optional[int] = None,
device: str = "cuda:0",
deploy_dir: Optional[str] = ".",
recipe_file: Optional[str] = None,
Expand All @@ -59,6 +60,7 @@ def one_shot(
:param model_path: path to Hugging Face stub
:param dataset_name: Dataset to extract calibration data from
:param num_samples: Number of samples to extract from the dataset
:param sequence_length: Maximum input sequence length to the model
:param device: Device (cuda:index or cpu) to use for computation
:param deploy_dir: The output directory to save the model to
:param recipe_file: recipe containing SparseGPT configuration
Expand Down Expand Up @@ -88,16 +90,15 @@ def one_shot(
if "opt" in model_type:
model_loader_fn = SparseCausalLM.opt_model_from_pretrained
forward_fn = opt_forward
elif "llama" in model_type:
model_loader_fn = SparseCausalLM.llama_model_from_pretrained
forward_fn = llama_forward
elif "mistral" in model_type:
elif "llama" in model_type or "mistral" in model_type:
model_loader_fn = SparseCausalLM.auto_model_from_pretrained
forward_fn = llama_forward
else:
raise ValueError(f"model_path={model_path} should be one of {SUPPORTED_MODELS}")
torch_dtype = _parse_dtype(precision)
model = model_loader_fn(model_path, torch_dtype=torch_dtype)
model = model_loader_fn(
model_path, sequence_length=sequence_length, torch_dtype=torch_dtype
)

if dataset_name not in SUPPORTED_DATASETS:
raise ValueError(
Expand All @@ -106,7 +107,7 @@ def one_shot(
dataset = TransformersDataset.load_from_registry(
dataset_name,
model=model_path,
seqlen=model.seqlen,
seqlen=sequence_length,
nsamples=num_samples,
seed=0,
split="train",
Expand Down Expand Up @@ -191,6 +192,12 @@ def _fallback_to_cpu(device):
parser.add_argument(
"--nsamples", type=int, default=512, help="Number of calibration data samples"
)
parser.add_argument(
"--seqlen",
type=int,
default=None,
help="Maximum input sequence length to the model",
)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--deploy-dir", type=str, default=".")
parser.add_argument("--recipe", type=str, default=None)
Expand All @@ -215,6 +222,7 @@ def _fallback_to_cpu(device):
dataset_name=args.dataset,
deploy_dir=args.deploy_dir,
num_samples=args.nsamples,
sequence_length=args.seqlen,
device=args.device,
recipe_file=args.recipe,
precision=args.precision,
Expand Down
35 changes: 14 additions & 21 deletions src/sparseml/transformers/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
LlamaForCausalLM,
OPTForCausalLM,
)
from transformers.file_utils import WEIGHTS_NAME
Expand Down Expand Up @@ -430,12 +429,15 @@ class SparseCausalLM:

@staticmethod
def opt_model_from_pretrained(
model_path: str, torch_dtype: Union[str, torch.dtype] = "auto"
model_path: str,
sequence_length: Optional[int] = None,
torch_dtype: Union[str, torch.dtype] = "auto",
) -> torch.nn.Module:
"""
Load a pretrained OPT model from the specified hugging face path

:param model_path: hugging face or local path to model
:param sequence_length: maximum allowable tokens in input sequence
:param torch_dtype: precision to load model weights in as
:return: loaded pretrained model
"""
Expand All @@ -449,41 +451,32 @@ def skip(*args, **kwargs):

model = OPTForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype)
model.eval()
model.seqlen = model.config.max_position_embeddings
return model

@staticmethod
def llama_model_from_pretrained(
model_path: str, torch_dtype: Union[str, torch.dtype] = "auto"
) -> torch.nn.Module:
"""
Load a pretrained Llama model from the specified hugging face path

:param model_path: hugging face path to model
:param torch_dtype: precision to load model weights in as
:return: loaded pretrained model
"""
model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype)
model.eval()
model.seqlen = model.config.max_position_embeddings
model.seqlen = (
sequence_length if sequence_length else model.config.max_position_embeddings
)
return model

@staticmethod
def auto_model_from_pretrained(
model_path: str, torch_dtype: Union[str, torch.dtype] = "auto"
model_path: str,
sequence_length: Optional[int] = None,
torch_dtype: Union[str, torch.dtype] = "auto",
) -> torch.nn.Module:
"""
Load a pretrained model using auto from the specified hugging face path

:param model_path: hugging face path to model
:param sequence_length: maximum allowable tokens in input sequence
:param torch_dtype: precision to load model weights in as
:return: loaded pretrained model
"""
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch_dtype
)
model.eval()
model.seqlen = model.config.max_position_embeddings
model.seqlen = (
sequence_length if sequence_length else model.config.max_position_embeddings
)
return model


Expand Down
Loading