diff --git a/Tests/test_image_putpalette.py b/Tests/test_image_putpalette.py index 725ecaade01..3b29769a7a4 100644 --- a/Tests/test_image_putpalette.py +++ b/Tests/test_image_putpalette.py @@ -74,4 +74,5 @@ def test_putpalette_with_alpha_values(): def test_rgba_palette(mode, palette): im = Image.new("P", (1, 1)) im.putpalette(palette, mode) + assert im.getpalette() == [1, 2, 3] assert im.palette.colors == {(1, 2, 3, 4): 0} diff --git a/Tests/test_image_quantize.py b/Tests/test_image_quantize.py index 7c8f3126c78..d48ee6c86c3 100644 --- a/Tests/test_image_quantize.py +++ b/Tests/test_image_quantize.py @@ -108,3 +108,18 @@ def test_palette(method, color): converted = im.quantize(method=method) converted_px = converted.load() assert converted_px[0, 0] == converted.palette.colors[color] + + +def test_small_palette(): + # Arrange + im = hopper() + + colors = (255, 0, 0, 0, 0, 255) + p = Image.new("P", (1, 1)) + p.putpalette(colors) + + # Act + im = im.quantize(palette=p) + + # Assert + assert len(im.getcolors()) == 2 diff --git a/src/PIL/Image.py b/src/PIL/Image.py index 7fd26059c95..9c35d332881 100644 --- a/src/PIL/Image.py +++ b/src/PIL/Image.py @@ -882,7 +882,7 @@ def load(self): if self.im and self.palette and self.palette.dirty: # realize palette mode, arr = self.palette.getdata() - palette_length = self.im.putpalette(mode, arr) + self.im.putpalette(mode, arr) self.palette.dirty = 0 self.palette.rawmode = None if "transparency" in self.info and mode in ("LA", "PA"): @@ -894,9 +894,7 @@ def load(self): else: palette_mode = "RGBA" if mode.startswith("RGBA") else "RGB" self.palette.mode = palette_mode - self.palette.palette = self.im.getpalette(palette_mode, palette_mode)[ - : palette_length * len(palette_mode) - ] + self.palette.palette = self.im.getpalette(palette_mode, palette_mode) if self.im: if cffi and USE_CFFI_ACCESS: diff --git a/src/_imaging.c b/src/_imaging.c index 2ea517816a0..0888188fb20 100644 --- a/src/_imaging.c +++ b/src/_imaging.c @@ -1063,7 +1063,7 @@ _gaussian_blur(ImagingObject *self, PyObject *args) { static PyObject * _getpalette(ImagingObject *self, PyObject *args) { PyObject *palette; - int palettesize = 256; + int palettesize; int bits; ImagingShuffler pack; @@ -1084,6 +1084,7 @@ _getpalette(ImagingObject *self, PyObject *args) { return NULL; } + palettesize = self->image->palette->size; palette = PyBytes_FromStringAndSize(NULL, palettesize * bits / 8); if (!palette) { return NULL; @@ -1672,9 +1673,11 @@ _putpalette(ImagingObject *self, PyObject *args) { self->image->palette = ImagingPaletteNew(palette_mode); - unpack(self->image->palette->palette, palette, palettesize * 8 / bits); + self->image->palette->size = palettesize * 8 / bits; + unpack(self->image->palette->palette, palette, self->image->palette->size); - return PyLong_FromLong(palettesize * 8 / bits); + Py_INCREF(Py_None); + return Py_None; } static PyObject * diff --git a/src/libImaging/Imaging.h b/src/libImaging/Imaging.h index 9b1c1024dc4..b65f8eadd51 100644 --- a/src/libImaging/Imaging.h +++ b/src/libImaging/Imaging.h @@ -143,6 +143,7 @@ struct ImagingPaletteInstance { char mode[IMAGING_MODE_LENGTH]; /* Band names */ /* Data */ + int size; UINT8 palette[1024]; /* Palette data (same format as image data) */ INT16 *cache; /* Palette cache (used for predefined palettes) */ diff --git a/src/libImaging/Palette.c b/src/libImaging/Palette.c index 43bea61e327..20c6bc84b1a 100644 --- a/src/libImaging/Palette.c +++ b/src/libImaging/Palette.c @@ -40,6 +40,7 @@ ImagingPaletteNew(const char *mode) { palette->mode[IMAGING_MODE_LENGTH - 1] = 0; /* Initialize to ramp */ + palette->size = 256; for (i = 0; i < 256; i++) { palette->palette[i * 4 + 0] = palette->palette[i * 4 + 1] = palette->palette[i * 4 + 2] = (UINT8)i; @@ -193,7 +194,7 @@ ImagingPaletteCacheUpdate(ImagingPalette palette, int r, int g, int b) { dmax = (unsigned int)~0; - for (i = 0; i < 256; i++) { + for (i = 0; i < palette->size; i++) { int r, g, b; unsigned int tmin, tmax; @@ -226,7 +227,7 @@ ImagingPaletteCacheUpdate(ImagingPalette palette, int r, int g, int b) { d[i] = (unsigned int)~0; } - for (i = 0; i < 256; i++) { + for (i = 0; i < palette->size; i++) { if (dmin[i] <= dmax) { int rd, gd, bd; int ri, gi, bi;