Skip to content

Commit

Permalink
adding OpenAI usage for street identification
Browse files Browse the repository at this point in the history
  • Loading branch information
tkalir committed Oct 3, 2024
1 parent 4f82f89 commit 135ce08
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 16 deletions.
83 changes: 83 additions & 0 deletions anyway/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from openai import OpenAI
import json
import tiktoken

from langchain.output_parsers import PydanticOutputParser, EnumOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
import langchain
from enum import Enum
from anyway import secrets

api_key = secrets.get("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)

langchain.debug = True
model = ChatOpenAI(api_key=api_key, temperature=0)


def match_streets_with_langchain(street_names, location):
street_names.append("-")
Streets = Enum('Streets', {name: name for name in street_names})

parser = EnumOutputParser(enum=Streets)
print(parser.get_format_instructions())
prompt = PromptTemplate(
template="Return the street that is mentioned in the location string. if non matches return '-'.\nstreets: {streets}\n" +
"location_string:{location}\n{format_instructions}\n",
input_variables=["streets", "location"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)

chain = prompt | model | parser

res = chain.invoke({"streets": street_names, "location": location})
return res


def count_tokens_for_prompt(messages, model):
tokenizer = tiktoken.encoding_for_model(model)
total_tokens = 0
for message in messages:
# Each message has a role and content
message_tokens = tokenizer.encode(f"{message['role']}: {message['content']}")
total_tokens += len(message_tokens)
# Additional tokens for formatting
total_tokens += 4 # approx overhead for each message (role + delimiters)

return total_tokens


def count_tokens(text, model):
tokenizer = tiktoken.encoding_for_model(model)
tokens = tokenizer.encode(text)
return len(tokens)


def ask_gpt(system_message, user_message, model="gpt-4o"):
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
]
completion = client.chat.completions.create(
response_format={"type": "json_object"},
model=model,
messages=messages
)
print(f"tokens for prompt: {count_tokens_for_prompt(messages, model)}")
return completion.choices[0].message


def ask_ai_about_street_matching(streets, location_string, model="gpt-4-turbo"):
system_message = """
Given a list of streets, return the name of the street that is mentioned in the provided location string.
Return the name exactly as appears in list.
If no match is found, return "-".
Return json with field "street" and your answer.
Select one of the following options:
""" + json.dumps(streets + ["-"])
input = json.dumps({"streets": streets, "location": location_string})
reply = ask_gpt(system_message, input, model)
# print(f"tokens: {count_tokens(reply.content, model)}")
result = json.loads(reply.content)["street"]
return result, result in streets
89 changes: 75 additions & 14 deletions anyway/parsers/location_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from anyway.parsers.resolution_fields import ResolutionFields as RF
from anyway import secrets
from anyway.models import AccidentMarkerView, RoadSegments
from anyway.llm import ask_ai_about_street_matching
from sqlalchemy import not_
import pandas as pd
from sqlalchemy.orm import load_only
Expand Down Expand Up @@ -61,7 +62,7 @@ def get_road_segment_by_name_and_road(road_segment_name: str, road: int) -> Road
segments = db.session.query(RoadSegments).filter(RoadSegments.road == road).all()
for segment in segments:
if road_segment_name.startswith(segment.from_name) and road_segment_name.endswith(
segment.to_name
segment.to_name
):
return segment
err_msg = f"get_road_segment_by_name_and_road:{road_segment_name},{road}: not found"
Expand Down Expand Up @@ -246,6 +247,35 @@ def get_db_matching_location(db, latitude, longitude, resolution, road_no=None):
return final_loc


def read_n_closest_streets(db, n, latitude, longitude, road_no=None):
markers = read_markers_and_distance_from_location(db, latitude, longitude,
BE_CONST.ResolutionCategories.STREET, road_no)
# Sort by distance
sorted_markers = markers.sort_values(by="dist_point")

# Drop duplicates to ensure unique street1_hebrew values
unique_street_markers = sorted_markers.drop_duplicates(subset="street1_hebrew")

# Select the top n entries
top_n_unique_streets = unique_street_markers.head(n)

# Convert to dictionary if needed
result_dicts = top_n_unique_streets.to_dict(orient='records')
return [result["street1_hebrew"] for result in result_dicts]


def read_n_closest_markers(db, n, latitude, longitude, resolution, road_no=None):
markers = read_markers_and_distance_from_location(db, latitude, longitude, resolution, road_no)
# Sort by distance
sorted_markers = markers.sort_values(by="dist_point")

top_n_unique_streets = sorted_markers.head(n)

# Convert to dictionary if needed
result_dicts = top_n_unique_streets.to_dict(orient='records')
return result_dicts


def set_accident_resolution(accident_row):
"""
set the resolution of the accident
Expand Down Expand Up @@ -288,11 +318,12 @@ def reverse_geocode_extract(latitude, longitude):
try:
gmaps = googlemaps.Client(key=secrets.get("GOOGLE_MAPS_KEY"))
geocode_result = gmaps.reverse_geocode((latitude, longitude))

print(geocode_result)
# if we got no results, move to next iteration of location string
if not geocode_result:
return None
except Exception as _:
logging.info(_)
logging.info("exception in gmaps")
return None
# logging.info(geocode_result)
Expand Down Expand Up @@ -458,30 +489,30 @@ def extract_location_text(text):
punc_after_ind = text.find(punc_to_try, forbid_ind)
if punc_before_ind != -1 or punc_after_ind != -1:
if punc_before_ind == -1:
text = text[(punc_after_ind + 1) :]
text = text[(punc_after_ind + 1):]
elif punc_after_ind == -1:
text = text[:punc_before_ind]
else:
text = text[:punc_before_ind] + " " + text[(punc_after_ind + 1) :]
text = text[:punc_before_ind] + " " + text[(punc_after_ind + 1):]
removed_punc = True
break
if (not removed_punc) and (forbid_word in hospital_words):
for hospital_name in hospital_names:
hospital_ind = text.find(hospital_name)
if (
hospital_ind == forbid_ind + len(forbid_word) + 1
or hospital_ind == forbid_ind + len(forbid_word) + 2
hospital_ind == forbid_ind + len(forbid_word) + 1
or hospital_ind == forbid_ind + len(forbid_word) + 2
):
text = (
text[:hospital_ind] + text[hospital_ind + len(hospital_name) + 1 :]
text[:hospital_ind] + text[hospital_ind + len(hospital_name) + 1:]
)
forbid_ind = text.find(forbid_word)
text = text[:forbid_ind] + text[forbid_ind + len(forbid_word) + 1 :]
text = text[:forbid_ind] + text[forbid_ind + len(forbid_word) + 1:]
found_hospital = True
if (not found_hospital) and (not removed_punc):
text = (
text[:forbid_ind]
+ text[text.find(" ", forbid_ind + len(forbid_word) + 2) :]
text[:forbid_ind]
+ text[text.find(" ", forbid_ind + len(forbid_word) + 2):]
)

except Exception as _:
Expand Down Expand Up @@ -517,7 +548,7 @@ def extract_location_text(text):
for token in near_tokens:
i = text.find(token)
if i >= 0:
text = text[:i] + token + text[i + len(token) :]
text = text[:i] + token + text[i + len(token):]
return text


Expand Down Expand Up @@ -545,6 +576,36 @@ def extract_geo_features(db, newsflash: NewsFlash, use_existing_coordinates_only
if location_from_db is not None:
update_location_fields(newsflash, location_from_db)
try_find_segment_id(newsflash)
logging.debug(newsflash.resolution)
if newsflash.resolution == BE_CONST.ResolutionCategories.STREET:
try_improve_street_identification(newsflash)


def try_improve_street_identification(newsflash):
from anyway.parsers import news_flash_db_adapter

db = news_flash_db_adapter.init_db()
all_closest_streets = read_n_closest_streets(db, 20, newsflash.lat, newsflash.lon)

num_of_streets_for_first_try = 5
streets_for_first_try = all_closest_streets[:num_of_streets_for_first_try]
streets_for_second_try = all_closest_streets[num_of_streets_for_first_try:]

result, result_in_input = ask_ai_about_street_matching(streets_for_first_try, newsflash.location)
logging.debug(f"result of 1st try {result}")
if not result_in_input:
logging.debug(f"street matching failed first try for newsflash {newsflash.id}")
result, result_in_input = ask_ai_about_street_matching(streets_for_second_try, newsflash.location)
logging.debug(f"result of 2nd try {result}")
if result_in_input:
if result == newsflash.street1_hebrew:
logging.debug("street matching succeeded, street not changed")
else:
logging.debug(f"street matching succeeded, street updated for {newsflash.id} "
f"from {newsflash.street1_hebrew} to {result}")
newsflash.street1_hebrew = result
else:
logging.debug(f"street matching failed second try for newsflash {newsflash.id}")


def update_location_fields(newsflash, location_from_db):
Expand All @@ -557,9 +618,9 @@ def update_location_fields(newsflash, location_from_db):

def try_find_segment_id(newsflash):
if (
newsflash.road_segment_id is None
and newsflash.road_segment_name is not None
and newsflash.road1 is not None
newsflash.road_segment_id is None
and newsflash.road_segment_name is not None
and newsflash.road1 is not None
):
try:
seg = get_road_segment_by_name_and_road(newsflash.road_segment_name, newsflash.road1)
Expand Down
13 changes: 12 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def update(source, news_flash_id, update_cbs_location_only):
source = None
if not news_flash_id:
news_flash_id = None
return news_flash.update_all_in_db(source, news_flash_id, update_cbs_location_only)
return news_flash.update_all_in_db(source, news_flash_id, False)


@update_news_flash.command()
Expand Down Expand Up @@ -333,6 +333,17 @@ def infographics_pictures(id):
raise Exception("generation failed")


@process.command()
@click.option("--id", type=int)
def street_name(id):
from anyway.parsers import news_flash_db_adapter
from anyway.parsers.location_extraction import try_improve_street_identification

db = news_flash_db_adapter.init_db()
newsflash = db.get_newsflash_by_id(id).first()
try_improve_street_identification(newsflash)


@process.group()
def cache():
pass
Expand Down
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Flask-Login==0.5.0
Flask-SQLAlchemy==2.4.1
flask-restx==0.5.1
Jinja2==3.1.4
SQLAlchemy==1.3.17
SQLAlchemy==1.4
Werkzeug==2.0.3
alembic==1.4.2
attrs==23.1.0
Expand Down Expand Up @@ -53,3 +53,7 @@ swifter==1.3.4
telebot==0.0.5
selenium==4.11.2
apache-airflow-client==2.6.2
openai==1.45.0
langchain==0.2.16
langchain_openai==0.1.25
python-dotenv

0 comments on commit 135ce08

Please sign in to comment.