diff --git a/setup.py b/setup.py index 7f383b82ec4..dbe8ce58aa2 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1" USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" +USE_HEIC = os.getenv("TORCHVISION_USE_HEIC", "0") == "1" # TODO enable by default! USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default! USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) @@ -50,6 +51,7 @@ print(f"{USE_PNG = }") print(f"{USE_JPEG = }") print(f"{USE_WEBP = }") +print(f"{USE_HEIC = }") print(f"{USE_AVIF = }") print(f"{USE_NVJPEG = }") print(f"{NVCC_FLAGS = }") @@ -334,6 +336,21 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") + if USE_HEIC: + heic_found, heic_include_dir, heic_library_dir = find_library(header="libheif/heif.h") + if heic_found: + print("Building torchvision with HEIC support") + print(f"{heic_include_dir = }") + print(f"{heic_library_dir = }") + if heic_include_dir is not None and heic_library_dir is not None: + # if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add. + include_dirs.append(heic_include_dir) + library_dirs.append(heic_library_dir) + libraries.append("heif") + define_macros += [("HEIC_FOUND", 1)] + else: + warnings.warn("Building torchvision without HEIC support") + if USE_AVIF: avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h") if avif_found: diff --git a/test/assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic b/test/assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic new file mode 100644 index 00000000000..4c29ac3c71c Binary files /dev/null and b/test/assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic differ diff --git a/test/test_image.py b/test/test_image.py index f1fe70135fe..4d14af638a0 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -15,6 +15,7 @@ from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence from torchvision.io.image import ( _decode_avif, + _decode_heic, decode_gif, decode_image, decode_jpeg, @@ -41,6 +42,19 @@ IS_WINDOWS = sys.platform in ("win32", "cygwin") IS_MACOS = sys.platform == "darwin" PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) +WEBP_TEST_IMAGES_DIR = os.environ.get("WEBP_TEST_IMAGES_DIR", "") + +# Hacky way of figuring out whether we compiled with libavif/libheif (those are +# currenlty disabled by default) +try: + _decode_avif(torch.arange(10, dtype=torch.uint8)) +except Exception as e: + DECODE_AVIF_ENABLED = "torchvision not compiled with libavif support" not in str(e) + +try: + _decode_heic(torch.arange(10, dtype=torch.uint8)) +except Exception as e: + DECODE_HEIC_ENABLED = "torchvision not compiled with libheif support" not in str(e) def _get_safe_image_name(name): @@ -148,17 +162,6 @@ def test_invalid_exif(tmpdir, size): torch.testing.assert_close(expected, output) -def test_decode_jpeg_errors(): - with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): - decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) - - with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): - decode_jpeg(torch.empty((100,), dtype=torch.float16)) - - with pytest.raises(RuntimeError, match="Not a JPEG file"): - decode_jpeg(torch.empty((100), dtype=torch.uint8)) - - def test_decode_bad_huffman_images(): # sanity check: make sure we can decode the bad Huffman encoding bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg")) @@ -234,10 +237,6 @@ def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun): def test_decode_png_errors(): - with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): - decode_png(torch.empty((), dtype=torch.uint8)) - with pytest.raises(RuntimeError, match="Content is not png"): - decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) with pytest.raises(RuntimeError, match="Out of bound read in decode_png"): decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png"))) with pytest.raises(RuntimeError, match="Content is too small for png"): @@ -863,8 +862,20 @@ def test_decode_gif(tmpdir, name, scripted): torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0) -@pytest.mark.parametrize("decode_fun", (decode_gif, decode_webp)) -def test_decode_gif_webp_errors(decode_fun): +decode_fun_and_match = [ + (decode_png, "Content is not png"), + (decode_jpeg, "Not a JPEG file"), + (decode_gif, re.escape("DGifOpenFileName() failed - 103")), + (decode_webp, "WebPGetFeatures failed."), +] +if DECODE_AVIF_ENABLED: + decode_fun_and_match.append((_decode_avif, "BMFF parsing failed")) +if DECODE_HEIC_ENABLED: + decode_fun_and_match.append((_decode_heic, "Invalid input: No 'ftyp' box")) + + +@pytest.mark.parametrize("decode_fun, match", decode_fun_and_match) +def test_decode_bad_encoded_data(decode_fun, match): encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8) with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"): decode_fun(encoded_data[None]) @@ -872,11 +883,7 @@ def test_decode_gif_webp_errors(decode_fun): decode_fun(encoded_data.float()) with pytest.raises(RuntimeError, match="Input tensor must be contiguous"): decode_fun(encoded_data[::2]) - if decode_fun is decode_gif: - expected_match = re.escape("DGifOpenFileName() failed - 103") - elif decode_fun is decode_webp: - expected_match = "WebPGetFeatures failed." - with pytest.raises(RuntimeError, match=expected_match): + with pytest.raises(RuntimeError, match=match): decode_fun(encoded_data) @@ -889,21 +896,27 @@ def test_decode_webp(decode_fun, scripted): img = decode_fun(encoded_bytes) assert img.shape == (3, 100, 100) assert img[None].is_contiguous(memory_format=torch.channels_last) + img += 123 # make sure image buffer wasn't freed by underlying decoding lib -# This test is skipped because it requires webp images that we're not including -# within the repo. The test images were downloaded from the different pages of -# https://developers.google.com/speed/webp/gallery -# Note that converting an RGBA image to RGB leads to bad results because the -# transparent pixels aren't necessarily set to "black" or "white", they can be -# random stuff. This is consistent with PIL results. -@pytest.mark.skip(reason="Need to download test images first") +# This test is skipped by default because it requires webp images that we're not +# including within the repo. The test images were downloaded manually from the +# different pages of https://developers.google.com/speed/webp/gallery +@pytest.mark.skipif(not WEBP_TEST_IMAGES_DIR, reason="WEBP_TEST_IMAGES_DIR is not set") @pytest.mark.parametrize("decode_fun", (decode_webp, decode_image)) @pytest.mark.parametrize("scripted", (False, True)) @pytest.mark.parametrize( - "mode, pil_mode", ((ImageReadMode.RGB, "RGB"), (ImageReadMode.RGB_ALPHA, "RGBA"), (ImageReadMode.UNCHANGED, None)) + "mode, pil_mode", + ( + # Note that converting an RGBA image to RGB leads to bad results because the + # transparent pixels aren't necessarily set to "black" or "white", they can be + # random stuff. This is consistent with PIL results. + (ImageReadMode.RGB, "RGB"), + (ImageReadMode.RGB_ALPHA, "RGBA"), + (ImageReadMode.UNCHANGED, None), + ), ) -@pytest.mark.parametrize("filename", Path("/home/nicolashug/webp_samples").glob("*.webp")) +@pytest.mark.parametrize("filename", Path(WEBP_TEST_IMAGES_DIR).glob("*.webp"), ids=lambda p: p.name) def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename): encoded_bytes = read_file(filename) if scripted: @@ -914,9 +927,10 @@ def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename) pil_img = Image.open(filename).convert(pil_mode) from_pil = F.pil_to_tensor(pil_img) assert_equal(img, from_pil) + img += 123 # make sure image buffer wasn't freed by underlying decoding lib -@pytest.mark.xfail(reason="AVIF support not enabled yet.") +@pytest.mark.skipif(not DECODE_AVIF_ENABLED, reason="AVIF support not enabled.") @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image)) @pytest.mark.parametrize("scripted", (False, True)) def test_decode_avif(decode_fun, scripted): @@ -926,13 +940,20 @@ def test_decode_avif(decode_fun, scripted): img = decode_fun(encoded_bytes) assert img.shape == (3, 100, 100) assert img[None].is_contiguous(memory_format=torch.channels_last) + img += 123 # make sure image buffer wasn't freed by underlying decoding lib -@pytest.mark.xfail(reason="AVIF support not enabled yet.") # Note: decode_image fails because some of these files have a (valid) signature # we don't recognize. We should probably use libmagic.... -# @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image)) -@pytest.mark.parametrize("decode_fun", (_decode_avif,)) +decode_funs = [] +if DECODE_AVIF_ENABLED: + decode_funs.append(_decode_avif) +if DECODE_HEIC_ENABLED: + decode_funs.append(_decode_heic) + + +@pytest.mark.skipif(not decode_funs, reason="Built without avif and heic support.") +@pytest.mark.parametrize("decode_fun", decode_funs) @pytest.mark.parametrize("scripted", (False, True)) @pytest.mark.parametrize( "mode, pil_mode", @@ -942,8 +963,10 @@ def test_decode_avif(decode_fun, scripted): (ImageReadMode.UNCHANGED, None), ), ) -@pytest.mark.parametrize("filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif")) -def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename): +@pytest.mark.parametrize( + "filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name +) +def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, filename): if "reversed_dimg_order" in str(filename): # Pillow properly decodes this one, but we don't (order of parts of the # image is wrong). This is due to a bug that was recently fixed in @@ -960,7 +983,14 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename) except RuntimeError as e: if any( s in str(e) - for s in ("BMFF parsing failed", "avifDecoderParse failed: ", "file contains more than one image") + for s in ( + "BMFF parsing failed", + "avifDecoderParse failed: ", + "file contains more than one image", + "no 'ispe' property", + "'iref' has double references", + "Invalid image grid", + ) ): pytest.skip(reason="Expected failure, that's OK") else: @@ -970,22 +1000,48 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename) assert img.shape[0] == 3 if mode == ImageReadMode.RGB_ALPHA: assert img.shape[0] == 4 + if img.dtype == torch.uint16: img = F.to_dtype(img, dtype=torch.uint8, scale=True) + try: + from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode)) + except RuntimeError as e: + if "Invalid image grid" in str(e): + pytest.skip(reason="PIL failure") + else: + raise e - from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode)) - if False: + if True: from torchvision.utils import make_grid g = make_grid([img, from_pil]) F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png")) - if mode != ImageReadMode.RGB: - # We don't compare against PIL for RGB because results look pretty - # different on RGBA images (other images are fine). The result on - # torchvision basically just plainly ignores the alpha channel, resuting - # in transparent pixels looking dark. PIL seems to be using a sort of - # k-nn thing, looking at the output. Take a look at the resuting images. - torch.testing.assert_close(img, from_pil, rtol=0, atol=3) + + is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic" + if mode == ImageReadMode.RGB and not is_decode_heic: + # We don't compare torchvision's AVIF against PIL for RGB because + # results look pretty different on RGBA images (other images are fine). + # The result on torchvision basically just plainly ignores the alpha + # channel, resuting in transparent pixels looking dark. PIL seems to be + # using a sort of k-nn thing (Take a look at the resuting images) + return + if filename.name == "sofa_grid1x5_420.avif" and is_decode_heic: + return + + torch.testing.assert_close(img, from_pil, rtol=0, atol=3) + + +@pytest.mark.skipif(not DECODE_HEIC_ENABLED, reason="HEIC support not enabled yet.") +@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image)) +@pytest.mark.parametrize("scripted", (False, True)) +def test_decode_heic(decode_fun, scripted): + encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".heic"))) + if scripted: + decode_fun = torch.jit.script(decode_fun) + img = decode_fun(encoded_bytes) + assert img.shape == (3, 100, 100) + assert img[None].is_contiguous(memory_format=torch.channels_last) + img += 123 # make sure image buffer wasn't freed by underlying decoding lib if __name__ == "__main__": diff --git a/torchvision/csrc/io/image/common.cpp b/torchvision/csrc/io/image/common.cpp new file mode 100644 index 00000000000..16b7ac2f91e --- /dev/null +++ b/torchvision/csrc/io/image/common.cpp @@ -0,0 +1,43 @@ + +#include "common.h" +#include + +namespace vision { +namespace image { + +void validate_encoded_data(const torch::Tensor& encoded_data) { + TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); + TORCH_CHECK( + encoded_data.dtype() == torch::kU8, + "Input tensor must have uint8 data type, got ", + encoded_data.dtype()); + TORCH_CHECK( + encoded_data.dim() == 1 && encoded_data.numel() > 0, + "Input tensor must be 1-dimensional and non-empty, got ", + encoded_data.dim(), + " dims and ", + encoded_data.numel(), + " numels."); +} + +bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( + ImageReadMode mode, + bool has_alpha) { + // Return true if the calling decoding function should return a 3D RGB tensor, + // and false if it should return a 4D RGBA tensor. + // This function ignores the requested "grayscale" modes and treats it as + // "unchanged", so it should only used on decoders who don't support grayscale + // outputs. + + if (mode == IMAGE_READ_MODE_RGB) { + return true; + } + if (mode == IMAGE_READ_MODE_RGB_ALPHA) { + return false; + } + // From here we assume mode is "unchanged", even for grayscale ones. + return !has_alpha; +} + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/image_read_mode.h b/torchvision/csrc/io/image/common.h similarity index 65% rename from torchvision/csrc/io/image/image_read_mode.h rename to torchvision/csrc/io/image/common.h index 84425265c34..d81acfda7d4 100644 --- a/torchvision/csrc/io/image/image_read_mode.h +++ b/torchvision/csrc/io/image/common.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace vision { namespace image { @@ -13,5 +14,11 @@ const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2; const ImageReadMode IMAGE_READ_MODE_RGB = 3; const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4; +void validate_encoded_data(const torch::Tensor& encoded_data); + +bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( + ImageReadMode mode, + bool has_alpha); + } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_avif.cpp b/torchvision/csrc/io/image/cpu/decode_avif.cpp index 5752f04a448..3cb326e2f11 100644 --- a/torchvision/csrc/io/image/cpu/decode_avif.cpp +++ b/torchvision/csrc/io/image/cpu/decode_avif.cpp @@ -1,4 +1,5 @@ #include "decode_avif.h" +#include "../common.h" #if AVIF_FOUND #include "avif/avif.h" @@ -33,16 +34,7 @@ torch::Tensor decode_avif( // Refer there for more detail about what each function does, and which // structure/data is available after which call. - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, - "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( - encoded_data.dim() == 1, - "Input tensor must be 1-dimensional, got ", - encoded_data.dim(), - " dims."); + validate_encoded_data(encoded_data); DecoderPtr decoder(avifDecoderCreate()); TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder."); @@ -60,6 +52,7 @@ torch::Tensor decode_avif( result == AVIF_RESULT_OK, "avifDecoderParse failed: ", avifResultToString(result)); + printf("avif num images = %d\n", decoder->imageCount); TORCH_CHECK( decoder->imageCount == 1, "Avif file contains more than one image"); @@ -78,18 +71,9 @@ torch::Tensor decode_avif( auto use_uint8 = (decoder->image->depth <= 8); rgb.depth = use_uint8 ? 8 : 16; - if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB && - mode != IMAGE_READ_MODE_RGB_ALPHA) { - // Other modes aren't supported, but we don't error or even warn because we - // have generic entry points like decode_image which may support all modes, - // it just depends on the underlying decoder. - mode = IMAGE_READ_MODE_UNCHANGED; - } - - // If return_rgb is false it means we return rgba - nothing else. auto return_rgb = - (mode == IMAGE_READ_MODE_RGB || - (mode == IMAGE_READ_MODE_UNCHANGED && !decoder->alphaPresent)); + should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( + mode, decoder->alphaPresent); auto num_channels = return_rgb ? 3 : 4; rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA; diff --git a/torchvision/csrc/io/image/cpu/decode_avif.h b/torchvision/csrc/io/image/cpu/decode_avif.h index 0510c2104e5..7feee1adfcb 100644 --- a/torchvision/csrc/io/image/cpu/decode_avif.h +++ b/torchvision/csrc/io/image/cpu/decode_avif.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" namespace vision { namespace image { diff --git a/torchvision/csrc/io/image/cpu/decode_gif.cpp b/torchvision/csrc/io/image/cpu/decode_gif.cpp index 183d42e86a4..f26d37950e3 100644 --- a/torchvision/csrc/io/image/cpu/decode_gif.cpp +++ b/torchvision/csrc/io/image/cpu/decode_gif.cpp @@ -1,5 +1,6 @@ #include "decode_gif.h" #include +#include "../common.h" #include "giflib/gif_lib.h" namespace vision { @@ -34,16 +35,7 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { // Refer over there for more details on the libgif API, API ref, and a // detailed description of the GIF format. - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, - "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( - encoded_data.dim() == 1, - "Input tensor must be 1-dimensional, got ", - encoded_data.dim(), - " dims."); + validate_encoded_data(encoded_data); int error = D_GIF_SUCCEEDED; diff --git a/torchvision/csrc/io/image/cpu/decode_heic.cpp b/torchvision/csrc/io/image/cpu/decode_heic.cpp new file mode 100644 index 00000000000..e245c25f9d7 --- /dev/null +++ b/torchvision/csrc/io/image/cpu/decode_heic.cpp @@ -0,0 +1,135 @@ +#include "decode_heic.h" +#include "../common.h" + +#if HEIC_FOUND +#include "libheif/heif_cxx.h" +#endif // HEIC_FOUND + +namespace vision { +namespace image { + +#if !HEIC_FOUND +torch::Tensor decode_heic( + const torch::Tensor& encoded_data, + ImageReadMode mode) { + TORCH_CHECK( + false, "decode_heic: torchvision not compiled with libheif support"); +} +#else + +torch::Tensor decode_heic( + const torch::Tensor& encoded_data, + ImageReadMode mode) { + validate_encoded_data(encoded_data); + + auto return_rgb = true; + + int height = 0; + int width = 0; + int num_channels = 0; + int stride = 0; + uint8_t* decoded_data = nullptr; + heif::Image img; + int bit_depth = 0; + + try { + heif::Context ctx; + ctx.read_from_memory_without_copy( + encoded_data.data_ptr(), encoded_data.numel()); + + // TODO properly error on (or support) image sequences. Right now, I think + // this function will always return the first image in a sequence, which is + // inconsistent with decode_gif (which returns a batch) and with decode_avif + // (which errors loudly). + // Why? I'm struggling to make sense of + // ctx.get_number_of_top_level_images(). It disagrees with libavif's + // imageCount. For example on some of the libavif test images: + // + // - colors-animated-12bpc-keyframes-0-2-3.avif + // avif num images = 5 + // heif num images = 1 // Why is this 1 when clearly this is supposed to + // be a sequence? + // - sofa_grid1x5_420.avif + // avif num images = 1 + // heif num images = 6 // If we were to error here we won't be able to + // decode this image which is otherwise properly + // decoded by libavif. + // I can't find a libheif function that does what we need here, or at least + // that agrees with libavif. + + // TORCH_CHECK( + // ctx.get_number_of_top_level_images() == 1, + // "heic file contains more than one image"); + + heif::ImageHandle handle = ctx.get_primary_image_handle(); + bit_depth = handle.get_luma_bits_per_pixel(); + + return_rgb = + should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( + mode, handle.has_alpha_channel()); + + height = handle.get_height(); + width = handle.get_width(); + + num_channels = return_rgb ? 3 : 4; + heif_chroma chroma; + if (bit_depth == 8) { + chroma = return_rgb ? heif_chroma_interleaved_RGB + : heif_chroma_interleaved_RGBA; + } else { + // TODO: This, along with our 10bits -> 16bits range mapping down below, + // may not work on BE platforms + chroma = return_rgb ? heif_chroma_interleaved_RRGGBB_LE + : heif_chroma_interleaved_RRGGBBAA_LE; + } + + img = handle.decode_image(heif_colorspace_RGB, chroma); + + decoded_data = img.get_plane(heif_channel_interleaved, &stride); + } catch (const heif::Error& err) { + // We need this try/catch block and call TORCH_CHECK, because libheif may + // otherwise throw heif::Error that would just be reported as "An unknown + // exception occurred" when we move back to Python. + TORCH_CHECK(false, "decode_heif failed: ", err.get_message()); + } + TORCH_CHECK(decoded_data != nullptr, "Something went wrong during decoding."); + + auto dtype = (bit_depth == 8) ? torch::kUInt8 : at::kUInt16; + auto out = torch::empty({height, width, num_channels}, dtype); + uint8_t* out_ptr = (uint8_t*)out.data_ptr(); + + // decoded_data is *almost* the raw decoded data, but not quite: for some + // images, there may be some padding at the end of each row, i.e. when stride + // != row_size_in_bytes. So we can't copy decoded_data into the tensor's + // memory directly, we have to copy row by row. Oh, and if you think you can + // take a shortcut when stride == row_size_in_bytes and just do: + // out = torch::from_blob(decoded_data, ...) + // you can't, because decoded_data is owned by the heif::Image object and it + // gets freed when it gets out of scope! + auto row_size_in_bytes = width * num_channels * ((bit_depth == 8) ? 1 : 2); + for (auto h = 0; h < height; h++) { + memcpy( + out_ptr + h * row_size_in_bytes, + decoded_data + h * stride, + row_size_in_bytes); + } + if (bit_depth > 8) { + // Say bit depth is 10. decodec_data and out_ptr contain 10bits values + // over 2 bytes, stored into uint16_t. In torchvision a uint16 value is + // expected to be in [0, 2**16), so we have to map the 10bits value to that + // range. Note that other libraries like libavif do that mapping + // automatically. + // TODO: It's possible to avoid the memcpy call above in this case, and do + // the copy at the same time as the conversation. Whether it's worth it + // should be benchmarked. + auto out_ptr_16 = (uint16_t*)out_ptr; + for (auto p = 0; p < height * width * num_channels; p++) { + out_ptr_16[p] <<= (16 - bit_depth); + } + } + return out.permute({2, 0, 1}); +} +#endif // HEIC_FOUND + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_heic.h b/torchvision/csrc/io/image/cpu/decode_heic.h new file mode 100644 index 00000000000..10b414f554d --- /dev/null +++ b/torchvision/csrc/io/image/cpu/decode_heic.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include "../common.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_heic( + const torch::Tensor& data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index e5a421b7287..9c1a7ff3ef4 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -2,6 +2,7 @@ #include "decode_avif.h" #include "decode_gif.h" +#include "decode_heic.h" #include "decode_jpeg.h" #include "decode_png.h" #include "decode_webp.h" @@ -61,6 +62,17 @@ torch::Tensor decode_image( return decode_avif(data, mode); } + // Similarly for heic we assume the signature is "ftypeheic" but some files + // may come as "ftypmif1" where the "heic" part is defined later in the file. + // We can't be re-inventing libmagic here. We might need to start relying on + // it though... + const uint8_t heic_signature[8] = { + 0x66, 0x74, 0x79, 0x70, 0x68, 0x65, 0x69, 0x63}; // == "ftypheic" + TORCH_CHECK(data.numel() >= 12, err_msg); + if ((memcmp(heic_signature, datap + 4, 8) == 0)) { + return decode_heic(data, mode); + } + const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF" const uint8_t webp_signature_end[7] = { 0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8" diff --git a/torchvision/csrc/io/image/cpu/decode_image.h b/torchvision/csrc/io/image/cpu/decode_image.h index f0e66d397ac..f66d47eccd4 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.h +++ b/torchvision/csrc/io/image/cpu/decode_image.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" namespace vision { namespace image { diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp index ec5953e4106..052b98e1be9 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp @@ -1,4 +1,5 @@ #include "decode_jpeg.h" +#include "../common.h" #include "common_jpeg.h" #include "exif.h" @@ -134,12 +135,8 @@ torch::Tensor decode_jpeg( bool apply_exif_orientation) { C10_LOG_API_USAGE_ONCE( "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg"); - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); + + validate_encoded_data(data); struct jpeg_decompress_struct cinfo; struct torch_jpeg_error_mgr jerr; diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.h b/torchvision/csrc/io/image/cpu/decode_jpeg.h index e0c9a24c846..7412a46d2ea 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.h +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" namespace vision { namespace image { diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index ac14ae934a4..ede14c1e94a 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -1,4 +1,5 @@ #include "decode_png.h" +#include "../common.h" #include "common_png.h" #include "exif.h" @@ -27,12 +28,8 @@ torch::Tensor decode_png( ImageReadMode mode, bool apply_exif_orientation) { C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); + + validate_encoded_data(data); auto png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); diff --git a/torchvision/csrc/io/image/cpu/decode_png.h b/torchvision/csrc/io/image/cpu/decode_png.h index 0866711e987..faaffa7ae49 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.h +++ b/torchvision/csrc/io/image/cpu/decode_png.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" namespace vision { namespace image { diff --git a/torchvision/csrc/io/image/cpu/decode_webp.cpp b/torchvision/csrc/io/image/cpu/decode_webp.cpp index bf115c23c41..b202473c039 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.cpp +++ b/torchvision/csrc/io/image/cpu/decode_webp.cpp @@ -1,4 +1,5 @@ #include "decode_webp.h" +#include "../common.h" #if WEBP_FOUND #include "webp/decode.h" @@ -19,16 +20,7 @@ torch::Tensor decode_webp( torch::Tensor decode_webp( const torch::Tensor& encoded_data, ImageReadMode mode) { - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, - "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( - encoded_data.dim() == 1, - "Input tensor must be 1-dimensional, got ", - encoded_data.dim(), - " dims."); + validate_encoded_data(encoded_data); auto encoded_data_p = encoded_data.data_ptr(); auto encoded_data_size = encoded_data.numel(); @@ -40,18 +32,9 @@ torch::Tensor decode_webp( TORCH_CHECK( !features.has_animation, "Animated webp files are not supported."); - if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB && - mode != IMAGE_READ_MODE_RGB_ALPHA) { - // Other modes aren't supported, but we don't error or even warn because we - // have generic entry points like decode_image which may support all modes, - // it just depends on the underlying decoder. - mode = IMAGE_READ_MODE_UNCHANGED; - } - - // If return_rgb is false it means we return rgba - nothing else. auto return_rgb = - (mode == IMAGE_READ_MODE_RGB || - (mode == IMAGE_READ_MODE_UNCHANGED && !features.has_alpha)); + should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( + mode, features.has_alpha); auto decoding_func = return_rgb ? WebPDecodeRGB : WebPDecodeRGBA; auto num_channels = return_rgb ? 3 : 4; diff --git a/torchvision/csrc/io/image/cpu/decode_webp.h b/torchvision/csrc/io/image/cpu/decode_webp.h index 5632ea56ff9..d5c81547c42 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.h +++ b/torchvision/csrc/io/image/cpu/decode_webp.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" namespace vision { namespace image { diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 6314ececef1..2079ca5f919 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -139,7 +139,7 @@ std::vector decode_jpegs_cuda( } CUDAJpegDecoder::CUDAJpegDecoder(const torch::Device& target_device) - : original_device{torch::kCUDA, torch::cuda::current_device()}, + : original_device{torch::kCUDA, c10::cuda::current_device()}, target_device{target_device}, stream{ target_device.has_index() diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 2458a103a3a..6f72d9e35b2 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -1,7 +1,7 @@ #pragma once #include #include -#include "../image_read_mode.h" +#include "../common.h" #if NVJPEG_FOUND #include diff --git a/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h index 3fdf715b00f..8c3ad8f9a9d 100644 --- a/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" #include "decode_jpegs_cuda.h" #include "encode_jpegs_cuda.h" diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index a777d19d3bd..f0ce91144a6 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -23,6 +23,8 @@ static auto registry = &decode_jpeg) .op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor", &decode_webp) + .op("image::decode_heic(Tensor encoded_data, int mode) -> Tensor", + &decode_heic) .op("image::decode_avif(Tensor encoded_data, int mode) -> Tensor", &decode_avif) .op("image::encode_jpeg", &encode_jpeg) diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index 91a5144fa1c..23493f3c030 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -2,6 +2,7 @@ #include "cpu/decode_avif.h" #include "cpu/decode_gif.h" +#include "cpu/decode_heic.h" #include "cpu/decode_image.h" #include "cpu/decode_jpeg.h" #include "cpu/decode_png.h" diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 08a0d6d62b7..a604ea1fdb6 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -61,6 +61,7 @@ "decode_image", "decode_jpeg", "decode_png", + "decode_heic", "decode_webp", "decode_gif", "encode_jpeg", diff --git a/torchvision/io/image.py b/torchvision/io/image.py index e169c0a4f7a..f1df0d52672 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -417,5 +417,31 @@ def _decode_avif( Decoded image (Tensor[image_channels, image_height, image_width]) """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(decode_webp) + _log_api_usage_once(_decode_avif) return torch.ops.image.decode_avif(input, mode.value) + + +def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: + """ + Decode an HEIC image into a 3 dimensional RGB[A] Tensor. + + The values of the output tensor are in uint8 in [0, 255] for most images. If + the image has a bit-depth of more than 8, then the output tensor is uint16 + in [0, 65535]. Since uint16 support is limited in pytorch, we recommend + calling :func:`torchvision.transforms.v2.functional.to_dtype()` with + ``scale=True`` after this function to convert the decoded image into a uint8 + or float tensor. + + Args: + input (Tensor[1]): a one dimensional contiguous uint8 tensor containing + the raw bytes of the HEIC image. + mode (ImageReadMode): The read mode used for optionally + converting the image color space. Default: ``ImageReadMode.UNCHANGED``. + Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. + + Returns: + Decoded image (Tensor[image_channels, image_height, image_width]) + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(_decode_heic) + return torch.ops.image.decode_heic(input, mode.value) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 2a6e0ce12c0..07932390efe 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1649,9 +1649,9 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace raise TypeError("Argument value should be either a number or str or a sequence") if isinstance(value, str) and value != "random": raise ValueError("If value is str, it should be 'random'") - if not isinstance(scale, (tuple, list)): + if not isinstance(scale, Sequence): raise TypeError("Scale should be a sequence") - if not isinstance(ratio, (tuple, list)): + if not isinstance(ratio, Sequence): raise TypeError("Ratio should be a sequence") if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index f085ef3ca6e..b1dd5083408 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import PIL.Image import torch @@ -56,8 +56,8 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: def __init__( self, p: float = 0.5, - scale: Tuple[float, float] = (0.02, 0.33), - ratio: Tuple[float, float] = (0.3, 3.3), + scale: Sequence[float] = (0.02, 0.33), + ratio: Sequence[float] = (0.3, 3.3), value: float = 0.0, inplace: bool = False, ): @@ -66,9 +66,9 @@ def __init__( raise TypeError("Argument value should be either a number or str or a sequence") if isinstance(value, str) and value != "random": raise ValueError("If value is str, it should be 'random'") - if not isinstance(scale, (tuple, list)): + if not isinstance(scale, Sequence): raise TypeError("Scale should be a sequence") - if not isinstance(ratio, (tuple, list)): + if not isinstance(ratio, Sequence): raise TypeError("Ratio should be a sequence") if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)")