Skip to content

Commit

Permalink
feat: add frame & gain ingestion to db import (#308)
Browse files Browse the repository at this point in the history
* Remove deprecated `Tomogram.is_standardized field`
  • Loading branch information
jgadling authored Oct 14, 2024
1 parent 740cd4c commit 7a167d6
Show file tree
Hide file tree
Showing 39 changed files with 732 additions and 306 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions apiv2/database/models/frame.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion apiv2/database/models/gain_file.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion apiv2/database/models/tomogram.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

103 changes: 103 additions & 0 deletions apiv2/db_import/common/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import json
import logging
from datetime import datetime
from pathlib import PurePath
from typing import TYPE_CHECKING, Any

from botocore.exceptions import ClientError
from sqlalchemy.orm import Session

if TYPE_CHECKING:
from mypy_boto3_s3 import S3Client
else:
S3Client = object

logger = logging.getLogger("config")


class DBImportConfig:
s3_client: S3Client
bucket_name: str
s3_prefix: str
https_prefix: str
session: Session

def __init__(
self,
s3_client: S3Client,
bucket_name: str,
https_prefix: str,
session: Session,
):
self.s3_client = s3_client
self.bucket_name = bucket_name
self.s3_prefix = f"s3://{bucket_name}"
self.https_prefix = https_prefix if https_prefix else "https://files.cryoetdataportal.cziscience.com"
self.session = session

def get_db_session(self) -> Session:
return self.session

def find_subdirs_with_files(self, prefix: str, target_filename: str) -> list[str]:
paginator = self.s3_client.get_paginator("list_objects_v2")
logger.info("looking for prefix %s", prefix)
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix, Delimiter="/")
result = []
for page in pages:
for obj in page.get("CommonPrefixes", []):
try:
subdir = obj["Prefix"]
self.s3_client.head_object(Bucket=self.bucket_name, Key=f"{subdir}{target_filename}")
result.append(subdir)
except Exception:
continue
return result

def glob_s3(self, prefix: str, glob_string: str, is_file: bool = True):
paginator = self.s3_client.get_paginator("list_objects_v2")
if prefix.startswith("s3://"):
prefix = "/".join(prefix.split("/")[3:])
logger.info("looking for prefix %s%s", prefix, glob_string)
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix, Delimiter="/")
page_key = "Contents" if is_file else "CommonPrefixes"
obj_key = "Key" if is_file else "Prefix"
for page in pages:
for obj in page.get(page_key, {}):
if not obj:
break
if PurePath(obj[obj_key]).match(glob_string):
yield obj[obj_key]

def load_key_json(self, key: str, is_file_required: bool = True) -> dict[str, Any] | None:
"""
Loads file matching the key value as json. If file does not exist, will raise error if is_file_required is True
else it will return None.
"""
try:
text = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
return json.loads(text["Body"].read())
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey" and not is_file_required:
logger.warning("NoSuchKey on bucket_name=%s key=%s", self.bucket_name, key)
return None
else:
raise


def map_to_value(db_key: str, mapping: dict[str, Any], data: dict[str, Any]) -> Any:
"""
For a key, it maps to value by traversing the json as specified in the mapping parameter or returns the precomputed
value if mapping has a non list value.
"""
data_path = mapping.get(db_key)
if not isinstance(data_path, list):
return data_path

value = None
for path_part in data_path:
value = data.get(path_part) if not value else value.get(path_part)
if not value:
break
if value and "date" in db_key:
value = datetime.strptime(value, "%Y-%m-%d") # type: ignore
return value
44 changes: 44 additions & 0 deletions apiv2/db_import/common/finders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
import re
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

from db_import.common.config import DBImportConfig

from platformics.database.models.base import Base

logger = logging.getLogger("db_import")

if TYPE_CHECKING:
from db_import.importers.base import ItemDBImporter
else:
ItemDBImporter = Any


class ItemFinder(ABC):
@abstractmethod
def __init__(self, config: DBImportConfig, **kwargs):
pass

@abstractmethod
def find(self, item_importer: ItemDBImporter) -> list[ItemDBImporter]:
pass


class FileFinder(ItemFinder):
def __init__(self, config: DBImportConfig, path: str, glob: str, match_regex: str | None):
self.config = config
self.path = path
self.glob = glob
self.match_regex = None
if match_regex:
self.match_regex = re.compile(match_regex)

