-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sakharov
committed
Apr 4, 2024
1 parent
f2be760
commit 2579ee3
Showing
12 changed files
with
582 additions
and
157 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
aeronet_raster/aeronet_raster/dataadapters/boundsafemixin.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import numpy as np | ||
|
||
|
||
class BoundSafeReaderMixin: | ||
""" | ||
Redefines __getitem__() so it works even if the coordinates are out of bounds | ||
""" | ||
def __init__(self, padding_mode: str = 'constant', **kwargs): | ||
super().__init__(**kwargs) | ||
self.padding_mode = padding_mode | ||
|
||
def __getitem__(self, item): | ||
item = self.parse_item(item) | ||
|
||
pads, safe_coords = list(), list() | ||
for axis, coords in enumerate(item): | ||
# coords can be either slice or tuple at this point (after parse_item) | ||
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted | ||
pads.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices | ||
safe_coords.append(coords) | ||
elif isinstance(coords, slice): # coords = (min:max:step) | ||
pads.append((max(-coords.start, 0), max(coords.stop - self.shape[axis], 0))) | ||
safe_coords.append(slice(coords.start + pads[-1][0], coords.stop - pads[-1][1], coords.step)) | ||
else: | ||
raise ValueError(f'Can not parse coords={coords} at axis={axis}') | ||
|
||
res = self.fetch(safe_coords) | ||
return np.pad(res, pads, mode=self.padding_mode) | ||
|
||
|
||
class BoundSafeWriterMixin: | ||
""" | ||
Redefines __setitem__() so it works even if the coordinates are out of bounds | ||
""" | ||
def __setitem__(self, item, data): | ||
item = self.parse_item(item) | ||
assert data.ndim == self.ndim == len(item) | ||
safe_coords, crops = list(), list() | ||
for axis, coords in enumerate(item): | ||
# coords can be either slice or tuple at this point (after parse_item) | ||
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted | ||
crops.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices | ||
safe_coords.append(coords) | ||
elif isinstance(coords, slice): # coords = (min:max:step) | ||
crops.append((max(-coords.start, 0), max(coords.stop - self.shape[axis], 0))) | ||
safe_coords.append(slice(coords.start + crops[-1][0], coords.stop - crops[-1][1], coords.step)) | ||
|
||
self.write(safe_coords, | ||
data[tuple(slice(crops[i][0], data.shape[i]-crops[i][1], 1) for i in range(data.ndim))]) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
class FileMixin: | ||
"""Abstract class, provides interface to work with a file (open, close and context manager)""" | ||
|
||
def __init__(self, path, **kwargs): | ||
super().__init__(**kwargs) | ||
self._path = path | ||
self._descriptor = None | ||
self._shape = None | ||
|
||
def open(self): | ||
raise NotImplementedError | ||
|
||
def close(self): | ||
self._descriptor.close() | ||
self._descriptor = None | ||
self._shape = None | ||
|
||
def __enter__(self): | ||
self.open() | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, traceback): | ||
self.close() | ||
|
||
@property | ||
def shape(self): | ||
if not self._descriptor: | ||
raise ValueError(f'File {self._path} is not opened') | ||
return self._shape |
55 changes: 55 additions & 0 deletions
55
aeronet_raster/aeronet_raster/dataadapters/imageadapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from .abstractadapter import AbstractReader, AbstractWriter | ||
from .boundsafemixin import BoundSafeReaderMixin, BoundSafeWriterMixin | ||
|
||
|
||
class ImageReader(BoundSafeReaderMixin, AbstractReader): | ||
"""Works with 3-dimensional data (channels, height, width), allows indexing channels with Sequence[int], | ||
spatial dimensions with slices""" | ||
def __init__(self, padding_mode='constant', **kwargs): | ||
super().__init__(padding_mode, **kwargs) | ||
|
||
def parse_item(self, item): | ||
item = super().parse_item(item) | ||
if not len(item) == 3: | ||
raise ValueError(f"mage must be indexed with 3 axes, got {item}") | ||
if isinstance(item[0], slice): | ||
item[0] = list(range(item[0].start, item[0].stop, item[0].step)) | ||
assert isinstance(item[1], slice) and isinstance(item[2], slice),\ | ||
f"Image spatial axes (1 and 2) must be indexed with slices, got {item}" | ||
return item | ||
|
||
@property | ||
def ndim(self): | ||
return 3 | ||
|
||
@property | ||
def shape(self): | ||
return self._shape | ||
|
||
def __len__(self): | ||
return self.shape[0] | ||
|
||
|
||
class ImageWriter(BoundSafeWriterMixin, AbstractWriter): | ||
"""Works with 3-dimensional data (channels, height, width), allows indexing channels with Sequence[int], | ||
spatial dimensions with slices""" | ||
def parse_item(self, item): | ||
item = super().parse_item(item) | ||
if not len(item) == 3: | ||
raise ValueError(f"PIL Image must be indexed with 3 axes, got {item}") | ||
if isinstance(item[0], slice): | ||
item[0] = list(range(item[0].start, item[0].stop, item[0].step)) | ||
assert isinstance(item[1], slice) and isinstance(item[2], slice),\ | ||
f"PIL Image spatial axes (1 and 2) must be indexed with slices, got {item}" | ||
return item | ||
|
||
@property | ||
def ndim(self): | ||
return 3 | ||
|
||
@property | ||
def shape(self): | ||
return self._shape | ||
|
||
def __len__(self): | ||
return self.shape[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,46 @@ | ||
from .abstractadapter import AbstractReader | ||
from .imageadapter import ImageReader, ImageWriter | ||
from .filemixin import FileMixin | ||
import numpy as np | ||
import pkg_resources | ||
|
||
if 'PIL' in {pkg.key for pkg in pkg_resources.working_set}: | ||
if 'pillow' in {pkg.key for pkg in pkg_resources.working_set}: | ||
from PIL import Image | ||
|
||
|
||
class PilReader(AbstractReader): | ||
"""Provides numpy array-like interface to PIL-compatible image file""" | ||
__slots__ = ('_path',) | ||
class PilReader(FileMixin, ImageReader): | ||
"""Provides numpy array-like interface to PIL-compatible image file.""" | ||
|
||
def __init__(self, path, verbose: bool = False, **kwargs): | ||
self._path = path | ||
self.verbose = verbose | ||
self._shape = np.array(Image.open(path)).transpose(2, 0, 1).shape | ||
def __init__(self, path, padding_mode='constant', **kwargs): | ||
super().__init__(path=path, padding_mode=padding_mode) | ||
|
||
def __getitem__(self, item): | ||
channels, y, x = self.parse_item(item) | ||
res = np.array(Image.open(self._path)).transpose(2, 0, 1)[channels, y.start:y.stop, x.start:x.stop] | ||
return res | ||
def open(self): | ||
self._descriptor = Image.open(self._path) | ||
self._shape = len(self._descriptor.getbands()), self._descriptor.height, self._descriptor.width | ||
|
||
def fetch(self, item): | ||
if not self._descriptor: | ||
raise ValueError(f'File {self._path} is not opened') | ||
channels, y, x = item | ||
return np.array(self._descriptor.crop((x.start, y.start, x.stop, y.stop))).transpose(2, 0, 1)[channels] | ||
|
||
|
||
class PilWriter(FileMixin, ImageWriter): | ||
"""Provides numpy array-like interface to PIL-compatible image file.""" | ||
|
||
def __init__(self, path, shape): | ||
super().__init__(path=path) | ||
if not shape[0] in (1, 3, 4): | ||
raise ValueError(f'Only 1, 3, and 4 channels supported, got {shape}') | ||
self._shape = shape | ||
|
||
def open(self): | ||
self._descriptor = np.zeros(self._shape, dtype=np.uint8) | ||
|
||
def close(self): | ||
Image.fromarray(self._descriptor.transpose(1, 2, 0)).save(self._path) | ||
|
||
def write(self, item, data): | ||
if not self._descriptor: | ||
raise ValueError(f'Writer is not opened') | ||
channels, y, x = item | ||
self._descriptor[channels, y.start:y.stop, x.start:x.stop] = data |
64 changes: 25 additions & 39 deletions
64
aeronet_raster/aeronet_raster/dataadapters/rasterioadapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,39 @@ | ||
from .abstractadapter import AbstractReader, AbstractWriter, PaddedReaderMixin, PaddedWriterMixin | ||
from .imageadapter import ImageWriter, ImageReader | ||
from .filemixin import FileMixin | ||
import numpy as np | ||
import rasterio | ||
|
||
|
||
class RasterioReader(PaddedReaderMixin, AbstractReader): | ||
class RasterioReader(FileMixin, ImageReader): | ||
"""Provides numpy array-like interface to geotiff file via rasterio""" | ||
|
||
def __init__(self, path, verbose: bool = False, padding_mode: str = 'reflect', **kwargs): | ||
super().__init__(padding_mode) | ||
self._path = path | ||
self._data = rasterio.open(path) | ||
self.verbose = verbose | ||
self._shape = np.array((self._data.count, self._data.shape[0], self._data.shape[1])) | ||
def __init__(self, path, padding_mode: str = 'constant', **kwargs): | ||
super().__init__(path=path, padding_mode=padding_mode) | ||
|
||
def open(self): | ||
self._descriptor = rasterio.open(self._path) | ||
self._shape = self._descriptor.count, self._descriptor.shape[0], self._descriptor.shape[1] | ||
|
||
def fetch(self, item): | ||
res = self._data.read([ch+1 for ch in item[0]], | ||
window=((item[1].start, item[1].stop), | ||
(item[2].start, item[2].stop)), | ||
boundless=True).astype(np.uint8) | ||
channels, y, x = item | ||
res = self._descriptor.read([ch+1 for ch in channels], | ||
window=((y.start, y.stop), | ||
(x.start, x.stop)), | ||
boundless=True).astype(np.uint8) | ||
return res | ||
|
||
def parse_item(self, item): | ||
item = super().parse_item(item) | ||
assert len(item) == 3, f"Rasterio geotif must be indexed with 3 axes, got {item}" | ||
if isinstance(item[0], slice): | ||
item[0] = list(range(item[0].start, item[0].stop, item[0].step)) | ||
assert isinstance(item[1], slice) and isinstance(item[2], slice),\ | ||
f"Rasterio geotif spatial dimensions (1, 2) must be indexed with slices, got {item}" | ||
|
||
return item | ||
|
||
|
||
class RasterioWriter(PaddedWriterMixin, AbstractWriter): | ||
class RasterioWriter(FileMixin, ImageWriter): | ||
def __init__(self, path, **profile): | ||
self._path = path | ||
self._data = rasterio.open(path, 'w', **profile) | ||
self._shape = np.array((self._data.count, self._data.shape[0], self._data.shape[1])) | ||
super().__init__(path=path) | ||
self.profile = profile | ||
|
||
def open(self): | ||
self._descriptor = rasterio.open(self._path, 'w', **self.profile) | ||
self._shape = self._descriptor.count, self._descriptor.shape[0], self._descriptor.shape[1] | ||
|
||
def write(self, item, data): | ||
self._data.write(data, [ch+1 for ch in item[0]], | ||
window=((item[1].start, item[1].stop), | ||
(item[2].start, item[2].stop))) | ||
|
||
def parse_item(self, item): | ||
item = super().parse_item(item) | ||
assert len(item) == 3, f"Rasterio geotif must be indexed with 3 axes, got {item}" | ||
if isinstance(item[0], slice): | ||
item[0] = list(range(item[0].start, item[0].stop, item[0].step)) | ||
assert isinstance(item[1], slice) and isinstance(item[2], slice),\ | ||
f"Rasterio geotif spatial dimensions (1, 2) must be indexed with slices, got {item}" | ||
|
||
return item | ||
channels, y, x = item | ||
self._descriptor.write(data, [ch+1 for ch in channels], | ||
window=((y.start, y.stop), | ||
(x.start, x.stop))) |
Oops, something went wrong.