Skip to content

Commit

Permalink
merge from Joey
Browse files Browse the repository at this point in the history
  • Loading branch information
Fanchengyan committed Jun 19, 2023
2 parents e9d7c8e + 6950cbf commit aa86dd9
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 33 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ After activating the Geo SAM plugin, you may find the tool under the `Plugins` m
or somewhere on the toolbar near the Python Plugin.

<p align="center">
<img src="assets/Toolbar_geo_sam.png" width="350" title="Plugin menu">
<img src="assets/Toolbar_geo_sam.png" width="350" title="Plugin toolbar">
</p>

## Use the GeoSAM Tool
Expand All @@ -107,7 +107,7 @@ A user interface will be shown as below.
<!-- ![ui_geo_sam](assets/ui_geo_sam.png) -->

<p align="center">
<img src="assets/ui_geo_sam.png" width="600" title="Try Geo SAM">
<img src="assets/ui_geo_sam.png" width="600" title="Geo SAM UI">
</p>

### Add Points
Expand Down
85 changes: 58 additions & 27 deletions tools/SAMTool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing
import numpy as np
from pathlib import Path
import time

import numpy as np
import rasterio as rio
Expand All @@ -10,6 +11,7 @@
from .torchgeo_sam import SamTestFeatureDataset, SamTestFeatureGeoSampler
from .sam_ext import sam_model_registry_no_encoder, SamPredictorNoImgEncoder
from .geoTool import LayerExtent, ImageCRSManager
from .canvasTool import SAM_PolygonFeature, Canvas_Rectangle, Canvas_Points
from torchgeo.datasets import BoundingBox, stack_samples
from torchgeo.samplers import Units

Expand All @@ -20,6 +22,7 @@ def __init__(self, feature_dir, cwd, model_type="vit_h"):
self.sam_checkpoint = cwd + "/checkpoint/sam_vit_h_4b8939_no_img_encoder.pth"
self.model_type = model_type
self._prepare_data_and_layer()
self.sample_path = None

def _prepare_data_and_layer(self):
"""Prepares data and layer."""
Expand All @@ -35,7 +38,7 @@ def _prepare_data_and_layer(self):
self.extent = QgsRectangle(
feature_bounds[0], feature_bounds[2], feature_bounds[1], feature_bounds[3])

def sam_predict(self, canvas_points, canvas_rect, sam_polygon):
def sam_predict(self, canvas_points: Canvas_Points, canvas_rect: Canvas_Rectangle, sam_polygon: SAM_PolygonFeature):
extent_union = LayerExtent.union_extent(
canvas_points.extent, canvas_rect.extent)

Expand All @@ -48,6 +51,7 @@ def sam_predict(self, canvas_points, canvas_rect, sam_polygon):
points_roi = BoundingBox(
min_x, max_x, min_y, max_y, self.test_features.index.bounds[4], self.test_features.index.bounds[5])

start_time = time.time()
test_sampler = SamTestFeatureGeoSampler(
self.test_features, feature_size=64, roi=points_roi, units=Units.PIXELS) # Units.CRS or Units.PIXELS

Expand All @@ -57,56 +61,83 @@ def sam_predict(self, canvas_points, canvas_rect, sam_polygon):
mb.setText('Point is located outside of the image boundary')
mb.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel)
return_value = mb.exec()
# TODO: Clear last point falls outside the boundary
if return_value == QMessageBox.Ok:
print('You pressed OK')
elif return_value == QMessageBox.Cancel:
print('You pressed Cancel')

return False
test_dataloader = DataLoader(
self.test_features, batch_size=1, sampler=test_sampler, collate_fn=stack_samples)

for batch in test_dataloader:
# print(batch.keys())
# print(batch['image'].shape)
# print(batch['path'])
# print(batch['bbox'])
# print(len(batch['image']))
# break
pass

bbox = batch['bbox'][0]
# TODO: Change to sam.img_encoder.img_size
width = height = 1024

