Skip to content

Commit

Permalink
Merge readers and writers
Browse files Browse the repository at this point in the history
  • Loading branch information
Sakharov committed Apr 18, 2024
1 parent 331e54a commit 03577f2
Show file tree
Hide file tree
Showing 16 changed files with 342 additions and 271 deletions.
13 changes: 6 additions & 7 deletions aeronet_raster/aeronet_raster/dataadapters/abstractadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from ..utils.utils import validate_coord


class AbstractArrayLike:
class AbstractAdapter:
"""Base abstract class for adapters. Provides numpy array-like interface for arbitrary data source"""
@property
def shape(self):
raise NotImplementedError
Expand Down Expand Up @@ -42,9 +43,7 @@ def parse_item(self, item):
item[axis] = validate_coord(coord, self.shape[axis])
return item


class AbstractReader(AbstractArrayLike):
"""Provides numpy array-like interface to arbitrary source of data"""
# Read -------------------------------------------------------------------------------------------------------------
def __getitem__(self, item):
item = self.parse_item(item)
return self.fetch(item)
Expand All @@ -53,9 +52,7 @@ def fetch(self, item):
"""Datasource-specific data fetching, e.g. rasterio.read()"""
raise NotImplementedError


class AbstractWriter(AbstractReader):
"""Provides numpy array-like interface to arbitrary source of data"""
# Write ------------------------------------------------------------------------------------------------------------
def __setitem__(self, item, data):
item = self.parse_item(item)
self.write(item, data)
Expand All @@ -64,3 +61,5 @@ def write(self, item, data):
raise NotImplementedError




9 changes: 2 additions & 7 deletions aeronet_raster/aeronet_raster/dataadapters/boundsafemixin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np


class BoundSafeReaderMixin:
class BoundSafeMixin:
"""
Redefines __getitem__() so it works even if the coordinates are out of bounds
Redefines __getitem__() and __setitem__() so it works even if the coordinates are out of bounds
"""
def __init__(self, padding_mode: str = 'constant', **kwargs):
super().__init__(**kwargs)
Expand All @@ -27,11 +27,6 @@ def __getitem__(self, item):
res = self.fetch(safe_coords)
return np.pad(res, pads, mode=self.padding_mode)


class BoundSafeWriterMixin:
"""
Redefines __setitem__() so it works even if the coordinates are out of bounds
"""
def __setitem__(self, item, data):
item = self.parse_item(item)
assert data.ndim == self.ndim == len(item)
Expand Down
45 changes: 5 additions & 40 deletions aeronet_raster/aeronet_raster/dataadapters/imageadapter.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,10 @@
from .abstractadapter import AbstractReader, AbstractWriter
from .boundsafemixin import BoundSafeReaderMixin, BoundSafeWriterMixin
from .abstractadapter import AbstractAdapter
from .boundsafemixin import BoundSafeMixin


class ImageReader(BoundSafeReaderMixin, AbstractReader):
"""Works with 3-dimensional data (channels, height, width), allows indexing channels with Sequence[int],
spatial dimensions with slices"""
def __init__(self, padding_mode='constant', **kwargs):
super().__init__(padding_mode, **kwargs)

def parse_item(self, item):
item = super().parse_item(item)
if not len(item) == 3:
raise ValueError(f"Image must be indexed with 3 axes, got {item}")
if isinstance(item[0], slice):
item[0] = list(range(item[0].start, item[0].stop, item[0].step))
assert isinstance(item[1], slice) and isinstance(item[2], slice),\
f"Image spatial axes (1 and 2) must be indexed with slices, got {item}"
return item

@property
def ndim(self):
return 3

@property
def shape(self):
return self._shape

def __len__(self):
return self.shape[0]


class ImageWriter(BoundSafeWriterMixin, AbstractWriter):
"""Works with 3-dimensional data (channels, height, width), allows indexing channels with Sequence[int],
spatial dimensions with slices"""
class ImageAdapter(BoundSafeMixin, AbstractAdapter):
"""Abstract class. Redefines parse_item() so that it works with 3-dimensional data (channels, height, width),
allows indexing channels with Sequence[int], spatial dimensions with slices"""
def parse_item(self, item):
item = super().parse_item(item)
if not len(item) == 3:
Expand All @@ -46,10 +18,3 @@ def parse_item(self, item):
@property
def ndim(self):
return 3

@property
def shape(self):
return self._shape

def __len__(self):
return self.shape[0]
11 changes: 4 additions & 7 deletions aeronet_raster/aeronet_raster/dataadapters/numpyadapter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .abstractadapter import AbstractReader, AbstractWriter
from .boundsafemixin import BoundSafeReaderMixin, BoundSafeWriterMixin
from .abstractadapter import AbstractAdapter
from .boundsafemixin import BoundSafeMixin


class NumpyReader(BoundSafeReaderMixin, AbstractReader):
"""Works with numpy arrays. Useful for testing"""
class NumpyAdapter(BoundSafeMixin, AbstractAdapter):
"""Bound-safe adapter for numpy array"""
def __init__(self, data, padding_mode='constant', **kwargs):
super().__init__(padding_mode, **kwargs)
self._data = data
Expand All @@ -28,9 +28,6 @@ def fetch(self, item):
item = tuple(item)
return self._data[item]


