Skip to content

Commit

Permalink
Speed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sakharov committed Apr 1, 2024
1 parent de52d1e commit f2be760
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 36 deletions.
2 changes: 1 addition & 1 deletion aeronet_raster/aeronet_raster/collectionprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def __init__(self,
input_channels: List[str],
output_labels: List[str],
processing_fn: Callable,
sample_size: Tuple[int] = (1024, 1024),
sample_size: Tuple[int, int] = (1024, 1024),
bound: int = 256,
src_nodata=0,
nodata=None, dst_nodata=0,
Expand Down
4 changes: 2 additions & 2 deletions aeronet_raster/aeronet_raster/dataadapters/abstractadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class PaddedReaderMixin:
"""
Redefines __getitem__() so it works even if the coordinates are out of bounds
"""
def __init__(self, padding_mode: str = 'reflect'):
def __init__(self, padding_mode: str = 'constant'):
self.padding_mode = padding_mode

def __getitem__(self, item):
Expand Down Expand Up @@ -42,7 +42,7 @@ def __setitem__(self, item, data):
crops.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices
safe_coords.append(coords)
elif isinstance(coords, slice): # coords = (min:max:step)
crops.append((max(-coords.start, 0), max(coords.stop - data.shape[axis], 0)))
crops.append((max(-coords.start, 0), max(coords.stop - self.shape[axis], 0)))
safe_coords.append(slice(coords.start + crops[-1][0], coords.stop - crops[-1][1], coords.step))

self.write(safe_coords,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def parse_item(self, item):
return item


class RasterioWriter(AbstractWriter):
class RasterioWriter(PaddedWriterMixin, AbstractWriter):
def __init__(self, path, **profile):
self._path = path
self._data = rasterio.open(path, 'w', **profile)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .abstractadapter import AbstractReader, AbstractWriter
from .abstractadapter import AbstractReader, AbstractWriter, PaddedReaderMixin, PaddedWriterMixin
import numpy as np
from typing import Sequence


class SeparateBandsReader(AbstractReader):
class SeparateBandsReader(PaddedReaderMixin, AbstractReader):
"""Provides numpy array-like interface to separate data sources (image bands)"""

def __init__(self, bands: Sequence[AbstractReader], verbose: bool = False, **kwargs):
Expand All @@ -16,6 +16,9 @@ def __init__(self, bands: Sequence[AbstractReader], verbose: bool = False, **kwa
raise ValueError(f'Band {i} shape = {b.shape[1:]} != Band 0 shape = {bands[0].shape[1:]}')
self._shape = (len(self._channels), self._data[0].shape[1], self._data[0].shape[2])

def __getattr__(self, item):
return getattr(self._data[0], item)

def fetch(self, item):
res = list()
for ch in item[0]:
Expand All @@ -33,7 +36,7 @@ def parse_item(self, item):
return item


class SeparateBandsWriter(AbstractWriter):
class SeparateBandsWriter(PaddedWriterMixin, AbstractWriter):
"""Provides numpy array-like interface to separate adapters, representing image bands"""

def __init__(self, bands: Sequence[AbstractWriter], **kwargs):
Expand All @@ -46,6 +49,9 @@ def __init__(self, bands: Sequence[AbstractWriter], **kwargs):
raise ValueError(f'Band {i} shape = {b.shape[1:]} != Band 0 shape = {bands[0].shape[1:]}')
self._shape = (len(self._channels), self._data[0].shape[1], self._data[0].shape[2])

def __getattr__(self, item):
return getattr(self._data[0], item)

def write(self, item, data):
assert len(data) == len(item[0])
for data_ch, ch in enumerate(item[0]):
Expand Down
74 changes: 45 additions & 29 deletions aeronet_raster/test/unit/bands_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,54 @@
import rasterio
import numpy as np
import logging

import cProfile
from pstats import SortKey, Stats
logging.basicConfig(level=logging.INFO)

profile = {'width': 16, 'height': 16, 'count': 3, 'dtype': 'uint8' }

random_data = np.random.randint(0, 254, size=(3, 16, 16))
processing = lambda x: x+1

profile['count'] = 1
with rasterio.open('RED.tif', 'w', **profile) as dst:
dst.write(random_data[0], 1)
with rasterio.open('GRN.tif', 'w', **profile) as dst:
dst.write(random_data[1], 1)
with rasterio.open('BLU.tif', 'w', **profile) as dst:
dst.write(random_data[2], 1)

src = SeparateBandsReader([RasterioReader('RED.tif'), RasterioReader('GRN.tif'), RasterioReader('BLU.tif')])
dst = SeparateBandsWriter([RasterioWriter('RED_res.tif', **profile),
RasterioWriter('GRN_res.tif', **profile),
RasterioWriter('BLU_res.tif', **profile)])
def create_random_bands():
profile = {'width': 16, 'height': 16, 'count': 3, 'dtype': 'uint8' }
random_data = np.random.randint(0, 254, size=(3, 16, 16))
profile['count'] = 1
with rasterio.open('test_data/RED.tif', 'w', **profile) as dst:
dst.write(random_data[0], 1)
with rasterio.open('test_data/GRN.tif', 'w', **profile) as dst:
dst.write(random_data[1], 1)
with rasterio.open('test_data/BLU.tif', 'w', **profile) as dst:
dst.write(random_data[2], 1)
return random_data, profile

def validate_result(random_data):
res = list()
with rasterio.open('RED_res.tif') as d:
res.append(d.read(1))
with rasterio.open('GRN_res.tif') as d:
res.append(d.read(1))
with rasterio.open('BLU_res.tif') as d:
res.append(d.read(1))
print(res == processing(random_data))

processing = lambda x: x

src = SeparateBandsReader([RasterioReader('test_data/RED.tif'),
RasterioReader('test_data/GRN.tif'),
RasterioReader('test_data/BLU.tif')])
profile = src.profile
dst = SeparateBandsWriter([RasterioWriter('test_data/RED_out.tif', **profile),
RasterioWriter('test_data/GRN_out.tif', **profile),
RasterioWriter('test_data/BLU_out.tif', **profile)])

#profile['count'] = 3
#dst = RasterioWriter('test_data/output.tif', **profile)
res = list()
import time
for _ in range(10):
start = time.time()
process_image(src, 1024, 256, processing, dst)
res.append(time.time() - start)
print(res)
print(np.mean(res), np.std(res))
#cProfile.run('process_image(src, 1024, 256, processing, dst)', sort='tottime')

process_image(src, 8, 2, processing, dst, verbose=True)

del src
del dst

res = list()
with rasterio.open('RED_res.tif') as d:
res.append(d.read(1))
with rasterio.open('GRN_res.tif') as d:
res.append(d.read(1))
with rasterio.open('BLU_res.tif') as d:
res.append(d.read(1))

print(res == processing(random_data))
24 changes: 24 additions & 0 deletions aeronet_raster/test/unit/old_collectionprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from aeronet_raster.aeronet_raster import CollectionProcessor, BandCollection
import cProfile
import numpy as np
import time
start = time.time()

processor = CollectionProcessor(input_channels=['RED', 'GRN', 'BLU'],
output_labels=['RED_out', 'GRN_out', 'BLU_out'],
processing_fn=lambda x: x,
n_workers=0,
sample_size=(512, 512),
bound=256)

bc = BandCollection(['test_data/RED.tif', 'test_data/GRN.tif', 'test_data/BLU.tif'])
#cProfile.run('labels_bc = processor.process(bc, "test_data")', sort='tottime')

res = list()
for _ in range(10):
start = time.time()
processor.process(bc, "test_data")
res.append(time.time() - start)
print(res)
print(np.mean(res), np.std(res))

0 comments on commit f2be760

Please sign in to comment.