From d047296f50dc6697c1c30b9dc31273f911c5fa76 Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Tue, 16 Apr 2024 23:33:34 -0700 Subject: [PATCH] [ET-VK][2/n] select_copy.int equivalent to `select.int` Differential Revision: [D56092143](https://our.internmc.facebook.com/intern/diff/D56092143/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ops/impl/Select.cpp | 1 + backends/vulkan/test/op_tests/cases.py | 1 + 2 files changed, 2 insertions(+) diff --git a/backends/vulkan/runtime/graph/ops/impl/Select.cpp b/backends/vulkan/runtime/graph/ops/impl/Select.cpp index c66d98af4a2..e0412450ed6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Select.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Select.cpp @@ -127,6 +127,7 @@ void select_int(ComputeGraph& graph, const std::vector& args) { REGISTER_OPERATORS { VK_REGISTER_OP(aten.select.int, select_int); + VK_REGISTER_OP(aten.select_copy.int, select_int); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index fe8d7c25e01..7ebe7bbcffa 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -182,4 +182,5 @@ def get_select_int_inputs(): "aten.native_layer_norm.default": get_native_layer_norm_inputs(), "aten.full.default": get_full_inputs(), "aten.select.int": get_select_int_inputs(), + "aten.select_copy.int": get_select_int_inputs(), }