diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index 1a4784fcf591..b27dcee33b18 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -66,6 +66,7 @@ namespace ipc { using internal::FieldPosition; using internal::IoRecordedRandomAccessFile; +using MetadataVector = std::vector>; namespace test { @@ -1018,8 +1019,9 @@ struct FileWriterHelper { return Status::OK(); } - Status WriteBatch(const std::shared_ptr& batch) { - RETURN_NOT_OK(writer_->WriteRecordBatch(*batch)); + Status WriteBatch(const std::shared_ptr& batch, + const std::shared_ptr& metadata = nullptr) { + RETURN_NOT_OK(writer_->WriteRecordBatch(*batch, metadata)); num_batches_written_++; return Status::OK(); } @@ -1042,16 +1044,22 @@ struct FileWriterHelper { virtual Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches, - ReadStats* out_stats = nullptr) { + ReadStats* out_stats = nullptr, + MetadataVector* out_metadata_list = nullptr) { auto buf_reader = std::make_shared(buffer_); ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchFileReader::Open( buf_reader.get(), footer_offset_, options)); EXPECT_EQ(num_batches_written_, reader->num_record_batches()); for (int i = 0; i < num_batches_written_; ++i) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr chunk, - reader->ReadRecordBatch(i)); + ARROW_ASSIGN_OR_RAISE(auto chunk_with_metadata, + reader->ReadRecordBatchWithCustomMetadata(i)); + auto chunk = chunk_with_metadata.batch; out_batches->push_back(chunk); + if (out_metadata_list) { + auto metadata = chunk_with_metadata.custom_metadata; + out_metadata_list->push_back(metadata); + } } if (out_stats) { *out_stats = reader->stats(); @@ -1096,7 +1104,8 @@ class NoZeroCopyBufferReader : public io::BufferReader { template struct FileGeneratorWriterHelper : public FileWriterHelper { Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches, - ReadStats* out_stats = nullptr) override { + ReadStats* out_stats = nullptr, + MetadataVector* out_metadata_list = nullptr) override { std::shared_ptr buf_reader; if (kCoalesce) { // Use a non-zero-copy enabled BufferReader so we can test paths properly @@ -1145,8 +1154,9 @@ struct StreamWriterHelper { return Status::OK(); } - Status WriteBatch(const std::shared_ptr& batch) { - RETURN_NOT_OK(writer_->WriteRecordBatch(*batch)); + Status WriteBatch(const std::shared_ptr& batch, + const std::shared_ptr& metadata = nullptr) { + RETURN_NOT_OK(writer_->WriteRecordBatch(*batch, metadata)); return Status::OK(); } @@ -1165,10 +1175,23 @@ struct StreamWriterHelper { virtual Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches, - ReadStats* out_stats = nullptr) { + ReadStats* out_stats = nullptr, + MetadataVector* out_metadata_list = nullptr) { auto buf_reader = std::make_shared(buffer_); ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchStreamReader::Open(buf_reader, options)) - ARROW_ASSIGN_OR_RAISE(*out_batches, reader->ToRecordBatches()); + if (out_metadata_list) { + while (true) { + ARROW_ASSIGN_OR_RAISE(auto chunk_with_metadata, reader->ReadNext()); + if (chunk_with_metadata.batch == nullptr) { + break; + } + out_batches->push_back(chunk_with_metadata.batch); + out_metadata_list->push_back(chunk_with_metadata.custom_metadata); + } + } else { + ARROW_ASSIGN_OR_RAISE(*out_batches, reader->ToRecordBatches()); + } + if (out_stats) { *out_stats = reader->stats(); } @@ -1195,7 +1218,8 @@ struct StreamWriterHelper { struct StreamDecoderWriterHelper : public StreamWriterHelper { Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches, - ReadStats* out_stats = nullptr) override { + ReadStats* out_stats = nullptr, + MetadataVector* out_metadata_list = nullptr) override { auto listener = std::make_shared(); StreamDecoder decoder(listener, options); RETURN_NOT_OK(DoConsume(&decoder)); @@ -1420,6 +1444,57 @@ class ReaderWriterMixin : public ExtensionTypesMixin { ASSERT_TRUE(out_batches[0]->schema()->Equals(*schema)); } + void TestWriteBatchWithMetadata() { + std::shared_ptr batch; + ASSERT_OK(MakeIntRecordBatch(&batch)); + + WriterHelper writer_helper; + ASSERT_OK(writer_helper.Init(batch->schema(), IpcWriteOptions::Defaults())); + + auto metadata = key_value_metadata({"some_key"}, {"some_value"}); + ASSERT_OK(writer_helper.WriteBatch(batch, metadata)); + ASSERT_OK(writer_helper.Finish()); + + RecordBatchVector out_batches; + MetadataVector out_metadata_list; + ASSERT_OK(writer_helper.ReadBatches(IpcReadOptions::Defaults(), &out_batches, nullptr, + &out_metadata_list)); + ASSERT_EQ(out_batches.size(), 1); + ASSERT_EQ(out_metadata_list.size(), 1); + CompareBatch(*out_batches[0], *batch, false /* compare_metadata */); + ASSERT_TRUE(out_metadata_list[0]->Equals(*metadata)); + } + + // write multiple batches and each of them with different metadata + void TestWriteDifferentMetadata() { + std::shared_ptr batch_0; + std::shared_ptr batch_1; + auto metadata_0 = key_value_metadata({"some_key"}, {"0"}); + auto metadata_1 = key_value_metadata({"some_key"}, {"1"}); + ASSERT_OK(MakeIntRecordBatch(&batch_0)); + ASSERT_OK(MakeIntRecordBatch(&batch_1)); + + WriterHelper writer_helper; + ASSERT_OK(writer_helper.Init(batch_0->schema(), IpcWriteOptions::Defaults())); + + ASSERT_OK(writer_helper.WriteBatch(batch_0, metadata_0)); + + // Write a batch with different metadata + ASSERT_OK(writer_helper.WriteBatch(batch_1, metadata_1)); + ASSERT_OK(writer_helper.Finish()); + + RecordBatchVector out_batches; + MetadataVector out_metadata_list; + ASSERT_OK(writer_helper.ReadBatches(IpcReadOptions::Defaults(), &out_batches, nullptr, + &out_metadata_list)); + ASSERT_EQ(out_batches.size(), 2); + ASSERT_EQ(out_metadata_list.size(), 2); + CompareBatch(*out_batches[0], *batch_0, true /* compare_metadata */); + CompareBatch(*out_batches[1], *batch_1, true /* compare_metadata */); + ASSERT_TRUE(out_metadata_list[0]->Equals(*metadata_0)); + ASSERT_TRUE(out_metadata_list[1]->Equals(*metadata_1)); + } + void TestWriteNoRecordBatches() { // Test writing no batches. auto schema = arrow::schema({field("a", int32())}); @@ -1800,6 +1875,15 @@ TEST_F(TestFileFormatGeneratorCoalesced, DictionaryRoundTrip) { } TEST_F(TestStreamFormat, DifferentSchema) { TestWriteDifferentSchema(); } + +TEST_F(TestFileFormat, BatchWithMetadata) { TestWriteBatchWithMetadata(); } + +TEST_F(TestStreamFormat, BatchWithMetadata) { TestWriteBatchWithMetadata(); } + +TEST_F(TestFileFormat, DifferentMetadataBatches) { TestWriteDifferentMetadata(); } + +TEST_F(TestStreamFormat, DifferentMetadataBatches) { TestWriteDifferentMetadata(); } + TEST_F(TestFileFormat, DifferentSchema) { TestWriteDifferentSchema(); } TEST_F(TestFileFormatGenerator, DifferentSchema) { TestWriteDifferentSchema(); } TEST_F(TestFileFormatGeneratorCoalesced, DifferentSchema) { TestWriteDifferentSchema(); } diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index a5f31d74febf..0b4620379593 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -654,7 +654,7 @@ Result> ReadRecordBatch( reader.get()); } -Result> ReadRecordBatchInternal( +Result ReadRecordBatchInternal( const Buffer& metadata, const std::shared_ptr& schema, const std::vector& inclusion_mask, IpcReadContext& context, io::RandomAccessFile* file) { @@ -676,7 +676,15 @@ Result> ReadRecordBatchInternal( } context.compression = compression; context.metadata_version = internal::GetMetadataVersion(message->version()); - return LoadRecordBatch(batch, schema, inclusion_mask, context, file); + + std::shared_ptr custom_metadata; + if (message->custom_metadata() != nullptr) { + RETURN_NOT_OK( + internal::GetKeyValueMetadata(message->custom_metadata(), &custom_metadata)); + } + ARROW_ASSIGN_OR_RAISE(auto record_batch, + LoadRecordBatch(batch, schema, inclusion_mask, context, file)); + return RecordBatchWithMetadata{record_batch, custom_metadata}; } // If we are selecting only certain fields, populate an inclusion mask for fast lookups. @@ -756,7 +764,10 @@ Result> ReadRecordBatch( IpcReadContext context(const_cast(dictionary_memo), options, false); RETURN_NOT_OK(GetInclusionMaskAndOutSchema(schema, context.options.included_fields, &inclusion_mask, &out_schema)); - return ReadRecordBatchInternal(metadata, schema, inclusion_mask, context, file); + ARROW_ASSIGN_OR_RAISE( + auto batch_and_custom_metadata, + ReadRecordBatchInternal(metadata, schema, inclusion_mask, context, file)); + return batch_and_custom_metadata.batch; } Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context, @@ -852,15 +863,21 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { } Status ReadNext(std::shared_ptr* batch) override { + ARROW_ASSIGN_OR_RAISE(auto batch_with_metadata, ReadNext()); + *batch = std::move(batch_with_metadata.batch); + return Status::OK(); + } + + Result ReadNext() override { if (!have_read_initial_dictionaries_) { RETURN_NOT_OK(ReadInitialDictionaries()); } + RecordBatchWithMetadata batch_with_metadata; if (empty_stream_) { // ARROW-6006: Degenerate case where stream contains no data, we do not // bother trying to read a RecordBatch message from the stream - *batch = nullptr; - return Status::OK(); + return batch_with_metadata; } // Continue to read other dictionaries, if any @@ -874,16 +891,14 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { if (message == nullptr) { // End of stream - *batch = nullptr; - return Status::OK(); + return batch_with_metadata; } CHECK_HAS_BODY(*message); ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); IpcReadContext context(&dictionary_memo_, options_, swap_endian_); return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_, - context, reader.get()) - .Value(batch); + context, reader.get()); } std::shared_ptr schema() const override { return out_schema_; } @@ -1158,12 +1173,26 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader { } Result> ReadRecordBatch(int i) override { + ARROW_ASSIGN_OR_RAISE(auto batch_with_metadata, ReadRecordBatchWithCustomMetadata(i)); + return batch_with_metadata.batch; + } + + Result ReadRecordBatchWithCustomMetadata(int i) override { DCHECK_GE(i, 0); DCHECK_LT(i, num_record_batches()); auto cached_metadata = cached_metadata_.find(i); if (cached_metadata != cached_metadata_.end()) { - return ReadCachedRecordBatch(i, cached_metadata->second).result(); + auto result = ReadCachedRecordBatch(i, cached_metadata->second).result(); + ARROW_ASSIGN_OR_RAISE(auto batch, result); + ARROW_ASSIGN_OR_RAISE(auto message_obj, cached_metadata->second.result()); + ARROW_ASSIGN_OR_RAISE(auto message, GetFlatbufMessage(message_obj)); + std::shared_ptr custom_metadata; + if (message->custom_metadata() != nullptr) { + RETURN_NOT_OK( + internal::GetKeyValueMetadata(message->custom_metadata(), &custom_metadata)); + } + return RecordBatchWithMetadata{std::move(batch), std::move(custom_metadata)}; } RETURN_NOT_OK(WaitForDictionaryReadFinished()); @@ -1185,11 +1214,12 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader { CHECK_HAS_BODY(*message); ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); IpcReadContext context(&dictionary_memo_, options_, swap_endian_); - ARROW_ASSIGN_OR_RAISE(auto batch, ReadRecordBatchInternal( - *message->metadata(), schema_, - field_inclusion_mask_, context, reader.get())); + ARROW_ASSIGN_OR_RAISE( + auto batch_with_metadata, + ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_, + context, reader.get())); ++stats_.num_record_batches; - return batch; + return batch_with_metadata; } Result CountRows() override { @@ -1832,8 +1862,11 @@ Result> WholeIpcFileRecordBatchGenerator::ReadRecor CHECK_HAS_BODY(*message); ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); IpcReadContext context(&state->dictionary_memo_, state->options_, state->swap_endian_); - return ReadRecordBatchInternal(*message->metadata(), state->schema_, - state->field_inclusion_mask_, context, reader.get()); + ARROW_ASSIGN_OR_RAISE( + auto batch_with_metadata, + ReadRecordBatchInternal(*message->metadata(), state->schema_, + state->field_inclusion_mask_, context, reader.get())); + return batch_with_metadata.batch; } Status Listener::OnEOS() { return Status::OK(); } @@ -1938,11 +1971,11 @@ class StreamDecoder::StreamDecoderImpl : public MessageDecoderListener { ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); IpcReadContext context(&dictionary_memo_, options_, swap_endian_); ARROW_ASSIGN_OR_RAISE( - auto batch, + auto batch_with_metadata, ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_, context, reader.get())); ++stats_.num_record_batches; - return listener_->OnRecordBatchDecoded(std::move(batch)); + return listener_->OnRecordBatchDecoded(std::move(batch_with_metadata.batch)); } } diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index 4bdbccc5097a..ad7969b31c99 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -190,6 +190,13 @@ class ARROW_EXPORT RecordBatchFileReader /// \return the read batch virtual Result> ReadRecordBatch(int i) = 0; + /// \brief Read a particular record batch along with its custom metadada from the file. + /// Does not copy memory if the input source supports zero-copy. + /// + /// \param[in] i the index of the record batch to return + /// \return a struct containing the read batch and its custom metadata + virtual Result ReadRecordBatchWithCustomMetadata(int i) = 0; + /// \brief Return current read statistics virtual ReadStats stats() const = 0; diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index cf5a08bf3bf6..4a7671e158fc 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -128,9 +128,11 @@ static inline bool NeedTruncate(int64_t offset, const Buffer* buffer, class RecordBatchSerializer { public: - RecordBatchSerializer(int64_t buffer_start_offset, const IpcWriteOptions& options, - IpcPayload* out) + RecordBatchSerializer(int64_t buffer_start_offset, + const std::shared_ptr& custom_metadata, + const IpcWriteOptions& options, IpcPayload* out) : out_(out), + custom_metadata_(custom_metadata), options_(options), max_recursion_depth_(options.max_recursion_depth), buffer_start_offset_(buffer_start_offset) { @@ -175,13 +177,6 @@ class RecordBatchSerializer { field_nodes_, buffer_meta_, options_, &out_->metadata); } - void AppendCustomMetadata(const std::string& key, const std::string& value) { - if (!custom_metadata_) { - custom_metadata_ = std::make_shared(); - } - custom_metadata_->Append(key, value); - } - Status CompressBuffer(const Buffer& buffer, util::Codec* codec, std::shared_ptr* out) { // Convert buffer to uncompressed-length-prefixed compressed buffer @@ -540,7 +535,7 @@ class RecordBatchSerializer { // Destination for output buffers IpcPayload* out_; - std::shared_ptr custom_metadata_; + std::shared_ptr custom_metadata_; std::vector field_nodes_; std::vector buffer_meta_; @@ -554,7 +549,7 @@ class DictionarySerializer : public RecordBatchSerializer { public: DictionarySerializer(int64_t dictionary_id, bool is_delta, int64_t buffer_start_offset, const IpcWriteOptions& options, IpcPayload* out) - : RecordBatchSerializer(buffer_start_offset, options, out), + : RecordBatchSerializer(buffer_start_offset, NULLPTR, options, out), dictionary_id_(dictionary_id), is_delta_(is_delta) {} @@ -636,8 +631,16 @@ Status GetDictionaryPayload(int64_t id, bool is_delta, Status GetRecordBatchPayload(const RecordBatch& batch, const IpcWriteOptions& options, IpcPayload* out) { + return GetRecordBatchPayload(batch, NULLPTR, options, out); +} + +Status GetRecordBatchPayload( + const RecordBatch& batch, + const std::shared_ptr& custom_metadata, + const IpcWriteOptions& options, IpcPayload* out) { out->type = MessageType::RECORD_BATCH; - RecordBatchSerializer assembler(/*buffer_start_offset=*/0, options, out); + RecordBatchSerializer assembler(/*buffer_start_offset=*/0, custom_metadata, options, + out); return assembler.Assemble(batch); } @@ -645,7 +648,7 @@ Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length, const IpcWriteOptions& options) { IpcPayload payload; - RecordBatchSerializer assembler(buffer_start_offset, options, &payload); + RecordBatchSerializer assembler(buffer_start_offset, NULLPTR, options, &payload); RETURN_NOT_OK(assembler.Assemble(batch)); // TODO: it's a rough edge that the metadata and body length here are @@ -1000,6 +1003,12 @@ class ARROW_EXPORT IpcFormatWriter : public RecordBatchWriter { } Status WriteRecordBatch(const RecordBatch& batch) override { + return WriteRecordBatch(batch, NULLPTR); + } + + Status WriteRecordBatch( + const RecordBatch& batch, + const std::shared_ptr& custom_metadata) override { if (!batch.schema()->Equals(schema_, false /* check_metadata */)) { return Status::Invalid("Tried to write record batch with different schema"); } @@ -1009,7 +1018,7 @@ class ARROW_EXPORT IpcFormatWriter : public RecordBatchWriter { RETURN_NOT_OK(WriteDictionaries(batch)); IpcPayload payload; - RETURN_NOT_OK(GetRecordBatchPayload(batch, options_, &payload)); + RETURN_NOT_OK(GetRecordBatchPayload(batch, custom_metadata, options_, &payload)); RETURN_NOT_OK(WritePayload(payload)); ++stats_.num_record_batches; diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h index 31937fa69c26..eb5827a60981 100644 --- a/cpp/src/arrow/ipc/writer.h +++ b/cpp/src/arrow/ipc/writer.h @@ -96,6 +96,18 @@ class ARROW_EXPORT RecordBatchWriter { /// \return Status virtual Status WriteRecordBatch(const RecordBatch& batch) = 0; + /// \brief Write a record batch with custom metadata to the stream + /// + /// \param[in] batch the record batch to write to the stream + /// \param[in] custom_metadata the record batch's custom metadata to write to the stream + /// \return Status + virtual Status WriteRecordBatch( + const RecordBatch& batch, + const std::shared_ptr& custom_metadata) { + return Status::NotImplemented( + "Write record batch with custom metadata not implemented"); + } + /// \brief Write possibly-chunked table by creating sequence of record batches /// \param[in] table table to write /// \return Status @@ -389,6 +401,18 @@ ARROW_EXPORT Status GetRecordBatchPayload(const RecordBatch& batch, const IpcWriteOptions& options, IpcPayload* out); +/// \brief Compute IpcPayload for the given record batch and custom metadata +/// \param[in] batch the RecordBatch that is being serialized +/// \param[in] custom_metadata the custom metadata to be serialized with the record batch +/// \param[in] options options for serialization +/// \param[out] out the returned IpcPayload +/// \return Status +ARROW_EXPORT +Status GetRecordBatchPayload( + const RecordBatch& batch, + const std::shared_ptr& custom_metadata, + const IpcWriteOptions& options, IpcPayload* out); + /// \brief Write an IPC payload to the given stream. /// \param[in] payload the payload to write /// \param[in] options options for serialization diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index fb7eea25eec1..60aa9ad9c941 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -210,6 +210,11 @@ class ARROW_EXPORT RecordBatch { ARROW_DISALLOW_COPY_AND_ASSIGN(RecordBatch); }; +struct ARROW_EXPORT RecordBatchWithMetadata { + std::shared_ptr batch; + std::shared_ptr custom_metadata; +}; + /// \brief Abstract interface for reading stream of record batches class ARROW_EXPORT RecordBatchReader { public: @@ -227,6 +232,10 @@ class ARROW_EXPORT RecordBatchReader { /// \return Status virtual Status ReadNext(std::shared_ptr* batch) = 0; + virtual Result ReadNext() { + return Status::NotImplemented("ReadNext with custom metadata"); + } + /// \brief Iterator interface Result> Next() { std::shared_ptr batch;