Skip to content

Commit

Permalink
Merge pull request #65 from VIDA-NYU/dhodcz2
Browse files Browse the repository at this point in the history
fix raster.inference() args
  • Loading branch information
Mary-h86 authored Sep 15, 2024
2 parents fb65aa7 + e030cbd commit 726982b
Show file tree
Hide file tree
Showing 22 changed files with 128 additions and 83 deletions.
6 changes: 6 additions & 0 deletions src/tile2net/raster/input_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from toolz import curried, curry as cur, pipe

from tile2net.raster.util import cached_descriptor
import os


if False:
Expand Down Expand Up @@ -50,6 +51,7 @@ def __set__(self, instance: Raster, value: str | PathLike):
return
if isinstance(value, Path):
value = str(value)
value = os.path.normpath(value)

self.original = value

Expand Down Expand Up @@ -141,9 +143,13 @@ def __fspath__(self):
return self.dir

def __repr__(self):
if not self.original:
return f'{super().__repr__()} not set'
return self.original

def __str__(self):
if not self.original:
return f'{super().__repr__()} not set'
return self.original

def __set_name__(self, owner, name):
Expand Down
19 changes: 17 additions & 2 deletions src/tile2net/raster/pednet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import datetime
import shutil
import warnings

import pandas as pd
import os
Expand Down Expand Up @@ -217,6 +218,11 @@ def create_crosswalk(self):
cw_ntw.geometry = cw_ntw.geometry.set_crs(3857)
smoothed = wrinkle_remover(cw_ntw, 1.3)
self.crosswalk = smoothed
else:
warnings.warn('No crosswalks found')
self.crosswalk = gpd.GeoDataFrame({
'geometry': [],
},crs=3857)

