From fd16e5abef94e274572d40912f12baeffece8696 Mon Sep 17 00:00:00 2001 From: Ziniu Yu Date: Wed, 7 Dec 2022 15:03:33 +0800 Subject: [PATCH] fix: check dtype when loading models (#872) * fix: check dtype when loading models * fix: black --- server/clip_server/model/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index e4cfe3c3a..12de55e5b 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -464,7 +464,9 @@ def load_openai_model( # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use model = model.to(device) - if dtype == torch.float32 or dtype.startswith('amp'): + if dtype == torch.float32 or ( + isinstance(dtype, str) and dtype.startswith('amp') + ): model.float() elif dtype == torch.bfloat16: convert_weights_to_lp(model, dtype=torch.bfloat16)