Skip to content

Commit

Permalink
Merge pull request #64 from bugout-dev/spire-entity-extention-v2
Browse files Browse the repository at this point in the history
Spire entity extention v2
  • Loading branch information
kompotkot committed Jul 27, 2023
2 parents 5e3c72b + 6e0f395 commit 0173753
Show file tree
Hide file tree
Showing 11 changed files with 1,458 additions and 627 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"multidict",
"protobuf==3.19.1",
"psycopg2-binary>=2.9.1",
"pydantic",
"pydantic<=1.10.2",
"PyJWT==1.7.1",
"redis",
"requests",
Expand All @@ -56,6 +56,7 @@
"typed-ast",
"uvicorn>=0.17.6",
"uvloop",
"web3>=5.30.0",
"websockets",
"yarl",
],
Expand Down
14 changes: 5 additions & 9 deletions spire/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from fastapi.middleware.cors import CORSMiddleware

from .data import PingResponse, VersionResponse
from .github.api import app as github_api
from .go.api import app as go_api
from .humbug.api import app as humbug_app
from .journal.api import app as journal_api
from .preferences.api import app as preferences_api
from .public.api import app_public as public_api
from .slack.api import app as slack_api
from .github.api import app as github_api
from .preferences.api import app as preferences_api
from .humbug.api import app as humbug_app
from .utils.settings import SPIRE_RAW_ORIGINS_LST
from .version import SPIRE_VERSION

LOG_LEVEL = logging.INFO
Expand All @@ -26,14 +27,9 @@

app = FastAPI(openapi_url=None)

# CORS configuration
origins_raw = os.environ.get("SPIRE_CORS_ALLOWED_ORIGINS")
if origins_raw is None:
raise ValueError("SPIRE_CORS_ALLOWED_ORIGINS environment variable must be set")
origins = origins_raw.split(",")
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_origins=SPIRE_RAW_ORIGINS_LST,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
Expand Down
143 changes: 106 additions & 37 deletions spire/journal/actions.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,58 @@
"""
Journal-related actions in Spire
"""
from datetime import date, timedelta, datetime
import calendar
import json
import logging
import os
import time
from typing import Any, Dict, List, Set, Optional, Tuple, Union
from datetime import date, datetime, timedelta
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from uuid import UUID, uuid4

import boto3

from sqlalchemy.orm import Session, Query
from sqlalchemy import or_, func, text, and_, select
from fastapi import HTTPException, Request
from sqlalchemy import and_, func, or_, select, text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Query, Session


