Skip to content

Commit

Permalink
Fix string format vulnerability
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Mar 2, 2023
1 parent d025c83 commit 0d52aa7
Show file tree
Hide file tree
Showing 42 changed files with 166 additions and 199 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
-COMPARE_RESULT 1
-QUERY MATCH (comment:Comment) RETURN MIN(gamma(comment.length % 2 + 2))
---- 1
1
1.000000
10 changes: 5 additions & 5 deletions src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ std::vector<PropertyNameDataType> Binder::bindPropertyNameDataTypes(
std::unordered_set<std::string> boundPropertyNames;
for (auto& propertyNameDataType : propertyNameDataTypes) {
if (boundPropertyNames.contains(propertyNameDataType.first)) {
throw BinderException(StringUtils::string_format(
"Duplicated column name: %s, column name must be unique.",
propertyNameDataType.first.c_str()));
throw BinderException(
fmt::format("Duplicated column name: {}, column name must be unique.",
propertyNameDataType.first));
} else if (TableSchema::isReservedPropertyName(propertyNameDataType.first)) {
throw BinderException(
StringUtils::string_format("PropertyName: %s is an internal reserved propertyName.",
propertyNameDataType.first.c_str()));
fmt::format("PropertyName: {} is an internal reserved propertyName.",
propertyNameDataType.first));
}
StringUtils::toUpper(propertyNameDataType.second);
auto dataType = bindDataType(propertyNameDataType.second);
Expand Down
6 changes: 3 additions & 3 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ bool Binder::validateStringParsingOptionName(std::string& parsingOptionName) {
void Binder::validateNodeTableHasNoEdge(const Catalog& _catalog, table_id_t tableID) {
for (auto& tableIDSchema : _catalog.getReadOnlyVersion()->getRelTableSchemas()) {
if (tableIDSchema.second->isSrcOrDstTable(tableID)) {
throw BinderException(StringUtils::string_format(
"Cannot delete a node table with edges. It is on the edges of rel: %s.",
tableIDSchema.second->tableName.c_str()));
throw BinderException(
fmt::format("Cannot delete a node table with edges. It is on the edges of rel: {}.",
tableIDSchema.second->tableName));
}
}
}
Expand Down
16 changes: 7 additions & 9 deletions src/catalog/catalog_structs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ std::string TableSchema::getPropertyName(property_id_t propertyID) const {
return property.name;
}
}
throw common::RuntimeException(common::StringUtils::string_format(
"Table: %s doesn't have a property with propertyID=%d.", tableName.c_str(), propertyID));
throw common::RuntimeException(fmt::format(
"Table: {} doesn't have a property with propertyID={}.", tableName, propertyID));
}

property_id_t TableSchema::getPropertyID(const std::string& propertyName) const {
Expand All @@ -57,9 +57,8 @@ property_id_t TableSchema::getPropertyID(const std::string& propertyName) const
return property.propertyID;
}
}
throw common::RuntimeException(common::StringUtils::string_format(
"Table: %s doesn't have a property with propertyName=%s.", tableName.c_str(),
propertyName.c_str()));
throw common::RuntimeException(fmt::format(
"Table: {} doesn't have a property with propertyName={}.", tableName, propertyName));
}

Property TableSchema::getProperty(property_id_t propertyID) const {
Expand All @@ -68,8 +67,8 @@ Property TableSchema::getProperty(property_id_t propertyID) const {
return property;
}
}
throw common::RuntimeException(common::StringUtils::string_format(
"Table: %s doesn't have a property with propertyID=%d.", tableName.c_str(), propertyID));
throw common::RuntimeException(fmt::format(
"Table: {} doesn't have a property with propertyID={}.", tableName, propertyID));
}

