Skip to content

Commit

Permalink
Refactored adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
Sakharov committed Apr 4, 2024
1 parent f2be760 commit 2579ee3
Show file tree
Hide file tree
Showing 12 changed files with 582 additions and 157 deletions.
75 changes: 13 additions & 62 deletions aeronet_raster/aeronet_raster/dataadapters/abstractadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,16 @@
from ..utils.utils import validate_coord


class PaddedReaderMixin:
"""
Redefines __getitem__() so it works even if the coordinates are out of bounds
"""
def __init__(self, padding_mode: str = 'constant'):
self.padding_mode = padding_mode

def __getitem__(self, item):
item = self.parse_item(item)

pads, safe_coords = list(), list()
for axis, coords in enumerate(item):
# coords can be either slice or tuple at this point (after parse_item)
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted
pads.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices
safe_coords.append(coords)
elif isinstance(coords, slice): # coords = (min:max:step)
pads.append((max(-coords.start, 0), max(coords.stop - self.shape[axis], 0)))
safe_coords.append(slice(coords.start + pads[-1][0], coords.stop - pads[-1][1], coords.step))
else:
raise ValueError(f'Can not parse coords={coords} at axis={axis}')

res = self.fetch(safe_coords)
return np.pad(res, pads, mode=self.padding_mode)


class PaddedWriterMixin:
"""
Redefines __setitem__() so it works even if the coordinates are out of bounds
"""
def __setitem__(self, item, data):
item = self.parse_item(item)
assert data.ndim == self.ndim == len(item)
safe_coords, crops = list(), list()
for axis, coords in enumerate(item):
# coords can be either slice or tuple at this point (after parse_item)
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted
crops.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices
safe_coords.append(coords)
elif isinstance(coords, slice): # coords = (min:max:step)
crops.append((max(-coords.start, 0), max(coords.stop - self.shape[axis], 0)))
safe_coords.append(slice(coords.start + crops[-1][0], coords.stop - crops[-1][1], coords.step))

self.write(safe_coords,
data[tuple(slice(crops[i][0], data.shape[i]-crops[i][1], 1) for i in range(data.ndim))])


class AbstractReader:
"""Provides numpy array-like interface to arbitrary source of data"""
def __getattr__(self, item):
return getattr(self._data, item)

class AbstractArrayLike:
@property
def shape(self):
return self._shape
raise NotImplementedError

@property
def ndim(self):
return len(self._shape)

def __getitem__(self, item):
item = self.parse_item(item)
return self.fetch(item)
raise NotImplementedError

def fetch(self, item):
"""Datasource-specific data fetching, e.g. rasterio.read()"""
def __len__(self):
raise NotImplementedError

def parse_item(self, item):
Expand Down Expand Up @@ -95,17 +39,24 @@ def parse_item(self, item):
return item


class AbstractWriter(AbstractReader):
class AbstractReader(AbstractArrayLike):
"""Provides numpy array-like interface to arbitrary source of data"""
def __getitem__(self, item):
raise NotImplementedError
item = self.parse_item(item)
return self.fetch(item)

def fetch(self, item):
"""Datasource-specific data fetching, e.g. rasterio.read()"""
raise NotImplementedError


class AbstractWriter(AbstractArrayLike):
"""Provides numpy array-like interface to arbitrary source of data"""
def __setitem__(self, item, data):
item = self.parse_item(item)
self.write(item, data)

def write(self, item, data):
raise NotImplementedError


50 changes: 50 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/boundsafemixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np


class BoundSafeReaderMixin:
"""
Redefines __getitem__() so it works even if the coordinates are out of bounds
"""
def __init__(self, padding_mode: str = 'constant', **kwargs):
super().__init__(**kwargs)
self.padding_mode = padding_mode

def __getitem__(self, item):
item = self.parse_item(item)

pads, safe_coords = list(), list()
for axis, coords in enumerate(item):
# coords can be either slice or tuple at this point (after parse_item)
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted
pads.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices
safe_coords.append(coords)
elif isinstance(coords, slice): # coords = (min:max:step)
pads.append((max(-coords.start, 0), max(coords.stop - self.shape[axis], 0)))
safe_coords.append(slice(coords.start + pads[-1][0], coords.stop - pads[-1][1], coords.step))
else:
raise ValueError(f'Can not parse coords={coords} at axis={axis}')

res = self.fetch(safe_coords)
return np.pad(res, pads, mode=self.padding_mode)


