Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Sakharov committed Mar 21, 2024
1 parent 4a48e19 commit 04ade79
Show file tree
Hide file tree
Showing 13 changed files with 531 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aeronet_raster/aeronet_raster/collectionprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from multiprocessing.pool import ThreadPool
from threading import Lock
from tqdm import tqdm
import cv2

from typing import Union, Optional, Callable, List, Tuple
from .band.band import Band
from .bandcollection.bandcollection import BandCollection
Expand Down
Empty file.
112 changes: 112 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/abstractadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import numpy as np
from ..utils.utils import validate_coord


class PaddedReaderMixin:
"""
Redefines __getitem__() so it works even if the coordinates are out of bounds
"""
def __init__(self, padding_mode: str = 'reflect'):
self.padding_mode = padding_mode

def __getitem__(self, item):
item = self.parse_item(item)

pads, safe_coords = list(), list()
for axis, coords in enumerate(item):
# coords can be either slice or tuple at this point (after parse_item)
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted
pads.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices
safe_coords.append(coords)
elif isinstance(coords, slice): # coords = (min:max:step)
pads.append((max(-coords.start, 0), max(coords.stop - self.shape[axis], 0)))
safe_coords.append(slice(coords.start + pads[-1][0], coords.stop - pads[-1][1], coords.step))
else:
raise ValueError(f'Can not parse coords={coords} at axis={axis}')

res = self.fetch(safe_coords)
return np.pad(res, pads, mode=self.padding_mode)


class PaddedWriterMixin:
"""
Redefines __setitem__() so it works even if the coordinates are out of bounds
"""
def __setitem__(self, item, data):
item = self.parse_item(item)
assert data.ndim == self.ndim == len(item)
safe_coords, crops = list(), list()
for axis, coords in enumerate(item):
# coords can be either slice or tuple at this point (after parse_item)
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted
crops.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices
safe_coords.append(coords)
elif isinstance(coords, slice): # coords = (min:max:step)
crops.append((max(-coords.start, 0), max(coords.stop - data.shape[axis], 0)))
safe_coords.append(slice(coords.start + crops[-1][0], coords.stop - crops[-1][1], coords.step))

self.write(safe_coords,
data[tuple(slice(crops[i][0], data.shape[i]-crops[i][1], 1) for i in range(data.ndim))])


class AbstractReader:
"""Provides numpy array-like interface to arbitrary source of data"""
def __getattr__(self, item):
print(item)
return getattr(self._data, item)

@property
def shape(self):
return self._shape

@property
def ndim(self):
return len(self._shape)

def __getitem__(self, item):
item = self.parse_item(item)
return self.fetch(item)

def fetch(self, item):
"""Datasource-specific data fetching, e.g. rasterio.read()"""
raise NotImplementedError

def parse_item(self, item):
"""Parse input for __getitem__() to handle arbitrary input
Possible cases:
- item is a single value (int) -> turns it into a tuple and adds the slice over the whole axis for every
missing dimension
- len(item) < self.ndim -> adds the slice over the whole axis for every missing dimension
- len(item) > self.ndim -> raises IndexError
- item contains slices without start or step defined -> defines start=0, step=1
- item contains negative indexes -> substitute them with (self.shape[axis] - index)
"""
if isinstance(item, (list, np.ndarray)):
item = tuple(item)
if not isinstance(item, tuple):
item = (item, )
if len(item) > self.ndim:
raise IndexError(f"Index={item} has more dimensions than data={self.shape}")
item = list(item)
while len(item) < self.ndim:
item.append(None)

for axis, coord in enumerate(item):
item[axis] = validate_coord(coord, self.shape[axis])
return item


class AbstractWriter(AbstractReader):
"""Provides numpy array-like interface to arbitrary source of data"""
def __getitem__(self, item):
raise NotImplementedError

def fetch(self, item):
raise NotImplementedError

def __setitem__(self, item, data):
item = self.parse_item(item)
self.write(item, data)

def write(self, item, data):
raise NotImplementedError
21 changes: 21 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/piladapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .abstractadapter import AbstractReader
import numpy as np
import pkg_resources