void TableSchema::renameProperty(property_id_t propertyID, const std::string& newName) {
Expand All @@ -79,8 +78,7 @@ void TableSchema::renameProperty(property_id_t propertyID, const std::string& ne
return;
}
}
throw common::InternalException(
"Property with id=" + std::to_string(propertyID) + " not found.");
throw common::InternalException(fmt::format("Property with id={} not found.", propertyID));
}

} // namespace catalog
Expand Down
4 changes: 2 additions & 2 deletions src/common/assert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ void kuAssertInternal(bool condition, const char* condition_name, const char* fi
if (condition) {
return;
}
throw InternalException(StringUtils::string_format(
"Assertion triggered in file \"%s\" on line %d: %s", file, linenr, condition_name));
throw InternalException(fmt::format(
"Assertion triggered in file \"{}\" on line {}: {}", file, linenr, condition_name));
}

} // namespace common
Expand Down
10 changes: 5 additions & 5 deletions src/common/csv_reader/csv_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ bool CSVReader::hasNextToken() {
bool CSVReader::hasNextTokenOrError() {
if (!hasNextToken()) {
throw ReaderException(
StringUtils::string_format("CSV Reader was expecting more tokens but the line does not "
"have any tokens left. Last token: %s",
fmt::format("CSV Reader was expecting more tokens but the line does not "
"have any tokens left. Last token: {}",
line + linePtrStart));
}
return true;
Expand Down Expand Up @@ -310,9 +310,9 @@ std::unique_ptr<Value> CSVReader::getList(const DataType& dataType) {
}
auto numBytesOfOverflow = listVal.size() * Types::getDataTypeSize(dataType.typeID);
if (numBytesOfOverflow >= BufferPoolConstants::DEFAULT_PAGE_SIZE) {
throw ReaderException(StringUtils::string_format(
"Maximum num bytes of a LIST is %d. Input list's num bytes is %d.",
BufferPoolConstants::DEFAULT_PAGE_SIZE, numBytesOfOverflow));
throw ReaderException(
fmt::format("Maximum num bytes of a LIST is {}. Input list's num bytes is {}.",
BufferPoolConstants::DEFAULT_PAGE_SIZE, numBytesOfOverflow));
}
return std::make_unique<Value>(
DataType(VAR_LIST, std::make_unique<DataType>(dataType)), std::move(listVal));
Expand Down
44 changes: 19 additions & 25 deletions src/common/file_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void FileUtils::writeToFile(
FileInfo* fileInfo, uint8_t* buffer, uint64_t numBytes, uint64_t offset) {
auto fileSize = getFileSize(fileInfo->fd);
if (fileSize == -1) {
throw Exception(StringUtils::string_format("File %s not open.", fileInfo->path.c_str()));
throw Exception(fmt::format("File {} not open.", fileInfo->path));
}
uint64_t remainingNumBytesToWrite = numBytes;
uint64_t bufferOffset = 0;
Expand All @@ -29,10 +29,10 @@ void FileUtils::writeToFile(
uint64_t numBytesWritten =
pwrite(fileInfo->fd, buffer + bufferOffset, numBytesToWrite, offset);
if (numBytesWritten != numBytesToWrite) {
throw Exception(StringUtils::string_format(
"Cannot write to file. path: %s fileDescriptor: %d offsetToWrite: %llu "
"numBytesToWrite: %llu numBytesWritten: %llu",
fileInfo->path.c_str(), fileInfo->fd, offset, numBytesToWrite, numBytesWritten));
throw Exception(
fmt::format("Cannot write to file. path: {} fileDescriptor: {} offsetToWrite: {} "
"numBytesToWrite: {} numBytesWritten: {}",
fileInfo->path, fileInfo->fd, offset, numBytesToWrite, numBytesWritten));
}
remainingNumBytesToWrite -= numBytesWritten;
offset += numBytesWritten;
Expand All @@ -46,35 +46,32 @@ void FileUtils::overwriteFile(const std::string& from, const std::string& to) {
std::error_code errorCode;
if (!std::filesystem::copy_file(
from, to, std::filesystem::copy_options::overwrite_existing, errorCode)) {
throw Exception(StringUtils::string_format("Error copying file %s to %s. ErrorMessage: %s",
from.c_str(), to.c_str(), errorCode.message().c_str()));
throw Exception(fmt::format(
"Error copying file {} to {}. ErrorMessage: {}", from, to, errorCode.message()));
}
}

void FileUtils::readFromFile(
FileInfo* fileInfo, void* buffer, uint64_t numBytes, uint64_t position) {
auto numBytesRead = pread(fileInfo->fd, buffer, numBytes, position);
if (numBytesRead != numBytes && getFileSize(fileInfo->fd) != position + numBytesRead) {
throw Exception(
StringUtils::string_format("Cannot read from file: %s fileDescriptor: %d "
"numBytesRead: %llu numBytesToRead: %llu position: %llu",
fileInfo->path.c_str(), fileInfo->fd, numBytesRead, numBytes, position));
throw Exception(fmt::format("Cannot read from file: {} fileDescriptor: {} "
"numBytesRead: {} numBytesToRead: {} position: {}",
fileInfo->path, fileInfo->fd, numBytesRead, numBytes, position));
}
}

void FileUtils::createDir(const std::string& dir) {
try {
if (std::filesystem::exists(dir)) {
throw Exception(
StringUtils::string_format("Directory %s already exists.", dir.c_str()));
throw Exception(fmt::format("Directory {} already exists.", dir));
}
if (!std::filesystem::create_directory(dir)) {
throw Exception(StringUtils::string_format(
"Directory %s cannot be created. Check if it exists and remove it.", dir.c_str()));
throw Exception(fmt::format(
"Directory {} cannot be created. Check if it exists and remove it.", dir));
}
} catch (std::exception& e) {
throw Exception(StringUtils::string_format(
"Failed to create directory %s due to: %s", dir.c_str(), e.what()));
throw Exception(fmt::format("Failed to create directory {} due to: {}", dir, e.what()));
}
}

Expand All @@ -83,9 +80,8 @@ void FileUtils::removeDir(const std::string& dir) {
if (!fileOrPathExists(dir))
return;
if (!std::filesystem::remove_all(dir, removeErrorCode)) {
throw Exception(
StringUtils::string_format("Error removing directory %s. Error Message: %s",
dir.c_str(), removeErrorCode.message().c_str()));
throw Exception(fmt::format(
"Error removing directory {}. Error Message: {}", dir, removeErrorCode.message()));
}
}

Expand All @@ -96,18 +92,16 @@ void FileUtils::renameFileIfExists(const std::string& oldName, const std::string
std::error_code errorCode;
std::filesystem::rename(oldName, newName, errorCode);
if (errorCode.value() != 0) {
throw Exception(
StringUtils::string_format("Error replacing file %s to %s. ErrorMessage: %s",
oldName.c_str(), newName.c_str(), errorCode.message().c_str()));
throw Exception(fmt::format("Error replacing file {} to {}. ErrorMessage: {}", oldName,
newName, errorCode.message()));
}
}

void FileUtils::removeFileIfExists(const std::string& path) {
if (!fileOrPathExists(path))
return;
if (remove(path.c_str()) != 0) {
throw Exception(StringUtils::string_format(
"Error removing directory or file %s. Error Message: ", path.c_str()));
throw Exception(fmt::format("Error removing directory or file {}. Error Message: ", path));
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/common/logging_level_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ spdlog::level::level_enum LoggingLevelUtils::convertStrToLevelEnum(std::string l
} else if (loggingLevel == "err") {
return spdlog::level::level_enum::err;
} else {
throw ConversionException(
StringUtils::string_format("Unsupported logging level: %s.", loggingLevel.c_str()));
throw ConversionException(fmt::format("Unsupported logging level: {}.", loggingLevel));
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/common/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ uint32_t TypeUtils::convertToUint32(const char* data) {
std::istringstream iss(data);
uint32_t val;
if (!(iss >> val)) {
throw ConversionException(
StringUtils::string_format("Failed to convert %s to uint32_t", data));
throw ConversionException(fmt::format("Failed to convert {} to uint32_t", data));
}
return val;
}
Expand Down
3 changes: 1 addition & 2 deletions src/common/types/date_t.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ void Date::Convert(date_t d, int32_t& year, int32_t& month, int32_t& day) {
date_t Date::FromDate(int32_t year, int32_t month, int32_t day) {
int32_t n = 0;
if (!Date::IsValid(year, month, day)) {
throw ConversionException(
StringUtils::string_format("Date out of range: %d-%d-%d.", year, month, day));
throw ConversionException(fmt::format("Date out of range: {}-{}-{}.", year, month, day));
}
while (year < 1970) {
year += Date::YEAR_INTERVAL;
Expand Down
10 changes: 5 additions & 5 deletions src/common/types/dtime_t.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ dtime_t Time::FromCString(const char* buf, uint64_t len) {
dtime_t result;
uint64_t pos;
if (!Time::TryConvertTime(buf, len, pos, result)) {
throw ConversionException(StringUtils::string_format(
"Error occurred during parsing time. Given: \"" + std::string(buf, len) +
"\". Expected format: (hh:mm:ss[.zzzzzz])."));
throw ConversionException(fmt::format("Error occurred during parsing time. Given: \"{}\". "
"Expected format: (hh:mm:ss[.zzzzzz]).",
std::string(buf, len)));
}
return result;
}
Expand All @@ -159,8 +159,8 @@ bool Time::IsValid(int32_t hour, int32_t minute, int32_t second, int32_t microse

dtime_t Time::FromTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) {
if (!Time::IsValid(hour, minute, second, microseconds)) {
throw ConversionException(StringUtils::string_format(
"Time field value out of range: %d:%d:%d[.%d].", hour, minute, second, microseconds));
throw ConversionException(fmt::format(
"Time field value out of range: {}:{}:{}[.{}].", hour, minute, second, microseconds));
}
int64_t result;
result = hour; // hours
Expand Down
21 changes: 10 additions & 11 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ void Value::copyValueFrom(const uint8_t* value) {
case DOUBLE: {
val.doubleVal = *((double*)value);
} break;
case FLOAT: {
val.floatVal = *((float_t*)value);
} break;
case DATE: {
val.dateVal = *((date_t*)value);
} break;
Expand All @@ -177,9 +180,6 @@ void Value::copyValueFrom(const uint8_t* value) {
case FIXED_LIST: {
listVal = convertKUFixedListToVector(value);
} break;
case FLOAT: {
val.floatVal = *((float_t*)value);
} break;
default:
throw RuntimeException(
"Data type " + Types::dataTypeToString(dataType) + " is not supported for Value::set");
Expand Down Expand Up @@ -209,6 +209,9 @@ void Value::copyValueFrom(const Value& other) {
case DOUBLE: {
val.doubleVal = other.val.doubleVal;
} break;
case FLOAT: {
val.floatVal = other.val.floatVal;
} break;
case DATE: {
val.dateVal = other.val.dateVal;
} break;
Expand Down Expand Up @@ -236,9 +239,6 @@ void Value::copyValueFrom(const Value& other) {
case REL: {
relVal = other.relVal->copy();
} break;
case FLOAT: {
val.floatVal = other.val.floatVal;
} break;
default:
throw NotImplementedException("Value::Value(const Value&) for type " +
Types::dataTypeToString(dataType) + " is not implemented.");
Expand All @@ -264,6 +264,8 @@ std::string Value::toString() const {
return TypeUtils::toString(val.int16Val);
case DOUBLE:
return TypeUtils::toString(val.doubleVal);
case FLOAT:
return TypeUtils::toString(val.floatVal);
case DATE:
return TypeUtils::toString(val.dateVal);
case TIMESTAMP:
Expand All @@ -287,8 +289,6 @@ std::string Value::toString() const {
return nodeVal->toString();
case REL:
return relVal->toString();
case FLOAT:
return TypeUtils::toString(val.floatVal);
default:
throw NotImplementedException("Value::toString for type " +
Types::dataTypeToString(dataType) + " is not implemented.");
Expand All @@ -305,9 +305,8 @@ void Value::validateType(DataTypeID typeID) const {

void Value::validateType(const DataType& type) const {
if (type != dataType) {
throw RuntimeException(
StringUtils::string_format("Cannot get %s value from the %s result value.",
Types::dataTypeToString(type).c_str(), Types::dataTypeToString(dataType).c_str()));
throw RuntimeException(fmt::format("Cannot get {} value from the {} result value.",
Types::dataTypeToString(type), Types::dataTypeToString(dataType)));
}
}

Expand Down
17 changes: 3 additions & 14 deletions src/include/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "common/constants.h"
#include "exception.h"
#include "spdlog/fmt/fmt.h"

namespace spdlog {
class logger;
Expand Down Expand Up @@ -48,23 +49,11 @@ class StringUtils {
}

static bool CharacterIsDigit(char c) { return c >= '0' && c <= '9'; }
template<typename... Args>
static std::string string_format(const std::string& format, Args... args) {
int size_s = snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0'
if (size_s <= 0) {
throw Exception("Error during formatting.");
}
auto size = static_cast<size_t>(size_s);
auto buf = std::make_unique<char[]>(size);
snprintf(buf.get(), size, format.c_str(), args...);
return {buf.get(), buf.get() + size - 1}; // We don't want the '\0' inside
}

static std::string getLongStringErrorMessage(
const char* strToInsert, uint64_t maxAllowedStrSize) {
return StringUtils::string_format(
"Maximum length of strings is %d. Input string's length is %d.", maxAllowedStrSize,
strlen(strToInsert), strToInsert);
return fmt::format("Maximum length of strings is {}. Input string's length is {}.",
maxAllowedStrSize, strlen(strToInsert), strToInsert);
}
};

Expand Down
Loading

0 comments on commit 0d52aa7

Please sign in to comment.