diff --git a/spire/broodusers.py b/spire/broodusers.py index cd99035..0c0c542 100644 --- a/spire/broodusers.py +++ b/spire/broodusers.py @@ -5,7 +5,7 @@ import uuid from bugout.app import Bugout # type: ignore -import requests +import requests # type: ignore from sqlalchemy.orm import Session from .utils.settings import auth_url_from_env, SPIRE_API_URL, BUGOUT_CLIENT_ID_HEADER diff --git a/spire/db.py b/spire/db.py index 7fa67ea..1a66861 100644 --- a/spire/db.py +++ b/spire/db.py @@ -4,7 +4,7 @@ from contextlib import contextmanager from typing import Optional -import redis +import redis # type: ignore from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session diff --git a/spire/github/api.py b/spire/github/api.py index 78ab93c..460c525 100644 --- a/spire/github/api.py +++ b/spire/github/api.py @@ -5,7 +5,7 @@ import json import time import logging -import dateutil.parser +import dateutil.parser # type: ignore from typing import Any, Dict, List, Optional import uuid @@ -19,7 +19,7 @@ ) import jwt # type: ignore -import requests +import requests # type: ignore from starlette.responses import RedirectResponse from sqlalchemy.orm import Session diff --git a/spire/github/calls.py b/spire/github/calls.py index 8cff0c7..86e6b35 100644 --- a/spire/github/calls.py +++ b/spire/github/calls.py @@ -5,7 +5,7 @@ import logging from typing import Any, cast, Dict, List, Optional -import requests +import requests # type: ignore logger = logging.getLogger(__name__) diff --git a/spire/humbug/actions.py b/spire/humbug/actions.py index 98dce96..5303ef3 100644 --- a/spire/humbug/actions.py +++ b/spire/humbug/actions.py @@ -3,7 +3,7 @@ from uuid import UUID, uuid4 from sqlalchemy.orm import Session -import requests +import requests # type: ignore from ..journal.actions import create_journal_entries_pack from .data import HumbugEventDependencies, HumbugReport diff --git a/spire/indices.py b/spire/indices.py index 7b8ba3d..becc347 100644 --- a/spire/indices.py +++ b/spire/indices.py @@ -1,8 +1,8 @@ import logging from typing import Any, Callable, cast, Dict, List, Union -import requests -from requests.api import head +import requests # type: ignore +from requests.api import head # type: ignore from sqlalchemy.orm import Session from .slack.data import Index diff --git a/spire/journal/actions.py b/spire/journal/actions.py index f354397..7539c27 100644 --- a/spire/journal/actions.py +++ b/spire/journal/actions.py @@ -13,7 +13,7 @@ import boto3 from sqlalchemy.orm import Session, Query -from sqlalchemy import or_, func, text, and_ +from sqlalchemy import or_, func, text, and_, select from sqlalchemy.dialects import postgresql @@ -26,6 +26,7 @@ CreateJournalEntryRequest, JournalEntryListContent, CreateJournalEntryTagRequest, + CreateEntriesTagsRequest, JournalSearchResultsResponse, JournalStatisticsResponse, UpdateJournalSpec, @@ -67,6 +68,17 @@ class EntryNotFound(Exception): """ +# Excption with list of not found entries +class EntriesNotFound(Exception): + """ + Raised on actions that involve journal entries which are not present in the database. + """ + + def __init__(self, message: str, entries: List[UUID] = []): + super().__init__(message) + self.entries = entries + + class EntryLocked(Exception): """ Raised on actions when entry is not released for editing by other users. @@ -91,6 +103,12 @@ class InvalidParameters(ValueError): """ +class CommitFailed(Exception): + """ + Raised when commit failed. + """ + + def acl_auth( db_session: Session, user_id: str, user_group_id_list: List[str], journal_id: UUID ) -> Tuple[Journal, Dict[HolderType, List[str]]]: @@ -687,6 +705,71 @@ async def get_journal_entry_with_tags( return entry, tags, entry_lock +async def get_journal_entries_with_tags( + db_session: Session, journal_entries_ids: List[UUID] +) -> List[JournalEntryResponse]: + """ + Returns a journal entries by its id with tags. + """ + objects = ( + db_session.query(JournalEntry, JournalEntryTag.tag) + .join( + JournalEntryTag, + JournalEntryTag.journal_entry_id == JournalEntry.id, + isouter=True, + ) + .join( + JournalEntryLock, + JournalEntryLock.journal_entry_id == JournalEntry.id, + isouter=True, + ) + .filter(JournalEntry.id.in_(journal_entries_ids)) + ).cte("entries") + + entries = ( + db_session.query( + objects.c.id.label("id"), + objects.c.journal_id.label("journal_id"), + objects.c.title.label("title"), + objects.c.content.label("content"), + func.array_agg(objects.c.tag).label("tags"), + objects.c.created_at.label("created_at"), + objects.c.updated_at.label("updated_at"), + objects.c.context_url.label("context_url"), + objects.c.context_type.label("context_type"), + objects.c.context_id.label("context_id"), + ) + .group_by( + objects.c.id, + objects.c.journal_id, + objects.c.title, + objects.c.content, + objects.c.created_at, + objects.c.updated_at, + objects.c.context_url, + objects.c.context_type, + objects.c.context_id, + ) + .all() + ) + + return [ + JournalEntryResponse( + id=entry.id, + title=entry.title, + content=entry.content, + tags=list(entry.tags if entry.tags != [None] else []), + context_url=entry.context_url, + context_type=entry.context_type, + context_id=entry.context_id, + created_at=entry.created_at, + updated_at=entry.updated_at, + locked_by=None, + ) + for entry in entries + ] + + async def update_journal_entry( db_session: Session, new_title: str, @@ -1014,6 +1097,104 @@ async def update_journal_entry_tags( return query.all() +async def create_journal_entries_tags( + db_session: Session, + journal: Journal, + entries_tags_request: CreateEntriesTagsRequest, +) -> List[UUID]: + + """ + Create tags for entries in journal. + """ + + # For more useful error message + requested_entries = [ + entry.journal_entry_id for entry in entries_tags_request.entries + ] + + await entries_exists_check( + db_session=db_session, journal_id=journal.id, entries_ids=requested_entries + ) + + deduplicated_values = await dedublicate_entries_tags( + entries_tags=entries_tags_request + ) + + insert_statement = ( + postgresql.insert(JournalEntryTag) + .values(deduplicated_values) + .on_conflict_do_nothing(index_elements=["journal_entry_id", "tag"]) + ) + + try: + db_session.execute(insert_statement) + db_session.commit() + except Exception as err: + logger.error(f"Could not create tags for entries error: {err}") + db_session.rollback() + raise CommitFailed("Could not create tags") + + return requested_entries + + +async def delete_journal_entries_tags( + db_session: Session, + journal: Journal, + entries_tags_request: CreateEntriesTagsRequest, +) -> List[UUID]: + + """ + Delete tags for entries in journal. + """ + + requested_entries = [ + entry.journal_entry_id for entry in entries_tags_request.entries + ] + + await entries_exists_check( + db_session=db_session, journal_id=journal.id, entries_ids=requested_entries + ) + + deduplicated_values = await dedublicate_entries_tags( + entries_tags=entries_tags_request + ) + + selected_tags = ( + db_session.query( + JournalEntryTag.id.label("id"), + ) + .join(JournalEntry, JournalEntryTag.journal_entry_id == JournalEntry.id) + .filter(JournalEntry.journal_id == journal.id) + .filter( + JournalEntryTag.journal_entry_id.in_( + [entry["journal_entry_id"] for entry in deduplicated_values] + ) + ) + .filter( + JournalEntryTag.tag.in_([entry["tag"] for entry in deduplicated_values]) + ) + .cte("selected_tags") + ) + + delete_statement = ( + db_session.query(JournalEntryTag) + .filter(JournalEntryTag.id.in_(select(selected_tags.c.id))) + .delete(synchronize_session=False) + ) + + try: + db_session.commit() + logger.info( + f"Deleted {delete_statement} tags in journal {journal.id} for {len(requested_entries)} entries" + ) + except Exception as err: + logger.error(f"Could not delete tags for entries error: {err}") + db_session.rollback() + raise CommitFailed("Could not delete tags") + + return requested_entries + + async def delete_journal_entry_tag( db_session: Session, journal_spec: JournalSpec, @@ -1279,3 +1460,61 @@ async def delete_journal_scopes( db_session.commit() return permission_list + + +async def entries_exists_check( + db_session: Session, + journal_id: UUID, + entries_ids: List[UUID], +) -> None: + """ + Check if entries exists in journal. + """ + + # Index scan for entries ids + existing_entries_obj = ( + db_session.query(JournalEntry.id) + .filter(JournalEntry.journal_id == journal_id) + .filter(JournalEntry.id.in_(entries_ids)) + .all() + ) + + ### perfomance test https://stackoverflow.com/a/3462202/13271066 + + existing_entries: Set[UUID] = set([entry[0] for entry in existing_entries_obj]) + + diff = [x for x in entries_ids if x not in existing_entries] + + if len(diff) > 0: + raise EntriesNotFound("Could not find some of the given entries", diff) + + +async def dedublicate_entries_tags( + entries_tags: CreateEntriesTagsRequest, +) -> List[Dict[str, Any]]: + + values: List[Dict[str, Any]] = [] + + for entry_tag_request in entries_tags.entries: + entry_id = entry_tag_request.journal_entry_id + + for tag in entry_tag_request.tags: + + insert_object = { + "journal_entry_id": entry_id, + "tag": tag, + } + + values.append(insert_object) + + # Deduplicate tags + + seen = set() + deduplicated_values = [] + for d in values: + t = tuple(sorted(d.items())) + if t not in seen: + seen.add(t) + deduplicated_values.append(d) + + return deduplicated_values diff --git a/spire/journal/api.py b/spire/journal/api.py index c7030a5..4d855e6 100644 --- a/spire/journal/api.py +++ b/spire/journal/api.py @@ -3,6 +3,7 @@ from typing import Any, cast, Dict, List, Optional, Set, Union, Tuple from uuid import UUID + from elasticsearch import Elasticsearch from fastapi import ( FastAPI, @@ -14,7 +15,7 @@ Path, ) from fastapi.middleware.cors import CORSMiddleware -import requests +import requests # type: ignore from sqlalchemy.orm import Session import boto3 @@ -31,6 +32,7 @@ JournalEntryContent, JournalEntryListContent, CreateJournalEntryTagRequest, + CreateEntriesTagsRequest, CreateJournalEntryTagsAPIRequest, DeleteJournalEntryTagAPIRequest, DeleteJournalEntriesByTagsAPIRequest, @@ -42,6 +44,7 @@ JournalResponse, JournalEntryResponse, JournalEntryTagsResponse, + JournalsEntriesTagsResponse, JournalEntryIds, JournalStatisticsSpecs, JournalStatisticsResponse, @@ -1950,6 +1953,177 @@ async def update_tags( return api_tag_request.tags +@app.post( + "/{journal_id}/bulk_entries_tags", + tags=["tags"], + response_model=List[JournalEntryResponse], +) +async def create_entries_tags( + journal_id: UUID, + entries_tags_request: CreateEntriesTagsRequest, + request: Request, + db_session: Session = Depends(db.yield_connection_from_env), + es_client: Elasticsearch = Depends(es.yield_es_client_from_env), +) -> List[JournalEntryResponse]: + + """ + Create tags for multiple journal entries. + """ + + ensure_journal_permission( + db_session, + request.state.user_id, + request.state.user_group_id_list, + journal_id, + {JournalEntryScopes.UPDATE}, + ) + + journal_spec = JournalSpec(id=journal_id, bugout_user_id=request.state.user_id) + try: + journal = await actions.find_journal( + db_session=db_session, + journal_spec=journal_spec, + user_group_id_list=request.state.user_group_id_list, + ) + except actions.JournalNotFound: + logger.error( + f"Journal not found with ID={journal_id} for user={request.state.user_id}" + ) + raise HTTPException(status_code=404) + except Exception as e: + logger.error(f"Error retrieving journal: {str(e)}") + raise HTTPException(status_code=500) + es_index = journal.search_index + try: + updated_entry_ids = await actions.create_journal_entries_tags( + db_session, journal, entries_tags_request + ) + except actions.EntriesNotFound as e: + logger.error(f"Entries not found with entries") + raise HTTPException( + status_code=404, detail=f"Not entries with ids: {e.entries}" + ) + except actions.CommitFailed as e: + logger.error(f"Can't write tags for entries to database") + raise HTTPException( + status_code=409, detail=f"Can't write tags for entries to database" + ) + except Exception as e: + logger.error(f"Error journal entries tags update: {str(e)}") + raise HTTPException(status_code=500) + + try: + entries_objects = await actions.get_journal_entries_with_tags( + db_session, journal_entries_ids=updated_entry_ids + ) + except Exception as e: + logger.error(f"Error get journal entries: {str(e)}") + raise HTTPException(status_code=500) + + if es_index is not None: + + try: + + search.bulk_create_entries( + es_client, + es_index=es_index, + journal_id=journal_id, + entries=entries_objects, + ) + + except Exception as e: + logger.warning( + f"Error creating tags for entry ({updated_entry_ids}) in journal ({str(journal_id)}) for " + f"user ({request.state.user_id}): {repr(e)}" + ) + + return entries_objects + + +@app.delete( + "/{journal_id}/bulk_entries_tags", + tags=["tags"], + response_model=List[JournalEntryResponse], +) +async def delete_entries_tags( + journal_id: UUID, + entries_tags_request: CreateEntriesTagsRequest, + request: Request, + db_session: Session = Depends(db.yield_connection_from_env), + es_client: Elasticsearch = Depends(es.yield_es_client_from_env), +) -> List[JournalEntryResponse]: + + """ + Delete tags for multiple journal entries. + """ + + ensure_journal_permission( + db_session, + request.state.user_id, + request.state.user_group_id_list, + journal_id, + {JournalEntryScopes.UPDATE}, + ) + + journal_spec = JournalSpec(id=journal_id, bugout_user_id=request.state.user_id) + try: + journal = await actions.find_journal( + db_session=db_session, + journal_spec=journal_spec, + user_group_id_list=request.state.user_group_id_list, + ) + except actions.JournalNotFound: + logger.error( + f"Journal not found with ID={journal_id} for user={request.state.user_id}" + ) + raise HTTPException(status_code=404) + except Exception as e: + logger.error(f"Error retrieving journal: {str(e)}") + raise HTTPException(status_code=500) + es_index = journal.search_index + + try: + deleted_entry_ids = await actions.delete_journal_entries_tags( + db_session, journal, entries_tags_request + ) + except actions.CommitFailed as e: + logger.error(f"Can't delete tags form entries") + raise HTTPException( + status_code=409, detail=f"Can't delete tags from entries in database" + ) + except actions.EntriesNotFound as e: + logger.error(f"Entries not found with entries") + raise HTTPException( + status_code=404, detail=f"Not entries with ids: {e.entries}" + ) + except Exception as e: + logger.error(f"Error journal entries tags update: {str(e)}") + raise HTTPException(status_code=500) + + entries_objects = await actions.get_journal_entries_with_tags( + db_session, journal_entries_ids=deleted_entry_ids + ) + + if es_index is not None: + + try: + + search.bulk_create_entries( + es_client, + es_index=es_index, + journal_id=journal_id, + entries=entries_objects, + ) + + except Exception as e: + logger.warning( + f"Error creating tags for entry ({deleted_entry_ids}) in journal ({str(journal_id)}) for " + f"user ({request.state.user_id}): {repr(e)}" + ) + + return entries_objects + + @app.delete( "/{journal_id}/entries/{entry_id}/tags", tags=["tags"], diff --git a/spire/journal/data.py b/spire/journal/data.py index ecf5f22..44f2302 100644 --- a/spire/journal/data.py +++ b/spire/journal/data.py @@ -57,14 +57,14 @@ class RuleActions(Enum): class CreateJournalAPIRequest(BaseModel): # group_id is Optional to have possibility send null via update update_journal() name: str - group_id: Optional[str] + group_id: Optional[str] = None journal_type: JournalTypes = JournalTypes.DEFAULT class CreateJournalRequest(BaseModel): bugout_user_id: str name: str - search_index: Optional[str] + search_index: Optional[str] = None class JournalResponse(BaseModel): @@ -81,8 +81,8 @@ class ListJournalsResponse(BaseModel): class UpdateJournalSpec(BaseModel): - holder_id: Optional[str] - name: Optional[str] + holder_id: Optional[str] = None + name: Optional[str] = None class JournalEntryIds(BaseModel): @@ -90,10 +90,10 @@ class JournalEntryIds(BaseModel): class JournalSpec(BaseModel): - id: Optional[uuid.UUID] - bugout_user_id: Optional[str] - holder_ids: Optional[Set[str]] - name: Optional[str] + id: Optional[uuid.UUID] = None + bugout_user_id: Optional[str] = None + holder_ids: Optional[Set[str]] = None + name: Optional[str] = None class CreateJournalEntryRequest(BaseModel): @@ -101,21 +101,21 @@ class CreateJournalEntryRequest(BaseModel): title: str content: str tags: List[str] = Field(default_factory=list) - context_url: Optional[str] - context_id: Optional[str] - context_type: Optional[str] - created_at: Optional[datetime] + context_url: Optional[str] = None + context_id: Optional[str] = None + context_type: Optional[str] = None + created_at: Optional[datetime] = None class JournalEntryContent(BaseModel): title: str content: str tags: List[str] = Field(default_factory=list) - context_url: Optional[str] - context_id: Optional[str] - context_type: Optional[str] - created_at: Optional[datetime] - locked_by: Optional[str] + context_url: Optional[str] = None + context_id: Optional[str] = None + context_type: Optional[str] = None + created_at: Optional[datetime] = None + locked_by: Optional[str] = None class JournalEntryListContent(BaseModel): @@ -124,15 +124,15 @@ class JournalEntryListContent(BaseModel): class JournalEntryResponse(BaseModel): id: uuid.UUID - journal_url: Optional[str] - content_url: Optional[str] - title: Optional[str] - content: Optional[str] + journal_url: Optional[str] = None + content_url: Optional[str] = None + title: Optional[str] = None + content: Optional[str] = None tags: List[str] = Field(default_factory=list) - created_at: Optional[datetime] - updated_at: Optional[datetime] - context_url: Optional[str] - context_type: Optional[str] + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + context_url: Optional[str] = None + context_type: Optional[str] = None context_id: Optional[str] = None locked_by: Optional[str] = None @@ -152,18 +152,18 @@ class DronesStatisticsResponce(BaseModel): class JournalStatisticsSpecs(BaseModel): - entries_hour: Optional[bool] - entries_day: Optional[bool] - entries_week: Optional[bool] - entries_month: Optional[bool] - entries_total: Optional[bool] - most_used_tags: Optional[bool] + entries_hour: Optional[bool] = None + entries_day: Optional[bool] = None + entries_week: Optional[bool] = None + entries_month: Optional[bool] = None + entries_total: Optional[bool] = None + most_used_tags: Optional[bool] = None class UpdateStatsRequest(BaseModel): stats_version: int - stats_type: List[str] = [] - timescale: List[str] = [] + stats_type: List[str] = Field(default_factory=list) + timescale: List[str] = Field(default_factory=list) push_to_bucket: Optional[bool] = True @@ -173,11 +173,15 @@ class ListJournalEntriesResponse(BaseModel): class CreateJournalEntryTagRequest(BaseModel): journal_entry_id: uuid.UUID - tags: List[str] + tags: List[str] = Field(default_factory=list) + + +class CreateEntriesTagsRequest(BaseModel): + entries: List[CreateJournalEntryTagRequest] = Field(default_factory=list) class CreateJournalEntryTagsAPIRequest(BaseModel): - tags: List[str] + tags: List[str] = Field(default_factory=list) class DeleteJournalEntryTagAPIRequest(BaseModel): @@ -191,7 +195,11 @@ class DeleteJournalEntriesByTagsAPIRequest(BaseModel): class JournalEntryTagsResponse(BaseModel): journal_id: uuid.UUID entry_id: uuid.UUID - tags: List[str] + tags: List[str] = Field(default_factory=list) + + +class JournalsEntriesTagsResponse(BaseModel): + entries: List[JournalEntryTagsResponse] = Field(default_factory=list) class JournalEntriesByTagsDeletionResponse(BaseModel): @@ -210,8 +218,8 @@ class JournalSearchResult(BaseModel): entry_url: str content_url: str title: str - content: Optional[str] - tags: List[str] + content: Optional[str] = None + tags: List[str] = Field(default_factory=list) created_at: str updated_at: str score: float @@ -278,7 +286,7 @@ class ContextSpec(BaseModel): class JournalTTLRuleResponse(BaseModel): id: int - journal_id: Optional[uuid.UUID] + journal_id: Optional[uuid.UUID] = None name: str conditions: Dict[str, Any] action: RuleActions diff --git a/spire/journal/search.py b/spire/journal/search.py index 6c3a8c6..5f52149 100644 --- a/spire/journal/search.py +++ b/spire/journal/search.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Optional, List, Tuple, Union from uuid import UUID -from dateutil.parser import parse as parse_datetime +from dateutil.parser import parse as parse_datetime # type: ignore import elasticsearch from elasticsearch.client import IndicesClient from elasticsearch.helpers import bulk diff --git a/spire/middleware.py b/spire/middleware.py index f23cf73..f523725 100644 --- a/spire/middleware.py +++ b/spire/middleware.py @@ -4,7 +4,7 @@ from tokenize import group from typing import Callable, Awaitable, List, Optional -import requests +import requests # type: ignore from starlette.middleware.base import BaseHTTPMiddleware from fastapi import Request, Response diff --git a/spire/slack/api.py b/spire/slack/api.py index accda56..1249f45 100644 --- a/spire/slack/api.py +++ b/spire/slack/api.py @@ -13,7 +13,7 @@ HTTPException, ) -import requests +import requests # type: ignore from starlette.responses import RedirectResponse from sqlalchemy.orm import Session diff --git a/spire/slack/commands.py b/spire/slack/commands.py index bb79f9e..b1cab39 100644 --- a/spire/slack/commands.py +++ b/spire/slack/commands.py @@ -10,7 +10,7 @@ import urllib import urllib.parse -import requests +import requests # type: ignore from sqlalchemy.orm import Session from . import admin as slack_admin diff --git a/spire/slack/data.py b/spire/slack/data.py index f5e3745..836c231 100644 --- a/spire/slack/data.py +++ b/spire/slack/data.py @@ -21,9 +21,9 @@ class Index(BaseModel): class BroodUser(BaseModel): id: uuid.UUID - username: Optional[str] - email: Optional[str] - token: Optional[uuid.UUID] + username: Optional[str] = None + email: Optional[str] = None + token: Optional[uuid.UUID] = None class BroodGroup(BaseModel): diff --git a/spire/slack/indices.py b/spire/slack/indices.py index 8e4e39e..3282f98 100644 --- a/spire/slack/indices.py +++ b/spire/slack/indices.py @@ -8,7 +8,7 @@ from typing import Any, Callable, cast, Dict, List, Union import uuid -import requests +import requests # type: ignore from sqlalchemy.orm import Session from .data import Index diff --git a/spire/slack/reactions.py b/spire/slack/reactions.py index a9e8b44..fab1bc9 100644 --- a/spire/slack/reactions.py +++ b/spire/slack/reactions.py @@ -6,7 +6,7 @@ import logging from typing import Any, cast, Dict, List, Union -import requests +import requests # type: ignore from concurrent.futures import ThreadPoolExecutor diff --git a/spire/utils/confparse.py b/spire/utils/confparse.py index 5db9e70..ed45d93 100644 --- a/spire/utils/confparse.py +++ b/spire/utils/confparse.py @@ -1,5 +1,5 @@ from pathlib import Path -import toml +import toml # type: ignore MODULE_PATH = Path(__file__).parent.parent.parent.resolve()