def find(self, item_importer: ItemDBImporter, parents: dict[str, Base]) -> list[ItemDBImporter]:
results: list[ItemDBImporter] = []
for file in self.config.glob_s3(self.path, self.glob, is_file=True):
if self.match_regex.match(file):
data = {"file": file}
data.update(parents)
results.append(item_importer(self.config, data))
return results
23 changes: 21 additions & 2 deletions apiv2/db_import/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
import click
from botocore import UNSIGNED
from botocore.config import Config
from db_import.common.config import DBImportConfig
from db_import.importers.annotation import (
AnnotationAuthorDBImporter,
AnnotationDBImporter,
AnnotationMethodLinkDBImporter,
StaleAnnotationDeletionDBImporter,
)
from db_import.importers.base_importer import DBImportConfig
from db_import.importers.dataset import DatasetAuthorDBImporter, DatasetDBImporter, DatasetFundingDBImporter
from db_import.importers.deposition import DepositionAuthorDBImporter, DepositionDBImporter, DepositionTypeDBImporter
from db_import.importers.frame import FrameImporter
from db_import.importers.gain import GainImporter
from db_import.importers.run import RunDBImporter, StaleRunDeletionDBImporter
from db_import.importers.tiltseries import StaleTiltSeriesDeletionDBImporter, TiltSeriesDBImporter
from db_import.importers.tomogram import StaleTomogramDeletionDBImporter, TomogramAuthorDBImporter, TomogramDBImporter
Expand All @@ -39,6 +41,8 @@ def db_import_options(func):
options.append(click.option("--import-dataset-funding", is_flag=True, default=False))
options.append(click.option("--import-depositions", is_flag=True, default=False))
options.append(click.option("--import-runs", is_flag=True, default=False))
options.append(click.option("--import-gains", is_flag=True, default=False))
options.append(click.option("--import-frames", is_flag=True, default=False))
options.append(click.option("--import-tiltseries", is_flag=True, default=False))
options.append(click.option("--import-tomograms", is_flag=True, default=False))
options.append(click.option("--import-tomogram-authors", is_flag=True, default=False))
Expand Down Expand Up @@ -82,6 +86,8 @@ def load(
import_dataset_funding: bool,
import_depositions: bool,
import_runs: bool,
import_gains: bool,
import_frames: bool,
import_tiltseries: bool,
import_tomograms: bool,
import_tomogram_authors: bool,
Expand All @@ -104,6 +110,8 @@ def load(
import_dataset_funding,
import_depositions,
import_runs,
import_gains,
import_frames,
import_tiltseries,
import_tomograms,
import_tomogram_authors,
Expand All @@ -128,6 +136,8 @@ def load_func(
import_dataset_funding: bool = False,
import_depositions: bool = False,
import_runs: bool = False,
import_gains: bool = False,
import_frames: bool = False,
import_tiltseries: bool = False,
import_tomograms: bool = False,
import_tomogram_authors: bool = False,
Expand All @@ -149,6 +159,8 @@ def load_func(
import_dataset_funding = True
import_depositions = True
import_runs = True
import_gains = True
import_frames = True
import_tiltseries = True
import_tomograms = True
import_tomogram_authors = True
Expand All @@ -157,7 +169,7 @@ def load_func(
import_annotations = max(import_annotations, import_annotation_authors, import_annotation_method_links)
import_tomograms = max(import_tomograms, import_tomogram_authors)
import_tomogram_voxel_spacing = max(import_annotations, import_tomograms, import_tomogram_voxel_spacing)
import_runs = max(import_runs, import_tiltseries, import_tomogram_voxel_spacing)
import_runs = max(import_runs, import_gains, import_frames, import_tiltseries, import_tomogram_voxel_spacing)

s3_config = Config(signature_version=UNSIGNED) if anonymous else None
s3_client = boto3.client("s3", endpoint_url=endpoint_url, config=s3_config)
Expand Down Expand Up @@ -198,6 +210,13 @@ def load_func(
run_id = run_obj.id
run_cleaner.mark_as_active(run_obj)

parents = {"run": run_obj, "dataset": dataset_obj}
if import_frames:
frame_importer = FrameImporter(config, **parents)
frame_importer.import_items()
if import_gains:
gain_importer = GainImporter(config, **parents)
gain_importer.import_items()
if import_tiltseries:
tiltseries_cleaner = StaleTiltSeriesDeletionDBImporter(run_id, config)
tiltseries = TiltSeriesDBImporter.get_item(run_id, run, config)
Expand Down
2 changes: 1 addition & 1 deletion apiv2/db_import/importers/annotation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Iterator

from database import models
from db_import.common.config import DBImportConfig
from db_import.importers.base_importer import (
AuthorsStaleDeletionDBImporter,
BaseDBImporter,
DBImportConfig,
StaleDeletionDBImporter,
StaleParentDeletionDBImporter,
)
Expand Down
Loading

0 comments on commit 7a167d6

Please sign in to comment.