Skip to content

Commit

Permalink
fix: add check for adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Jan 8, 2024
1 parent 3ed71af commit 284672e
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ def train(
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model_ref = None
if cfg.rl:
# use built-in trl autounwrap
model_ref = None

# load the model again for model_ref/baseline
# model_ref, _ = load_model(
# cfg, tokenizer, inference=cli_args.inference, reference_model=True
# )
if cfg.adapter:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)

safe_serialization = cfg.save_safetensors is True

Expand Down

0 comments on commit 284672e

Please sign in to comment.