Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add delete entries_tags and create entries_tags endpoints. #62

Merged
merged 6 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 240 additions & 1 deletion spire/journal/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import boto3

from sqlalchemy.orm import Session, Query
from sqlalchemy import or_, func, text, and_
from sqlalchemy import or_, func, text, and_, select
from sqlalchemy.dialects import postgresql


Expand All @@ -26,6 +26,7 @@
CreateJournalEntryRequest,
JournalEntryListContent,
CreateJournalEntryTagRequest,
CreateEntriesTagsRequest,
JournalSearchResultsResponse,
JournalStatisticsResponse,
UpdateJournalSpec,
Expand Down Expand Up @@ -67,6 +68,17 @@ class EntryNotFound(Exception):
"""


# Excption with list of not found entries
class EntriesNotFound(Exception):
"""
Raised on actions that involve journal entries which are not present in the database.
"""

def __init__(self, message: str, entries: List[UUID] = []):
super().__init__(message)
self.entries = entries


class EntryLocked(Exception):
"""
Raised on actions when entry is not released for editing by other users.
Expand All @@ -91,6 +103,12 @@ class InvalidParameters(ValueError):
"""


class CommitFailed(Exception):
"""
Raised when commit failed.
"""


def acl_auth(
db_session: Session, user_id: str, user_group_id_list: List[str], journal_id: UUID
) -> Tuple[Journal, Dict[HolderType, List[str]]]:
Expand Down Expand Up @@ -687,6 +705,71 @@ async def get_journal_entry_with_tags(
return entry, tags, entry_lock


async def get_journal_entries_with_tags(
db_session: Session, journal_entries_ids: List[UUID]
) -> List[JournalEntryResponse]:
"""
Returns a journal entries by its id with tags.
"""
objects = (
db_session.query(JournalEntry, JournalEntryTag.tag)
.join(
JournalEntryTag,
JournalEntryTag.journal_entry_id == JournalEntry.id,
isouter=True,
)
.join(
JournalEntryLock,
JournalEntryLock.journal_entry_id == JournalEntry.id,
isouter=True,
)
.filter(JournalEntry.id.in_(journal_entries_ids))
).cte("entries")

entries = (
db_session.query(
objects.c.id.label("id"),
objects.c.journal_id.label("journal_id"),
objects.c.title.label("title"),
objects.c.content.label("content"),
func.array_agg(objects.c.tag).label("tags"),
objects.c.created_at.label("created_at"),
objects.c.updated_at.label("updated_at"),
objects.c.context_url.label("context_url"),
objects.c.context_type.label("context_type"),
objects.c.context_id.label("context_id"),
)
.group_by(
objects.c.id,
objects.c.journal_id,
objects.c.title,
objects.c.content,
objects.c.created_at,
objects.c.updated_at,
objects.c.context_url,
objects.c.context_type,
objects.c.context_id,
)
.all()
)

return [
JournalEntryResponse(
id=entry.id,
title=entry.title,
content=entry.content,
tags=list(entry.tags if entry.tags != [None] else []),
context_url=entry.context_url,
context_type=entry.context_type,
context_id=entry.context_id,
created_at=entry.created_at,
updated_at=entry.updated_at,
locked_by=None,
)
for entry in entries
]


async def update_journal_entry(
db_session: Session,
new_title: str,
Expand Down Expand Up @@ -1014,6 +1097,104 @@ async def update_journal_entry_tags(
return query.all()


async def create_journal_entries_tags(
db_session: Session,
journal: Journal,
entries_tags_request: CreateEntriesTagsRequest,
) -> List[UUID]:

"""
Create tags for entries in journal.
"""

# For more useful error message
requested_entries = [
entry.journal_entry_id for entry in entries_tags_request.entries
]

await entries_exists_check(
db_session=db_session, journal_id=journal.id, entries_ids=requested_entries
)

deduplicated_values = await dedublicate_entries_tags(
entries_tags=entries_tags_request
)

insert_statement = (
postgresql.insert(JournalEntryTag)
.values(deduplicated_values)
.on_conflict_do_nothing(index_elements=["journal_entry_id", "tag"])
)

try:
db_session.execute(insert_statement)
db_session.commit()
except Exception as err:
logger.error(f"Could not create tags for entries error: {err}")
db_session.rollback()
raise CommitFailed("Could not create tags")

return requested_entries


async def delete_journal_entries_tags(
db_session: Session,
journal: Journal,
entries_tags_request: CreateEntriesTagsRequest,
) -> List[UUID]:

"""
Delete tags for entries in journal.
"""

requested_entries = [
entry.journal_entry_id for entry in entries_tags_request.entries
]

await entries_exists_check(
db_session=db_session, journal_id=journal.id, entries_ids=requested_entries
)

deduplicated_values = await dedublicate_entries_tags(
entries_tags=entries_tags_request
)

selected_tags = (
db_session.query(
JournalEntryTag.id.label("id"),
)
.join(JournalEntry, JournalEntryTag.journal_entry_id == JournalEntry.id)
.filter(JournalEntry.journal_id == journal.id)
.filter(
JournalEntryTag.journal_entry_id.in_(
[entry["journal_entry_id"] for entry in deduplicated_values]
)
)
.filter(
JournalEntryTag.tag.in_([entry["tag"] for entry in deduplicated_values])
)
.cte("selected_tags")
)

delete_statement = (
db_session.query(JournalEntryTag)
.filter(JournalEntryTag.id.in_(select(selected_tags.c.id)))
.delete(synchronize_session=False)
)

try:
db_session.commit()
logger.info(
f"Deleted {delete_statement} tags in journal {journal.id} for {len(requested_entries)} entries"
)
except Exception as err:
logger.error(f"Could not delete tags for entries error: {err}")
db_session.rollback()
raise CommitFailed("Could not delete tags")

return requested_entries


async def delete_journal_entry_tag(
db_session: Session,
journal_spec: JournalSpec,
Expand Down Expand Up @@ -1279,3 +1460,61 @@ async def delete_journal_scopes(
db_session.commit()

return permission_list


async def entries_exists_check(
db_session: Session,
journal_id: UUID,
entries_ids: List[UUID],
) -> None:
"""
Check if entries exists in journal.
"""

# Index scan for entries ids
existing_entries_obj = (
db_session.query(JournalEntry.id)
.filter(JournalEntry.journal_id == journal_id)
.filter(JournalEntry.id.in_(entries_ids))
.all()
)

### perfomance test https://stackoverflow.com/a/3462202/13271066

existing_entries: Set[UUID] = set([entry[0] for entry in existing_entries_obj])

diff = [x for x in entries_ids if x not in existing_entries]

if len(diff) > 0:
raise EntriesNotFound("Could not find some of the given entries", diff)


async def dedublicate_entries_tags(
entries_tags: CreateEntriesTagsRequest,
) -> List[Dict[str, Any]]:

values: List[Dict[str, Any]] = []

for entry_tag_request in entries_tags.entries:
entry_id = entry_tag_request.journal_entry_id

for tag in entry_tag_request.tags:

insert_object = {
"journal_entry_id": entry_id,
"tag": tag,
}

values.append(insert_object)

# Deduplicate tags

seen = set()
deduplicated_values = []
for d in values:
t = tuple(sorted(d.items()))
if t not in seen:
seen.add(t)
deduplicated_values.append(d)

return deduplicated_values
Loading