Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 95 additions & 11 deletions cpp/src/arrow/ipc/read_write_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ namespace ipc {

using internal::FieldPosition;
using internal::IoRecordedRandomAccessFile;
using MetadataVector = std::vector<std::shared_ptr<KeyValueMetadata>>;

namespace test {

Expand Down Expand Up @@ -1018,8 +1019,9 @@ struct FileWriterHelper {
return Status::OK();
}

Status WriteBatch(const std::shared_ptr<RecordBatch>& batch) {
RETURN_NOT_OK(writer_->WriteRecordBatch(*batch));
Status WriteBatch(const std::shared_ptr<RecordBatch>& batch,
const std::shared_ptr<const KeyValueMetadata>& metadata = nullptr) {
RETURN_NOT_OK(writer_->WriteRecordBatch(*batch, metadata));
num_batches_written_++;
return Status::OK();
}
Expand All @@ -1042,16 +1044,22 @@ struct FileWriterHelper {

virtual Status ReadBatches(const IpcReadOptions& options,
RecordBatchVector* out_batches,
ReadStats* out_stats = nullptr) {
ReadStats* out_stats = nullptr,
MetadataVector* out_metadata_list = nullptr) {
auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchFileReader::Open(
buf_reader.get(), footer_offset_, options));

EXPECT_EQ(num_batches_written_, reader->num_record_batches());
for (int i = 0; i < num_batches_written_; ++i) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> chunk,
reader->ReadRecordBatch(i));
ARROW_ASSIGN_OR_RAISE(auto chunk_with_metadata,
reader->ReadRecordBatchWithCustomMetadata(i));
auto chunk = chunk_with_metadata.batch;
out_batches->push_back(chunk);
if (out_metadata_list) {
auto metadata = chunk_with_metadata.custom_metadata;
out_metadata_list->push_back(metadata);
}
}
if (out_stats) {
*out_stats = reader->stats();
Expand Down Expand Up @@ -1096,7 +1104,8 @@ class NoZeroCopyBufferReader : public io::BufferReader {
template <bool kCoalesce>
struct FileGeneratorWriterHelper : public FileWriterHelper {
Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches,
ReadStats* out_stats = nullptr) override {
ReadStats* out_stats = nullptr,
MetadataVector* out_metadata_list = nullptr) override {
std::shared_ptr<io::RandomAccessFile> buf_reader;
if (kCoalesce) {
// Use a non-zero-copy enabled BufferReader so we can test paths properly
Expand Down Expand Up @@ -1145,8 +1154,9 @@ struct StreamWriterHelper {
return Status::OK();
}

Status WriteBatch(const std::shared_ptr<RecordBatch>& batch) {
RETURN_NOT_OK(writer_->WriteRecordBatch(*batch));
Status WriteBatch(const std::shared_ptr<RecordBatch>& batch,
const std::shared_ptr<const KeyValueMetadata>& metadata = nullptr) {
RETURN_NOT_OK(writer_->WriteRecordBatch(*batch, metadata));
return Status::OK();
}

Expand All @@ -1165,10 +1175,23 @@ struct StreamWriterHelper {

virtual Status ReadBatches(const IpcReadOptions& options,
RecordBatchVector* out_batches,
ReadStats* out_stats = nullptr) {
ReadStats* out_stats = nullptr,
MetadataVector* out_metadata_list = nullptr) {
auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchStreamReader::Open(buf_reader, options))
ARROW_ASSIGN_OR_RAISE(*out_batches, reader->ToRecordBatches());
if (out_metadata_list) {
while (true) {
ARROW_ASSIGN_OR_RAISE(auto chunk_with_metadata, reader->ReadNext());
if (chunk_with_metadata.batch == nullptr) {
break;
}
out_batches->push_back(chunk_with_metadata.batch);
out_metadata_list->push_back(chunk_with_metadata.custom_metadata);
}
} else {
ARROW_ASSIGN_OR_RAISE(*out_batches, reader->ToRecordBatches());
}

if (out_stats) {
*out_stats = reader->stats();
}
Expand All @@ -1195,7 +1218,8 @@ struct StreamWriterHelper {

struct StreamDecoderWriterHelper : public StreamWriterHelper {
Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches,
ReadStats* out_stats = nullptr) override {
ReadStats* out_stats = nullptr,
MetadataVector* out_metadata_list = nullptr) override {
auto listener = std::make_shared<CollectListener>();
StreamDecoder decoder(listener, options);
RETURN_NOT_OK(DoConsume(&decoder));
Expand Down Expand Up @@ -1420,6 +1444,57 @@ class ReaderWriterMixin : public ExtensionTypesMixin {
ASSERT_TRUE(out_batches[0]->schema()->Equals(*schema));
}

void TestWriteBatchWithMetadata() {
std::shared_ptr<RecordBatch> batch;
ASSERT_OK(MakeIntRecordBatch(&batch));

WriterHelper writer_helper;
ASSERT_OK(writer_helper.Init(batch->schema(), IpcWriteOptions::Defaults()));

auto metadata = key_value_metadata({"some_key"}, {"some_value"});
ASSERT_OK(writer_helper.WriteBatch(batch, metadata));
ASSERT_OK(writer_helper.Finish());

RecordBatchVector out_batches;
MetadataVector out_metadata_list;
ASSERT_OK(writer_helper.ReadBatches(IpcReadOptions::Defaults(), &out_batches, nullptr,
&out_metadata_list));
ASSERT_EQ(out_batches.size(), 1);
ASSERT_EQ(out_metadata_list.size(), 1);
CompareBatch(*out_batches[0], *batch, false /* compare_metadata */);
ASSERT_TRUE(out_metadata_list[0]->Equals(*metadata));
}

// write multiple batches and each of them with different metadata
void TestWriteDifferentMetadata() {
std::shared_ptr<RecordBatch> batch_0;
std::shared_ptr<RecordBatch> batch_1;
auto metadata_0 = key_value_metadata({"some_key"}, {"0"});
auto metadata_1 = key_value_metadata({"some_key"}, {"1"});
ASSERT_OK(MakeIntRecordBatch(&batch_0));
ASSERT_OK(MakeIntRecordBatch(&batch_1));

WriterHelper writer_helper;
ASSERT_OK(writer_helper.Init(batch_0->schema(), IpcWriteOptions::Defaults()));

ASSERT_OK(writer_helper.WriteBatch(batch_0, metadata_0));

// Write a batch with different metadata
ASSERT_OK(writer_helper.WriteBatch(batch_1, metadata_1));
ASSERT_OK(writer_helper.Finish());

RecordBatchVector out_batches;
MetadataVector out_metadata_list;
ASSERT_OK(writer_helper.ReadBatches(IpcReadOptions::Defaults(), &out_batches, nullptr,
&out_metadata_list));
ASSERT_EQ(out_batches.size(), 2);
ASSERT_EQ(out_metadata_list.size(), 2);
CompareBatch(*out_batches[0], *batch_0, true /* compare_metadata */);
CompareBatch(*out_batches[1], *batch_1, true /* compare_metadata */);
ASSERT_TRUE(out_metadata_list[0]->Equals(*metadata_0));
ASSERT_TRUE(out_metadata_list[1]->Equals(*metadata_1));
}

void TestWriteNoRecordBatches() {
// Test writing no batches.
auto schema = arrow::schema({field("a", int32())});
Expand Down Expand Up @@ -1800,6 +1875,15 @@ TEST_F(TestFileFormatGeneratorCoalesced, DictionaryRoundTrip) {
}

TEST_F(TestStreamFormat, DifferentSchema) { TestWriteDifferentSchema(); }

TEST_F(TestFileFormat, BatchWithMetadata) { TestWriteBatchWithMetadata(); }

TEST_F(TestStreamFormat, BatchWithMetadata) { TestWriteBatchWithMetadata(); }

TEST_F(TestFileFormat, DifferentMetadataBatches) { TestWriteDifferentMetadata(); }

TEST_F(TestStreamFormat, DifferentMetadataBatches) { TestWriteDifferentMetadata(); }

TEST_F(TestFileFormat, DifferentSchema) { TestWriteDifferentSchema(); }
TEST_F(TestFileFormatGenerator, DifferentSchema) { TestWriteDifferentSchema(); }
TEST_F(TestFileFormatGeneratorCoalesced, DifferentSchema) { TestWriteDifferentSchema(); }
Expand Down
69 changes: 51 additions & 18 deletions cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
reader.get());
}

Result<std::shared_ptr<RecordBatch>> ReadRecordBatchInternal(
Result<RecordBatchWithMetadata> ReadRecordBatchInternal(
const Buffer& metadata, const std::shared_ptr<Schema>& schema,
const std::vector<bool>& inclusion_mask, IpcReadContext& context,
io::RandomAccessFile* file) {
Expand All @@ -676,7 +676,15 @@ Result<std::shared_ptr<RecordBatch>> ReadRecordBatchInternal(
}
context.compression = compression;
context.metadata_version = internal::GetMetadataVersion(message->version());
return LoadRecordBatch(batch, schema, inclusion_mask, context, file);

std::shared_ptr<KeyValueMetadata> custom_metadata;
if (message->custom_metadata() != nullptr) {
RETURN_NOT_OK(
internal::GetKeyValueMetadata(message->custom_metadata(), &custom_metadata));
}
ARROW_ASSIGN_OR_RAISE(auto record_batch,
LoadRecordBatch(batch, schema, inclusion_mask, context, file));
return RecordBatchWithMetadata{record_batch, custom_metadata};
}

// If we are selecting only certain fields, populate an inclusion mask for fast lookups.
Expand Down Expand Up @@ -756,7 +764,10 @@ Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
IpcReadContext context(const_cast<DictionaryMemo*>(dictionary_memo), options, false);
RETURN_NOT_OK(GetInclusionMaskAndOutSchema(schema, context.options.included_fields,
&inclusion_mask, &out_schema));
return ReadRecordBatchInternal(metadata, schema, inclusion_mask, context, file);
ARROW_ASSIGN_OR_RAISE(
auto batch_and_custom_metadata,
ReadRecordBatchInternal(metadata, schema, inclusion_mask, context, file));
return batch_and_custom_metadata.batch;
}

Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context,
Expand Down Expand Up @@ -852,15 +863,21 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader {
}

Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
ARROW_ASSIGN_OR_RAISE(auto batch_with_metadata, ReadNext());
*batch = std::move(batch_with_metadata.batch);
return Status::OK();
}

Result<RecordBatchWithMetadata> ReadNext() override {
if (!have_read_initial_dictionaries_) {
RETURN_NOT_OK(ReadInitialDictionaries());
}

RecordBatchWithMetadata batch_with_metadata;
if (empty_stream_) {
// ARROW-6006: Degenerate case where stream contains no data, we do not
// bother trying to read a RecordBatch message from the stream
*batch = nullptr;
return Status::OK();
return batch_with_metadata;
}

// Continue to read other dictionaries, if any
Expand All @@ -874,16 +891,14 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader {

if (message == nullptr) {
// End of stream
*batch = nullptr;
return Status::OK();
return batch_with_metadata;
}

CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
context, reader.get())
.Value(batch);
context, reader.get());
}

std::shared_ptr<Schema> schema() const override { return out_schema_; }
Expand Down Expand Up @@ -1158,12 +1173,26 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader {
}

Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(int i) override {
ARROW_ASSIGN_OR_RAISE(auto batch_with_metadata, ReadRecordBatchWithCustomMetadata(i));
return batch_with_metadata.batch;
}

Result<RecordBatchWithMetadata> ReadRecordBatchWithCustomMetadata(int i) override {
DCHECK_GE(i, 0);
DCHECK_LT(i, num_record_batches());

auto cached_metadata = cached_metadata_.find(i);
if (cached_metadata != cached_metadata_.end()) {
return ReadCachedRecordBatch(i, cached_metadata->second).result();
auto result = ReadCachedRecordBatch(i, cached_metadata->second).result();
ARROW_ASSIGN_OR_RAISE(auto batch, result);
ARROW_ASSIGN_OR_RAISE(auto message_obj, cached_metadata->second.result());
ARROW_ASSIGN_OR_RAISE(auto message, GetFlatbufMessage(message_obj));
std::shared_ptr<KeyValueMetadata> custom_metadata;
if (message->custom_metadata() != nullptr) {
RETURN_NOT_OK(
internal::GetKeyValueMetadata(message->custom_metadata(), &custom_metadata));
}
return RecordBatchWithMetadata{std::move(batch), std::move(custom_metadata)};
}

RETURN_NOT_OK(WaitForDictionaryReadFinished());
Expand All @@ -1185,11 +1214,12 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader {
CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
ARROW_ASSIGN_OR_RAISE(auto batch, ReadRecordBatchInternal(
*message->metadata(), schema_,
field_inclusion_mask_, context, reader.get()));
ARROW_ASSIGN_OR_RAISE(
auto batch_with_metadata,
ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
context, reader.get()));
++stats_.num_record_batches;
return batch;
return batch_with_metadata;
}

Result<int64_t> CountRows() override {
Expand Down Expand Up @@ -1832,8 +1862,11 @@ Result<std::shared_ptr<RecordBatch>> WholeIpcFileRecordBatchGenerator::ReadRecor
CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
IpcReadContext context(&state->dictionary_memo_, state->options_, state->swap_endian_);
return ReadRecordBatchInternal(*message->metadata(), state->schema_,
state->field_inclusion_mask_, context, reader.get());
ARROW_ASSIGN_OR_RAISE(
auto batch_with_metadata,
ReadRecordBatchInternal(*message->metadata(), state->schema_,
state->field_inclusion_mask_, context, reader.get()));
return batch_with_metadata.batch;
}

Status Listener::OnEOS() { return Status::OK(); }
Expand Down Expand Up @@ -1938,11 +1971,11 @@ class StreamDecoder::StreamDecoderImpl : public MessageDecoderListener {
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
ARROW_ASSIGN_OR_RAISE(
auto batch,
auto batch_with_metadata,
ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
context, reader.get()));
++stats_.num_record_batches;
return listener_->OnRecordBatchDecoded(std::move(batch));
return listener_->OnRecordBatchDecoded(std::move(batch_with_metadata.batch));
}
}

Expand Down
7 changes: 7 additions & 0 deletions cpp/src/arrow/ipc/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ class ARROW_EXPORT RecordBatchFileReader
/// \return the read batch
virtual Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(int i) = 0;

/// \brief Read a particular record batch along with its custom metadada from the file.
/// Does not copy memory if the input source supports zero-copy.
///
/// \param[in] i the index of the record batch to return
/// \return a struct containing the read batch and its custom metadata
virtual Result<RecordBatchWithMetadata> ReadRecordBatchWithCustomMetadata(int i) = 0;

/// \brief Return current read statistics
virtual ReadStats stats() const = 0;

Expand Down
Loading