diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 207d8ebdffce30..932fd937d26f3a 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -16,7 +16,6 @@ Callbacks to use with the Trainer class and customize the training loop. """ -import copy import dataclasses import json from dataclasses import dataclass @@ -617,13 +616,16 @@ def on_predict(self, args, state, control, **kwargs): def on_log(self, args, state, control, logs=None, **kwargs): if state.is_world_process_zero and self.training_bar is not None: - # avoid modifying the logs object as it is shared between callbacks - logs = copy.deepcopy(logs) - _ = logs.pop("total_flos", None) + # make a shallow copy of logs so we can mutate the fields copied + # but avoid doing any value pickling. + shallow_logs = {} + for k, v in logs.items(): + shallow_logs[k] = v + _ = shallow_logs.pop("total_flos", None) # round numbers so that it looks better in console - if "epoch" in logs: - logs["epoch"] = round(logs["epoch"], 2) - self.training_bar.write(str(logs)) + if "epoch" in shallow_logs: + shallow_logs["epoch"] = round(shallow_logs["epoch"], 2) + self.training_bar.write(str(shallow_logs)) def on_train_end(self, args, state, control, **kwargs): if state.is_world_process_zero: