Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: add tagging support to axolotl #1004

Merged
merged 3 commits into from
Dec 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import partial
from functools import partial, wraps
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -120,6 +120,7 @@ class AxolotlTrainer(Trainer):
"""

args = None # type: AxolotlTrainingArguments
tag_names = ["axolotl"]

def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
self.num_epochs = num_epochs
Expand Down Expand Up @@ -290,12 +291,41 @@ def compute_loss(self, model, inputs, return_outputs=False):
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)

def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]

if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names

return kwargs

@wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = self._sanitize_kwargs_for_tagging(
tag_names=self.tag_names, kwargs=kwargs
)

return super().push_to_hub(*args, **kwargs)


class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""

tag_names = ["axolotl", "mamba"]

def compute_loss(
self,
model,
Expand All @@ -322,6 +352,8 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
Trainer subclass that uses the OneCycleLR scheduler
"""

tag_names = ["axolotl", "onecycle"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
Expand Down Expand Up @@ -351,6 +383,8 @@ class ReLoRATrainer(AxolotlTrainer):
Trainer subclass that uses the OneCycleLR scheduler
"""

tag_names = ["axolotl", "relora"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
Expand Down