Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/vector tiles training files #291

Merged
merged 23 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
19b3e45
Add centroid lat and lon
kshitijrajsharma Oct 8, 2024
db90cce
Add mercantile requirements
kshitijrajsharma Oct 8, 2024
d9d4f62
Added thumnail_url
kshitijrajsharma Oct 8, 2024
a15bf2a
Commented get_training function for now
kshitijrajsharma Oct 8, 2024
6ba4c0b
Addes tippecanoe to generate vector tiles from aois and labels along …
kshitijrajsharma Oct 10, 2024
0ab6df3
Updated docker compose file and added dockerfile frontend
kshitijrajsharma Oct 10, 2024
28b81d8
Feat : Search by model id
kshitijrajsharma Oct 14, 2024
0dc50ae
Feat : Added feedback count to training /get/id/ endpoint
kshitijrajsharma Oct 14, 2024
1bbb0e9
Feat : Add banners from backend
kshitijrajsharma Oct 14, 2024
d568faf
Add admin view and permission set for retrieving the data
kshitijrajsharma Oct 14, 2024
382274f
Refactor ::: created_by -> user
kshitijrajsharma Oct 14, 2024
7a44110
Feat : Authentication for admins and staffs
kshitijrajsharma Oct 14, 2024
feeaffd
Fix tippecanoe command
kshitijrajsharma Oct 14, 2024
4a10736
Use -L instead of full named layer
kshitijrajsharma Oct 14, 2024
2b3d9b3
move tippecanoe to furhter steps
kshitijrajsharma Oct 14, 2024
ba0c5ad
Try : Fix files labels
kshitijrajsharma Oct 14, 2024
77c9889
Use subprocess run instead of check output
kshitijrajsharma Oct 14, 2024
27a32ae
Added foundation model to Model table
kshitijrajsharma Oct 14, 2024
5c6d632
Fix : Change foudnation model to base model , Remove is_active from a…
kshitijrajsharma Oct 15, 2024
da9eddf
Add public methods in admin and staff permission , Feat : Add KPI stats
kshitijrajsharma Oct 15, 2024
f10382e
FIX : Remove is_displayable to API
kshitijrajsharma Oct 15, 2024
a9f14de
Add permission for user to be able to submit the training request but…
kshitijrajsharma Oct 15, 2024
40b8060
Enhance : Cache on the endpoint for key api stats
kshitijrajsharma Oct 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/api-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ fairpredictor==0.0.26

rasterio==1.3.8
numpy<2.0.0

mercantile==1.2.1

20 changes: 17 additions & 3 deletions backend/core/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

@admin.register(Dataset)
class DatasetAdmin(geoadmin.OSMGeoAdmin):
list_display = ["name", "created_by"]
list_display = ["name", "user"]


@admin.register(Model)
class ModelAdmin(geoadmin.OSMGeoAdmin):
list_display = ["get_dataset_id", "name", "status", "created_at", "created_by"]
list_display = ["get_dataset_id", "name", "status", "created_at", "user"]

def get_dataset_id(self, obj):
return obj.dataset.id
Expand All @@ -28,7 +28,7 @@ class TrainingAdmin(geoadmin.OSMGeoAdmin):
"description",
"status",
"zoom_level",
"created_by",
"user",
"accuracy",
]
list_filter = ["status"]
Expand All @@ -47,3 +47,17 @@ class FeedbackAOIAdmin(geoadmin.OSMGeoAdmin):
@admin.register(Feedback)
class FeedbackAdmin(geoadmin.OSMGeoAdmin):
list_display = ["feedback_type", "training", "user", "created_at"]


@admin.register(Banner)
class BannerAdmin(admin.ModelAdmin):
list_display = ("message", "start_date", "end_date", "is_displayable")
list_filter = ("start_date", "end_date")
search_fields = ("message",)
readonly_fields = ("is_displayable",)

def is_displayable(self, obj):
return obj.is_displayable()

is_displayable.boolean = True
is_displayable.short_description = "Currently Displayable"
38 changes: 30 additions & 8 deletions backend/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from django.contrib.postgres.fields import ArrayField
from django.core.validators import MaxValueValidator, MinValueValidator
from django.db import models

from django.utils import timezone
from login.models import OsmUser

# Create your models here.
Expand All @@ -15,7 +15,7 @@ class DatasetStatus(models.IntegerChoices):
DRAFT = -1

name = models.CharField(max_length=255)
created_by = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
last_modified = models.DateTimeField(auto_now=True)
created_at = models.DateTimeField(auto_now_add=True)
source_imagery = models.URLField(blank=True, null=True)
Expand Down Expand Up @@ -47,6 +47,11 @@ class Label(models.Model):


