Skip to content

Commit

Permalink
[c++/python] Add update_columns in C++
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenv committed Jul 30, 2024
1 parent 27784fa commit b672685
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 87 deletions.
6 changes: 0 additions & 6 deletions apis/python/src/tiledbsoma/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,6 @@ def cast_values_to_target_schema(values: pa.Table, schema: pa.Schema) -> pa.Tabl
When writing data to a SOMAArray, the values that the user passes in may not
match the schema on disk. Cast the values to the correct dtypes.
"""
# Ensure fields are in the correct order
# target_schema = []
# for input_field in values.schema:
# target_schema.append(schema.field(input_field.name))

# return values.cast(pa.schema(target_schema, values.schema.metadata))
return values


Expand Down
97 changes: 22 additions & 75 deletions apis/python/src/tiledbsoma/io/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
from somacore.options import PlatformConfig
from typing_extensions import get_args

import tiledb

from .. import (
Collection,
DataFrame,
Expand All @@ -54,7 +52,7 @@
eta,
logging,
)
from .._arrow_types import df_to_arrow, tiledb_type_from_arrow_type
from .._arrow_types import df_to_arrow
from .._collection import AnyTileDBCollection, CollectionBase
from .._common_nd_array import NDArray
from .._constants import SOMA_JOINID
Expand Down Expand Up @@ -1509,29 +1507,10 @@ def _update_dataframe(
new_data, default_index_name
)

with DataFrame.open(
sdf.uri, mode="r", context=context, platform_config=platform_config
) as sdf_r:
# Until we someday support deletes, this is the correct check on the existing,
# contiguous soma join IDs compared to the new contiguous ones about to be created.
old_jids = sorted(
e.as_py()
for e in sdf_r.read(column_names=["soma_joinid"]).concat()["soma_joinid"]
)
new_jids = list(range(len(new_data)))
if old_jids != new_jids:
raise ValueError(
f"{caller_name}: old and new data must have the same row count; got {len(old_jids)} != {len(new_jids)}",
)

old_keys = set(old_sig.keys())
new_keys = set(new_sig.keys())
drop_keys = old_keys.difference(new_keys)
add_keys = new_keys.difference(old_keys)
common_keys = old_keys.intersection(new_keys)

tiledb_create_options = TileDBCreateOptions.from_platform_config(platform_config)

msgs = []
for key in common_keys:
old_type = old_sig[key]
Expand All @@ -1543,61 +1522,29 @@ def _update_dataframe(
msg = ", ".join(msgs)
raise ValueError(f"unsupported type updates: {msg}")

se = tiledb.ArraySchemaEvolution(sdf.context.tiledb_ctx)
for drop_key in drop_keys:
se.drop_attribute(drop_key)

arrow_table = df_to_arrow(new_data)
arrow_schema = arrow_table.schema.remove_metadata()

for add_key in add_keys:
# Don't directly use the new dataframe's dtypes. Go through the
# to-Arrow-schema logic, and back, as this recapitulates the original
# schema-creation logic.
atype = arrow_schema.field(add_key).type
dtype = tiledb_type_from_arrow_type(atype)

enum_label: Optional[str] = None
if pa.types.is_dictionary(arrow_table.schema.field(add_key).type):
enum_label = add_key
dt = cast(pd.CategoricalDtype, new_data[add_key].dtype)
se.add_enumeration(
tiledb.Enumeration(
name=add_key, ordered=atype.ordered, values=list(dt.categories)
)
)

filters = tiledb_create_options.attr_filters_tiledb(add_key, ["ZstdFilter"])

# An update can create (or drop) columns, or mutate existing ones. A
# brand-new column might have nulls in it -- or it might not. And a
# subsequent mutator-update might set null values to non-null -- or vice
# versa. Therefore we must be careful to set nullability for all types.
#
# Note: this must match what DataFrame.create does:
# * DataFrame.create sets nullability for obs/var columns on initial ingest
# * Here, we set nullabiliity for obs/var columns on update_obs
# Users should get the same behavior either way.
#
# Note: this is specific to tiledbsoma.io.
# * In the SOMA API -- e.g. soma.DataFrame.create -- users bring their
# own Arrow schema (including nullabilities) and we must do what they
# say.
# * In the tiledbsoma.io API, users bring their AnnData objects, and
# we compute Arrow schemas on their behalf, and we must accommodate
# reasonable/predictable needs.

se.add_attribute(
tiledb.Attr(
name=add_key,
dtype=dtype,
filters=filters,
enum_label=enum_label,
nullable=True,
)
with DataFrame.open(
sdf.uri, mode="r", context=context, platform_config=platform_config
) as sdf_r:
# Until we someday support deletes, this is the correct check on the existing,
# contiguous soma join IDs compared to the new contiguous ones about to be created.
old_jids = sorted(
e.as_py()
for e in sdf_r.read(column_names=["soma_joinid"]).concat()["soma_joinid"]
)
new_jids = list(range(len(new_data)))
if old_jids != new_jids:
raise ValueError(
f"{caller_name}: old and new data must have the same row count; got {len(old_jids)} != {len(new_jids)}",
)

se.array_evolve(uri=sdf.uri)
new_data.reset_index(inplace=True)
if default_index_name is not None:
if default_index_name in new_data:
if "index" in new_data:
new_data.drop(columns=["index"], inplace=True)
else:
new_data.rename(columns={"index": default_index_name}, inplace=True)
sdf_r._handle._handle.update(df_to_arrow(new_data).schema)

_write_dataframe(
df_uri=sdf.uri,
Expand Down
15 changes: 15 additions & 0 deletions apis/python/src/tiledbsoma/soma_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ void write_coords(
}
}

void update(SOMAArray& array, py::handle pyarrow_schema) {
ArrowSchema arrow_schema;
uintptr_t arrow_schema_ptr = (uintptr_t)(&arrow_schema);
pyarrow_schema.attr("_export_to_c")(arrow_schema_ptr);

try {
array.update_columns(std::make_unique<ArrowSchema>(arrow_schema));
} catch (const std::exception& e) {
TPY_ERROR_LOC(e.what());
}
arrow_schema.release(&arrow_schema);
}

void load_soma_array(py::module& m) {
py::class_<SOMAArray, SOMAObject>(m, "SOMAArray")
.def(
Expand Down Expand Up @@ -518,6 +531,8 @@ void load_soma_array(py::module& m) {

.def("write_coords", write_coords)

.def("update", update)

.def("nnz", &SOMAArray::nnz, py::call_guard<py::gil_scoped_release>())

.def_property_readonly("shape", &SOMAArray::shape)
Expand Down
127 changes: 127 additions & 0 deletions libtiledbsoma/src/soma/soma_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,133 @@ void SOMAArray::write(bool sort_coords) {
array_buffer_ = nullptr;
}

void SOMAArray::update_columns(std::unique_ptr<ArrowSchema> arrow_schema) {
std::vector<std::string> old_cols;
for (auto attr : tiledb_schema()->attributes()) {
old_cols.push_back(attr.first);
}
for (auto dim : tiledb_schema()->domain().dimensions()) {
old_cols.push_back(dim.name());
}

std::vector<std::string> new_cols;
for (auto i = 0; i < arrow_schema->n_children; ++i) {
new_cols.push_back(arrow_schema->children[i]->name);
}

std::sort(new_cols.begin(), new_cols.end());
std::sort(old_cols.begin(), old_cols.end());

std::vector<std::string>::iterator it;

std::vector<std::string> common_cols(old_cols.size() + new_cols.size());
it = std::set_intersection(
old_cols.begin(),
old_cols.end(),
new_cols.begin(),
new_cols.end(),
common_cols.begin());
common_cols.resize(it - common_cols.begin());
if (!common_cols.empty()) {
for (auto name : common_cols) {
for (auto i = 0; i < arrow_schema->n_children; ++i) {
auto arrow_sch_ = arrow_schema->children[i];
if (name != arrow_sch_->name) {
continue;
}
auto new_type = ArrowAdapter::to_tiledb_format(
arrow_sch_->format);
if (!tiledb_schema()->has_attribute(arrow_sch_->name)) {
continue;
}
auto attr = tiledb_schema()->attribute(arrow_sch_->name);
auto old_type = attr.type();
auto enmr_name = AttributeExperimental::get_enumeration_name(
*ctx_->tiledb_ctx(), attr);

if (!enmr_name.has_value() && (new_type != old_type)) {
throw std::invalid_argument(fmt::format(
"Unsupported type update for {}: {} != {}",
arrow_sch_->name,
tiledb::impl::type_to_str(new_type),
tiledb::impl::type_to_str(old_type)));
}
break;
}
}
}

std::vector<std::string> drop_cols(old_cols.size());
it = std::set_difference(
old_cols.begin(),
old_cols.end(),
new_cols.begin(),
new_cols.end(),
drop_cols.begin());
drop_cols.resize(it - drop_cols.begin());

if (!drop_cols.empty()) {
ArraySchemaEvolution se(*ctx_->tiledb_ctx());
for (it = drop_cols.begin(); it != drop_cols.end(); ++it) {
if (tiledb_schema()->has_attribute(*it)) {
se.drop_attribute(*it);
}
}
se.array_evolve(uri_);
}

std::vector<std::string> add_cols(new_cols.size());
it = std::set_difference(
new_cols.begin(),
new_cols.end(),
old_cols.begin(),
old_cols.end(),
add_cols.begin());
add_cols.resize(it - add_cols.begin());

if (!add_cols.empty()) {
ArraySchemaEvolution se(*ctx_->tiledb_ctx());
for (it = add_cols.begin(); it != add_cols.end(); ++it) {
for (auto i = 0; i < arrow_schema->n_children; ++i) {
auto arrow_sch_ = arrow_schema->children[i];
if (*it != arrow_sch_->name) {
continue;
}
if (arrow_sch_->dictionary != nullptr) {
auto enmr_format = arrow_sch_->dictionary->format;
auto enmr_type = ArrowAdapter::to_tiledb_format(
enmr_format);
auto enmr = Enumeration::create_empty(
*ctx_->tiledb_ctx(),
arrow_sch_->name,
enmr_type,
ArrowAdapter::is_var_arrow_format(enmr_format) ?
TILEDB_VAR_NUM :
1,
arrow_sch_->flags & ARROW_FLAG_DICTIONARY_ORDERED);
se.add_enumeration(enmr);
}
auto type = ArrowAdapter::to_tiledb_format(arrow_sch_->format);
Attribute attr(*ctx_->tiledb_ctx(), *it, type);
if (arrow_sch_->dictionary != nullptr) {
AttributeExperimental::set_enumeration_name(
*ctx_->tiledb_ctx(), attr, arrow_sch_->name);
}
se.add_attribute(attr);
break;
}
}
se.array_evolve(uri_);
}

// Use the updated array in ManagedQuery
if (!drop_cols.empty() || !add_cols.empty()) {
arr_->close();
arr_->open(TILEDB_WRITE);
mq_ = std::make_unique<ManagedQuery>(arr_, ctx_->tiledb_ctx(), name_);
}
}

void SOMAArray::consolidate_and_vacuum(std::vector<std::string> modes) {
for (auto mode : modes) {
auto cfg = ctx_->tiledb_ctx()->config();
Expand Down
2 changes: 2 additions & 0 deletions libtiledbsoma/src/soma/soma_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ class SOMAArray : public SOMAObject {
*/
void write(bool sort_coords = true);

void update_columns(std::unique_ptr<ArrowSchema> arrow_schema);

/**
* @brief Consolidates and vacuums fragment metadata and commit files.
*
Expand Down
10 changes: 6 additions & 4 deletions libtiledbsoma/src/utils/arrow_adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ ArraySchema ArrowAdapter::tiledb_schema_from_arrow_schema(
for (int64_t i = 0; i < index_column_schema->n_children; ++i) {
auto col_name = index_column_schema->children[i]->name;
if (strcmp(child->name, col_name) == 0) {
if (ArrowAdapter::_isvar(child->format)) {
if (ArrowAdapter::is_var_arrow_format(child->format)) {
type = TILEDB_STRING_ASCII;
}

Expand Down Expand Up @@ -549,7 +549,7 @@ ArraySchema ArrowAdapter::tiledb_schema_from_arrow_schema(
attr.set_nullable(true);
}

if (ArrowAdapter::_isvar(child->format)) {
if (ArrowAdapter::is_var_arrow_format(child->format)) {
attr.set_cell_val_num(TILEDB_VAR_NUM);
}

Expand All @@ -560,7 +560,9 @@ ArraySchema ArrowAdapter::tiledb_schema_from_arrow_schema(
*ctx,
child->name,
enmr_type,
ArrowAdapter::_isvar(enmr_format) ? TILEDB_VAR_NUM : 1,
ArrowAdapter::is_var_arrow_format(enmr_format) ?
TILEDB_VAR_NUM :
1,
child->flags & ARROW_FLAG_DICTIONARY_ORDERED);
ArraySchemaExperimental::add_enumeration(*ctx, schema, enmr);
AttributeExperimental::set_enumeration_name(
Expand Down Expand Up @@ -846,7 +848,7 @@ ArrowAdapter::to_arrow(std::shared_ptr<ColumnBuffer> column) {
return std::pair(std::move(array), std::move(schema));
}

bool ArrowAdapter::_isvar(const char* format) {
bool ArrowAdapter::is_var_arrow_format(const char* format) {
if ((strcmp(format, "U") == 0) || (strcmp(format, "Z") == 0) ||
(strcmp(format, "u") == 0) || (strcmp(format, "z") == 0)) {
return true;
Expand Down
4 changes: 2 additions & 2 deletions libtiledbsoma/src/utils/arrow_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class ArrowAdapter {
static std::string_view to_arrow_format(
tiledb_datatype_t tiledb_dtype, bool use_large = true);

static bool is_var_arrow_format(const char* format);

/**
* @brief Get TileDB datatype from Arrow format string.
*
Expand Down Expand Up @@ -244,8 +246,6 @@ class ArrowAdapter {
return Dimension::create<T>(*ctx, name, {b[0], b[1]}, b[2]);
}

static bool _isvar(const char* format);

static FilterList _create_filter_list(
std::string filters, std::shared_ptr<Context> ctx);

Expand Down

0 comments on commit b672685

Please sign in to comment.