if 'PIL' in {pkg.key for pkg in pkg_resources.working_set}:
from PIL import Image


class PilReader(AbstractReader):
"""Provides numpy array-like interface to PIL-compatible image file"""
__slots__ = ('_path',)

def __init__(self, path, verbose: bool = False, **kwargs):
self._path = path
self.verbose = verbose
self._shape = np.array(Image.open(path)).transpose(2, 0, 1).shape

def __getitem__(self, item):
channels, y, x = self.parse_item(item)
res = np.array(Image.open(self._path)).transpose(2, 0, 1)[channels, y.start:y.stop, x.start:x.stop]
return res
58 changes: 58 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/rasterioadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from .abstractadapter import AbstractReader, AbstractWriter, PaddedReaderMixin, PaddedWriterMixin
import numpy as np
import rasterio


class RasterioReader(PaddedReaderMixin, AbstractReader):
"""Provides numpy array-like interface to geotiff file via rasterio"""
__slots__ = ('_path',)

def __init__(self, path, verbose: bool = False, padding_mode: str = 'reflect', **kwargs):
super().__init__(padding_mode)
self._path = path
self._data = rasterio.open(path)
self.verbose = verbose
self._shape = np.array((self._data.count, self._data.shape[0], self._data.shape[1]))

def fetch(self, item):
res = self._data.read([ch+1 for ch in item[0]],
window=((item[1].start, item[1].stop),
(item[2].start, item[2].stop)),
boundless=True).astype(np.uint8)
return res

def parse_item(self, item):
item = super().parse_item(item)
assert len(item) == 3, f"Rasterio geotif must be indexed with 3 axes, got {item}"
if isinstance(item[0], slice):
item[0] = list(range(item[0].start, item[0].stop, item[0].step))
assert isinstance(item[1], slice) and isinstance(item[2], slice),\
f"Rasterio geotif spatial dimensions (1, 2) must be indexed with slices, got {item}"

return item


class RasterioWriter(PaddedWriterMixin, AbstractWriter):
def __init__(self, path, **profile):
self._path = path
self._data = rasterio.open(path, 'w', **profile)
self._shape = np.array((self._data.count, self._data.shape[0], self._data.shape[1]))

def write(self, item, data):
self._data.write(data, [ch+1 for ch in item[0]],
window=((item[1].start, item[1].stop),
(item[2].start, item[2].stop)))

def __setitem__(self, item, data):
item = self.parse_item(item)
self.write(item, data)

def parse_item(self, item):
item = super().parse_item(item)
assert len(item) == 3, f"Rasterio geotif must be indexed with 3 axes, got {item}"
if isinstance(item[0], slice):
item[0] = list(range(item[0].start, item[0].stop, item[0].step))
assert isinstance(item[1], slice) and isinstance(item[2], slice),\
f"Rasterio geotif spatial dimensions (1, 2) must be indexed with slices, got {item}"

return item
41 changes: 41 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/separatebandsadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from .abstractadapter import AbstractReader, AbstractWriter
import numpy as np
from typing import Sequence


class SeparateBandsReader(AbstractReader):
"""Provides numpy array-like interface to separate data sources (image bands)"""

def __init__(self, bands: Sequence[AbstractReader], verbose: bool = False, **kwargs):
self._data = bands
channels = bands[0].shape[0]
if len(bands) > 1:
for i, b in enumerate(bands[1:]):
channels += b.shape[0]
if b.shape[1:] != bands[0].shape[1:]:
raise ValueError(f'Band {i} shape = {b.shape[1:]} != Band 0 shape = {bands[0].shape[1:]}')
self._shape = (channels, self._data[0].shape[1], self._data[0].shape[2])

def __getitem__(self, item):
return np.concatenate([b[item] for b in self._data], 0)


class SeparateBandsWriter(AbstractWriter):
"""Provides numpy array-like interface to separate adapters, representing image bands"""
__slots__ = ('_channels',)

def __init__(self, bands: Sequence[AbstractWriter], **kwargs):
self._data = bands
channels = bands[0].shape[0]
if len(bands) > 1:
for i, b in enumerate(bands[1:]):
self._channels += b.shape[0]
if b.shape[1:] != bands[0].shape[1:]:
raise ValueError(f'Band {i} shape = {b.shape[1:]} != Band 0 shape = {bands[0].shape[1:]}')
self._shape = (np.sum(self._channels), self._data[0].shape[1], self._data[0].shape[2])

