diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index cab6dd34caac..a713d05e04bf 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -23,7 +23,7 @@ on: jobs: - docker: + integration: name: Integration Test runs-on: ubuntu-latest steps: @@ -46,3 +46,54 @@ jobs: run: pip install -e dev/archery[docker] - name: Execute Docker Build run: archery docker run -e ARCHERY_INTEGRATION_WITH_RUST=1 conda-integration + + # test FFI against the C-Data interface exposed by pyarrow + pyarrow-integration-test: + name: Test Pyarrow C Data Interface + runs-on: ubuntu-latest + strategy: + matrix: + rust: [stable] + steps: + - uses: actions/checkout@v2 + with: + submodules: true + - name: Setup Rust toolchain + run: | + rustup toolchain install ${{ matrix.rust }} + rustup default ${{ matrix.rust }} + rustup component add rustfmt clippy + - name: Cache Cargo + uses: actions/cache@v2 + with: + path: /home/runner/.cargo + key: cargo-maturin-cache- + - name: Cache Rust dependencies + uses: actions/cache@v2 + with: + path: /home/runner/target + # this key is not equal because maturin uses different compilation flags. + key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- + - uses: actions/setup-python@v2 + with: + python-version: '3.7' + - name: Upgrade pip and setuptools + run: pip install --upgrade pip setuptools wheel + - name: Install python dependencies + run: pip install maturin==0.8.2 toml==0.10.1 pytest pytz + - name: Install nightly pyarrow wheel + # this points to a nightly pyarrow build containing neccessary + # API for integration testing (https://github.com/apache/arrow/pull/10529) + # the hardcoded version is wrong and should be removed either + # after https://issues.apache.org/jira/browse/ARROW-13083 + # gets fixes or pyarrow 5.0 gets released + hardcoded version is wrong, bot contains + run: pip install --index-url https://pypi.fury.io/arrow-nightlies/ pyarrow==3.1.0.dev1030 + - name: Run tests + env: + CARGO_HOME: "/home/runner/.cargo" + CARGO_TARGET_DIR: "/home/runner/target" + working-directory: arrow-pyarrow-integration-testing + run: | + maturin develop + pytest -v . diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 559c7c8a3961..a041afc8b217 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -283,52 +283,6 @@ jobs: continue-on-error: true run: bash <(curl -s https://codecov.io/bash) - # test FFI against the C-Data interface exposed by pyarrow - pyarrow-integration-test: - name: Test Pyarrow C Data Interface - runs-on: ubuntu-latest - strategy: - matrix: - rust: [stable] - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt clippy - - name: Cache Cargo - uses: actions/cache@v2 - with: - path: /home/runner/.cargo - key: cargo-maturin-cache- - - name: Cache Rust dependencies - uses: actions/cache@v2 - with: - path: /home/runner/target - # this key is not equal because maturin uses different compilation flags. - key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- - - uses: actions/setup-python@v2 - with: - python-version: '3.7' - - name: Install Python dependencies - run: python -m pip install --upgrade pip setuptools wheel - - name: Run tests - run: | - export CARGO_HOME="/home/runner/.cargo" - export CARGO_TARGET_DIR="/home/runner/target" - - cd arrow-pyarrow-integration-testing - - python -m venv venv - source venv/bin/activate - - pip install maturin==0.8.2 toml==0.10.1 pyarrow==1.0.0 pytz - maturin develop - python -m unittest discover tests - # test the arrow crate builds against wasm32 in stable rust wasm32-build: name: Build wasm32 on AMD64 Rust ${{ matrix.rust }} diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 5b5462d9c151..a601654d0bcd 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -18,6 +18,7 @@ //! This library demonstrates a minimal usage of Rust's C data interface to pass //! arrays from and to Python. +use std::convert::TryFrom; use std::error; use std::fmt; use std::sync::Arc; @@ -28,8 +29,10 @@ use pyo3::{libc::uintptr_t, prelude::*}; use arrow::array::{make_array_from_raw, ArrayRef, Int64Array}; use arrow::compute::kernels; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; use arrow::ffi; +use arrow::ffi::FFI_ArrowSchema; /// an error that bridges ArrowError with a Python error #[derive(Debug)] @@ -68,7 +71,107 @@ impl From for PyErr { } } -fn to_rust(ob: PyObject, py: Python) -> PyResult { +#[pyclass] +struct PyDataType { + inner: DataType, +} + +#[pyclass] +struct PyField { + inner: Field, +} + +#[pyclass] +struct PySchema { + inner: Schema, +} + +#[pymethods] +impl PyDataType { + #[staticmethod] + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let dtype = DataType::try_from(&c_schema).map_err(PyO3ArrowError::from)?; + Ok(Self { inner: dtype }) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = + FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("DataType")?; + let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(dtype.into()) + } +} + +#[pymethods] +impl PyField { + #[staticmethod] + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let field = Field::try_from(&c_schema).map_err(PyO3ArrowError::from)?; + Ok(Self { inner: field }) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = + FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("Field")?; + let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(dtype.into()) + } +} + +#[pymethods] +impl PySchema { + #[staticmethod] + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let schema = Schema::try_from(&c_schema).map_err(PyO3ArrowError::from)?; + Ok(Self { inner: schema }) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = + FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("Schema")?; + let schema = + class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(schema.into()) + } +} + +impl<'source> FromPyObject<'source> for PyDataType { + fn extract(value: &'source PyAny) -> PyResult { + PyDataType::from_pyarrow(value) + } +} + +impl<'source> FromPyObject<'source> for PyField { + fn extract(value: &'source PyAny) -> PyResult { + PyField::from_pyarrow(value) + } +} + +impl<'source> FromPyObject<'source> for PySchema { + fn extract(value: &'source PyAny) -> PyResult { + PySchema::from_pyarrow(value) + } +} + +fn array_to_rust(ob: PyObject, py: Python) -> PyResult { // prepare a pointer to receive the Array struct let (array_pointer, schema_pointer) = ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() }); @@ -82,13 +185,12 @@ fn to_rust(ob: PyObject, py: Python) -> PyResult { )?; let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) } - .map_err(|e| PyO3ArrowError::from(e))?; + .map_err(PyO3ArrowError::from)?; Ok(array) } -fn to_py(array: ArrayRef, py: Python) -> PyResult { - let (array_pointer, schema_pointer) = - array.to_raw().map_err(|e| PyO3ArrowError::from(e))?; +fn array_to_py(array: ArrayRef, py: Python) -> PyResult { + let (array_pointer, schema_pointer) = array.to_raw().map_err(PyO3ArrowError::from)?; let pa = py.import("pyarrow")?; @@ -103,22 +205,17 @@ fn to_py(array: ArrayRef, py: Python) -> PyResult { #[pyfunction] fn double(array: PyObject, py: Python) -> PyResult { // import - let array = to_rust(array, py)?; + let array = array_to_rust(array, py)?; // perform some operation - let array = - array - .as_any() - .downcast_ref::() - .ok_or(PyO3ArrowError::ArrowError(ArrowError::ParseError( - "Expects an int64".to_string(), - )))?; - let array = - kernels::arithmetic::add(&array, &array).map_err(|e| PyO3ArrowError::from(e))?; + let array = array.as_any().downcast_ref::().ok_or_else(|| { + PyO3ArrowError::ArrowError(ArrowError::ParseError("Expects an int64".to_string())) + })?; + let array = kernels::arithmetic::add(&array, &array).map_err(PyO3ArrowError::from)?; let array = Arc::new(array); // export - to_py(array, py) + array_to_py(array, py) } /// calls a lambda function that receives and returns an array @@ -130,11 +227,9 @@ fn double_py(lambda: PyObject, py: Python) -> PyResult { let expected = Arc::new(Int64Array::from(vec![Some(2), None, Some(6)])) as ArrayRef; // to py - let array = to_py(array, py)?; - - let array = lambda.call1(py, (array,))?; - - let array = to_rust(array, py)?; + let pyarray = array_to_py(array, py)?; + let pyarray = lambda.call1(py, (pyarray,))?; + let array = array_to_rust(pyarray, py)?; Ok(array == expected) } @@ -143,42 +238,45 @@ fn double_py(lambda: PyObject, py: Python) -> PyResult { #[pyfunction] fn substring(array: PyObject, start: i64, py: Python) -> PyResult { // import - let array = to_rust(array, py)?; + let array = array_to_rust(array, py)?; // substring let array = kernels::substring::substring(array.as_ref(), start, &None) - .map_err(|e| PyO3ArrowError::from(e))?; + .map_err(PyO3ArrowError::from)?; // export - to_py(array, py) + array_to_py(array, py) } /// Returns the concatenate #[pyfunction] fn concatenate(array: PyObject, py: Python) -> PyResult { // import - let array = to_rust(array, py)?; + let array = array_to_rust(array, py)?; // concat let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]) - .map_err(|e| PyO3ArrowError::from(e))?; + .map_err(PyO3ArrowError::from)?; // export - to_py(array, py) + array_to_py(array, py) } /// Converts to rust and back to python #[pyfunction] -fn round_trip(array: PyObject, py: Python) -> PyResult { +fn round_trip(pyarray: PyObject, py: Python) -> PyResult { // import - let array = to_rust(array, py)?; + let array = array_to_rust(pyarray, py)?; // export - to_py(array, py) + array_to_py(array, py) } #[pymodule] fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pyfunction!(double))?; m.add_wrapped(wrap_pyfunction!(double_py))?; m.add_wrapped(wrap_pyfunction!(substring))?; diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 5524c54ec178..301eac8d2a09 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -16,156 +16,252 @@ # specific language governing permissions and limitations # under the License. -import unittest -from datetime import date, datetime -from decimal import Decimal - -import arrow_pyarrow_integration_testing -import pyarrow -from pytz import timezone - - -class TestCase(unittest.TestCase): - def test_primitive_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - a = pyarrow.array([1, 2, 3]) - b = arrow_pyarrow_integration_testing.double(a) - self.assertEqual(b, pyarrow.array([2, 4, 6])) - del a - del b - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_primitive_rust(self): - """ - Rust -> Python -> Rust - """ - old_allocated = pyarrow.total_allocated_bytes() - - def double(array): - array = array.to_pylist() - return pyarrow.array([x * 2 if x is not None else None for x in array]) - - is_correct = arrow_pyarrow_integration_testing.double_py(double) - self.assertTrue(is_correct) - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_string_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - a = pyarrow.array(["a", None, "ccc"]) - b = arrow_pyarrow_integration_testing.substring(a, 1) - self.assertEqual(b, pyarrow.array(["", None, "cc"])) - del a - del b - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_time32_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - a = pyarrow.array([None, 1, 2], pyarrow.time32("s")) - b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array([None, 1, 2] + [None, 1, 2], pyarrow.time32("s")) - self.assertEqual(b, expected) - del a - del b - del expected - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_date32_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - py_array = [None, date(1990, 3, 9), date(2021, 6, 20)] - a = pyarrow.array(py_array, pyarrow.date32()) - b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array(py_array + py_array, pyarrow.date32()) - self.assertEqual(b, expected) - del a - del b - del expected - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_timestamp_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - py_array = [ - None, - datetime(2021, 1, 1, 1, 1, 1, 1), - datetime(2020, 3, 9, 1, 1, 1, 1), +import contextlib +import datetime +import decimal +import string + +import pytest +import pyarrow as pa +import pytz + +from arrow_pyarrow_integration_testing import PyDataType, PyField, PySchema +import arrow_pyarrow_integration_testing as rust + + +@contextlib.contextmanager +def no_pyarrow_leak(): + # No leak of C++ memory + old_allocation = pa.total_allocated_bytes() + try: + yield + finally: + assert pa.total_allocated_bytes() == old_allocation + + +@pytest.fixture(autouse=True) +def assert_pyarrow_leak(): + # automatically applied to all test cases + with no_pyarrow_leak(): + yield + + +_supported_pyarrow_types = [ + pa.null(), + pa.bool_(), + pa.int32(), + pa.time32("s"), + pa.time64("us"), + pa.date32(), + pa.timestamp("us"), + pa.timestamp("us", tz="UTC"), + pa.timestamp("us", tz="Europe/Paris"), + pa.float16(), + pa.float32(), + pa.float64(), + pa.decimal128(19, 4), + pa.string(), + pa.binary(), + pa.large_string(), + pa.large_binary(), + pa.list_(pa.int32()), + pa.large_list(pa.uint16()), + pa.struct( + [ + pa.field("a", pa.int32()), + pa.field("b", pa.int8()), + pa.field("c", pa.string()), ] - a = pyarrow.array(py_array, pyarrow.timestamp("us")) - b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array(py_array + py_array, pyarrow.timestamp("us")) - self.assertEqual(b, expected) - del a - del b - del expected - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_timestamp_tz_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - py_array = [ - None, - datetime(2021, 1, 1, 1, 1, 1, 1, tzinfo=timezone("America/New_York")), - datetime(2020, 3, 9, 1, 1, 1, 1, tzinfo=timezone("America/New_York")), + ), + pa.struct( + [ + pa.field("a", pa.int32(), nullable=False), + pa.field("b", pa.int8(), nullable=False), + pa.field("c", pa.string()), ] - a = pyarrow.array(py_array, pyarrow.timestamp("us", tz="America/New_York")) - b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array( - py_array + py_array, pyarrow.timestamp("us", tz="America/New_York") - ) - self.assertEqual(b, expected) - del a - del b - del expected - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_decimal_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - py_array = [round(Decimal(123.45), 2), round(Decimal(-123.45), 2), None] - a = pyarrow.array(py_array, pyarrow.decimal128(6, 2)) - b = arrow_pyarrow_integration_testing.round_trip(a) - self.assertEqual(a, b) - del a - del b - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_list_array(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - a = pyarrow.array([[], None, [1, 2], [4, 5, 6]], pyarrow.list_(pyarrow.int64())) - b = arrow_pyarrow_integration_testing.round_trip(a) - - b.validate(full=True) - assert a.to_pylist() == b.to_pylist() - assert a.type == b.type - del a - del b - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) + ), +] + +_unsupported_pyarrow_types = [ + pa.decimal256(76, 38), + pa.duration("s"), + pa.binary(10), + pa.list_(pa.int32(), 2), + pa.map_(pa.string(), pa.int32()), + pa.union( + [pa.field("a", pa.binary(10)), pa.field("b", pa.string())], + mode=pa.lib.UnionMode_DENSE, + ), + pa.union( + [pa.field("a", pa.binary(10)), pa.field("b", pa.string())], + mode=pa.lib.UnionMode_DENSE, + type_codes=[4, 8], + ), + pa.union( + [pa.field("a", pa.binary(10)), pa.field("b", pa.string())], + mode=pa.lib.UnionMode_SPARSE, + ), + pa.union( + [ + pa.field("a", pa.binary(10), nullable=False), + pa.field("b", pa.string()), + ], + mode=pa.lib.UnionMode_SPARSE, + ), +] + + +@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str) +def test_type_roundtrip(pyarrow_type): + ty = PyDataType.from_pyarrow(pyarrow_type) + restored = ty.to_pyarrow() + assert restored == pyarrow_type + assert restored is not pyarrow_type + + +@pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str) +def test_type_roundtrip_raises(pyarrow_type): + with pytest.raises(Exception): + PyDataType.from_pyarrow(pyarrow_type) + + +def test_dictionary_type_roundtrip(): + # the dictionary type conversion is incomplete + pyarrow_type = pa.dictionary(pa.int32(), pa.string()) + ty = PyDataType.from_pyarrow(pyarrow_type) + assert ty.to_pyarrow() == pa.int32() + + +@pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str) +def test_field_roundtrip(pyarrow_type): + pyarrow_field = pa.field("test", pyarrow_type, nullable=True) + field = PyField.from_pyarrow(pyarrow_field) + assert field.to_pyarrow() == pyarrow_field + + if pyarrow_type != pa.null(): + # A null type field may not be non-nullable + pyarrow_field = pa.field("test", pyarrow_type, nullable=False) + field = PyField.from_pyarrow(pyarrow_field) + assert field.to_pyarrow() == pyarrow_field + + +def test_schema_roundtrip(): + pyarrow_fields = zip(string.ascii_lowercase, _supported_pyarrow_types) + pyarrow_schema = pa.schema(pyarrow_fields) + schema = PySchema.from_pyarrow(pyarrow_schema) + assert schema.to_pyarrow() == pyarrow_schema + + +def test_primitive_python(): + """ + Python -> Rust -> Python + """ + a = pa.array([1, 2, 3]) + b = rust.double(a) + assert b == pa.array([2, 4, 6]) + del a + del b + + +def test_primitive_rust(): + """ + Rust -> Python -> Rust + """ + + def double(array): + array = array.to_pylist() + return pa.array([x * 2 if x is not None else None for x in array]) + + is_correct = rust.double_py(double) + assert is_correct + + +def test_string_python(): + """ + Python -> Rust -> Python + """ + a = pa.array(["a", None, "ccc"]) + b = rust.substring(a, 1) + assert b == pa.array(["", None, "cc"]) + del a + del b + + +def test_time32_python(): + """ + Python -> Rust -> Python + """ + a = pa.array([None, 1, 2], pa.time32("s")) + b = rust.concatenate(a) + expected = pa.array([None, 1, 2] + [None, 1, 2], pa.time32("s")) + assert b == expected + del a + del b + del expected + + +def test_list_array(): + """ + Python -> Rust -> Python + """ + a = pa.array([[], None, [1, 2], [4, 5, 6]], pa.list_(pa.int64())) + b = rust.round_trip(a) + b.validate(full=True) + assert a.to_pylist() == b.to_pylist() + assert a.type == b.type + del a + del b + + +def test_timestamp_python(): + """ + Python -> Rust -> Python + """ + data = [ + None, + datetime.datetime(2021, 1, 1, 1, 1, 1, 1), + datetime.datetime(2020, 3, 9, 1, 1, 1, 1), + ] + a = pa.array(data, pa.timestamp("us")) + b = rust.concatenate(a) + expected = pa.array(data + data, pa.timestamp("us")) + assert b == expected + del a + del b + del expected + + +def test_timestamp_tz_python(): + """ + Python -> Rust -> Python + """ + tzinfo = pytz.timezone("America/New_York") + pyarrow_type = pa.timestamp("us", tz="America/New_York") + data = [ + None, + datetime.datetime(2021, 1, 1, 1, 1, 1, 1, tzinfo=tzinfo), + datetime.datetime(2020, 3, 9, 1, 1, 1, 1, tzinfo=tzinfo), + ] + a = pa.array(data, type=pyarrow_type) + b = rust.concatenate(a) + expected = pa.array(data * 2, type=pyarrow_type) + assert b == expected + del a + del b + del expected + + +def test_decimal_python(): + """ + Python -> Rust -> Python + """ + data = [ + round(decimal.Decimal(123.45), 2), + round(decimal.Decimal(-123.45), 2), + None + ] + a = pa.array(data, pa.decimal128(6, 2)) + b = rust.round_trip(a) + assert a == b + del a + del b diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 0ed2a4526211..4a1016aab026 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -52,6 +52,7 @@ hex = "0.4" prettytable-rs = { version = "0.8.0", optional = true } lexical-core = "^0.7" multiversion = "0.6.1" +bitflags = "1.2.1" [features] default = ["csv", "ipc"] diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs new file mode 100644 index 000000000000..7e98508cf090 --- /dev/null +++ b/arrow/src/datatypes/ffi.rs @@ -0,0 +1,359 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::convert::TryFrom; + +use crate::{ + datatypes::{DataType, Field, Schema, TimeUnit}, + error::{ArrowError, Result}, + ffi::{FFI_ArrowSchema, Flags}, +}; + +impl TryFrom<&FFI_ArrowSchema> for DataType { + type Error = ArrowError; + + /// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings + fn try_from(c_schema: &FFI_ArrowSchema) -> Result { + let dtype = match c_schema.format() { + "n" => DataType::Null, + "b" => DataType::Boolean, + "c" => DataType::Int8, + "C" => DataType::UInt8, + "s" => DataType::Int16, + "S" => DataType::UInt16, + "i" => DataType::Int32, + "I" => DataType::UInt32, + "l" => DataType::Int64, + "L" => DataType::UInt64, + "e" => DataType::Float16, + "f" => DataType::Float32, + "g" => DataType::Float64, + "z" => DataType::Binary, + "Z" => DataType::LargeBinary, + "u" => DataType::Utf8, + "U" => DataType::LargeUtf8, + "tdD" => DataType::Date32, + "tdm" => DataType::Date64, + "tts" => DataType::Time32(TimeUnit::Second), + "ttm" => DataType::Time32(TimeUnit::Millisecond), + "ttu" => DataType::Time64(TimeUnit::Microsecond), + "ttn" => DataType::Time64(TimeUnit::Nanosecond), + "+l" => { + let c_child = c_schema.child(0); + DataType::List(Box::new(Field::try_from(c_child)?)) + } + "+L" => { + let c_child = c_schema.child(0); + DataType::LargeList(Box::new(Field::try_from(c_child)?)) + } + "+s" => { + let fields = c_schema.children().map(Field::try_from); + DataType::Struct(fields.collect::>>()?) + } + // Parametrized types, requiring string parse + other => { + match other.splitn(2, ':').collect::>().as_slice() { + // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth" + ["d", extra] => { + match extra.splitn(3, ',').collect::>().as_slice() { + [precision, scale] => { + let parsed_precision = precision.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer precision".to_string(), + ) + })?; + let parsed_scale = scale.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer scale".to_string(), + ) + })?; + DataType::Decimal(parsed_precision, parsed_scale) + }, + [precision, scale, bits] => { + if *bits != "128" { + return Err(ArrowError::CDataInterface("Only 128 bit wide decimal is supported in the Rust implementation".to_string())); + } + let parsed_precision = precision.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer precision".to_string(), + ) + })?; + let parsed_scale = scale.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer scale".to_string(), + ) + })?; + DataType::Decimal(parsed_precision, parsed_scale) + } + _ => { + return Err(ArrowError::CDataInterface(format!( + "The decimal pattern \"d:{:?}\" is not supported in the Rust implementation", + extra + ))) + } + } + } + + // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp. + ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None), + ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None), + ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None), + ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None), + ["tss", tz] => { + DataType::Timestamp(TimeUnit::Second, Some(tz.to_string())) + } + ["tsm", tz] => { + DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string())) + } + ["tsu", tz] => { + DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string())) + } + ["tsn", tz] => { + DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string())) + } + _ => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{:?}\" is still not supported in Rust implementation", + other + ))) + } + } + } + }; + Ok(dtype) + } +} + +impl TryFrom<&FFI_ArrowSchema> for Field { + type Error = ArrowError; + + fn try_from(c_schema: &FFI_ArrowSchema) -> Result { + let dtype = DataType::try_from(c_schema)?; + let field = Field::new(c_schema.name(), dtype, c_schema.nullable()); + Ok(field) + } +} + +impl TryFrom<&FFI_ArrowSchema> for Schema { + type Error = ArrowError; + + fn try_from(c_schema: &FFI_ArrowSchema) -> Result { + // interpret it as a struct type then extract its fields + let dtype = DataType::try_from(c_schema)?; + if let DataType::Struct(fields) = dtype { + Ok(Schema::new(fields)) + } else { + Err(ArrowError::CDataInterface( + "Unable to interpret C data struct as a Schema".to_string(), + )) + } + } +} + +impl TryFrom<&DataType> for FFI_ArrowSchema { + type Error = ArrowError; + + /// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings + fn try_from(dtype: &DataType) -> Result { + let format = match dtype { + DataType::Null => "n".to_string(), + DataType::Boolean => "b".to_string(), + DataType::Int8 => "c".to_string(), + DataType::UInt8 => "C".to_string(), + DataType::Int16 => "s".to_string(), + DataType::UInt16 => "S".to_string(), + DataType::Int32 => "i".to_string(), + DataType::UInt32 => "I".to_string(), + DataType::Int64 => "l".to_string(), + DataType::UInt64 => "L".to_string(), + DataType::Float16 => "e".to_string(), + DataType::Float32 => "f".to_string(), + DataType::Float64 => "g".to_string(), + DataType::Binary => "z".to_string(), + DataType::LargeBinary => "Z".to_string(), + DataType::Utf8 => "u".to_string(), + DataType::LargeUtf8 => "U".to_string(), + DataType::Decimal(precision, scale) => format!("d:{},{}", precision, scale), + DataType::Date32 => "tdD".to_string(), + DataType::Date64 => "tdm".to_string(), + DataType::Time32(TimeUnit::Second) => "tts".to_string(), + DataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(), + DataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(), + DataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(), + DataType::Timestamp(TimeUnit::Second, None) => "tss:".to_string(), + DataType::Timestamp(TimeUnit::Millisecond, None) => "tsm:".to_string(), + DataType::Timestamp(TimeUnit::Microsecond, None) => "tsu:".to_string(), + DataType::Timestamp(TimeUnit::Nanosecond, None) => "tsn:".to_string(), + DataType::Timestamp(TimeUnit::Second, Some(tz)) => format!("tss:{}", tz), + DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => format!("tsm:{}", tz), + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => format!("tsu:{}", tz), + DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => format!("tsn:{}", tz), + DataType::List(_) => "+l".to_string(), + DataType::LargeList(_) => "+L".to_string(), + DataType::Struct(_) => "+s".to_string(), + other => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{:?}\" is still not supported in Rust implementation", + other + ))) + } + }; + // allocate and hold the children + let children = match dtype { + DataType::List(child) | DataType::LargeList(child) => { + vec![FFI_ArrowSchema::try_from(child.as_ref())?] + } + DataType::Struct(fields) => fields + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::>>()?, + _ => vec![], + }; + FFI_ArrowSchema::try_new(&format, children) + } +} + +impl TryFrom<&Field> for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(field: &Field) -> Result { + let flags = if field.is_nullable() { + Flags::NULLABLE + } else { + Flags::empty() + }; + FFI_ArrowSchema::try_from(field.data_type())? + .with_name(field.name())? + .with_flags(flags) + } +} + +impl TryFrom<&Schema> for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(schema: &Schema) -> Result { + let dtype = DataType::Struct(schema.fields().clone()); + let c_schema = FFI_ArrowSchema::try_from(&dtype)?; + Ok(c_schema) + } +} + +impl TryFrom for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(dtype: DataType) -> Result { + FFI_ArrowSchema::try_from(&dtype) + } +} + +impl TryFrom for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(field: Field) -> Result { + FFI_ArrowSchema::try_from(&field) + } +} + +impl TryFrom for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(schema: Schema) -> Result { + FFI_ArrowSchema::try_from(&schema) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::datatypes::{DataType, Field, TimeUnit}; + use crate::error::Result; + use std::convert::TryFrom; + + fn round_trip_type(dtype: DataType) -> Result<()> { + let c_schema = FFI_ArrowSchema::try_from(&dtype)?; + let restored = DataType::try_from(&c_schema)?; + assert_eq!(restored, dtype); + Ok(()) + } + + fn round_trip_field(field: Field) -> Result<()> { + let c_schema = FFI_ArrowSchema::try_from(&field)?; + let restored = Field::try_from(&c_schema)?; + assert_eq!(restored, field); + Ok(()) + } + + fn round_trip_schema(schema: Schema) -> Result<()> { + let c_schema = FFI_ArrowSchema::try_from(&schema)?; + let restored = Schema::try_from(&c_schema)?; + assert_eq!(restored, schema); + Ok(()) + } + + #[test] + fn test_type() -> Result<()> { + round_trip_type(DataType::Int64)?; + round_trip_type(DataType::UInt64)?; + round_trip_type(DataType::Float64)?; + round_trip_type(DataType::Date64)?; + round_trip_type(DataType::Time64(TimeUnit::Nanosecond))?; + round_trip_type(DataType::Utf8)?; + round_trip_type(DataType::List(Box::new(Field::new( + "a", + DataType::Int16, + false, + ))))?; + round_trip_type(DataType::Struct(vec![Field::new( + "a", + DataType::Utf8, + true, + )]))?; + Ok(()) + } + + #[test] + fn test_field() -> Result<()> { + let dtype = DataType::Struct(vec![Field::new("a", DataType::Utf8, true)]); + round_trip_field(Field::new("test", dtype, true))?; + Ok(()) + } + + #[test] + fn test_schema() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]); + round_trip_schema(schema)?; + + // test that we can interpret struct types as schema + let dtype = DataType::Struct(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int16, false), + ]); + let c_schema = FFI_ArrowSchema::try_from(&dtype)?; + let schema = Schema::try_from(&c_schema)?; + assert_eq!(schema.fields().len(), 2); + + // test that we assert the input type + let c_schema = FFI_ArrowSchema::try_from(&DataType::Float64)?; + let result = Schema::try_from(&c_schema); + assert!(result.is_err()); + Ok(()) + } +} diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs index 6a2d0dcfe27e..51b33dc667e3 100644 --- a/arrow/src/datatypes/mod.rs +++ b/arrow/src/datatypes/mod.rs @@ -36,6 +36,8 @@ mod types; pub use types::*; mod datatype; pub use datatype::*; +mod ffi; +pub use ffi::*; /// A reference-counted reference to a [`Schema`](crate::datatypes::Schema). pub type SchemaRef = Arc; diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index b804dd2db74a..e3589cacdd43 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -77,24 +77,30 @@ To export an array, create an `ArrowArray` using [ArrowArray::try_new]. */ use std::{ + convert::TryFrom, ffi::CStr, ffi::CString, iter, mem::size_of, + os::raw::{c_char, c_void}, ptr::{self, NonNull}, sync::Arc, }; +use bitflags::bitflags; + use crate::array::ArrayData; use crate::buffer::Buffer; -use crate::datatypes::{DataType, Field, TimeUnit}; +use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -#[allow(dead_code)] -struct SchemaPrivateData { - field: Field, - children_ptr: Box<[*mut FFI_ArrowSchema]>, +bitflags! { + pub struct Flags: i64 { + const DICTIONARY_ORDERED = 0b00000001; + const NULLABLE = 0b00000010; + const MAP_KEYS_SORTED = 0b00000100; + } } /// ABI-compatible struct for `ArrowSchema` from C Data Interface @@ -103,15 +109,19 @@ struct SchemaPrivateData { #[repr(C)] #[derive(Debug)] pub struct FFI_ArrowSchema { - format: *const ::std::os::raw::c_char, - name: *const ::std::os::raw::c_char, - metadata: *const ::std::os::raw::c_char, + format: *const c_char, + name: *const c_char, + metadata: *const c_char, flags: i64, n_children: i64, children: *mut *mut FFI_ArrowSchema, dictionary: *mut FFI_ArrowSchema, - release: ::std::option::Option, - private_data: *mut ::std::os::raw::c_void, + release: Option, + private_data: *mut c_void, +} + +struct SchemaPrivateData { + children: Box<[*mut FFI_ArrowSchema]>, } // callback used to drop [FFI_ArrowSchema] when it is exported. @@ -122,11 +132,16 @@ unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) { let schema = &mut *schema; // take ownership back to release it. - CString::from_raw(schema.format as *mut std::os::raw::c_char); - CString::from_raw(schema.name as *mut std::os::raw::c_char); - let private = Box::from_raw(schema.private_data as *mut SchemaPrivateData); - for child in private.children_ptr.iter() { - let _ = Box::from_raw(*child); + CString::from_raw(schema.format as *mut c_char); + if !schema.name.is_null() { + CString::from_raw(schema.name as *mut c_char); + } + if !schema.private_data.is_null() { + let private_data = Box::from_raw(schema.private_data as *mut SchemaPrivateData); + for child in private_data.children.iter() { + drop(Box::from_raw(*child)) + } + drop(private_data); } schema.release = None; @@ -134,54 +149,39 @@ unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) { impl FFI_ArrowSchema { /// create a new [`Ffi_ArrowSchema`]. This fails if the fields' [`DataType`] is not supported. - fn try_new(field: Field) -> Result { - let format = to_format(field.data_type())?; - let name = field.name().clone(); - - // allocate (and hold) the children - let children_vec = match field.data_type() { - DataType::List(field) => { - vec![Box::new(FFI_ArrowSchema::try_new(field.as_ref().clone())?)] - } - DataType::LargeList(field) => { - vec![Box::new(FFI_ArrowSchema::try_new(field.as_ref().clone())?)] - } - DataType::Struct(fields) => fields - .iter() - .map(|field| Ok(Box::new(FFI_ArrowSchema::try_new(field.clone())?))) - .collect::>>()?, - _ => vec![], - }; - // note: this cannot be done along with the above because the above is fallible and this op leaks. - let children_ptr = children_vec + pub fn try_new(format: &str, children: Vec) -> Result { + let mut this = Self::empty(); + + let mut children_ptr = children .into_iter() + .map(Box::new) .map(Box::into_raw) .collect::>(); - let n_children = children_ptr.len() as i64; - let flags = field.is_nullable() as i64 * 2; + this.format = CString::new(format).unwrap().into_raw(); + this.release = Some(release_schema); + this.n_children = children_ptr.len() as i64; + this.children = children_ptr.as_mut_ptr(); - let mut private = Box::new(SchemaPrivateData { - field, - children_ptr, + let private_data = Box::new(SchemaPrivateData { + children: children_ptr, }); + this.private_data = Box::into_raw(private_data) as *mut c_void; - // - Ok(FFI_ArrowSchema { - format: CString::new(format).unwrap().into_raw(), - name: CString::new(name).unwrap().into_raw(), - metadata: std::ptr::null_mut(), - flags, - n_children, - children: private.children_ptr.as_mut_ptr(), - dictionary: std::ptr::null_mut(), - release: Some(release_schema), - private_data: Box::into_raw(private) as *mut ::std::os::raw::c_void, - }) + Ok(this) } - /// create an empty [FFI_ArrowSchema] - fn empty() -> Self { + pub fn with_name(mut self, name: &str) -> Result { + self.name = CString::new(name).unwrap().into_raw(); + Ok(self) + } + + pub fn with_flags(mut self, flags: Flags) -> Result { + self.flags = flags.bits(); + Ok(self) + } + + pub fn empty() -> Self { Self { format: std::ptr::null_mut(), name: std::ptr::null_mut(), @@ -208,15 +208,24 @@ impl FFI_ArrowSchema { pub fn name(&self) -> &str { assert!(!self.name.is_null()); // safe because the lifetime of `self.name` equals `self` - unsafe { CStr::from_ptr(self.name) }.to_str().unwrap() + unsafe { CStr::from_ptr(self.name) } + .to_str() + .expect("The external API has a non-utf8 as name") + } + + pub fn flags(&self) -> Option { + Flags::from_bits(self.flags) } pub fn child(&self, index: usize) -> &Self { assert!(index < self.n_children as usize); - assert!(!self.name.is_null()); unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() } } + pub fn children(&self) -> impl Iterator { + (0..self.n_children as usize).map(move |i| self.child(i)) + } + pub fn nullable(&self) -> bool { (self.flags / 2) & 1 == 1 } @@ -231,178 +240,6 @@ impl Drop for FFI_ArrowSchema { } } -/// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings -fn to_field(schema: &FFI_ArrowSchema) -> Result { - let data_type = match schema.format() { - "n" => DataType::Null, - "b" => DataType::Boolean, - "c" => DataType::Int8, - "C" => DataType::UInt8, - "s" => DataType::Int16, - "S" => DataType::UInt16, - "i" => DataType::Int32, - "I" => DataType::UInt32, - "l" => DataType::Int64, - "L" => DataType::UInt64, - "e" => DataType::Float16, - "f" => DataType::Float32, - "g" => DataType::Float64, - "z" => DataType::Binary, - "Z" => DataType::LargeBinary, - "u" => DataType::Utf8, - "U" => DataType::LargeUtf8, - "tdD" => DataType::Date32, - "tdm" => DataType::Date64, - "tts" => DataType::Time32(TimeUnit::Second), - "ttm" => DataType::Time32(TimeUnit::Millisecond), - "ttu" => DataType::Time64(TimeUnit::Microsecond), - "ttn" => DataType::Time64(TimeUnit::Nanosecond), - "+l" => { - let child = schema.child(0); - DataType::List(Box::new(to_field(child)?)) - } - "+L" => { - let child = schema.child(0); - DataType::LargeList(Box::new(to_field(child)?)) - } - "+s" => { - let children = (0..schema.n_children as usize) - .map(|x| to_field(schema.child(x))) - .collect::>>()?; - DataType::Struct(children) - } - // Parametrized types, requiring string parse - other => { - match other.splitn(2, ':').collect::>().as_slice() { - // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth" - ["d", extra] => { - match extra.splitn(3, ',').collect::>().as_slice() { - [precision, scale] => { - let parsed_precision = precision.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer precision".to_string(), - ) - })?; - let parsed_scale = scale.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer scale".to_string(), - ) - })?; - DataType::Decimal(parsed_precision, parsed_scale) - }, - [precision, scale, bits] => { - if *bits != "128" { - return Err(ArrowError::CDataInterface("Only 128 bit wide decimal is supported in the Rust implementation".to_string())); - } - let parsed_precision = precision.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer precision".to_string(), - ) - })?; - let parsed_scale = scale.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer scale".to_string(), - ) - })?; - DataType::Decimal(parsed_precision, parsed_scale) - } - _ => { - return Err(ArrowError::CDataInterface(format!( - "The decimal pattern \"d:{:?}\" is not supported in the Rust implementation", - extra - ))) - } - } - } - - // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp. - ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None), - ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None), - ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None), - ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None), - ["tss", tz] => { - DataType::Timestamp(TimeUnit::Second, Some(tz.to_string())) - } - ["tsm", tz] => { - DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string())) - } - ["tsu", tz] => { - DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string())) - } - ["tsn", tz] => { - DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string())) - } - - _ => { - return Err(ArrowError::CDataInterface(format!( - "The datatype \"{:?}\" is still not supported in Rust implementation", - other - ))) - } - } - } - }; - Ok(Field::new(schema.name(), data_type, schema.nullable())) -} - -/// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings -fn to_format(data_type: &DataType) -> Result { - Ok(match data_type { - DataType::Null => "n", - DataType::Boolean => "b", - DataType::Int8 => "c", - DataType::UInt8 => "C", - DataType::Int16 => "s", - DataType::UInt16 => "S", - DataType::Int32 => "i", - DataType::UInt32 => "I", - DataType::Int64 => "l", - DataType::UInt64 => "L", - DataType::Float16 => "e", - DataType::Float32 => "f", - DataType::Float64 => "g", - DataType::Binary => "z", - DataType::LargeBinary => "Z", - DataType::Utf8 => "u", - DataType::LargeUtf8 => "U", - DataType::Decimal(precision, scale) => { - return Ok(format!("d:{},{}", precision, scale)) - } - DataType::Date32 => "tdD", - DataType::Date64 => "tdm", - DataType::Time32(TimeUnit::Second) => "tts", - DataType::Time32(TimeUnit::Millisecond) => "ttm", - DataType::Time64(TimeUnit::Microsecond) => "ttu", - DataType::Time64(TimeUnit::Nanosecond) => "ttn", - DataType::Timestamp(TimeUnit::Second, None) => "tss:", - DataType::Timestamp(TimeUnit::Millisecond, None) => "tsm:", - DataType::Timestamp(TimeUnit::Microsecond, None) => "tsu:", - DataType::Timestamp(TimeUnit::Nanosecond, None) => "tsn:", - DataType::Timestamp(TimeUnit::Second, Some(tz)) => { - return Ok(format!("tss:{}", tz)) - } - DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => { - return Ok(format!("tsm:{}", tz)) - } - DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { - return Ok(format!("tsu:{}", tz)) - } - DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => { - return Ok(format!("tsn:{}", tz)) - } - DataType::List(_) => "+l", - DataType::LargeList(_) => "+L", - DataType::Struct(_) => "+s", - z => { - return Err(ArrowError::CDataInterface(format!( - "The datatype \"{:?}\" is still not supported in Rust implementation", - z - ))) - } - } - .to_string()) -} - // returns the number of bits that buffer `i` (in the C data interface) is expected to have. // This is set by the Arrow specification fn bit_width(data_type: &DataType, i: usize) -> Result { @@ -482,16 +319,16 @@ pub struct FFI_ArrowArray { pub(crate) offset: i64, pub(crate) n_buffers: i64, pub(crate) n_children: i64, - pub(crate) buffers: *mut *const ::std::os::raw::c_void, + pub(crate) buffers: *mut *const c_void, children: *mut *mut FFI_ArrowArray, dictionary: *mut FFI_ArrowArray, - release: ::std::option::Option, + release: Option, // When exported, this MUST contain everything that is owned by this array. - // for example, any buffer pointed to in `buffers` must be here, as well as the `buffers` pointer - // itself. - // In other words, everything in [FFI_ArrowArray] must be owned by `private_data` and can assume - // that they do not outlive `private_data`. - private_data: *mut ::std::os::raw::c_void, + // for example, any buffer pointed to in `buffers` must be here, as well + // as the `buffers` pointer itself. + // In other words, everything in [FFI_ArrowArray] must be owned by + // `private_data` and can assume that they do not outlive `private_data`. + private_data: *mut c_void, } impl Drop for FFI_ArrowArray { @@ -511,7 +348,7 @@ unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { let array = &mut *array; // take ownership of `private_data`, therefore dropping it` - let private = Box::from_raw(array.private_data as *mut PrivateData); + let private = Box::from_raw(array.private_data as *mut ArrayPrivateData); for child in private.children.iter() { let _ = Box::from_raw(*child); } @@ -519,9 +356,9 @@ unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { array.release = None; } -struct PrivateData { +struct ArrayPrivateData { buffers: Vec>, - buffers_ptr: Box<[*const std::os::raw::c_void]>, + buffers_ptr: Box<[*const c_void]>, children: Box<[*mut FFI_ArrowArray]>, } @@ -542,7 +379,7 @@ impl FFI_ArrowArray { .iter() .map(|maybe_buffer| match maybe_buffer { // note that `raw_data` takes into account the buffer's offset - Some(b) => b.as_ptr() as *const std::os::raw::c_void, + Some(b) => b.as_ptr() as *const c_void, None => std::ptr::null(), }) .collect::>(); @@ -556,7 +393,7 @@ impl FFI_ArrowArray { // create the private data owning everything. // any other data must be added here, e.g. via a struct, to track lifetime. - let mut private_data = Box::new(PrivateData { + let mut private_data = Box::new(ArrayPrivateData { buffers, buffers_ptr, children, @@ -572,7 +409,7 @@ impl FFI_ArrowArray { children: private_data.children.as_mut_ptr(), dictionary: std::ptr::null_mut(), release: Some(release_array), - private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + private_data: Box::into_raw(private_data) as *mut c_void, } } @@ -814,7 +651,7 @@ pub struct ArrowArrayChild<'a> { impl ArrowArrayRef for ArrowArray { /// the data_type as declared in the schema fn data_type(&self) -> Result { - to_field(&self.schema).map(|x| x.data_type().clone()) + DataType::try_from(self.schema.as_ref()) } fn array(&self) -> &FFI_ArrowArray { @@ -833,7 +670,7 @@ impl ArrowArrayRef for ArrowArray { impl<'a> ArrowArrayRef for ArrowArrayChild<'a> { /// the data_type as declared in the schema fn data_type(&self) -> Result { - to_field(self.schema).map(|x| x.data_type().clone()) + DataType::try_from(self.schema) } fn array(&self) -> &FFI_ArrowArray { @@ -855,10 +692,8 @@ impl ArrowArray { /// See safety of [ArrowArray] #[allow(clippy::too_many_arguments)] pub unsafe fn try_new(data: ArrayData) -> Result { - let field = Field::new("", data.data_type().clone(), data.null_count() != 0); let array = Arc::new(FFI_ArrowArray::new(&data)); - let schema = Arc::new(FFI_ArrowSchema::try_new(field)?); - + let schema = Arc::new(FFI_ArrowSchema::try_from(data.data_type())?); Ok(ArrowArray { array, schema }) }