for query in test_sampler:
# different query than last time, update feature
if query['path'] == self.sample_path:
break
sample = self.test_features[query]
self.sample_path = sample['path']
self.sample_bbox = sample['bbox']
self.img_features = sample['image']
break

# test_dataloader = DataLoader(
# self.test_features, batch_size=1, sampler=test_sampler, collate_fn=stack_samples)

# for batch in test_dataloader:
# # print(batch.keys())
# # print(batch['image'].shape)
# # print(batch['path'])
# # print(batch['bbox'])
# # print(len(batch['image']))
# # break
# pass

bbox = self.sample_bbox # batch['bbox'][0]
# Change to sam.img_encoder.img_size
img_width = img_height = self.predictor.model.image_encoder.img_size # 1024
img_clip_transform = rio.transform.from_bounds(
bbox.minx, bbox.miny, bbox.maxx, bbox.maxy, width, height)
bbox.minx, bbox.miny, bbox.maxx, bbox.maxy, img_width, img_height)

input_point, input_label = canvas_points.get_points_and_labels(
img_clip_transform)
box = canvas_rect.get_img_box(img_clip_transform)
print("box", box)

img_features = batch['image']
self.predictor.set_image_feature(img_features, img_shape=(1024, 1024))
input_box = canvas_rect.get_img_box(img_clip_transform)
# print("box", input_box)

# img_features = batch['image']
self.predictor.set_image_feature(
self.img_features, img_shape=(img_height, img_width))

# TODO: Points or rectangles can be negative or exceed 1024, should be regulated
# also may consider remove those points after checking
masks, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=box,
box=input_box,
multimask_output=False,
)
end_time = time.time()
# get the execution time of sam predictor, ms
elapsed_time = (end_time - start_time) * 1000

QgsMessageLog.logMessage(
"SAM predict executed", 'Geo SAM', level=Qgis.Info)
f"SAM predict executed with {elapsed_time:.3f} ms", 'Geo SAM', level=Qgis.Info)

mask = masks[0, ...]
# mask = mask_morph

# convert mask to geojson
results = ({'properties': {'raster_val': v}, 'geometry': s}
for i, (s, v) in enumerate(rio.features.shapes(mask.astype(np.uint8), mask=mask, transform=img_clip_transform)))
geoms = list(results)
# results = ({'properties': {'raster_val': v}, 'geometry': s}
# for i, (s, v) in enumerate(rio.features.shapes(mask.astype(np.uint8), mask=mask, transform=img_clip_transform)))
# geoms = list(results)
shape_generator = rio.features.shapes(
mask.astype(np.uint8),
mask=mask,
transform=img_clip_transform
)
geojson = [{'properties': {'raster_val': value}, 'geometry': polygon}
for polygon, value in shape_generator]

