Skip to content

Commit

Permalink
TensorStore
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jun 24, 2024
1 parent 314d8ae commit 7f3f390
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 0 deletions.
47 changes: 47 additions & 0 deletions .github/workflows/tensorstore-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: TensorStore tests

on:
schedule:
# Every weekday at 03:58 UTC, see https://crontab.guru/
- cron: "58 3 * * 1-5"
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.11"]

steps:
- name: Checkout source
uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
architecture: x64

- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v2

- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install -e '.[test]' 'tensorstore'
- name: Run tests
run: |
# exclude tests that rely on the nchunks_initialized array attribute
pytest -k "not test_resume"
env:
CUBED_STORAGE_NAME: tensorstore
4 changes: 4 additions & 0 deletions cubed/storage/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def open_backend_array(
from cubed.storage.backends.zarr_python import open_zarr_array

open_func = open_zarr_array
elif storage_name == "tensorstore":
from cubed.storage.backends.tensorstore import open_tensorstore_array

open_func = open_tensorstore_array
else:
raise ValueError(f"Unrecognized storage name: {storage_name}")

Expand Down
155 changes: 155 additions & 0 deletions cubed/storage/backends/tensorstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import dataclasses
import math
from typing import Any, Dict, Optional

import numpy as np
import tensorstore

from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store
from cubed.utils import join_path


@dataclasses.dataclass(frozen=True)
class TensorStoreArray:
array: tensorstore.TensorStore

@property
def shape(self) -> tuple[int, ...]:
return self.array.shape

@property
def dtype(self) -> np.dtype:
return self.array.dtype.numpy_dtype

@property
def chunks(self) -> tuple[int, ...]:
return self.array.chunk_layout.read_chunk.shape or ()

@property
def ndim(self) -> int:
return len(self.shape)

@property
def size(self) -> int:
return math.prod(self.shape)

@property
def oindex(self):
return self.array.oindex

def __getitem__(self, key):
# read eagerly
return self.array.__getitem__(key).read().result()

def __setitem__(self, key, value):
self.array.__setitem__(key, value)


class TensorStoreGroup(dict):
def __init__(
self,
shape: Optional[T_Shape] = None,
dtype: Optional[T_DType] = None,
chunks: Optional[T_RegularChunks] = None,
):
dict.__init__(self)
self.shape = shape
self.dtype = dtype
self.chunks = chunks

def __getitem__(self, key):
if isinstance(key, str):
return super().__getitem__(key)
return {field: zarray[key] for field, zarray in self.items()}

def set_basic_selection(self, selection, value, fields=None):
self[fields][selection] = value


def encode_dtype(d):
if d.fields is None:
return d.str
else:
return d.descr


def get_metadata(dtype, chunks):
metadata = {}
if dtype is not None:
dtype = np.dtype(dtype)
metadata["dtype"] = encode_dtype(dtype)
if chunks is not None:
if isinstance(chunks, int):
chunks = (chunks,)
metadata["chunks"] = chunks
return metadata


def open_tensorstore_array(
store: T_Store,
mode: str,
*,
shape: Optional[T_Shape] = None,
dtype: Optional[T_DType] = None,
chunks: Optional[T_RegularChunks] = None,
path: Optional[str] = None,
**kwargs,
):
store = str(store) # TODO: check if Path or str

spec: Dict[str, Any]
if "://" in store:
spec = {"driver": "zarr", "kvstore": store}
else:
spec = {
"driver": "zarr",
"kvstore": {"driver": "file", "path": store},
"path": path or "",
}

if mode == "r":
open_kwargs = dict(read=True, open=True)
elif mode == "r+":
open_kwargs = dict(read=True, write=True, open=True)
elif mode == "a":
open_kwargs = dict(read=True, write=True, open=True, create=True)
elif mode == "w":
open_kwargs = dict(read=True, write=True, create=True, delete_existing=True)
elif mode == "w-":
open_kwargs = dict(read=True, write=True, create=True)
else:
raise ValueError(f"Mode not supported: {mode}")

if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None:
metadata = get_metadata(dtype, chunks)
if metadata:
spec["metadata"] = metadata

return TensorStoreArray(
tensorstore.open(
spec,
shape=shape,
dtype=dtype,
**open_kwargs,
).result()
)
else:
ret = TensorStoreGroup(shape=shape, dtype=dtype, chunks=chunks)
for field in dtype.fields:
field_path = field if path is None else join_path(path, field)
spec["path"] = field_path

field_dtype, _ = dtype.fields[field]
metadata = get_metadata(field_dtype, chunks)
if metadata:
spec["metadata"] = metadata

ret[field] = TensorStoreArray(
tensorstore.open(
spec,
shape=shape,
dtype=field_dtype,
**open_kwargs,
).result()
)
return ret
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-tenacity.*]
ignore_missing_imports = True
[mypy-tensorstore.*]
ignore_missing_imports = True
[mypy-tlz.*]
ignore_missing_imports = True
[mypy-toolz.*]
Expand Down

0 comments on commit 7f3f390

Please sign in to comment.