-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sakharov
committed
Mar 21, 2024
1 parent
4a48e19
commit 04ade79
Showing
13 changed files
with
531 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
112 changes: 112 additions & 0 deletions
112
aeronet_raster/aeronet_raster/dataadapters/abstractadapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import numpy as np | ||
from ..utils.utils import validate_coord | ||
|
||
|
||
class PaddedReaderMixin: | ||
""" | ||
Redefines __getitem__() so it works even if the coordinates are out of bounds | ||
""" | ||
def __init__(self, padding_mode: str = 'reflect'): | ||
self.padding_mode = padding_mode | ||
|
||
def __getitem__(self, item): | ||
item = self.parse_item(item) | ||
|
||
pads, safe_coords = list(), list() | ||
for axis, coords in enumerate(item): | ||
# coords can be either slice or tuple at this point (after parse_item) | ||
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted | ||
pads.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices | ||
safe_coords.append(coords) | ||
elif isinstance(coords, slice): # coords = (min:max:step) | ||
pads.append((max(-coords.start, 0), max(coords.stop - self.shape[axis], 0))) | ||
safe_coords.append(slice(coords.start + pads[-1][0], coords.stop - pads[-1][1], coords.step)) | ||
else: | ||
raise ValueError(f'Can not parse coords={coords} at axis={axis}') | ||
|
||
res = self.fetch(safe_coords) | ||
return np.pad(res, pads, mode=self.padding_mode) | ||
|
||
|
||
class PaddedWriterMixin: | ||
""" | ||
Redefines __setitem__() so it works even if the coordinates are out of bounds | ||
""" | ||
def __setitem__(self, item, data): | ||
item = self.parse_item(item) | ||
assert data.ndim == self.ndim == len(item) | ||
safe_coords, crops = list(), list() | ||
for axis, coords in enumerate(item): | ||
# coords can be either slice or tuple at this point (after parse_item) | ||
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted | ||
crops.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices | ||
safe_coords.append(coords) | ||
elif isinstance(coords, slice): # coords = (min:max:step) | ||
crops.append((max(-coords.start, 0), max(coords.stop - data.shape[axis], 0))) | ||
safe_coords.append(slice(coords.start + crops[-1][0], coords.stop - crops[-1][1], coords.step)) | ||
|
||
self.write(safe_coords, | ||
data[tuple(slice(crops[i][0], data.shape[i]-crops[i][1], 1) for i in range(data.ndim))]) | ||
|
||
|
||
class AbstractReader: | ||
"""Provides numpy array-like interface to arbitrary source of data""" | ||
def __getattr__(self, item): | ||
print(item) | ||
return getattr(self._data, item) | ||
|
||
@property | ||
def shape(self): | ||
return self._shape | ||
|
||
@property | ||
def ndim(self): | ||
return len(self._shape) | ||
|
||
def __getitem__(self, item): | ||
item = self.parse_item(item) | ||
return self.fetch(item) | ||
|
||
def fetch(self, item): | ||
"""Datasource-specific data fetching, e.g. rasterio.read()""" | ||
raise NotImplementedError | ||
|
||
def parse_item(self, item): | ||
"""Parse input for __getitem__() to handle arbitrary input | ||
Possible cases: | ||
- item is a single value (int) -> turns it into a tuple and adds the slice over the whole axis for every | ||
missing dimension | ||
- len(item) < self.ndim -> adds the slice over the whole axis for every missing dimension | ||
- len(item) > self.ndim -> raises IndexError | ||
- item contains slices without start or step defined -> defines start=0, step=1 | ||
- item contains negative indexes -> substitute them with (self.shape[axis] - index) | ||
""" | ||
if isinstance(item, (list, np.ndarray)): | ||
item = tuple(item) | ||
if not isinstance(item, tuple): | ||
item = (item, ) | ||
if len(item) > self.ndim: | ||
raise IndexError(f"Index={item} has more dimensions than data={self.shape}") | ||
item = list(item) | ||
while len(item) < self.ndim: | ||
item.append(None) | ||
|
||
for axis, coord in enumerate(item): | ||
item[axis] = validate_coord(coord, self.shape[axis]) | ||
return item | ||
|
||
|
||
class AbstractWriter(AbstractReader): | ||
"""Provides numpy array-like interface to arbitrary source of data""" | ||
def __getitem__(self, item): | ||
raise NotImplementedError | ||
|
||
def fetch(self, item): | ||
raise NotImplementedError | ||
|
||
def __setitem__(self, item, data): | ||
item = self.parse_item(item) | ||
self.write(item, data) | ||
|
||
def write(self, item, data): | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from .abstractadapter import AbstractReader | ||
import numpy as np | ||
import pkg_resources | ||
|
||
if 'PIL' in {pkg.key for pkg in pkg_resources.working_set}: | ||
from PIL import Image | ||
|
||
|
||
class PilReader(AbstractReader): | ||
"""Provides numpy array-like interface to PIL-compatible image file""" | ||
__slots__ = ('_path',) | ||
|
||
def __init__(self, path, verbose: bool = False, **kwargs): | ||
self._path = path | ||
self.verbose = verbose | ||
self._shape = np.array(Image.open(path)).transpose(2, 0, 1).shape | ||
|
||
def __getitem__(self, item): | ||
channels, y, x = self.parse_item(item) | ||
res = np.array(Image.open(self._path)).transpose(2, 0, 1)[channels, y.start:y.stop, x.start:x.stop] | ||
return res |
58 changes: 58 additions & 0 deletions
58
aeronet_raster/aeronet_raster/dataadapters/rasterioadapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from .abstractadapter import AbstractReader, AbstractWriter, PaddedReaderMixin, PaddedWriterMixin | ||
import numpy as np | ||
import rasterio | ||
|
||
|
||
class RasterioReader(PaddedReaderMixin, AbstractReader): | ||
"""Provides numpy array-like interface to geotiff file via rasterio""" | ||
__slots__ = ('_path',) | ||
|
||
def __init__(self, path, verbose: bool = False, padding_mode: str = 'reflect', **kwargs): | ||
super().__init__(padding_mode) | ||
self._path = path | ||
self._data = rasterio.open(path) | ||
self.verbose = verbose | ||
self._shape = np.array((self._data.count, self._data.shape[0], self._data.shape[1])) | ||
|
||
def fetch(self, item): | ||
res = self._data.read([ch+1 for ch in item[0]], | ||
window=((item[1].start, item[1].stop), | ||
(item[2].start, item[2].stop)), | ||
boundless=True).astype(np.uint8) | ||
return res | ||
|
||
def parse_item(self, item): | ||
item = super().parse_item(item) | ||
assert len(item) == 3, f"Rasterio geotif must be indexed with 3 axes, got {item}" | ||
if isinstance(item[0], slice): | ||
item[0] = list(range(item[0].start, item[0].stop, item[0].step)) | ||
assert isinstance(item[1], slice) and isinstance(item[2], slice),\ | ||
f"Rasterio geotif spatial dimensions (1, 2) must be indexed with slices, got {item}" | ||
|
||
return item | ||
|
||
|
||
class RasterioWriter(PaddedWriterMixin, AbstractWriter): | ||
def __init__(self, path, **profile): | ||
self._path = path | ||
self._data = rasterio.open(path, 'w', **profile) | ||
self._shape = np.array((self._data.count, self._data.shape[0], self._data.shape[1])) | ||
|
||
def write(self, item, data): | ||
self._data.write(data, [ch+1 for ch in item[0]], | ||
window=((item[1].start, item[1].stop), | ||
(item[2].start, item[2].stop))) | ||
|
||
def __setitem__(self, item, data): | ||
item = self.parse_item(item) | ||
self.write(item, data) | ||
|
||
def parse_item(self, item): | ||
item = super().parse_item(item) | ||
assert len(item) == 3, f"Rasterio geotif must be indexed with 3 axes, got {item}" | ||
if isinstance(item[0], slice): | ||
item[0] = list(range(item[0].start, item[0].stop, item[0].step)) | ||
assert isinstance(item[1], slice) and isinstance(item[2], slice),\ | ||
f"Rasterio geotif spatial dimensions (1, 2) must be indexed with slices, got {item}" | ||
|
||
return item |
41 changes: 41 additions & 0 deletions
41
aeronet_raster/aeronet_raster/dataadapters/separatebandsadapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from .abstractadapter import AbstractReader, AbstractWriter | ||
import numpy as np | ||
from typing import Sequence | ||
|
||
|
||
class SeparateBandsReader(AbstractReader): | ||
"""Provides numpy array-like interface to separate data sources (image bands)""" | ||
|
||
def __init__(self, bands: Sequence[AbstractReader], verbose: bool = False, **kwargs): | ||
self._data = bands | ||
channels = bands[0].shape[0] | ||
if len(bands) > 1: | ||
for i, b in enumerate(bands[1:]): | ||
channels += b.shape[0] | ||
if b.shape[1:] != bands[0].shape[1:]: | ||
raise ValueError(f'Band {i} shape = {b.shape[1:]} != Band 0 shape = {bands[0].shape[1:]}') | ||
self._shape = (channels, self._data[0].shape[1], self._data[0].shape[2]) | ||
|
||
def __getitem__(self, item): | ||
return np.concatenate([b[item] for b in self._data], 0) | ||
|
||
|
||
class SeparateBandsWriter(AbstractWriter): | ||
"""Provides numpy array-like interface to separate adapters, representing image bands""" | ||
__slots__ = ('_channels',) | ||
|
||
def __init__(self, bands: Sequence[AbstractWriter], **kwargs): | ||
self._data = bands | ||
channels = bands[0].shape[0] | ||
if len(bands) > 1: | ||
for i, b in enumerate(bands[1:]): | ||
self._channels += b.shape[0] | ||
if b.shape[1:] != bands[0].shape[1:]: | ||
raise ValueError(f'Band {i} shape = {b.shape[1:]} != Band 0 shape = {bands[0].shape[1:]}') | ||
self._shape = (np.sum(self._channels), self._data[0].shape[1], self._data[0].shape[2]) | ||
|
||
def __setitem__(self, item, data): | ||
current_ch = 0 | ||
for band in self._data: | ||
band[item] = data[current_ch:current_ch+band.shape[0]] | ||
current_ch += band.shape[0] |
Empty file.
52 changes: 52 additions & 0 deletions
52
aeronet_raster/aeronet_raster/utils/samplers/gridsampler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import numpy as np | ||
from typing import Sequence | ||
|
||
|
||
def get_safe_shape(shape: Sequence[int], stride: Sequence[int]): | ||
""" | ||
Returns safe shape that is divisible by stride (Equal or bigger than original shape). | ||
""" | ||
assert len(shape) == len(stride) | ||
return [shape[i] if not shape[i]%stride[i] else shape[i]//stride[i]*stride[i]+stride[i] for i in range(len(shape))] | ||
|
||
|
||
def make_grid(boundaries, stride): | ||
assert len(boundaries) == len(stride) | ||
return np.stack([x.reshape(-1) for x in np.meshgrid(*tuple(np.arange(boundaries[i][0], | ||
boundaries[i][1], | ||
stride[i]) for i in range(len(boundaries))), | ||
indexing='ij')]).transpose(1, 0) | ||
|
||
|
||
class GridSampler: | ||
""" | ||
yields from grid | ||
Args: | ||
grid: array; | ||
verbose: print debug | ||
""" | ||
__slots__ = ('_grid', 'verbose') | ||
|
||
def __init__(self, | ||
grid: np.ndarray, | ||
verbose: bool = False): | ||
if verbose: | ||
print(f"Initializing sampler {type(self)}:\n") | ||
self.verbose = verbose | ||
self._grid = grid | ||
if verbose: | ||
print(f"grid shape is {self._grid.shape}\n") | ||
|
||
@property | ||
def grid(self) -> np.ndarray: | ||
return self._grid | ||
|
||
def __len__(self): | ||
return self._grid.shape[0] | ||
|
||
def __iter__(self): | ||
yield from self._grid | ||
|
||
|
||
# def get_sampler(shape: Sequence[int], stride: Sequence[int], offset: Sequence[int]): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.