Skip to content

Commit

Permalink
Fix string format vulnerability
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Mar 3, 2023
1 parent 308fe34 commit 406c574
Show file tree
Hide file tree
Showing 41 changed files with 141 additions and 150 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.11)

project(Kuzu VERSION 0.0.3 LANGUAGES CXX)
project(Kuzu VERSION 0.0.4 LANGUAGES CXX)

find_package(Threads REQUIRED)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
-COMPARE_RESULT 1
-QUERY MATCH (comment:Comment) RETURN MIN(gamma(comment.length % 2 + 2))
---- 1
1
1.000000
8 changes: 4 additions & 4 deletions src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ std::vector<PropertyNameDataType> Binder::bindPropertyNameDataTypes(
for (auto& propertyNameDataType : propertyNameDataTypes) {
if (boundPropertyNames.contains(propertyNameDataType.first)) {
throw BinderException(StringUtils::string_format(
"Duplicated column name: %s, column name must be unique.",
propertyNameDataType.first.c_str()));
"Duplicated column name: {}, column name must be unique.",
propertyNameDataType.first));
} else if (TableSchema::isReservedPropertyName(propertyNameDataType.first)) {
throw BinderException(
StringUtils::string_format("PropertyName: %s is an internal reserved propertyName.",
propertyNameDataType.first.c_str()));
StringUtils::string_format("PropertyName: {} is an internal reserved propertyName.",
propertyNameDataType.first));
}
StringUtils::toUpper(propertyNameDataType.second);
auto dataType = bindDataType(propertyNameDataType.second);
Expand Down
4 changes: 2 additions & 2 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ void Binder::validateNodeTableHasNoEdge(const Catalog& _catalog, table_id_t tabl
for (auto& tableIDSchema : _catalog.getReadOnlyVersion()->getRelTableSchemas()) {
if (tableIDSchema.second->isSrcOrDstTable(tableID)) {
throw BinderException(StringUtils::string_format(
"Cannot delete a node table with edges. It is on the edges of rel: %s.",
tableIDSchema.second->tableName.c_str()));
"Cannot delete a node table with edges. It is on the edges of rel: {}.",
tableIDSchema.second->tableName));
}
}
}
Expand Down
15 changes: 7 additions & 8 deletions src/catalog/catalog_structs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ std::string TableSchema::getPropertyName(property_id_t propertyID) const {
return property.name;
}
}
throw common::RuntimeException(common::StringUtils::string_format(
"Table: %s doesn't have a property with propertyID=%d.", tableName.c_str(), propertyID));
throw common::RuntimeException(StringUtils::string_format(
"Table: {} doesn't have a property with propertyID={}.", tableName, propertyID));
}

property_id_t TableSchema::getPropertyID(const std::string& propertyName) const {
Expand All @@ -57,9 +57,8 @@ property_id_t TableSchema::getPropertyID(const std::string& propertyName) const
return property.propertyID;
}
}
throw common::RuntimeException(common::StringUtils::string_format(
"Table: %s doesn't have a property with propertyName=%s.", tableName.c_str(),
propertyName.c_str()));
throw common::RuntimeException(StringUtils::string_format(
"Table: {} doesn't have a property with propertyName={}.", tableName, propertyName));
}

Property TableSchema::getProperty(property_id_t propertyID) const {
Expand All @@ -68,8 +67,8 @@ Property TableSchema::getProperty(property_id_t propertyID) const {
return property;
}
}
throw common::RuntimeException(common::StringUtils::string_format(
"Table: %s doesn't have a property with propertyID=%d.", tableName.c_str(), propertyID));
throw common::RuntimeException(StringUtils::string_format(
"Table: {} doesn't have a property with propertyID={}.", tableName, propertyID));
}

