diff --git a/test.py b/test.py index c0af91120e60..d099699bcad8 100644 --- a/test.py +++ b/test.py @@ -37,6 +37,7 @@ def test(data, plots=True, wandb_logger=None, compute_loss=None, + half_precision=True, is_coco=False): # Initialize/load model and set device training = model is not None @@ -61,7 +62,7 @@ def test(data, # model = nn.DataParallel(model) # Half - half = device.type != 'cpu' # half precision only supported on CUDA + half = device.type != 'cpu' and half_precision # half precision only supported on CUDA if half: model.half()