From e4db00b8aa90d5a1f329aeae8238797cccb878a2 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Tue, 13 Sep 2022 10:01:20 +0200 Subject: [PATCH 01/27] Support casting to extension type --- python/pyarrow/array.pxi | 8 +++++++- python/pyarrow/tests/test_extension_type.py | 18 +++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 4b35697874ce..b218d3234ec3 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -916,7 +916,13 @@ cdef class Array(_PandasConvertible): ------- cast : Array """ - return _pc().cast(self, target_type, safe=safe, options=options) + if hasattr(target_type, "storage_type"): + arr = self.cast(target_type.storage_type, safe, options) + return ExtensionArray.from_buffers(target_type, len(arr), + arr.buffers(), arr.null_count, + arr.offset) + else: + return _pc().cast(self, target_type, safe=safe, options=options) def view(self, object target_type): """ diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 9c5a394f8955..65d0a5a5fd88 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -517,10 +517,22 @@ def test_cast_kernel_on_extension_arrays(): assert isinstance(casted, pa.ChunkedArray) -def test_casting_to_extension_type_raises(): +def test_casting_to_extension_type(): arr = pa.array([1, 2, 3, 4], pa.int64()) - with pytest.raises(pa.ArrowNotImplementedError): - arr.cast(IntegerType()) + out = arr.cast(IntegerType()) + assert isinstance(out, pa.ExtensionArray) + assert out.to_pylist() == [1, 2, 3, 4] + + +def test_casting_dict_array_to_extension_type(): + storage = pa.array([b"0123456789abcdef"], type=pa.binary(16)) + arr = pa.ExtensionArray.from_storage(UuidType(), storage) + dict_arr = pa.DictionaryArray.from_arrays(pa.array([0, 0], pa.int32()), + arr) + out = dict_arr.cast(UuidType()) + assert isinstance(out, pa.ExtensionArray) + assert out.to_pylist() == [UUID('30313233-3435-3637-3839-616263646566'), + UUID('30313233-3435-3637-3839-616263646566')] def test_null_storage_type(): From b9d8cb6747fb496f70e64c89378d62717b7e7d9c Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Tue, 13 Sep 2022 13:17:29 +0200 Subject: [PATCH 02/27] Move extension casting impl from Python to C++ --- cpp/src/arrow/compute/cast.cc | 11 +++++++---- python/pyarrow/array.pxi | 8 +------- python/pyarrow/tests/test_extension_type.py | 2 +- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 52aecf3e45a6..ba4661b8f575 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -187,11 +187,14 @@ Result CastFunction::DispatchExact( Result> GetCastFunction(const DataType& to_type) { internal::EnsureInitCastTable(); - auto it = internal::g_cast_table.find(static_cast(to_type.id())); - if (it == internal::g_cast_table.end()) { - return Status::NotImplemented("Unsupported cast to ", to_type); + auto ids = {to_type.id(), to_type.storage_id()}; + for (auto& id : ids) { + auto it = internal::g_cast_table.find(static_cast(id)); + if (it != internal::g_cast_table.end()) { + return it->second; + } } - return it->second; + return Status::NotImplemented("Unsupported cast to ", to_type); } } // namespace internal diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index b218d3234ec3..4b35697874ce 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -916,13 +916,7 @@ cdef class Array(_PandasConvertible): ------- cast : Array """ - if hasattr(target_type, "storage_type"): - arr = self.cast(target_type.storage_type, safe, options) - return ExtensionArray.from_buffers(target_type, len(arr), - arr.buffers(), arr.null_count, - arr.offset) - else: - return _pc().cast(self, target_type, safe=safe, options=options) + return _pc().cast(self, target_type, safe=safe, options=options) def view(self, object target_type): """ diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 65d0a5a5fd88..6f07c712f3f3 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -520,7 +520,7 @@ def test_cast_kernel_on_extension_arrays(): def test_casting_to_extension_type(): arr = pa.array([1, 2, 3, 4], pa.int64()) out = arr.cast(IntegerType()) - assert isinstance(out, pa.ExtensionArray) + assert isinstance(out, pa.Int64Array) assert out.to_pylist() == [1, 2, 3, 4] From 89aa16c0b78ad5c63cdc8dde1c498aef2ca018e4 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 14 Sep 2022 14:37:30 +0200 Subject: [PATCH 03/27] Attempt impl CastToExtension et al. [skip ci] --- cpp/src/arrow/compute/cast.cc | 12 +++--- cpp/src/arrow/compute/cast_internal.h | 1 + .../compute/kernels/scalar_cast_internal.h | 1 + .../compute/kernels/scalar_cast_numeric.cc | 38 ++++++++++++++++++- 4 files changed, 44 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index ba4661b8f575..2bfc963b0827 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -62,6 +62,7 @@ void InitCastTable() { AddCastFunctions(GetNumericCasts()); AddCastFunctions(GetTemporalCasts()); AddCastFunctions(GetDictionaryCasts()); + AddCastFunctions(GetExtensionCasts()); } void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); } @@ -187,14 +188,11 @@ Result CastFunction::DispatchExact( Result> GetCastFunction(const DataType& to_type) { internal::EnsureInitCastTable(); - auto ids = {to_type.id(), to_type.storage_id()}; - for (auto& id : ids) { - auto it = internal::g_cast_table.find(static_cast(id)); - if (it != internal::g_cast_table.end()) { - return it->second; - } + auto it = internal::g_cast_table.find(static_cast(to_type.id())); + if (it == internal::g_cast_table.end()) { + return Status::NotImplemented("Unsupported cast to ", to_type); } - return Status::NotImplemented("Unsupported cast to ", to_type); + return it->second; } } // namespace internal diff --git a/cpp/src/arrow/compute/cast_internal.h b/cpp/src/arrow/compute/cast_internal.h index f00a6cdbf4d9..423b791e6a76 100644 --- a/cpp/src/arrow/compute/cast_internal.h +++ b/cpp/src/arrow/compute/cast_internal.h @@ -63,6 +63,7 @@ std::vector> GetTemporalCasts(); std::vector> GetBinaryLikeCasts(); std::vector> GetNestedCasts(); std::vector> GetDictionaryCasts(); +std::vector> GetExtensionCasts(); ARROW_EXPORT Result> GetCastFunction(const DataType& to_type); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h index 4d9afab199ce..7d9fb742437d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h @@ -43,6 +43,7 @@ struct CastFunctor< }; Status CastFromExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out); +Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out); // Utility for numeric casts void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index 8d36cff6ae95..a0c9eeea6191 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -16,11 +16,12 @@ // under the License. // Implementation of casting to integer, floating point, or decimal types - +#include #include "arrow/array/builder_primitive.h" #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/scalar_cast_internal.h" #include "arrow/compute/kernels/util_internal.h" +#include "arrow/extension_type.h" #include "arrow/scalar.h" #include "arrow/util/bit_block_counter.h" #include "arrow/util/int_util.h" @@ -769,6 +770,41 @@ std::vector> GetNumericCasts() { return functions; } + +Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + const CastOptions& options = checked_cast(ctx->state())->options; + + DCHECK(batch[0].is_array()); + std::shared_ptr array = batch[0].array.ToArray(); + std::shared_ptr result; + auto ext_ty = GetExtensionType(kOutputTargetType.type()->name()); + if (ext_ty == nullptr) { + return Status::Invalid("Could not find extension type: " + kOutputTargetType.type()->name()); + } + auto out_ty = ext_ty->storage_type(); + RETURN_NOT_OK(Cast(*array, out_ty, options, + ctx->exec_context()) + .Value(&result)); + ExtensionArray extension(ext_ty, result); + out->value = std::move(extension.data()); + return Status::OK(); +} + +std::shared_ptr GetCastToExtension(std::string name) { + auto func = std::make_shared(std::move(name), Type::EXTENSION); + auto out_ty = kOutputTargetType.type(); + std::cout << "About to add to kernel" << std::endl; + DCHECK_OK(func->AddKernel(Type::INT64, {InputType(Type::INT64)}, + out_ty, CastToExtension)); + return func; +} + +std::vector> GetExtensionCasts() { + auto func = GetCastToExtension("cast_extension"); + return {func}; +} + + } // namespace internal } // namespace compute } // namespace arrow From 444fd0219997accd892e722ff8f02918ecc55665 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Thu, 15 Sep 2022 14:44:59 +0200 Subject: [PATCH 04/27] Works with hard-coded extension type [skip ci] --- .../compute/kernels/scalar_cast_numeric.cc | 17 ++++++----------- python/pyarrow/tests/test_extension_type.py | 2 +- python/pyarrow/types.pxi | 5 +++-- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index a0c9eeea6191..ebc76fe886eb 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -16,7 +16,6 @@ // under the License. // Implementation of casting to integer, floating point, or decimal types -#include #include "arrow/array/builder_primitive.h" #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/scalar_cast_internal.h" @@ -26,6 +25,7 @@ #include "arrow/util/bit_block_counter.h" #include "arrow/util/int_util.h" #include "arrow/util/value_parsing.h" +#include "arrow/util/logging.h" namespace arrow { @@ -773,29 +773,24 @@ std::vector> GetNumericCasts() { Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { const CastOptions& options = checked_cast(ctx->state())->options; - DCHECK(batch[0].is_array()); std::shared_ptr array = batch[0].array.ToArray(); std::shared_ptr result; - auto ext_ty = GetExtensionType(kOutputTargetType.type()->name()); - if (ext_ty == nullptr) { - return Status::Invalid("Could not find extension type: " + kOutputTargetType.type()->name()); - } - auto out_ty = ext_ty->storage_type(); + + auto out_ty = GetExtensionType("arrow.py_integer_type")->storage_type(); RETURN_NOT_OK(Cast(*array, out_ty, options, ctx->exec_context()) .Value(&result)); - ExtensionArray extension(ext_ty, result); + + ExtensionArray extension(options.to_type.GetSharedPtr(), result); out->value = std::move(extension.data()); return Status::OK(); } std::shared_ptr GetCastToExtension(std::string name) { auto func = std::make_shared(std::move(name), Type::EXTENSION); - auto out_ty = kOutputTargetType.type(); - std::cout << "About to add to kernel" << std::endl; DCHECK_OK(func->AddKernel(Type::INT64, {InputType(Type::INT64)}, - out_ty, CastToExtension)); + kOutputTargetType, CastToExtension)); return func; } diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 6f07c712f3f3..a99e4512c4ff 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -28,7 +28,7 @@ class IntegerType(pa.PyExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.int64()) + pa.PyExtensionType.__init__(self, pa.int64(), "arrow.py_integer_type") def __reduce__(self): return IntegerType, () diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index d37363e06ff3..b241ce6e6d52 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1008,8 +1008,9 @@ cdef class PyExtensionType(ExtensionType): raise TypeError("Can only instantiate subclasses of " "PyExtensionType") - def __init__(self, DataType storage_type): - ExtensionType.__init__(self, storage_type, "arrow.py_extension_type") + def __init__(self, DataType storage_type, + extension_name="arrow.py_extension_type"): + ExtensionType.__init__(self, storage_type, extension_name) def __reduce__(self): raise NotImplementedError("Please implement {0}.__reduce__" From 6c05a1d620dbaea44e2db4520acb16330332df9a Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Fri, 16 Sep 2022 14:37:49 +0200 Subject: [PATCH 05/27] Static cast to extension type [skip ci] --- cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc | 10 +++++----- python/pyarrow/tests/test_extension_type.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index ebc76fe886eb..9c2edd09c669 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -770,14 +770,12 @@ std::vector> GetNumericCasts() { return functions; } - Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { const CastOptions& options = checked_cast(ctx->state())->options; + auto out_ty = static_cast(*options.to_type.type).storage_type(); DCHECK(batch[0].is_array()); std::shared_ptr array = batch[0].array.ToArray(); std::shared_ptr result; - - auto out_ty = GetExtensionType("arrow.py_integer_type")->storage_type(); RETURN_NOT_OK(Cast(*array, out_ty, options, ctx->exec_context()) .Value(&result)); @@ -789,8 +787,10 @@ Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou std::shared_ptr GetCastToExtension(std::string name) { auto func = std::make_shared(std::move(name), Type::EXTENSION); - DCHECK_OK(func->AddKernel(Type::INT64, {InputType(Type::INT64)}, - kOutputTargetType, CastToExtension)); + for (auto in_ty : IntTypes()) { + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, + kOutputTargetType, CastToExtension)); + } return func; } diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index a99e4512c4ff..cd6c817e8d76 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -28,7 +28,7 @@ class IntegerType(pa.PyExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.int64(), "arrow.py_integer_type") + pa.PyExtensionType.__init__(self, pa.int64()) #, "arrow.py_integer_type") def __reduce__(self): return IntegerType, () From 296b4fbc0bbb5654fc24e6711184770c88a37297 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Fri, 16 Sep 2022 15:04:51 +0200 Subject: [PATCH 06/27] Move to scalar_cast_extension, need C++ tests [skip ci] --- cpp/src/arrow/CMakeLists.txt | 1 + .../compute/kernels/scalar_cast_extension.cc | 66 +++++++++++++++++++ .../compute/kernels/scalar_cast_numeric.cc | 33 +--------- python/pyarrow/tests/test_extension_type.py | 19 ++++-- python/pyarrow/types.pxi | 5 +- 5 files changed, 84 insertions(+), 40 deletions(-) create mode 100644 cpp/src/arrow/compute/kernels/scalar_cast_extension.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 88d72b11832a..2396a5e3a1e0 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -425,6 +425,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_boolean.cc compute/kernels/scalar_cast_boolean.cc compute/kernels/scalar_cast_dictionary.cc + compute/kernels/scalar_cast_extension.cc compute/kernels/scalar_cast_internal.cc compute/kernels/scalar_cast_nested.cc compute/kernels/scalar_cast_numeric.cc diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc new file mode 100644 index 000000000000..231684e06fba --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Implementation of casting to extension types +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/scalar.h" + +namespace arrow { +namespace compute { +namespace internal { + +Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + const CastOptions& options = checked_cast(ctx->state())->options; + auto out_ty = static_cast(*options.to_type.type).storage_type(); + + DCHECK(batch[0].is_array()); + std::shared_ptr array = batch[0].array.ToArray(); + std::shared_ptr result; + + RETURN_NOT_OK(Cast(*array, out_ty, options, + ctx->exec_context()) + .Value(&result)); + ExtensionArray extension(options.to_type.GetSharedPtr(), result); + out->value = std::move(extension.data()); + return Status::OK(); +} + +std::shared_ptr GetCastToExtension(std::string name) { + auto func = std::make_shared(std::move(name), Type::EXTENSION); + // TODO(milesgranger): Better way to add all types? `AllTypeIds` exists in tests... + for (auto types : {IntTypes(), FloatingPointTypes(), StringTypes(), BinaryTypes()}) { + for (auto in_ty : types) { + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, + kOutputTargetType, CastToExtension)); + } + } + DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, + kOutputTargetType, CastToExtension)); + return func; +} + +std::vector> GetExtensionCasts() { + auto func = GetCastToExtension("cast_extension"); + return {func}; +} + + +} // namespace internal +} // namespace compute +} // namespace arrow + diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index 9c2edd09c669..8d36cff6ae95 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -16,16 +16,15 @@ // under the License. // Implementation of casting to integer, floating point, or decimal types + #include "arrow/array/builder_primitive.h" #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/scalar_cast_internal.h" #include "arrow/compute/kernels/util_internal.h" -#include "arrow/extension_type.h" #include "arrow/scalar.h" #include "arrow/util/bit_block_counter.h" #include "arrow/util/int_util.h" #include "arrow/util/value_parsing.h" -#include "arrow/util/logging.h" namespace arrow { @@ -770,36 +769,6 @@ std::vector> GetNumericCasts() { return functions; } -Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - const CastOptions& options = checked_cast(ctx->state())->options; - auto out_ty = static_cast(*options.to_type.type).storage_type(); - DCHECK(batch[0].is_array()); - std::shared_ptr array = batch[0].array.ToArray(); - std::shared_ptr result; - RETURN_NOT_OK(Cast(*array, out_ty, options, - ctx->exec_context()) - .Value(&result)); - - ExtensionArray extension(options.to_type.GetSharedPtr(), result); - out->value = std::move(extension.data()); - return Status::OK(); -} - -std::shared_ptr GetCastToExtension(std::string name) { - auto func = std::make_shared(std::move(name), Type::EXTENSION); - for (auto in_ty : IntTypes()) { - DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, - kOutputTargetType, CastToExtension)); - } - return func; -} - -std::vector> GetExtensionCasts() { - auto func = GetCastToExtension("cast_extension"); - return {func}; -} - - } // namespace internal } // namespace compute } // namespace arrow diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index cd6c817e8d76..5c3c10339e0a 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -28,7 +28,7 @@ class IntegerType(pa.PyExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.int64()) #, "arrow.py_integer_type") + pa.PyExtensionType.__init__(self, pa.int64()) def __reduce__(self): return IntegerType, () @@ -517,11 +517,20 @@ def test_cast_kernel_on_extension_arrays(): assert isinstance(casted, pa.ChunkedArray) -def test_casting_to_extension_type(): - arr = pa.array([1, 2, 3, 4], pa.int64()) +@pytest.mark.parametrize("arr", ( + pa.array([1, 2], pa.int32()), + pa.array([1, 2], pa.int64()), + pa.array(["1", "2"], pa.string()), + pa.array([b"1", b"2"], pa.binary()), + pa.array([1.0, 2.0], pa.float32()), + pa.array([1.0, 2.0], pa.float64()) + +)) +def test_casting_to_extension_type(arr): out = arr.cast(IntegerType()) - assert isinstance(out, pa.Int64Array) - assert out.to_pylist() == [1, 2, 3, 4] + assert isinstance(out, pa.ExtensionArray) + assert out.type == IntegerType() + assert out.to_pylist() == [1, 2] def test_casting_dict_array_to_extension_type(): diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index b241ce6e6d52..d37363e06ff3 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1008,9 +1008,8 @@ cdef class PyExtensionType(ExtensionType): raise TypeError("Can only instantiate subclasses of " "PyExtensionType") - def __init__(self, DataType storage_type, - extension_name="arrow.py_extension_type"): - ExtensionType.__init__(self, storage_type, extension_name) + def __init__(self, DataType storage_type): + ExtensionType.__init__(self, storage_type, "arrow.py_extension_type") def __reduce__(self): raise NotImplementedError("Please implement {0}.__reduce__" From 120204bc20a01e8b66a7eb8d0b0ef91e63816cbd Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Mon, 19 Sep 2022 09:58:39 +0200 Subject: [PATCH 07/27] Impl initial cpp test for casting to extension --- .../arrow/compute/kernels/scalar_cast_test.cc | 20 +++++++++++++++++++ python/pyarrow/tests/test_extension_type.py | 18 ++++++++--------- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 963748c9f97d..53b611325064 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -2765,6 +2765,26 @@ TEST(Cast, ExtensionTypeToIntDowncast) { } } +TEST(Cast, PrimitiveToExtension) { + auto primitive_array = ArrayFromJSON(uint8(), "[0, 1, 3]"); + auto extension_array = SmallintArrayFromJSON("[0, 1, 3]"); + CastOptions options; + options.to_type = smallint(); + CheckCast(primitive_array, extension_array, options); +} + +TEST(Cast, DictTypeToExtension) { + auto extension_array = SmallintArrayFromJSON("[1, 2, 1]"); + auto indices_array = ArrayFromJSON(int32(), "[0, 1, 0]"); + + ASSERT_OK_AND_ASSIGN(auto dict_array, + DictionaryArray::FromArrays(indices_array, extension_array)); + + CastOptions options; + options.to_type = smallint(); + CheckCast(dict_array, extension_array, options); +} + TEST(Cast, DictTypeToAnotherDict) { auto check_cast = [&](const std::shared_ptr& in_type, const std::shared_ptr& out_type, diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 5c3c10339e0a..0f174d1a7238 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -517,16 +517,16 @@ def test_cast_kernel_on_extension_arrays(): assert isinstance(casted, pa.ChunkedArray) -@pytest.mark.parametrize("arr", ( - pa.array([1, 2], pa.int32()), - pa.array([1, 2], pa.int64()), - pa.array(["1", "2"], pa.string()), - pa.array([b"1", b"2"], pa.binary()), - pa.array([1.0, 2.0], pa.float32()), - pa.array([1.0, 2.0], pa.float64()) - +@pytest.mark.parametrize("data,ty", ( + ([1, 2], pa.int32), + ([1, 2], pa.int64), + (["1", "2"], pa.string), + ([b"1", b"2"], pa.binary), + ([1.0, 2.0], pa.float32), + ([1.0, 2.0], pa.float64) )) -def test_casting_to_extension_type(arr): +def test_casting_to_extension_type(data, ty): + arr = pa.array(data, ty()) out = arr.cast(IntegerType()) assert isinstance(out, pa.ExtensionArray) assert out.type == IntegerType() From 34c99989c5b8123d5d1ddce25198d8a903261f67 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 21 Sep 2022 08:43:59 +0200 Subject: [PATCH 08/27] Update cpp/src/arrow/compute/kernels/scalar_cast_test.cc [skip ci] Co-authored-by: Antoine Pitrou --- cpp/src/arrow/compute/kernels/scalar_cast_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 53b611325064..4ba75df13b8f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -2773,7 +2773,7 @@ TEST(Cast, PrimitiveToExtension) { CheckCast(primitive_array, extension_array, options); } -TEST(Cast, DictTypeToExtension) { +TEST(Cast, ExtensionDictToExtension) { auto extension_array = SmallintArrayFromJSON("[1, 2, 1]"); auto indices_array = ArrayFromJSON(int32(), "[0, 1, 0]"); From 20be1fec930a4be4eac6a1231327a6cf8f0e6b04 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 21 Sep 2022 09:09:13 +0200 Subject: [PATCH 09/27] Move all but GetExtensionCasts to anonymous namespace [skip ci] --- .../compute/kernels/scalar_cast_extension.cc | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc index 231684e06fba..160bbbaa15f8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -24,17 +24,18 @@ namespace arrow { namespace compute { namespace internal { +namespace { Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - const CastOptions& options = checked_cast(ctx->state())->options; - auto out_ty = static_cast(*options.to_type.type).storage_type(); - + const CastOptions& options = + checked_cast(ctx->state())->options; + auto out_ty = + static_cast(*options.to_type.type).storage_type(); + DCHECK(batch[0].is_array()); std::shared_ptr array = batch[0].array.ToArray(); std::shared_ptr result; - - RETURN_NOT_OK(Cast(*array, out_ty, options, - ctx->exec_context()) - .Value(&result)); + + RETURN_NOT_OK(Cast(*array, out_ty, options, ctx->exec_context()).Value(&result)); ExtensionArray extension(options.to_type.GetSharedPtr(), result); out->value = std::move(extension.data()); return Status::OK(); @@ -45,8 +46,8 @@ std::shared_ptr GetCastToExtension(std::string name) { // TODO(milesgranger): Better way to add all types? `AllTypeIds` exists in tests... for (auto types : {IntTypes(), FloatingPointTypes(), StringTypes(), BinaryTypes()}) { for (auto in_ty : types) { - DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, - kOutputTargetType, CastToExtension)); + DCHECK_OK( + func->AddKernel(in_ty->id(), {in_ty}, kOutputTargetType, CastToExtension)); } } DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, @@ -54,13 +55,13 @@ std::shared_ptr GetCastToExtension(std::string name) { return func; } +}; // namespace + std::vector> GetExtensionCasts() { auto func = GetCastToExtension("cast_extension"); return {func}; } - } // namespace internal } // namespace compute } // namespace arrow - From c7e76c74a62b821e65c161e6533c4b859be7f067 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 21 Sep 2022 09:11:50 +0200 Subject: [PATCH 10/27] Remove CastToExtension from header file [skip ci] --- cpp/src/arrow/compute/kernels/scalar_cast_internal.h | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h index 7d9fb742437d..4d9afab199ce 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h @@ -43,7 +43,6 @@ struct CastFunctor< }; Status CastFromExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out); -Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out); // Utility for numeric casts void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, From 4cd9b65b909c1b1ac9ac7232f8df64c555e5e379 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 21 Sep 2022 11:19:18 +0200 Subject: [PATCH 11/27] Use ARROW_ASSIGN_OR_RAISE during cast [skip ci] --- .../arrow/compute/kernels/scalar_cast_extension.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc index 160bbbaa15f8..dc571faec8e6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -26,16 +26,15 @@ namespace internal { namespace { Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - const CastOptions& options = - checked_cast(ctx->state())->options; - auto out_ty = - static_cast(*options.to_type.type).storage_type(); + const CastOptions& options = checked_cast(ctx->state())->options; + auto out_ty = static_cast(*options.to_type.type).storage_type(); DCHECK(batch[0].is_array()); std::shared_ptr array = batch[0].array.ToArray(); - std::shared_ptr result; - RETURN_NOT_OK(Cast(*array, out_ty, options, ctx->exec_context()).Value(&result)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, + Cast(*array, out_ty, options, ctx->exec_context())); + ExtensionArray extension(options.to_type.GetSharedPtr(), result); out->value = std::move(extension.data()); return Status::OK(); From 2d1890437d229d5818ddf412c9a89c71f5885d67 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 21 Sep 2022 11:44:13 +0200 Subject: [PATCH 12/27] Updates tests from review comments [skip ci] --- .../compute/kernels/scalar_cast_extension.cc | 1 - .../arrow/compute/kernels/scalar_cast_test.cc | 59 +++++++++++++++++-- cpp/src/arrow/testing/extension_type.h | 28 +++++++++ cpp/src/arrow/testing/gtest_util.cc | 29 +++++++++ 4 files changed, 111 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc index dc571faec8e6..96786b8b9557 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -42,7 +42,6 @@ Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou std::shared_ptr GetCastToExtension(std::string name) { auto func = std::make_shared(std::move(name), Type::EXTENSION); - // TODO(milesgranger): Better way to add all types? `AllTypeIds` exists in tests... for (auto types : {IntTypes(), FloatingPointTypes(), StringTypes(), BinaryTypes()}) { for (auto in_ty : types) { DCHECK_OK( diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 4ba75df13b8f..6b67e385aed0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -2728,6 +2728,13 @@ std::shared_ptr SmallintArrayFromJSON(const std::string& json_data) { return MakeArray(ext_data); } +std::shared_ptr TinyintArrayFromJSON(const std::string& json_data) { + auto arr = ArrayFromJSON(int8(), json_data); + auto ext_data = arr->data()->Copy(); + ext_data->type = tinyint(); + return MakeArray(ext_data); +} + TEST(Cast, ExtensionTypeToIntDowncast) { auto smallint = std::make_shared(); ExtensionTypeGuard smallint_guard(smallint); @@ -2766,11 +2773,18 @@ TEST(Cast, ExtensionTypeToIntDowncast) { } TEST(Cast, PrimitiveToExtension) { - auto primitive_array = ArrayFromJSON(uint8(), "[0, 1, 3]"); - auto extension_array = SmallintArrayFromJSON("[0, 1, 3]"); - CastOptions options; - options.to_type = smallint(); - CheckCast(primitive_array, extension_array, options); + { + auto primitive_array = ArrayFromJSON(uint8(), "[0, 1, 3]"); + auto extension_array = SmallintArrayFromJSON("[0, 1, 3]"); + CastOptions options; + options.to_type = smallint(); + CheckCast(primitive_array, extension_array, options); + } + { + CastOptions options; + options.to_type = smallint(); + CheckCastFails(ArrayFromJSON(utf8(), "[\"hello\"]"), options); + } } TEST(Cast, ExtensionDictToExtension) { @@ -2785,6 +2799,41 @@ TEST(Cast, ExtensionDictToExtension) { CheckCast(dict_array, extension_array, options); } +TEST(Cast, IntToExtensionTypeDowncast) { + CheckCast(ArrayFromJSON(uint8(), "[0, 100, 200, 1, 2]"), + SmallintArrayFromJSON("[0, 100, 200, 1, 2]")); + + // int32 to Smallint(int16), with overflow + { + CastOptions options; + options.to_type = smallint(); + CheckCastFails(ArrayFromJSON(int32(), "[0, null, 32768, 1, 3]"), options); + + options.allow_int_overflow = true; + CheckCast(ArrayFromJSON(int32(), "[0, null, 32768, 1, 3]"), + SmallintArrayFromJSON("[0, null, -32768, 1, 3]"), options); + } + + // int32 to Smallint(int16), with underflow + { + CastOptions options; + options.to_type = smallint(); + CheckCastFails(ArrayFromJSON(int32(), "[0, null, -32769, 1, 3]"), options); + + options.allow_int_overflow = true; + CheckCast(ArrayFromJSON(int32(), "[0, null, -32769, 1, 3]"), + SmallintArrayFromJSON("[0, null, 32767, 1, 3]"), options); + } + + // Cannot cast between extension types + { + CastOptions options; + options.to_type = smallint(); + auto tiny_array = TinyintArrayFromJSON("[0, 1, 3]"); + ASSERT_NOT_OK(Cast(tiny_array, smallint(), options)); + } +} + TEST(Cast, DictTypeToAnotherDict) { auto check_cast = [&](const std::shared_ptr& in_type, const std::shared_ptr& out_type, diff --git a/cpp/src/arrow/testing/extension_type.h b/cpp/src/arrow/testing/extension_type.h index 338b4cb4da05..846e3c7a1657 100644 --- a/cpp/src/arrow/testing/extension_type.h +++ b/cpp/src/arrow/testing/extension_type.h @@ -54,6 +54,11 @@ class ARROW_TESTING_EXPORT SmallintArray : public ExtensionArray { using ExtensionArray::ExtensionArray; }; +class ARROW_TESTING_EXPORT TinyintArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + class ARROW_TESTING_EXPORT ListExtensionArray : public ExtensionArray { public: using ExtensionArray::ExtensionArray; @@ -76,6 +81,23 @@ class ARROW_TESTING_EXPORT SmallintType : public ExtensionType { std::string Serialize() const override { return "smallint"; } }; +class ARROW_TESTING_EXPORT TinyintType : public ExtensionType { + public: + TinyintType() : ExtensionType(int8()) {} + + std::string extension_name() const override { return "tinyint"; } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized) const override; + + std::string Serialize() const override { return "tinyint"; } +}; + class ARROW_TESTING_EXPORT ListExtensionType : public ExtensionType { public: ListExtensionType() : ExtensionType(list(int32())) {} @@ -140,6 +162,9 @@ std::shared_ptr uuid(); ARROW_TESTING_EXPORT std::shared_ptr smallint(); +ARROW_TESTING_EXPORT +std::shared_ptr tinyint(); + ARROW_TESTING_EXPORT std::shared_ptr list_extension_type(); @@ -155,6 +180,9 @@ std::shared_ptr ExampleUuid(); ARROW_TESTING_EXPORT std::shared_ptr ExampleSmallint(); +ARROW_TESTING_EXPORT +std::shared_ptr ExampleTinyint(); + ARROW_TESTING_EXPORT std::shared_ptr ExampleDictExtension(); diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index a4d867088000..9c65d3b7fdd4 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -832,6 +832,28 @@ Result> SmallintType::Deserialize( return std::make_shared(); } +bool TinyintType::ExtensionEquals(const ExtensionType& other) const { + return (other.extension_name() == this->extension_name()); +} + +std::shared_ptr TinyintType::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("tinyint", static_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +Result> TinyintType::Deserialize( + std::shared_ptr storage_type, const std::string& serialized) const { + if (serialized != "tinyint") { + return Status::Invalid("Type identifier did not match: '", serialized, "'"); + } + if (!storage_type->Equals(*int16())) { + return Status::Invalid("Invalid storage type for TinyintType: ", + storage_type->ToString()); + } + return std::make_shared(); +} + bool ListExtensionType::ExtensionEquals(const ExtensionType& other) const { return (other.extension_name() == this->extension_name()); } @@ -905,6 +927,8 @@ std::shared_ptr uuid() { return std::make_shared(); } std::shared_ptr smallint() { return std::make_shared(); } +std::shared_ptr tinyint() { return std::make_shared(); } + std::shared_ptr list_extension_type() { return std::make_shared(); } @@ -936,6 +960,11 @@ std::shared_ptr ExampleSmallint() { return ExtensionType::WrapArray(smallint(), arr); } +std::shared_ptr ExampleTinyint() { + auto arr = ArrayFromJSON(int8(), "[-128, null, 1, 2, 3, 4, 127]"); + return ExtensionType::WrapArray(tinyint(), arr); +} + std::shared_ptr ExampleDictExtension() { auto arr = DictArrayFromJSON(dictionary(int8(), utf8()), "[0, 1, null, 1]", R"(["foo", "bar"])"); From 7e8d795012b95af90cd424b745bb75a7266005e3 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 21 Sep 2022 14:08:48 +0200 Subject: [PATCH 13/27] Add entry to compute.rst docs --- docs/source/cpp/compute.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index a354f42a4b11..72da5827fd2e 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1291,6 +1291,8 @@ provided by a concrete function :func:`~arrow::compute::Cast`. +-----------------+------------+--------------------+------------------+--------------------------------+-------+ | strptime | Unary | String-like | Timestamp | :struct:`StrptimeOptions` | | +-----------------+------------+--------------------+------------------+--------------------------------+-------+ +| extension | Unary | Many | Extension | :struct:`CastOptions` | \(2) | ++-----------------+------------+--------------------+------------------+--------------------------------+-------+ The conversions available with ``cast`` are listed below. In all cases, a null input value is converted into a null output value. @@ -1303,6 +1305,10 @@ null input value is converted into a null output value. The character for the decimal point is localized according to the locale. See `detailed formatting documentation`_ for descriptions of other flags. +* \(2) Input types can be anything whose type can be cast to the + resulting Extension's storage_type. Casting between extensions, even with + compatible storage types is not supported. + .. _detailed formatting documentation: https://howardhinnant.github.io/date/date.html#to_stream_formatting **Truth value extraction** From 069bfd000eea16fb359f3fd2b8c387f3485346b9 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 21 Sep 2022 14:22:02 +0200 Subject: [PATCH 14/27] Move compute.rst docs entry to generic conversions table --- docs/source/cpp/compute.rst | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 72da5827fd2e..b8847fa6ef06 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1291,8 +1291,6 @@ provided by a concrete function :func:`~arrow::compute::Cast`. +-----------------+------------+--------------------+------------------+--------------------------------+-------+ | strptime | Unary | String-like | Timestamp | :struct:`StrptimeOptions` | | +-----------------+------------+--------------------+------------------+--------------------------------+-------+ -| extension | Unary | Many | Extension | :struct:`CastOptions` | \(2) | -+-----------------+------------+--------------------+------------------+--------------------------------+-------+ The conversions available with ``cast`` are listed below. In all cases, a null input value is converted into a null output value. @@ -1305,10 +1303,6 @@ null input value is converted into a null output value. The character for the decimal point is localized according to the locale. See `detailed formatting documentation`_ for descriptions of other flags. -* \(2) Input types can be anything whose type can be cast to the - resulting Extension's storage_type. Casting between extensions, even with - compatible storage types is not supported. - .. _detailed formatting documentation: https://howardhinnant.github.io/date/date.html#to_stream_formatting **Truth value extraction** @@ -1383,6 +1377,8 @@ null input value is converted into a null output value. +-----------------------------+------------------------------------+---------+ | Null | Any | | +-----------------------------+------------------------------------+---------+ +| Any* | Extension | \(3) | ++-----------------------------+------------------------------------+---------+ * \(1) The dictionary indices are unchanged, the dictionary values are cast from the input value type to the output value type (if a conversion @@ -1392,6 +1388,9 @@ null input value is converted into a null output value. input value type to the output value type (if a conversion is available). +* \(3) Any input type except Extension types. Can cast any input type + which can be cast to the resulting extension's storage type. + Temporal component extraction ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 3e46cbda1d0e339a21c8a37f9099a2f8584b0d51 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 21 Sep 2022 15:04:20 +0200 Subject: [PATCH 15/27] Refactor adding input types in GetCastToExtension [skip ci] --- cpp/src/arrow/compute/kernels/scalar_cast_extension.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc index 96786b8b9557..514e2c354ff4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -42,14 +42,17 @@ Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou std::shared_ptr GetCastToExtension(std::string name) { auto func = std::make_shared(std::move(name), Type::EXTENSION); - for (auto types : {IntTypes(), FloatingPointTypes(), StringTypes(), BinaryTypes()}) { + for (auto types : {PrimitiveTypes(), IntervalTypes(), TemporalTypes()}) { for (auto in_ty : types) { DCHECK_OK( func->AddKernel(in_ty->id(), {in_ty}, kOutputTargetType, CastToExtension)); } } - DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, - kOutputTargetType, CastToExtension)); + for (auto in_ty : + {Type::DICTIONARY, Type::LIST, Type::LARGE_LIST, Type::FIXED_SIZE_LIST}) { + DCHECK_OK( + func->AddKernel(in_ty, {InputType(in_ty)}, kOutputTargetType, CastToExtension)); + } return func; } From 9a5a5142e31e876debe104a531aaef3511e28d02 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Thu, 22 Sep 2022 09:46:09 +0200 Subject: [PATCH 16/27] Support casting between extension types --- .../compute/kernels/scalar_cast_extension.cc | 4 ++-- .../arrow/compute/kernels/scalar_cast_test.cc | 5 +++-- docs/source/cpp/compute.rst | 6 +++--- python/pyarrow/tests/test_extension_type.py | 19 +++++++++++++++++++ 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc index 514e2c354ff4..0a2abaaa387e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -48,8 +48,8 @@ std::shared_ptr GetCastToExtension(std::string name) { func->AddKernel(in_ty->id(), {in_ty}, kOutputTargetType, CastToExtension)); } } - for (auto in_ty : - {Type::DICTIONARY, Type::LIST, Type::LARGE_LIST, Type::FIXED_SIZE_LIST}) { + for (auto in_ty : {Type::DICTIONARY, Type::LIST, Type::LARGE_LIST, + Type::FIXED_SIZE_LIST, Type::EXTENSION}) { DCHECK_OK( func->AddKernel(in_ty, {InputType(in_ty)}, kOutputTargetType, CastToExtension)); } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 6b67e385aed0..d095188b7cda 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -225,7 +225,8 @@ TEST(Cast, CanCast) { ExpectCanCast(smallint(), {int16()}); // cast storage ExpectCanCast(smallint(), kNumericTypes); // any cast which is valid for storage is supported - ExpectCannotCast(null(), {smallint()}); // FIXME missing common cast from null + ExpectCanCast(null(), {smallint()}); + ExpectCanCast(tinyint(), {smallint()}); // cast between compatible storage types ExpectCanCast(date32(), {utf8(), large_utf8()}); ExpectCanCast(date64(), {utf8(), large_utf8()}); @@ -2830,7 +2831,7 @@ TEST(Cast, IntToExtensionTypeDowncast) { CastOptions options; options.to_type = smallint(); auto tiny_array = TinyintArrayFromJSON("[0, 1, 3]"); - ASSERT_NOT_OK(Cast(tiny_array, smallint(), options)); + ASSERT_OK(Cast(tiny_array, smallint(), options)); } } diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index b8847fa6ef06..74db126aa7bc 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1377,7 +1377,7 @@ null input value is converted into a null output value. +-----------------------------+------------------------------------+---------+ | Null | Any | | +-----------------------------+------------------------------------+---------+ -| Any* | Extension | \(3) | +| Any | Extension | \(3) | +-----------------------------+------------------------------------+---------+ * \(1) The dictionary indices are unchanged, the dictionary values are @@ -1388,8 +1388,8 @@ null input value is converted into a null output value. input value type to the output value type (if a conversion is available). -* \(3) Any input type except Extension types. Can cast any input type - which can be cast to the resulting extension's storage type. +* \(3) Any input type where the type, or storage type if Extension type, can + be cast to the resulting extension's storage type. Temporal component extraction ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 0f174d1a7238..fb6f4cc369da 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -25,6 +25,15 @@ import pytest +class TinyIntType(pa.PyExtensionType): + + def __init__(self): + pa.PyExtensionType.__init__(self, pa.int8()) + + def __reduce__(self): + return TinyIntType, () + + class IntegerType(pa.PyExtensionType): def __init__(self): @@ -533,6 +542,16 @@ def test_casting_to_extension_type(data, ty): assert out.to_pylist() == [1, 2] +def test_cast_between_extension_types(): + array = pa.array([1, 2, 3], pa.int8()) + + tiny_int_arr = array.cast(TinyIntType()) + assert tiny_int_arr.type == TinyIntType() + + int_arr = tiny_int_arr.cast(IntegerType()) + assert int_arr.type == IntegerType() + + def test_casting_dict_array_to_extension_type(): storage = pa.array([b"0123456789abcdef"], type=pa.binary(16)) arr = pa.ExtensionArray.from_storage(UuidType(), storage) From 29db94b04c8d5598002a4942088cd3a556ffae86 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Thu, 22 Sep 2022 13:08:34 +0200 Subject: [PATCH 17/27] Move and use AllTypeIds from gtest_util --- .../compute/kernels/scalar_cast_extension.cc | 9 +--- cpp/src/arrow/testing/gtest_util.cc | 41 ------------------ cpp/src/arrow/testing/gtest_util.h | 3 -- cpp/src/arrow/type_fwd.h | 42 +++++++++++++++++++ 4 files changed, 43 insertions(+), 52 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc index 0a2abaaa387e..25efc6e24409 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -42,14 +42,7 @@ Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou std::shared_ptr GetCastToExtension(std::string name) { auto func = std::make_shared(std::move(name), Type::EXTENSION); - for (auto types : {PrimitiveTypes(), IntervalTypes(), TemporalTypes()}) { - for (auto in_ty : types) { - DCHECK_OK( - func->AddKernel(in_ty->id(), {in_ty}, kOutputTargetType, CastToExtension)); - } - } - for (auto in_ty : {Type::DICTIONARY, Type::LIST, Type::LARGE_LIST, - Type::FIXED_SIZE_LIST, Type::EXTENSION}) { + for (Type::type in_ty : AllTypeIds()) { DCHECK_OK( func->AddKernel(in_ty, {InputType(in_ty)}, kOutputTargetType, CastToExtension)); } diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 9c65d3b7fdd4..18f43da72a3f 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -64,47 +64,6 @@ namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; -std::vector AllTypeIds() { - return {Type::NA, - Type::BOOL, - Type::INT8, - Type::INT16, - Type::INT32, - Type::INT64, - Type::UINT8, - Type::UINT16, - Type::UINT32, - Type::UINT64, - Type::HALF_FLOAT, - Type::FLOAT, - Type::DOUBLE, - Type::DECIMAL128, - Type::DECIMAL256, - Type::DATE32, - Type::DATE64, - Type::TIME32, - Type::TIME64, - Type::TIMESTAMP, - Type::INTERVAL_DAY_TIME, - Type::INTERVAL_MONTHS, - Type::DURATION, - Type::STRING, - Type::BINARY, - Type::LARGE_STRING, - Type::LARGE_BINARY, - Type::FIXED_SIZE_BINARY, - Type::STRUCT, - Type::LIST, - Type::LARGE_LIST, - Type::FIXED_SIZE_LIST, - Type::MAP, - Type::DENSE_UNION, - Type::SPARSE_UNION, - Type::DICTIONARY, - Type::EXTENSION, - Type::INTERVAL_MONTH_DAY_NANO}; -} - template void AssertTsSame(const T& expected, const T& actual, CompareFunctor&& compare) { if (!compare(actual, expected)) { diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 8ce5049452a9..1408042d994e 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -190,9 +190,6 @@ class RecordBatch; class Table; struct Datum; -ARROW_TESTING_EXPORT -std::vector AllTypeIds(); - #define ASSERT_ARRAYS_EQUAL(lhs, rhs) AssertArraysEqual((lhs), (rhs)) #define ASSERT_BATCHES_EQUAL(lhs, rhs) AssertBatchesEqual((lhs), (rhs)) #define ASSERT_BATCHES_APPROX_EQUAL(lhs, rhs) AssertBatchesApproxEqual((lhs), (rhs)) diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 84a50a12eb30..d28c8e950424 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -410,6 +410,48 @@ struct Type { }; }; +/// \brief Get a vector of all type ids +inline std::vector AllTypeIds() { + return {Type::NA, + Type::BOOL, + Type::INT8, + Type::INT16, + Type::INT32, + Type::INT64, + Type::UINT8, + Type::UINT16, + Type::UINT32, + Type::UINT64, + Type::HALF_FLOAT, + Type::FLOAT, + Type::DOUBLE, + Type::DECIMAL128, + Type::DECIMAL256, + Type::DATE32, + Type::DATE64, + Type::TIME32, + Type::TIME64, + Type::TIMESTAMP, + Type::INTERVAL_DAY_TIME, + Type::INTERVAL_MONTHS, + Type::DURATION, + Type::STRING, + Type::BINARY, + Type::LARGE_STRING, + Type::LARGE_BINARY, + Type::FIXED_SIZE_BINARY, + Type::STRUCT, + Type::LIST, + Type::LARGE_LIST, + Type::FIXED_SIZE_LIST, + Type::MAP, + Type::DENSE_UNION, + Type::SPARSE_UNION, + Type::DICTIONARY, + Type::EXTENSION, + Type::INTERVAL_MONTH_DAY_NANO}; +} + /// \defgroup type-factories Factory functions for creating data types /// /// Factory functions for creating data types From 2ea1e11b4a1c6f1b604f76c90e47215f8f440321 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Thu, 22 Sep 2022 14:36:22 +0200 Subject: [PATCH 18/27] Move AllTypeIds impl -> type.cc --- cpp/src/arrow/type.cc | 41 ++++++++++++++++++++++++++++++++++++++++ cpp/src/arrow/type_fwd.h | 41 +--------------------------------------- 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index efff07db6671..a3285cf92f51 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -95,6 +95,47 @@ constexpr Type::type DurationType::type_id; constexpr Type::type DictionaryType::type_id; +std::vector AllTypeIds() { + return {Type::NA, + Type::BOOL, + Type::INT8, + Type::INT16, + Type::INT32, + Type::INT64, + Type::UINT8, + Type::UINT16, + Type::UINT32, + Type::UINT64, + Type::HALF_FLOAT, + Type::FLOAT, + Type::DOUBLE, + Type::DECIMAL128, + Type::DECIMAL256, + Type::DATE32, + Type::DATE64, + Type::TIME32, + Type::TIME64, + Type::TIMESTAMP, + Type::INTERVAL_DAY_TIME, + Type::INTERVAL_MONTHS, + Type::DURATION, + Type::STRING, + Type::BINARY, + Type::LARGE_STRING, + Type::LARGE_BINARY, + Type::FIXED_SIZE_BINARY, + Type::STRUCT, + Type::LIST, + Type::LARGE_LIST, + Type::FIXED_SIZE_LIST, + Type::MAP, + Type::DENSE_UNION, + Type::SPARSE_UNION, + Type::DICTIONARY, + Type::EXTENSION, + Type::INTERVAL_MONTH_DAY_NANO}; +} + namespace internal { struct TypeIdToTypeNameVisitor { diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index d28c8e950424..2b29fe310de1 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -411,46 +411,7 @@ struct Type { }; /// \brief Get a vector of all type ids -inline std::vector AllTypeIds() { - return {Type::NA, - Type::BOOL, - Type::INT8, - Type::INT16, - Type::INT32, - Type::INT64, - Type::UINT8, - Type::UINT16, - Type::UINT32, - Type::UINT64, - Type::HALF_FLOAT, - Type::FLOAT, - Type::DOUBLE, - Type::DECIMAL128, - Type::DECIMAL256, - Type::DATE32, - Type::DATE64, - Type::TIME32, - Type::TIME64, - Type::TIMESTAMP, - Type::INTERVAL_DAY_TIME, - Type::INTERVAL_MONTHS, - Type::DURATION, - Type::STRING, - Type::BINARY, - Type::LARGE_STRING, - Type::LARGE_BINARY, - Type::FIXED_SIZE_BINARY, - Type::STRUCT, - Type::LIST, - Type::LARGE_LIST, - Type::FIXED_SIZE_LIST, - Type::MAP, - Type::DENSE_UNION, - Type::SPARSE_UNION, - Type::DICTIONARY, - Type::EXTENSION, - Type::INTERVAL_MONTH_DAY_NANO}; -} +std::vector AllTypeIds(); /// \defgroup type-factories Factory functions for creating data types /// From e6ab97a07678272c1826608e747ded8bc86a24c3 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Tue, 27 Sep 2022 14:43:58 +0200 Subject: [PATCH 19/27] Add test for nested extension types casting --- python/pyarrow/tests/test_extension_type.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index fb6f4cc369da..c54733fd6c31 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -552,6 +552,25 @@ def test_cast_between_extension_types(): assert int_arr.type == IntegerType() +@pytest.mark.parametrize("data,type_factory", ( + # list + ([[1, 2, 3]], lambda: pa.list_(IntegerType())), + # struct + ([{"foo": 1}], lambda: pa.struct([("foo", IntegerType())])), + # list> + ([[{"foo": 1}]], lambda: pa.list_(pa.struct([("foo", IntegerType())]))), + # struct> + ([{"foo": [1, 2, 3]}], lambda: pa.struct( + [("foo", pa.list_(IntegerType()))])), +)) +def test_cast_nested_extension_types(data, type_factory): + ty = type_factory() + a = pa.array(data) + b = a.cast(ty) + assert b.type == ty # casted to target extension + assert b.cast(a.type) # and can cast back + + def test_casting_dict_array_to_extension_type(): storage = pa.array([b"0123456789abcdef"], type=pa.binary(16)) arr = pa.ExtensionArray.from_storage(UuidType(), storage) From 02013e8623d5cb5c0eabebc0668477fd3f06b950 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Thu, 29 Sep 2022 08:41:10 +0200 Subject: [PATCH 20/27] Fail cast between incompatible storage types [skip ci] --- .../compute/kernels/scalar_cast_extension.cc | 13 +++++++++++++ .../arrow/compute/kernels/scalar_cast_test.cc | 4 ++-- python/pyarrow/tests/test_extension_type.py | 16 ++++++++++++++-- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc index 25efc6e24409..0985192d6779 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -32,6 +32,19 @@ Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou DCHECK(batch[0].is_array()); std::shared_ptr array = batch[0].array.ToArray(); + // Try to prevent user errors by preventing casting between extensions w/ + // different storage types. Provide a tip on how to accomplish same outcome. + if (array->type()->id() == Type::EXTENSION) { + const auto& ext_arr = checked_cast(*array); + if (out_ty->id() != ext_arr.extension_type()->storage_id()) { + return Status::Invalid("Casting from '" + ext_arr.extension_type()->ToString() + + "' to extension with different storage type '" + + out_ty->ToString() + + "' not permitted. One can first cast to the storage " + "type, then to the extension type."); + } + } + ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, Cast(*array, out_ty, options, ctx->exec_context())); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index d095188b7cda..6b172eaa1407 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -2826,12 +2826,12 @@ TEST(Cast, IntToExtensionTypeDowncast) { SmallintArrayFromJSON("[0, null, 32767, 1, 3]"), options); } - // Cannot cast between extension types + // Cannot cast between extension types when storage types differ { CastOptions options; options.to_type = smallint(); auto tiny_array = TinyintArrayFromJSON("[0, 1, 3]"); - ASSERT_OK(Cast(tiny_array, smallint(), options)); + ASSERT_NOT_OK(Cast(tiny_array, smallint(), options)); } } diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index c54733fd6c31..b0311b1b8c82 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -548,8 +548,20 @@ def test_cast_between_extension_types(): tiny_int_arr = array.cast(TinyIntType()) assert tiny_int_arr.type == TinyIntType() - int_arr = tiny_int_arr.cast(IntegerType()) - assert int_arr.type == IntegerType() + # Casting between extension types w/ different storage types not okay. + msg = ( + "Casting from 'extension>' " + "to extension with different storage type 'int64' not permitted. " + "One can first cast to the storage type, then to the extension type." + ) + with pytest.raises(pa.ArrowInvalid, match=msg): + tiny_int_arr.cast(IntegerType()) + tiny_int_arr.cast(pa.int64()).cast(IntegerType()) + + # Casting between extension types w/ same storage type is okay. + arr = pa.array([b'1' * 16, b'2' * 16], pa.binary(16)) + uuid_arr = arr.cast(UuidType()) + uuid2_arr = uuid_arr.cast(UuidType2()) @pytest.mark.parametrize("data,type_factory", ( From da66bf423e97e40d9c7c5316427a9f000061c36f Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Thu, 29 Sep 2022 09:51:23 +0200 Subject: [PATCH 21/27] Only convert to ext where input type is same as storage type [skip ci] --- .../arrow/compute/kernels/scalar_cast_extension.cc | 11 +++++------ python/pyarrow/tests/test_extension_type.py | 5 ----- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc index 0985192d6779..f9722ff2da44 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -34,20 +34,19 @@ Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou // Try to prevent user errors by preventing casting between extensions w/ // different storage types. Provide a tip on how to accomplish same outcome. + std::shared_ptr result = array; if (array->type()->id() == Type::EXTENSION) { - const auto& ext_arr = checked_cast(*array); - if (out_ty->id() != ext_arr.extension_type()->storage_id()) { - return Status::Invalid("Casting from '" + ext_arr.extension_type()->ToString() + + if (!array->type()->Equals(out_ty)) { + return Status::Invalid("Casting from '" + array->type()->ToString() + "' to extension with different storage type '" + out_ty->ToString() + "' not permitted. One can first cast to the storage " "type, then to the extension type."); } + } else { + ARROW_ASSIGN_OR_RAISE(result, Cast(*array, out_ty, options, ctx->exec_context())); } - ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, - Cast(*array, out_ty, options, ctx->exec_context())); - ExtensionArray extension(options.to_type.GetSharedPtr(), result); out->value = std::move(extension.data()); return Status::OK(); diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index b0311b1b8c82..4c648967ff9a 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -558,11 +558,6 @@ def test_cast_between_extension_types(): tiny_int_arr.cast(IntegerType()) tiny_int_arr.cast(pa.int64()).cast(IntegerType()) - # Casting between extension types w/ same storage type is okay. - arr = pa.array([b'1' * 16, b'2' * 16], pa.binary(16)) - uuid_arr = arr.cast(UuidType()) - uuid2_arr = uuid_arr.cast(UuidType2()) - @pytest.mark.parametrize("data,type_factory", ( # list From 374301cbf10f4d9c350029c0e3b7aff3015edc35 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Thu, 29 Sep 2022 10:11:55 +0200 Subject: [PATCH 22/27] ARROW_EXPORT for AllTypeIds --- cpp/src/arrow/type_fwd.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 2b29fe310de1..e2bace974e2c 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -411,7 +411,7 @@ struct Type { }; /// \brief Get a vector of all type ids -std::vector AllTypeIds(); +ARROW_EXPORT std::vector AllTypeIds(); /// \defgroup type-factories Factory functions for creating data types /// From 44cef44d00f657f43f6a257677ddbe4493f8901c Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Fri, 30 Sep 2022 14:59:31 +0200 Subject: [PATCH 23/27] Move result assignments inside if branches --- cpp/src/arrow/compute/kernels/scalar_cast_extension.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc index f9722ff2da44..2db3ba775dde 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -34,7 +34,7 @@ Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou // Try to prevent user errors by preventing casting between extensions w/ // different storage types. Provide a tip on how to accomplish same outcome. - std::shared_ptr result = array; + std::shared_ptr result; if (array->type()->id() == Type::EXTENSION) { if (!array->type()->Equals(out_ty)) { return Status::Invalid("Casting from '" + array->type()->ToString() + @@ -43,6 +43,7 @@ Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou "' not permitted. One can first cast to the storage " "type, then to the extension type."); } + result = array; } else { ARROW_ASSIGN_OR_RAISE(result, Cast(*array, out_ty, options, ctx->exec_context())); } From 8e5063f005e7d3bb22b5eaff7007bb7632a7c2f4 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Tue, 4 Oct 2022 13:18:47 +0200 Subject: [PATCH 24/27] Update error msg and test casting between same ext types --- .../compute/kernels/scalar_cast_extension.cc | 12 +++++----- python/pyarrow/tests/test_extension_type.py | 24 +++++++++++++------ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc index 2db3ba775dde..d2e2ab72f006 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -27,7 +27,8 @@ namespace internal { namespace { Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { const CastOptions& options = checked_cast(ctx->state())->options; - auto out_ty = static_cast(*options.to_type.type).storage_type(); + const auto& ext_ty = static_cast(*options.to_type.type); + auto out_ty = ext_ty.storage_type(); DCHECK(batch[0].is_array()); std::shared_ptr array = batch[0].array.ToArray(); @@ -37,11 +38,10 @@ Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou std::shared_ptr result; if (array->type()->id() == Type::EXTENSION) { if (!array->type()->Equals(out_ty)) { - return Status::Invalid("Casting from '" + array->type()->ToString() + - "' to extension with different storage type '" + - out_ty->ToString() + - "' not permitted. One can first cast to the storage " - "type, then to the extension type."); + return Status::TypeError("Casting from '" + array->type()->ToString() + + "' to different extension type '" + ext_ty.ToString() + + "' not permitted. One can first cast to the storage " + "type, then to the extension type."); } result = array; } else { diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 4c648967ff9a..fc2f1dbd95d8 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -66,7 +66,7 @@ def __init__(self): pa.PyExtensionType.__init__(self, pa.binary(16)) def __reduce__(self): - return UuidType, () + return UuidType2, () class ParamExtType(pa.PyExtensionType): @@ -549,15 +549,25 @@ def test_cast_between_extension_types(): assert tiny_int_arr.type == TinyIntType() # Casting between extension types w/ different storage types not okay. - msg = ( - "Casting from 'extension>' " - "to extension with different storage type 'int64' not permitted. " - "One can first cast to the storage type, then to the extension type." - ) - with pytest.raises(pa.ArrowInvalid, match=msg): + msg = ("Casting from 'extension>' " + "to different extension type " + "'extension>' not permitted. " + "One can first cast to the storage type, then to the extension type." + ) + with pytest.raises(TypeError, match=msg): tiny_int_arr.cast(IntegerType()) tiny_int_arr.cast(pa.int64()).cast(IntegerType()) + # Between the same extension types is okay + array = pa.array([b'1' * 16, b'2' * 16], pa.binary(16)).cast(UuidType()) + out = array.cast(UuidType()) + assert out.type == UuidType() + + # Will still fail casting between extensions who share storage type, + # can only cast between exactly the same extension types. + with pytest.raises(TypeError, match='Casting from *'): + array.cast(UuidType2()) + @pytest.mark.parametrize("data,type_factory", ( # list From 2bc8a83035473592397490b81aeaf9138ae4e7f2 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Tue, 4 Oct 2022 15:28:23 +0200 Subject: [PATCH 25/27] Update compute.rst doc notes --- docs/source/cpp/compute.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 74db126aa7bc..2e3c6ab9de8c 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1388,8 +1388,8 @@ null input value is converted into a null output value. input value type to the output value type (if a conversion is available). -* \(3) Any input type where the type, or storage type if Extension type, can - be cast to the resulting extension's storage type. +* \(3) Any input type that can be cast to the resulting extension's storage type. + This excludes extension types, unless being cast to the same extension type. Temporal component extraction ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 2cef2ae0de801daf69cc59a7151dbe1fa141ad14 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 5 Oct 2022 08:57:31 +0200 Subject: [PATCH 26/27] Fixup: autopep8 format --- python/pyarrow/tests/test_extension_type.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index fc2f1dbd95d8..ee885221a8ac 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -552,7 +552,8 @@ def test_cast_between_extension_types(): msg = ("Casting from 'extension>' " "to different extension type " "'extension>' not permitted. " - "One can first cast to the storage type, then to the extension type." + "One can first cast to the storage type, " + "then to the extension type." ) with pytest.raises(TypeError, match=msg): tiny_int_arr.cast(IntegerType()) From 135c8a0e74bb8676f1ef5fb04136ab3e8ddce4a3 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Fri, 7 Oct 2022 11:33:37 +0200 Subject: [PATCH 27/27] Add test for casting to extension w/ extension storage --- python/pyarrow/tests/test_extension_type.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index ee885221a8ac..926790cfe064 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -43,6 +43,15 @@ def __reduce__(self): return IntegerType, () +class IntegerEmbeddedType(pa.PyExtensionType): + + def __init__(self): + pa.PyExtensionType.__init__(self, IntegerType()) + + def __reduce__(self): + return IntegerEmbeddedType, () + + class UuidScalarType(pa.ExtensionScalar): def as_py(self): return None if self.value is None else UUID(bytes=self.value.as_py()) @@ -570,6 +579,13 @@ def test_cast_between_extension_types(): array.cast(UuidType2()) +def test_cast_to_extension_with_extension_storage(): + # Test casting directly, and IntegerType -> IntegerEmbeddedType + array = pa.array([1, 2, 3], pa.int64()) + array.cast(IntegerEmbeddedType()) + array.cast(IntegerType()).cast(IntegerEmbeddedType()) + + @pytest.mark.parametrize("data,type_factory", ( # list ([[1, 2, 3]], lambda: pa.list_(IntegerType())),