void TableSchema::renameProperty(property_id_t propertyID, const std::string& newName) {
Expand All @@ -80,7 +79,7 @@ void TableSchema::renameProperty(property_id_t propertyID, const std::string& ne
}
}
throw common::InternalException(
"Property with id=" + std::to_string(propertyID) + " not found.");
StringUtils::string_format("Property with id={} not found.", propertyID));
}

} // namespace catalog
Expand Down
2 changes: 1 addition & 1 deletion src/common/assert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void kuAssertInternal(bool condition, const char* condition_name, const char* fi
return;
}
throw InternalException(StringUtils::string_format(
"Assertion triggered in file \"%s\" on line %d: %s", file, linenr, condition_name));
"Assertion triggered in file \"{}\" on line {}: {}", file, linenr, condition_name));
}

} // namespace common
Expand Down
4 changes: 2 additions & 2 deletions src/common/csv_reader/csv_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ bool CSVReader::hasNextTokenOrError() {
if (!hasNextToken()) {
throw ReaderException(
StringUtils::string_format("CSV Reader was expecting more tokens but the line does not "
"have any tokens left. Last token: %s",
"have any tokens left. Last token: {}",
line + linePtrStart));
}
return true;
Expand Down Expand Up @@ -311,7 +311,7 @@ std::unique_ptr<Value> CSVReader::getList(const DataType& dataType) {
auto numBytesOfOverflow = listVal.size() * Types::getDataTypeSize(dataType.typeID);
if (numBytesOfOverflow >= BufferPoolConstants::DEFAULT_PAGE_SIZE) {
throw ReaderException(StringUtils::string_format(
"Maximum num bytes of a LIST is %d. Input list's num bytes is %d.",
"Maximum num bytes of a LIST is {}. Input list's num bytes is {}.",
BufferPoolConstants::DEFAULT_PAGE_SIZE, numBytesOfOverflow));
}
return std::make_unique<Value>(
Expand Down
38 changes: 18 additions & 20 deletions src/common/file_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void FileUtils::writeToFile(
FileInfo* fileInfo, uint8_t* buffer, uint64_t numBytes, uint64_t offset) {
auto fileSize = getFileSize(fileInfo->fd);
if (fileSize == -1) {
throw Exception(StringUtils::string_format("File %s not open.", fileInfo->path.c_str()));
throw Exception(StringUtils::string_format("File {} not open.", fileInfo->path));
}
uint64_t remainingNumBytesToWrite = numBytes;
uint64_t bufferOffset = 0;
Expand All @@ -30,9 +30,9 @@ void FileUtils::writeToFile(
pwrite(fileInfo->fd, buffer + bufferOffset, numBytesToWrite, offset);
if (numBytesWritten != numBytesToWrite) {
throw Exception(StringUtils::string_format(
"Cannot write to file. path: %s fileDescriptor: %d offsetToWrite: %llu "
"numBytesToWrite: %llu numBytesWritten: %llu",
fileInfo->path.c_str(), fileInfo->fd, offset, numBytesToWrite, numBytesWritten));
"Cannot write to file. path: {} fileDescriptor: {} offsetToWrite: {} "
"numBytesToWrite: {} numBytesWritten: {}",
fileInfo->path, fileInfo->fd, offset, numBytesToWrite, numBytesWritten));
}
remainingNumBytesToWrite -= numBytesWritten;
offset += numBytesWritten;
Expand All @@ -46,8 +46,8 @@ void FileUtils::overwriteFile(const std::string& from, const std::string& to) {
std::error_code errorCode;
if (!std::filesystem::copy_file(
from, to, std::filesystem::copy_options::overwrite_existing, errorCode)) {
throw Exception(StringUtils::string_format("Error copying file %s to %s. ErrorMessage: %s",
from.c_str(), to.c_str(), errorCode.message().c_str()));
throw Exception(StringUtils::string_format(
"Error copying file {} to {}. ErrorMessage: {}", from, to, errorCode.message()));
}
}

Expand All @@ -56,25 +56,24 @@ void FileUtils::readFromFile(
auto numBytesRead = pread(fileInfo->fd, buffer, numBytes, position);
if (numBytesRead != numBytes && getFileSize(fileInfo->fd) != position + numBytesRead) {
throw Exception(
StringUtils::string_format("Cannot read from file: %s fileDescriptor: %d "
"numBytesRead: %llu numBytesToRead: %llu position: %llu",
fileInfo->path.c_str(), fileInfo->fd, numBytesRead, numBytes, position));
StringUtils::string_format("Cannot read from file: {} fileDescriptor: {} "
"numBytesRead: {} numBytesToRead: {} position: {}",
fileInfo->path, fileInfo->fd, numBytesRead, numBytes, position));
}
}

void FileUtils::createDir(const std::string& dir) {
try {
if (std::filesystem::exists(dir)) {
throw Exception(
StringUtils::string_format("Directory %s already exists.", dir.c_str()));
throw Exception(StringUtils::string_format("Directory {} already exists.", dir));
}
if (!std::filesystem::create_directory(dir)) {
throw Exception(StringUtils::string_format(
"Directory %s cannot be created. Check if it exists and remove it.", dir.c_str()));
"Directory {} cannot be created. Check if it exists and remove it.", dir));
}
} catch (std::exception& e) {
throw Exception(StringUtils::string_format(
"Failed to create directory %s due to: %s", dir.c_str(), e.what()));
throw Exception(
StringUtils::string_format("Failed to create directory {} due to: {}", dir, e.what()));
}
}

