Skip to content

Commit

Permalink
Add scale and weight dtype check for quantization config (#1519)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss authored Apr 28, 2024
1 parent ef82cd3 commit 307c1a8
Showing 1 changed file with 40 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1738,10 +1738,49 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if ((not CpuInfo().bf16 and quantization_config.compute_dtype == "bf16")
or (use_cpu and quantization_config.compute_dtype == "fp16")):
quantization_config.compute_dtype = "fp32"

if quantization_config.scale_dtype is None:
quantization_config.scale_dtype = "fp32"
if quantization_config.scale_dtype not in ["fp32", "fp16", "bf16"]:
logger.warning("scale_dtype only supports fp32, bf16, fp16.")
quantization_config.scale_dtype = "fp32"
logger.warning("fp32 scale_dtype is used, please change the config.json if you don't want to use it.")

# weight dtype is higher priority than bits in config.json when both existed.
if quantization_config.weight_dtype is None:
quantization_config.weight_dtype = "int4_clip"
if quantization_config.bits == 4:
quantization_config.weight_dtype = "int4_clip"
logger.info(
"{} quantization weight_dtype is used due to bits is 4 in config.json.".format(
quantization_config.weight_dtype)
)
elif quantization_config.bits == 8:
quantization_config.weight_dtype = "int8"
logger.info(
"{} quantization weight_dtype is used due to bits is 8 in config.json.".format(
quantization_config.weight_dtype)
)
else:
logger.warning("bits number only supports 4, 8.")
quantization_config.weight_dtype = "int4_clip"
logger.warning(
"int4_clip weight_dtype is used, please change the config.json if you don't want to use it.")
else:
if quantization_config.weight_dtype not in ["int4_fullrange",
"int4_clip",
"int8",
"fp8_e5m2",
"fp8_e4m3",
"nf4",
"fp4_e2m1_bnb",
"fp4_e2m1"]:
logger.warning("Please provide the correct bits number or weight_dtype in config.json.")
raise ValueError(
f"weight_dtype must be a string in "
f"'int8', 'int4_fullrange', 'int4_clip', 'nf4', 'fp4_e2m1_bnb', 'fp4_e2m1', 'fp8_e5m2, fp8_e4m3'"
)
else:
logger.info("{} quantization weight_dtype is used.".format(quantization_config.weight_dtype))

init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts.append(init_empty_weights())
Expand Down

0 comments on commit 307c1a8

Please sign in to comment.