class Model(models.Model):
BASE_MODEL_CHOICES = (
("RAMP", "RAMP"),
("YOLO", "YOLO"),
)

class ModelStatus(models.IntegerChoices):
ARCHIVED = 1
PUBLISHED = 0
Expand All @@ -57,9 +62,12 @@ class ModelStatus(models.IntegerChoices):
created_at = models.DateTimeField(auto_now_add=True)
last_modified = models.DateTimeField(auto_now=True)
description = models.TextField(max_length=500, null=True, blank=True)
created_by = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
published_training = models.PositiveIntegerField(null=True, blank=True)
status = models.IntegerField(default=-1, choices=ModelStatus.choices) #
status = models.IntegerField(default=-1, choices=ModelStatus.choices)
base_model = models.CharField(
choices=BASE_MODEL_CHOICES, default="RAMP", max_length=10
)


class Training(models.Model):
Expand All @@ -81,14 +89,15 @@ class Training(models.Model):
models.PositiveIntegerField(),
size=4,
)
created_by = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
started_at = models.DateTimeField(null=True, blank=True)
finished_at = models.DateTimeField(null=True, blank=True)
accuracy = models.FloatField(null=True, blank=True)
epochs = models.PositiveIntegerField()
chips_length = models.PositiveIntegerField(default=0)
batch_size = models.PositiveIntegerField()
freeze_layers = models.BooleanField(default=False)
centroid = geomodels.PointField(srid=4326, null=True, blank=True)


class Feedback(models.Model):
Expand Down Expand Up @@ -146,6 +155,19 @@ class ApprovedPredictions(models.Model):
srid=4326
) ## Making this geometry field to support point/line prediction later on
approved_at = models.DateTimeField(auto_now_add=True)
approved_by = models.ForeignKey(
OsmUser, to_field="osm_id", on_delete=models.CASCADE
)
user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)


class Banner(models.Model):
message = models.TextField()
start_date = models.DateTimeField(default=timezone.now)
end_date = models.DateTimeField(null=True, blank=True)

def is_displayable(self):
now = timezone.now()
return (self.start_date <= now) and (
self.end_date is None or self.end_date >= now
)

def __str__(self):
return self.message
61 changes: 47 additions & 14 deletions backend/core/serializers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import mercantile
from django.conf import settings
from login.models import OsmUser
from rest_framework import serializers
from rest_framework_gis.serializers import (
GeoFeatureModelSerializer, # this will be used if we used to serialize as geojson
)

from login.models import OsmUser

from .models import *

# from .tasks import train_model
Expand All @@ -18,14 +18,14 @@ class Meta:
model = Dataset
fields = "__all__" # defining all the fields to be included in curd for now , we can restrict few if we want
read_only_fields = (
"created_by",
"user",
"created_at",
"last_modified",
)

def create(self, validated_data):
user = self.context["request"].user
validated_data["created_by"] = user
validated_data["user"] = user
return super().create(validated_data)


Expand All @@ -45,30 +45,51 @@ class Meta:
]


class ModelSerializer(
serializers.ModelSerializer
): # serializers are used to translate models objects to api
created_by = UserSerializer(read_only=True)
class ModelSerializer(serializers.ModelSerializer):
user = UserSerializer(read_only=True)
accuracy = serializers.SerializerMethodField()
thumbnail_url = serializers.SerializerMethodField()

class Meta:
model = Model
fields = "__all__"
read_only_fields = (
"created_at",
"last_modified",
"created_by",
"user",
"published_training",
)

def create(self, validated_data):
user = self.context["request"].user
validated_data["created_by"] = user
validated_data["user"] = user
return super().create(validated_data)

def get_accuracy(
self, obj
): ## this might have performance problem when db grows bigger , consider adding indexes / view in db
# def get_training(self, obj):
# if not hasattr(self, "_cached_training"):
# self._cached_training = Training.objects.filter(
# id=obj.published_training
# ).first()
# return self._cached_training

def get_thumbnail_url(self, obj):
training = Training.objects.filter(id=obj.published_training).first()

if training:
if training.source_imagery:
aoi = AOI.objects.filter(dataset=obj.dataset).first()
if aoi and aoi.geom:
centroid = (
aoi.geom.centroid.coords
) ## Centroid can be stored in db table if required when project grows bigger
try:
tile = mercantile.tile(centroid[0], centroid[1], zoom=18)
return training.source_imagery.format(x=tile.x, y=tile.y, z=18)
except Exception as ex:
pass
return None

