Skip to content

Commit

Permalink
Merge branch 'dev' of https://github.com/tkalir/anyway into 2669-Impr…
Browse files Browse the repository at this point in the history
…ove-location-accuracy-for-street-newsflashes
  • Loading branch information
tkalir committed Sep 17, 2024
2 parents b65c013 + 0e9a621 commit 4f82f89
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 41 deletions.
59 changes: 34 additions & 25 deletions anyway/parsers/location_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from sqlalchemy.orm import load_only
from datetime import date


INTEGER_FIELDS = ["road1", "road2", "road_segment_id", "yishuv_symbol", "street1", "street2"]


Expand Down Expand Up @@ -522,31 +521,41 @@ def extract_location_text(text):
return text


def extract_geo_features(db, newsflash: NewsFlash, update_cbs_location_only: bool) -> None:
location_from_db = None
if update_cbs_location_only:
if newsflash.resolution in ["כביש בינעירוני", "רחוב"]:
location_from_db = get_db_matching_location(
db, newsflash.lat, newsflash.lon, newsflash.resolution, newsflash.road1
)
else:
newsflash.location = extract_location_text(newsflash.description) or extract_location_text(
newsflash.title
def update_coordinates_and_resolution_using_location_text(newsflash):
newsflash.location = extract_location_text(newsflash.description) or extract_location_text(
newsflash.title
)
geo_location = geocode_extract(newsflash.location)
if geo_location is not None:
newsflash.lat = geo_location["geom"]["lat"]
newsflash.lon = geo_location["geom"]["lng"]
newsflash.road1 = geo_location["road_no"]
newsflash.resolution = set_accident_resolution(geo_location)
return geo_location is not None


def extract_geo_features(db, newsflash: NewsFlash, use_existing_coordinates_only: bool) -> None:
if not use_existing_coordinates_only:
update_coordinates_and_resolution_using_location_text(newsflash)

if newsflash.resolution in BE_CONST.SUPPORTED_RESOLUTIONS:
location_from_db = get_db_matching_location(
db, newsflash.lat, newsflash.lon, newsflash.resolution, newsflash.road1
)
geo_location = geocode_extract(newsflash.location)
if geo_location is not None:
newsflash.lat = geo_location["geom"]["lat"]
newsflash.lon = geo_location["geom"]["lng"]
newsflash.resolution = set_accident_resolution(geo_location)
location_from_db = get_db_matching_location(
db, newsflash.lat, newsflash.lon, newsflash.resolution, geo_location["road_no"]
)
if location_from_db is not None:
for k, v in location_from_db.items():
setattr(newsflash, k, v)
for field in RF.get_all_location_fields():
if field not in location_from_db:
setattr(newsflash, field, None)
if location_from_db is not None:
update_location_fields(newsflash, location_from_db)
try_find_segment_id(newsflash)


def update_location_fields(newsflash, location_from_db):
for k, v in location_from_db.items():
setattr(newsflash, k, v)
for field in RF.get_all_location_fields():
if field not in location_from_db:
setattr(newsflash, field, None)


def try_find_segment_id(newsflash):
if (
newsflash.road_segment_id is None
and newsflash.road_segment_name is not None
Expand Down
6 changes: 3 additions & 3 deletions anyway/parsers/news_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def update_all_in_db(source=None, newsflash_id=None, update_cbs_location_only=Fa
newsflash.accident = classify(newsflash.title)
if newsflash.accident:
extract_geo_features(
db=db, newsflash=newsflash, update_cbs_location_only=update_cbs_location_only
db=db, newsflash=newsflash, use_existing_coordinates_only=update_cbs_location_only
)
if i % 1000 == 0:
db.commit()
Expand All @@ -53,7 +53,7 @@ def scrape_extract_store_rss(site_name, db):
newsflash.organization = classify_organization(site_name)
if newsflash.accident:
# FIX: No accident-accurate date extracted
extract_geo_features(db=db, newsflash=newsflash, update_cbs_location_only=False)
extract_geo_features(db=db, newsflash=newsflash, use_existing_coordinates_only=False)
db.insert_new_newsflash(newsflash)


Expand All @@ -66,7 +66,7 @@ def scrape_extract_store_twitter(screen_name, db):
newsflash.accident = classify_tweets(newsflash.description)
newsflash.organization = classify_organization("twitter")
if newsflash.accident:
extract_geo_features(db=db, newsflash=newsflash, update_cbs_location_only=False)
extract_geo_features(db=db, newsflash=newsflash, use_existing_coordinates_only=False)
db.insert_new_newsflash(newsflash)


Expand Down
64 changes: 54 additions & 10 deletions anyway/views/news_flash/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import os

from typing import List, Optional
from typing import List, Optional, Tuple, Any
from http import HTTPStatus
from collections import OrderedDict

Expand Down Expand Up @@ -37,12 +37,20 @@
DEFAULT_OFFSET_REQ_PARAMETER = 0
DEFAULT_LIMIT_REQ_PARAMETER = 100
DEFAULT_NUMBER_OF_YEARS_AGO = 5
PAGE_NUMBER = "pageNumber"
PAGE_SIZE = "pageSize"
NEWS_FALSH_ID = "newsFlash_Id"
ID = "id"
LIMIT = "limit"
OFFSET = "offset"


class NewsFlashQuery(BaseModel):
id: Optional[int]
road_number: Optional[int]
offset: Optional[int] = DEFAULT_OFFSET_REQ_PARAMETER
pageNumber: Optional[int]
pageSize: Optional[int]
limit: Optional[int] = DEFAULT_LIMIT_REQ_PARAMETER
resolution: Optional[List[str]]
source: Optional[List[BE_CONST.Source]]
Expand Down Expand Up @@ -75,19 +83,54 @@ def news_flash():
except ValidationError as e:
return make_response(jsonify(e.errors()[0]["msg"]), 404)

if "id" in validated_query_params:
return get_news_flash_by_id(validated_query_params["id"])
pagination, validated_query_params = set_pagination_params(validated_query_params)

query = gen_news_flash_query(db.session, validated_query_params)
if ID in validated_query_params:
return get_news_flash_by_id(validated_query_params[ID])

total, query = gen_news_flash_query(db.session, validated_query_params)
news_flashes = query.all()

news_flashes_jsons = [n.serialize() for n in news_flashes]
for news_flash in news_flashes_jsons:
news_flashes_dicts = [n.serialize() for n in news_flashes]
for news_flash in news_flashes_dicts:
set_display_source(news_flash)
return Response(json.dumps(news_flashes_jsons, default=str), mimetype="application/json")


def gen_news_flash_query(session, valid_params: dict):
if pagination:
res = add_pagination_to_result(validated_query_params, news_flashes_dicts, total)
else:
res = news_flashes_dicts
return Response(json.dumps(res, default=str), mimetype="application/json")


def set_pagination_params(validated_params: dict) -> Tuple[bool, dict]:
pagination = False
if NEWS_FALSH_ID in validated_params:
validated_params[ID] = validated_params.pop(NEWS_FALSH_ID)
if PAGE_NUMBER in validated_params:
pagination = True
page_number = validated_params[PAGE_NUMBER]
page_size = validated_params.get(PAGE_SIZE, DEFAULT_LIMIT_REQ_PARAMETER)
validated_params[OFFSET] = (page_number - 1) * page_size
validated_params[LIMIT] = page_size
return pagination, validated_params


def add_pagination_to_result(validated_params: dict, news_flashes: list, num_nf: int) -> dict:
page_size = validated_params[PAGE_SIZE]
page_num = validated_params[PAGE_NUMBER]
total_pages = num_nf // page_size + (1 if num_nf % page_size else 0)
return {
"data": news_flashes,
"pagination": {
"pageNumber": page_num,
"pageSize": page_size,
"totalPages": total_pages,
"totalRecords": num_nf
}
}


def gen_news_flash_query(session, valid_params: dict) -> Tuple[int, Any]:
query = session.query(NewsFlash)
supported_resolutions = set([x.value for x in BE_CONST.SUPPORTED_RESOLUTIONS])
query = query.filter(NewsFlash.resolution.in_(supported_resolutions))
Expand All @@ -113,11 +156,12 @@ def gen_news_flash_query(session, valid_params: dict):
not_(and_(NewsFlash.lat == None, NewsFlash.lon == None)),
)
).order_by(NewsFlash.date.desc())
total = query.count()
if "offset" in valid_params:
query = query.offset(valid_params["offset"])
if "limit" in valid_params:
query = query.limit(valid_params["limit"])
return query
return total, query


def set_display_source(news_flash):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_news_flash_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,21 @@ def _test_update_news_flash_qualifying_not_manual_exists_location_db(self):
def test_gen_news_flash_query(self):
orig_supported_resolutions = BE_CONST.SUPPORTED_RESOLUTIONS
BE_CONST.SUPPORTED_RESOLUTIONS = [BE_CONST.ResolutionCategories.DISTRICT]
actual = gen_news_flash_query(self.session, {"road_number": 12345678})
_, actual = gen_news_flash_query(self.session, {"road_number": 12345678})
news_flashes = actual.all()
self.assertEqual(len(news_flashes), 1, "single news flash")
self.assertEqual(
news_flashes[0].description, self.district_description, "district description"
)

BE_CONST.SUPPORTED_RESOLUTIONS = [BE_CONST.ResolutionCategories.REGION]
actual = gen_news_flash_query(self.session, {"road_number": 12345678})
_, actual = gen_news_flash_query(self.session, {"road_number": 12345678})
news_flashes = actual.all()
self.assertEqual(len(news_flashes), 1, "single news flash")
self.assertEqual(news_flashes[0].description, self.region_description, "region description")

BE_CONST.SUPPORTED_RESOLUTIONS = [BE_CONST.ResolutionCategories.CITY]
actual = gen_news_flash_query(self.session, {"road_number": 12345678})
_, actual = gen_news_flash_query(self.session, {"road_number": 12345678})
news_flashes = actual.all()
self.assertEqual(len(news_flashes), 0, "zero news flash")

Expand Down

0 comments on commit 4f82f89

Please sign in to comment.