Skip to content

Commit

Permalink
ensure merged model matches the training dtype (#902)
Browse files Browse the repository at this point in the history
* ensure merged model matches the training dtype

* Update src/axolotl/cli/__init__.py

* Update src/axolotl/cli/__init__.py
  • Loading branch information
winglian committed Nov 29, 2023
1 parent 71b7ea3 commit 1d21aa6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def do_merge_lora(

LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
model.to(dtype=cfg.torch_dtype)

if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
Expand Down

0 comments on commit 1d21aa6

Please sign in to comment.