diff --git a/icons/img_encoder.png b/icons/img_encoder.png new file mode 100644 index 0000000..b8f7c4a Binary files /dev/null and b/icons/img_encoder.png differ diff --git a/icons/img_encoder.svg b/icons/img_encoder.svg new file mode 100644 index 0000000..52cbd9f --- /dev/null +++ b/icons/img_encoder.svg @@ -0,0 +1,118 @@ + + + + + + + + + + + + diff --git a/tools/SAMTool.py b/tools/SAMTool.py index 4f9bf06..7aaaf3d 100644 --- a/tools/SAMTool.py +++ b/tools/SAMTool.py @@ -4,7 +4,9 @@ import time import numpy as np -import rasterio as rio +import rasterio +from rasterio.transform import from_bounds as transform_from_bounds +from rasterio.features import shapes as get_shapes from PyQt5.QtWidgets import QMessageBox from qgis.core import QgsRectangle, QgsMessageLog, Qgis from torch.utils.data import DataLoader @@ -94,7 +96,7 @@ def sam_predict(self, canvas_points: Canvas_Points, canvas_rect: Canvas_Rectangl bbox = self.sample_bbox # batch['bbox'][0] # Change to sam.img_encoder.img_size img_width = img_height = self.predictor.model.image_encoder.img_size # 1024 - img_clip_transform = rio.transform.from_bounds( + img_clip_transform = transform_from_bounds( bbox.minx, bbox.miny, bbox.maxx, bbox.maxy, img_width, img_height) input_point, input_label = canvas_points.get_points_and_labels( @@ -129,13 +131,14 @@ def sam_predict(self, canvas_points: Canvas_Points, canvas_rect: Canvas_Rectangl # results = ({'properties': {'raster_val': v}, 'geometry': s} # for i, (s, v) in enumerate(rio.features.shapes(mask.astype(np.uint8), mask=mask, transform=img_clip_transform))) # geoms = list(results) - shape_generator = rio.features.shapes( + shape_generator = get_shapes( mask.astype(np.uint8), mask=mask, + connectivity=8, # change from default:4 to 8 transform=img_clip_transform ) geojson = [{'properties': {'raster_val': value}, 'geometry': polygon} - for polygon, value in shape_generator] + for polygon, value in shape_generator] # add to layer sam_polygon.rollback_changes() diff --git a/tools/geoTool.py b/tools/geoTool.py index 53f7d81..b3e43a5 100644 --- a/tools/geoTool.py +++ b/tools/geoTool.py @@ -1,12 +1,13 @@ import os import typing import numpy as np -from qgis.core import (QgsProject, QgsCoordinateReferenceSystem, Qgis, QgsMessageLog, +from qgis.core import (QgsProject, QgsCoordinateReferenceSystem, Qgis, QgsMessageLog, QgsCoordinateTransform, QgsPointXY, QgsRectangle, QgsVectorLayer) class ImageCRSManager: '''Manage image crs and transform point and extent between image crs and other crs''' + def __init__(self, img_crs) -> None: self.img_crs = QgsCoordinateReferenceSystem( img_crs) # from str to QgsCRS @@ -63,10 +64,10 @@ def extent_to_img_crs( dst_crs, self.img_crs, QgsProject.instance()) extent_transformed = transform.transformBoundingBox(extent) return extent_transformed - + def img_extent_to_crs(self, extent: QgsRectangle, dst_crs: QgsCoordinateReferenceSystem): '''transform extent from this image crs to destination crs - + Parameters: ---------- extent: QgsRectangle @@ -74,7 +75,8 @@ def img_extent_to_crs(self, extent: QgsRectangle, dst_crs: QgsCoordinateReferenc dst_crs: QgsCoordinateReferenceSystem destination crs for extent ''' - transform = QgsCoordinateTransform(self.img_crs, dst_crs, QgsProject.instance()) + transform = QgsCoordinateTransform( + self.img_crs, dst_crs, QgsProject.instance()) extent_transformed = transform.transformBoundingBox(extent) return extent_transformed diff --git a/tools/torchgeo_sam.py b/tools/torchgeo_sam.py index ea51ccd..fbd63e0 100644 --- a/tools/torchgeo_sam.py +++ b/tools/torchgeo_sam.py @@ -1,4 +1,4 @@ -# Extension of torchgeo library by zyzhao +# Extension of torchgeo library by zyzhao import sys import glob import os @@ -8,9 +8,9 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, cast, Union, Iterator, Iterable -import rasterio as rio +import rasterio from rasterio.vrt import WarpedVRT -from rasterio.windows import from_bounds +from rasterio.windows import from_bounds as window_from_bounds from rasterio.crs import CRS import numpy as np import pandas as pd @@ -26,8 +26,9 @@ from rtree.index import Index, Property + class TestGridGeoSampler(GridGeoSampler): - + def __iter__(self) -> Iterator[Dict[str, Any]]: """Return the index of a dataset. @@ -58,12 +59,14 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: maxx = bounds.maxx minx = bounds.maxx - self.size[1] query = {"bbox": BoundingBox(minx, maxx, miny, maxy, mint, maxt), - "path": cast(str, hit.object)} + "path": cast(str, hit.object)} + + # BoundingBox(minx, maxx, miny, maxy, mint, maxt) + yield query - yield query # BoundingBox(minx, maxx, miny, maxy, mint, maxt) class SamTestFeatureDataset(RasterDataset): - filename_glob = "*.tif" + filename_glob = "*.tif" # filename_regex = r"^S2.{5}_(?P\d{8})_N\d{4}_R\d{3}_6Bands_S\d{1}" filename_regex = ".*" date_format = "" @@ -76,9 +79,10 @@ def __init__(self, root: str = "data", crs: Optional[CRS] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[ + Dict[str, Any]], Dict[str, Any]]] = None, cache: bool = True - ) -> None: + ) -> None: if self.separate_files: raise NotImplementedError( 'Testing for separated files are not supported yet' @@ -95,7 +99,8 @@ def __init__(self, root: str = "data", # Populate the dataset index i = 0 - pathname = os.path.join(root, "**", self.filename_glob) + # pathname = os.path.join(root, "**", self.filename_glob) + pathname = os.path.join(root, self.filename_glob) raster_list = glob.glob(pathname, recursive=True) dir_name = os.path.basename(root) csv_filepath = os.path.join(root, dir_name + '.csv') @@ -103,7 +108,8 @@ def __init__(self, root: str = "data", if os.path.exists(csv_filepath): self.index_df = pd.read_csv(csv_filepath) filepath_csv = self.index_df.loc[0, 'filepath'] - if len(self.index_df) == len(raster_list) and os.path.dirname(filepath_csv) == os.path.dirname(raster_list[0]): + # and os.path.dirname(filepath_csv) == os.path.dirname(raster_list[0]): + if len(self.index_df) == len(raster_list): for _, row_df in self.index_df.iterrows(): if crs is None: crs = row_df['crs'] @@ -113,28 +119,35 @@ def __init__(self, root: str = "data", coords = (row_df['minx'], row_df['maxx'], row_df['miny'], row_df['maxy'], row_df['mint'], row_df['maxt']) - filepath = row_df['filepath'] # TODO change to relative name + # change to relative path + filepath = os.path.join( + root, os.path.basename(row_df['filepath'])) self.index.insert(id, coords, filepath) i += 1 # print(coords[0].dtype) index_set = True - print('index loaded from: ', os.path.basename(csv_filepath)) + # print('index loaded from: ', os.path.basename(csv_filepath)) + QgsMessageLog.logMessage( + f"Index loaded from: {os.path.basename(csv_filepath)}", 'Geo SAM', level=Qgis.Info) else: - print('index file does not match the raster list, it will be recreated.') + # print('index file does not match the raster list, it will be recreated.') + QgsMessageLog.logMessage( + f"Index file does not match the raster list, will be recreated.", 'Geo SAM', level=Qgis.Info) if not index_set: - self.index_df = pd.DataFrame(columns = ['id', - 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt', - 'filepath', - 'crs', 'res']) + self.index_df = pd.DataFrame(columns=['id', + 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt', + 'filepath', + 'crs', 'res']) id_list = [] coords_list = [] filepath_list = [] filename_regex = re.compile(self.filename_regex, re.VERBOSE) - for filepath in raster_list: # glob.iglob(pathname, recursive=True): + # glob.iglob(pathname, recursive=True): + for filepath in raster_list: match = re.match(filename_regex, os.path.basename(filepath)) if match is not None: try: - with rio.open(filepath) as src: + with rasterio.open(filepath) as src: # See if file has a color map if len(self.cmap) == 0: try: @@ -149,7 +162,7 @@ def __init__(self, root: str = "data", with WarpedVRT(src, crs=crs) as vrt: minx, miny, maxx, maxy = vrt.bounds - except rio.errors.RasterioIOError: + except rasterio.errors.RasterioIOError: # Skip files that rasterio is unable to read continue else: @@ -157,29 +170,39 @@ def __init__(self, root: str = "data", maxt: float = sys.maxsize if "date" in match.groupdict(): date = match.group("date") - mint, maxt = disambiguate_timestamp(date, self.date_format) + mint, maxt = disambiguate_timestamp( + date, self.date_format) coords = (minx, maxx, miny, maxy, mint, maxt) self.index.insert(i, coords, filepath) id_list.append(i) coords_list.append(coords) - filepath_list.append(filepath) + # change to relative path + filepath_list.append(os.path.basename(filepath)) i += 1 self.index_df['id'] = id_list self.index_df['filepath'] = filepath_list - self.index_df['minx'] = pd.to_numeric([coord[0] for coord in coords_list], downcast='float') - self.index_df['maxx'] = pd.to_numeric([coord[1] for coord in coords_list], downcast='float') - self.index_df['miny'] = pd.to_numeric([coord[2] for coord in coords_list], downcast='float') - self.index_df['maxy'] = pd.to_numeric([coord[3] for coord in coords_list], downcast='float') - self.index_df['mint'] = pd.to_numeric([coord[4] for coord in coords_list], downcast='float') - self.index_df['maxt'] = pd.to_numeric([coord[5] for coord in coords_list], downcast='float') + self.index_df['minx'] = pd.to_numeric( + [coord[0] for coord in coords_list], downcast='float') + self.index_df['maxx'] = pd.to_numeric( + [coord[1] for coord in coords_list], downcast='float') + self.index_df['miny'] = pd.to_numeric( + [coord[2] for coord in coords_list], downcast='float') + self.index_df['maxy'] = pd.to_numeric( + [coord[3] for coord in coords_list], downcast='float') + self.index_df['mint'] = pd.to_numeric( + [coord[4] for coord in coords_list], downcast='float') + self.index_df['maxt'] = pd.to_numeric( + [coord[5] for coord in coords_list], downcast='float') # print(type(crs), res) self.index_df.loc[:, 'crs'] = str(crs) self.index_df.loc[:, 'res'] = res # print(self.index_df.dtypes) index_set = True self.index_df.to_csv(csv_filepath) - print('index file: ', os.path.basename(csv_filepath), ' saved') + # print('index file: ', os.path.basename(csv_filepath), ' saved') + QgsMessageLog.logMessage( + f"Index file: {os.path.basename(csv_filepath)} saved", 'Geo SAM', level=Qgis.Info) if i == 0: msg = f"No {self.__class__.__name__} data was found in `root='{self.root}'`" @@ -204,7 +227,6 @@ def __init__(self, root: str = "data", self._crs = cast(CRS, crs) self.res = cast(float, res) - def __getitem__(self, query: Dict[str, Any]) -> Dict[str, Any]: """Retrieve image/mask and metadata indexed by query. @@ -238,7 +260,7 @@ def __getitem__(self, query: Dict[str, Any]) -> Dict[str, Any]: # band_indexes = self.band_indexes src = vrt_fh - dest = src.read() # read all bands + dest = src.read() # read all bands # print(src.profile) # print(src.compression) @@ -248,22 +270,23 @@ def __getitem__(self, query: Dict[str, Any]) -> Dict[str, Any]: elif dest.dtype == np.uint32: dest = dest.astype(np.int64) - tensor = torch.tensor(dest) # .float() + tensor = torch.tensor(dest) # .float() # bbox may be useful to form the final mask results (geo-info) sample = {"crs": self.crs, "bbox": bbox, "path": filepath} if self.is_image: sample["image"] = tensor.float() else: - sample["mask"] = tensor # .float() #long() # modified zyzhao + sample["mask"] = tensor # .float() #long() # modified zyzhao if self.transforms is not None: sample = self.transforms(sample) return sample + class SamTestRasterDataset(RasterDataset): - filename_glob = "*.tif" + filename_glob = "*.tif" # filename_regex = r"^S2.{5}_(?P\d{8})_N\d{4}_R\d{3}_6Bands_S\d{1}" filename_regex = ".*" date_format = "" @@ -276,9 +299,10 @@ def __init__(self, root: str = "data", crs: Optional[CRS] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[ + Dict[str, Any]], Dict[str, Any]]] = None, cache: bool = True - ) -> None: + ) -> None: if self.separate_files: raise NotImplementedError( 'Testing for separated files are not supported yet' @@ -326,7 +350,7 @@ def __getitem__(self, query: Dict[str, Any]) -> Dict[str, Any]: dest = src.read( indexes=band_indexes, out_shape=out_shape, - window=from_bounds(*bounds, src.transform), + window=window_from_bounds(*bounds, src.transform), ) # fix numpy dtypes which are not supported by pytorch tensors @@ -335,13 +359,13 @@ def __getitem__(self, query: Dict[str, Any]) -> Dict[str, Any]: elif dest.dtype == np.uint32: dest = dest.astype(np.int64) - tensor = torch.tensor(dest) # .float() + tensor = torch.tensor(dest) # .float() sample = {"crs": self.crs, "bbox": bbox, "path": filepath} if self.is_image: sample["image"] = tensor.float() else: - sample["mask"] = tensor # .float() #long() # modified zyzhao + sample["mask"] = tensor # .float() #long() # modified zyzhao if self.transforms is not None: sample = self.transforms(sample) @@ -363,7 +387,7 @@ def plot(self, sample, bright=1): image = sample["image"][0, rgb_indices, :, :].permute(1, 2, 0) else: image = sample["image"][rgb_indices, :, :].permute(1, 2, 0) - + # if image.max() > 10: # image = self.apply_scale(image) image = torch.clamp(image*bright/255, min=0, max=1) @@ -375,6 +399,7 @@ def plot(self, sample, bright=1): return fig + class SamTestFeatureGeoSampler(GeoSampler): """Samples entire files at a time. @@ -417,7 +442,8 @@ def __init__( # roi = BoundingBox(*self.index.bounds) raise Exception('roi should be defined based on prompts!!!') else: - self.index = Index(interleaved=False, properties=Property(dimension=3)) + self.index = Index(interleaved=False, + properties=Property(dimension=3)) hits = dataset.index.intersection(tuple(roi), objects=True) # hit_nearest = list(dataset.index.nearest(tuple(roi), num_results=1, objects=True))[0] # print('nearest hit: ', hit_nearest.object) @@ -431,7 +457,8 @@ def __init__( center_x_roi = (roi.maxx + roi.minx)/2 center_y_roi = (roi.maxy + roi.miny)/2 - dist_roi_tmp = (center_x_bbox - center_x_roi)**2 + (center_y_bbox - center_y_roi)**2 + dist_roi_tmp = (center_x_bbox - center_x_roi)**2 + \ + (center_y_bbox - center_y_roi)**2 # print(dist_roi_tmp) if dist_roi_tmp < self.dist_roi: self.dist_roi = dist_roi_tmp @@ -474,5 +501,4 @@ def __len__(self) -> int: Returns: number of patches that will be sampled """ - return self.length #len(self.q_path) - + return self.length # len(self.q_path)