From 33b5c27bd2cbc3ff64e9614ceca95c0ca0350b32 Mon Sep 17 00:00:00 2001 From: Cyberbeing Date: Tue, 30 Jan 2024 06:44:35 -0800 Subject: [PATCH] ToDType before ToPILImage for uint8 rounding ToPILImage converts to uint8 without rounding: (npimg * 255).astype(np.uint8) ToDtype converts to uint8 while rounding: image.mul(255 + 1.0 - 0.001).to(torch.uint8) --- modules/upscaler_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 5b45060796a..b901959dbae 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -21,7 +21,8 @@ def upscale_pil_patch(model, img: Image.Image) -> Image.Image: tensor = T.PILToTensor()(img) tensor = T.ToDtype(torch.float32, scale=True)(tensor) tensor = tensor.clamp_(0.0, 1.0).unsqueeze(0).to(device=param.device, dtype=param.dtype) - return T.ToPILImage(mode="RGB")(model(tensor).squeeze(0).clamp_(0.0, 1.0)) + tensorimg = T.ToDtype(torch.uint8, scale=True)(model(tensor).squeeze(0).clamp_(0.0, 1.0)) + return T.ToPILImage(mode="RGB")(tensorimg) def upscale_with_model( @@ -165,5 +166,6 @@ def upscale_2( desc=desc, device=param.device, ) - return T.ToPILImage(mode="RGB")(output.squeeze(0).clamp_(0.0, 1.0)) + output = T.ToDtype(torch.uint8, scale=True)(output.squeeze(0).clamp_(0.0, 1.0)) + return T.ToPILImage(mode="RGB")(output)