Skip to content

Commit

Permalink
create a model card with axolotl badge
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 22, 2023
1 parent c25ba79 commit cb6d525
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from typing import Optional

import torch

# add src to the pythonpath so we don't need to pip install this
import transformers.modelcard
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer

Expand Down Expand Up @@ -103,6 +102,9 @@ def terminate_handler(_, __, model):
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)

badge_markdown = """[<img src="https://github.com/raw/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"

LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
Expand Down Expand Up @@ -138,4 +140,7 @@ def terminate_handler(_, __, model):

model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)

if not cfg.hub_model_id:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))

return model, tokenizer

0 comments on commit cb6d525

Please sign in to comment.