Skip to content

Commit

Permalink
Crossfade
Browse files Browse the repository at this point in the history
  • Loading branch information
Sakharov committed Apr 8, 2024
1 parent 2579ee3 commit 3f74efa
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def fetch(self, item):
raise NotImplementedError


class AbstractWriter(AbstractArrayLike):
class AbstractWriter(AbstractReader):
"""Provides numpy array-like interface to arbitrary source of data"""
def __setitem__(self, item, data):
item = self.parse_item(item)
Expand Down
33 changes: 33 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/numpyadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .abstractadapter import AbstractReader, AbstractWriter
from .boundsafemixin import BoundSafeReaderMixin, BoundSafeWriterMixin


class NumpyReader(BoundSafeReaderMixin, AbstractReader):
"""Works with numpy arrays. Useful for testing"""
def __init__(self, data, padding_mode='constant', **kwargs):
super().__init__(padding_mode, **kwargs)
self._data = data

@property
def ndim(self):
return self._data.ndim

@property
def shape(self):
return self._data.shape

def __len__(self):
return self.shape[0]

def fetch(self, item):
if isinstance(item, list):
item = tuple(item)
return self._data[item]


class NumpyWriter(BoundSafeWriterMixin, NumpyReader):
"""Works with numpy arrays. Useful for testing"""
def write(self, item, data):
if isinstance(item, list):
item = tuple(item)
self._data[item] = data
84 changes: 71 additions & 13 deletions aeronet_raster/aeronet_raster/dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from .utils.samplers.gridsampler import GridSampler, make_grid, get_safe_shape
from .dataadapters.abstractadapter import AbstractArrayLike
from .dataadapters.imageadapter import ImageWriter, ImageReader
from typing import Sequence, Callable, Union
from typing import Sequence, Callable, Union, Final, Tuple, Optional
import numpy as np

ArrayLike = Union[np.array, AbstractArrayLike]

DST_MARGIN_MODES: Final[Tuple] = ('crop', 'crossfade')


def process(src: ArrayLike,
src_sampler: GridSampler,
Expand All @@ -15,6 +17,7 @@ def process(src: ArrayLike,
dst: ArrayLike,
dst_sampler: GridSampler,
dst_sample_size: Sequence[int],
mode: str = 'crop',
verbose: bool = False):
"""
Processes array-like data with predictor in windowed mode. Writes to dst inplace
Expand All @@ -26,6 +29,7 @@ def process(src: ArrayLike,
dst: destination, array-like
dst_sampler: must yield coordinate tuples
dst_sample_size: window size for dst
mode: 'crop' or 'crossfade'
verbose: verbose
"""

Expand All @@ -41,16 +45,63 @@ def process(src: ArrayLike,
src_coords[i]+src_sample_size[i],
1) for i in range(len(src_coords)))]
res = processor(sample)
# TODO: add weight matrix
dst[tuple(slice(dst_coords[i],
dst_coords[i]+dst_sample_size[i],
1) for i in range(len(dst_coords)))] = res
if mode == 'crop':
dst[tuple(slice(dst_coords[i],
dst_coords[i]+dst_sample_size[i],
1) for i in range(len(dst_coords)))] = res
elif mode == 'crossfade':
dst[tuple(slice(dst_coords[i],
dst_coords[i] + dst_sample_size[i],
1) for i in range(len(dst_coords)))] += res


def get_auto_cropped_processor(processor: Callable, margin: Sequence[int]) -> Callable:
"""Wraps processor, crops its output by margin"""
def get_blend_mask(shape: Sequence[int], margin: Sequence[int]) -> np.ndarray:
"""
Returns alpha-blend float mask with values within [0..1] and linear fades on each side
shape: mask shape
margin: margin values for every axis
Returns: np.ndarray
"""
if len(shape) != len(margin):
raise ValueError('len(shape) != len(margin). Margin for every axis is required')
mask = np.ones(shape)
for axis in range(len(shape)):
if margin[axis] == 0:
continue
if margin[axis]*2 >= shape[axis]:
raise ValueError(f'margin must be less than shape//2, got {margin[axis]}, {shape[axis]} along axis={axis}')
linear_mask = np.concatenate((np.linspace(1/margin[axis], 1, margin[axis]) if margin[axis] > 1 else np.array((0.5,)),
np.ones(shape[axis] - 2 * margin[axis]),
np.linspace(1, 1/margin[axis], margin[axis]) if margin[axis] > 1 else np.array((0.5,))))