from ..broodusers import bugout_api
from ..utils.confparse import scope_conf
from ..utils.settings import BUGOUT_CLIENT_ID_HEADER
from .data import (
JournalScopes,
JournalEntryScopes,
CreateJournalRequest,
JournalSpec,
JournalStatisticsSpecs,
ContextSpec,
CreateEntriesTagsRequest,
CreateJournalEntryRequest,
JournalEntryListContent,
CreateJournalEntryTagRequest,
CreateEntriesTagsRequest,
CreateJournalRequest,
EntitiesResponse,
EntityList,
EntityResponse,
EntryRepresentationTypes,
JournalEntryListContent,
JournalEntryResponse,
JournalEntryScopes,
JournalPermission,
JournalResponse,
JournalScopes,
JournalSearchResultsResponse,
JournalSpec,
JournalStatisticsResponse,
UpdateJournalSpec,
JournalStatisticsSpecs,
ListJournalEntriesResponse,
JournalEntryResponse,
JournalPermission,
ContextSpec,
ListJournalsResponse,
UpdateJournalSpec,
)
from .models import (
HolderType,
Journal,
JournalEntry,
JournalEntryLock,
JournalEntryTag,
JournalPermissions,
HolderType,
SpireOAuthScopes,
)
from ..utils.confparse import scope_conf
from ..broodusers import bugout_api
from .representations import journal_representation_parsers, parse_entity_to_entry

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,6 +116,18 @@ class CommitFailed(Exception):
"""


def bugout_client_id_from_request(request: Request) -> Optional[str]:
"""
Returns Bugout search client ID from request if it has been passed.
"""
bugout_client_id: Optional[str] = request.headers.get(BUGOUT_CLIENT_ID_HEADER)
# We are deprecating the SIMIOTICS_CLIENT_ID_HEADER header in favor of BUGOUT_CLIENT_ID_HEADER, but
# this needs to be here for legacy support.
if bugout_client_id is None:
bugout_client_id = request.headers.get("x-simiotics-client-id")
return bugout_client_id


def acl_auth(
db_session: Session, user_id: str, user_group_id_list: List[str], journal_id: UUID
) -> Tuple[Journal, Dict[HolderType, List[str]]]:
Expand Down Expand Up @@ -192,6 +211,38 @@ def acl_check(
raise PermissionsNotFound("No permissions for requested information")


def ensure_journal_permission(
db_session: Session,
user_id: str,
user_group_ids: List[str],
journal_id: UUID,
required_scopes: Set[Union[JournalScopes, JournalEntryScopes]],
) -> Journal:
"""
Checks if the given user (who is a member of the groups specified by user_group_ids) holds the
given scope on the journal specified by journal_id.
Returns: None if the user is a holder of that scope, and raises the appropriate HTTPException
otherwise.
"""
try:
journal, acl = acl_auth(db_session, user_id, user_group_ids, journal_id)
acl_check(acl, required_scopes)
except PermissionsNotFound:
logger.error(
f"User (id={user_id}) does not have the appropriate permissions (scopes={required_scopes}) "
f"for journal (id={journal_id})"
)
raise HTTPException(status_code=404)
except Exception:
logger.error(
f"Error checking permissions for user (id={user_id}) in journal (id={journal_id})"
)
raise HTTPException(status_code=500)

return journal


async def find_journals(
db_session: Session, user_id: UUID, user_group_id_list: Optional[List[str]] = None
) -> List[Journal]:
Expand Down Expand Up @@ -388,7 +439,6 @@ async def journal_statistics(
tags: List[str],
user_group_id_list: Optional[List[str]] = None,
) -> JournalStatisticsResponse:

"""
Return journals statistics.
For now just amount of entries for default periods.
Expand Down Expand Up @@ -549,52 +599,75 @@ async def create_journal_entry(
async def create_journal_entries_pack(
db_session: Session,
journal_id: UUID,
entries_pack_request: JournalEntryListContent,
entries_pack_request: Union[JournalEntryListContent, EntityList],
) -> ListJournalEntriesResponse:
"""
Bulk pack of entries to database.
"""
representation: EntryRepresentationTypes
if type(entries_pack_request) == JournalEntryListContent:
representation = EntryRepresentationTypes.ENTRY
elif type(entries_pack_request) == EntityList:
representation = EntryRepresentationTypes.ENTITY

entries_response = ListJournalEntriesResponse(entries=[])

chunk_size = 50
chunks = [
entries_pack_request.entries[i : i + chunk_size]
for i in range(0, len(entries_pack_request.entries), chunk_size)
]
e_list = []
if representation == EntryRepresentationTypes.ENTRY:
e_list = entries_pack_request.entries
elif representation == EntryRepresentationTypes.ENTITY:
e_list = entries_pack_request.entities
chunks = [e_list[i : i + chunk_size] for i in range(0, len(e_list), chunk_size)]
logger.info(
f"Entries pack split into to {len(chunks)} chunks for journal {str(journal_id)}"
)

for chunk in chunks:
entries_pack = []
entries_tags_pack = []

for entry_request in chunk:
entry_id = uuid4()

title: str = ""
tags: Optional[List[str]] = None
content: str = ""
if representation == EntryRepresentationTypes.ENTRY:
title = entry_request.title
tags = entry_request.tags
content = entry_request.content
elif representation == EntryRepresentationTypes.ENTITY:
title, tags, content_raw = parse_entity_to_entry(
create_entity=entry_request,
)
content = json.dumps(content_raw)

entries_pack.append(
JournalEntry(
id=entry_id,
journal_id=journal_id,
title=entry_request.title,
content=entry_request.content,
title=title,
content=content,
context_id=entry_request.context_id,
context_url=entry_request.context_url,
context_type=entry_request.context_type,
created_at=entry_request.created_at,
)
)
if entry_request.tags is not None:
if tags is not None:
entries_tags_pack += [
JournalEntryTag(journal_entry_id=entry_id, tag=tag)
for tag in entry_request.tags
for tag in tags
if tag
]

entries_response.entries.append(
JournalEntryResponse(
id=entry_id,
title=entry_request.title,
content=entry_request.content,
tags=entry_request.tags if entry_request.tags is not None else [],
title=title,
content=content,
tags=tags if tags is not None else [],
context_url=entry_request.context_url,
context_type=entry_request.context_type,
context_id=entry_request.context_id,
Expand Down Expand Up @@ -1102,7 +1175,6 @@ async def create_journal_entries_tags(
journal: Journal,
entries_tags_request: CreateEntriesTagsRequest,
) -> List[UUID]:

"""
Create tags for entries in journal.
"""
Expand Down Expand Up @@ -1142,7 +1214,6 @@ async def delete_journal_entries_tags(
journal: Journal,
entries_tags_request: CreateEntriesTagsRequest,
) -> List[UUID]:

"""
Delete tags for entries in journal.
"""
Expand Down Expand Up @@ -1492,14 +1563,12 @@ async def entries_exists_check(
async def dedublicate_entries_tags(
entries_tags: CreateEntriesTagsRequest,
) -> List[Dict[str, Any]]:

values: List[Dict[str, Any]] = []

for entry_tag_request in entries_tags.entries:
entry_id = entry_tag_request.journal_entry_id

for tag in entry_tag_request.tags:

insert_object = {
"journal_entry_id": entry_id,
"tag": tag,
Expand Down
Loading

0 comments on commit 0173753

Please sign in to comment.