Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Enable mypy in lint checking #7

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ repos:
rev: 23.9.1
hooks:
- id: black
files: backend|client/python
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.292
hooks:
- id: ruff
files: backend|client/python
args:
- --fix
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.4.1"
hooks:
- id: mypy
args: ["--config-file", "ingestion_tools/pyproject.toml"]
additional_dependencies: ["types-PyYAML", "types-dateparser", "types-requests"]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
Expand Down
27 changes: 26 additions & 1 deletion ingestion_tools/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion ingestion_tools/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ imageio = "^2.33.1"
pytest = "^8.0.0"
boto3-stubs = {extras = ["s3"], version = "^1.34.34"}
mypy = "^1.8.0"
types-requests = "^2.31.0.20240125"
types-dateparser = "^1.1.4.20240106"

[tool.black]
line-length = 120
Expand All @@ -68,10 +70,17 @@ select = [
ignore = [
"E501", # line too long
"C408", # rewrite empty built-ins as literals
"T201", # print statements.
"DTZ007", # Datetime objects without timezones.
"DTZ005", # More datetimes without timezones.
]
line-length = 120
target-version = "py39"

[tool.ruff.lint.per-file-ignores]
# Ignore `SIM115` (not using open() in a context manager) since all calls to this method *do* use a context manager.
"scripts/common/fs.py" = ["SIM115"]

[tool.ruff.isort]
known-first-party =["common"]

Expand All @@ -82,4 +91,6 @@ docstring-quotes = "double"
show_error_codes = true
ignore_missing_imports = true
warn_unreachable = true
strict = true
explicit_package_bases = true
# We'd like to turn this on soon but we're not there yet.
# disallow_untyped_defs = true
42 changes: 19 additions & 23 deletions ingestion_tools/scripts/common/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import contextlib
import csv
import os
import os.path
import re
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

import yaml

Expand Down Expand Up @@ -35,10 +36,10 @@ class DataImportConfig:
run_regex: re.Pattern
tomo_glob: str
tomo_format: str
tomo_regex: re.Pattern = None
tomo_key_photo_glob: str = None
tomo_regex: Optional[re.Pattern] = None
tomo_key_photo_glob: Optional[str] = None
tomo_voxel_size: str
ts_name_regex: re.Pattern = None
ts_name_regex: Optional[re.Pattern] = None
run_name_regex: re.Pattern
frames_name_regex: re.Pattern
frames_glob: str
Expand Down Expand Up @@ -109,13 +110,11 @@ def load_map_files(self):
self.load_run_ts_map()
self.load_run_data_map()

def load_run_metadata_file(self, file_attr: str):
mapdata = {}
def load_run_metadata_file(self, file_attr: str) -> dict[str, Any]:
mapdata: dict[str, Any] = {}
map_filename = None
try:
with contextlib.suppress(AttributeError):
map_filename = getattr(self, file_attr)
except AttributeError:
pass
if not map_filename:
return mapdata
with self.fs.open(f"{self.input_path}/{map_filename}", "r") as tsvfile:
Expand All @@ -127,13 +126,11 @@ def load_run_metadata_file(self, file_attr: str):
mapdata[row["run_name"]] = row
return mapdata

def load_run_csv_file(self, file_attr: str):
mapdata = {}
def load_run_csv_file(self, file_attr: str) -> dict[str, Any]:
mapdata: dict[str, Any] = {}
map_filename = None
try:
with contextlib.suppress(AttributeError):
map_filename = getattr(self, file_attr)
except AttributeError:
pass
if not map_filename:
return mapdata
with self.fs.open(f"{self.input_path}/{map_filename}", "r") as csvfile:
Expand Down Expand Up @@ -181,7 +178,7 @@ def get_run_data_map(self, run_name) -> dict[str, Any]:
return {}

def expand_string(self, run_name: str, string_template: Any) -> int | float | str:
if type(string_template) != str:
if not isinstance(string_template, str):
return string_template
if run_data := self.get_run_data_map(run_name):
string_template = string_template.format(**run_data)
Expand All @@ -195,16 +192,16 @@ def expand_string(self, run_name: str, string_template: Any) -> int | float | st

def expand_metadata(self, run_name: str, metadata_dict: dict[str, Any]) -> dict[str, Any]:
for k, v in metadata_dict.items():
if type(v) == str:
if isinstance(v, str):
metadata_dict[k] = self.expand_string(run_name, v)
elif (type(v)) == dict:
elif isinstance(v, dict):
metadata_dict[k] = self.expand_metadata(run_name, v)
elif (type(v)) == list:
elif isinstance(v, list):
for idx in range(len(v)):
# Note - we're not supporting deeply nested lists,
# but we don't need to with our current data model.
item = v[idx]
if type(item) == str:
if isinstance(item, str):
v[idx] = self.expand_string(run_name, item)
return metadata_dict

Expand All @@ -229,10 +226,11 @@ def get_expanded_metadata(self, obj) -> dict[str, Any]:

def get_run_override(self, run_name: str) -> RunOverride | None:
if not self.overrides_by_run:
return
return None
for item in self.overrides_by_run:
if item.run_regex.match(run_name):
return item
return None

def get_metadata_path(self, obj: BaseImporter) -> str:
key = f"{obj.type_key}_metadata"
Expand Down Expand Up @@ -266,10 +264,8 @@ def glob_files(self, obj: BaseImporter, globstring: str) -> list[str]:
if not globstring:
return []
globvars = run.get_glob_vars()
try:
with contextlib.suppress(ValueError):
globvars["int_run_name"] = int(run.run_name)
except ValueError:
pass
expanded_glob = os.path.join(self.dataset_root_dir, globstring.format(**globvars))
results = self.fs.glob(expanded_glob)
if not results:
Expand Down
17 changes: 8 additions & 9 deletions ingestion_tools/scripts/common/fs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import glob
import os
import os.path
Expand All @@ -12,7 +13,7 @@ class FileSystemApi:
force_overwrite: bool = False

@classmethod
def get_fs_api(cls, mode: str, force_overwrite: bool, client_kwargs: None | dict[str, str]=None):
def get_fs_api(cls, mode: str, force_overwrite: bool, client_kwargs: None | dict[str, str] = None):
if mode == "s3":
return S3Filesystem(force_overwrite=force_overwrite, client_kwargs=client_kwargs)
else:
Expand All @@ -24,19 +25,19 @@ def glob(self, *args):
def open(self, path: str, mode: str):
pass

def localreadable(self, path: str) -> str:
def localreadable(self, path: str) -> str: # type: ignore
pass

def makedirs(self, path: str):
def makedirs(self, path: str) -> None:
pass

def localwritable(self, path) -> str:
def localwritable(self, path) -> str: # type: ignore
pass

def push(self, path):
pass

def destformat(self, path) -> str:
def destformat(self, path) -> str: # type: ignore
pass

def copy(self, src_path: str, dest_path: str):
Expand Down Expand Up @@ -89,10 +90,8 @@ def push(self, path):
remote_file = os.path.relpath(path, self.tmpdir)
src_size = os.path.getsize(path)
dest_size = 0
try:
with contextlib.suppress(FileNotFoundError):
dest_size = self.s3fs.size(remote_file)
except FileNotFoundError:
pass
if src_size == dest_size:
if self.force_overwrite:
print(f"Forcing re-upload of {path}")
Expand Down Expand Up @@ -156,7 +155,7 @@ def localreadable(self, path) -> str:
def localwritable(self, path) -> str:
return path

def makedirs(self, path: str):
def makedirs(self, path: str) -> None:
os.makedirs(path, exist_ok=True)

def destformat(self, path) -> str:
Expand Down
46 changes: 19 additions & 27 deletions ingestion_tools/scripts/common/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import os.path
from datetime import datetime
from typing import Any, List, Callable
from typing import Any, Callable, List, Optional

import mrcfile
import numpy as np
Expand Down Expand Up @@ -34,9 +34,7 @@ def __init__(self, fs: FileSystemApi, mrc_filename: str, header_only: bool = Fal
self.mrc_filename = fs.read_block(mrc_filename)
else:
self.mrc_filename = fs.localreadable(mrc_filename)
with mrcfile.open(
self.mrc_filename, permissive=True, header_only=header_only
) as mrc:
with mrcfile.open(self.mrc_filename, permissive=True, header_only=header_only) as mrc:
if mrc.data is None and not header_only:
raise Exception("missing mrc data")
self.header = mrc.header
Expand All @@ -62,8 +60,8 @@ def pyramid_to_mrc(
pyramid: List[np.ndarray],
mrc_filename: str,
write: bool = True,
header_mapper: Callable[[np.array], None] = None,
voxel_spacing: float = None,
header_mapper: Optional[Callable[[np.array], None]] = None,
voxel_spacing: Optional[float] = None,
) -> List[str]:
mrcfiles = []
# NOTE - 2023-10-24
Expand All @@ -83,7 +81,6 @@ def pyramid_to_mrc(
print(f"skipping remote push for {filename}")
return mrcfiles


def pyramid_to_omezarr(
self,
fs: FileSystemApi,
Expand Down Expand Up @@ -120,7 +117,7 @@ def update_headers(self, mrcfile: MrcFile, header_mapper, voxel_spacing):
header.cella.y = isotropic_voxel_size * data.shape[1]
header.cella.z = isotropic_voxel_size * data.shape[0]
header.label[0] = "{0:40s}{1:>39s}".format("Validated by cryoET data portal.", time)
header.rms = np.sqrt(np.mean((data - np.mean(data))**2))
header.rms = np.sqrt(np.mean((data - np.mean(data)) ** 2))
header.extra1 = self.header.extra1
header.extra2 = self.header.extra2

Expand All @@ -130,7 +127,7 @@ def update_headers(self, mrcfile: MrcFile, header_mapper, voxel_spacing):
header.exttyp = self.header.exttyp
else:
header.nsymbt = np.array(0, dtype="i4")
header.exttyp = np.array(b'MRCO', dtype="S4")
header.exttyp = np.array(b"MRCO", dtype="S4")

if header_mapper:
header_mapper(header)
Expand Down Expand Up @@ -168,10 +165,12 @@ def get_tomo_metadata(
scales = []
size: dict[str, float] = {}
omezarr_dir = fs.destformat(f"{output_prefix}.zarr")
zarrinfo = json.loads(open(fs.localreadable(os.path.join(omezarr_dir, ".zattrs")), "r").read())
with open(fs.localreadable(os.path.join(omezarr_dir, ".zattrs")), "r") as fh:
zarrinfo = json.loads(fh.read())
multiscales = zarrinfo["multiscales"][0]["datasets"]
for scale in multiscales:
scaleinfo = json.loads(open(fs.localreadable(os.path.join(omezarr_dir, scale["path"], ".zarray")), "r").read())
with open(fs.localreadable(os.path.join(omezarr_dir, scale["path"], ".zarray")), "r") as fh:
scaleinfo = json.loads(fh.read())
shape = scaleinfo["shape"]
dims = {"z": shape[0], "y": shape[1], "x": shape[2]}
if not size:
Expand All @@ -196,26 +195,19 @@ def get_header(fs: FileSystemApi, tomo_filename: str) -> MrcObject:


def scale_mrcfile(
fs: FileSystemApi,
output_prefix: str,
tomo_filename: str,
scale_z_axis: bool = True,
write_mrc: bool = True,
write_zarr: bool = True,
header_mapper: Callable[[np.array], None] = None,
voxel_spacing=None,
fs: FileSystemApi,
output_prefix: str,
tomo_filename: str,
scale_z_axis: bool = True,
write_mrc: bool = True,
write_zarr: bool = True,
header_mapper: Optional[Callable[[np.array], None]] = None,
voxel_spacing=None,
):
tc = TomoConverter(fs, tomo_filename)
pyramid = tc.make_pyramid(scale_z_axis=scale_z_axis)
_ = tc.pyramid_to_omezarr(fs, pyramid, f"{output_prefix}.zarr", write_zarr)
_ = tc.pyramid_to_mrc(
fs,
pyramid,
f"{output_prefix}.mrc",
write_mrc,
header_mapper,
voxel_spacing
)
_ = tc.pyramid_to_mrc(fs, pyramid, f"{output_prefix}.mrc", write_mrc, header_mapper, voxel_spacing)


def scale_maskfile(
Expand Down
Loading
Loading