class BoundSafeWriterMixin:
"""
Redefines __setitem__() so it works even if the coordinates are out of bounds
"""
def __setitem__(self, item, data):
item = self.parse_item(item)
assert data.ndim == self.ndim == len(item)
safe_coords, crops = list(), list()
for axis, coords in enumerate(item):
# coords can be either slice or tuple at this point (after parse_item)
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted
crops.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices
safe_coords.append(coords)
elif isinstance(coords, slice): # coords = (min:max:step)
crops.append((max(-coords.start, 0), max(coords.stop - self.shape[axis], 0)))
safe_coords.append(slice(coords.start + crops[-1][0], coords.stop - crops[-1][1], coords.step))

self.write(safe_coords,
data[tuple(slice(crops[i][0], data.shape[i]-crops[i][1], 1) for i in range(data.ndim))])

29 changes: 29 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/filemixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
class FileMixin:
"""Abstract class, provides interface to work with a file (open, close and context manager)"""

def __init__(self, path, **kwargs):
super().__init__(**kwargs)
self._path = path
self._descriptor = None
self._shape = None

def open(self):
raise NotImplementedError

def close(self):
self._descriptor.close()
self._descriptor = None
self._shape = None

def __enter__(self):
self.open()
return self

def __exit__(self, exc_type, exc_val, traceback):
self.close()

@property
def shape(self):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
return self._shape
55 changes: 55 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/imageadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from .abstractadapter import AbstractReader, AbstractWriter
from .boundsafemixin import BoundSafeReaderMixin, BoundSafeWriterMixin


class ImageReader(BoundSafeReaderMixin, AbstractReader):
"""Works with 3-dimensional data (channels, height, width), allows indexing channels with Sequence[int],
spatial dimensions with slices"""
def __init__(self, padding_mode='constant', **kwargs):
super().__init__(padding_mode, **kwargs)

def parse_item(self, item):
item = super().parse_item(item)
if not len(item) == 3:
raise ValueError(f"mage must be indexed with 3 axes, got {item}")
if isinstance(item[0], slice):
item[0] = list(range(item[0].start, item[0].stop, item[0].step))
assert isinstance(item[1], slice) and isinstance(item[2], slice),\
f"Image spatial axes (1 and 2) must be indexed with slices, got {item}"
return item

@property
def ndim(self):
return 3

@property
def shape(self):
return self._shape

def __len__(self):
return self.shape[0]


class ImageWriter(BoundSafeWriterMixin, AbstractWriter):
"""Works with 3-dimensional data (channels, height, width), allows indexing channels with Sequence[int],
spatial dimensions with slices"""
def parse_item(self, item):
item = super().parse_item(item)
if not len(item) == 3:
raise ValueError(f"PIL Image must be indexed with 3 axes, got {item}")
if isinstance(item[0], slice):
item[0] = list(range(item[0].start, item[0].stop, item[0].step))
assert isinstance(item[1], slice) and isinstance(item[2], slice),\
f"PIL Image spatial axes (1 and 2) must be indexed with slices, got {item}"
return item

@property
def ndim(self):
return 3

@property
def shape(self):
return self._shape

def __len__(self):
return self.shape[0]
51 changes: 38 additions & 13 deletions aeronet_raster/aeronet_raster/dataadapters/piladapter.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,46 @@
from .abstractadapter import AbstractReader
from .imageadapter import ImageReader, ImageWriter
from .filemixin import FileMixin
import numpy as np
import pkg_resources

if 'PIL' in {pkg.key for pkg in pkg_resources.working_set}:
if 'pillow' in {pkg.key for pkg in pkg_resources.working_set}:
from PIL import Image


class PilReader(AbstractReader):
"""Provides numpy array-like interface to PIL-compatible image file"""
__slots__ = ('_path',)
class PilReader(FileMixin, ImageReader):
"""Provides numpy array-like interface to PIL-compatible image file."""

def __init__(self, path, verbose: bool = False, **kwargs):
self._path = path
self.verbose = verbose
self._shape = np.array(Image.open(path)).transpose(2, 0, 1).shape
def __init__(self, path, padding_mode='constant', **kwargs):
super().__init__(path=path, padding_mode=padding_mode)

def __getitem__(self, item):
channels, y, x = self.parse_item(item)
res = np.array(Image.open(self._path)).transpose(2, 0, 1)[channels, y.start:y.stop, x.start:x.stop]
return res
def open(self):
self._descriptor = Image.open(self._path)
self._shape = len(self._descriptor.getbands()), self._descriptor.height, self._descriptor.width