class NumpyWriter(BoundSafeWriterMixin, NumpyReader):
"""Works with numpy arrays. Useful for testing"""
def write(self, item, data):
if isinstance(item, list):
item = tuple(item)
Expand Down
36 changes: 15 additions & 21 deletions aeronet_raster/aeronet_raster/dataadapters/piladapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .imageadapter import ImageReader, ImageWriter
from .imageadapter import ImageAdapter
from .filemixin import FileMixin
import numpy as np
import pkg_resources
Expand All @@ -7,12 +7,9 @@
from PIL import Image


class PilReader(FileMixin, ImageReader):
class PilAdapter(FileMixin, ImageAdapter):
"""Provides numpy array-like interface to PIL-compatible image file."""

def __init__(self, path, padding_mode='constant', **kwargs):
super().__init__(path=path, padding_mode=padding_mode)

def open(self):
self._descriptor = Image.open(self._path)
self._shape = len(self._descriptor.getbands()), self._descriptor.height, self._descriptor.width
Expand All @@ -23,24 +20,21 @@ def fetch(self, item):
channels, y, x = item
return np.array(self._descriptor.crop((x.start, y.start, x.stop, y.stop))).transpose(2, 0, 1)[channels]

def write(self, key, value):
raise AttributeError('PIL Image is not writable. Use NumpyAdapter and save it as Image ')

class PilWriter(FileMixin, ImageWriter):
"""Provides numpy array-like interface to PIL-compatible image file."""
@property
def dtype(self):
return np.uint8

def __init__(self, path, shape):
super().__init__(path=path)
if not shape[0] in (1, 3, 4):
raise ValueError(f'Only 1, 3, and 4 channels supported, got {shape}')
self._shape = shape
@property
def ndim(self):
return 3

def open(self):
self._descriptor = np.zeros(self._shape, dtype=np.uint8)
@property
def shape(self):
return self._shape

def close(self):
Image.fromarray(self._descriptor.transpose(1, 2, 0)).save(self._path)
def __len__(self):
return self._shape[0]

def write(self, item, data):
if not self._descriptor:
raise ValueError(f'Writer is not opened')
channels, y, x = item
self._descriptor[channels, y.start:y.stop, x.start:x.stop] = data
53 changes: 32 additions & 21 deletions aeronet_raster/aeronet_raster/dataadapters/rasterioadapter.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,45 @@
from .imageadapter import ImageWriter, ImageReader
from .imageadapter import ImageAdapter
from .filemixin import FileMixin
import numpy as np
import rasterio

RASTERIO_OPEN_MODES = {'r', 'r+', 'w', 'w+'}

class RasterioReader(FileMixin, ImageReader):

class RasterioAdapter(FileMixin, ImageAdapter):
"""Provides numpy array-like interface to geotiff file via rasterio"""

def __init__(self, path, padding_mode: str = 'constant', **kwargs):
def __init__(self, path, mode='r', profile=None, padding_mode: str = 'constant', **kwargs):
super().__init__(path=path, padding_mode=padding_mode)
if mode not in RASTERIO_OPEN_MODES:
raise ValueError(f'Mode must be one of {RASTERIO_OPEN_MODES}')
if mode.startswith('w') and not profile:
raise ValueError(f'Profile must be specified for mode={mode}')
self._mode = mode
self._profile = profile

def open(self):
self._descriptor = rasterio.open(self._path)
if self._mode.startswith('w'):
self._descriptor = rasterio.open(self._path, self._mode, **self._profile)
else:
self._descriptor = rasterio.open(self._path, self._mode)
self._profile = self._descriptor.profile
self._shape = self._descriptor.count, self._descriptor.shape[0], self._descriptor.shape[1]

def fetch(self, item):
channels, y, x = item
return self._descriptor.read([ch+1 for ch in channels],
window=((y.start, y.stop),
(x.start, x.stop)))
@property
def shape(self):
return self._shape

@property
def ndim(self):
return 3

def __len__(self):
return self._shape[0]

@property
def profile(self):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
return self._descriptor.profile
return self._profile

@property
def crs(self):
Expand All @@ -50,15 +65,11 @@ def dtype(self):
raise ValueError(f'File {self._path} is not opened')
return self._descriptor.profile['dtype']


class RasterioWriter(ImageWriter, RasterioReader):
def __init__(self, path, profile, padding_mode: str = 'constant', **kwargs):
super().__init__(path=path, padding_mode=padding_mode)
self.write_profile = profile

def open(self):
self._descriptor = rasterio.open(self._path, 'w+', **self.write_profile)
self._shape = self._descriptor.count, self._descriptor.shape[0], self._descriptor.shape[1]
def fetch(self, item):
channels, y, x = item
return self._descriptor.read([ch+1 for ch in channels],
window=((y.start, y.stop),
(x.start, x.stop)))

def write(self, item, data):
channels, y, x = item
Expand Down
Loading

0 comments on commit 03577f2

Please sign in to comment.