diff --git a/cpp/src/arrow/ipc/message.cc b/cpp/src/arrow/ipc/message.cc index b089ab25e947..75a8cd7a11b1 100644 --- a/cpp/src/arrow/ipc/message.cc +++ b/cpp/src/arrow/ipc/message.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include "arrow/buffer.h" #include "arrow/device.h" @@ -183,33 +184,35 @@ Status CheckMetadataAndGetBodyLength(const Buffer& metadata, int64_t* body_lengt Result> Message::ReadFrom(std::shared_ptr metadata, io::InputStream* stream) { - RETURN_NOT_OK(MaybeAlignMetadata(&metadata)); - int64_t body_length = -1; - RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata, &body_length)); - - ARROW_ASSIGN_OR_RAISE(auto body, stream->Read(body_length)); - if (body->size() < body_length) { - return Status::IOError("Expected to be able to read ", body_length, + std::unique_ptr result; + auto listener = std::make_shared(&result); + MessageDecoder decoder(listener, MessageDecoder::State::METADATA, metadata->size()); + ARROW_RETURN_NOT_OK(decoder.Consume(metadata)); + + ARROW_ASSIGN_OR_RAISE(auto body, stream->Read(decoder.next_required_size())); + if (body->size() < decoder.next_required_size()) { + return Status::IOError("Expected to be able to read ", decoder.next_required_size(), " bytes for message body, got ", body->size()); } - - return Message::Open(metadata, body); + RETURN_NOT_OK(decoder.Consume(body)); + return std::move(result); } Result> Message::ReadFrom(const int64_t offset, std::shared_ptr metadata, io::RandomAccessFile* file) { - RETURN_NOT_OK(MaybeAlignMetadata(&metadata)); - int64_t body_length = -1; - RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata, &body_length)); - - ARROW_ASSIGN_OR_RAISE(auto body, file->ReadAt(offset, body_length)); - if (body->size() < body_length) { - return Status::IOError("Expected to be able to read ", body_length, + std::unique_ptr result; + auto listener = std::make_shared(&result); + MessageDecoder decoder(listener, MessageDecoder::State::METADATA, metadata->size()); + ARROW_RETURN_NOT_OK(decoder.Consume(metadata)); + + ARROW_ASSIGN_OR_RAISE(auto body, file->ReadAt(offset, decoder.next_required_size())); + if (body->size() < decoder.next_required_size()) { + return Status::IOError("Expected to be able to read ", decoder.next_required_size(), " bytes for message body, got ", body->size()); } - - return Message::Open(metadata, body); + RETURN_NOT_OK(decoder.Consume(body)); + return std::move(result); } Status WritePadding(io::OutputStream* stream, int64_t nbytes) { @@ -263,53 +266,48 @@ std::string FormatMessageType(Message::Type type) { Result> ReadMessage(int64_t offset, int32_t metadata_length, io::RandomAccessFile* file) { - if (static_cast(metadata_length) < sizeof(int32_t)) { - return Status::Invalid("metadata_length should be at least 4"); - } + std::unique_ptr result; + auto listener = std::make_shared(&result); + MessageDecoder decoder(listener); - ARROW_ASSIGN_OR_RAISE(auto buffer, file->ReadAt(offset, metadata_length)); + if (metadata_length < decoder.next_required_size()) { + return Status::Invalid("metadata_length should be at least ", + decoder.next_required_size()); + } - if (buffer->size() < metadata_length) { + ARROW_ASSIGN_OR_RAISE(auto metadata, file->ReadAt(offset, metadata_length)); + if (metadata->size() < metadata_length) { return Status::Invalid("Expected to read ", metadata_length, - " metadata bytes but got ", buffer->size()); + " metadata bytes but got ", metadata->size()); } - - const int32_t continuation = util::SafeLoadAs(buffer->data()); - - // The size of the Flatbuffer including padding - int32_t flatbuffer_length = -1; - int32_t prefix_size = -1; - if (continuation == internal::kIpcContinuationToken) { - if (metadata_length < 8) { - return Status::Invalid( - "Corrupted IPC message, had continuation token " - " but length ", - metadata_length); + ARROW_RETURN_NOT_OK(decoder.Consume(metadata)); + + switch (decoder.state()) { + case MessageDecoder::State::INITIAL: + return std::move(result); + case MessageDecoder::State::METADATA_LENGTH: + return Status::Invalid("metadata length is missing. File offset: ", offset, + ", metadata length: ", metadata_length); + case MessageDecoder::State::METADATA: + return Status::Invalid("flatbuffer size ", decoder.next_required_size(), + " invalid. File offset: ", offset, + ", metadata length: ", metadata_length); + case MessageDecoder::State::BODY: { + ARROW_ASSIGN_OR_RAISE(auto body, file->ReadAt(offset + metadata_length, + decoder.next_required_size())); + if (body->size() < decoder.next_required_size()) { + return Status::IOError("Expected to be able to read ", + decoder.next_required_size(), + " bytes for message body, got ", body->size()); + } + RETURN_NOT_OK(decoder.Consume(body)); + return std::move(result); } - - // Valid IPC message, parse the message length now - flatbuffer_length = util::SafeLoadAs(buffer->data() + 4); - prefix_size = 8; - } else { - // ARROW-6314: Backwards compatibility for reading old IPC - // messages produced prior to version 0.15.0 - flatbuffer_length = continuation; - prefix_size = 4; - } - - if (flatbuffer_length == 0) { - return Status::Invalid("Unexpected empty message in IPC file format"); - } - - if (flatbuffer_length != metadata_length - prefix_size) { - return Status::Invalid("flatbuffer size ", flatbuffer_length, - " invalid. File offset: ", offset, - ", metadata length: ", metadata_length); + case MessageDecoder::State::EOS: + return Status::Invalid("Unexpected empty message in IPC file format"); + default: + return Status::Invalid("Unexpected state: ", decoder.state()); } - - std::shared_ptr metadata = - SliceBuffer(buffer, prefix_size, buffer->size() - prefix_size); - return Message::ReadFrom(offset + metadata_length, metadata, file); } Status AlignStream(io::InputStream* stream, int32_t alignment) { @@ -336,43 +334,71 @@ Status CheckAligned(io::FileInterface* stream, int32_t alignment) { } } -Result> ReadMessage(io::InputStream* file, MemoryPool* pool) { - int32_t continuation = 0; - ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, file->Read(sizeof(int32_t), &continuation)); - - if (bytes_read == 0) { - // EOS without indication - return nullptr; - } else if (bytes_read != sizeof(int32_t)) { - return Status::Invalid("Corrupted message, only ", bytes_read, " bytes available"); +Status DecodeMessage(MessageDecoder* decoder, io::InputStream* file) { + if (decoder->state() == MessageDecoder::State::INITIAL) { + uint8_t continuation[sizeof(int32_t)]; + ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, file->Read(sizeof(int32_t), &continuation)); + if (bytes_read == 0) { + // EOS without indication + return Status::OK(); + } else if (bytes_read != decoder->next_required_size()) { + return Status::Invalid("Corrupted message, only ", bytes_read, " bytes available"); + } + ARROW_RETURN_NOT_OK(decoder->Consume(continuation, bytes_read)); } - int32_t flatbuffer_length = -1; - if (continuation == internal::kIpcContinuationToken) { + if (decoder->state() == MessageDecoder::State::METADATA_LENGTH) { // Valid IPC message, read the message length now - ARROW_ASSIGN_OR_RAISE(bytes_read, file->Read(sizeof(int32_t), &flatbuffer_length)); - } else { - // ARROW-6314: Backwards compatibility for reading old IPC - // messages produced prior to version 0.15.0 - flatbuffer_length = continuation; + uint8_t metadata_length[sizeof(int32_t)]; + ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, + file->Read(sizeof(int32_t), &metadata_length)); + if (bytes_read != decoder->next_required_size()) { + return Status::Invalid("Corrupted metadata length, only ", bytes_read, + " bytes available"); + } + ARROW_RETURN_NOT_OK(decoder->Consume(metadata_length, bytes_read)); } - if (flatbuffer_length == 0) { - // EOS - return nullptr; + if (decoder->state() == MessageDecoder::State::EOS) { + return Status::OK(); } - ARROW_ASSIGN_OR_RAISE(auto metadata, file->Read(flatbuffer_length)); - bytes_read = metadata->size(); - if (bytes_read != flatbuffer_length) { - return Status::Invalid("Expected to read ", flatbuffer_length, - " metadata bytes, but ", "only read ", bytes_read); + auto metadata_length = decoder->next_required_size(); + ARROW_ASSIGN_OR_RAISE(auto metadata, file->Read(metadata_length)); + if (metadata->size() != metadata_length) { + return Status::Invalid("Expected to read ", metadata_length, " metadata bytes, but ", + "only read ", metadata->size()); + } + ARROW_RETURN_NOT_OK(decoder->Consume(metadata)); + + if (decoder->state() == MessageDecoder::State::BODY) { + ARROW_ASSIGN_OR_RAISE(auto body, file->Read(decoder->next_required_size())); + if (body->size() < decoder->next_required_size()) { + return Status::IOError("Expected to be able to read ", + decoder->next_required_size(), + " bytes for message body, got ", body->size()); + } + ARROW_RETURN_NOT_OK(decoder->Consume(body)); + } + + if (decoder->state() == MessageDecoder::State::INITIAL || + decoder->state() == MessageDecoder::State::EOS) { + return Status::OK(); + } else { + return Status::Invalid("Failed to decode message"); } - // The buffer could be a non-CPU buffer (e.g. CUDA) - ARROW_ASSIGN_OR_RAISE(metadata, - Buffer::ViewOrCopy(metadata, CPUDevice::memory_manager(pool))); +} - return Message::ReadFrom(metadata, file); +Result> ReadMessage(io::InputStream* file, MemoryPool* pool) { + std::unique_ptr message; + auto listener = std::make_shared(&message); + MessageDecoder decoder(listener, pool); + ARROW_RETURN_NOT_OK(DecodeMessage(&decoder, file)); + if (!message) { + return nullptr; + } else { + return std::move(message); + } } Status WriteMessage(const Buffer& message, const IpcWriteOptions& options, @@ -407,13 +433,391 @@ Status WriteMessage(const Buffer& message, const IpcWriteOptions& options, return Status::OK(); } +// ---------------------------------------------------------------------- +// Implement MessageDecoder + +Status MessageDecoderListener::OnInitial() { return Status::OK(); } +Status MessageDecoderListener::OnMetadataLength() { return Status::OK(); } +Status MessageDecoderListener::OnMetadata() { return Status::OK(); } +Status MessageDecoderListener::OnBody() { return Status::OK(); } +Status MessageDecoderListener::OnEOS() { return Status::OK(); } + +static constexpr auto kMessageDecoderNextRequiredSizeInitial = sizeof(int32_t); +static constexpr auto kMessageDecoderNextRequiredSizeMetadataLength = sizeof(int32_t); + +class MessageDecoder::MessageDecoderImpl { + public: + explicit MessageDecoderImpl(std::shared_ptr listener, + State initial_state, int64_t initial_next_required_size, + MemoryPool* pool) + : listener_(std::move(listener)), + pool_(pool), + state_(initial_state), + next_required_size_(initial_next_required_size), + chunks_(), + buffered_size_(0), + metadata_(nullptr) {} + + Status ConsumeData(const uint8_t* data, int64_t size) { + if (buffered_size_ == 0) { + while (size > 0 && size >= next_required_size_) { + auto used_size = next_required_size_; + switch (state_) { + case State::INITIAL: + RETURN_NOT_OK(ConsumeInitialData(data, next_required_size_)); + break; + case State::METADATA_LENGTH: + RETURN_NOT_OK(ConsumeMetadataLengthData(data, next_required_size_)); + break; + case State::METADATA: { + auto buffer = std::make_shared(data, next_required_size_); + RETURN_NOT_OK(ConsumeMetadataBuffer(&buffer)); + } break; + case State::BODY: { + auto buffer = std::make_shared(data, next_required_size_); + RETURN_NOT_OK(ConsumeBodyBuffer(&buffer)); + } break; + case State::EOS: + return Status::OK(); + } + data += used_size; + size -= used_size; + } + } + + if (size == 0) { + return Status::OK(); + } + + chunks_.push_back(std::make_shared(data, size)); + buffered_size_ += size; + return ConsumeChunks(); + } + + Status ConsumeBuffer(std::shared_ptr* buffer) { + if (buffered_size_ == 0) { + while ((*buffer)->size() >= next_required_size_) { + auto used_size = next_required_size_; + switch (state_) { + case State::INITIAL: + RETURN_NOT_OK(ConsumeInitialBuffer(buffer)); + break; + case State::METADATA_LENGTH: + RETURN_NOT_OK(ConsumeMetadataLengthBuffer(buffer)); + break; + case State::METADATA: + if ((*buffer)->size() == next_required_size_) { + return ConsumeMetadataBuffer(buffer); + } else { + auto sliced_buffer = SliceBuffer(*buffer, 0, next_required_size_); + RETURN_NOT_OK(ConsumeMetadataBuffer(&sliced_buffer)); + } + break; + case State::BODY: + if ((*buffer)->size() == next_required_size_) { + return ConsumeBodyBuffer(buffer); + } else { + auto sliced_buffer = SliceBuffer(*buffer, 0, next_required_size_); + RETURN_NOT_OK(ConsumeBodyBuffer(&sliced_buffer)); + } + break; + case State::EOS: + return Status::OK(); + } + if ((*buffer)->size() == used_size) { + return Status::OK(); + } + *buffer = SliceBuffer(*buffer, used_size); + } + } + + if ((*buffer)->size() == 0) { + return Status::OK(); + } + + buffered_size_ += (*buffer)->size(); + chunks_.push_back(std::move(*buffer)); + return ConsumeChunks(); + } + + int64_t next_required_size() const { return next_required_size_ - buffered_size_; } + + MessageDecoder::State state() const { return state_; } + + private: + Status ConsumeChunks() { + while (state_ != State::EOS) { + if (buffered_size_ < next_required_size_) { + return Status::OK(); + } + + switch (state_) { + case State::INITIAL: + RETURN_NOT_OK(ConsumeInitialChunks()); + break; + case State::METADATA_LENGTH: + RETURN_NOT_OK(ConsumeMetadataLengthChunks()); + break; + case State::METADATA: + RETURN_NOT_OK(ConsumeMetadataChunks()); + break; + case State::BODY: + RETURN_NOT_OK(ConsumeBodyChunks()); + break; + case State::EOS: + return Status::OK(); + } + } + + return Status::OK(); + } + + Status ConsumeInitialData(const uint8_t* data, int64_t size) { + return ConsumeInitial(util::SafeLoadAs(data)); + } + + Status ConsumeInitialBuffer(std::shared_ptr* buffer) { + ARROW_ASSIGN_OR_RAISE(auto continuation, ConsumeDataBufferInt32(buffer)); + return ConsumeInitial(continuation); + } + + Status ConsumeInitialChunks() { + int32_t continuation = 0; + RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &continuation)); + return ConsumeInitial(continuation); + } + + Status ConsumeInitial(int32_t continuation) { + if (continuation == internal::kIpcContinuationToken) { + state_ = State::METADATA_LENGTH; + next_required_size_ = kMessageDecoderNextRequiredSizeMetadataLength; + RETURN_NOT_OK(listener_->OnMetadataLength()); + // Valid IPC message, read the message length now + return Status::OK(); + } else if (continuation == 0) { + state_ = State::EOS; + next_required_size_ = 0; + RETURN_NOT_OK(listener_->OnEOS()); + return Status::OK(); + } else { + state_ = State::METADATA; + // ARROW-6314: Backwards compatibility for reading old IPC + // messages produced prior to version 0.15.0 + next_required_size_ = continuation; + RETURN_NOT_OK(listener_->OnMetadata()); + return Status::OK(); + } + } + + Status ConsumeMetadataLengthData(const uint8_t* data, int64_t size) { + return ConsumeMetadataLength(util::SafeLoadAs(data)); + } + + Status ConsumeMetadataLengthBuffer(std::shared_ptr* buffer) { + ARROW_ASSIGN_OR_RAISE(auto metadata_length, ConsumeDataBufferInt32(buffer)); + return ConsumeMetadataLength(metadata_length); + } + + Status ConsumeMetadataLengthChunks() { + int32_t metadata_length = 0; + RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &metadata_length)); + return ConsumeMetadataLength(metadata_length); + } + + Status ConsumeMetadataLength(int32_t metadata_length) { + if (metadata_length == 0) { + state_ = State::EOS; + next_required_size_ = 0; + RETURN_NOT_OK(listener_->OnEOS()); + return Status::OK(); + } else { + state_ = State::METADATA; + next_required_size_ = metadata_length; + RETURN_NOT_OK(listener_->OnMetadata()); + return Status::OK(); + } + } + + Status ConsumeMetadataBuffer(std::shared_ptr* buffer) { + if ((*buffer)->is_cpu()) { + metadata_ = std::move(*buffer); + } else { + ARROW_ASSIGN_OR_RAISE( + metadata_, Buffer::ViewOrCopy(*buffer, CPUDevice::memory_manager(pool_))); + } + return ConsumeMetadata(); + } + + Status ConsumeMetadataChunks() { + if (chunks_[0]->size() >= next_required_size_) { + if (chunks_[0]->size() == next_required_size_) { + if (chunks_[0]->is_cpu()) { + metadata_ = std::move(chunks_[0]); + } else { + ARROW_ASSIGN_OR_RAISE( + metadata_, + Buffer::ViewOrCopy(chunks_[0], CPUDevice::memory_manager(pool_))); + } + chunks_.erase(chunks_.begin()); + } else { + metadata_ = SliceBuffer(chunks_[0], 0, next_required_size_); + if (!chunks_[0]->is_cpu()) { + ARROW_ASSIGN_OR_RAISE( + metadata_, Buffer::ViewOrCopy(metadata_, CPUDevice::memory_manager(pool_))); + } + chunks_[0] = SliceBuffer(chunks_[0], next_required_size_); + } + buffered_size_ -= next_required_size_; + } else { + ARROW_ASSIGN_OR_RAISE(auto metadata, AllocateBuffer(next_required_size_, pool_)); + metadata_ = std::shared_ptr(metadata.release()); + RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, metadata_->mutable_data())); + } + return ConsumeMetadata(); + } + + Status ConsumeMetadata() { + RETURN_NOT_OK(MaybeAlignMetadata(&metadata_)); + int64_t body_length = -1; + RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata_, &body_length)); + + state_ = State::BODY; + next_required_size_ = body_length; + RETURN_NOT_OK(listener_->OnBody()); + if (next_required_size_ == 0) { + ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_)); + std::shared_ptr shared_body(body.release()); + return ConsumeBody(&shared_body); + } else { + return Status::OK(); + } + } + + Status ConsumeBodyBuffer(std::shared_ptr* buffer) { + return ConsumeBody(buffer); + } + + Status ConsumeBodyChunks() { + if (chunks_[0]->size() >= next_required_size_) { + auto used_size = next_required_size_; + if (chunks_[0]->size() == next_required_size_) { + RETURN_NOT_OK(ConsumeBody(&chunks_[0])); + chunks_.erase(chunks_.begin()); + } else { + auto body = SliceBuffer(chunks_[0], 0, next_required_size_); + RETURN_NOT_OK(ConsumeBody(&body)); + chunks_[0] = SliceBuffer(chunks_[0], used_size); + } + buffered_size_ -= used_size; + return Status::OK(); + } else { + ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(next_required_size_, pool_)); + RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, body->mutable_data())); + std::shared_ptr shared_body(body.release()); + return ConsumeBody(&shared_body); + } + } + + Status ConsumeBody(std::shared_ptr* buffer) { + ARROW_ASSIGN_OR_RAISE(std::unique_ptr message, + Message::Open(metadata_, *buffer)); + + RETURN_NOT_OK(listener_->OnMessageDecoded(std::move(message))); + state_ = State::INITIAL; + next_required_size_ = kMessageDecoderNextRequiredSizeInitial; + RETURN_NOT_OK(listener_->OnInitial()); + return Status::OK(); + } + + Result ConsumeDataBufferInt32(std::shared_ptr* buffer) { + if ((*buffer)->is_cpu()) { + return util::SafeLoadAs((*buffer)->data()); + } else { + ARROW_ASSIGN_OR_RAISE( + auto cpu_buffer, Buffer::ViewOrCopy(*buffer, CPUDevice::memory_manager(pool_))); + return util::SafeLoadAs(cpu_buffer->data()); + } + } + + Status ConsumeDataChunks(int64_t nbytes, void* out) { + size_t offset = 0; + size_t n_used_chunks = 0; + auto required_size = nbytes; + std::shared_ptr last_chunk; + for (auto& chunk : chunks_) { + if (!chunk->is_cpu()) { + ARROW_ASSIGN_OR_RAISE( + chunk, Buffer::ViewOrCopy(chunk, CPUDevice::memory_manager(pool_))); + } + auto data = chunk->data(); + auto data_size = chunk->size(); + auto copy_size = std::min(required_size, data_size); + memcpy(static_cast(out) + offset, data, copy_size); + n_used_chunks++; + offset += copy_size; + required_size -= copy_size; + if (required_size == 0) { + if (data_size != copy_size) { + last_chunk = SliceBuffer(chunk, copy_size); + } + break; + } + } + chunks_.erase(chunks_.begin(), chunks_.begin() + n_used_chunks); + if (last_chunk.get() != nullptr) { + chunks_.insert(chunks_.begin(), std::move(last_chunk)); + } + buffered_size_ -= offset; + return Status::OK(); + } + + std::shared_ptr listener_; + MemoryPool* pool_; + State state_; + int64_t next_required_size_; + std::vector> chunks_; + int64_t buffered_size_; + std::shared_ptr metadata_; // Must be CPU buffer +}; + +MessageDecoder::MessageDecoder(std::shared_ptr listener, + MemoryPool* pool) { + impl_.reset(new MessageDecoderImpl(std::move(listener), State::INITIAL, + kMessageDecoderNextRequiredSizeInitial, pool)); +} + +MessageDecoder::MessageDecoder(std::shared_ptr listener, + State initial_state, int64_t initial_next_required_size, + MemoryPool* pool) { + impl_.reset(new MessageDecoderImpl(std::move(listener), initial_state, + initial_next_required_size, pool)); +} + +MessageDecoder::~MessageDecoder() {} + +Status MessageDecoder::Consume(const uint8_t* data, int64_t size) { + return impl_->ConsumeData(data, size); +} + +Status MessageDecoder::Consume(std::shared_ptr buffer) { + return impl_->ConsumeBuffer(&buffer); +} + +int64_t MessageDecoder::next_required_size() const { return impl_->next_required_size(); } + +MessageDecoder::State MessageDecoder::state() const { return impl_->state(); } + // ---------------------------------------------------------------------- // Implement InputStream message reader /// \brief Implementation of MessageReader that reads from InputStream -class InputStreamMessageReader : public MessageReader { +class InputStreamMessageReader : public MessageReader, public MessageDecoderListener { public: - explicit InputStreamMessageReader(io::InputStream* stream) : stream_(stream) {} + explicit InputStreamMessageReader(io::InputStream* stream) + : stream_(stream), + owned_stream_(), + message_(), + decoder_(std::shared_ptr(this, [](void*) {})) {} explicit InputStreamMessageReader(const std::shared_ptr& owned_stream) : InputStreamMessageReader(owned_stream.get()) { @@ -422,11 +826,21 @@ class InputStreamMessageReader : public MessageReader { ~InputStreamMessageReader() {} - Result> ReadNextMessage() { return ReadMessage(stream_); } + Status OnMessageDecoded(std::unique_ptr message) override { + message_ = std::move(message); + return Status::OK(); + } + + Result> ReadNextMessage() override { + ARROW_RETURN_NOT_OK(DecodeMessage(&decoder_, stream_)); + return std::move(message_); + } private: io::InputStream* stream_; std::shared_ptr owned_stream_; + std::unique_ptr message_; + MessageDecoder decoder_; }; std::unique_ptr MessageReader::Open(io::InputStream* stream) { diff --git a/cpp/src/arrow/ipc/message.h b/cpp/src/arrow/ipc/message.h index 9698806fd212..1b3fa8f861aa 100644 --- a/cpp/src/arrow/ipc/message.h +++ b/cpp/src/arrow/ipc/message.h @@ -180,6 +180,288 @@ class ARROW_EXPORT Message { ARROW_EXPORT std::string FormatMessageType(Message::Type type); +/// \class MessageDecoderListener +/// \brief An abstract class to listen events from MessageDecoder. +/// +/// This API is EXPERIMENTAL. +/// +/// \since 0.17.0 +class ARROW_EXPORT MessageDecoderListener { + public: + virtual ~MessageDecoderListener() = default; + + /// \brief Called when a message is decoded. + /// + /// MessageDecoder calls this method when it decodes a message. This + /// method is called multiple times when the target stream has + /// multiple messages. + /// + /// \param[in] message a decoded message + /// \return Status + virtual Status OnMessageDecoded(std::unique_ptr message) = 0; + + /// \brief Called when the decoder state is changed to + /// MessageDecoder::State::INITIAL. + /// + /// The default implementation just returns arrow::Status::OK(). + /// + /// \return Status + virtual Status OnInitial(); + + /// \brief Called when the decoder state is changed to + /// MessageDecoder::State::METADATA_LENGTH. + /// + /// The default implementation just returns arrow::Status::OK(). + /// + /// \return Status + virtual Status OnMetadataLength(); + + /// \brief Called when the decoder state is changed to + /// MessageDecoder::State::METADATA. + /// + /// The default implementation just returns arrow::Status::OK(). + /// + /// \return Status + virtual Status OnMetadata(); + + /// \brief Called when the decoder state is changed to + /// MessageDecoder::State::BODY. + /// + /// The default implementation just returns arrow::Status::OK(). + /// + /// \return Status + virtual Status OnBody(); + + /// \brief Called when the decoder state is changed to + /// MessageDecoder::State::EOS. + /// + /// The default implementation just returns arrow::Status::OK(). + /// + /// \return Status + virtual Status OnEOS(); +}; + +/// \class AssignMessageDecoderListener +/// \brief Assign a message decoded by MessageDecoder. +/// +/// This API is EXPERIMENTAL. +/// +/// \since 0.17.0 +class ARROW_EXPORT AssignMessageDecoderListener : public MessageDecoderListener { + public: + /// \brief Construct a listener that assigns a decoded message to the + /// specified location. + /// + /// \param[in] message a location to store the received message + explicit AssignMessageDecoderListener(std::unique_ptr* message) + : message_(message) {} + + virtual ~AssignMessageDecoderListener() = default; + + Status OnMessageDecoded(std::unique_ptr message) override { + *message_ = std::move(message); + return Status::OK(); + } + + private: + std::unique_ptr* message_; + + ARROW_DISALLOW_COPY_AND_ASSIGN(AssignMessageDecoderListener); +}; + +/// \class MessageDecoder +/// \brief Push style message decoder that receives data from user. +/// +/// This API is EXPERIMENTAL. +/// +/// \since 0.17.0 +class ARROW_EXPORT MessageDecoder { + public: + /// \brief State for reading a message + enum State { + /// The initial state. It requires one of the followings as the next data: + /// + /// * int32_t continuation token + /// * int32_t end-of-stream mark (== 0) + /// * int32_t metadata length (backward compatibility for + /// reading old IPC messages produced prior to version 0.15.0 + INITIAL, + + /// It requires int32_t metadata length. + METADATA_LENGTH, + + /// It requires metadata. + METADATA, + + /// It requires message body. + BODY, + + /// The end-of-stream state. No more data is processed. + EOS, + }; + + /// \brief Construct a message decoder. + /// + /// \param[in] listener a MessageDecoderListener that responds events from + /// the decoder + /// \param[in] pool an optional MemoryPool to copy metadata on the + /// CPU, if required + explicit MessageDecoder(std::shared_ptr listener, + MemoryPool* pool = default_memory_pool()); + + /// \brief Construct a message decoder with the specified state. + /// + /// This is a construct for advanced users that know how to decode + /// Message. + /// + /// \param[in] listener a MessageDecoderListener that responds events from + /// the decoder + /// \param[in] initial_state an initial state of the decode + /// \param[in] initial_next_required_size the number of bytes needed + /// to run the next action + /// \param[in] pool an optional MemoryPool to copy metadata on the + /// CPU, if required + MessageDecoder(std::shared_ptr listener, State initial_state, + int64_t initial_next_required_size, + MemoryPool* pool = default_memory_pool()); + + virtual ~MessageDecoder(); + + /// \brief Feed data to the decoder as a raw data. + /// + /// If the decoder can decode one or more messages by the data, the + /// decoder calls listener->OnMessageDecoded() with a decoded + /// message multiple times. + /// + /// If the state of the decoder is changed, corresponding callbacks + /// on listener is called: + /// + /// * MessageDecoder::State::INITIAL: listener->OnInitial() + /// * MessageDecoder::State::METADATA_LENGTH: listener->OnMetadataLength() + /// * MessageDecoder::State::METADATA: listener->OnMetadata() + /// * MessageDecoder::State::BODY: listener->OnBody() + /// * MessageDecoder::State::EOS: listener->OnEOS() + /// + /// \param[in] data a raw data to be processed. This data isn't + /// copied. The passed memory must be kept alive through message + /// processing. + /// \param[in] size raw data size. + /// \return Status + Status Consume(const uint8_t* data, int64_t size); + + /// \brief Feed data to the decoder as a Buffer. + /// + /// If the decoder can decode one or more messages by the Buffer, + /// the decoder calls listener->OnMessageDecoded() with a decoded + /// message multiple times. + /// + /// \param[in] buffer a Buffer to be processed. + /// \return Status + Status Consume(std::shared_ptr buffer); + + /// \brief Return the number of bytes needed to advance the state of + /// the decoder. + /// + /// This method is provided for users who want to optimize performance. + /// Normal users don't need to use this method. + /// + /// Here is an example usage for normal users: + /// + /// ~~~{.cpp} + /// decoder.Consume(buffer1); + /// decoder.Consume(buffer2); + /// decoder.Consume(buffer3); + /// ~~~ + /// + /// Decoder has internal buffer. If consumed data isn't enough to + /// advance the state of the decoder, consumed data is buffered to + /// the internal buffer. It causes performance overhead. + /// + /// If you pass next_required_size() size data to each Consume() + /// call, the decoder doesn't use its internal buffer. It improves + /// performance. + /// + /// Here is an example usage to avoid using internal buffer: + /// + /// ~~~{.cpp} + /// buffer1 = get_data(decoder.next_required_size()); + /// decoder.Consume(buffer1); + /// buffer2 = get_data(decoder.next_required_size()); + /// decoder.Consume(buffer2); + /// ~~~ + /// + /// Users can use this method to avoid creating small + /// chunks. Message body must be contiguous data. If users pass + /// small chunks to the decoder, the decoder needs concatenate small + /// chunks internally. It causes performance overhead. + /// + /// Here is an example usage to reduce small chunks: + /// + /// ~~~{.cpp} + /// buffer = AllocateResizableBuffer(); + /// while ((small_chunk = get_data(&small_chunk_size))) { + /// auto current_buffer_size = buffer->size(); + /// buffer->Resize(current_buffer_size + small_chunk_size); + /// memcpy(buffer->mutable_data() + current_buffer_size, + /// small_chunk, + /// small_chunk_size); + /// if (buffer->size() < decoder.next_requied_size()) { + /// continue; + /// } + /// std::shared_ptr chunk(buffer.release()); + /// decoder.Consume(chunk); + /// buffer = AllocateResizableBuffer(); + /// } + /// if (buffer->size() > 0) { + /// std::shared_ptr chunk(buffer.release()); + /// decoder.Consume(chunk); + /// } + /// ~~~ + /// + /// \return the number of bytes needed to advance the state of the + /// decoder + int64_t next_required_size() const; + + /// \brief Return the current state of the decoder. + /// + /// This method is provided for users who want to optimize performance. + /// Normal users don't need to use this method. + /// + /// Decoder doesn't need Buffer to process data on the + /// MessageDecoder::State::INITIAL state and the + /// MessageDecoder::State::METADATA_LENGTH. Creating Buffer has + /// performance overhead. Advanced users can avoid creating Buffer + /// by checking the current state of the decoder: + /// + /// ~~~{.cpp} + /// switch (decoder.state()) { + /// MessageDecoder::State::INITIAL: + /// MessageDecoder::State::METADATA_LENGTH: + /// { + /// uint8_t data[sizeof(int32_t)]; + /// auto data_size = input->Read(decoder.next_required_size(), data); + /// decoder.Consume(data, data_size); + /// } + /// break; + /// default: + /// { + /// auto buffer = input->Read(decoder.next_required_size()); + /// decoder.Consume(buffer); + /// } + /// break; + /// } + /// ~~~ + /// + /// \return the current state + State state() const; + + private: + class MessageDecoderImpl; + std::unique_ptr impl_; + + ARROW_DISALLOW_COPY_AND_ASSIGN(MessageDecoder); +}; + /// \brief Abstract interface for a sequence of messages /// \since 0.5.0 class ARROW_EXPORT MessageReader { @@ -257,6 +539,19 @@ ARROW_EXPORT Result> ReadMessage(io::InputStream* stream, MemoryPool* pool = default_memory_pool()); +/// \brief Feed data from InputStream to MessageDecoder to decode an +/// encapsulated IPC message (metadata and body) +/// +/// This API is EXPERIMENTAL. +/// +/// \param[in] decoder a decoder +/// \param[in] stream an input stream +/// \return Status +/// +/// \since 0.17.0 +ARROW_EXPORT +Status DecodeMessage(MessageDecoder* decoder, io::InputStream* stream); + /// Write encapsulated IPC message Does not make assumptions about /// whether the stream is aligned already. Can write legacy (pre /// version 0.15.0) IPC message if option set diff --git a/cpp/src/arrow/ipc/read_write_benchmark.cc b/cpp/src/arrow/ipc/read_write_benchmark.cc index c76c166dfcd5..1c4ee406fcfb 100644 --- a/cpp/src/arrow/ipc/read_write_benchmark.cc +++ b/cpp/src/arrow/ipc/read_write_benchmark.cc @@ -59,11 +59,8 @@ static void WriteRecordBatch(benchmark::State& state) { // NOLINT non-const ref io::BufferOutputStream stream(buffer); int32_t metadata_length; int64_t body_length; - if (!ipc::WriteRecordBatch(*record_batch, 0, &stream, &metadata_length, &body_length, - options) - .ok()) { - state.SkipWithError("Failed to write!"); - } + ABORT_NOT_OK(ipc::WriteRecordBatch(*record_batch, 0, &stream, &metadata_length, + &body_length, options)); } state.SetBytesProcessed(int64_t(state.iterations()) * kTotalSize); } @@ -80,25 +77,85 @@ static void ReadRecordBatch(benchmark::State& state) { // NOLINT non-const refe int32_t metadata_length; int64_t body_length; - if (!ipc::WriteRecordBatch(*record_batch, 0, &stream, &metadata_length, &body_length, - options) - .ok()) { - state.SkipWithError("Failed to write!"); - } + ABORT_NOT_OK(ipc::WriteRecordBatch(*record_batch, 0, &stream, &metadata_length, + &body_length, options)); ipc::DictionaryMemo empty_memo; while (state.KeepRunning()) { io::BufferReader reader(buffer); - if (!ipc::ReadRecordBatch(record_batch->schema(), &empty_memo, - ipc::IpcReadOptions::Defaults(), &reader) - .ok()) { - state.SkipWithError("Failed to read!"); + ABORT_NOT_OK(ipc::ReadRecordBatch(record_batch->schema(), &empty_memo, + ipc::IpcReadOptions::Defaults(), &reader)); + } + state.SetBytesProcessed(int64_t(state.iterations()) * kTotalSize); +} + +static void ReadStream(benchmark::State& state) { // NOLINT non-const reference + // 1MB + constexpr int64_t kTotalSize = 1 << 20; + auto options = ipc::IpcWriteOptions::Defaults(); + + std::shared_ptr buffer = *AllocateResizableBuffer(kTotalSize & 2); + auto record_batch = MakeRecordBatch(kTotalSize, state.range(0)); + + io::BufferOutputStream stream(buffer); + + auto writer_result = ipc::NewStreamWriter(&stream, record_batch->schema(), options); + ABORT_NOT_OK(writer_result); + auto writer = *writer_result; + ABORT_NOT_OK(writer->WriteRecordBatch(*record_batch)); + ABORT_NOT_OK(writer->Close()); + + ipc::DictionaryMemo empty_memo; + while (state.KeepRunning()) { + io::BufferReader input(buffer); + auto reader_result = + ipc::RecordBatchStreamReader::Open(&input, ipc::IpcReadOptions::Defaults()); + ABORT_NOT_OK(reader_result); + auto reader = *reader_result; + while (true) { + std::shared_ptr batch; + ABORT_NOT_OK(reader->ReadNext(&batch)); + if (batch.get() == nullptr) { + break; + } } } state.SetBytesProcessed(int64_t(state.iterations()) * kTotalSize); } +static void DecodeStream(benchmark::State& state) { // NOLINT non-const reference + // 1MB + constexpr int64_t kTotalSize = 1 << 20; + auto options = ipc::IpcWriteOptions::Defaults(); + + std::shared_ptr buffer = *AllocateResizableBuffer(kTotalSize & 2); + auto record_batch = MakeRecordBatch(kTotalSize, state.range(0)); + + io::BufferOutputStream stream(buffer); + + auto writer_result = ipc::NewStreamWriter(&stream, record_batch->schema(), options); + ABORT_NOT_OK(writer_result); + auto writer = *writer_result; + ABORT_NOT_OK(writer->WriteRecordBatch(*record_batch)); + ABORT_NOT_OK(writer->Close()); + + ipc::DictionaryMemo empty_memo; + while (state.KeepRunning()) { + class NullListener : public ipc::Listener { + Status OnRecordBatchDecoded(std::shared_ptr batch) override { + return Status::OK(); + } + } listener; + ipc::StreamDecoder decoder(std::shared_ptr(&listener, [](void*) {}), + ipc::IpcReadOptions::Defaults()); + ABORT_NOT_OK(decoder.Consume(buffer)); + } + state.SetBytesProcessed(int64_t(state.iterations()) * kTotalSize); +} + BENCHMARK(WriteRecordBatch)->RangeMultiplier(4)->Range(1, 1 << 13)->UseRealTime(); BENCHMARK(ReadRecordBatch)->RangeMultiplier(4)->Range(1, 1 << 13)->UseRealTime(); +BENCHMARK(ReadStream)->RangeMultiplier(4)->Range(1, 1 << 13)->UseRealTime(); +BENCHMARK(DecodeStream)->RangeMultiplier(4)->Range(1, 1 << 13)->UseRealTime(); } // namespace arrow diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index f9500be40af0..d30eed754abb 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -884,6 +884,84 @@ struct StreamWriterHelper { std::shared_ptr writer_; }; +struct StreamDecoderDataWriterHelper : public StreamWriterHelper { + Status ReadBatches(const IpcReadOptions& options, BatchVector* out_batches) { + auto listener = std::make_shared(); + StreamDecoder decoder(listener, options); + ARROW_RETURN_NOT_OK(decoder.Consume(buffer_->data(), buffer_->size())); + *out_batches = listener->record_batches(); + return Status::OK(); + } + + Status ReadSchema(std::shared_ptr* out) { + auto listener = std::make_shared(); + StreamDecoder decoder(listener); + ARROW_RETURN_NOT_OK(decoder.Consume(buffer_->data(), buffer_->size())); + *out = listener->schema(); + return Status::OK(); + } +}; + +struct StreamDecoderBufferWriterHelper : public StreamWriterHelper { + Status ReadBatches(const IpcReadOptions& options, BatchVector* out_batches) { + auto listener = std::make_shared(); + StreamDecoder decoder(listener, options); + ARROW_RETURN_NOT_OK(decoder.Consume(buffer_)); + *out_batches = listener->record_batches(); + return Status::OK(); + } + + Status ReadSchema(std::shared_ptr* out) { + auto listener = std::make_shared(); + StreamDecoder decoder(listener); + ARROW_RETURN_NOT_OK(decoder.Consume(buffer_)); + *out = listener->schema(); + return Status::OK(); + } +}; + +struct StreamDecoderSmallChunksWriterHelper : public StreamWriterHelper { + Status ReadBatches(const IpcReadOptions& options, BatchVector* out_batches) { + auto listener = std::make_shared(); + StreamDecoder decoder(listener, options); + for (int64_t offset = 0; offset < buffer_->size() - 1; ++offset) { + ARROW_RETURN_NOT_OK(decoder.Consume(buffer_->data() + offset, 1)); + } + *out_batches = listener->record_batches(); + return Status::OK(); + } + + Status ReadSchema(std::shared_ptr* out) { + auto listener = std::make_shared(); + StreamDecoder decoder(listener); + for (int64_t offset = 0; offset < buffer_->size() - 1; ++offset) { + ARROW_RETURN_NOT_OK(decoder.Consume(buffer_->data() + offset, 1)); + } + *out = listener->schema(); + return Status::OK(); + } +}; + +struct StreamDecoderLargeChunksWriterHelper : public StreamWriterHelper { + Status ReadBatches(const IpcReadOptions& options, BatchVector* out_batches) { + auto listener = std::make_shared(); + StreamDecoder decoder(listener, options); + ARROW_RETURN_NOT_OK(decoder.Consume(SliceBuffer(buffer_, 0, 1))); + ARROW_RETURN_NOT_OK(decoder.Consume(SliceBuffer(buffer_, 1))); + *out_batches = listener->record_batches(); + return Status::OK(); + } + + Status ReadSchema(std::shared_ptr* out) { + auto listener = std::make_shared(); + StreamDecoder decoder(listener); + ARROW_RETURN_NOT_OK(decoder.Consume(SliceBuffer(buffer_, 0, 1))); + ARROW_RETURN_NOT_OK(decoder.Consume(SliceBuffer(buffer_, 1))); + *out = listener->schema(); + return Status::OK(); + } +}; + // Parameterized mixin with tests for stream / file writer template @@ -902,8 +980,9 @@ class ReaderWriterMixin { BatchVector in_batches = {batch1, batch2}; BatchVector out_batches; - ASSERT_OK( - RoundTripHelper(in_batches, options, IpcReadOptions::Defaults(), &out_batches)); + WriterHelper writer_helper; + ASSERT_OK(RoundTripHelper(writer_helper, in_batches, options, + IpcReadOptions::Defaults(), &out_batches)); ASSERT_EQ(out_batches.size(), in_batches.size()); // Compare batches @@ -924,8 +1003,9 @@ class ReaderWriterMixin { BatchVector in_batches = {batch1, batch2}; BatchVector out_batches; - ASSERT_OK( - RoundTripHelper(in_batches, options, IpcReadOptions::Defaults(), &out_batches)); + WriterHelper writer_helper; + ASSERT_OK(RoundTripHelper(writer_helper, in_batches, options, + IpcReadOptions::Defaults(), &out_batches)); ASSERT_EQ(out_batches.size(), in_batches.size()); // Compare batches @@ -938,8 +1018,9 @@ class ReaderWriterMixin { std::shared_ptr batch; ASSERT_OK(MakeDictionary(&batch)); + WriterHelper writer_helper; BatchVector out_batches; - ASSERT_OK(RoundTripHelper({batch}, IpcWriteOptions::Defaults(), + ASSERT_OK(RoundTripHelper(writer_helper, {batch}, IpcWriteOptions::Defaults(), IpcReadOptions::Defaults(), &out_batches)); ASSERT_EQ(out_batches.size(), 1); @@ -967,22 +1048,35 @@ class ReaderWriterMixin { options.included_fields = {1, 3}; - BatchVector out_batches; - ASSERT_OK( - RoundTripHelper({batch}, IpcWriteOptions::Defaults(), options, &out_batches)); + { + WriterHelper writer_helper; + BatchVector out_batches; + ASSERT_OK(RoundTripHelper(writer_helper, {batch}, IpcWriteOptions::Defaults(), + options, &out_batches)); - auto ex_schema = schema({field("a1", utf8()), field("a3", utf8())}, - key_value_metadata({"key1"}, {"value1"})); - auto ex_batch = RecordBatch::Make(ex_schema, a0->length(), {a1, a3}); - AssertBatchesEqual(*ex_batch, *out_batches[0], /*check_metadata=*/true); + auto ex_schema = schema({field("a1", utf8()), field("a3", utf8())}, + key_value_metadata({"key1"}, {"value1"})); + auto ex_batch = RecordBatch::Make(ex_schema, a0->length(), {a1, a3}); + AssertBatchesEqual(*ex_batch, *out_batches[0], /*check_metadata=*/true); + } // Out of bounds cases options.included_fields = {1, 3, 5}; - ASSERT_RAISES(Invalid, RoundTripHelper({batch}, IpcWriteOptions::Defaults(), options, - &out_batches)); + { + WriterHelper writer_helper; + BatchVector out_batches; + ASSERT_RAISES(Invalid, + RoundTripHelper(writer_helper, {batch}, IpcWriteOptions::Defaults(), + options, &out_batches)); + } options.included_fields = {1, 3, -1}; - ASSERT_RAISES(Invalid, RoundTripHelper({batch}, IpcWriteOptions::Defaults(), options, - &out_batches)); + { + WriterHelper writer_helper; + BatchVector out_batches; + ASSERT_RAISES(Invalid, + RoundTripHelper(writer_helper, {batch}, IpcWriteOptions::Defaults(), + options, &out_batches)); + } } void TestWriteDifferentSchema() { @@ -1031,10 +1125,9 @@ class ReaderWriterMixin { } private: - Status RoundTripHelper(const BatchVector& in_batches, + Status RoundTripHelper(WriterHelper& writer_helper, const BatchVector& in_batches, const IpcWriteOptions& write_options, const IpcReadOptions& read_options, BatchVector* out_batches) { - WriterHelper writer_helper; RETURN_NOT_OK(writer_helper.Init(in_batches[0]->schema(), write_options)); for (const auto& batch : in_batches) { RETURN_NOT_OK(writer_helper.WriteBatch(batch)); @@ -1069,6 +1162,17 @@ class TestFileFormat : public ReaderWriterMixin, class TestStreamFormat : public ReaderWriterMixin, public ::testing::TestWithParam {}; +class TestStreamDecoderData : public ReaderWriterMixin, + public ::testing::TestWithParam {}; +class TestStreamDecoderBuffer : public ReaderWriterMixin, + public ::testing::TestWithParam {}; +class TestStreamDecoderSmallChunks + : public ReaderWriterMixin, + public ::testing::TestWithParam {}; +class TestStreamDecoderLargeChunks + : public ReaderWriterMixin, + public ::testing::TestWithParam {}; + TEST_P(TestFileFormat, RoundTrip) { TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); @@ -1089,9 +1193,57 @@ TEST_P(TestStreamFormat, RoundTrip) { TestZeroLengthRoundTrip(*GetParam(), options); } +TEST_P(TestStreamDecoderData, RoundTrip) { + TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); + TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); + + IpcWriteOptions options; + options.write_legacy_ipc_format = true; + TestRoundTrip(*GetParam(), options); + TestZeroLengthRoundTrip(*GetParam(), options); +} + +TEST_P(TestStreamDecoderBuffer, RoundTrip) { + TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); + TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); + + IpcWriteOptions options; + options.write_legacy_ipc_format = true; + TestRoundTrip(*GetParam(), options); + TestZeroLengthRoundTrip(*GetParam(), options); +} + +TEST_P(TestStreamDecoderSmallChunks, RoundTrip) { + TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); + TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); + + IpcWriteOptions options; + options.write_legacy_ipc_format = true; + TestRoundTrip(*GetParam(), options); + TestZeroLengthRoundTrip(*GetParam(), options); +} + +TEST_P(TestStreamDecoderLargeChunks, RoundTrip) { + TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); + TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); + + IpcWriteOptions options; + options.write_legacy_ipc_format = true; + TestRoundTrip(*GetParam(), options); + TestZeroLengthRoundTrip(*GetParam(), options); +} + INSTANTIATE_TEST_SUITE_P(GenericIpcRoundTripTests, TestIpcRoundTrip, BATCH_CASES()); INSTANTIATE_TEST_SUITE_P(FileRoundTripTests, TestFileFormat, BATCH_CASES()); INSTANTIATE_TEST_SUITE_P(StreamRoundTripTests, TestStreamFormat, BATCH_CASES()); +INSTANTIATE_TEST_SUITE_P(StreamDecoderDataRoundTripTests, TestStreamDecoderData, + BATCH_CASES()); +INSTANTIATE_TEST_SUITE_P(StreamDecoderBufferRoundTripTests, TestStreamDecoderBuffer, + BATCH_CASES()); +INSTANTIATE_TEST_SUITE_P(StreamDecoderSmallChunksRoundTripTests, + TestStreamDecoderSmallChunks, BATCH_CASES()); +INSTANTIATE_TEST_SUITE_P(StreamDecoderLargeChunksRoundTripTests, + TestStreamDecoderLargeChunks, BATCH_CASES()); TEST(TestIpcFileFormat, FooterMetaData) { // ARROW-6837 @@ -1692,6 +1844,15 @@ TEST(TestRecordBatchStreamReader, MalformedInput) { ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader)); } +TEST(TestStreamDecoder, NextRequiredSize) { + auto listener = std::make_shared(); + StreamDecoder decoder(listener); + auto next_required_size = decoder.next_required_size(); + const uint8_t data[1] = {0}; + ASSERT_OK(decoder.Consume(data, 1)); + ASSERT_EQ(next_required_size - 1, decoder.next_required_size()); +} + // ---------------------------------------------------------------------- // DictionaryMemo miscellanea diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 484d648bbd06..571aac33fcbf 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -502,7 +502,7 @@ Result> ReadRecordBatchInternal( const Buffer& metadata, const std::shared_ptr& schema, const std::vector& inclusion_mask, const DictionaryMemo* dictionary_memo, const IpcReadOptions& options, io::RandomAccessFile* file) { - const flatbuf::Message* message; + const flatbuf::Message* message = nullptr; RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message)); auto batch = message->header_as_RecordBatch(); if (batch == nullptr) { @@ -528,6 +528,24 @@ Status PopulateInclusionMask(const std::vector& included_indices, return Status::OK(); } +Status PrepareSchemaMessage(const Message& message, DictionaryMemo* dictionary_memo, + std::shared_ptr* schema, + const IpcReadOptions& options, + std::vector* field_inclusion_mask) { + CHECK_MESSAGE_TYPE(Message::SCHEMA, message.type()); + CHECK_HAS_NO_BODY(message); + + RETURN_NOT_OK(internal::GetSchema(message.header(), dictionary_memo, schema)); + + // If we are selecting only certain fields, populate the inclusion mask now + // for fast lookups + if (options.included_fields) { + RETURN_NOT_OK(PopulateInclusionMask(*options.included_fields, (*schema)->num_fields(), + field_inclusion_mask)); + } + return Status::OK(); +} + Result> ReadRecordBatch( const Buffer& metadata, const std::shared_ptr& schema, const DictionaryMemo* dictionary_memo, const IpcReadOptions& options, @@ -544,7 +562,7 @@ Result> ReadRecordBatch( Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo, const IpcReadOptions& options, io::RandomAccessFile* file) { - const flatbuf::Message* message; + const flatbuf::Message* message = nullptr; RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message)); auto dictionary_batch = message->header_as_DictionaryBatch(); if (dictionary_batch == nullptr) { @@ -580,6 +598,21 @@ Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo, return dictionary_memo->AddDictionary(id, dictionary); } +Status ParseDictionary(const Message& message, DictionaryMemo* dictionary_memo, + const IpcReadOptions& options) { + // Only invoke this method if we already know we have a dictionary message + DCHECK_EQ(message.type(), Message::DICTIONARY_BATCH); + CHECK_HAS_BODY(message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body())); + return ReadDictionary(*message.metadata(), dictionary_memo, options, reader.get()); +} + +Status UpdateDictionaries(const Message& message, DictionaryMemo* dictionary_memo, + const IpcReadOptions& options) { + // TODO(wesm): implement delta dictionaries + return Status::NotImplemented("Delta dictionaries not yet implemented"); +} + // ---------------------------------------------------------------------- // RecordBatchStreamReader implementation @@ -596,18 +629,8 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { if (!message) { return Status::Invalid("Tried reading schema message, was null or length 0"); } - CHECK_MESSAGE_TYPE(Message::SCHEMA, message->type()); - CHECK_HAS_NO_BODY(*message); - - RETURN_NOT_OK(internal::GetSchema(message->header(), &dictionary_memo_, &schema_)); - - // If we are selecting only certain fields, populate the inclusion mask now - // for fast lookups - if (options.included_fields) { - RETURN_NOT_OK(PopulateInclusionMask(*options.included_fields, schema_->num_fields(), - &field_inclusion_mask_)); - } - return Status::OK(); + return PrepareSchemaMessage(*message, &dictionary_memo_, &schema_, options, + &field_inclusion_mask_); } Status ReadNext(std::shared_ptr* batch) override { @@ -631,8 +654,7 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { } if (message->type() == Message::DICTIONARY_BATCH) { - // TODO(wesm): implement delta dictionaries - return Status::NotImplemented("Delta dictionaries not yet implemented"); + return UpdateDictionaries(*message, &dictionary_memo_, options_); } else { CHECK_HAS_BODY(*message); ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); @@ -645,14 +667,6 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { std::shared_ptr schema() const override { return schema_; } private: - Status ParseDictionary(const Message& message) { - // Only invoke this method if we already know we have a dictionary message - DCHECK_EQ(message.type(), Message::DICTIONARY_BATCH); - CHECK_HAS_BODY(message); - ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body())); - return ReadDictionary(*message.metadata(), &dictionary_memo_, options_, reader.get()); - } - Status ReadInitialDictionaries() { // We must receive all dictionaries before reconstructing the // first record batch. Subsequent dictionary deltas modify the memo @@ -684,7 +698,7 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { dictionary_memo_.num_fields(), ") of dictionaries at the start of the stream"); } - RETURN_NOT_OK(ParseDictionary(*message)); + RETURN_NOT_OK(ParseDictionary(*message, &dictionary_memo_, options_)); } have_read_initial_dictionaries_ = true; @@ -927,6 +941,142 @@ Result> RecordBatchFileReader::Open( return result; } +Status Listener::OnEOS() { return Status::OK(); } + +Status Listener::OnSchemaDecoded(std::shared_ptr schema) { return Status::OK(); } + +Status Listener::OnRecordBatchDecoded(std::shared_ptr record_batch) { + return Status::NotImplemented("OnRecordBatchDecoded() callback isn't implemented"); +} + +class StreamDecoder::StreamDecoderImpl : public MessageDecoderListener { + private: + enum State { + SCHEMA, + INITIAL_DICTIONARIES, + RECORD_BATCHES, + EOS, + }; + + public: + explicit StreamDecoderImpl(std::shared_ptr listener, + const IpcReadOptions& options) + : MessageDecoderListener(), + listener_(std::move(listener)), + options_(options), + state_(State::SCHEMA), + message_decoder_(std::shared_ptr(this, [](void*) {}), + options_.memory_pool), + field_inclusion_mask_(), + n_required_dictionaries_(0), + dictionary_memo_(), + schema_() {} + + Status OnMessageDecoded(std::unique_ptr message) override { + switch (state_) { + case State::SCHEMA: + ARROW_RETURN_NOT_OK(OnSchemaMessageDecoded(std::move(message))); + break; + case State::INITIAL_DICTIONARIES: + ARROW_RETURN_NOT_OK(OnInitialDictionaryMessageDecoded(std::move(message))); + break; + case State::RECORD_BATCHES: + ARROW_RETURN_NOT_OK(OnRecordBatchMessageDecoded(std::move(message))); + break; + case State::EOS: + break; + } + return Status::OK(); + } + + Status OnEOS() override { + state_ = State::EOS; + return listener_->OnEOS(); + } + + Status Consume(const uint8_t* data, int64_t size) { + return message_decoder_.Consume(data, size); + } + + Status Consume(std::shared_ptr buffer) { + return message_decoder_.Consume(std::move(buffer)); + } + + std::shared_ptr schema() const { return schema_; } + + int64_t next_required_size() const { return message_decoder_.next_required_size(); } + + private: + Status OnSchemaMessageDecoded(std::unique_ptr message) { + RETURN_NOT_OK(PrepareSchemaMessage(*message, &dictionary_memo_, &schema_, options_, + &field_inclusion_mask_)); + n_required_dictionaries_ = dictionary_memo_.num_fields(); + if (n_required_dictionaries_ == 0) { + state_ = State::RECORD_BATCHES; + ARROW_RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_)); + } else { + state_ = State::INITIAL_DICTIONARIES; + } + return Status::OK(); + } + + Status OnInitialDictionaryMessageDecoded(std::unique_ptr message) { + if (message->type() != Message::DICTIONARY_BATCH) { + return Status::Invalid("IPC stream did not have the expected number (", + dictionary_memo_.num_fields(), + ") of dictionaries at the start of the stream"); + } + RETURN_NOT_OK(ParseDictionary(*message, &dictionary_memo_, options_)); + n_required_dictionaries_--; + if (n_required_dictionaries_ == 0) { + state_ = State::RECORD_BATCHES; + ARROW_RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_)); + } + return Status::OK(); + } + + Status OnRecordBatchMessageDecoded(std::unique_ptr message) { + if (message->type() == Message::DICTIONARY_BATCH) { + return UpdateDictionaries(*message, &dictionary_memo_, options_); + } else { + CHECK_HAS_BODY(*message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); + ARROW_ASSIGN_OR_RAISE( + auto batch, + ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_, + &dictionary_memo_, options_, reader.get())); + return listener_->OnRecordBatchDecoded(std::move(batch)); + } + } + + std::shared_ptr listener_; + IpcReadOptions options_; + State state_; + MessageDecoder message_decoder_; + std::vector field_inclusion_mask_; + int n_required_dictionaries_; + DictionaryMemo dictionary_memo_; + std::shared_ptr schema_; +}; + +StreamDecoder::StreamDecoder(std::shared_ptr listener, + const IpcReadOptions& options) { + impl_.reset(new StreamDecoderImpl(std::move(listener), options)); +} + +StreamDecoder::~StreamDecoder() {} + +Status StreamDecoder::Consume(const uint8_t* data, int64_t size) { + return impl_->Consume(data, size); +} +Status StreamDecoder::Consume(std::shared_ptr buffer) { + return impl_->Consume(std::move(buffer)); +} + +std::shared_ptr StreamDecoder::schema() const { return impl_->schema(); } + +int64_t StreamDecoder::next_required_size() const { return impl_->next_required_size(); } + Result> ReadSchema(io::InputStream* stream, DictionaryMemo* dictionary_memo) { std::unique_ptr reader = MessageReader::Open(stream); diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index b691e0df1225..45d2ee32951f 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include "arrow/ipc/message.h" #include "arrow/ipc/options.h" @@ -194,6 +196,201 @@ class ARROW_EXPORT RecordBatchFileReader { } }; +/// \class Listener +/// \brief A general listener class to receive events. +/// +/// You must implement callback methods for interested events. +/// +/// This API is EXPERIMENTAL. +/// +/// \since 0.17.0 +class ARROW_EXPORT Listener { + public: + virtual ~Listener() = default; + + /// \brief Called when end-of-stream is received. + /// + /// The default implementation just returns arrow::Status::OK(). + /// + /// \return Status + /// + /// \see StreamDecoder + virtual Status OnEOS(); + + /// \brief Called when a record batch is decoded. + /// + /// The default implementation just returns + /// arrow::Status::NotImplemented(). + /// + /// \param[in] record_batch a record batch decoded + /// \return Status + /// + /// \see StreamDecoder + virtual Status OnRecordBatchDecoded(std::shared_ptr record_batch); + + /// \brief Called when a schema is decoded. + /// + /// The default implementation just returns arrow::Status::OK(). + /// + /// \param[in] schema a schema decoded + /// \return Status + /// + /// \see StreamDecoder + virtual Status OnSchemaDecoded(std::shared_ptr schema); +}; + +/// \class CollectListener +/// \brief Collect schema and record batches decoded by StreamDecoder. +/// +/// This API is EXPERIMENTAL. +/// +/// \since 0.17.0 +class ARROW_EXPORT CollectListener : public Listener { + public: + CollectListener() : schema_(), record_batches_() {} + virtual ~CollectListener() = default; + + Status OnSchemaDecoded(std::shared_ptr schema) override { + schema_ = std::move(schema); + return Status::OK(); + } + + Status OnRecordBatchDecoded(std::shared_ptr record_batch) override { + record_batches_.push_back(std::move(record_batch)); + return Status::OK(); + } + + /// \return the decoded schema + std::shared_ptr schema() const { return schema_; } + + /// \return the all decoded record batches + std::vector> record_batches() const { + return record_batches_; + } + + private: + std::shared_ptr schema_; + std::vector> record_batches_; +}; + +/// \class StreamDecoder +/// \brief Push style stream decoder that receives data from user. +/// +/// This class decodes the Apache Arrow IPC streaming format data. +/// +/// This API is EXPERIMENTAL. +/// +/// \see https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format +/// +/// \since 0.17.0 +class ARROW_EXPORT StreamDecoder { + public: + /// \brief Construct a stream decoder. + /// + /// \param[in] listener a Listener that must implement + /// Listener::OnRecordBatchDecoded() to receive decoded record batches + /// \param[in] options any IPC reading options (optional) + StreamDecoder(std::shared_ptr listener, + const IpcReadOptions& options = IpcReadOptions::Defaults()); + + virtual ~StreamDecoder(); + + /// \brief Feed data to the decoder as a raw data. + /// + /// If the decoder can read one or more record batches by the data, + /// the decoder calls listener->OnRecordBatchDecoded() with a + /// decoded record batch multiple times. + /// + /// \param[in] data a raw data to be processed. This data isn't + /// copied. The passed memory must be kept alive through record + /// batch processing. + /// \param[in] size raw data size. + /// \return Status + Status Consume(const uint8_t* data, int64_t size); + + /// \brief Feed data to the decoder as a Buffer. + /// + /// If the decoder can read one or more record batches by the + /// Buffer, the decoder calls listener->RecordBatchReceived() with a + /// decoded record batch multiple times. + /// + /// \param[in] buffer a Buffer to be processed. + /// \return Status + Status Consume(std::shared_ptr buffer); + + /// \return the shared schema of the record batches in the stream + std::shared_ptr schema() const; + + /// \brief Return the number of bytes needed to advance the state of + /// the decoder. + /// + /// This method is provided for users who want to optimize performance. + /// Normal users don't need to use this method. + /// + /// Here is an example usage for normal users: + /// + /// ~~~{.cpp} + /// decoder.Consume(buffer1); + /// decoder.Consume(buffer2); + /// decoder.Consume(buffer3); + /// ~~~ + /// + /// Decoder has internal buffer. If consumed data isn't enough to + /// advance the state of the decoder, consumed data is buffered to + /// the internal buffer. It causes performance overhead. + /// + /// If you pass next_required_size() size data to each Consume() + /// call, the decoder doesn't use its internal buffer. It improves + /// performance. + /// + /// Here is an example usage to avoid using internal buffer: + /// + /// ~~~{.cpp} + /// buffer1 = get_data(decoder.next_required_size()); + /// decoder.Consume(buffer1); + /// buffer2 = get_data(decoder.next_required_size()); + /// decoder.Consume(buffer2); + /// ~~~ + /// + /// Users can use this method to avoid creating small chunks. Record + /// batch data must be contiguous data. If users pass small chunks + /// to the decoder, the decoder needs concatenate small chunks + /// internally. It causes performance overhead. + /// + /// Here is an example usage to reduce small chunks: + /// + /// ~~~{.cpp} + /// buffer = AllocateResizableBuffer(); + /// while ((small_chunk = get_data(&small_chunk_size))) { + /// auto current_buffer_size = buffer->size(); + /// buffer->Resize(current_buffer_size + small_chunk_size); + /// memcpy(buffer->mutable_data() + current_buffer_size, + /// small_chunk, + /// small_chunk_size); + /// if (buffer->size() < decoder.next_requied_size()) { + /// continue; + /// } + /// std::shared_ptr chunk(buffer.release()); + /// decoder.Consume(chunk); + /// buffer = AllocateResizableBuffer(); + /// } + /// if (buffer->size() > 0) { + /// std::shared_ptr chunk(buffer.release()); + /// decoder.Consume(chunk); + /// } + /// ~~~ + /// + /// \return the number of bytes needed to advance the state of the + /// decoder + int64_t next_required_size() const; + + private: + class StreamDecoderImpl; + std::unique_ptr impl_; + + ARROW_DISALLOW_COPY_AND_ASSIGN(StreamDecoder); +}; + // Generic read functions; does not copy data if the input supports zero copy reads /// \brief Read Schema from stream serialized as a single IPC message