Skip to content

Commit

Permalink
Set proper trainer output dir (#1696)
Browse files Browse the repository at this point in the history
  • Loading branch information
natuan committed Aug 2, 2023
1 parent d199188 commit b52a13d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def load_task_dataset(
or task == "sentiment-analysis"
or task == "text-classification"
):

from sparseml.transformers.text_classification import (
DataTrainingArguments,
get_tokenized_text_classification_dataset,
Expand Down Expand Up @@ -315,7 +314,8 @@ def export_transformer_to_onnx(

model = model.train()

args = DeviceCPUTrainingArgs(output_dir="tmp_trainer")
trainer_output_dir = os.path.dirname(model_path)
args = DeviceCPUTrainingArgs(output_dir=trainer_output_dir)
trainer = Trainer(
model=model,
args=args,
Expand Down

0 comments on commit b52a13d

Please sign in to comment.