Expand All @@ -83,9 +82,8 @@ void FileUtils::removeDir(const std::string& dir) {
if (!fileOrPathExists(dir))
return;
if (!std::filesystem::remove_all(dir, removeErrorCode)) {
throw Exception(
StringUtils::string_format("Error removing directory %s. Error Message: %s",
dir.c_str(), removeErrorCode.message().c_str()));
throw Exception(StringUtils::string_format(
"Error removing directory {}. Error Message: {}", dir, removeErrorCode.message()));
}
}

Expand All @@ -97,8 +95,8 @@ void FileUtils::renameFileIfExists(const std::string& oldName, const std::string
std::filesystem::rename(oldName, newName, errorCode);
if (errorCode.value() != 0) {
throw Exception(
StringUtils::string_format("Error replacing file %s to %s. ErrorMessage: %s",
oldName.c_str(), newName.c_str(), errorCode.message().c_str()));
StringUtils::string_format("Error replacing file {} to {}. ErrorMessage: {}", oldName,
newName, errorCode.message()));
}
}

Expand All @@ -107,7 +105,7 @@ void FileUtils::removeFileIfExists(const std::string& path) {
return;
if (remove(path.c_str()) != 0) {
throw Exception(StringUtils::string_format(
"Error removing directory or file %s. Error Message: ", path.c_str()));
"Error removing directory or file {}. Error Message: ", path));
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/common/logging_level_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ spdlog::level::level_enum LoggingLevelUtils::convertStrToLevelEnum(std::string l
return spdlog::level::level_enum::err;
} else {
throw ConversionException(
StringUtils::string_format("Unsupported logging level: %s.", loggingLevel.c_str()));
StringUtils::string_format("Unsupported logging level: {}.", loggingLevel));
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/common/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ uint32_t TypeUtils::convertToUint32(const char* data) {
uint32_t val;
if (!(iss >> val)) {
throw ConversionException(
StringUtils::string_format("Failed to convert %s to uint32_t", data));
StringUtils::string_format("Failed to convert {} to uint32_t", data));
}
return val;
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/date_t.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ date_t Date::FromDate(int32_t year, int32_t month, int32_t day) {
int32_t n = 0;
if (!Date::IsValid(year, month, day)) {
throw ConversionException(
StringUtils::string_format("Date out of range: %d-%d-%d.", year, month, day));
StringUtils::string_format("Date out of range: {}-{}-{}.", year, month, day));
}
while (year < 1970) {
year += Date::YEAR_INTERVAL;
Expand Down
9 changes: 5 additions & 4 deletions src/common/types/dtime_t.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,10 @@ dtime_t Time::FromCString(const char* buf, uint64_t len) {
dtime_t result;
uint64_t pos;
if (!Time::TryConvertTime(buf, len, pos, result)) {
throw ConversionException(StringUtils::string_format(
"Error occurred during parsing time. Given: \"" + std::string(buf, len) +
"\". Expected format: (hh:mm:ss[.zzzzzz])."));
throw ConversionException(
StringUtils::string_format("Error occurred during parsing time. Given: \"{}\". "
"Expected format: (hh:mm:ss[.zzzzzz]).",
std::string(buf, len)));
}
return result;
}
Expand All @@ -160,7 +161,7 @@ bool Time::IsValid(int32_t hour, int32_t minute, int32_t second, int32_t microse
dtime_t Time::FromTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) {
if (!Time::IsValid(hour, minute, second, microseconds)) {
throw ConversionException(StringUtils::string_format(
"Time field value out of range: %d:%d:%d[.%d].", hour, minute, second, microseconds));
"Time field value out of range: {}:{}:{}[.{}].", hour, minute, second, microseconds));
}
int64_t result;
result = hour; // hours
Expand Down
20 changes: 10 additions & 10 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ void Value::copyValueFrom(const uint8_t* value) {
case DOUBLE: {
val.doubleVal = *((double*)value);
} break;
case FLOAT: {
val.floatVal = *((float_t*)value);
} break;
case DATE: {
val.dateVal = *((date_t*)value);
} break;
Expand All @@ -177,9 +180,6 @@ void Value::copyValueFrom(const uint8_t* value) {
case FIXED_LIST: {
listVal = convertKUFixedListToVector(value);
} break;
case FLOAT: {
val.floatVal = *((float_t*)value);
} break;
default:
throw RuntimeException(
"Data type " + Types::dataTypeToString(dataType) + " is not supported for Value::set");
Expand Down Expand Up @@ -209,6 +209,9 @@ void Value::copyValueFrom(const Value& other) {
case DOUBLE: {
val.doubleVal = other.val.doubleVal;
} break;
case FLOAT: {
val.floatVal = other.val.floatVal;
} break;
case DATE: {
val.dateVal = other.val.dateVal;
} break;
Expand Down Expand Up @@ -236,9 +239,6 @@ void Value::copyValueFrom(const Value& other) {
case REL: {
relVal = other.relVal->copy();
} break;
case FLOAT: {
val.floatVal = other.val.floatVal;
} break;
default:
throw NotImplementedException("Value::Value(const Value&) for type " +
Types::dataTypeToString(dataType) + " is not implemented.");
Expand All @@ -264,6 +264,8 @@ std::string Value::toString() const {
return TypeUtils::toString(val.int16Val);
case DOUBLE:
return TypeUtils::toString(val.doubleVal);
case FLOAT:
return TypeUtils::toString(val.floatVal);
case DATE:
return TypeUtils::toString(val.dateVal);
case TIMESTAMP:
Expand All @@ -287,8 +289,6 @@ std::string Value::toString() const {
return nodeVal->toString();
case REL:
return relVal->toString();
case FLOAT:
return TypeUtils::toString(val.floatVal);
default:
throw NotImplementedException("Value::toString for type " +
Types::dataTypeToString(dataType) + " is not implemented.");
Expand All @@ -306,8 +306,8 @@ void Value::validateType(DataTypeID typeID) const {
void Value::validateType(const DataType& type) const {
if (type != dataType) {
throw RuntimeException(
StringUtils::string_format("Cannot get %s value from the %s result value.",
Types::dataTypeToString(type).c_str(), Types::dataTypeToString(dataType).c_str()));
StringUtils::string_format("Cannot get {} value from the {} result value.",
Types::dataTypeToString(type), Types::dataTypeToString(dataType)));
}
}

Expand Down
22 changes: 8 additions & 14 deletions src/include/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "common/constants.h"
#include "exception.h"
#include "spdlog/fmt/fmt.h"

namespace spdlog {
class logger;
Expand All @@ -33,6 +34,11 @@ class LoggerUtils {
class StringUtils {

public:
template<typename... Args>
inline static std::string string_format(const std::string& format, Args... args) {
return fmt::format(fmt::runtime(format), args...);
}

static std::vector<std::string> split(const std::string& input, const std::string& delimiter);

static void toUpper(std::string& input) {
Expand All @@ -48,23 +54,11 @@ class StringUtils {
}

static bool CharacterIsDigit(char c) { return c >= '0' && c <= '9'; }
template<typename... Args>
static std::string string_format(const std::string& format, Args... args) {
int size_s = snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0'
if (size_s <= 0) {
throw Exception("Error during formatting.");
}
auto size = static_cast<size_t>(size_s);
auto buf = std::make_unique<char[]>(size);
snprintf(buf.get(), size, format.c_str(), args...);
return {buf.get(), buf.get() + size - 1}; // We don't want the '\0' inside
}

static std::string getLongStringErrorMessage(
const char* strToInsert, uint64_t maxAllowedStrSize) {
return StringUtils::string_format(
"Maximum length of strings is %d. Input string's length is %d.", maxAllowedStrSize,
strlen(strToInsert), strToInsert);
return string_format("Maximum length of strings is {}. Input string's length is {}.",
maxAllowedStrSize, strlen(strToInsert), strToInsert);
}
};

Expand Down
Loading

0 comments on commit 406c574

Please sign in to comment.