Skip to content

Commit

Permalink
FEAT: add tagging support to axolotl (#1004)
Browse files Browse the repository at this point in the history
* add tagging support to axolotl

* chore: lint

* fix method w self

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
younesbelkada and winglian committed Dec 27, 2023
1 parent 6ef46f8 commit db9094d
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import partial
from functools import partial, wraps
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -120,6 +120,7 @@ class AxolotlTrainer(Trainer):
"""

args = None # type: AxolotlTrainingArguments
tag_names = ["axolotl"]

def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
self.num_epochs = num_epochs
Expand Down Expand Up @@ -290,12 +291,41 @@ def compute_loss(self, model, inputs, return_outputs=False):
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)

def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]

if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names

return kwargs

@wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = self._sanitize_kwargs_for_tagging(
tag_names=self.tag_names, kwargs=kwargs
)

return super().push_to_hub(*args, **kwargs)


class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""

tag_names = ["axolotl", "mamba"]

def compute_loss(
self,
model,
Expand All @@ -322,6 +352,8 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
Trainer subclass that uses the OneCycleLR scheduler
"""

tag_names = ["axolotl", "onecycle"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
Expand Down Expand Up @@ -351,6 +383,8 @@ class ReLoRATrainer(AxolotlTrainer):
Trainer subclass that uses the OneCycleLR scheduler
"""

tag_names = ["axolotl", "relora"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
Expand Down

0 comments on commit db9094d

Please sign in to comment.