Skip to content

Commit

Permalink
Crossfade 2
Browse files Browse the repository at this point in the history
  • Loading branch information
Sakharov committed Apr 9, 2024
1 parent 3f74efa commit 5b50af3
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 62 deletions.
10 changes: 10 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/filemixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, traceback):
self.close()

def __getitem__(self, item):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
return super().__getitem__(item)

def __setitem__(self, item, data):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
super().__setitem__(item, data)

@property
def shape(self):
if not self._descriptor:
Expand Down
6 changes: 3 additions & 3 deletions aeronet_raster/aeronet_raster/dataadapters/imageadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, padding_mode='constant', **kwargs):
def parse_item(self, item):
item = super().parse_item(item)
if not len(item) == 3:
raise ValueError(f"mage must be indexed with 3 axes, got {item}")
raise ValueError(f"Image must be indexed with 3 axes, got {item}")
if isinstance(item[0], slice):
item[0] = list(range(item[0].start, item[0].stop, item[0].step))
assert isinstance(item[1], slice) and isinstance(item[2], slice),\
Expand All @@ -36,11 +36,11 @@ class ImageWriter(BoundSafeWriterMixin, AbstractWriter):
def parse_item(self, item):
item = super().parse_item(item)
if not len(item) == 3:
raise ValueError(f"PIL Image must be indexed with 3 axes, got {item}")
raise ValueError(f"Image must be indexed with 3 axes, got {item}")
if isinstance(item[0], slice):
item[0] = list(range(item[0].start, item[0].stop, item[0].step))
assert isinstance(item[1], slice) and isinstance(item[2], slice),\
f"PIL Image spatial axes (1 and 2) must be indexed with slices, got {item}"
f"Image spatial axes (1 and 2) must be indexed with slices, got {item}"
return item

@property
Expand Down
14 changes: 10 additions & 4 deletions aeronet_raster/aeronet_raster/dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,14 @@ def get_blend_mask(shape: Sequence[int], margin: Sequence[int]) -> np.ndarray:
continue
if margin[axis]*2 >= shape[axis]:
raise ValueError(f'margin must be less than shape//2, got {margin[axis]}, {shape[axis]} along axis={axis}')
linear_mask = np.concatenate((np.linspace(1/margin[axis], 1, margin[axis]) if margin[axis] > 1 else np.array((0.5,)),
min_v = 1/(margin[axis] + 1)
linear_mask = np.concatenate((np.linspace(min_v,
1 - min_v,
margin[axis]) if margin[axis] > 1 else np.array((0.5,)),
np.ones(shape[axis] - 2 * margin[axis]),
np.linspace(1, 1/margin[axis], margin[axis]) if margin[axis] > 1 else np.array((0.5,))))
np.linspace(1 - min_v,
min_v,
margin[axis]) if margin[axis] > 1 else np.array((0.5,))))

mask = np.swapaxes(mask, len(shape)-1, axis)
mask = mask*linear_mask
Expand Down Expand Up @@ -128,8 +133,9 @@ def process_image(src: ImageReader,
dst_margin_mode: 'crop' or 'crossfade'
verbose: verbose
"""
def build_sampler(shape, sample_size, margin):
stride = sample_size - 2 * margin
def build_sampler(shape, sample_size, margin, mode):
assert mode in DST_MARGIN_MODES
stride = sample_size - 2 * margin if mode == 'crop' else sample_size - margin
assert np.all(stride > 0)
safe_shape = get_safe_shape(shape, stride)
return GridSampler(make_grid([(-margin[i],
Expand Down
Loading

0 comments on commit 5b50af3

Please sign in to comment.