def fetch(self, item):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
channels, y, x = item
return np.array(self._descriptor.crop((x.start, y.start, x.stop, y.stop))).transpose(2, 0, 1)[channels]


class PilWriter(FileMixin, ImageWriter):
"""Provides numpy array-like interface to PIL-compatible image file."""

def __init__(self, path, shape):
super().__init__(path=path)
if not shape[0] in (1, 3, 4):
raise ValueError(f'Only 1, 3, and 4 channels supported, got {shape}')
self._shape = shape

def open(self):
self._descriptor = np.zeros(self._shape, dtype=np.uint8)

def close(self):
Image.fromarray(self._descriptor.transpose(1, 2, 0)).save(self._path)

def write(self, item, data):
if not self._descriptor:
raise ValueError(f'Writer is not opened')
channels, y, x = item
self._descriptor[channels, y.start:y.stop, x.start:x.stop] = data
64 changes: 25 additions & 39 deletions aeronet_raster/aeronet_raster/dataadapters/rasterioadapter.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,39 @@
from .abstractadapter import AbstractReader, AbstractWriter, PaddedReaderMixin, PaddedWriterMixin
from .imageadapter import ImageWriter, ImageReader
from .filemixin import FileMixin
import numpy as np
import rasterio


class RasterioReader(PaddedReaderMixin, AbstractReader):
class RasterioReader(FileMixin, ImageReader):
"""Provides numpy array-like interface to geotiff file via rasterio"""

def __init__(self, path, verbose: bool = False, padding_mode: str = 'reflect', **kwargs):
super().__init__(padding_mode)
self._path = path
self._data = rasterio.open(path)
self.verbose = verbose
self._shape = np.array((self._data.count, self._data.shape[0], self._data.shape[1]))
def __init__(self, path, padding_mode: str = 'constant', **kwargs):
super().__init__(path=path, padding_mode=padding_mode)

def open(self):
self._descriptor = rasterio.open(self._path)
self._shape = self._descriptor.count, self._descriptor.shape[0], self._descriptor.shape[1]

def fetch(self, item):
res = self._data.read([ch+1 for ch in item[0]],
window=((item[1].start, item[1].stop),
(item[2].start, item[2].stop)),
boundless=True).astype(np.uint8)
channels, y, x = item
res = self._descriptor.read([ch+1 for ch in channels],
window=((y.start, y.stop),
(x.start, x.stop)),
boundless=True).astype(np.uint8)
return res

def parse_item(self, item):
item = super().parse_item(item)
assert len(item) == 3, f"Rasterio geotif must be indexed with 3 axes, got {item}"
if isinstance(item[0], slice):
item[0] = list(range(item[0].start, item[0].stop, item[0].step))
assert isinstance(item[1], slice) and isinstance(item[2], slice),\
f"Rasterio geotif spatial dimensions (1, 2) must be indexed with slices, got {item}"

return item


class RasterioWriter(PaddedWriterMixin, AbstractWriter):
class RasterioWriter(FileMixin, ImageWriter):
def __init__(self, path, **profile):
self._path = path
self._data = rasterio.open(path, 'w', **profile)
self._shape = np.array((self._data.count, self._data.shape[0], self._data.shape[1]))
super().__init__(path=path)
self.profile = profile

def open(self):
self._descriptor = rasterio.open(self._path, 'w', **self.profile)
self._shape = self._descriptor.count, self._descriptor.shape[0], self._descriptor.shape[1]

def write(self, item, data):
self._data.write(data, [ch+1 for ch in item[0]],
window=((item[1].start, item[1].stop),
(item[2].start, item[2].stop)))

def parse_item(self, item):
item = super().parse_item(item)
assert len(item) == 3, f"Rasterio geotif must be indexed with 3 axes, got {item}"
if isinstance(item[0], slice):
item[0] = list(range(item[0].start, item[0].stop, item[0].step))
assert isinstance(item[1], slice) and isinstance(item[2], slice),\
f"Rasterio geotif spatial dimensions (1, 2) must be indexed with slices, got {item}"

return item
channels, y, x = item
self._descriptor.write(data, [ch+1 for ch in channels],
window=((y.start, y.stop),
(x.start, x.stop)))
Loading

0 comments on commit 2579ee3

Please sign in to comment.