From 4491ca2e8585375ccbab43a42f8af5c664414090 Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 24 Aug 2023 17:50:20 +0200 Subject: [PATCH] Added support for CMYK in decode_jpeg (#7741) Co-authored-by: Nicolas Hug --- test/test_image.py | 5 +- torchvision/csrc/io/image/cpu/decode_jpeg.cpp | 87 +++++++++++++++++-- 2 files changed, 83 insertions(+), 9 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index b24ac07d956..a87f5fa2d1e 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -83,12 +83,9 @@ def test_decode_jpeg(img_path, pil_mode, mode): with Image.open(img_path) as img: is_cmyk = img.mode == "CMYK" if pil_mode is not None: - if is_cmyk: - # libjpeg does not support the conversion - pytest.xfail("Decoding a CMYK jpeg isn't supported") img = img.convert(pil_mode) img_pil = torch.from_numpy(np.array(img)) - if is_cmyk: + if is_cmyk and mode == ImageReadMode.UNCHANGED: # flip the colors to match libjpeg img_pil = 255 - img_pil diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp index 09a0618ad1f..63a4e5b42ec 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp @@ -67,6 +67,58 @@ static void torch_jpeg_set_source_mgr( src->pub.next_input_byte = src->data; } +inline unsigned char clamped_cmyk_rgb_convert( + unsigned char k, + unsigned char cmy) { + // Inspired from Pillow: + // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569 + int v = k * cmy + 128; + v = ((v >> 8) + v) >> 8; + return std::clamp(k - v, 0, 255); +} + +void convert_line_cmyk_to_rgb( + j_decompress_ptr cinfo, + const unsigned char* cmyk_line, + unsigned char* rgb_line) { + int width = cinfo->output_width; + for (int i = 0; i < width; ++i) { + int c = cmyk_line[i * 4 + 0]; + int m = cmyk_line[i * 4 + 1]; + int y = cmyk_line[i * 4 + 2]; + int k = cmyk_line[i * 4 + 3]; + + rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c); + rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m); + rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y); + } +} + +inline unsigned char rgb_to_gray(int r, int g, int b) { + // Inspired from Pillow: + // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226 + return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16; +} + +void convert_line_cmyk_to_gray( + j_decompress_ptr cinfo, + const unsigned char* cmyk_line, + unsigned char* gray_line) { + int width = cinfo->output_width; + for (int i = 0; i < width; ++i) { + int c = cmyk_line[i * 4 + 0]; + int m = cmyk_line[i * 4 + 1]; + int y = cmyk_line[i * 4 + 2]; + int k = cmyk_line[i * 4 + 3]; + + int r = clamped_cmyk_rgb_convert(k, 255 - c); + int g = clamped_cmyk_rgb_convert(k, 255 - m); + int b = clamped_cmyk_rgb_convert(k, 255 - y); + + gray_line[i] = rgb_to_gray(r, g, b); + } +} + } // namespace torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { @@ -102,20 +154,29 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { jpeg_read_header(&cinfo, TRUE); int channels = cinfo.num_components; + bool cmyk_to_rgb_or_gray = false; if (mode != IMAGE_READ_MODE_UNCHANGED) { switch (mode) { case IMAGE_READ_MODE_GRAY: - if (cinfo.jpeg_color_space != JCS_GRAYSCALE) { + if (cinfo.jpeg_color_space == JCS_CMYK || + cinfo.jpeg_color_space == JCS_YCCK) { + cinfo.out_color_space = JCS_CMYK; + cmyk_to_rgb_or_gray = true; + } else { cinfo.out_color_space = JCS_GRAYSCALE; - channels = 1; } + channels = 1; break; case IMAGE_READ_MODE_RGB: - if (cinfo.jpeg_color_space != JCS_RGB) { + if (cinfo.jpeg_color_space == JCS_CMYK || + cinfo.jpeg_color_space == JCS_YCCK) { + cinfo.out_color_space = JCS_CMYK; + cmyk_to_rgb_or_gray = true; + } else { cinfo.out_color_space = JCS_RGB; - channels = 3; } + channels = 3; break; /* * Libjpeg does not support converting from CMYK to grayscale etc. There @@ -139,12 +200,28 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { auto tensor = torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); auto ptr = tensor.data_ptr(); + torch::Tensor cmyk_line_tensor; + if (cmyk_to_rgb_or_gray) { + cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8); + } + while (cinfo.output_scanline < cinfo.output_height) { /* jpeg_read_scanlines expects an array of pointers to scanlines. * Here the array is only one element long, but you could ask for * more than one scanline at a time if that's more convenient. */ - jpeg_read_scanlines(&cinfo, &ptr, 1); + if (cmyk_to_rgb_or_gray) { + auto cmyk_line_ptr = cmyk_line_tensor.data_ptr(); + jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1); + + if (channels == 3) { + convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr); + } else if (channels == 1) { + convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr); + } + } else { + jpeg_read_scanlines(&cinfo, &ptr, 1); + } ptr += stride; }