diff --git a/icons/img_encoder.png b/icons/img_encoder.png
new file mode 100644
index 0000000..b8f7c4a
Binary files /dev/null and b/icons/img_encoder.png differ
diff --git a/icons/img_encoder.svg b/icons/img_encoder.svg
new file mode 100644
index 0000000..52cbd9f
--- /dev/null
+++ b/icons/img_encoder.svg
@@ -0,0 +1,118 @@
+
+
+
+
diff --git a/tools/SAMTool.py b/tools/SAMTool.py
index 4f9bf06..7aaaf3d 100644
--- a/tools/SAMTool.py
+++ b/tools/SAMTool.py
@@ -4,7 +4,9 @@
import time
import numpy as np
-import rasterio as rio
+import rasterio
+from rasterio.transform import from_bounds as transform_from_bounds
+from rasterio.features import shapes as get_shapes
from PyQt5.QtWidgets import QMessageBox
from qgis.core import QgsRectangle, QgsMessageLog, Qgis
from torch.utils.data import DataLoader
@@ -94,7 +96,7 @@ def sam_predict(self, canvas_points: Canvas_Points, canvas_rect: Canvas_Rectangl
bbox = self.sample_bbox # batch['bbox'][0]
# Change to sam.img_encoder.img_size
img_width = img_height = self.predictor.model.image_encoder.img_size # 1024
- img_clip_transform = rio.transform.from_bounds(
+ img_clip_transform = transform_from_bounds(
bbox.minx, bbox.miny, bbox.maxx, bbox.maxy, img_width, img_height)
input_point, input_label = canvas_points.get_points_and_labels(
@@ -129,13 +131,14 @@ def sam_predict(self, canvas_points: Canvas_Points, canvas_rect: Canvas_Rectangl
# results = ({'properties': {'raster_val': v}, 'geometry': s}
# for i, (s, v) in enumerate(rio.features.shapes(mask.astype(np.uint8), mask=mask, transform=img_clip_transform)))
# geoms = list(results)
- shape_generator = rio.features.shapes(
+ shape_generator = get_shapes(
mask.astype(np.uint8),
mask=mask,
+ connectivity=8, # change from default:4 to 8
transform=img_clip_transform
)
geojson = [{'properties': {'raster_val': value}, 'geometry': polygon}
- for polygon, value in shape_generator]
+ for polygon, value in shape_generator]
# add to layer
sam_polygon.rollback_changes()
diff --git a/tools/geoTool.py b/tools/geoTool.py
index 53f7d81..b3e43a5 100644
--- a/tools/geoTool.py
+++ b/tools/geoTool.py
@@ -1,12 +1,13 @@
import os
import typing
import numpy as np
-from qgis.core import (QgsProject, QgsCoordinateReferenceSystem, Qgis, QgsMessageLog,
+from qgis.core import (QgsProject, QgsCoordinateReferenceSystem, Qgis, QgsMessageLog,
QgsCoordinateTransform, QgsPointXY, QgsRectangle, QgsVectorLayer)
class ImageCRSManager:
'''Manage image crs and transform point and extent between image crs and other crs'''
+
def __init__(self, img_crs) -> None:
self.img_crs = QgsCoordinateReferenceSystem(
img_crs) # from str to QgsCRS
@@ -63,10 +64,10 @@ def extent_to_img_crs(
dst_crs, self.img_crs, QgsProject.instance())
extent_transformed = transform.transformBoundingBox(extent)
return extent_transformed
-
+
def img_extent_to_crs(self, extent: QgsRectangle, dst_crs: QgsCoordinateReferenceSystem):
'''transform extent from this image crs to destination crs
-
+
Parameters:
----------
extent: QgsRectangle
@@ -74,7 +75,8 @@ def img_extent_to_crs(self, extent: QgsRectangle, dst_crs: QgsCoordinateReferenc
dst_crs: QgsCoordinateReferenceSystem
destination crs for extent
'''
- transform = QgsCoordinateTransform(self.img_crs, dst_crs, QgsProject.instance())
+ transform = QgsCoordinateTransform(
+ self.img_crs, dst_crs, QgsProject.instance())
extent_transformed = transform.transformBoundingBox(extent)
return extent_transformed
diff --git a/tools/torchgeo_sam.py b/tools/torchgeo_sam.py
index ea51ccd..fbd63e0 100644
--- a/tools/torchgeo_sam.py
+++ b/tools/torchgeo_sam.py
@@ -1,4 +1,4 @@
-# Extension of torchgeo library by zyzhao
+# Extension of torchgeo library by zyzhao
import sys
import glob
import os
@@ -8,9 +8,9 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, cast, Union, Iterator, Iterable
-import rasterio as rio
+import rasterio
from rasterio.vrt import WarpedVRT
-from rasterio.windows import from_bounds
+from rasterio.windows import from_bounds as window_from_bounds
from rasterio.crs import CRS
import numpy as np
import pandas as pd
@@ -26,8 +26,9 @@
from rtree.index import Index, Property
+
class TestGridGeoSampler(GridGeoSampler):
-
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
"""Return the index of a dataset.
@@ -58,12 +59,14 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
maxx = bounds.maxx
minx = bounds.maxx - self.size[1]
query = {"bbox": BoundingBox(minx, maxx, miny, maxy, mint, maxt),
- "path": cast(str, hit.object)}
+ "path": cast(str, hit.object)}
+
+ # BoundingBox(minx, maxx, miny, maxy, mint, maxt)
+ yield query
- yield query # BoundingBox(minx, maxx, miny, maxy, mint, maxt)
class SamTestFeatureDataset(RasterDataset):
- filename_glob = "*.tif"
+ filename_glob = "*.tif"
# filename_regex = r"^S2.{5}_(?P\d{8})_N\d{4}_R\d{3}_6Bands_S\d{1}"
filename_regex = ".*"
date_format = ""
@@ -76,9 +79,10 @@ def __init__(self, root: str = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Optional[Sequence[str]] = None,
- transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ transforms: Optional[Callable[[
+ Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True
- ) -> None:
+ ) -> None:
if self.separate_files:
raise NotImplementedError(
'Testing for separated files are not supported yet'
@@ -95,7 +99,8 @@ def __init__(self, root: str = "data",
# Populate the dataset index
i = 0
- pathname = os.path.join(root, "**", self.filename_glob)
+ # pathname = os.path.join(root, "**", self.filename_glob)
+ pathname = os.path.join(root, self.filename_glob)
raster_list = glob.glob(pathname, recursive=True)
dir_name = os.path.basename(root)
csv_filepath = os.path.join(root, dir_name + '.csv')
@@ -103,7 +108,8 @@ def __init__(self, root: str = "data",
if os.path.exists(csv_filepath):
self.index_df = pd.read_csv(csv_filepath)
filepath_csv = self.index_df.loc[0, 'filepath']
- if len(self.index_df) == len(raster_list) and os.path.dirname(filepath_csv) == os.path.dirname(raster_list[0]):
+ # and os.path.dirname(filepath_csv) == os.path.dirname(raster_list[0]):
+ if len(self.index_df) == len(raster_list):
for _, row_df in self.index_df.iterrows():
if crs is None:
crs = row_df['crs']
@@ -113,28 +119,35 @@ def __init__(self, root: str = "data",
coords = (row_df['minx'], row_df['maxx'],
row_df['miny'], row_df['maxy'],
row_df['mint'], row_df['maxt'])
- filepath = row_df['filepath'] # TODO change to relative name
+ # change to relative path
+ filepath = os.path.join(
+ root, os.path.basename(row_df['filepath']))
self.index.insert(id, coords, filepath)
i += 1
# print(coords[0].dtype)
index_set = True
- print('index loaded from: ', os.path.basename(csv_filepath))
+ # print('index loaded from: ', os.path.basename(csv_filepath))
+ QgsMessageLog.logMessage(
+ f"Index loaded from: {os.path.basename(csv_filepath)}", 'Geo SAM', level=Qgis.Info)
else:
- print('index file does not match the raster list, it will be recreated.')
+ # print('index file does not match the raster list, it will be recreated.')
+ QgsMessageLog.logMessage(
+ f"Index file does not match the raster list, will be recreated.", 'Geo SAM', level=Qgis.Info)
if not index_set:
- self.index_df = pd.DataFrame(columns = ['id',
- 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt',
- 'filepath',
- 'crs', 'res'])
+ self.index_df = pd.DataFrame(columns=['id',
+ 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt',
+ 'filepath',
+ 'crs', 'res'])
id_list = []
coords_list = []
filepath_list = []
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
- for filepath in raster_list: # glob.iglob(pathname, recursive=True):
+ # glob.iglob(pathname, recursive=True):
+ for filepath in raster_list:
match = re.match(filename_regex, os.path.basename(filepath))
if match is not None:
try:
- with rio.open(filepath) as src:
+ with rasterio.open(filepath) as src:
# See if file has a color map
if len(self.cmap) == 0:
try:
@@ -149,7 +162,7 @@ def __init__(self, root: str = "data",
with WarpedVRT(src, crs=crs) as vrt:
minx, miny, maxx, maxy = vrt.bounds
- except rio.errors.RasterioIOError:
+ except rasterio.errors.RasterioIOError:
# Skip files that rasterio is unable to read
continue
else:
@@ -157,29 +170,39 @@ def __init__(self, root: str = "data",
maxt: float = sys.maxsize
if "date" in match.groupdict():
date = match.group("date")
- mint, maxt = disambiguate_timestamp(date, self.date_format)
+ mint, maxt = disambiguate_timestamp(
+ date, self.date_format)
coords = (minx, maxx, miny, maxy, mint, maxt)
self.index.insert(i, coords, filepath)
id_list.append(i)
coords_list.append(coords)
- filepath_list.append(filepath)
+ # change to relative path
+ filepath_list.append(os.path.basename(filepath))
i += 1
self.index_df['id'] = id_list
self.index_df['filepath'] = filepath_list
- self.index_df['minx'] = pd.to_numeric([coord[0] for coord in coords_list], downcast='float')
- self.index_df['maxx'] = pd.to_numeric([coord[1] for coord in coords_list], downcast='float')
- self.index_df['miny'] = pd.to_numeric([coord[2] for coord in coords_list], downcast='float')
- self.index_df['maxy'] = pd.to_numeric([coord[3] for coord in coords_list], downcast='float')
- self.index_df['mint'] = pd.to_numeric([coord[4] for coord in coords_list], downcast='float')
- self.index_df['maxt'] = pd.to_numeric([coord[5] for coord in coords_list], downcast='float')
+ self.index_df['minx'] = pd.to_numeric(
+ [coord[0] for coord in coords_list], downcast='float')
+ self.index_df['maxx'] = pd.to_numeric(
+ [coord[1] for coord in coords_list], downcast='float')
+ self.index_df['miny'] = pd.to_numeric(
+ [coord[2] for coord in coords_list], downcast='float')
+ self.index_df['maxy'] = pd.to_numeric(
+ [coord[3] for coord in coords_list], downcast='float')
+ self.index_df['mint'] = pd.to_numeric(
+ [coord[4] for coord in coords_list], downcast='float')
+ self.index_df['maxt'] = pd.to_numeric(
+ [coord[5] for coord in coords_list], downcast='float')
# print(type(crs), res)
self.index_df.loc[:, 'crs'] = str(crs)
self.index_df.loc[:, 'res'] = res
# print(self.index_df.dtypes)
index_set = True
self.index_df.to_csv(csv_filepath)
- print('index file: ', os.path.basename(csv_filepath), ' saved')
+ # print('index file: ', os.path.basename(csv_filepath), ' saved')
+ QgsMessageLog.logMessage(
+ f"Index file: {os.path.basename(csv_filepath)} saved", 'Geo SAM', level=Qgis.Info)
if i == 0:
msg = f"No {self.__class__.__name__} data was found in `root='{self.root}'`"
@@ -204,7 +227,6 @@ def __init__(self, root: str = "data",
self._crs = cast(CRS, crs)
self.res = cast(float, res)
-
def __getitem__(self, query: Dict[str, Any]) -> Dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
@@ -238,7 +260,7 @@ def __getitem__(self, query: Dict[str, Any]) -> Dict[str, Any]:
# band_indexes = self.band_indexes
src = vrt_fh
- dest = src.read() # read all bands
+ dest = src.read() # read all bands
# print(src.profile)
# print(src.compression)
@@ -248,22 +270,23 @@ def __getitem__(self, query: Dict[str, Any]) -> Dict[str, Any]:
elif dest.dtype == np.uint32:
dest = dest.astype(np.int64)
- tensor = torch.tensor(dest) # .float()
+ tensor = torch.tensor(dest) # .float()
# bbox may be useful to form the final mask results (geo-info)
sample = {"crs": self.crs, "bbox": bbox, "path": filepath}
if self.is_image:
sample["image"] = tensor.float()
else:
- sample["mask"] = tensor # .float() #long() # modified zyzhao
+ sample["mask"] = tensor # .float() #long() # modified zyzhao
if self.transforms is not None:
sample = self.transforms(sample)
return sample
+
class SamTestRasterDataset(RasterDataset):
- filename_glob = "*.tif"
+ filename_glob = "*.tif"
# filename_regex = r"^S2.{5}_(?P\d{8})_N\d{4}_R\d{3}_6Bands_S\d{1}"
filename_regex = ".*"
date_format = ""
@@ -276,9 +299,10 @@ def __init__(self, root: str = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Optional[Sequence[str]] = None,
- transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ transforms: Optional[Callable[[
+ Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True
- ) -> None:
+ ) -> None:
if self.separate_files:
raise NotImplementedError(
'Testing for separated files are not supported yet'
@@ -326,7 +350,7 @@ def __getitem__(self, query: Dict[str, Any]) -> Dict[str, Any]:
dest = src.read(
indexes=band_indexes,
out_shape=out_shape,
- window=from_bounds(*bounds, src.transform),
+ window=window_from_bounds(*bounds, src.transform),
)
# fix numpy dtypes which are not supported by pytorch tensors
@@ -335,13 +359,13 @@ def __getitem__(self, query: Dict[str, Any]) -> Dict[str, Any]:
elif dest.dtype == np.uint32:
dest = dest.astype(np.int64)
- tensor = torch.tensor(dest) # .float()
+ tensor = torch.tensor(dest) # .float()
sample = {"crs": self.crs, "bbox": bbox, "path": filepath}
if self.is_image:
sample["image"] = tensor.float()
else:
- sample["mask"] = tensor # .float() #long() # modified zyzhao
+ sample["mask"] = tensor # .float() #long() # modified zyzhao
if self.transforms is not None:
sample = self.transforms(sample)
@@ -363,7 +387,7 @@ def plot(self, sample, bright=1):
image = sample["image"][0, rgb_indices, :, :].permute(1, 2, 0)
else:
image = sample["image"][rgb_indices, :, :].permute(1, 2, 0)
-
+
# if image.max() > 10:
# image = self.apply_scale(image)
image = torch.clamp(image*bright/255, min=0, max=1)
@@ -375,6 +399,7 @@ def plot(self, sample, bright=1):
return fig
+
class SamTestFeatureGeoSampler(GeoSampler):
"""Samples entire files at a time.
@@ -417,7 +442,8 @@ def __init__(
# roi = BoundingBox(*self.index.bounds)
raise Exception('roi should be defined based on prompts!!!')
else:
- self.index = Index(interleaved=False, properties=Property(dimension=3))
+ self.index = Index(interleaved=False,
+ properties=Property(dimension=3))
hits = dataset.index.intersection(tuple(roi), objects=True)
# hit_nearest = list(dataset.index.nearest(tuple(roi), num_results=1, objects=True))[0]
# print('nearest hit: ', hit_nearest.object)
@@ -431,7 +457,8 @@ def __init__(
center_x_roi = (roi.maxx + roi.minx)/2
center_y_roi = (roi.maxy + roi.miny)/2
- dist_roi_tmp = (center_x_bbox - center_x_roi)**2 + (center_y_bbox - center_y_roi)**2
+ dist_roi_tmp = (center_x_bbox - center_x_roi)**2 + \
+ (center_y_bbox - center_y_roi)**2
# print(dist_roi_tmp)
if dist_roi_tmp < self.dist_roi:
self.dist_roi = dist_roi_tmp
@@ -474,5 +501,4 @@ def __len__(self) -> int:
Returns:
number of patches that will be sampled
"""
- return self.length #len(self.q_path)
-
+ return self.length # len(self.q_path)