Skip to content

Commit

Permalink
Fix: bf16 support for inference (#981)
Browse files Browse the repository at this point in the history
* Fix: bf16 torch dtype

* simplify casting to device and dtype

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
taziksh and winglian committed Dec 29, 2023
1 parent f8ae59b commit 3678a6c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)

model = model.to(cfg.device)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

while True:
print("=" * 80)
Expand Down Expand Up @@ -168,7 +168,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)

model = model.to(cfg.device)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

def generate(instruction):
if not instruction:
Expand Down

0 comments on commit 3678a6c

Please sign in to comment.