mask = np.swapaxes(mask, len(shape)-1, axis)
mask = mask*linear_mask
mask = np.swapaxes(mask, len(shape)-1, axis)
return mask


def get_auto_cropped_processor(processor: Callable, margin: Sequence[int], mode: str = 'crop',
blend_mask: Optional[np.ndarray] = None) -> Callable:
"""Wraps processor, crops its output by margin
processor: function to wrap
margin: margin values for every axis
mode: 'crop' - crop every axis by margin, 'crossfade' - apply alpha-blend mask.
blend_mask: mask to use. Must be same size as the sample. If not specified, will be calculated for each sample
Returns: Callable
"""
if mode not in DST_MARGIN_MODES:
raise ValueError(f'mode must be one of {DST_MARGIN_MODES}')

def inner(x):
return processor(x)[tuple(slice(margin[i], x.shape[i]-margin[i], 1) for i in range(len(margin)))]
if mode == 'crop':
return processor(x)[tuple(slice(margin[i], x.shape[i]-margin[i], 1) for i in range(len(margin)))]
if mode == 'crossfade':
output = processor(x)
if blend_mask is None:
mask = get_blend_mask(output.shape, margin)
else:
mask = blend_mask
return output*mask
return inner


Expand All @@ -61,6 +112,7 @@ def process_image(src: ImageReader,
dst: ImageWriter,
dst_sample_size: Union[int, Sequence[int], None] = None,
dst_margin: Union[int, Sequence[int], None] = None,
dst_margin_mode: str = 'crop',
verbose: bool = False):
"""
Helper function that prepares samplers and mimics the behavior of the old collectionprocessor
Expand All @@ -73,6 +125,7 @@ def process_image(src: ImageReader,
dst_sample_size: size of the window including margins (processor output), so stride = sample_size - 2 * margin.
If None - same as src_sample_size
dst_margin: processor output crop along each axis. If None - same as src_margin
dst_margin_mode: 'crop' or 'crossfade'
verbose: verbose
"""
def build_sampler(shape, sample_size, margin):
Expand All @@ -97,16 +150,21 @@ def add_ch_ndim(size, n_ch):
src_sample_size = add_ch_ndim(src_sample_size, src.shape[0])
src_margin = add_ch_ndim(src_margin, 0)

processor = get_auto_cropped_processor(processor, dst_margin)

src_sampler = build_sampler(src.shape, src_sample_size, src_margin)
if verbose:
logging.info(f'Src sampler grid: {src_sampler.grid}')

dst_sample_size = dst_sample_size-2*dst_margin # exclude margin from dst sample size since we crop it in processor
dst_sampler = build_sampler(dst.shape, dst_sample_size, np.array((0, 0, 0))) # zero margin
if dst_margin_mode == 'crop':
dst_sample_size = dst_sample_size-2*dst_margin # exclude margin from dst sample size since we crop it in the processor
processor = get_auto_cropped_processor(processor, dst_margin, dst_margin_mode)
dst_margin = np.array((0, 0, 0)) # zero margin
elif dst_margin_mode == 'crossfade':
mask = get_blend_mask(dst_sample_size, dst_margin)
processor = get_auto_cropped_processor(processor, dst_margin, dst_margin_mode, mask)

dst_sampler = build_sampler(dst.shape, dst_sample_size, dst_margin)
if verbose:
logging.info(f'Dst sampler grid: {dst_sampler.grid}')

process(src, src_sampler, src_sample_size, processor,
dst, dst_sampler, dst_sample_size, verbose)
dst, dst_sampler, dst_sample_size, dst_margin_mode, verbose)
Loading

0 comments on commit 3f74efa

Please sign in to comment.