From 758f6d041b63c20923fa007d537743b6dce3f15a Mon Sep 17 00:00:00 2001 From: ffacs Date: Tue, 12 May 2026 13:56:58 +0800 Subject: [PATCH 1/3] ORC-2166: [C++] Validate string lengths in direct encoding reader --- c++/src/ColumnReader.cc | 25 ++++++++++++++-- c++/test/TestColumnReader.cc | 57 ++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/c++/src/ColumnReader.cc b/c++/src/ColumnReader.cc index ab1fe24b1c..fe1eed65c4 100644 --- a/c++/src/ColumnReader.cc +++ b/c++/src/ColumnReader.cc @@ -666,7 +666,11 @@ namespace orc { while (done < numValues) { uint64_t step = std::min(BUFFER_SIZE, static_cast(numValues - done)); lengthRle_->next(buffer, step, nullptr); - totalBytes += computeSize(buffer, nullptr, step); + size_t stepBytes = computeSize(buffer, nullptr, step); + if (totalBytes > std::numeric_limits::max() - stepBytes) { + throw ParseError("String length overflow in StringDirectColumn"); + } + totalBytes += stepBytes; done += step; } if (totalBytes <= lastBufferLength_) { @@ -694,17 +698,32 @@ namespace orc { size_t StringDirectColumnReader::computeSize(const int64_t* lengths, const char* notNull, uint64_t numValues) { size_t totalLength = 0; + bool hasNegativeLength = false; + bool hasLengthOverflow = false; + auto addLength = [&](int64_t value) { + hasNegativeLength |= value < 0; + size_t length = static_cast(value); + size_t nextTotalLength = totalLength + length; + hasLengthOverflow |= nextTotalLength < totalLength; + totalLength = nextTotalLength; + }; if (notNull) { for (size_t i = 0; i < numValues; ++i) { if (notNull[i]) { - totalLength += static_cast(lengths[i]); + addLength(lengths[i]); } } } else { for (size_t i = 0; i < numValues; ++i) { - totalLength += static_cast(lengths[i]); + addLength(lengths[i]); } } + if (hasNegativeLength) { + throw ParseError("Negative string length in StringDirectColumn"); + } + if (hasLengthOverflow) { + throw ParseError("String length overflow in StringDirectColumn"); + } return totalLength; } diff --git a/c++/test/TestColumnReader.cc b/c++/test/TestColumnReader.cc index d2aa38cb66..6ef2de09f3 100644 --- a/c++/test/TestColumnReader.cc +++ b/c++/test/TestColumnReader.cc @@ -42,6 +42,44 @@ namespace orc { return timeptr != nullptr; } + void expectStringDirectLengthError(const unsigned char* lengthData, uint64_t lengthDataSize, + uint64_t numValues, bool skip) { + MockStripeStreams streams; + + std::vector selectedColumns(2, true); + EXPECT_CALL(streams, getSelectedColumns()).WillRepeatedly(testing::Return(selectedColumns)); + + proto::ColumnEncoding directEncoding; + directEncoding.set_kind(proto::ColumnEncoding_Kind_DIRECT); + EXPECT_CALL(streams, getEncoding(testing::_)).WillRepeatedly(testing::Return(directEncoding)); + EXPECT_CALL(streams, getSchemaEvolution()).WillRepeatedly(testing::Return(nullptr)); + + EXPECT_CALL(streams, getStreamProxy(0, proto::Stream_Kind_PRESENT, true)) + .WillRepeatedly(testing::Return(nullptr)); + EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_PRESENT, true)) + .WillRepeatedly(testing::Return(nullptr)); + + const char blob[] = {'x'}; + EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_DATA, true)) + .WillRepeatedly(testing::Return(new SeekableArrayInputStream(blob, ARRAY_SIZE(blob)))); + EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) + .WillRepeatedly( + testing::Return(new SeekableArrayInputStream(lengthData, lengthDataSize))); + + std::unique_ptr rowType = createStructType(); + rowType->addStructField("col0", createPrimitiveType(STRING)); + + std::unique_ptr reader = buildReader(*rowType, streams); + if (skip) { + EXPECT_THROW(reader->skip(numValues), ParseError); + } else { + StructVectorBatch batch(numValues, *getDefaultPool()); + StringVectorBatch* strings = new StringVectorBatch(numValues, *getDefaultPool()); + batch.fields.push_back(strings); + EXPECT_THROW(reader->next(batch, numValues, nullptr), ParseError); + } + } + class TestColumnReaderEncoded : public TestWithParam { void SetUp() override; @@ -882,6 +920,25 @@ namespace orc { EXPECT_THROW(reader->next(batch, 100, 0), ParseError); } + TEST(TestColumnReader, testStringDirectRejectsNegativeLength) { + // RLEv1 literal run with one unsigned value UINT64_MAX, which becomes -1 + // when decoded into int64_t. + const unsigned char lengthData[] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x01}; + expectStringDirectLengthError(lengthData, ARRAY_SIZE(lengthData), 1, false); + expectStringDirectLengthError(lengthData, ARRAY_SIZE(lengthData), 1, true); + } + + TEST(TestColumnReader, testStringDirectRejectsLengthOverflow) { + // RLEv1 literal run with three INT64_MAX lengths. + const unsigned char lengthData[] = { + 0xfd, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}; + expectStringDirectLengthError(lengthData, ARRAY_SIZE(lengthData), 3, false); + expectStringDirectLengthError(lengthData, ARRAY_SIZE(lengthData), 3, true); + } + TEST_P(TestColumnReaderEncoded, testStringDirectShortBuffer) { MockStripeStreams streams; From d51a2aa07b879394ffe01f5a37ee47260fa29ddf Mon Sep 17 00:00:00 2001 From: ffacs Date: Tue, 12 May 2026 22:50:12 +0800 Subject: [PATCH 2/3] Format TestColumnReader --- c++/test/TestColumnReader.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/c++/test/TestColumnReader.cc b/c++/test/TestColumnReader.cc index 6ef2de09f3..7528c9f769 100644 --- a/c++/test/TestColumnReader.cc +++ b/c++/test/TestColumnReader.cc @@ -63,8 +63,7 @@ namespace orc { EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_DATA, true)) .WillRepeatedly(testing::Return(new SeekableArrayInputStream(blob, ARRAY_SIZE(blob)))); EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly( - testing::Return(new SeekableArrayInputStream(lengthData, lengthDataSize))); + .WillRepeatedly(testing::Return(new SeekableArrayInputStream(lengthData, lengthDataSize))); std::unique_ptr rowType = createStructType(); rowType->addStructField("col0", createPrimitiveType(STRING)); @@ -931,10 +930,9 @@ namespace orc { TEST(TestColumnReader, testStringDirectRejectsLengthOverflow) { // RLEv1 literal run with three INT64_MAX lengths. - const unsigned char lengthData[] = { - 0xfd, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}; + const unsigned char lengthData[] = {0xfd, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}; expectStringDirectLengthError(lengthData, ARRAY_SIZE(lengthData), 3, false); expectStringDirectLengthError(lengthData, ARRAY_SIZE(lengthData), 3, true); } From 266b26bbd9cac623a59791de9aaf15da7f2a7f83 Mon Sep 17 00:00:00 2001 From: ffacs Date: Sun, 24 May 2026 21:35:17 +0800 Subject: [PATCH 3/3] Guard C++ integer arithmetic against overflow --- c++/src/BlockBuffer.cc | 28 +++++++-- c++/src/BlockBuffer.hh | 2 +- c++/src/ColumnReader.cc | 94 ++++++++++++++++++++++-------- c++/src/DictionaryLoader.cc | 24 ++++++-- c++/src/MemoryPool.cc | 45 +++++++++++---- c++/src/Utils.hh | 66 +++++++++++++++++++++ c++/src/Vector.cc | 18 ++++-- c++/test/CMakeLists.txt | 1 + c++/test/TestColumnReader.cc | 107 +++++++++++++++++++++++++++++++++++ c++/test/TestUtils.cc | 61 ++++++++++++++++++++ c++/test/meson.build | 1 + 11 files changed, 396 insertions(+), 51 deletions(-) create mode 100644 c++/test/TestUtils.cc diff --git a/c++/src/BlockBuffer.cc b/c++/src/BlockBuffer.cc index 09bf078c85..44cb68c706 100644 --- a/c++/src/BlockBuffer.cc +++ b/c++/src/BlockBuffer.cc @@ -17,10 +17,12 @@ */ #include "BlockBuffer.hh" +#include "Utils.hh" #include "orc/OrcFile.hh" #include "orc/Writer.hh" #include +#include namespace orc { @@ -51,10 +53,19 @@ namespace orc { if (currentSize_ < currentCapacity_) { Block emptyBlock(blocks_[currentSize_ / blockSize_] + currentSize_ % blockSize_, blockSize_ - currentSize_ % blockSize_); - currentSize_ = (currentSize_ / blockSize_ + 1) * blockSize_; + uint64_t nextBlockNumber = currentSize_ / blockSize_ + 1; + uint64_t nextSize = 0; + if (multiplyWithOverflow(nextBlockNumber, blockSize_, &nextSize)) { + throw std::length_error("Block buffer size overflow"); + } + currentSize_ = nextSize; return emptyBlock; } else { - resize(currentSize_ + blockSize_); + uint64_t nextSize = 0; + if (addWithOverflow(currentSize_, blockSize_, &nextSize)) { + throw std::length_error("Block buffer size overflow"); + } + resize(nextSize); return Block(blocks_.back(), blockSize_); } } @@ -70,10 +81,19 @@ namespace orc { void BlockBuffer::reserve(uint64_t newCapacity) { while (currentCapacity_ < newCapacity) { + uint64_t nextCapacity = 0; + if (addWithOverflow(currentCapacity_, blockSize_, &nextCapacity)) { + throw std::length_error("Block buffer capacity overflow"); + } char* newBlockPtr = memoryPool_.malloc(blockSize_); if (newBlockPtr != nullptr) { - blocks_.push_back(newBlockPtr); - currentCapacity_ += blockSize_; + try { + blocks_.push_back(newBlockPtr); + } catch (...) { + memoryPool_.free(newBlockPtr); + throw; + } + currentCapacity_ = nextCapacity; } else { break; } diff --git a/c++/src/BlockBuffer.hh b/c++/src/BlockBuffer.hh index 6d265b0e32..f288de4a01 100644 --- a/c++/src/BlockBuffer.hh +++ b/c++/src/BlockBuffer.hh @@ -94,7 +94,7 @@ namespace orc { * Get the number of blocks that are fully or partially occupied */ uint64_t getBlockNumber() const { - return (currentSize_ + blockSize_ - 1) / blockSize_; + return currentSize_ / blockSize_ + (currentSize_ % blockSize_ == 0 ? 0 : 1); } uint64_t size() const { diff --git a/c++/src/ColumnReader.cc b/c++/src/ColumnReader.cc index fe1eed65c4..516b227647 100644 --- a/c++/src/ColumnReader.cc +++ b/c++/src/ColumnReader.cc @@ -27,11 +27,13 @@ #include "DictionaryLoader.hh" #include "RLE.hh" #include "SchemaEvolution.hh" +#include "Utils.hh" #include "orc/Exceptions.hh" #include "orc/Int128.hh" #include #include +#include #include namespace orc { @@ -127,6 +129,26 @@ namespace orc { } } + void addLengthToTotal(uint64_t* total, int64_t length, const char* columnKind) { + if (length < 0) { + throw ParseError(std::string("Negative length in ") + columnKind + " column"); + } + uint64_t nextTotal = 0; + if (addWithOverflow(*total, static_cast(length), &nextTotal) || + nextTotal > static_cast((std::numeric_limits::max)())) { + throw ParseError(std::string("Length overflow in ") + columnKind + " column"); + } + *total = nextTotal; + } + + void incrementUnionChildCount(int64_t* counts, size_t tag) { + int64_t nextCount = 0; + if (addWithOverflow(counts[tag], static_cast(1), &nextCount)) { + throw ParseError("Union child count overflow"); + } + counts[tag] = nextCount; + } + template class BooleanColumnReader : public ColumnReader { private: @@ -428,11 +450,24 @@ namespace orc { uint64_t numValues) { numValues = ColumnReader::skip(numValues); - if (static_cast(bufferEnd_ - bufferPointer_) >= bytesPerValue_ * numValues) { - bufferPointer_ += bytesPerValue_ * numValues; + if (numValues > static_cast((std::numeric_limits::max)())) { + throw ParseError("Double column skip size overflow"); + } + size_t bytesToSkip = 0; + if (multiplyWithOverflow(static_cast(bytesPerValue_), static_cast(numValues), + &bytesToSkip)) { + throw ParseError("Double column skip size overflow"); + } + if (bytesToSkip == 0) { + return numValues; + } + + size_t bufferedBytes = + bufferPointer_ == nullptr ? 0 : static_cast(bufferEnd_ - bufferPointer_); + if (bufferedBytes >= bytesToSkip) { + bufferPointer_ += bytesToSkip; } else { - size_t sizeToSkip = - bytesPerValue_ * numValues - static_cast(bufferEnd_ - bufferPointer_); + size_t sizeToSkip = bytesToSkip - bufferedBytes; const size_t cap = static_cast(std::numeric_limits::max()); while (sizeToSkip != 0) { size_t step = sizeToSkip > cap ? cap : sizeToSkip; @@ -498,7 +533,7 @@ namespace orc { if (!stream->Next(&chunk, &length)) { throw ParseError("bad read in readFully"); } - if (posn + length > bufferSize) { + if (length < 0 || length > bufferSize - posn) { throw ParseError("Corrupt dictionary blob in StringDictionaryColumn"); } memcpy(buffer + posn, chunk, static_cast(length)); @@ -667,10 +702,11 @@ namespace orc { uint64_t step = std::min(BUFFER_SIZE, static_cast(numValues - done)); lengthRle_->next(buffer, step, nullptr); size_t stepBytes = computeSize(buffer, nullptr, step); - if (totalBytes > std::numeric_limits::max() - stepBytes) { + size_t nextTotalBytes = 0; + if (addWithOverflow(totalBytes, stepBytes, &nextTotalBytes)) { throw ParseError("String length overflow in StringDirectColumn"); } - totalBytes += stepBytes; + totalBytes = nextTotalBytes; done += step; } if (totalBytes <= lastBufferLength_) { @@ -701,11 +737,17 @@ namespace orc { bool hasNegativeLength = false; bool hasLengthOverflow = false; auto addLength = [&](int64_t value) { - hasNegativeLength |= value < 0; + if (value < 0) { + hasNegativeLength = true; + return; + } size_t length = static_cast(value); - size_t nextTotalLength = totalLength + length; - hasLengthOverflow |= nextTotalLength < totalLength; - totalLength = nextTotalLength; + size_t nextTotalLength = 0; + bool overflow = addWithOverflow(totalLength, length, &nextTotalLength); + hasLengthOverflow |= overflow; + if (!overflow) { + totalLength = nextTotalLength; + } }; if (notNull) { for (size_t i = 0; i < numValues; ++i) { @@ -747,7 +789,7 @@ namespace orc { size_t bytesBuffered = 0; byteBatch.blob.resize(totalLength); char* ptr = byteBatch.blob.data(); - while (bytesBuffered + lastBufferLength_ < totalLength) { + while (bytesBuffered < totalLength && lastBufferLength_ < totalLength - bytesBuffered) { if (lastBuffer_ != nullptr) { memcpy(ptr + bytesBuffered, lastBuffer_, lastBufferLength_); } @@ -941,7 +983,7 @@ namespace orc { uint64_t chunk = std::min(numValues - lengthsRead, BUFFER_SIZE); rle_->next(buffer, chunk, nullptr); for (size_t i = 0; i < chunk; ++i) { - childrenElements += static_cast(buffer[i]); + addLengthToTotal(&childrenElements, buffer[i], "List"); } lengthsRead += chunk; } @@ -973,18 +1015,18 @@ namespace orc { if (notNull) { for (size_t i = 0; i < numValues; ++i) { if (notNull[i]) { - uint64_t tmp = static_cast(offsets[i]); + int64_t length = offsets[i]; offsets[i] = static_cast(totalChildren); - totalChildren += tmp; + addLengthToTotal(&totalChildren, length, "List"); } else { offsets[i] = static_cast(totalChildren); } } } else { for (size_t i = 0; i < numValues; ++i) { - uint64_t tmp = static_cast(offsets[i]); + int64_t length = offsets[i]; offsets[i] = static_cast(totalChildren); - totalChildren += tmp; + addLengthToTotal(&totalChildren, length, "List"); } } offsets[numValues] = static_cast(totalChildren); @@ -1069,7 +1111,7 @@ namespace orc { uint64_t chunk = std::min(numValues - lengthsRead, BUFFER_SIZE); rle_->next(buffer, chunk, nullptr); for (size_t i = 0; i < chunk; ++i) { - childrenElements += static_cast(buffer[i]); + addLengthToTotal(&childrenElements, buffer[i], "Map"); } lengthsRead += chunk; } @@ -1106,18 +1148,18 @@ namespace orc { if (notNull) { for (size_t i = 0; i < numValues; ++i) { if (notNull[i]) { - uint64_t tmp = static_cast(offsets[i]); + int64_t length = offsets[i]; offsets[i] = static_cast(totalChildren); - totalChildren += tmp; + addLengthToTotal(&totalChildren, length, "Map"); } else { offsets[i] = static_cast(totalChildren); } } } else { for (size_t i = 0; i < numValues; ++i) { - uint64_t tmp = static_cast(offsets[i]); + int64_t length = offsets[i]; offsets[i] = static_cast(totalChildren); - totalChildren += tmp; + addLengthToTotal(&totalChildren, length, "Map"); } } offsets[numValues] = static_cast(totalChildren); @@ -1218,7 +1260,7 @@ namespace orc { uint64_t chunk = std::min(numValues - lengthsRead, BUFFER_SIZE); rle_->next(reinterpret_cast(buffer), chunk, nullptr); for (size_t i = 0; i < chunk; ++i) { - counts[getCheckedUnionTag(buffer[i], numChildren_)] += 1; + incrementUnionChildCount(counts, getCheckedUnionTag(buffer[i], numChildren_)); } lengthsRead += chunk; } @@ -1255,13 +1297,15 @@ namespace orc { for (size_t i = 0; i < numValues; ++i) { if (notNull[i]) { size_t tag = getCheckedUnionTag(tags[i], numChildren_); - offsets[i] = static_cast(counts[tag]++); + offsets[i] = static_cast(counts[tag]); + incrementUnionChildCount(counts, tag); } } } else { for (size_t i = 0; i < numValues; ++i) { size_t tag = getCheckedUnionTag(tags[i], numChildren_); - offsets[i] = static_cast(counts[tag]++); + offsets[i] = static_cast(counts[tag]); + incrementUnionChildCount(counts, tag); } } // read the right number of each child column diff --git a/c++/src/DictionaryLoader.cc b/c++/src/DictionaryLoader.cc index 428d288d57..24e59e7529 100644 --- a/c++/src/DictionaryLoader.cc +++ b/c++/src/DictionaryLoader.cc @@ -18,6 +18,7 @@ #include "DictionaryLoader.hh" #include "RLE.hh" +#include "Utils.hh" namespace orc { @@ -32,7 +33,7 @@ namespace orc { if (!stream->Next(&chunk, &length)) { throw ParseError("bad read in readFully"); } - if (posn + length > bufferSize) { + if (length < 0 || length > bufferSize - posn) { throw ParseError("Corrupt dictionary blob"); } memcpy(buffer + posn, chunk, static_cast(length)); @@ -64,19 +65,32 @@ namespace orc { createRleDecoder(std::move(stream), false, rleVersion, pool, stripe.getReaderMetrics()); // Decode dictionary entry lengths - dictionary->dictionaryOffset.resize(dictSize + 1); + uint64_t dictionaryOffsetSize = 0; + if (addWithOverflow(static_cast(dictSize), static_cast(1), + &dictionaryOffsetSize)) { + std::stringstream ss; + ss << "Dictionary size overflow for column " << columnId; + throw ParseError(ss.str()); + } + dictionary->dictionaryOffset.resize(dictionaryOffsetSize); int64_t* lengthArray = dictionary->dictionaryOffset.data(); lengthDecoder->next(lengthArray + 1, dictSize, nullptr); lengthArray[0] = 0; // Convert lengths to cumulative offsets - for (uint32_t i = 1; i < dictSize + 1; ++i) { + for (uint64_t i = 1; i < dictionaryOffsetSize; ++i) { if (lengthArray[i] < 0) { std::stringstream ss; ss << "Negative dictionary entry length for column " << columnId; throw ParseError(ss.str()); } - lengthArray[i] += lengthArray[i - 1]; + int64_t nextOffset = 0; + if (addWithOverflow(lengthArray[i], lengthArray[i - 1], &nextOffset)) { + std::stringstream ss; + ss << "Dictionary entry length overflow for column " << columnId; + throw ParseError(ss.str()); + } + lengthArray[i] = nextOffset; } int64_t blobSize = lengthArray[dictSize]; @@ -97,4 +111,4 @@ namespace orc { return dictionary; } -} // namespace orc \ No newline at end of file +} // namespace orc diff --git a/c++/src/MemoryPool.cc b/c++/src/MemoryPool.cc index 43f5b6212b..c7e190bd6a 100644 --- a/c++/src/MemoryPool.cc +++ b/c++/src/MemoryPool.cc @@ -20,10 +20,14 @@ #include "orc/Int128.hh" #include "Adaptor.hh" +#include "Utils.hh" #include #include #include +#include +#include +#include #include namespace orc { @@ -52,6 +56,15 @@ namespace orc { // PASS } + template + uint64_t checkedBufferSize(uint64_t count) { + uint64_t bytes = 0; + if (multiplyWithOverflow(static_cast(sizeof(T)), count, &bytes)) { + throw std::length_error("DataBuffer allocation size overflow"); + } + return bytes; + } + template DataBuffer::DataBuffer(MemoryPool& pool, uint64_t newSize, bool ownBuf) : memoryPool_(pool), buf_(nullptr), currentSize_(0), currentCapacity_(0), ownBuffer_(ownBuf) { @@ -113,13 +126,21 @@ namespace orc { return; } if (newCapacity > currentCapacity_ || !buf_) { + uint64_t newBytes = checkedBufferSize(newCapacity); if (buf_) { T* buf_old = buf_; - buf_ = reinterpret_cast(memoryPool_.malloc(sizeof(T) * newCapacity)); - memcpy(buf_, buf_old, sizeof(T) * currentSize_); + T* newBuf = reinterpret_cast(memoryPool_.malloc(newBytes)); + if (newBuf == nullptr && newBytes != 0) { + throw std::bad_alloc(); + } + buf_ = newBuf; + memcpy(buf_, buf_old, checkedBufferSize(currentSize_)); memoryPool_.free(reinterpret_cast(buf_old)); } else { - buf_ = reinterpret_cast(memoryPool_.malloc(sizeof(T) * newCapacity)); + buf_ = reinterpret_cast(memoryPool_.malloc(newBytes)); + if (buf_ == nullptr && newBytes != 0) { + throw std::bad_alloc(); + } } currentCapacity_ = newCapacity; } @@ -127,7 +148,7 @@ namespace orc { template void DataBuffer::zeroOut() { - memset(buf_, 0, sizeof(T) * currentCapacity_); + memset(buf_, 0, checkedBufferSize(currentCapacity_)); } template @@ -187,7 +208,7 @@ namespace orc { } reserve(newSize); if (newSize > currentSize_) { - memset(buf_ + currentSize_, 0, (newSize - currentSize_) * sizeof(char*)); + memset(buf_ + currentSize_, 0, checkedBufferSize(newSize - currentSize_)); } currentSize_ = newSize; } @@ -208,7 +229,7 @@ namespace orc { } reserve(newSize); if (newSize > currentSize_) { - memset(buf_ + currentSize_, 0, (newSize - currentSize_) * sizeof(double)); + memset(buf_ + currentSize_, 0, checkedBufferSize(newSize - currentSize_)); } currentSize_ = newSize; } @@ -229,7 +250,7 @@ namespace orc { } reserve(newSize); if (newSize > currentSize_) { - memset(buf_ + currentSize_, 0, (newSize - currentSize_) * sizeof(float)); + memset(buf_ + currentSize_, 0, checkedBufferSize(newSize - currentSize_)); } currentSize_ = newSize; } @@ -250,7 +271,7 @@ namespace orc { } reserve(newSize); if (newSize > currentSize_) { - memset(buf_ + currentSize_, 0, (newSize - currentSize_) * sizeof(int64_t)); + memset(buf_ + currentSize_, 0, checkedBufferSize(newSize - currentSize_)); } currentSize_ = newSize; } @@ -271,7 +292,7 @@ namespace orc { } reserve(newSize); if (newSize > currentSize_) { - memset(buf_ + currentSize_, 0, (newSize - currentSize_) * sizeof(int32_t)); + memset(buf_ + currentSize_, 0, checkedBufferSize(newSize - currentSize_)); } currentSize_ = newSize; } @@ -292,7 +313,7 @@ namespace orc { } reserve(newSize); if (newSize > currentSize_) { - memset(buf_ + currentSize_, 0, (newSize - currentSize_) * sizeof(int16_t)); + memset(buf_ + currentSize_, 0, checkedBufferSize(newSize - currentSize_)); } currentSize_ = newSize; } @@ -313,7 +334,7 @@ namespace orc { } reserve(newSize); if (newSize > currentSize_) { - memset(buf_ + currentSize_, 0, (newSize - currentSize_) * sizeof(int8_t)); + memset(buf_ + currentSize_, 0, checkedBufferSize(newSize - currentSize_)); } currentSize_ = newSize; } @@ -334,7 +355,7 @@ namespace orc { } reserve(newSize); if (newSize > currentSize_) { - memset(buf_ + currentSize_, 0, (newSize - currentSize_) * sizeof(uint64_t)); + memset(buf_ + currentSize_, 0, checkedBufferSize(newSize - currentSize_)); } currentSize_ = newSize; } diff --git a/c++/src/Utils.hh b/c++/src/Utils.hh index 851d0af15c..87f3ea85ac 100644 --- a/c++/src/Utils.hh +++ b/c++/src/Utils.hh @@ -21,10 +21,76 @@ #include #include +#include #include +#include namespace orc { + template + inline bool addWithOverflow(T left, T right, T* result) { + static_assert(std::is_integral::value, "addWithOverflow requires an integral type"); +#if defined(__GNUC__) || defined(__clang__) + return __builtin_add_overflow(left, right, result); +#else + if constexpr (std::is_unsigned::value) { + *result = left + right; + return *result < left; + } else { + if ((right > 0 && left > (std::numeric_limits::max)() - right) || + (right < 0 && left < (std::numeric_limits::min)() - right)) { + return true; + } + *result = left + right; + return false; + } +#endif + } + + template + inline bool multiplyWithOverflow(T left, T right, T* result) { + static_assert(std::is_integral::value, "multiplyWithOverflow requires an integral type"); +#if defined(__GNUC__) || defined(__clang__) + return __builtin_mul_overflow(left, right, result); +#else + if constexpr (std::is_unsigned::value) { + if (right != 0 && left > (std::numeric_limits::max)() / right) { + return true; + } + *result = left * right; + return false; + } else { + if (left == 0 || right == 0) { + *result = 0; + return false; + } + if ((left == -1 && right == (std::numeric_limits::min)()) || + (right == -1 && left == (std::numeric_limits::min)())) { + return true; + } + if (left > 0) { + if (right > 0) { + if (left > (std::numeric_limits::max)() / right) { + return true; + } + } else if (right < (std::numeric_limits::min)() / left) { + return true; + } + } else { + if (right > 0) { + if (left < (std::numeric_limits::min)() / right) { + return true; + } + } else if (right < (std::numeric_limits::max)() / left) { + return true; + } + } + *result = left * right; + return false; + } +#endif + } + class AutoStopwatch { std::chrono::high_resolution_clock::time_point start_; std::atomic* latencyUs_; diff --git a/c++/src/Vector.cc b/c++/src/Vector.cc index 49f47aeb03..87a7edef97 100644 --- a/c++/src/Vector.cc +++ b/c++/src/Vector.cc @@ -19,15 +19,25 @@ #include "orc/Vector.hh" #include "Adaptor.hh" +#include "Utils.hh" #include "orc/Exceptions.hh" #include "orc/MemoryPool.hh" #include #include #include +#include namespace orc { + uint64_t checkedOffsetCapacity(uint64_t capacity) { + uint64_t result = 0; + if (addWithOverflow(capacity, static_cast(1), &result)) { + throw std::length_error("Vector offset capacity overflow"); + } + return result; + } + ColumnVectorBatch::ColumnVectorBatch(uint64_t cap, MemoryPool& pool) : capacity(cap), numElements(0), @@ -200,7 +210,7 @@ namespace orc { } ListVectorBatch::ListVectorBatch(uint64_t cap, MemoryPool& pool) - : ColumnVectorBatch(cap, pool), offsets(pool, cap + 1) { + : ColumnVectorBatch(cap, pool), offsets(pool, checkedOffsetCapacity(cap)) { offsets.zeroOut(); } @@ -218,7 +228,7 @@ namespace orc { void ListVectorBatch::resize(uint64_t cap) { if (capacity < cap) { ColumnVectorBatch::resize(cap); - offsets.resize(cap + 1); + offsets.resize(checkedOffsetCapacity(cap)); } } @@ -241,7 +251,7 @@ namespace orc { } MapVectorBatch::MapVectorBatch(uint64_t cap, MemoryPool& pool) - : ColumnVectorBatch(cap, pool), offsets(pool, cap + 1) { + : ColumnVectorBatch(cap, pool), offsets(pool, checkedOffsetCapacity(cap)) { offsets.zeroOut(); } @@ -260,7 +270,7 @@ namespace orc { void MapVectorBatch::resize(uint64_t cap) { if (capacity < cap) { ColumnVectorBatch::resize(cap); - offsets.resize(cap + 1); + offsets.resize(checkedOffsetCapacity(cap)); } } diff --git a/c++/test/CMakeLists.txt b/c++/test/CMakeLists.txt index 93b066db64..2c5e7dfb26 100644 --- a/c++/test/CMakeLists.txt +++ b/c++/test/CMakeLists.txt @@ -55,6 +55,7 @@ add_executable (orc-test TestTimezone.cc TestType.cc TestUtil.cc + TestUtils.cc TestWriter.cc TestCache.cc ${SIMD_TEST_SRCS} diff --git a/c++/test/TestColumnReader.cc b/c++/test/TestColumnReader.cc index 7528c9f769..7181f88322 100644 --- a/c++/test/TestColumnReader.cc +++ b/c++/test/TestColumnReader.cc @@ -79,6 +79,73 @@ namespace orc { } } + const unsigned char EMPTY_DATA[] = {0x00}; + + std::unique_ptr buildListLengthReader(MockStripeStreams& streams, + const unsigned char* lengthData, + uint64_t lengthDataSize) { + std::vector selectedColumns(3, true); + EXPECT_CALL(streams, getSelectedColumns()).WillRepeatedly(testing::Return(selectedColumns)); + + proto::ColumnEncoding directEncoding; + directEncoding.set_kind(proto::ColumnEncoding_Kind_DIRECT); + EXPECT_CALL(streams, getEncoding(testing::_)).WillRepeatedly(testing::Return(directEncoding)); + EXPECT_CALL(streams, getSchemaEvolution()).WillRepeatedly(testing::Return(nullptr)); + + EXPECT_CALL(streams, getStreamProxy(testing::_, proto::Stream_Kind_PRESENT, true)) + .WillRepeatedly(testing::Return(nullptr)); + EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) + .WillRepeatedly(testing::Return(new SeekableArrayInputStream(lengthData, lengthDataSize))); + EXPECT_CALL(streams, getStreamProxy(2, proto::Stream_Kind_DATA, true)) + .WillRepeatedly(testing::Return(new SeekableArrayInputStream(EMPTY_DATA, 0))); + + std::unique_ptr rowType = createStructType(); + rowType->addStructField("col0", createListType(createPrimitiveType(LONG))); + return buildReader(*rowType, streams); + } + + std::unique_ptr buildMapLengthReader(MockStripeStreams& streams, + const unsigned char* lengthData, + uint64_t lengthDataSize) { + std::vector selectedColumns(4, true); + EXPECT_CALL(streams, getSelectedColumns()).WillRepeatedly(testing::Return(selectedColumns)); + + proto::ColumnEncoding directEncoding; + directEncoding.set_kind(proto::ColumnEncoding_Kind_DIRECT); + EXPECT_CALL(streams, getEncoding(testing::_)).WillRepeatedly(testing::Return(directEncoding)); + EXPECT_CALL(streams, getSchemaEvolution()).WillRepeatedly(testing::Return(nullptr)); + + EXPECT_CALL(streams, getStreamProxy(testing::_, proto::Stream_Kind_PRESENT, true)) + .WillRepeatedly(testing::Return(nullptr)); + EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) + .WillRepeatedly(testing::Return(new SeekableArrayInputStream(lengthData, lengthDataSize))); + EXPECT_CALL(streams, getStreamProxy(2, proto::Stream_Kind_DATA, true)) + .WillRepeatedly(testing::Return(new SeekableArrayInputStream(EMPTY_DATA, 0))); + EXPECT_CALL(streams, getStreamProxy(3, proto::Stream_Kind_DATA, true)) + .WillRepeatedly(testing::Return(new SeekableArrayInputStream(EMPTY_DATA, 0))); + + std::unique_ptr rowType = createStructType(); + rowType->addStructField("col0", + createMapType(createPrimitiveType(LONG), createPrimitiveType(LONG))); + return buildReader(*rowType, streams); + } + + template + void expectCollectionLengthError(std::unique_ptr (*buildReaderFunc)( + MockStripeStreams&, const unsigned char*, uint64_t), + const unsigned char* lengthData, uint64_t lengthDataSize, + uint64_t numValues, bool skip, BatchType* collectionBatch) { + MockStripeStreams streams; + std::unique_ptr reader = buildReaderFunc(streams, lengthData, lengthDataSize); + if (skip) { + EXPECT_THROW(reader->skip(numValues), ParseError); + } else { + StructVectorBatch batch(numValues, *getDefaultPool()); + batch.fields.push_back(collectionBatch); + EXPECT_THROW(reader->next(batch, numValues, nullptr), ParseError); + } + } + class TestColumnReaderEncoded : public TestWithParam { void SetUp() override; @@ -937,6 +1004,46 @@ namespace orc { expectStringDirectLengthError(lengthData, ARRAY_SIZE(lengthData), 3, true); } + TEST(TestColumnReader, testListRejectsInvalidLengths) { + const unsigned char negativeLengthData[] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x01}; + expectCollectionLengthError(buildListLengthReader, negativeLengthData, + ARRAY_SIZE(negativeLengthData), 1, false, + new ListVectorBatch(1, *getDefaultPool())); + expectCollectionLengthError(buildListLengthReader, negativeLengthData, + ARRAY_SIZE(negativeLengthData), 1, true, nullptr); + + // RLEv1 literal run with three INT64_MAX lengths. + const unsigned char overflowLengthData[] = { + 0xfd, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}; + expectCollectionLengthError(buildListLengthReader, overflowLengthData, + ARRAY_SIZE(overflowLengthData), 2, false, + new ListVectorBatch(2, *getDefaultPool())); + expectCollectionLengthError(buildListLengthReader, overflowLengthData, + ARRAY_SIZE(overflowLengthData), 2, true, nullptr); + } + + TEST(TestColumnReader, testMapRejectsInvalidLengths) { + const unsigned char negativeLengthData[] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x01}; + expectCollectionLengthError(buildMapLengthReader, negativeLengthData, + ARRAY_SIZE(negativeLengthData), 1, false, + new MapVectorBatch(1, *getDefaultPool())); + expectCollectionLengthError(buildMapLengthReader, negativeLengthData, + ARRAY_SIZE(negativeLengthData), 1, true, nullptr); + + // RLEv1 literal run with three INT64_MAX lengths. + const unsigned char overflowLengthData[] = { + 0xfd, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}; + expectCollectionLengthError(buildMapLengthReader, overflowLengthData, + ARRAY_SIZE(overflowLengthData), 2, false, + new MapVectorBatch(2, *getDefaultPool())); + expectCollectionLengthError(buildMapLengthReader, overflowLengthData, + ARRAY_SIZE(overflowLengthData), 2, true, nullptr); + } + TEST_P(TestColumnReaderEncoded, testStringDirectShortBuffer) { MockStripeStreams streams; diff --git a/c++/test/TestUtils.cc b/c++/test/TestUtils.cc new file mode 100644 index 0000000000..55caf6e512 --- /dev/null +++ b/c++/test/TestUtils.cc @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Utils.hh" +#include "wrap/gtest-wrapper.h" + +#include + +namespace orc { + + TEST(Utils, testAddWithOverflow) { + uint64_t unsignedResult = 0; + EXPECT_FALSE( + addWithOverflow(static_cast(1), static_cast(2), &unsignedResult)); + EXPECT_EQ(3, unsignedResult); + EXPECT_TRUE(addWithOverflow((std::numeric_limits::max)(), static_cast(1), + &unsignedResult)); + + int64_t signedResult = 0; + EXPECT_FALSE(addWithOverflow(static_cast(-2), static_cast(1), &signedResult)); + EXPECT_EQ(-1, signedResult); + EXPECT_TRUE(addWithOverflow((std::numeric_limits::max)(), static_cast(1), + &signedResult)); + EXPECT_TRUE(addWithOverflow((std::numeric_limits::min)(), static_cast(-1), + &signedResult)); + } + + TEST(Utils, testMultiplyWithOverflow) { + uint64_t unsignedResult = 0; + EXPECT_FALSE( + multiplyWithOverflow(static_cast(6), static_cast(7), &unsignedResult)); + EXPECT_EQ(42, unsignedResult); + EXPECT_TRUE(multiplyWithOverflow((std::numeric_limits::max)(), + static_cast(2), &unsignedResult)); + + int64_t signedResult = 0; + EXPECT_FALSE( + multiplyWithOverflow(static_cast(-6), static_cast(7), &signedResult)); + EXPECT_EQ(-42, signedResult); + EXPECT_TRUE(multiplyWithOverflow((std::numeric_limits::max)(), static_cast(2), + &signedResult)); + EXPECT_TRUE(multiplyWithOverflow((std::numeric_limits::min)(), + static_cast(-1), &signedResult)); + } + +} // namespace orc diff --git a/c++/test/meson.build b/c++/test/meson.build index a8d30a6b94..e9bb9eed94 100644 --- a/c++/test/meson.build +++ b/c++/test/meson.build @@ -56,6 +56,7 @@ test_sources = [ 'TestTimezone.cc', 'TestType.cc', 'TestUtil.cc', + 'TestUtils.cc', 'TestWriter.cc', 'TestCache.cc', ]