From 4ab7a28216211571fdddba414d4edd8426ab6489 Mon Sep 17 00:00:00 2001 From: Bharat Ramanathan Date: Fri, 19 Apr 2024 15:33:32 +0530 Subject: [PATCH] feat: Upgrade Weights & Biases callback (#30135) * feat: upgrade wandb callback with new features * fix: ci issues with imports and run fixup --- .../integrations/integration_utils.py | 101 +++++++++++++++++- 1 file changed, 96 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 63b9e050d4a1d3..45507bfda82198 100644 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -31,8 +31,17 @@ import numpy as np import packaging.version +from .. import PreTrainedModel, TFPreTrainedModel from .. import __version__ as version -from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging +from ..utils import ( + PushToHubMixin, + flatten_dict, + is_datasets_available, + is_pandas_available, + is_tf_available, + is_torch_available, + logging, +) logger = logging.get_logger(__name__) @@ -69,6 +78,7 @@ except importlib.metadata.PackageNotFoundError: _has_neptune = False +from .. import modelcard # noqa: E402 from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402 from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 from ..training_args import ParallelMode # noqa: E402 @@ -663,6 +673,22 @@ def on_train_end(self, args, state, control, **kwargs): self.tb_writer = None +def save_model_architecture_to_file(model: Any, output_dir: str): + with open(f"{output_dir}/model_architecture.txt", "w+") as f: + if isinstance(model, PreTrainedModel): + print(model, file=f) + elif is_tf_available() and isinstance(model, TFPreTrainedModel): + + def print_to_file(s): + print(s, file=f) + + model.summary(print_fn=print_to_file) + elif is_torch_available() and ( + isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model") + ): + print(model, file=f) + + class WandbCallback(TrainerCallback): """ A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). @@ -728,6 +754,9 @@ def setup(self, args, state, model, **kwargs): if hasattr(model, "config") and model.config is not None: model_config = model.config.to_dict() combined_dict = {**model_config, **combined_dict} + if hasattr(model, "peft_config") and model.peft_config is not None: + peft_config = model.peft_config + combined_dict = {**{"peft_config": peft_config}, **combined_dict} trial_name = state.trial_name init_args = {} if trial_name is not None: @@ -756,6 +785,51 @@ def setup(self, args, state, model, **kwargs): self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps)) self._wandb.run._label(code="transformers_trainer") + # add number of model parameters to wandb config + if any( + ( + isinstance(model, PreTrainedModel), + isinstance(model, PushToHubMixin), + (is_tf_available() and isinstance(model, TFPreTrainedModel)), + (is_torch_available() and isinstance(model, torch.nn.Module)), + ) + ): + self._wandb.config["model/num_parameters"] = model.num_parameters() + + # log the initial model and architecture to an artifact + with tempfile.TemporaryDirectory() as temp_dir: + model_name = ( + f"model-{self._wandb.run.id}" + if (args.run_name is None or args.run_name == args.output_dir) + else f"model-{self._wandb.run.name}" + ) + model_artifact = self._wandb.Artifact( + name=model_name, + type="model", + metadata={ + "model_config": model.config.to_dict() if hasattr(model, "config") else None, + "num_parameters": self._wandb.config.get("model/num_parameters"), + "initial_model": True, + }, + ) + model.save_pretrained(temp_dir) + # add the architecture to a separate text file + save_model_architecture_to_file(model, temp_dir) + + for f in Path(temp_dir).glob("*"): + if f.is_file(): + with model_artifact.new_file(f.name, mode="wb") as fa: + fa.write(f.read_bytes()) + self._wandb.run.log_artifact(model_artifact, aliases=["base_model"]) + + badge_markdown = ( + f'[Visualize in Weights & Biases]({self._wandb.run.get_url()})' + ) + + modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" + def on_train_begin(self, args, state, control, model=None, **kwargs): if self._wandb is None: return @@ -786,20 +860,25 @@ def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwarg else { f"eval/{args.metric_for_best_model}": state.best_metric, "train/total_floss": state.total_flos, + "model/num_parameters": self._wandb.config.get("model/num_parameters"), } ) + metadata["final_model"] = True logger.info("Logging model artifacts. ...") model_name = ( f"model-{self._wandb.run.id}" if (args.run_name is None or args.run_name == args.output_dir) else f"model-{self._wandb.run.name}" ) + # add the model architecture to a separate text file + save_model_architecture_to_file(model, temp_dir) + artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata) for f in Path(temp_dir).glob("*"): if f.is_file(): with artifact.new_file(f.name, mode="wb") as fa: fa.write(f.read_bytes()) - self._wandb.run.log_artifact(artifact) + self._wandb.run.log_artifact(artifact, aliases=["final_model"]) def on_log(self, args, state, control, model=None, logs=None, **kwargs): single_value_scalars = [ @@ -829,18 +908,30 @@ def on_save(self, args, state, control, **kwargs): for k, v in dict(self._wandb.summary).items() if isinstance(v, numbers.Number) and not k.startswith("_") } + checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters") ckpt_dir = f"checkpoint-{state.global_step}" artifact_path = os.path.join(args.output_dir, ckpt_dir) logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...") checkpoint_name = ( - f"checkpoint-{self._wandb.run.id}" + f"model-{self._wandb.run.id}" if (args.run_name is None or args.run_name == args.output_dir) - else f"checkpoint-{self._wandb.run.name}" + else f"model-{self._wandb.run.name}" ) artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata) artifact.add_dir(artifact_path) - self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"]) + self._wandb.log_artifact( + artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"] + ) + + def on_predict(self, args, state, control, metrics, **kwargs): + if self._wandb is None: + return + if not self._initialized: + self.setup(args, state, **kwargs) + if state.is_world_process_zero: + metrics = rewrite_logs(metrics) + self._wandb.log(metrics) class CometCallback(TrainerCallback):