diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index a15a1692b4f..02a8bb69c61 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -40,8 +40,9 @@ Path to directory where model files for weights, config, and tokenizer are stored --sequence_length SEQUENCE_LENGTH - Sequence length to use. Default is 384. Can be - overwritten later + Sequence length to use. Default is + `config.max_position_embeddings`. Can be overwritten + later --no_convert_qat Set flag to not perform QAT to fully quantized conversion after export --finetuning_task FINETUNING_TASK @@ -238,7 +239,7 @@ def load_task_dataset( def export_transformer_to_onnx( task: str, model_path: str, - sequence_length: int = 384, + sequence_length: Optional[int] = None, convert_qat: bool = True, finetuning_task: Optional[str] = None, onnx_file_name: str = MODEL_ONNX_NAME, @@ -294,6 +295,13 @@ def export_transformer_to_onnx( trust_remote_code=trust_remote_code, **config_args, ) + + if sequence_length is None: + _LOGGER.info( + f"Using default sequence length of {config.max_position_embeddings}" + ) + sequence_length = config.max_position_embeddings + tokenizer = AutoTokenizer.from_pretrained( model_path, model_max_length=sequence_length ) @@ -514,8 +522,11 @@ def _parse_args() -> argparse.Namespace: parser.add_argument( "--sequence_length", type=int, - default=384, - help="Sequence length to use. Default is 384. Can be overwritten later", + default=None, + help=( + "Sequence length to use. Default is `config.max_position_embeddings`. " + "Can be overwritten later" + ), ) parser.add_argument( "--no_convert_qat", @@ -579,7 +590,7 @@ def _parse_args() -> argparse.Namespace: def export( task: str, model_path: str, - sequence_length: int, + sequence_length: Optional[int], no_convert_qat: bool, finetuning_task: str, onnx_file_name: str, diff --git a/src/sparseml/transformers/sparsification/obcq/export.py b/src/sparseml/transformers/sparsification/obcq/export.py index ad29b3c5056..dda9712a989 100644 --- a/src/sparseml/transformers/sparsification/obcq/export.py +++ b/src/sparseml/transformers/sparsification/obcq/export.py @@ -40,8 +40,9 @@ Path to directory where model files for weights, config, and tokenizer are stored --sequence_length SEQUENCE_LENGTH - Sequence length to use. Default is 384. Can be - overwritten later + Sequence length to use. Default is + `config.max_position_embeddings`. Can be overwritten + later --no_convert_qat Set flag to not perform QAT to fully quantized conversion after export --onnx_file_name ONNX_FILE_NAME @@ -304,7 +305,7 @@ def load_task_dataset( def export_transformer_to_onnx( task: str, model_path: str, - sequence_length: int = 384, + sequence_length: Optional[int] = None, convert_qat: bool = True, onnx_file_name: str = MODEL_ONNX_NAME, num_export_samples: int = 0, @@ -353,6 +354,10 @@ def export_transformer_to_onnx( trust_remote_code=trust_remote_code, **config_args, ) + + if sequence_length is None: + sequence_length = config.max_position_embeddings + tokenizer = AutoTokenizer.from_pretrained( model_path, model_max_length=sequence_length ) @@ -543,8 +548,11 @@ def _parse_args() -> argparse.Namespace: parser.add_argument( "--sequence_length", type=int, - default=384, - help="Sequence length to use. Default is 384. Can be overwritten later", + default=None, + help=( + "Sequence length to use. Default is `config.max_position_embeddings`. " + "Can be overwritten later" + ), ) parser.add_argument( "--no_convert_qat", @@ -586,7 +594,7 @@ def _parse_args() -> argparse.Namespace: def export( task: str, model_path: str, - sequence_length: int, + sequence_length: Optional[int], no_convert_qat: bool, onnx_file_name: str, num_export_samples: int = 0,