def __setitem__(self, item, data):
current_ch = 0
for band in self._data:
band[item] = data[current_ch:current_ch+band.shape[0]]
current_ch += band.shape[0]
Empty file.
52 changes: 52 additions & 0 deletions aeronet_raster/aeronet_raster/utils/samplers/gridsampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np
from typing import Sequence


def get_safe_shape(shape: Sequence[int], stride: Sequence[int]):
"""
Returns safe shape that is divisible by stride (Equal or bigger than original shape).
"""
assert len(shape) == len(stride)
return [shape[i] if not shape[i]%stride[i] else shape[i]//stride[i]*stride[i]+stride[i] for i in range(len(shape))]


def make_grid(boundaries, stride):
assert len(boundaries) == len(stride)
return np.stack([x.reshape(-1) for x in np.meshgrid(*tuple(np.arange(boundaries[i][0],
boundaries[i][1],
stride[i]) for i in range(len(boundaries))),
indexing='ij')]).transpose(1, 0)


class GridSampler:
"""
yields from grid
Args:
grid: array;
verbose: print debug
"""
__slots__ = ('_grid', 'verbose')

def __init__(self,
grid: np.ndarray,
verbose: bool = False):
if verbose:
print(f"Initializing sampler {type(self)}:\n")
self.verbose = verbose
self._grid = grid
if verbose:
print(f"grid shape is {self._grid.shape}\n")

@property
def grid(self) -> np.ndarray:
return self._grid

def __len__(self):
return self._grid.shape[0]

def __iter__(self):
yield from self._grid


# def get_sampler(shape: Sequence[int], stride: Sequence[int], offset: Sequence[int]):
46 changes: 45 additions & 1 deletion aeronet_raster/aeronet_raster/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from typing import Final, Tuple, List
from typing import Final, Tuple, List, Union, Optional
import string
import random
import numpy as np

TMP_DIR: Final[str] = '/tmp/raster'
IntCoords = Union[Tuple[int, int], List[int], np.ndarray]
IntBox = Union[Tuple[IntCoords, IntCoords], List[IntCoords], np.ndarray]


def parse_directory(directory: str, names: Tuple[str],
Expand Down Expand Up @@ -42,3 +44,45 @@ def band_shape_guard(raster: np.ndarray) -> np.ndarray:
def random_name(length: int = 10) -> str:
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for _ in range(length))


def validate_coord(coord: Union[int, tuple, list, slice, None], dim_size: int) -> Union[List, slice]:
"""
:param coord: Anything that can be used as an index (single value, sequence or slice)
:param dim_size: size of current data dimension
:return: tuple of indexes or valid slice
"""
if coord is None:
return slice(0, dim_size, 1)
if isinstance(coord, int):
coord = [coord, ]
if isinstance(coord, tuple):
coord = list(coord)
if isinstance(coord, list):
for i, c in enumerate(coord):
if -dim_size <= c < 0:
coord[i] = dim_size + c
elif 0 <= c < dim_size:
pass
else:
raise IndexError(f'{c} is out of bounds for axis size {dim_size}')
return sorted(coord)
if isinstance(coord, slice):
start, stop, step = coord.start or 0, coord.stop or dim_size, coord.step or 1
#if not(-dim_size <= start < dim_size) or \
# not(-dim_size-1 <= stop <= dim_size) or \
# not((step > 0) ^ (stop-start <= 0)) or stop == start:
# raise IndexError(f'Invalid slice {start, stop, step} for axis size {dim_size}')
return slice(start, stop, step)


def to_np_2(value) -> np.ndarray:
"""
converts any value to (2) shaped np array if possible
"""
if isinstance(value, slice):
value = list(range(value.start, value.stop))
if isinstance(value, (int, float)):
value = (value, value)
assert len(value) == 2, f"Length {value} = {len(value)} != 2"
return np.array(value).astype(int)
Loading

0 comments on commit 04ade79

Please sign in to comment.