Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed May 17, 2023
1 parent 4d60988 commit d202b89
Show file tree
Hide file tree
Showing 46 changed files with 300 additions and 247 deletions.
4 changes: 2 additions & 2 deletions src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ LogicalType Binder::bindDataType(const std::string& dataType) {
std::to_string(numElementsInList) + ".");
}
auto numElementsPerPage = storage::PageUtils::getNumElementsInAPage(
Types::getLogicalTypeSize(boundType), true /* hasNull */);
storage::StorageUtils::getDataTypeSize(boundType), true /* hasNull */);
if (numElementsPerPage == 0) {
throw common::BinderException(
StringUtils::string_format("Cannot store a fixed list of size {} in a page.",
Types::getLogicalTypeSize(boundType)));
storage::StorageUtils::getDataTypeSize(boundType)));
}
}
return boundType;
Expand Down
7 changes: 1 addition & 6 deletions src/c_api/data_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,5 @@ uint64_t kuzu_data_type_get_fixed_num_elements_in_list(kuzu_data_type* data_type
if (parent_type->getLogicalTypeID() != LogicalTypeID::FIXED_LIST) {
return 0;
}
auto extra_info = static_cast<LogicalType*>(data_type->_data_type)->getExtraTypeInfo();
if (extra_info == nullptr) {
return 0;
}
auto fixed_list_info = dynamic_cast<FixedListTypeInfo*>(extra_info);
return fixed_list_info->getNumElementsInList();
return FixedListType::getNumElementsInList(static_cast<LogicalType*>(data_type->_data_type));
}
6 changes: 2 additions & 4 deletions src/c_api/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,13 @@ kuzu_value* kuzu_value_get_list_element(kuzu_value* value, uint64_t index) {
uint64_t kuzu_value_get_struct_num_fields(kuzu_value* value) {
auto val = static_cast<Value*>(value->_value);
auto data_type = val->getDataType();
auto struct_type_info = reinterpret_cast<StructTypeInfo*>(data_type.getExtraTypeInfo());
return struct_type_info->getStructFields().size();
return StructType::getNumFields(&data_type);
}

char* kuzu_value_get_struct_field_name(kuzu_value* value, uint64_t index) {
auto val = static_cast<Value*>(value->_value);
auto data_type = val->getDataType();
auto struct_type_info = reinterpret_cast<StructTypeInfo*>(data_type.getExtraTypeInfo());
auto struct_field_name = struct_type_info->getStructFields()[index]->getName();
auto struct_field_name = StructType::getStructFields(&data_type)[index]->getName();
auto* c_struct_field_name = (char*)malloc(sizeof(char) * (struct_field_name.size() + 1));
strcpy(c_struct_field_name, struct_field_name.c_str());
return c_struct_field_name;
Expand Down
8 changes: 5 additions & 3 deletions src/common/arrow/arrow_row_batch.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "common/arrow/arrow_row_batch.h"

#include "common/types/value.h"
#include "storage/storage_utils.h"

namespace kuzu {
namespace common {
Expand All @@ -19,7 +20,7 @@ template<LogicalTypeID DT>
void ArrowRowBatch::templateInitializeVector(
ArrowVector* vector, const main::DataTypeInfo& typeInfo, std::int64_t capacity) {
initializeNullBits(vector->validity, capacity);
vector->data.reserve(Types::getLogicalTypeSize(DT) * capacity);
vector->data.reserve(storage::StorageUtils::getDataTypeSize(LogicalType{DT}) * capacity);
}

template<>
Expand Down Expand Up @@ -160,7 +161,7 @@ void ArrowRowBatch::appendValue(
template<LogicalTypeID DT>
void ArrowRowBatch::templateCopyNonNullValue(
ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos) {
auto valSize = Types::getLogicalTypeSize(DT);
auto valSize = storage::StorageUtils::getDataTypeSize(LogicalType{DT});
std::memcpy(vector->data.data() + pos * valSize, &value->val, valSize);
}

Expand Down Expand Up @@ -201,7 +202,8 @@ void ArrowRowBatch::templateCopyNonNullValue<LogicalTypeID::VAR_LIST>(
}
if (typeInfo.childrenTypesInfo[0]->typeID != LogicalTypeID::VAR_LIST) {
vector->childData[0]->data.resize(
numChildElements * Types::getLogicalTypeSize(typeInfo.childrenTypesInfo[0]->typeID));
numChildElements * storage::StorageUtils::getDataTypeSize(
LogicalType{typeInfo.childrenTypesInfo[0]->typeID}));
}
for (auto i = 0u; i < numElements; i++) {
appendValue(vector->childData[0].get(), *typeInfo.childrenTypesInfo[0],
Expand Down
2 changes: 1 addition & 1 deletion src/common/in_mem_overflow_buffer_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void InMemOverflowBufferUtils::copyListRecursiveIfNested(const ku_list_t& src, k
assert(srcEndIdx < src.size);
auto numElements = srcEndIdx - srcStartIdx + 1;
auto childType = VarListType::getChildType(&dataType);
auto elementSize = Types::getLogicalTypeSize(*childType);
auto elementSize = storage::StorageUtils::getDataTypeSize(*childType);
InMemOverflowBufferUtils::allocateSpaceForList(
dst, numElements * elementSize, inMemOverflowBuffer);
memcpy((uint8_t*)dst.overflowPtr, (uint8_t*)src.overflowPtr + srcStartIdx * elementSize,
Expand Down
6 changes: 3 additions & 3 deletions src/common/types/ku_list.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#include "common/types/ku_list.h"

#include <cassert>
#include "storage/storage_utils.h"

namespace kuzu {
namespace common {

void ku_list_t::set(const uint8_t* values, const LogicalType& dataType) const {
memcpy(reinterpret_cast<uint8_t*>(overflowPtr), values,
size * Types::getLogicalTypeSize(*VarListType::getChildType(&dataType)));
size * storage::StorageUtils::getDataTypeSize(*VarListType::getChildType(&dataType)));
}

void ku_list_t::set(const std::vector<uint8_t*>& parameters, LogicalTypeID childTypeId) {
this->size = parameters.size();
auto numBytesOfListElement = Types::getLogicalTypeSize(childTypeId);
auto numBytesOfListElement = storage::StorageUtils::getDataTypeSize(LogicalType{childTypeId});
for (auto i = 0u; i < parameters.size(); i++) {
memcpy(reinterpret_cast<uint8_t*>(this->overflowPtr) + (i * numBytesOfListElement),
parameters[i], numBytesOfListElement);
Expand Down
91 changes: 20 additions & 71 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <stdexcept>

#include "common/exception.h"
#include "common/null_buffer.h"
#include "common/ser_deser.h"
#include "common/types/types_include.h"

Expand Down Expand Up @@ -128,7 +127,7 @@ LogicalType::LogicalType(const LogicalType& other) {
typeID = other.typeID;
physicalType = other.physicalType;
if (other.extraTypeInfo != nullptr) {
extraTypeInfo = other.getExtraTypeInfo()->copy();
extraTypeInfo = other.extraTypeInfo->copy();
}
}

Expand Down Expand Up @@ -386,76 +385,26 @@ std::string Types::dataTypesToString(const std::vector<LogicalTypeID>& dataTypeI
return result;
}

uint32_t Types::getLogicalTypeSize(LogicalTypeID dataTypeID) {
switch (dataTypeID) {
case LogicalTypeID::INTERNAL_ID:
return sizeof(internalID_t);
case LogicalTypeID::BOOL:
return sizeof(uint8_t);
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64:
uint32_t Types::getFixedTypeSize(kuzu::common::PhysicalType physicalType) {
switch (physicalType) {
case PhysicalType::BOOL:
return sizeof(bool);
case PhysicalType::INT64:
return sizeof(int64_t);
case LogicalTypeID::INT32:
case PhysicalType::INT32:
return sizeof(int32_t);
case LogicalTypeID::INT16:
case PhysicalType::INT16:
return sizeof(int16_t);
case LogicalTypeID::DOUBLE:
case PhysicalType::DOUBLE:
return sizeof(double_t);
case LogicalTypeID::FLOAT:
case PhysicalType::FLOAT:
return sizeof(float_t);
case LogicalTypeID::DATE:
return sizeof(date_t);
case LogicalTypeID::TIMESTAMP:
return sizeof(timestamp_t);
case LogicalTypeID::INTERVAL:
case PhysicalType::INTERVAL:
return sizeof(interval_t);
case LogicalTypeID::STRING:
return sizeof(ku_string_t);
case LogicalTypeID::VAR_LIST:
return sizeof(ku_list_t);
case PhysicalType::INTERNAL_ID:
return sizeof(internalID_t);
default:
throw InternalException(
"Cannot infer the size of dataTypeID: " + dataTypeToString(dataTypeID) + ".");
}
}

// This function returns the size of the dataType when stored in a row layout. (e.g.
// factorizedTable).
uint32_t Types::getLogicalTypeSize(const LogicalType& dataType) {
switch (dataType.typeID) {
case LogicalTypeID::FIXED_LIST: {
auto fixedListTypeInfo = reinterpret_cast<FixedListTypeInfo*>(dataType.extraTypeInfo.get());
return getLogicalTypeSize(*fixedListTypeInfo->getChildType()) *
fixedListTypeInfo->getNumElementsInList();
}
case LogicalTypeID::STRUCT: {
auto structTypeInfo = reinterpret_cast<StructTypeInfo*>(dataType.extraTypeInfo.get());
uint32_t size = 0;
for (auto& childType : structTypeInfo->getChildrenTypes()) {
size += getLogicalTypeSize(*childType);
}
size += NullBuffer::getNumBytesForNullValues(structTypeInfo->getChildrenNames().size());
return size;
}
case LogicalTypeID::INTERNAL_ID:
case LogicalTypeID::BOOL:
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64:
case LogicalTypeID::INT32:
case LogicalTypeID::INT16:
case LogicalTypeID::DOUBLE:
case LogicalTypeID::FLOAT:
case LogicalTypeID::DATE:
case LogicalTypeID::TIMESTAMP:
case LogicalTypeID::INTERVAL:
case LogicalTypeID::STRING:
case LogicalTypeID::VAR_LIST: {
return getLogicalTypeSize(dataType.typeID);
}
default: {
throw InternalException(
"Cannot infer the size of dataTypeID: " + dataTypeToString(dataType.typeID) + ".");
}
throw RuntimeException{"Cannot infer the size of a variable dataType."};
}
}

Expand Down Expand Up @@ -591,15 +540,15 @@ uint64_t SerDeser::serializeValue(const LogicalType& value, FileInfo* fileInfo,
offset = SerDeser::serializeValue(value.getLogicalTypeID(), fileInfo, offset);
switch (value.getLogicalTypeID()) {
case LogicalTypeID::VAR_LIST: {
auto varListTypeInfo = reinterpret_cast<VarListTypeInfo*>(value.getExtraTypeInfo());
auto varListTypeInfo = reinterpret_cast<VarListTypeInfo*>(value.extraTypeInfo.get());
offset = serializeValue(*varListTypeInfo, fileInfo, offset);
} break;
case LogicalTypeID::FIXED_LIST: {
auto fixedListTypeInfo = reinterpret_cast<FixedListTypeInfo*>(value.getExtraTypeInfo());
auto fixedListTypeInfo = reinterpret_cast<FixedListTypeInfo*>(value.extraTypeInfo.get());
offset = serializeValue(*fixedListTypeInfo, fileInfo, offset);
} break;
case LogicalTypeID::STRUCT: {
auto structTypeInfo = reinterpret_cast<StructTypeInfo*>(value.getExtraTypeInfo());
auto structTypeInfo = reinterpret_cast<StructTypeInfo*>(value.extraTypeInfo.get());
offset = serializeValue(*structTypeInfo, fileInfo, offset);
} break;
default:
Expand All @@ -616,18 +565,18 @@ uint64_t SerDeser::deserializeValue(LogicalType& value, FileInfo* fileInfo, uint
case LogicalTypeID::VAR_LIST: {
value.extraTypeInfo = std::make_unique<VarListTypeInfo>();
offset = deserializeValue(
*reinterpret_cast<VarListTypeInfo*>(value.getExtraTypeInfo()), fileInfo, offset);
*reinterpret_cast<VarListTypeInfo*>(value.extraTypeInfo.get()), fileInfo, offset);

} break;
case LogicalTypeID::FIXED_LIST: {
value.extraTypeInfo = std::make_unique<FixedListTypeInfo>();
offset = deserializeValue(
*reinterpret_cast<FixedListTypeInfo*>(value.getExtraTypeInfo()), fileInfo, offset);
*reinterpret_cast<FixedListTypeInfo*>(value.extraTypeInfo.get()), fileInfo, offset);
} break;
case LogicalTypeID::STRUCT: {
value.extraTypeInfo = std::make_unique<StructTypeInfo>();
offset = deserializeValue(
*reinterpret_cast<StructTypeInfo*>(value.getExtraTypeInfo()), fileInfo, offset);
*reinterpret_cast<StructTypeInfo*>(value.extraTypeInfo.get()), fileInfo, offset);
} break;
default:
break;
Expand Down
13 changes: 6 additions & 7 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "common/null_buffer.h"
#include "common/string_utils.h"
#include "storage/storage_utils.h"

namespace kuzu {
namespace common {
Expand Down Expand Up @@ -299,8 +300,7 @@ std::string Value::toString() const {
}
case LogicalTypeID::STRUCT: {
std::string result = "{";
auto structTypeInfo = reinterpret_cast<StructTypeInfo*>(dataType.getExtraTypeInfo());
auto childrenNames = structTypeInfo->getChildrenNames();
auto childrenNames = common::StructType::getStructFieldNames(&dataType);
for (auto i = 0u; i < nestedTypeVal.size(); ++i) {
result += childrenNames[i];
result += ": ";
Expand Down Expand Up @@ -329,7 +329,7 @@ Value::Value(LogicalType dataType) : dataType{std::move(dataType)}, isNull_{true
std::vector<std::unique_ptr<Value>> Value::convertKUVarListToVector(ku_list_t& list) const {
std::vector<std::unique_ptr<Value>> listResultValue;
auto childType = VarListType::getChildType(&dataType);
auto numBytesPerElement = Types::getLogicalTypeSize(*childType);
auto numBytesPerElement = storage::StorageUtils::getDataTypeSize(*childType);
auto listNullBytes = reinterpret_cast<uint8_t*>(list.overflowPtr);
auto numBytesForNullValues = NullBuffer::getNumBytesForNullValues(list.size);
auto listValues = listNullBytes + numBytesForNullValues;
Expand All @@ -351,7 +351,7 @@ std::vector<std::unique_ptr<Value>> Value::convertKUFixedListToVector(
auto numElementsInList = FixedListType::getNumElementsInList(&dataType);
std::vector<std::unique_ptr<Value>> fixedListResultVal{numElementsInList};
auto childType = FixedListType::getChildType(&dataType);
auto numBytesPerElement = Types::getLogicalTypeSize(*childType);
auto numBytesPerElement = storage::StorageUtils::getDataTypeSize(*childType);
switch (childType->getLogicalTypeID()) {
case common::LogicalTypeID::INT64: {
putValuesIntoVector<int64_t>(fixedListResultVal, fixedList, numBytesPerElement);
Expand All @@ -375,9 +375,8 @@ std::vector<std::unique_ptr<Value>> Value::convertKUFixedListToVector(
}

std::vector<std::unique_ptr<Value>> Value::convertKUStructToVector(const uint8_t* kuStruct) const {
auto structTypeInfo = reinterpret_cast<StructTypeInfo*>(dataType.getExtraTypeInfo());
std::vector<std::unique_ptr<Value>> structVal;
auto childrenTypes = structTypeInfo->getChildrenTypes();
auto childrenTypes = StructType::getStructFieldTypes(&dataType);
auto numFields = childrenTypes.size();
auto structNullValues = kuStruct;
auto structValues = structNullValues + NullBuffer::getNumBytesForNullValues(numFields);
Expand All @@ -389,7 +388,7 @@ std::vector<std::unique_ptr<Value>> Value::convertKUStructToVector(const uint8_t
childValue->copyValueFrom(structValues);
}
structVal.emplace_back(std::move(childValue));
structValues += Types::getLogicalTypeSize(*childrenTypes[i]);
structValues += storage::StorageUtils::getDataTypeSize(*childrenTypes[i]);
}
return structVal;
}
Expand Down
6 changes: 3 additions & 3 deletions src/common/vector/auxiliary_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ void StringAuxiliaryBuffer::addString(

StructAuxiliaryBuffer::StructAuxiliaryBuffer(
const LogicalType& type, storage::MemoryManager* memoryManager) {
auto structTypeInfo = reinterpret_cast<StructTypeInfo*>(type.getExtraTypeInfo());
childrenVectors.reserve(structTypeInfo->getChildrenTypes().size());
for (auto structFieldType : structTypeInfo->getChildrenTypes()) {
auto structFieldTypes = StructType::getStructFieldTypes(&type);
childrenVectors.reserve(structFieldTypes.size());
for (auto structFieldType : structFieldTypes) {
childrenVectors.push_back(std::make_shared<ValueVector>(*structFieldType, memoryManager));
}
}
Expand Down
23 changes: 15 additions & 8 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace common {

ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager)
: dataType{std::move(dataType)} {
setNumBytesPerValue();
numBytesPerValue = getDataTypeSize(this->dataType);
initializeValueBuffer();
nullMask = std::make_unique<NullMask>();
auxiliaryBuffer = AuxiliaryBufferFactory::getAuxiliaryBuffer(this->dataType, memoryManager);
Expand Down Expand Up @@ -56,16 +56,23 @@ void ValueVector::setValue(uint32_t pos, std::string val) {
StringVector::addString(this, pos, val.data(), val.length());
}

void ValueVector::setNumBytesPerValue() {
switch (dataType.getLogicalTypeID()) {
uint32_t ValueVector::getDataTypeSize(const LogicalType& type) {
switch (type.getLogicalTypeID()) {
case common::LogicalTypeID::STRING: {
return sizeof(common::ku_string_t);
}
case common::LogicalTypeID::FIXED_LIST: {
return getDataTypeSize(*common::FixedListType::getChildType(&type)) *
common::FixedListType::getNumElementsInList(&type);
}
case LogicalTypeID::STRUCT: {
numBytesPerValue = sizeof(struct_entry_t);
} break;
return sizeof(struct_entry_t);
}
case LogicalTypeID::VAR_LIST: {
numBytesPerValue = sizeof(list_entry_t);
} break;
return sizeof(list_entry_t);
}
default: {
numBytesPerValue = Types::getLogicalTypeSize(dataType);
return Types::getFixedTypeSize(type.getPhysicalType());
}
}
}
Expand Down
Loading

0 comments on commit d202b89

Please sign in to comment.