def get_accuracy(self, obj):
training = Training.objects.filter(id=obj.published_training).first()
if training:
return training.accuracy
Expand All @@ -82,7 +103,8 @@ class ModelCentroidSerializer(GeoFeatureModelSerializer):
class Meta:
model = Model
geo_field = "geometry"
fields = ("mid", "name", "geometry")
fields = ("mid", "geometry")
# fields = ("mid", "name", "geometry")

def get_geometry(self, obj):
"""
Expand Down Expand Up @@ -371,3 +393,14 @@ def validate(self, data):
data["area_threshold"]
)
return data


class BannerSerializer(serializers.ModelSerializer):
class Meta:
model = Banner
fields = [
"id",
"message",
"start_date",
"end_date",
]
30 changes: 23 additions & 7 deletions backend/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@
import logging
import os
import shutil
import subprocess
import sys
import tarfile
import traceback
from shutil import rmtree

from celery import shared_task
from django.conf import settings
from django.contrib.gis.db.models.aggregates import Extent
from django.contrib.gis.geos import GEOSGeometry
from django.shortcuts import get_object_or_404
from django.utils import timezone

from core.models import AOI, Feedback, FeedbackAOI, FeedbackLabel, Label, Training
from core.serializers import (
AOISerializer,
Expand All @@ -23,6 +18,11 @@
LabelFileSerializer,
)
from core.utils import bbox, is_dir_empty
from django.conf import settings
from django.contrib.gis.db.models.aggregates import Extent
from django.contrib.gis.geos import GEOSGeometry
from django.shortcuts import get_object_or_404
from django.utils import timezone

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -135,6 +135,10 @@ def train_model(
raise ValueError(
f"No AOI is attached with supplied dataset id:{dataset_id}, Create AOI first",
)
first_aoi_centroid = aois[0].geom.centroid
training_instance.centroid = first_aoi_centroid
training_instance.save()

for obj in aois:
bbox_coords = bbox(obj.geom.coords[0])
for z in zoom_level:
Expand Down Expand Up @@ -309,6 +313,18 @@ def train_model(
) as f:
f.write(json.dumps(aoi_serializer.data))

tippecanoe_command = f"""tippecanoe -o {os.path.join(output_path,"meta.pmtiles")} -Z7 -z18 -L aois:{ os.path.join(output_path, "aois.geojson")} -L labels:{os.path.join(output_path, "labels.geojson")} --force --read-parallel -rg --drop-densest-as-needed"""
logging.info("Starting to generate vector tiles for aois and labels")
try:
result = subprocess.run(
tippecanoe_command, shell=True, check=True, capture_output=True
)
logging.info(result.stdout.decode("utf-8"))
except subprocess.CalledProcessError as ex:
logger.error(ex.output)
raise ex
logging.info("Vector tile generation done !")

# copy aois and labels to preprocess output before compressing it to tar
shutil.copyfile(
os.path.join(output_path, "aois.geojson"),
Expand All @@ -332,7 +348,7 @@ def train_model(
training_instance.save()
response = {}
response["accuracy"] = float(final_accuracy)
# response["model_path"] = os.path.join(output_path, "checkpoint.tf")
response["tiles_path"] = os.path.join(output_path, "meta.pmtiles")
response["model_path"] = os.path.join(output_path, "checkpoint.h5")
response["graph_path"] = os.path.join(output_path, "graphs")
sys.stdout = sys.__stdout__
Expand Down
4 changes: 4 additions & 0 deletions backend/core/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .views import ( # APIStatus,
AOIViewSet,
ApprovedPredictionsViewSet,
BannerViewSet,
ConflateGeojson,
DatasetViewSet,
FeedbackAOIViewset,
Expand All @@ -26,6 +27,7 @@
UsersView,
download_training_data,
geojson2osmconverter,
get_kpi_stats,
publish_training,
run_task_status,
)
Expand All @@ -44,6 +46,7 @@
router.register(r"feedback", FeedbackViewset)
router.register(r"feedback-aoi", FeedbackAOIViewset)
router.register(r"feedback-label", FeedbackLabelViewset)
router.register(r"banner", BannerViewSet)


urlpatterns = [
Expand Down Expand Up @@ -71,6 +74,7 @@
"workspace/download/<path:lookup_dir>/", TrainingWorkspaceDownloadView.as_view()
),
path("workspace/<path:lookup_dir>/", TrainingWorkspaceView.as_view()),
path("kpi/stats/", get_kpi_stats, name="get_kpi_stats"),
]
if settings.ENABLE_PREDICTION_API:
urlpatterns.append(path("prediction/", PredictionView.as_view()))
Loading
Loading