# add to layer
sam_polygon.rollback_changes()
sam_polygon.add_geojson_feature(geoms)
sam_polygon.add_geojson_feature(geojson)
return True
4 changes: 2 additions & 2 deletions tools/canvasTool.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def rectangle(self):
self.startPoint.y() == self.endPoint.y()):
return None
else:
# TODO startPoint endPoint transform
# startPoint endPoint transform
if self.qgis_project.crs() != self.img_crs_manager.img_crs:
self.startPoint = self.img_crs_manager.point_to_img_crs(
self.startPoint, self.qgis_project.crs())
Expand Down Expand Up @@ -353,7 +353,7 @@ def add_geojson_feature(self, geojson):
points = []
coordinates = geom['geometry']['coordinates'][0]
for coord in coordinates:
# TODO transform pointXY from img_crs to polygon layer crs, if not match
# transform pointXY from img_crs to polygon layer crs, if not match
point = QgsPointXY(*coord)
if self.layer.crs() != self.img_crs_manager.img_crs:
point = self.img_crs_manager.img_point_to_crs(
Expand Down
7 changes: 5 additions & 2 deletions tools/torchgeo_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchgeo.samplers.constants import Units
from torchgeo.samplers.utils import _to_tuple, get_random_bounding_box, tile_to_chips
import matplotlib.pyplot as plt
from qgis.core import QgsMessageLog, Qgis

from rtree.index import Index, Property

Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(self, root: str = "data",
coords = (row_df['minx'], row_df['maxx'],
row_df['miny'], row_df['maxy'],
row_df['mint'], row_df['maxt'])
filepath = row_df['filepath']
filepath = row_df['filepath'] # TODO change to relative name
self.index.insert(id, coords, filepath)
i += 1
# print(coords[0].dtype)
Expand Down Expand Up @@ -438,7 +439,9 @@ def __init__(
self.q_path = hit.object
# self.index.insert(hit.id, tuple(bbox), hit.object)

print('intersected features: ', idx)
# print('intersected features: ', idx)
QgsMessageLog.logMessage(
f"Prompt intersected with {idx} feature patches", 'Geo SAM', level=Qgis.Info)
# print('selected hit: ', self.q_path)
if self.q_bbox is None:
self.length = 0
Expand Down
108 changes: 108 additions & 0 deletions ui/gsDockTest.ui
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
<?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>gsDockWidget</class>
<widget class="QgsDockWidget" name="gsDockWidget">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>430</width>
<height>306</height>
</rect>
</property>
<property name="windowTitle">
<string>gsDockWidget</string>
</property>
<widget class="QWidget" name="dockWidgetContents">
<widget class="QgsCollapsibleGroupBox" name="mGroupBox">
<property name="geometry">
<rect>
<x>50</x>
<y>120</y>
<width>300</width>
<height>100</height>
</rect>
</property>
</widget>
<widget class="QgsFileWidget" name="mQgsFileWidget">
<property name="geometry">
<rect>
<x>130</x>
<y>60</y>
<width>90</width>
<height>27</height>
</rect>
</property>
<property name="storageMode">
<enum>QgsFileWidget::GetDirectory</enum>
</property>
</widget>
<widget class="QgsSymbolButton" name="mSymbolButton">
<property name="geometry">
<rect>
<x>90</x>
<y>60</y>
<width>27</width>
<height>27</height>
</rect>
</property>
</widget>
<widget class="QgsOpacityWidget" name="mOpacityWidget">
<property name="geometry">
<rect>
<x>70</x>
<y>100</y>
<width>160</width>
<height>27</height>
</rect>
</property>
</widget>
<widget class="QgsMapLayerComboBox" name="mMapLayerComboBox">
<property name="geometry">
<rect>
<x>230</x>
<y>30</y>
<width>160</width>
<height>32</height>
</rect>
</property>
</widget>
</widget>
</widget>
<customwidgets>
<customwidget>
<class>QgsCollapsibleGroupBox</class>
<extends>QGroupBox</extends>
<header>qgscollapsiblegroupbox.h</header>
<container>1</container>
</customwidget>
<customwidget>
<class>QgsDockWidget</class>
<extends>QDockWidget</extends>
<header>qgsdockwidget.h</header>
<container>1</container>
</customwidget>
<customwidget>
<class>QgsFileWidget</class>
<extends>QWidget</extends>
<header>qgsfilewidget.h</header>
</customwidget>
<customwidget>
<class>QgsMapLayerComboBox</class>
<extends>QComboBox</extends>
<header>qgsmaplayercombobox.h</header>
</customwidget>
<customwidget>
<class>QgsOpacityWidget</class>
<extends>QWidget</extends>
<header>qgsopacitywidget.h</header>
</customwidget>
<customwidget>
<class>QgsSymbolButton</class>
<extends>QToolButton</extends>
<header>qgssymbolbutton.h</header>
</customwidget>
</customwidgets>
<resources/>
<connections/>
</ui>

0 comments on commit aa86dd9

Please sign in to comment.