From 376b519704cce762e14e3ee293ab76190fa3a27b Mon Sep 17 00:00:00 2001 From: Yue Ni Date: Sun, 1 May 2022 14:28:35 +0800 Subject: [PATCH 1/3] ARROW-16430, support reading record batch custom metadata API in pyarrow. --- cpp/src/arrow/ipc/writer.cc | 10 +++++ cpp/src/arrow/ipc/writer.h | 5 +-- python/pyarrow/_flight.pyx | 6 ++- python/pyarrow/includes/libarrow.pxd | 10 +++++ python/pyarrow/ipc.pxi | 55 +++++++++++++++++++++++++++- python/pyarrow/tests/test_ipc.py | 24 ++++++++++++ 6 files changed, 103 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index 674ddfe7a061..0dcab2d5e7ff 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -951,6 +951,16 @@ Status GetTensorSize(const Tensor& tensor, int64_t* size) { RecordBatchWriter::~RecordBatchWriter() {} +Status RecordBatchWriter::WriteRecordBatch( + const RecordBatch& batch, + const std::shared_ptr& custom_metadata) { + if (custom_metadata == nullptr) { + return WriteRecordBatch(batch); + } + return Status::NotImplemented( + "Write record batch with custom metadata not implemented"); +} + Status RecordBatchWriter::WriteTable(const Table& table, int64_t max_chunksize) { TableBatchReader reader(table); diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h index 6dc62f41761e..9e18a213ba3f 100644 --- a/cpp/src/arrow/ipc/writer.h +++ b/cpp/src/arrow/ipc/writer.h @@ -103,10 +103,7 @@ class ARROW_EXPORT RecordBatchWriter { /// \return Status virtual Status WriteRecordBatch( const RecordBatch& batch, - const std::shared_ptr& custom_metadata) { - return Status::NotImplemented( - "Write record batch with custom metadata not implemented"); - } + const std::shared_ptr& custom_metadata); /// \brief Write possibly-chunked table by creating sequence of record batches /// \param[in] table table to write diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index b6c9177195a1..538510ba2744 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -1086,14 +1086,18 @@ cdef class MetadataRecordBatchWriter(_CRecordBatchWriter): ---------- batch : RecordBatch """ + cdef: + shared_ptr[const CKeyValueMetadata] custom_metadata + # Override superclass method to use check_flight_status so we # can generate FlightWriteSizeExceededError. We don't do this # for write_table as callers who intend to handle the error # and retry with a smaller batch should be working with # individual batches to have control. + with nogil: check_flight_status( - self._writer().WriteRecordBatch(deref(batch.batch))) + self._writer().WriteRecordBatch(deref(batch.batch), custom_metadata)) def write_table(self, Table table, max_chunksize=None, **kwargs): """ diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e44fa2615e29..e512676b4309 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -820,6 +820,11 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: shared_ptr[CRecordBatch] Slice(int64_t offset) shared_ptr[CRecordBatch] Slice(int64_t offset, int64_t length) + cdef cppclass CRecordBatchWithMetadata" arrow::RecordBatchWithMetadata": + shared_ptr[CRecordBatch] batch + # The struct in C++ does not actually have these two `const` qualifiers, but adding `const` gets Cython to not complain + const shared_ptr[const CKeyValueMetadata] custom_metadata + cdef cppclass CTable" arrow::Table": CTable(const shared_ptr[CSchema]& schema, const vector[shared_ptr[CChunkedArray]]& columns) @@ -1584,6 +1589,9 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: cdef cppclass CRecordBatchWriter" arrow::ipc::RecordBatchWriter": CStatus Close() CStatus WriteRecordBatch(const CRecordBatch& batch) + CStatus WriteRecordBatch( + const CRecordBatch& batch, + const shared_ptr[const CKeyValueMetadata]& metadata) CStatus WriteTable(const CTable& table, int64_t max_chunksize) CIpcWriteStats stats() @@ -1619,6 +1627,8 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: CResult[shared_ptr[CRecordBatch]] ReadRecordBatch(int i) + CResult[CRecordBatchWithMetadata] ReadRecordBatchWithCustomMetadata(int i) + CIpcReadStats stats() CResult[shared_ptr[CRecordBatchWriter]] MakeStreamWriter( diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index b5cbbfb62cf8..80158a66d9d6 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -472,17 +472,22 @@ cdef class _CRecordBatchWriter(_Weakrefable): else: raise ValueError(type(table_or_batch)) - def write_batch(self, RecordBatch batch): + def write_batch(self, RecordBatch batch, custom_metadata=None): """ Write RecordBatch to stream. Parameters ---------- batch : RecordBatch + custom_metadata : dict + Keys and values must be string-like / coercible to bytes """ + metadata = ensure_metadata(custom_metadata, allow_none=True) + c_meta = pyarrow_unwrap_metadata(metadata) + with nogil: check_status(self.writer.get() - .WriteRecordBatch(deref(batch.batch))) + .WriteRecordBatch(deref(batch.batch), c_meta)) def write_table(self, Table table, max_chunksize=None): """ @@ -832,6 +837,26 @@ cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter): self.writer = GetResultValue( MakeFileWriter(c_sink, schema.sp_schema, self.options)) +_RecordBatchWithMetadata = namedtuple( + 'RecordBatchWithMetadata', + ('batch', 'custom_metadata')) + + +class RecordBatchWithMetadata(_RecordBatchWithMetadata): + """RecordBatch with its custom metadata + + Parameters + ---------- + batch: record batch + custom_metadata: record batch's custom metadata + """ + __slots__ = () + + +@staticmethod +cdef _wrap_record_batch_with_metadata(CRecordBatchWithMetadata c): + return RecordBatchWithMetadata(pyarrow_wrap_batch(c.batch), pyarrow_wrap_metadata(c.custom_metadata)) + cdef class _RecordBatchFileReader(_Weakrefable): cdef: @@ -908,6 +933,32 @@ cdef class _RecordBatchFileReader(_Weakrefable): # time has passed get_record_batch = get_batch + def get_batch_with_custom_metadata(self, int i): + """ + Read the record batch with the given index along with its custom metadata + + Parameters + ---------- + i : int + The index of the record batch in the IPC file. + + Returns + ------- + batch : RecordBatch + custom_metadata : KeyValueMetadata or dict + """ + cdef: + CRecordBatchWithMetadata batch_with_metadata + + if i < 0 or i >= self.num_record_batches: + raise ValueError('Batch number {0} out of range'.format(i)) + + with nogil: + batch_with_metadata = GetResultValue( + self.reader.get().ReadRecordBatchWithCustomMetadata(i)) + + return _wrap_record_batch_with_metadata(batch_with_metadata) + def read_all(self): """ Read all record batches as a pyarrow.Table diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index b7192867dcf0..3e3a4ed91632 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -945,6 +945,30 @@ def test_ipc_zero_copy_numpy(): assert_frame_equal(df, rdf) +@pytest.mark.pandas +def test_ipc_batch_with_custom_metadata_roundtrip(): + df = pd.DataFrame({'foo': [1.5]}) + + batch = pa.RecordBatch.from_pandas(df) + sink = pa.BufferOutputStream() + + batch_count = 2 + with pa.ipc.new_file(sink, batch.schema) as writer: + for i in range(batch_count): + writer.write_batch(batch, {"batch_id": str(i)}) + + buffer = sink.getvalue() + source = pa.BufferReader(buffer) + + with pa.ipc.open_file(source) as reader: + batch_with_metas = [reader.get_batch_with_custom_metadata( + i) for i in range(reader.num_record_batches)] + + for i in range(batch_count): + assert batch_with_metas[i].batch.num_rows == 1 + assert batch_with_metas[i].custom_metadata == {"batch_id": str(i)} + + def test_ipc_stream_no_batches(): # ARROW-2307 table = pa.Table.from_arrays([pa.array([1, 2, 3, 4]), From 8f5d119288ea6d2abc22f726bec737c7ed1fb1a7 Mon Sep 17 00:00:00 2001 From: Yue Ni Date: Wed, 17 Aug 2022 18:07:58 +0800 Subject: [PATCH 2/3] Add ReadNext with custom metadata API for RecordBatchReader in pyarrow so that pyarrow can read record batch along with its custom metadata. --- python/pyarrow/includes/libarrow.pxd | 1 + python/pyarrow/ipc.pxi | 59 +++++++++++++++++++++++++--- python/pyarrow/tests/test_ipc.py | 38 ++++++++++++++++++ 3 files changed, 92 insertions(+), 6 deletions(-) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e512676b4309..0085d3e4037f 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -889,6 +889,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CRecordBatchReader" arrow::RecordBatchReader": shared_ptr[CSchema] schema() CStatus Close() + CResult[CRecordBatchWithMetadata] ReadNext() CStatus ReadNext(shared_ptr[CRecordBatch]* batch) CResult[shared_ptr[CTable]] ToTable() diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index 80158a66d9d6..eb9be99592ff 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -479,7 +479,7 @@ cdef class _CRecordBatchWriter(_Weakrefable): Parameters ---------- batch : RecordBatch - custom_metadata : dict + custom_metadata : KeyValueMetadata Keys and values must be string-like / coercible to bytes """ metadata = ensure_metadata(custom_metadata, allow_none=True) @@ -692,6 +692,51 @@ cdef class RecordBatchReader(_Weakrefable): return pyarrow_wrap_batch(batch) + def read_next_batch_with_custom_metadata(self): + """ + Read next RecordBatch from the stream along with its custom metadata. + + Raises + ------ + StopIteration: + At end of stream. + + Returns + ------- + batch : RecordBatch + custom_metadata : KeyValueMetadata + """ + cdef: + CRecordBatchWithMetadata batch_with_metadata + + with nogil: + batch_with_metadata = GetResultValue(self.reader.get().ReadNext()) + + if batch_with_metadata.batch.get() == NULL: + raise StopIteration + + return _wrap_record_batch_with_metadata(batch_with_metadata) + + def iter_batches_with_custom_metadata(self): + """ + Read next RecordBatch from the stream along with its custom metadata + as a generator + + Raises + ------ + StopIteration: + At end of stream. + + Yields + ------- + RecordBatch with optional custom metadata + """ + while True: + try: + yield self.read_next_batch_with_custom_metadata() + except StopIteration: + return + def read_all(self): """ Read all record batches as a pyarrow.Table. @@ -847,15 +892,16 @@ class RecordBatchWithMetadata(_RecordBatchWithMetadata): Parameters ---------- - batch: record batch - custom_metadata: record batch's custom metadata + batch : RecordBatch + custom_metadata : KeyValueMetadata """ __slots__ = () @staticmethod cdef _wrap_record_batch_with_metadata(CRecordBatchWithMetadata c): - return RecordBatchWithMetadata(pyarrow_wrap_batch(c.batch), pyarrow_wrap_metadata(c.custom_metadata)) + return RecordBatchWithMetadata(pyarrow_wrap_batch(c.batch), + pyarrow_wrap_metadata(c.custom_metadata)) cdef class _RecordBatchFileReader(_Weakrefable): @@ -935,7 +981,8 @@ cdef class _RecordBatchFileReader(_Weakrefable): def get_batch_with_custom_metadata(self, int i): """ - Read the record batch with the given index along with its custom metadata + Read the record batch with the given index along with + its custom metadata Parameters ---------- @@ -945,7 +992,7 @@ cdef class _RecordBatchFileReader(_Weakrefable): Returns ------- batch : RecordBatch - custom_metadata : KeyValueMetadata or dict + custom_metadata : KeyValueMetadata """ cdef: CRecordBatchWithMetadata batch_with_metadata diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 3e3a4ed91632..3082d70ed42e 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -956,6 +956,8 @@ def test_ipc_batch_with_custom_metadata_roundtrip(): with pa.ipc.new_file(sink, batch.schema) as writer: for i in range(batch_count): writer.write_batch(batch, {"batch_id": str(i)}) + # write a batch without custom metadata + writer.write_batch(batch) buffer = sink.getvalue() source = pa.BufferReader(buffer) @@ -966,8 +968,44 @@ def test_ipc_batch_with_custom_metadata_roundtrip(): for i in range(batch_count): assert batch_with_metas[i].batch.num_rows == 1 + assert isinstance( + batch_with_metas[i].custom_metadata, pa.KeyValueMetadata) assert batch_with_metas[i].custom_metadata == {"batch_id": str(i)} + # the last batch has no custom metadata + assert batch_with_metas[batch_count].batch.num_rows == 1 + assert batch_with_metas[batch_count].custom_metadata is None + + +@pytest.mark.pandas +def test_record_batch_reader_with_custom_metadata_roundtrip(): + df = pd.DataFrame({'foo': [1.5]}) + + batch = pa.RecordBatch.from_pandas(df) + sink = pa.BufferOutputStream() + + batch_count = 2 + with pa.ipc.new_stream(sink, batch.schema) as writer: + for i in range(batch_count): + writer.write_batch(batch, {"batch_id": str(i)}) + # write a batch without custom metadata + writer.write_batch(batch) + + buffer = sink.getvalue() + stream_contents = pa.BufferReader(buffer) + + with pa.ipc.open_stream(stream_contents) as reader: + batch_meta_gen = reader.iter_batches_with_custom_metadata() + batch_with_metas = list(batch_meta_gen) + + for i in range(batch_count): + assert batch_with_metas[i].batch.num_rows == 1 + assert batch_with_metas[i].custom_metadata == {"batch_id": str(i)} + + # the last batch has no custom metadata + assert batch_with_metas[batch_count].batch.num_rows == 1 + assert batch_with_metas[batch_count].custom_metadata is None + def test_ipc_stream_no_batches(): # ARROW-2307 From cf65477a2082d06d1cb407ac33240cfb15615d9e Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 12 Dec 2022 18:20:09 +0100 Subject: [PATCH 3/3] Parametrize test, some docstring nits --- python/pyarrow/ipc.pxi | 15 ++++------ python/pyarrow/tests/test_ipc.py | 51 +++++++++----------------------- 2 files changed, 19 insertions(+), 47 deletions(-) diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index f2b2d3db401b..9b13e71dde96 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -484,7 +484,7 @@ cdef class _CRecordBatchWriter(_Weakrefable): Parameters ---------- batch : RecordBatch - custom_metadata : KeyValueMetadata + custom_metadata : mapping or KeyValueMetadata Keys and values must be string-like / coercible to bytes """ metadata = ensure_metadata(custom_metadata, allow_none=True) @@ -715,17 +715,12 @@ cdef class RecordBatchReader(_Weakrefable): def iter_batches_with_custom_metadata(self): """ - Read next RecordBatch from the stream along with its custom metadata - as a generator + Iterate over record batches from the stream along with their custom + metadata. - Raises + Yields ------ - StopIteration: - At end of stream. - - Yields - ------- - RecordBatch with optional custom metadata + RecordBatchWithMetadata """ while True: try: diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 9b4752b9cb3a..0df302a8de7f 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -947,25 +947,32 @@ def test_ipc_zero_copy_numpy(): @pytest.mark.pandas -def test_ipc_batch_with_custom_metadata_roundtrip(): +@pytest.mark.parametrize("ipc_type", ["stream", "file"]) +def test_batches_with_custom_metadata_roundtrip(ipc_type): df = pd.DataFrame({'foo': [1.5]}) batch = pa.RecordBatch.from_pandas(df) sink = pa.BufferOutputStream() batch_count = 2 - with pa.ipc.new_file(sink, batch.schema) as writer: + file_factory = {"stream": pa.ipc.new_stream, + "file": pa.ipc.new_file}[ipc_type] + + with file_factory(sink, batch.schema) as writer: for i in range(batch_count): - writer.write_batch(batch, {"batch_id": str(i)}) + writer.write_batch(batch, custom_metadata={"batch_id": str(i)}) # write a batch without custom metadata writer.write_batch(batch) buffer = sink.getvalue() - source = pa.BufferReader(buffer) - with pa.ipc.open_file(source) as reader: - batch_with_metas = [reader.get_batch_with_custom_metadata( - i) for i in range(reader.num_record_batches)] + if ipc_type == "stream": + with pa.ipc.open_stream(buffer) as reader: + batch_with_metas = list(reader.iter_batches_with_custom_metadata()) + else: + with pa.ipc.open_file(buffer) as reader: + batch_with_metas = [reader.get_batch_with_custom_metadata(i) + for i in range(reader.num_record_batches)] for i in range(batch_count): assert batch_with_metas[i].batch.num_rows == 1 @@ -978,36 +985,6 @@ def test_ipc_batch_with_custom_metadata_roundtrip(): assert batch_with_metas[batch_count].custom_metadata is None -@pytest.mark.pandas -def test_record_batch_reader_with_custom_metadata_roundtrip(): - df = pd.DataFrame({'foo': [1.5]}) - - batch = pa.RecordBatch.from_pandas(df) - sink = pa.BufferOutputStream() - - batch_count = 2 - with pa.ipc.new_stream(sink, batch.schema) as writer: - for i in range(batch_count): - writer.write_batch(batch, {"batch_id": str(i)}) - # write a batch without custom metadata - writer.write_batch(batch) - - buffer = sink.getvalue() - stream_contents = pa.BufferReader(buffer) - - with pa.ipc.open_stream(stream_contents) as reader: - batch_meta_gen = reader.iter_batches_with_custom_metadata() - batch_with_metas = list(batch_meta_gen) - - for i in range(batch_count): - assert batch_with_metas[i].batch.num_rows == 1 - assert batch_with_metas[i].custom_metadata == {"batch_id": str(i)} - - # the last batch has no custom metadata - assert batch_with_metas[batch_count].batch.num_rows == 1 - assert batch_with_metas[batch_count].custom_metadata is None - - def test_ipc_stream_no_batches(): # ARROW-2307 table = pa.Table.from_arrays([pa.array([1, 2, 3, 4]),