Skip to content

Commit

Permalink
simplify casting to device and dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 29, 2023
1 parent 7767d7d commit 097ce6c
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)

model = model.to(cfg.device)
if cfg.bf16:
model = model.to(torch.bfloat16)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

while True:
print("=" * 80)
Expand Down Expand Up @@ -170,9 +168,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)

model = model.to(cfg.device)
if cfg.bf16:
model = model.to(torch.bfloat16)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

def generate(instruction):
if not instruction:
Expand Down

0 comments on commit 097ce6c

Please sign in to comment.