Skip to content

Commit

Permalink
Furthur code refactoring (lanha#5)
Browse files Browse the repository at this point in the history
* Freeze keras version to be compatible with tf1

* minor refactoring

* Revert naming of saved patches.

* Uncomment the commented line

* Add tests for inference

* Black update and fix some liniting issues.
  • Loading branch information
nekhtiari committed Aug 5, 2020
1 parent 26658d5 commit 823e82e
Show file tree
Hide file tree
Showing 13 changed files with 194 additions and 49 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy
tensorflow==1.15.2
keras
keras==2.3.1
scikit-image
imageio
rasterio
Expand Down
81 changes: 51 additions & 30 deletions testing/s2_tiles_supres.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def save_result(


# pylint: disable-msg=too-many-arguments
def update(data, size_10m: Tuple, model_output: np.ndarray, xmi: int, ymi: int):
def update(pr_10m, size_10m: Tuple, model_output: np.ndarray, xmi: int, ymi: int):
"""
This method creates the proper georeferencing for the output image.
:param data: The raster file for 10m resolution.
Expand All @@ -47,16 +47,14 @@ def update(data, size_10m: Tuple, model_output: np.ndarray, xmi: int, ymi: int):
# Here based on the params.json file, the output image dimension will be calculated.
out_dims = model_output.shape[2]

with rasterio.open(data) as d_s:
p_r = d_s.profile
new_transform = p_r["transform"] * A.translation(xmi, ymi)
p_r.update(dtype=rasterio.float32)
p_r.update(driver="GTiff")
p_r.update(width=size_10m[1])
p_r.update(height=size_10m[0])
p_r.update(count=out_dims)
p_r.update(transform=new_transform)
return p_r
new_transform = pr_10m["transform"] * A.translation(xmi, ymi)
pr_10m.update(dtype=rasterio.float32)
pr_10m.update(driver="GTiff")
pr_10m.update(width=size_10m[1])
pr_10m.update(height=size_10m[0])
pr_10m.update(count=out_dims)
pr_10m.update(transform=new_transform)
return pr_10m


class Superresolution(DATA_UTILS):
Expand All @@ -68,6 +66,7 @@ def __init__(self, data_file_path, clip_to_aoi, copy_original_bands, output_dir)

super().__init__(data_file_path)

# pylint: disable=attribute-defined-outside-init
def start(self):
data_list = self.get_data()

Expand All @@ -93,17 +92,24 @@ def start(self):
for dsdesc in data_list:
if "10m" in dsdesc:
LOGGER.info("Selected 10m bands:")
validated_10m_bands, validated_10m_indices, dic_10m = self.validate(
dsdesc
)
(
self.validated_10m_bands,
validated_10m_indices,
dic_10m,
) = self.validate(dsdesc)
data10 = self.data_final(
dsdesc, validated_10m_indices, xmin, ymin, xmax, ymax, 1
)
with rasterio.open(dsdesc) as d_s:
pr_10m = d_s.profile

if "20m" in dsdesc:
LOGGER.info("Selected 20m bands:")
validated_20m_bands, validated_20m_indices, dic_20m = self.validate(
dsdesc
)
(
self.validated_20m_bands,
validated_20m_indices,
dic_20m,
) = self.validate(dsdesc)
data20 = self.data_final(
dsdesc,
validated_20m_indices,
Expand All @@ -115,9 +121,11 @@ def start(self):
)
if "60m" in dsdesc:
LOGGER.info("Selected 60m bands:")
validated_60m_bands, validated_60m_indices, dic_60m = self.validate(
dsdesc
)
(
self.validated_60m_bands,
validated_60m_indices,
dic_60m,
) = self.validate(dsdesc)
data60 = self.data_final(
dsdesc,
validated_60m_indices,
Expand All @@ -128,9 +136,16 @@ def start(self):
1 // 6,
)

validated_descriptions_all = {**dic_10m, **dic_20m, **dic_60m}
self.validated_descriptions_all = {**dic_10m, **dic_20m, **dic_60m}
return data10, data20, data60, [xmin, ymin, xmax, ymax], pr_10m

if validated_60m_bands and validated_20m_bands and validated_10m_bands:
def inference(self, data10, data20, data60, coord, pr_10m):

if (
self.validated_60m_bands
and self.validated_20m_bands
and self.validated_10m_bands
):
LOGGER.info("Super-resolving the 60m data into 10m bands")
sr60 = dsen2_60(data10, data20, data60, deep=False)
LOGGER.info("Super-resolving the 20m data into 10m bands")
Expand All @@ -142,15 +157,17 @@ def start(self):
if self.copy_original_bands:
sr_final = np.concatenate((data10, sr20, sr60), axis=2)
validated_sr_final_bands = (
validated_10m_bands + validated_20m_bands + validated_60m_bands
self.validated_10m_bands
+ self.validated_20m_bands
+ self.validated_60m_bands
)
else:
sr_final = np.concatenate((sr20, sr60), axis=2)
validated_sr_final_bands = validated_20m_bands + validated_60m_bands
validated_sr_final_bands = (
self.validated_20m_bands + self.validated_60m_bands
)

for dsdesc in data_list:
if "10m" in dsdesc:
p_r = update(dsdesc, data10.shape, sr_final, xmin, ymin)
pr_10m_updated = update(pr_10m, data10.shape, sr_final, coord[0], coord[1])

path_to_output_img = self.data_name.split(".")[0] + "_superresolution.tif"
filename = os.path.join(self.output_dir, path_to_output_img)
Expand All @@ -159,14 +176,18 @@ def start(self):
save_result(
sr_final,
validated_sr_final_bands,
validated_descriptions_all,
p_r,
self.validated_descriptions_all,
pr_10m_updated,
filename,
)
del sr_final
LOGGER.info("This is for releasing memory: %s", gc.collect())
LOGGER.info("Writing the super-resolved bands is finished.")

def process(self):
data10, data20, data60, coord, pr_10m = self.start()
self.inference(data10, data20, data60, coord, pr_10m)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -200,4 +221,4 @@ def start(self):
args = parser.parse_args()
Superresolution(
args.data_file_path, args.clip_to_aoi, args.copy_original_bands, args.output_dir
).start()
).process()
5 changes: 1 addition & 4 deletions testing/supres.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from __future__ import division

import sys

sys.path.append("../")
from utils.DSen2Net import s2model
from utils.patches import get_test_patches, get_test_patches60, recompose_images


SCALE = 2000
MDL_PATH = "../models/"
MDL_PATH = "./models/"


def dsen2_20(d10, d20, deep=False):
Expand Down
5 changes: 4 additions & 1 deletion tests/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../utils/"))
)

sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../testing/"))
)

# pylint: disable=unused-import,wrong-import-position
from data_utils import DATA_UTILS, get_logger
from s2_tiles_supres import Superresolution
Binary file added tests/mock_data/data_10.npy
Binary file not shown.
Binary file added tests/mock_data/data_20.npy
Binary file not shown.
Binary file added tests/mock_data/data_60.npy
Binary file not shown.
Binary file added tests/mock_data/test_10m.tif
Binary file not shown.
Binary file added tests/mock_data/test_20m.tif
Binary file not shown.
Binary file added tests/mock_data/test_60m.tif
Binary file not shown.
119 changes: 119 additions & 0 deletions tests/test_s2_tiles_superres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
import numpy as np

import pytest
import rasterio
from rasterio import Affine
from rasterio.crs import CRS

from blockutils.common import ensure_data_directories_exist
from context import Superresolution


# pylint: disable=redefined-outer-name
@pytest.fixture(scope="session")
def fixture_superresolution_clip():
ensure_data_directories_exist()
return Superresolution(
"a.SAFE", "50.550671,26.15174,50.596161,26.19195", True, "/tmp/output"
)


def test_start(fixture_superresolution_clip, monkeypatch):
_location_ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
data_10m = os.path.join(_location_, "mock_data/test_10m.tif")
data_20m = os.path.join(_location_, "mock_data/test_20m.tif")
data_60m = os.path.join(_location_, "mock_data/test_60m.tif")
expected_final_dset = [data_10m, data_20m, data_60m]

def _mock_getdata(self):
return expected_final_dset

monkeypatch.setattr(Superresolution, "get_data", _mock_getdata)
(
data10,
data20,
data60,
[xmin, ymin, xmax, ymax],
pr,
) = fixture_superresolution_clip.start()
assert data10.shape == (444, 456, 4)
assert data20.shape == (221, 227, 6)
assert data60.shape == (73, 75, 2)
assert [xmin, ymin, xmax, ymax] == [48, 174, 503, 617]
assert pr == {
"driver": "GTiff",
"dtype": "uint16",
"nodata": None,
"width": 1584,
"height": 1762,
"count": 4,
"crs": CRS.from_epsg(32639),
"transform": Affine(10.0, 0.0, 454590.0, 0.0, -10.0, 2898770.0),
"blockxsize": 128,
"blockysize": 128,
"tiled": True,
"interleave": "pixel",
}


def test_inference(fixture_superresolution_clip):
_location_ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
fixture_superresolution_clip.validated_10m_bands = ["B4", "B3", "B2", "B8"]
fixture_superresolution_clip.validated_20m_bands = [
"B5",
"B6",
"B7",
"B8A",
"B11",
"B12",
]
fixture_superresolution_clip.validated_60m_bands = ["B1", "B9"]
fixture_superresolution_clip.validated_descriptions_all = {
"B4": "B4 (665 nm)",
"B3": "B3 (560 nm)",
"B2": "B2 (490 nm)",
"B8": "B8 (842 nm)",
"B5": "B5 (705 nm)",
"B6": "B6 (740 nm)",
"B7": "B7 (783 nm)",
"B8A": "B8A (865 nm)",
"B11": "B11 (1610 nm)",
"B12": "B12 (2190 nm)",
"B1": "B1 (443 nm)",
"B9": "B9 (945 nm)",
}
fixture_superresolution_clip.data_name = (
"S2B_MSIL1C_20200708T070629_N0209_R106_T39RVJ_20200708T100303.SAFE"
)

data10 = np.load(os.path.join(_location_, "mock_data/data_10.npy"))
data20 = np.load(os.path.join(_location_, "mock_data/data_20.npy"))
data60 = np.load(os.path.join(_location_, "mock_data/data_60.npy"))

coord = [48, 174, 503, 617]
pr = {
"driver": "GTiff",
"dtype": "uint16",
"nodata": None,
"width": 1584,
"height": 1762,
"count": 4,
"crs": CRS.from_epsg(32639),
"transform": Affine(10.0, 0.0, 454590.0, 0.0, -10.0, 2898770.0),
"blockxsize": 128,
"blockysize": 128,
"tiled": True,
"interleave": "pixel",
}

fixture_superresolution_clip.inference(data10, data20, data60, coord, pr)
result_path = os.path.join(
"/tmp/output",
fixture_superresolution_clip.data_name.split(".")[0] + "_superresolution.tif",
)
assert os.path.isfile(result_path)
with rasterio.open(result_path) as src:
assert src.count == 12
assert src.transform == Affine(10.0, 0.0, 455070.0, 0.0, -10.0, 2897030.0)
assert src.profile["driver"] == "GTiff"
24 changes: 15 additions & 9 deletions training/create_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def process_patches(self):
else:
scale = 2

self.name = self.data_name.split(".")[0]
# self.name = self.data_name.split(".")[0]

data10, data20, data60, xmin, ymin, xmax, ymax = self.get_original_image()

Expand Down Expand Up @@ -159,7 +159,7 @@ def saving_test_data(self, data10, data20, data60):
data10, data20, data60
)
out_per_image0 = self.save_prefix + "test60/"
out_per_image = self.save_prefix + "test60/" + self.name + "/"
out_per_image = self.save_prefix + "test60/" + self.data_name + "/"
if not os.path.isdir(out_per_image0):
os.mkdir(out_per_image0)
if not os.path.isdir(out_per_image):
Expand All @@ -171,7 +171,7 @@ def saving_test_data(self, data10, data20, data60):
else:
data10_lr, data20_lr = self.get_downsampled_images(data10, data20, data60)
out_per_image0 = self.save_prefix + "test/"
out_per_image = self.save_prefix + "test/" + self.name + "/"
out_per_image = self.save_prefix + "test/" + self.data_name + "/"
if not os.path.isdir(out_per_image0):
os.mkdir(out_per_image0)
if not os.path.isdir(out_per_image):
Expand All @@ -198,7 +198,9 @@ def saving_test_data(self, data10, data20, data60):
out_per_image + "no_tiling/" + "data20_gt", data20.astype(np.float32)
)
self.save_band(
self.save_prefix, data10_lr[:, :, 0:3], "/test/" + self.name + "/RGB",
self.save_prefix,
data10_lr[:, :, 0:3],
"/test/" + self.data_name + "/RGB",
)
np.save(out_per_image + "no_tiling/" + "data10", data10_lr.astype(np.float32))
np.save(out_per_image + "no_tiling/" + "data20", data20_lr.astype(np.float32))
Expand All @@ -209,16 +211,20 @@ def create_rgb_images(self, data10, data20, data60):
data10_lr, data20_lr = self.get_downsampled_images(data10, data20, data60)
LOGGER.info("Creating RGB images...")
self.save_band(
self.save_prefix, data10_lr[:, :, 0:3], "/raw/rgbs/" + self.name + "RGB",
self.save_prefix,
data10_lr[:, :, 0:3],
"/raw/rgbs/" + self.data_name + "RGB",
)
self.save_band(
self.save_prefix, data20_lr[:, :, 0:3], "/raw/rgbs/" + self.name + "RGB20",
self.save_prefix,
data20_lr[:, :, 0:3],
"/raw/rgbs/" + self.data_name + "RGB20",
)

def saving_true_data(self, data10, data20, data60):
# elif true_data:
out_per_image0 = self.save_prefix + "true/"
out_per_image = self.save_prefix + "true/" + self.name + "/"
out_per_image = self.save_prefix + "true/" + self.data_name + "/"
if not os.path.isdir(out_per_image0):
os.mkdir(out_per_image0)
if not os.path.isdir(out_per_image):
Expand All @@ -244,7 +250,7 @@ def saving_train_data(self, data10, data20, data60):
# if train_data
if self.run_60:
out_per_image0 = self.save_prefix + "train60/"
out_per_image = self.save_prefix + "train60/" + self.name + "/"
out_per_image = self.save_prefix + "train60/" + self.data_name + "/"
if not os.path.isdir(out_per_image0):
os.mkdir(out_per_image0)
if not os.path.isdir(out_per_image):
Expand All @@ -260,7 +266,7 @@ def saving_train_data(self, data10, data20, data60):
)
else:
out_per_image0 = self.save_prefix + "train/"
out_per_image = self.save_prefix + "train/" + self.name + "/"
out_per_image = self.save_prefix + "train/" + self.data_name + "/"
if not os.path.isdir(out_per_image0):
os.mkdir(out_per_image0)
if not os.path.isdir(out_per_image):
Expand Down
Loading

0 comments on commit 823e82e

Please sign in to comment.