Skip to content

Commit

Permalink
feat: add config to disable autounwrap
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Jan 11, 2024
1 parent 116f34a commit 9f02c8d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
9 changes: 9 additions & 0 deletions docs/rlhf.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,12 @@ datasets:
```yaml
rl: ipo
```

#### Trl autounwrap for peft

Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.

```yaml
# load ref model when adapter training.
rl_adapter_ref_model: true
```
4 changes: 2 additions & 2 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def train(
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model_ref = None
if cfg.rl:
if cfg.adapter:
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
Expand Down

0 comments on commit 9f02c8d

Please sign in to comment.