Skip to content

Commit

Permalink
Crossfade fixes 3
Browse files Browse the repository at this point in the history
  • Loading branch information
Sakharov committed Apr 17, 2024
1 parent 69d226f commit 331e54a
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 112 deletions.
2 changes: 1 addition & 1 deletion aeronet_raster/aeronet_raster/band/bandsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BandSample(GeoObject):
It implements all the interfaces of the GeoObject, and stores the raster data in memory
Args:
name (str): a name of the sample, which is used as a defaule name for saving to file
name (str): a name of the sample, which is used as a default name for saving to file
raster (np.array): the raster data
crs: geographical coordinate reference system, as :obj:`CRS` or string representation
transform (Affine): affine transform
Expand Down
14 changes: 9 additions & 5 deletions aeronet_raster/aeronet_raster/dataadapters/rasterioadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ def open(self):

def fetch(self, item):
channels, y, x = item
res = self._descriptor.read([ch+1 for ch in channels],
window=((y.start, y.stop),
(x.start, x.stop)),
boundless=True).astype(np.uint8)
return res
return self._descriptor.read([ch+1 for ch in channels],
window=((y.start, y.stop),
(x.start, x.stop)))

@property
def profile(self):
Expand All @@ -46,6 +44,12 @@ def count(self):
raise ValueError(f'File {self._path} is not opened')
return self._descriptor.count

@property
def dtype(self):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
return self._descriptor.profile['dtype']


class RasterioWriter(ImageWriter, RasterioReader):
def __init__(self, path, profile, padding_mode: str = 'constant', **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion aeronet_raster/aeronet_raster/dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def process(src: ArrayLike,
dst_coords[i] + dst_sample_size[i],
1) for i in range(len(dst_coords)))] = readen + res


def get_blend_mask(shape: Sequence[int], margin: Sequence[int]) -> np.ndarray:
"""
Returns alpha-blend float mask with values within [0..1] and linear fades on each side
Expand Down Expand Up @@ -173,6 +172,8 @@ def add_ch_ndim(size, n_ch):
processor = get_auto_cropped_processor(processor, dst_margin, mode)
dst_margin = np.array((0, 0, 0)) # zero margin
elif mode == 'crossfade':
if dst.dtype not in ('float', 'float32', 'float16', np.float64):
logging.warning(f'For crossfade mode it is recommended to set destination dtype to float, got {dst.dtype}')
mask = get_blend_mask(dst_sample_size, dst_margin)
processor = get_auto_cropped_processor(processor, dst_margin, mode, mask)

Expand Down
165 changes: 63 additions & 102 deletions aeronet_raster/test/CrossfadeTest.ipynb

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions aeronet_raster/test/unit/crossfade_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import numpy as np
from aeronet_raster.aeronet_raster import dataprocessor
from aeronet_raster.aeronet_raster.dataadapters import rasterioadapter
import logging
logging.basicConfig(level=logging.INFO)

with rasterioadapter.RasterioReader('test_data/input.tif') as src:
with rasterioadapter.RasterioReader('test_data/input2.tif') as src:
profile = src.profile
profile['dtype'] = 'float32'
with rasterioadapter.RasterioWriter('test_data/output.tif', profile) as dst:
dataprocessor.process_image(src=src,
src_sample_size=512,
src_margin=64,
processor=lambda x: x,
processor=lambda x: x,#np.ones((3, 512, 512))*127,
dst=dst,
mode='crossfade',
verbose=False)
verbose=True)

0 comments on commit 331e54a

Please sign in to comment.