Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Append mode with resizing #3148

Merged
merged 4 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def from_isolated_dataframe(
next_soma_joinid += 1
return cls(data=data, field_name=index_field_name)

def get_shape(self) -> int:
if len(self.data.values()) == 0:
return 0
else:
return 1 + max(self.data.values())

def to_json(self) -> str:
return json.dumps(self, default=attrs.asdict, sort_keys=True, indent=4)

Expand Down Expand Up @@ -490,20 +496,15 @@ def get_obs_shape(self) -> int:
"""Reports the new obs shape which the experiment will need to be
resized to in order to accommodate the data contained within the
registration."""
if len(self.obs_axis.data.values()) == 0:
return 0
return 1 + max(self.obs_axis.data.values())
return self.obs_axis.get_shape()

def get_var_shapes(self) -> Dict[str, int]:
"""Reports the new var shapes, one per measurement, which the experiment
will need to be resized to in order to accommodate the data contained
within the registration."""
retval: Dict[str, int] = {}
for key, axis in self.var_axes.items():
if len(axis.data.values()) == 0:
retval[key] = 0
else:
retval[key] = 1 + max(axis.data.values())
retval[key] = axis.get_shape()
return retval

def to_json(self) -> str:
Expand Down
6 changes: 6 additions & 0 deletions apis/python/src/tiledbsoma/io/_registration/id_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def is_identity(self) -> bool:
return False
return True

def get_shape(self) -> int:
if len(self.data) == 0:
return 0
else:
return 1 + max(self.data)

@classmethod
def identity(cls, n: int) -> Self:
"""This maps 0-up input-file offsets to 0-up soma_joinid values. This is
Expand Down
24 changes: 23 additions & 1 deletion apis/python/src/tiledbsoma/io/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
NotCreateableError,
SOMAError,
)
from .._flags import NEW_SHAPE_FEATURE_FLAG_ENABLED
from .._soma_array import SOMAArray
from .._soma_object import AnySOMAObject, SOMAObject
from .._tdb_handles import RawHandle
Expand Down Expand Up @@ -1164,6 +1165,7 @@ def _write_dataframe(
df,
df_uri,
id_column_name,
shape=axis_mapping.get_shape(),
ingestion_params=ingestion_params,
additional_metadata=additional_metadata,
original_index_metadata=original_index_metadata,
Expand All @@ -1177,6 +1179,7 @@ def _write_dataframe_impl(
df_uri: str,
id_column_name: Optional[str],
*,
shape: int,
ingestion_params: IngestionParams,
additional_metadata: AdditionalMetadata = None,
original_index_metadata: OriginalIndexMetadata = None,
Expand All @@ -1203,9 +1206,15 @@ def _write_dataframe_impl(
arrow_table = _extract_new_values_for_append(df_uri, arrow_table, context)

try:
# Note: tiledbsoma.io creates dataframes with soma_joinid being the one
# and only index column.
domain = None
if NEW_SHAPE_FEATURE_FLAG_ENABLED:
domain = ((0, shape - 1),)
johnkerl marked this conversation as resolved.
Show resolved Hide resolved
soma_df = DataFrame.create(
df_uri,
schema=arrow_table.schema,
domain=domain,
platform_config=platform_config,
context=context,
)
Expand Down Expand Up @@ -1304,8 +1313,19 @@ def _create_from_matrix(
logging.log_io(None, f"START WRITING {uri}")

try:
shape: Sequence[Union[int, None]] = ()
# A SparseNDArray must be appendable in soma.io.
shape = [None for _ in matrix.shape] if cls.is_sparse else matrix.shape
if NEW_SHAPE_FEATURE_FLAG_ENABLED:
# Instead of
# shape = tuple(int(e) for e in matrix.shape)
# we consult the registration mapping. This is important
# in the case when multiple H5ADs/AnnDatas are being
# ingested to an experiment which doesn't pre-exist.
shape = (axis_0_mapping.get_shape(), axis_1_mapping.get_shape())
elif cls.is_sparse:
shape = tuple(None for _ in matrix.shape)
else:
shape = matrix.shape
soma_ndarray = cls.create(
uri,
type=pa.from_numpy_dtype(matrix.dtype),
Expand Down Expand Up @@ -2711,6 +2731,7 @@ def _ingest_uns_1d_string_array(
df,
df_uri,
None,
shape=df.shape[0],
ingestion_params=ingestion_params,
platform_config=platform_config,
context=context,
Expand Down Expand Up @@ -2756,6 +2777,7 @@ def _ingest_uns_2d_string_array(
df,
df_uri,
None,
shape=df.shape[0],
ingestion_params=ingestion_params,
additional_metadata=additional_metadata,
platform_config=platform_config,
Expand Down
20 changes: 11 additions & 9 deletions apis/python/src/tiledbsoma/io/shaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,18 @@ def resize_experiment(
output_handle=output_handle,
)

# Do an early check on the nvars keys vs the experiment's
# measurent names. This isn't a can-do status for the experiment;
# it's a failure of the user's arguments.
# Extra user-provided keys not relevant to the experiment are ignored. This
# is important for the case when a new measurement, which is registered from
# AnnData/H5AD inputs, is registered and is about to be created but does not
# exist just yet in the experiment storage.
#
# If the user hasn't provided a key -- e.g. a from-anndata-append-with-resize
# on one measurement while the experiment's other measurements aren't being
# updated -- then we need to find those other measurements' var-shapes.
with tiledbsoma.Experiment.open(uri) as exp:
arg_keys = sorted(nvars.keys())
ms_keys = sorted(exp.ms.keys())
if arg_keys != ms_keys:
raise ValueError(
f"resize_experiment: provided nvar keys {arg_keys} do not match experiment keys {ms_keys}"
)
for ms_key in exp.ms.keys():
if ms_key not in nvars.keys():
nvars[ms_key] = exp.ms[ms_key].var._maybe_soma_joinid_shape or 1

ok = _treewalk(
uri,
Expand Down
5 changes: 5 additions & 0 deletions apis/python/tests/test_basic_anndata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,11 @@ def test_nan_append(conftest_pbmc_small, dtype, nans, new_obs_ids):
var_field_name="var_id",
)

if tiledbsoma._flags.NEW_SHAPE_FEATURE_FLAG_ENABLED:
nobs = rd.get_obs_shape()
nvars = rd.get_var_shapes()
tiledbsoma.io.resize_experiment(SOMA_URI, nobs=nobs, nvars=nvars)

# Append the second anndata object
tiledbsoma.io.from_anndata(
experiment_uri=SOMA_URI,
Expand Down
42 changes: 42 additions & 0 deletions apis/python/tests/test_registration_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,11 @@ def test_multiples_without_experiment(
var_field_name=var_field_name,
)

if tiledbsoma._flags.NEW_SHAPE_FEATURE_FLAG_ENABLED:
nobs = rd.get_obs_shape()
nvars = rd.get_var_shapes()
tiledbsoma.io.resize_experiment(experiment_uri, nobs=nobs, nvars=nvars)

else:
# "Append" all the H5ADs where no experiment exists yet.
rd = registration.ExperimentAmbientLabelMapping.from_h5ad_appends_on_experiment(
Expand Down Expand Up @@ -451,6 +456,14 @@ def test_multiples_without_experiment(
h5ad_file_names[permutation[2]],
h5ad_file_names[permutation[3]],
]:
if tiledbsoma._flags.NEW_SHAPE_FEATURE_FLAG_ENABLED:
if tiledbsoma.Experiment.exists(experiment_uri):
tiledbsoma.io.resize_experiment(
experiment_uri,
nobs=rd.get_obs_shape(),
nvars=rd.get_var_shapes(),
)

tiledbsoma.io.from_h5ad(
experiment_uri,
h5ad_file_name,
Expand Down Expand Up @@ -713,6 +726,13 @@ def test_append_items_with_experiment(obs_field_name, var_field_name):

original = adata2.copy()

if tiledbsoma._flags.NEW_SHAPE_FEATURE_FLAG_ENABLED:
tiledbsoma.io.resize_experiment(
soma1,
nobs=rd.get_obs_shape(),
nvars=rd.get_var_shapes(),
)

with tiledbsoma.Experiment.open(soma1, "w") as exp1:
tiledbsoma.io.append_obs(
exp1,
Expand Down Expand Up @@ -836,6 +856,13 @@ def test_append_with_disjoint_measurements(
var_field_name=var_field_name,
)

if tiledbsoma._flags.NEW_SHAPE_FEATURE_FLAG_ENABLED:
tiledbsoma.io.resize_experiment(
soma_uri,
nobs=rd.get_obs_shape(),
nvars=rd.get_var_shapes(),
)

tiledbsoma.io.from_anndata(
soma_uri,
anndata2,
Expand Down Expand Up @@ -1190,6 +1217,14 @@ def test_enum_bit_width_append(tmp_path, all_at_once, nobs_a, nobs_b):
tiledbsoma.io.from_anndata(
soma_uri, adata, measurement_name=measurement_name, registration_mapping=rd
)

if tiledbsoma._flags.NEW_SHAPE_FEATURE_FLAG_ENABLED:
tiledbsoma.io.resize_experiment(
soma_uri,
nobs=rd.get_obs_shape(),
nvars=rd.get_var_shapes(),
)

tiledbsoma.io.from_anndata(
soma_uri, bdata, measurement_name=measurement_name, registration_mapping=rd
)
Expand All @@ -1208,6 +1243,13 @@ def test_enum_bit_width_append(tmp_path, all_at_once, nobs_a, nobs_b):
assert rd.get_obs_shape() == nobs_a + nobs_b
assert rd.get_var_shapes() == {"meas": 4}

if tiledbsoma._flags.NEW_SHAPE_FEATURE_FLAG_ENABLED:
tiledbsoma.io.resize_experiment(
soma_uri,
nobs=rd.get_obs_shape(),
nvars=rd.get_var_shapes(),
)

tiledbsoma.io.from_anndata(
soma_uri, bdata, measurement_name=measurement_name, registration_mapping=rd
)
Expand Down