def create_lines(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
"""
Expand Down Expand Up @@ -322,6 +328,11 @@ def create_sidewalks(self):
except:
# logging.info('cannot save modified')
self.sidewalk = sw_uni_lines
else:
warnings.warn('No sidewalk polygons found')
self.sidewalk = gpd.GeoDataFrame({
'geometry': [],
}, crs=3857)

self.sidewalk['f_type'] = 'sidewalk'

Expand All @@ -338,8 +349,12 @@ def convert_whole_poly2line(self):
points = get_line_sepoints(self.crosswalk)

# query LineString geometry to identify points intersecting 2 geometries
inp, res = self.crosswalk.sindex.query(geo2geodf(points).geometry,
predicate="intersects")
# inp, res = self.crosswalk.sindex.query(geo2geodf(points).geometry,
# predicate="intersects")
inp, res = (
self.crosswalk.sindex
.query(geo2geodf(points).geometry, predicate="intersects")
)
unique, counts = np.unique(inp, return_counts=True)
ends = np.unique(res[np.isin(inp, unique[counts == 1])])

Expand Down
31 changes: 14 additions & 17 deletions src/tile2net/raster/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from numpy import ndarray
from toolz.curried import *


if False:
from tile2net.raster.raster import Raster
from tile2net.raster.source import Source
Expand Down Expand Up @@ -174,6 +173,7 @@ def __get__(self, instance, owner) -> 'Weights':
# noinspection PyTypeChecker
return super().__get__(instance, owner)


class Segmentation(Directory):
def __fspath__(self):
raster = self.project.raster
Expand All @@ -198,7 +198,6 @@ def files(self, tiles: ndarray = None) -> Iterator[str]:
yield os.path.join(path, f'{r}_{c}_{i}.{extension}')



class Assets(Directory):
# @directory_method
def files(self, raster: 'Raster', **kwargs) -> list[Path]:
Expand All @@ -211,7 +210,6 @@ def __get__(self, instance, owner) -> 'Assets':
weights = Weights()



class Config(File):
# @directory_method
def __fspath__(self):
Expand Down Expand Up @@ -291,7 +289,6 @@ def __get__(self, instance, owner) -> 'Static':

_directory = WeakKeyDictionary()


def __fspath__(self):
raster = self.project.raster
source: 'Source' = raster.source
Expand All @@ -306,7 +303,6 @@ def __fspath__(self):
else:
raise ValueError('raster has no source or input_dir')


def files(self, tiles: ndarray = None) -> Iterator[Path]:
raster = self.project.raster
if tiles is None:
Expand All @@ -326,7 +322,6 @@ def files(self, tiles: ndarray = None) -> Iterator[Path]:
)



class Stitched(Directory):

def files(self, tiles: ndarray = None) -> Iterator[str]:
Expand Down Expand Up @@ -513,10 +508,10 @@ class Project(Directory):
segmentation = Segmentation()

def __init__(
self,
name: str,
outdir: PathLike,
raster: 'Raster'
self,
name: str,
outdir: PathLike,
raster: 'Raster',
):
"""
file structure for tile2net project
Expand Down Expand Up @@ -548,13 +543,15 @@ def __init__(
# free = psutil.disk_usage(self.__fspath__()).free
path = Path(self.__fspath__())
path.mkdir(parents=True, exist_ok=True)
free = psutil.disk_usage(path.__fspath__()).free
if free > 2 * 2 ** 30:
...
elif free > 1 * 2 ** 30:
warnings.warn(f'Low disk space: {free / 2 ** 30} GB')
else:
warnings.warn(f'Very low disk space: {free / 2 ** 20} MB')

# disabled because of python 3.12 bug
# free = psutil.disk_usage(path.__fspath__()).free
# if free > 2 * 2 ** 30:
# ...
# elif free > 1 * 2 ** 30:
# warnings.warn(f'Low disk space: {free / 2 ** 30} GB')
# else:
# warnings.warn(f'Very low disk space: {free / 2 ** 20} MB')

def to_file(self, path: PathLike = None) -> Path:
if path is None:
Expand Down
89 changes: 52 additions & 37 deletions src/tile2net/raster/raster.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,60 @@
from __future__ import annotations

import logging
import weakref
from functools import *

import copy
import inspect
import itertools
import json
import logging
import math

import tile2net.logger
from tile2net.raster import util
import subprocess
import mimetypes
import os
import certifi
import imageio.v2

import subprocess
import sys
import weakref
from concurrent.futures import Future, as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import cached_property
from os import PathLike as _PathLike
from pathlib import Path
from typing import Iterator, Optional, Type, Union

import certifi
import imageio.v2
import numpy as np
import requests
import itertools

from tqdm import tqdm

from typing import Iterator, Optional, Type, Union

from pathlib import Path
from os import PathLike as _PathLike
import json
from concurrent.futures import ThreadPoolExecutor, Future, as_completed
from functools import cached_property

from PIL import Image
import toolz
from PIL import Image
from numpy import ndarray
from toolz import curried, pipe, partial, curry
from tqdm import tqdm

from tile2net.raster import util
from tile2net.raster.grid import Grid
from tile2net.raster.input_dir import InputDir
from tile2net.raster.project import Project
from tile2net.raster.source import Source
from tile2net.raster.input_dir import InputDir
from tile2net.raster.validate import validate
from tile2net.logger import logger


def get_extension(url):
try:
response = requests.head(url, allow_redirects=True, timeout=5)
content_type = response.headers.get('content-type')
if content_type:
extension = mimetypes.guess_extension(content_type.split(';')[0])
return extension if extension else 'unknown'
return 'unknown'
except Exception:
return 'unknown'

def get_extensions(
urls: list[str] | ndarray,
) -> list[str]:
with ThreadPoolExecutor() as executor:
extensions = list(executor.map(get_extension, urls))
return extensions


PathLike = Union[str, _PathLike]


Expand Down Expand Up @@ -304,7 +317,7 @@ def __init__(
self.dest = ""
self.name = name
self.boundary_path = ""
self.input_dir: InputDir = input_dir
self.input_dir = input_dir
self.source = source
self.dump_percent = dump_percent
self.batch = -1
Expand Down Expand Up @@ -432,7 +445,10 @@ def stitch(self, step: int, force=False) -> None:
return

infiles: np.ndarray = pipe(
self.tiles, self.project.tiles.static.files, list, np.array
self.tiles,
self.project.tiles.static.files,
list,
np.array
)
indices = np.arange(self.tiles.size).reshape((self.width, self.height))
indices = (
Expand Down Expand Up @@ -845,16 +861,16 @@ def inference(self, *args: str, ):
"-m",
"tile2net",
"inference",
"--city_info",
str(info),
"--interactive",
"--dump_percent",
str(self.dump_percent),
*args,
]
logger.info(f"Running {args}")
# if eval_folder:
# args.extend(["--eval_folder", str(eval_folder)])
sargs = set(args)
extend = getattr(args, 'extend')
if '--city_info' not in sargs:
extend(["--city_info", str(info)])
if '--dump_percent' not in sargs:
extend(["--dump_percent", str(self.dump_percent)])
if '--interactive' not in sargs:
args.append("--interactive")
logger.debug(f'Running {" ".join(args)}')
try:
# todo: capture_outputs=False if want instant printout
Expand Down Expand Up @@ -960,4 +976,3 @@ def extension(self):
)
raster.generate(2)
raster.inference()

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
19 changes: 8 additions & 11 deletions src/tile2net/raster/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,6 @@ class SourceMeta(ABCMeta):
catalog: dict[str, Type[Source]] = {}
coverage = Coverage()

# @classmethod
# @property
# def coverage(cls) -> GeoSeries:
# return coverage

@not_found_none
def __getitem__(
cls: Type[Source],
Expand Down Expand Up @@ -371,6 +366,7 @@ class Massachusetts(ArcGis):
server = 'https://tiles.arcgis.com/tiles/hGdibHYSPO59RG1h/arcgis/rest/services/USGS_Orthos_2019/MapServer'
name = 'ma'
keyword = 'Massachusetts'
extension = 'jpg'


class KingCountyWashington(ArcGis):
Expand Down Expand Up @@ -432,12 +428,13 @@ class LosAngeles(ArcGis):
# extension = 'jpeg'
# keyword = 'Oregon'

class Oregon(ArcGis):
server = 'https://imagery.oregonexplorer.info/arcgis/rest/services/OSIP_2022/OSIP_2022_WM/ImageServer'
name = 'or'
extension = 'jpeg'
keyword = 'Oregon'

# todo: Oregon also has some SSL issues
# class Oregon(ArcGis):
# server = 'https://imagery.oregonexplorer.info/arcgis/rest/services/OSIP_2022/OSIP_2022_WM/ImageServer'
# name = 'or'
# extension = 'jpeg'
# keyword = 'Oregon'
#

class NewJersey(ArcGis):
server = 'https://maps.nj.gov/arcgis/rest/services/Basemap/Orthos_Natural_2020_NJ_WM/MapServer'
Expand Down
6 changes: 5 additions & 1 deletion src/tile2net/raster/tile_utils/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
# import momepy
import pandas as pd
from rasterio.windows import shape

pd.options.mode.chained_assignment = None
os.environ['USE_PYGEOS'] = '0'
Expand Down Expand Up @@ -138,7 +139,10 @@ def get_start_end(line_gdf):


def vectorize_points(lst):
return np.apply_along_axis(Point, 1, lst)
try:
return np.apply_along_axis(Point, 1, lst)
except ValueError:
return np.array([], dtype=object)


def get_shortest(gdf1, gdf2, f_type: str, max_dist=12):
Expand Down
2 changes: 2 additions & 0 deletions src/tile2net/tileseg/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,5 +513,7 @@ def inference(args: Namespace):
if __name__ == '__main__':
"""
--city_info /tmp/tile2net/washington_square_park/tiles/washington_square_park_256_info.json --interactive --dump_percent 10
--city_info /tmp/tile2net/frisco_test_1/tiles/frisco_test_1_256_info.json --interactive --dump_percent 10
"""
argh.dispatch_command(inference)
7 changes: 2 additions & 5 deletions tests/test_local.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import shutil

import pytest
from tile2net.raster.raster import Raster

def test_small():
raster = Raster(
location='Washington Square Park, New York, NY, USA',
zoom=19,
# dump_percent=1,
# name='small'
name='tiny'
name='small'
# name='tiny'
)

raster.generate(2)
Expand Down
Loading

0 comments on commit 726982b

Please sign in to comment.