Skip to content

Commit

Permalink
Added support for CMYK in decode_jpeg (#7741)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
vfdev-5 and NicolasHug authored Aug 24, 2023
1 parent f514ab6 commit 4491ca2
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 9 deletions.
5 changes: 1 addition & 4 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,9 @@ def test_decode_jpeg(img_path, pil_mode, mode):
with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK"
if pil_mode is not None:
if is_cmyk:
# libjpeg does not support the conversion
pytest.xfail("Decoding a CMYK jpeg isn't supported")
img = img.convert(pil_mode)
img_pil = torch.from_numpy(np.array(img))
if is_cmyk:
if is_cmyk and mode == ImageReadMode.UNCHANGED:
# flip the colors to match libjpeg
img_pil = 255 - img_pil

Expand Down
87 changes: 82 additions & 5 deletions torchvision/csrc/io/image/cpu/decode_jpeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,58 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data;
}

inline unsigned char clamped_cmyk_rgb_convert(
unsigned char k,
unsigned char cmy) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569
int v = k * cmy + 128;
v = ((v >> 8) + v) >> 8;
return std::clamp(k - v, 0, 255);
}

void convert_line_cmyk_to_rgb(
j_decompress_ptr cinfo,
const unsigned char* cmyk_line,
unsigned char* rgb_line) {
int width = cinfo->output_width;
for (int i = 0; i < width; ++i) {
int c = cmyk_line[i * 4 + 0];
int m = cmyk_line[i * 4 + 1];
int y = cmyk_line[i * 4 + 2];
int k = cmyk_line[i * 4 + 3];

rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c);
rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m);
rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y);
}
}

inline unsigned char rgb_to_gray(int r, int g, int b) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226
return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16;
}

void convert_line_cmyk_to_gray(
j_decompress_ptr cinfo,
const unsigned char* cmyk_line,
unsigned char* gray_line) {
int width = cinfo->output_width;
for (int i = 0; i < width; ++i) {
int c = cmyk_line[i * 4 + 0];
int m = cmyk_line[i * 4 + 1];
int y = cmyk_line[i * 4 + 2];
int k = cmyk_line[i * 4 + 3];

int r = clamped_cmyk_rgb_convert(k, 255 - c);
int g = clamped_cmyk_rgb_convert(k, 255 - m);
int b = clamped_cmyk_rgb_convert(k, 255 - y);

gray_line[i] = rgb_to_gray(r, g, b);
}
}

} // namespace

torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
Expand Down Expand Up @@ -102,20 +154,29 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
jpeg_read_header(&cinfo, TRUE);

int channels = cinfo.num_components;
bool cmyk_to_rgb_or_gray = false;

if (mode != IMAGE_READ_MODE_UNCHANGED) {
switch (mode) {
case IMAGE_READ_MODE_GRAY:
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
if (cinfo.jpeg_color_space == JCS_CMYK ||
cinfo.jpeg_color_space == JCS_YCCK) {
cinfo.out_color_space = JCS_CMYK;
cmyk_to_rgb_or_gray = true;
} else {
cinfo.out_color_space = JCS_GRAYSCALE;
channels = 1;
}
channels = 1;
break;
case IMAGE_READ_MODE_RGB:
if (cinfo.jpeg_color_space != JCS_RGB) {
if (cinfo.jpeg_color_space == JCS_CMYK ||
cinfo.jpeg_color_space == JCS_YCCK) {
cinfo.out_color_space = JCS_CMYK;
cmyk_to_rgb_or_gray = true;
} else {
cinfo.out_color_space = JCS_RGB;
channels = 3;
}
channels = 3;
break;
/*
* Libjpeg does not support converting from CMYK to grayscale etc. There
Expand All @@ -139,12 +200,28 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
auto tensor =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.data_ptr<uint8_t>();
torch::Tensor cmyk_line_tensor;
if (cmyk_to_rgb_or_gray) {
cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8);
}

while (cinfo.output_scanline < cinfo.output_height) {
/* jpeg_read_scanlines expects an array of pointers to scanlines.
* Here the array is only one element long, but you could ask for
* more than one scanline at a time if that's more convenient.
*/
jpeg_read_scanlines(&cinfo, &ptr, 1);
if (cmyk_to_rgb_or_gray) {
auto cmyk_line_ptr = cmyk_line_tensor.data_ptr<uint8_t>();
jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1);

if (channels == 3) {
convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr);
} else if (channels == 1) {
convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr);
}
} else {
jpeg_read_scanlines(&cinfo, &ptr, 1);
}
ptr += stride;
}

Expand Down

0 comments on commit 4491ca2

Please sign in to comment.