Skip to content

Commit

Permalink
Change Transformers export default sequence length to max_position_em…
Browse files Browse the repository at this point in the history
…beddings (#1826)

* Change Transformers export default sequence length to max_position_embeddings

* Fix style
  • Loading branch information
mgoin authored and bfineran committed Nov 16, 2023
1 parent a17ffe4 commit ec4a11e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
23 changes: 17 additions & 6 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
Path to directory where model files for weights,
config, and tokenizer are stored
--sequence_length SEQUENCE_LENGTH
Sequence length to use. Default is 384. Can be
overwritten later
Sequence length to use. Default is
`config.max_position_embeddings`. Can be overwritten
later
--no_convert_qat Set flag to not perform QAT to fully quantized
conversion after export
--finetuning_task FINETUNING_TASK
Expand Down Expand Up @@ -238,7 +239,7 @@ def load_task_dataset(
def export_transformer_to_onnx(
task: str,
model_path: str,
sequence_length: int = 384,
sequence_length: Optional[int] = None,
convert_qat: bool = True,
finetuning_task: Optional[str] = None,
onnx_file_name: str = MODEL_ONNX_NAME,
Expand Down Expand Up @@ -294,6 +295,13 @@ def export_transformer_to_onnx(
trust_remote_code=trust_remote_code,
**config_args,
)

if sequence_length is None:
_LOGGER.info(
f"Using default sequence length of {config.max_position_embeddings}"
)
sequence_length = config.max_position_embeddings

tokenizer = AutoTokenizer.from_pretrained(
model_path, model_max_length=sequence_length
)
Expand Down Expand Up @@ -514,8 +522,11 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument(
"--sequence_length",
type=int,
default=384,
help="Sequence length to use. Default is 384. Can be overwritten later",
default=None,
help=(
"Sequence length to use. Default is `config.max_position_embeddings`. "
"Can be overwritten later"
),
)
parser.add_argument(
"--no_convert_qat",
Expand Down Expand Up @@ -579,7 +590,7 @@ def _parse_args() -> argparse.Namespace:
def export(
task: str,
model_path: str,
sequence_length: int,
sequence_length: Optional[int],
no_convert_qat: bool,
finetuning_task: str,
onnx_file_name: str,
Expand Down
20 changes: 14 additions & 6 deletions src/sparseml/transformers/sparsification/obcq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
Path to directory where model files for weights,
config, and tokenizer are stored
--sequence_length SEQUENCE_LENGTH
Sequence length to use. Default is 384. Can be
overwritten later
Sequence length to use. Default is
`config.max_position_embeddings`. Can be overwritten
later
--no_convert_qat Set flag to not perform QAT to fully quantized
conversion after export
--onnx_file_name ONNX_FILE_NAME
Expand Down Expand Up @@ -304,7 +305,7 @@ def load_task_dataset(
def export_transformer_to_onnx(
task: str,
model_path: str,
sequence_length: int = 384,
sequence_length: Optional[int] = None,
convert_qat: bool = True,
onnx_file_name: str = MODEL_ONNX_NAME,
num_export_samples: int = 0,
Expand Down Expand Up @@ -353,6 +354,10 @@ def export_transformer_to_onnx(
trust_remote_code=trust_remote_code,
**config_args,
)

if sequence_length is None:
sequence_length = config.max_position_embeddings

tokenizer = AutoTokenizer.from_pretrained(
model_path, model_max_length=sequence_length
)
Expand Down Expand Up @@ -543,8 +548,11 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument(
"--sequence_length",
type=int,
default=384,
help="Sequence length to use. Default is 384. Can be overwritten later",
default=None,
help=(
"Sequence length to use. Default is `config.max_position_embeddings`. "
"Can be overwritten later"
),
)
parser.add_argument(
"--no_convert_qat",
Expand Down Expand Up @@ -586,7 +594,7 @@ def _parse_args() -> argparse.Namespace:
def export(
task: str,
model_path: str,
sequence_length: int,
sequence_length: Optional[int],
no_convert_qat: bool,
onnx_file_name: str,
num_export_samples: int = 0,
Expand Down

0 comments on commit ec4a11e

Please sign in to comment.