Skip to content

Commit

Permalink
feat: Upgrade Weights & Biases callback (#30135)
Browse files Browse the repository at this point in the history
* feat: upgrade wandb callback with new features

* fix: ci issues with imports and run fixup
  • Loading branch information
parambharat committed Apr 19, 2024
1 parent 30b4532 commit 4ab7a28
Showing 1 changed file with 96 additions and 5 deletions.
101 changes: 96 additions & 5 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,17 @@
import numpy as np
import packaging.version

from .. import PreTrainedModel, TFPreTrainedModel
from .. import __version__ as version
from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
from ..utils import (
PushToHubMixin,
flatten_dict,
is_datasets_available,
is_pandas_available,
is_tf_available,
is_torch_available,
logging,
)


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -69,6 +78,7 @@
except importlib.metadata.PackageNotFoundError:
_has_neptune = False

from .. import modelcard # noqa: E402
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from ..training_args import ParallelMode # noqa: E402
Expand Down Expand Up @@ -663,6 +673,22 @@ def on_train_end(self, args, state, control, **kwargs):
self.tb_writer = None


def save_model_architecture_to_file(model: Any, output_dir: str):
with open(f"{output_dir}/model_architecture.txt", "w+") as f:
if isinstance(model, PreTrainedModel):
print(model, file=f)
elif is_tf_available() and isinstance(model, TFPreTrainedModel):

def print_to_file(s):
print(s, file=f)

model.summary(print_fn=print_to_file)
elif is_torch_available() and (
isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model")
):
print(model, file=f)


class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
Expand Down Expand Up @@ -728,6 +754,9 @@ def setup(self, args, state, model, **kwargs):
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
if hasattr(model, "peft_config") and model.peft_config is not None:
peft_config = model.peft_config
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
Expand Down Expand Up @@ -756,6 +785,51 @@ def setup(self, args, state, model, **kwargs):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
self._wandb.run._label(code="transformers_trainer")

# add number of model parameters to wandb config
if any(
(
isinstance(model, PreTrainedModel),
isinstance(model, PushToHubMixin),
(is_tf_available() and isinstance(model, TFPreTrainedModel)),
(is_torch_available() and isinstance(model, torch.nn.Module)),
)
):
self._wandb.config["model/num_parameters"] = model.num_parameters()

# log the initial model and architecture to an artifact
with tempfile.TemporaryDirectory() as temp_dir:
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
model_artifact = self._wandb.Artifact(
name=model_name,
type="model",
metadata={
"model_config": model.config.to_dict() if hasattr(model, "config") else None,
"num_parameters": self._wandb.config.get("model/num_parameters"),
"initial_model": True,
},
)
model.save_pretrained(temp_dir)
# add the architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)

for f in Path(temp_dir).glob("*"):
if f.is_file():
with model_artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])

badge_markdown = (
f'[<img src="https://github.com/raw/wandb/assets/main/wandb-github-badge'
f'-28.svg" alt="Visualize in Weights & Biases" width="20'
f'0" height="32"/>]({self._wandb.run.get_url()})'
)

modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"

def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
return
Expand Down Expand Up @@ -786,20 +860,25 @@ def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwarg
else {
f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos,
"model/num_parameters": self._wandb.config.get("model/num_parameters"),
}
)
metadata["final_model"] = True
logger.info("Logging model artifacts. ...")
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
# add the model architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)

artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(artifact)
self._wandb.run.log_artifact(artifact, aliases=["final_model"])

def on_log(self, args, state, control, model=None, logs=None, **kwargs):
single_value_scalars = [
Expand Down Expand Up @@ -829,18 +908,30 @@ def on_save(self, args, state, control, **kwargs):
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters")

ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
checkpoint_name = (
f"checkpoint-{self._wandb.run.id}"
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"checkpoint-{self._wandb.run.name}"
else f"model-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
artifact.add_dir(artifact_path)
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])
self._wandb.log_artifact(
artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"]
)

def on_predict(self, args, state, control, metrics, **kwargs):
if self._wandb is None:
return
if not self._initialized:
self.setup(args, state, **kwargs)
if state.is_world_process_zero:
metrics = rewrite_logs(metrics)
self._wandb.log(metrics)


class CometCallback(TrainerCallback):
Expand Down

0 comments on commit 4ab7a28

Please sign in to comment.