diff --git a/.bazelrc b/.bazelrc index 9f33b600dcdd3e..1dd928acdb4455 100644 --- a/.bazelrc +++ b/.bazelrc @@ -496,7 +496,7 @@ build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe # TODO(gunan): Remove once we use MSVC 2019 with latest patches. build:rbe_win --define=override_eigen_strong_inline=true -build:rbe_win --jobs=500 +build:rbe_win --jobs=100 build:rbe_win_py37 --config=rbe build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO="@windows_py37_config_python" diff --git a/README.md b/README.md index 1392da3f60bf58..9cf595bbf6173d 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,7 @@ Build Type **Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/) **Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) **Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/) +**Linux aarch64 CPU** Nightly
Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) **Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) **Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/) **Red Hat® Enterprise Linux® 7.6 CPU & GPU**
Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/) diff --git a/RELEASE.md b/RELEASE.md index 5c05f2a4285ed2..69eca82c5f21a2 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,6 +11,10 @@ * C-API functions `TF_StringDecode`, `TF_StringEncode`, and `TF_StringEncodedSize` are no longer relevant and have been removed; see core/platform/ctstring.h for string access/modification in C. +* In batching library, rename parameter + SharedBatchScheduler::QueueOptions::max_batch_size to a more accurate name + (input_batch_size_limit) for a recent feature to enable split of large batch + sizes. ## Known Caveats @@ -27,7 +31,11 @@ * * * TF Core: - * + * + * `tf.Tensor` is now a subclass of `typing.Generic`, allowing type annotations + to be parameterized by dtype: `tf.Tensor[tf.Int32]`. This requires Python 3, + and will become fully compatible with static type checkers in the future. + * `tf.data`: * Added optional `exclude_cols` parameter to CsvDataset. This parameter is the complement of `select_cols`; at most one of these should be specified. @@ -50,6 +58,9 @@ * Tracing and Debugging: * * Other: + * We have replaced uses of "whitelist" with "allowlist" where possible. + Please see https://developers.google.com/style/word-list#blacklist for more + context. * ## Thanks to our Contributors diff --git a/SECURITY.md b/SECURITY.md index f3a6c148b2eaf5..6c722766b3a214 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -44,7 +44,7 @@ Even if the untrusted party only supplies the serialized computation graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the set of computation primitives available to TensorFlow is powerful enough that you should assume that the TensorFlow process effectively executes arbitrary -code. One common solution is to whitelist only a few safe Ops. While this is +code. One common solution is to allow only a few safe Ops. While this is possible in theory, we still recommend you sandbox the execution. It depends on the computation graph whether a user provided checkpoint is safe. diff --git a/tensorflow/BUILD b/tensorflow/BUILD index d00608ccc98114..8a0918b416fb74 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -467,6 +467,13 @@ config_setting( visibility = ["//visibility:public"], ) +# This flag enables experimental MLIR bridge support. +config_setting( + name = "enable_mlir_bridge", + values = {"define": "enable_mlir_bridge=true"}, + visibility = ["//visibility:public"], +) + # This flag enables experimental TPU support config_setting( name = "with_tpu_support", diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 9696a3415bf3ff..5c101bef85fc54 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -177,7 +177,9 @@ cc_library( visibility = [ "//tensorflow:internal", ], - deps = [], + deps = [ + "//tensorflow/core:protos_all_cc", + ], ) cc_library( @@ -480,6 +482,7 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/common_runtime:function_optimization_registry", + "//tensorflow/core/common_runtime:optimization_registry", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "@com_google_absl//absl/strings", diff --git a/tensorflow/c/eager/abstract_operation.h b/tensorflow/c/eager/abstract_operation.h index ff17fcf3cea5a0..b332679cc7cc70 100644 --- a/tensorflow/c/eager/abstract_operation.h +++ b/tensorflow/c/eager/abstract_operation.h @@ -73,7 +73,8 @@ class AbstractOperation { virtual Status SetDeviceName(const char* name) = 0; virtual Status AddInput(AbstractTensorHandle* input) = 0; - virtual Status AddInputList(absl::Span inputs) = 0; + virtual Status AddInputList( + absl::Span inputs) = 0; virtual Status Execute(absl::Span retvals, int* num_retvals) = 0; diff --git a/tensorflow/c/eager/abstract_tensor_handle.h b/tensorflow/c/eager/abstract_tensor_handle.h index d50bd4530db8b0..de041690420552 100644 --- a/tensorflow/c/eager/abstract_tensor_handle.h +++ b/tensorflow/c/eager/abstract_tensor_handle.h @@ -16,6 +16,8 @@ limitations under the License. #define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_ #include + +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { // Abstract interface to a Tensor handle in either tracing or immediate @@ -27,6 +29,9 @@ class AbstractTensorHandle { virtual ~AbstractTensorHandle() {} public: + // Returns tensor dtype. + virtual tensorflow::DataType DataType() const = 0; + AbstractTensorHandleKind getKind() const { return kind_; } // Release any underlying resources, including the interface object. diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 4be3cdd7c2db34..70acd710166fd5 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -337,10 +337,13 @@ tensorflow::Status CreateRemoteContexts( }); } counter.Wait(); + tensorflow::StatusGroup sg; for (int i = 0; i < num_remote_workers; i++) { - TF_RETURN_IF_ERROR(statuses[i]); + if (TF_PREDICT_FALSE(!statuses[i].ok())) { + sg.Update(statuses[i]); + } } - return tensorflow::Status::OK(); + return sg.as_summary_status(); } tensorflow::Status UpdateRemoteContexts( @@ -611,10 +614,21 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // Initialize remote eager workers. if (reset_context) { - LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( + const tensorflow::Status s = CreateRemoteContexts( ctx, remote_workers, context_id, context_view_id, keep_alive_secs, server_def, remote_eager_workers.get(), context->Executor().Async(), - context->LazyCopyFunctionRemoteInputs(), base_request)); + context->LazyCopyFunctionRemoteInputs(), base_request); + // NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause + // the CreateRemoteContexts to fail. We currently only log instead of + // directly returning the error, since returning here will cause the server + // object to be destroyed (which currently CHECK-fails). The client will + // see additional errors if ops are subsequently sent to the failed workers. + if (TF_PREDICT_FALSE(!s.ok())) { + LOG(ERROR) << "Error when creating contexts on remote targets: " + << s.error_message() + << "\nExecuting remote ops or functions on these remote " + "targets will fail."; + } } else { // The master's context_view_id will be incremented by one // the UpdateRemoteMaster call later. We want all new workers and @@ -644,15 +658,16 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( grpc_server->worker_env()->rendezvous_mgr->Find(context_id); auto* device_mgr = grpc_server->worker_env()->device_mgr; std::shared_ptr worker_session; - TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( - session_name, server_def, base_request.cluster_device_attributes(), - true)); - TF_RETURN_IF_ERROR( + LOG_AND_RETURN_IF_ERROR( + grpc_server->worker_env()->session_mgr->CreateSession( + session_name, server_def, base_request.cluster_device_attributes(), + true)); + LOG_AND_RETURN_IF_ERROR( grpc_server->worker_env()->session_mgr->WorkerSessionForSession( session_name, &worker_session)); // Initialize remote tensor communication based on worker session. - TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); + LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get())); tensorflow::DistributedFunctionLibraryRuntime* cluster_flr = tensorflow::eager::CreateClusterFLR(context_id, context, diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc index 65f8d3cc646328..a6547e23454817 100644 --- a/tensorflow/c/eager/c_api_distributed_test.cc +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -316,6 +317,114 @@ string VariableAddFunction() { return def.SerializeAsString(); } +// A graph optimization pass that would fail when triggered for more than once. +class GraphErrorInjectionPass : public tensorflow::GraphOptimizationPass { + public: + static bool enabled_; + GraphErrorInjectionPass() {} + + tensorflow::Status Run( + const tensorflow::GraphOptimizationPassOptions& options) override { + if (!enabled_) { + return tensorflow::Status::OK(); + } + if (first_call_) { + first_call_ = false; + return tensorflow::Status::OK(); + } + return tensorflow::errors::Internal("Graph pass runs for more than once!"); + } + + private: + bool first_call_ = true; +}; + +// After the graph pass is registered, it takes effect globally and can affect +// other test cases. Define a static variable to switch it on and off. +bool GraphErrorInjectionPass::enabled_ = false; + +// Test to ensure that a registered graph optimization pass is only executed +// once (i.e., on the main function side) in running distributed functions. +// This test creates a cluster with two workers, create a variable on the +// second worker, and run a distributed function (VariableAddFunction) whose ops +// span the local and remote workers. If the graph optimization pass is executed +// on both the main function side and the component function side, an error will +// be thrown in the registered graph optimization pass. +TEST(CAPI, DistributedFunctionGraphPassOnlyOnce) { + // Register graph pass that will raise error if called more than once. + tensorflow::optimization_registration::OptimizationPassRegistration + register_test_pass(tensorflow::OptimizationPassRegistry::PRE_PLACEMENT, 0, + std::make_unique(), + "error_injector"); + GraphErrorInjectionPass::enabled_ = true; + + tensorflow::ServerDef server_def = GetServerDef(3); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); + EXPECT_NE(var_handle, nullptr); + + const string function_def = VariableAddFunction(); + TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), + status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(func, var_handle, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(func, &retvals[0], &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retvals[0]); + float sum = 0; + ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t)); + memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + ASSERT_EQ(sum, 4.0); + + TFE_DeleteOp(func); + TFE_DeleteTensorHandle(var_handle); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server1.release(); + worker_server2.release(); + + // Disable the test graph pass so it does not affect other test cases. + GraphErrorInjectionPass::enabled_ = false; +} + class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass { public: FunctionErrorInjectionPass(string error_node, string error_device) diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index e7e1ef8486821e..6165a7d14a3637 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -49,6 +49,10 @@ class GraphTensor : public TracingTensorHandle { explicit GraphTensor(TF_Output output) : TracingTensorHandle(kGraph), output_(output) {} void Release() override { delete this; } + + tensorflow::DataType DataType() const override { + return static_cast(TF_OperationOutputType(output_)); + } TF_Output output_; // For LLVM style RTTI. @@ -102,9 +106,18 @@ class GraphOperation : public TracingOperation { TF_AddInput(op_.get(), t->output_); return Status::OK(); } - Status AddInputList(absl::Span inputs) override { - return tensorflow::errors::Unimplemented( - "AddInputList has not been implemented yet."); + Status AddInputList(absl::Span inputs) override { + std::vector tf_outputs(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + GraphTensor* t = dyn_cast(inputs[i]); + if (!t) { + return tensorflow::errors::InvalidArgument( + "Unable to cast input to GraphTensor"); + } + tf_outputs[i] = t->output_; + } + TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size()); + return Status::OK(); } Status Execute(absl::Span retvals, int* num_retvals) override { diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.h b/tensorflow/c/eager/immediate_execution_tensor_handle.h index c9e39a80663077..f7c77aa06db38f 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.h +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -33,8 +33,6 @@ namespace tensorflow { // is needed a static_cast can be applied. class ImmediateExecutionTensorHandle : public AbstractTensorHandle { public: - // Returns tensor dtype. - virtual tensorflow::DataType DataType() const = 0; // Returns number of dimensions. virtual Status NumDims(int* num_dims) const = 0; // Returns number of elements across all dimensions. diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index 5740fc4631e257..d0e9f351478817 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -243,8 +243,10 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context, const char* target_device_name, TF_Status* status, void* device_info) { - TF_SetStatus(status, TF_INTERNAL, - "Trying to copy a tensor out of a parallel device."); + TF_SetStatus(status, TF_UNIMPLEMENTED, + "Trying to copy a tensor out of a parallel device. Since there " + "are multiple components to parallel tensors, they must be " + "unpacked explicitly."); return nullptr; } diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc index 2fa183d50f6829..06a26ab2710092 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc @@ -157,7 +157,7 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) { // Copies off of parallel devices must be explicit. TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice( device_value.get(), context.get(), first_device_name, status.get())); - ASSERT_EQ(TF_GetCode(status.get()), TF_INTERNAL); + ASSERT_EQ(TF_GetCode(status.get()), TF_UNIMPLEMENTED); } TEST(PARALLEL_DEVICE, TestDifferentShapes) { diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index 2a886dee4cb24e..a0c137017664c8 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -52,6 +52,40 @@ cc_library( ], ) +cc_library( + name = "cleanup", + hdrs = ["cleanup.h"], +) + +cc_library( + name = "ram_file_block_cache", + srcs = ["ram_file_block_cache.cc"], + hdrs = ["ram_file_block_cache.h"], + deps = [ + ":cleanup", + ":file_block_cache", + "//tensorflow/c:env", + "//tensorflow/c:tf_status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "ram_file_block_cache_test", + size = "small", + srcs = ["ram_file_block_cache_test.cc"], + deps = [ + ":ram_file_block_cache", + "//tensorflow/c:tf_status_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:blocking_counter", + "//tensorflow/core/platform/cloud:now_seconds_env", + ], +) + tf_cc_test( name = "gcs_filesystem_test", srcs = [ @@ -69,3 +103,29 @@ tf_cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "expiring_lru_cache", + hdrs = ["expiring_lru_cache.h"], + deps = [ + "//tensorflow/c:env", + "//tensorflow/c:tf_status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "expiring_lru_cache_test", + size = "small", + srcs = ["expiring_lru_cache_test.cc"], + deps = [ + ":expiring_lru_cache", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c:tf_status_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform/cloud:now_seconds_env", + ], +) diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h b/tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h new file mode 100644 index 00000000000000..cc7a7451bb8fa8 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its +// destructor. The easiest way to use MakeCleanup is with a lambda argument, +// capturing the return value in an 'auto' local variable. Most users will not +// need more sophisticated syntax than that. +// +// Example: +// void func() { +// FILE* fp = fopen("data.txt", "r"); +// if (fp == nullptr) return; +// auto fp_cleaner = gtl::MakeCleanup([fp] { fclose(fp); }); +// // No matter what, fclose(fp) will happen. +// DataObject d; +// while (ReadDataObject(fp, &d)) { +// if (d.IsBad()) { +// LOG(ERROR) << "Bad Data"; +// return; +// } +// PushGoodData(d); +// } +// } +// +// You can use Cleanup directly, instead of using MakeCleanup and auto, +// but there's rarely a reason to do that. +// +// You can call 'release()' on a Cleanup object to cancel the cleanup. + +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_ + +#include +#include + +namespace tf_gcs_filesystem { + +// A move-only RAII object that calls a stored cleanup functor when +// destroyed. Cleanup is the return type of gtl::MakeCleanup(F). +template +class Cleanup { + public: + Cleanup() : released_(true), f_() {} + + template + explicit Cleanup(G&& f) // NOLINT + : f_(std::forward(f)) {} // NOLINT(build/c++11) + + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) {} + + // Implicitly move-constructible from any compatible Cleanup. + // The source will be released as if src.release() were called. + // A moved-from Cleanup can be safely destroyed or reassigned. + template + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) {} + + // Assignment to a Cleanup object behaves like destroying it + // and making a new one in its place, analogous to unique_ptr + // semantics. + Cleanup& operator=(Cleanup&& src) { // NOLINT + if (!released_) f_(); + released_ = src.released_; + f_ = src.release(); + return *this; + } + + ~Cleanup() { + if (!released_) f_(); + } + + // Releases the cleanup function instead of running it. + // Hint: use c.release()() to run early. + F release() { + released_ = true; + return std::move(f_); + } + + bool is_released() const { return released_; } + + private: + static_assert(!std::is_reference::value, "F must not be a reference"); + + bool released_ = false; + F f_; +}; + +template ::type> +Cleanup MakeCleanup(F&& f) { + return Cleanup(std::forward(f)); +} + +} // namespace tf_gcs_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h b/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h new file mode 100644 index 00000000000000..c0347faa16dff4 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h @@ -0,0 +1,191 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_EXPIRING_LRU_CACHE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_EXPIRING_LRU_CACHE_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/c/env.h" +#include "tensorflow/c/tf_status.h" + +namespace tf_gcs_filesystem { + +/// \brief An LRU cache of string keys and arbitrary values, with configurable +/// max item age (in seconds) and max entries. +/// +/// This class is thread safe. +template +class ExpiringLRUCache { + public: + /// A `max_age` of 0 means that nothing is cached. A `max_entries` of 0 means + /// that there is no limit on the number of entries in the cache (however, if + /// `max_age` is also 0, the cache will not be populated). + ExpiringLRUCache(uint64_t max_age, size_t max_entries, + std::function timer_seconds = TF_NowSeconds) + : max_age_(max_age), + max_entries_(max_entries), + timer_seconds_(timer_seconds) {} + + /// Insert `value` with key `key`. This will replace any previous entry with + /// the same key. + void Insert(const std::string& key, const T& value) { + if (max_age_ == 0) { + return; + } + absl::MutexLock lock(&mu_); + InsertLocked(key, value); + } + + // Delete the entry with key `key`. Return true if the entry was found for + // `key`, false if the entry was not found. In both cases, there is no entry + // with key `key` existed after the call. + bool Delete(const std::string& key) { + absl::MutexLock lock(&mu_); + return DeleteLocked(key); + } + + /// Look up the entry with key `key` and copy it to `value` if found. Returns + /// true if an entry was found for `key`, and its timestamp is not more than + /// max_age_ seconds in the past. + bool Lookup(const std::string& key, T* value) { + if (max_age_ == 0) { + return false; + } + absl::MutexLock lock(&mu_); + return LookupLocked(key, value); + } + + typedef std::function ComputeFunc; + + /// Look up the entry with key `key` and copy it to `value` if found. If not + /// found, call `compute_func`. If `compute_func` set `status` to `TF_OK`, + /// store a copy of the output parameter in the cache, and another copy in + /// `value`. + void LookupOrCompute(const std::string& key, T* value, + const ComputeFunc& compute_func, TF_Status* status) { + if (max_age_ == 0) { + return compute_func(key, value, status); + } + + // Note: we hold onto mu_ for the rest of this function. In practice, this + // is okay, as stat requests are typically fast, and concurrent requests are + // often for the same file. Future work can split this up into one lock per + // key if this proves to be a significant performance bottleneck. + absl::MutexLock lock(&mu_); + if (LookupLocked(key, value)) { + return TF_SetStatus(status, TF_OK, ""); + } + compute_func(key, value, status); + if (TF_GetCode(status) == TF_OK) { + InsertLocked(key, *value); + } + } + + /// Clear the cache. + void Clear() { + absl::MutexLock lock(&mu_); + cache_.clear(); + lru_list_.clear(); + } + + /// Accessors for cache parameters. + uint64_t max_age() const { return max_age_; } + size_t max_entries() const { return max_entries_; } + + private: + struct Entry { + /// The timestamp (seconds) at which the entry was added to the cache. + uint64_t timestamp; + + /// The entry's value. + T value; + + /// A list iterator pointing to the entry's position in the LRU list. + std::list::iterator lru_iterator; + }; + + bool LookupLocked(const std::string& key, T* value) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto it = cache_.find(key); + if (it == cache_.end()) { + return false; + } + lru_list_.erase(it->second.lru_iterator); + if (timer_seconds_() - it->second.timestamp > max_age_) { + cache_.erase(it); + return false; + } + *value = it->second.value; + lru_list_.push_front(it->first); + it->second.lru_iterator = lru_list_.begin(); + return true; + } + + void InsertLocked(const std::string& key, const T& value) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + lru_list_.push_front(key); + Entry entry{timer_seconds_(), value, lru_list_.begin()}; + auto insert = cache_.insert(std::make_pair(key, entry)); + if (!insert.second) { + lru_list_.erase(insert.first->second.lru_iterator); + insert.first->second = entry; + } else if (max_entries_ > 0 && cache_.size() > max_entries_) { + cache_.erase(lru_list_.back()); + lru_list_.pop_back(); + } + } + + bool DeleteLocked(const std::string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto it = cache_.find(key); + if (it == cache_.end()) { + return false; + } + lru_list_.erase(it->second.lru_iterator); + cache_.erase(it); + return true; + } + + /// The maximum age of entries in the cache, in seconds. A value of 0 means + /// that no entry is ever placed in the cache. + const uint64_t max_age_; + + /// The maximum number of entries in the cache. A value of 0 means there is no + /// limit on entry count. + const size_t max_entries_; + + /// The callback to read timestamps. + std::function timer_seconds_; + + /// Guards access to the cache and the LRU list. + absl::Mutex mu_; + + /// The cache (a map from string key to Entry). + std::map cache_ ABSL_GUARDED_BY(mu_); + + /// The LRU list of entries. The front of the list identifies the most + /// recently accessed entry. + std::list lru_list_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace tf_gcs_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_EXPIRING_LRU_CACHE_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache_test.cc new file mode 100644 index 00000000000000..b0d283fff82d9b --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache_test.cc @@ -0,0 +1,213 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h" + +#include + +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/cloud/now_seconds_env.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(ExpiringLRUCacheTest, MaxAge) { + const string key = "a"; + std::unique_ptr env(new NowSecondsEnv); + tf_gcs_filesystem::ExpiringLRUCache cache( + 1, 0, [&env]() { return env->NowSeconds(); }); + env->SetNowSeconds(1); + // Verify that replacement of an existing element works, and updates the + // timestamp of the entry. + cache.Insert(key, 41); + env->SetNowSeconds(2); + cache.Insert(key, 42); + // 1 second after the most recent insertion, the entry is still valid. + env->SetNowSeconds(3); + int value = 0; + EXPECT_TRUE(cache.Lookup(key, &value)); + EXPECT_EQ(value, 42); + // 2 seconds after the most recent insertion, the entry is no longer valid. + env->SetNowSeconds(4); + EXPECT_FALSE(cache.Lookup(key, &value)); + // Re-insert the entry. + cache.Insert(key, 43); + EXPECT_TRUE(cache.Lookup(key, &value)); + EXPECT_EQ(value, 43); + // The entry is valid 1 second after the insertion... + env->SetNowSeconds(5); + value = 0; + EXPECT_TRUE(cache.Lookup(key, &value)); + EXPECT_EQ(value, 43); + // ...but is no longer valid 2 seconds after the insertion. + env->SetNowSeconds(6); + EXPECT_FALSE(cache.Lookup(key, &value)); +} + +TEST(ExpiringLRUCacheTest, MaxEntries) { + // max_age of 0 means nothing will be cached. + tf_gcs_filesystem::ExpiringLRUCache cache1(0, 4); + cache1.Insert("a", 1); + int value = 0; + EXPECT_FALSE(cache1.Lookup("a", &value)); + // Now set max_age = 1 and verify the LRU eviction logic. + tf_gcs_filesystem::ExpiringLRUCache cache2(1, 4); + cache2.Insert("a", 1); + cache2.Insert("b", 2); + cache2.Insert("c", 3); + cache2.Insert("d", 4); + EXPECT_TRUE(cache2.Lookup("a", &value)); + EXPECT_EQ(value, 1); + EXPECT_TRUE(cache2.Lookup("b", &value)); + EXPECT_EQ(value, 2); + EXPECT_TRUE(cache2.Lookup("c", &value)); + EXPECT_EQ(value, 3); + EXPECT_TRUE(cache2.Lookup("d", &value)); + EXPECT_EQ(value, 4); + // Insertion of "e" causes "a" to be evicted, but the other entries are still + // there. + cache2.Insert("e", 5); + EXPECT_FALSE(cache2.Lookup("a", &value)); + EXPECT_TRUE(cache2.Lookup("b", &value)); + EXPECT_EQ(value, 2); + EXPECT_TRUE(cache2.Lookup("c", &value)); + EXPECT_EQ(value, 3); + EXPECT_TRUE(cache2.Lookup("d", &value)); + EXPECT_EQ(value, 4); + EXPECT_TRUE(cache2.Lookup("e", &value)); + EXPECT_EQ(value, 5); +} + +TEST(ExpiringLRUCacheTest, LookupOrCompute) { + // max_age of 0 means we should always compute. + uint64 num_compute_calls = 0; + tf_gcs_filesystem::ExpiringLRUCache::ComputeFunc compute_func = + [&num_compute_calls](const string& key, int* value, TF_Status* status) { + *value = num_compute_calls; + num_compute_calls++; + return TF_SetStatus(status, TF_OK, ""); + }; + tf_gcs_filesystem::ExpiringLRUCache cache1(0, 4); + + int value = -1; + TF_Status status; + cache1.LookupOrCompute("a", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 0); + EXPECT_EQ(num_compute_calls, 1); + // re-read the same value, expect another lookup + cache1.LookupOrCompute("a", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 1); + EXPECT_EQ(num_compute_calls, 2); + + // Define a new cache with max_age > 0 and verify correct behavior. + tf_gcs_filesystem::ExpiringLRUCache cache2(2, 4); + num_compute_calls = 0; + value = -1; + + // Read our first value + cache2.LookupOrCompute("a", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 0); + EXPECT_EQ(num_compute_calls, 1); + // Re-read, exepct no additional function compute_func calls. + cache2.LookupOrCompute("a", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 0); + EXPECT_EQ(num_compute_calls, 1); + + // Read a sequence of additional values, eventually evicting "a". + cache2.LookupOrCompute("b", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 1); + EXPECT_EQ(num_compute_calls, 2); + cache2.LookupOrCompute("c", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 2); + EXPECT_EQ(num_compute_calls, 3); + cache2.LookupOrCompute("d", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 3); + EXPECT_EQ(num_compute_calls, 4); + cache2.LookupOrCompute("e", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 4); + EXPECT_EQ(num_compute_calls, 5); + // Verify the other values remain in the cache. + cache2.LookupOrCompute("b", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 1); + EXPECT_EQ(num_compute_calls, 5); + cache2.LookupOrCompute("c", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 2); + EXPECT_EQ(num_compute_calls, 5); + cache2.LookupOrCompute("d", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 3); + EXPECT_EQ(num_compute_calls, 5); + + // Re-read "a", ensure it is re-computed. + cache2.LookupOrCompute("a", &value, compute_func, &status); + TF_EXPECT_OK(status.status); + EXPECT_EQ(value, 5); + EXPECT_EQ(num_compute_calls, 6); +} + +TEST(ExpiringLRUCacheTest, Clear) { + tf_gcs_filesystem::ExpiringLRUCache cache(1, 4); + cache.Insert("a", 1); + cache.Insert("b", 2); + cache.Insert("c", 3); + cache.Insert("d", 4); + int value = 0; + EXPECT_TRUE(cache.Lookup("a", &value)); + EXPECT_EQ(value, 1); + EXPECT_TRUE(cache.Lookup("b", &value)); + EXPECT_EQ(value, 2); + EXPECT_TRUE(cache.Lookup("c", &value)); + EXPECT_EQ(value, 3); + EXPECT_TRUE(cache.Lookup("d", &value)); + EXPECT_EQ(value, 4); + cache.Clear(); + EXPECT_FALSE(cache.Lookup("a", &value)); + EXPECT_FALSE(cache.Lookup("b", &value)); + EXPECT_FALSE(cache.Lookup("c", &value)); + EXPECT_FALSE(cache.Lookup("d", &value)); +} + +TEST(ExpiringLRUCacheTest, Delete) { + // Insert an entry. + tf_gcs_filesystem::ExpiringLRUCache cache(1, 4); + cache.Insert("a", 1); + int value = 0; + EXPECT_TRUE(cache.Lookup("a", &value)); + EXPECT_EQ(value, 1); + + // Delete the entry. + EXPECT_TRUE(cache.Delete("a")); + EXPECT_FALSE(cache.Lookup("a", &value)); + + // Try deleting the entry again. + EXPECT_FALSE(cache.Delete("a")); + EXPECT_FALSE(cache.Lookup("a", &value)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h b/tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h index aa45e71e9b4e4c..3ba7d8d7993f14 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h @@ -1,8 +1,11 @@ /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc index d3bbb8d9e889dd..7861a5708b5a49 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc @@ -73,6 +73,14 @@ void ParseGCSPath(const std::string& fname, bool object_empty_ok, } } +/// Appends a trailing slash if the name doesn't already have one. +static void MaybeAppendSlash(std::string* name) { + if (name->empty()) + *name = "/"; + else if (name->back() != '/') + name->push_back('/'); +} + // SECTION 1. Implementation for `TF_RandomAccessFile` // ---------------------------------------------------------------------------- namespace tf_random_access_file { @@ -169,6 +177,12 @@ static void SyncImpl(const std::string& bucket, const std::string& object, TF_SetStatusFromGCSStatus(metadata.status(), status); return; } + // We have to delete the temporary object after composing. + auto delete_status = gcs_client->DeleteObject(bucket, temporary_object); + if (!delete_status.ok()) { + TF_SetStatusFromGCSStatus(delete_status, status); + return; + } // We truncate the data that are already uploaded. if (!outfile->truncate()) { TF_SetStatus(status, TF_INTERNAL, @@ -276,6 +290,10 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region) { // ---------------------------------------------------------------------------- namespace tf_gcs_filesystem { // TODO(vnvo2409): Add lazy-loading and customizing parameters. +// TODO(vnvo2409): Use partial reponse for better performance. +// TODO(vnvo2409): We could do some cleanups like `return TF_SetStatus`. +// TODO(vnvo2409): Refactor the filesystem implementation when +// https://github.com/googleapis/google-cloud-cpp/issues/4482 is done. void Init(TF_Filesystem* filesystem, TF_Status* status) { google::cloud::StatusOr client = gcs::Client::CreateDefaultClient(); @@ -410,6 +428,134 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, } } +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + std::string bucket, object; + ParseGCSPath(path, true, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + auto gcs_file = static_cast(filesystem->plugin_filesystem); + if (object.empty()) { + auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); + TF_SetStatusFromGCSStatus(bucket_metadata.status(), status); + return; + } + + MaybeAppendSlash(&object); + auto object_metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); + TF_SetStatusFromGCSStatus(object_metadata.status(), status); + if (TF_GetCode(status) == TF_NOT_FOUND) { + auto insert_metadata = + gcs_file->gcs_client.InsertObject(bucket, object, ""); + TF_SetStatusFromGCSStatus(insert_metadata.status(), status); + } else if (TF_GetCode(status) == TF_OK) { + TF_SetStatus(status, TF_ALREADY_EXISTS, path); + } +} + +// TODO(vnvo2409): `RecursivelyCreateDir` should use `CreateDir` instead of the +// default implementation. Because we could create an empty object whose +// key is equal to the `path` and Google Cloud Console will automatically +// display it as a directory tree. + +void DeleteFile(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + std::string bucket, object; + ParseGCSPath(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + auto gcs_file = static_cast(filesystem->plugin_filesystem); + auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object); + TF_SetStatusFromGCSStatus(gcs_status, status); +} + +void DeleteDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + std::string bucket, object; + ParseGCSPath(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + MaybeAppendSlash(&object); + auto gcs_file = static_cast(filesystem->plugin_filesystem); + int object_count = 0; + for (auto&& metadata : + gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) { + if (!metadata) { + TF_SetStatusFromGCSStatus(metadata.status(), status); + return; + } + ++object_count; + // We consider a path is a non-empty directory in two cases: + // - There are more than two objects whose keys start with the name of this + // directory. + // - There is one object whose key contains the name of this directory ( but + // not equal ). + if (object_count > 1 || metadata->name() != object) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, + "Cannot delete a non-empty directory."); + return; + } + } + auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object); + TF_SetStatusFromGCSStatus(gcs_status, status); +} + +// TODO(vnvo2409): `DeleteRecursively` needs `GetChildrens` but there will be +// some differents compared to the default implementation. Will be refactored. +static void DeleteRecursively(const TF_Filesystem* filesystem, const char* path, + uint64_t* undeleted_files, + uint64_t* undeleted_dirs, TF_Status* status) { + std::string bucket, object; + ParseGCSPath(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + + auto gcs_file = static_cast(filesystem->plugin_filesystem); + auto gcs_status = gcs::DeleteByPrefix(gcs_file->gcs_client, bucket, object); + TF_SetStatusFromGCSStatus(gcs_status, status); + if (TF_GetCode(status) != TF_OK) return; + *undeleted_dirs = 0; + *undeleted_files = 0; +} + +// TODO(vnvo2409): `RewriteObjectBlocking` will set `status` to `TF_NOT_FOUND` +// if the object does not exist. In that case, we will have to check if the +// `src` is a directory or not to set the correspondent `status` (i.e +// `TF_NOT_FOUND` if path `src` does not exist, `TF_FAILED_PRECONDITION` if +// path `src` is a directory). +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status) { + std::string bucket_src, object_src; + ParseGCSPath(src, false, &bucket_src, &object_src, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string bucket_dst, object_dst; + ParseGCSPath(dst, false, &bucket_dst, &object_dst, status); + if (TF_GetCode(status) != TF_OK) return; + + auto gcs_file = static_cast(filesystem->plugin_filesystem); + auto metadata = gcs_file->gcs_client.RewriteObjectBlocking( + bucket_src, object_src, bucket_dst, object_dst); + if (!metadata) { + TF_SetStatusFromGCSStatus(metadata.status(), status); + return; + } + auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket_src, object_src); + TF_SetStatusFromGCSStatus(gcs_status, status); +} + +void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, + TF_Status* status) { + std::string bucket_src, object_src; + ParseGCSPath(src, false, &bucket_src, &object_src, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string bucket_dst, object_dst; + ParseGCSPath(dst, false, &bucket_dst, &object_dst, status); + if (TF_GetCode(status) != TF_OK) return; + + auto gcs_file = static_cast(filesystem->plugin_filesystem); + auto metadata = gcs_file->gcs_client.RewriteObjectBlocking( + bucket_src, object_src, bucket_dst, object_dst); + TF_SetStatusFromGCSStatus(metadata.status(), status); +} + } // namespace tf_gcs_filesystem static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.cc new file mode 100644 index 00000000000000..102c7fa175cf7a --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.cc @@ -0,0 +1,317 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h" + +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h" + +namespace tf_gcs_filesystem { + +bool RamFileBlockCache::BlockNotStale(const std::shared_ptr& block) { + absl::MutexLock l(&block->mu); + if (block->state != FetchState::FINISHED) { + return true; // No need to check for staleness. + } + if (max_staleness_ == 0) return true; // Not enforcing staleness. + return timer_seconds_() - block->timestamp <= max_staleness_; +} + +std::shared_ptr RamFileBlockCache::Lookup( + const Key& key) { + absl::MutexLock lock(&mu_); + auto entry = block_map_.find(key); + if (entry != block_map_.end()) { + if (BlockNotStale(entry->second)) { + if (cache_stats_ != nullptr) { + cache_stats_->RecordCacheHitBlockSize(entry->second->data.size()); + } + return entry->second; + } else { + // Remove the stale block and continue. + RemoveFile_Locked(key.first); + } + } + + // Insert a new empty block, setting the bookkeeping to sentinel values + // in order to update them as appropriate. + auto new_entry = std::make_shared(); + lru_list_.push_front(key); + lra_list_.push_front(key); + new_entry->lru_iterator = lru_list_.begin(); + new_entry->lra_iterator = lra_list_.begin(); + new_entry->timestamp = timer_seconds_(); + block_map_.emplace(std::make_pair(key, new_entry)); + return new_entry; +} + +// Remove blocks from the cache until we do not exceed our maximum size. +void RamFileBlockCache::Trim() { + while (!lru_list_.empty() && cache_size_ > max_bytes_) { + RemoveBlock(block_map_.find(lru_list_.back())); + } +} + +/// Move the block to the front of the LRU list if it isn't already there. +void RamFileBlockCache::UpdateLRU(const Key& key, + const std::shared_ptr& block, + TF_Status* status) { + absl::MutexLock lock(&mu_); + if (block->timestamp == 0) { + // The block was evicted from another thread. Allow it to remain evicted. + return TF_SetStatus(status, TF_OK, ""); + } + if (block->lru_iterator != lru_list_.begin()) { + lru_list_.erase(block->lru_iterator); + lru_list_.push_front(key); + block->lru_iterator = lru_list_.begin(); + } + + // Check for inconsistent state. If there is a block later in the same file + // in the cache, and our current block is not block size, this likely means + // we have inconsistent state within the cache. Note: it's possible some + // incomplete reads may still go undetected. + if (block->data.size() < block_size_) { + Key fmax = std::make_pair(key.first, std::numeric_limits::max()); + auto fcmp = block_map_.upper_bound(fmax); + if (fcmp != block_map_.begin() && key < (--fcmp)->first) { + return TF_SetStatus(status, TF_INTERNAL, + "Block cache contents are inconsistent."); + } + } + + Trim(); + + return TF_SetStatus(status, TF_OK, ""); +} + +void RamFileBlockCache::MaybeFetch(const Key& key, + const std::shared_ptr& block, + TF_Status* status) { + bool downloaded_block = false; + auto reconcile_state = MakeCleanup([this, &downloaded_block, &key, &block] { + // Perform this action in a cleanup callback to avoid locking mu_ after + // locking block->mu. + if (downloaded_block) { + absl::MutexLock l(&mu_); + // Do not update state if the block is already to be evicted. + if (block->timestamp != 0) { + // Use capacity() instead of size() to account for all memory + // used by the cache. + cache_size_ += block->data.capacity(); + // Put to beginning of LRA list. + lra_list_.erase(block->lra_iterator); + lra_list_.push_front(key); + block->lra_iterator = lra_list_.begin(); + block->timestamp = timer_seconds_(); + } + } + }); + // Loop until either block content is successfully fetched, or our request + // encounters an error. + absl::MutexLock l(&block->mu); + TF_SetStatus(status, TF_OK, ""); + while (true) { + switch (block->state) { + case FetchState::ERROR: + // TF_FALLTHROUGH_INTENDED + case FetchState::CREATED: + block->state = FetchState::FETCHING; + block->mu.Unlock(); // Release the lock while making the API call. + block->data.clear(); + block->data.resize(block_size_, 0); + size_t bytes_transferred; + block_fetcher_(key.first, key.second, block_size_, block->data.data(), + &bytes_transferred, status); + if (cache_stats_ != nullptr) { + cache_stats_->RecordCacheMissBlockSize(bytes_transferred); + } + block->mu.Lock(); // Reacquire the lock immediately afterwards + if (TF_GetCode(status) == TF_OK) { + block->data.resize(bytes_transferred, 0); + // Shrink the data capacity to the actual size used. + // NOLINTNEXTLINE: shrink_to_fit() may not shrink the capacity. + std::vector(block->data).swap(block->data); + downloaded_block = true; + block->state = FetchState::FINISHED; + } else { + block->state = FetchState::ERROR; + } + block->cond_var.SignalAll(); + return; + case FetchState::FETCHING: + block->cond_var.WaitWithTimeout(&block->mu, absl::Minutes(1)); + if (block->state == FetchState::FINISHED) { + return TF_SetStatus(status, TF_OK, ""); + } + // Re-loop in case of errors. + break; + case FetchState::FINISHED: + return TF_SetStatus(status, TF_OK, ""); + } + } + return TF_SetStatus( + status, TF_INTERNAL, + "Control flow should never reach the end of RamFileBlockCache::Fetch."); +} + +void RamFileBlockCache::Read(const std::string& filename, size_t offset, + size_t n, char* buffer, size_t* bytes_transferred, + TF_Status* status) { + *bytes_transferred = 0; + if (n == 0) { + return TF_SetStatus(status, TF_OK, ""); + } + if (!IsCacheEnabled() || (n > max_bytes_)) { + // The cache is effectively disabled, so we pass the read through to the + // fetcher without breaking it up into blocks. + return block_fetcher_(filename, offset, n, buffer, bytes_transferred, + status); + } + // Calculate the block-aligned start and end of the read. + size_t start = block_size_ * (offset / block_size_); + size_t finish = block_size_ * ((offset + n) / block_size_); + if (finish < offset + n) { + finish += block_size_; + } + size_t total_bytes_transferred = 0; + // Now iterate through the blocks, reading them one at a time. + for (size_t pos = start; pos < finish; pos += block_size_) { + Key key = std::make_pair(filename, pos); + // Look up the block, fetching and inserting it if necessary, and update the + // LRU iterator for the key and block. + std::shared_ptr block = Lookup(key); + if (!block) { + std::cerr << "No block for key " << key.first << "@" << key.second; + abort(); + } + MaybeFetch(key, block, status); + if (TF_GetCode(status) != TF_OK) return; + UpdateLRU(key, block, status); + if (TF_GetCode(status) != TF_OK) return; + // Copy the relevant portion of the block into the result buffer. + const auto& data = block->data; + if (offset >= pos + data.size()) { + // The requested offset is at or beyond the end of the file. This can + // happen if `offset` is not block-aligned, and the read returns the last + // block in the file, which does not extend all the way out to `offset`. + *bytes_transferred = total_bytes_transferred; + std::stringstream os; + os << "EOF at offset " << offset << " in file " << filename + << " at position " << pos << " with data size " << data.size(); + return TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str()); + } + auto begin = data.begin(); + if (offset > pos) { + // The block begins before the slice we're reading. + begin += offset - pos; + } + auto end = data.end(); + if (pos + data.size() > offset + n) { + // The block extends past the end of the slice we're reading. + end -= (pos + data.size()) - (offset + n); + } + if (begin < end) { + size_t bytes_to_copy = end - begin; + memcpy(&buffer[total_bytes_transferred], &*begin, bytes_to_copy); + total_bytes_transferred += bytes_to_copy; + } + if (data.size() < block_size_) { + // The block was a partial block and thus signals EOF at its upper bound. + break; + } + } + *bytes_transferred = total_bytes_transferred; + return TF_SetStatus(status, TF_OK, ""); +} + +bool RamFileBlockCache::ValidateAndUpdateFileSignature( + const std::string& filename, int64_t file_signature) { + absl::MutexLock lock(&mu_); + auto it = file_signature_map_.find(filename); + if (it != file_signature_map_.end()) { + if (it->second == file_signature) { + return true; + } + // Remove the file from cache if the signatures don't match. + RemoveFile_Locked(filename); + it->second = file_signature; + return false; + } + file_signature_map_[filename] = file_signature; + return true; +} + +size_t RamFileBlockCache::CacheSize() const { + absl::MutexLock lock(&mu_); + return cache_size_; +} + +void RamFileBlockCache::Prune() { + while (!stop_pruning_thread_.WaitForNotificationWithTimeout( + absl::Microseconds(1000000))) { + absl::MutexLock lock(&mu_); + uint64_t now = timer_seconds_(); + while (!lra_list_.empty()) { + auto it = block_map_.find(lra_list_.back()); + if (now - it->second->timestamp <= max_staleness_) { + // The oldest block is not yet expired. Come back later. + break; + } + // We need to make a copy of the filename here, since it could otherwise + // be used within RemoveFile_Locked after `it` is deleted. + RemoveFile_Locked(std::string(it->first.first)); + } + } +} + +void RamFileBlockCache::Flush() { + absl::MutexLock lock(&mu_); + block_map_.clear(); + lru_list_.clear(); + lra_list_.clear(); + cache_size_ = 0; +} + +void RamFileBlockCache::RemoveFile(const std::string& filename) { + absl::MutexLock lock(&mu_); + RemoveFile_Locked(filename); +} + +void RamFileBlockCache::RemoveFile_Locked(const std::string& filename) { + Key begin = std::make_pair(filename, 0); + auto it = block_map_.lower_bound(begin); + while (it != block_map_.end() && it->first.first == filename) { + auto next = std::next(it); + RemoveBlock(it); + it = next; + } +} + +void RamFileBlockCache::RemoveBlock(BlockMap::iterator entry) { + // This signals that the block is removed, and should not be inadvertently + // reinserted into the cache in UpdateLRU. + entry->second->timestamp = 0; + lru_list_.erase(entry->second->lru_iterator); + lra_list_.erase(entry->second->lra_iterator); + cache_size_ -= entry->second->data.capacity(); + block_map_.erase(entry); +} + +} // namespace tf_gcs_filesystem diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h new file mode 100644 index 00000000000000..5a82f65db418b0 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h @@ -0,0 +1,267 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "tensorflow/c/env.h" +#include "tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h" +#include "tensorflow/c/tf_status.h" + +namespace tf_gcs_filesystem { + +/// \brief An LRU block cache of file contents, keyed by {filename, offset}. +/// +/// This class should be shared by read-only random access files on a remote +/// filesystem (e.g. GCS). +class RamFileBlockCache : public FileBlockCache { + public: + /// The callback executed when a block is not found in the cache, and needs to + /// be fetched from the backing filesystem. This callback is provided when the + /// cache is constructed. The `status` should be `TF_OK` as long as the + /// read from the remote filesystem succeeded (similar to the semantics of the + /// read(2) system call). + typedef std::function + BlockFetcher; + + RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness, + BlockFetcher block_fetcher, + std::function timer_seconds = TF_NowSeconds) + : block_size_(block_size), + max_bytes_(max_bytes), + max_staleness_(max_staleness), + block_fetcher_(block_fetcher), + timer_seconds_(timer_seconds), + pruning_thread_(nullptr, + [](TF_Thread* thread) { TF_JoinThread(thread); }) { + if (max_staleness_ > 0) { + TF_ThreadOptions thread_options; + TF_DefaultThreadOptions(&thread_options); + pruning_thread_.reset( + TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this)); + } + std::cout << "GCS file block cache is " + << (IsCacheEnabled() ? "enabled" : "disabled"); + } + + ~RamFileBlockCache() override { + if (pruning_thread_) { + stop_pruning_thread_.Notify(); + // Destroying pruning_thread_ will block until Prune() receives the above + // notification and returns. + pruning_thread_.reset(); + } + } + + /// Read `n` bytes from `filename` starting at `offset` into `buffer`. This + /// method will set `status` to: + /// + /// 1) The error from the remote filesystem, if the read from the remote + /// filesystem failed. + /// 2) `TF_FAILED_PRECONDITION` if the read from the remote filesystem + /// succeeded, + /// but the read returned a partial block, and the LRU cache contained a + /// block at a higher offset (indicating that the partial block should have + /// been a full block). + /// 3) `TF_OUT_OF_RANGE` if the read from the remote filesystem succeeded, but + /// the file contents do not extend past `offset` and thus nothing was + /// placed in `out`. + /// 4) `TF_OK` otherwise (i.e. the read succeeded, and at least one byte was + /// placed + /// in `buffer`). + /// + /// Caller is responsible for allocating memory for `buffer`. + /// `buffer` will be left unchanged in case of errors. + void Read(const std::string& filename, size_t offset, size_t n, char* buffer, + size_t* bytes_transferred, TF_Status* status) override; + + // Validate the given file signature with the existing file signature in the + // cache. Returns true if the signature doesn't change or the file doesn't + // exist before. If the signature changes, update the existing signature with + // the new one and remove the file from cache. + bool ValidateAndUpdateFileSignature(const std::string& filename, + int64_t file_signature) override + ABSL_LOCKS_EXCLUDED(mu_); + + /// Remove all cached blocks for `filename`. + void RemoveFile(const std::string& filename) override + ABSL_LOCKS_EXCLUDED(mu_); + + /// Remove all cached data. + void Flush() override ABSL_LOCKS_EXCLUDED(mu_); + + /// Accessors for cache parameters. + size_t block_size() const override { return block_size_; } + size_t max_bytes() const override { return max_bytes_; } + uint64_t max_staleness() const override { return max_staleness_; } + + /// The current size (in bytes) of the cache. + size_t CacheSize() const override ABSL_LOCKS_EXCLUDED(mu_); + + // Returns true if the cache is enabled. If false, the BlockFetcher callback + // is always executed during Read. + bool IsCacheEnabled() const override { + return block_size_ > 0 && max_bytes_ > 0; + } + + // We can not pass a lambda with capture as a function pointer to + // `TF_StartThread`, so we have to wrap `Prune` inside a static function. + static void PruneThread(void* param) { + auto ram_file_block_cache = static_cast(param); + ram_file_block_cache->Prune(); + } + + private: + /// The size of the blocks stored in the LRU cache, as well as the size of the + /// reads from the underlying filesystem. + const size_t block_size_; + /// The maximum number of bytes (sum of block sizes) allowed in the LRU cache. + const size_t max_bytes_; + /// The maximum staleness of any block in the LRU cache, in seconds. + const uint64_t max_staleness_; + /// The callback to read a block from the underlying filesystem. + const BlockFetcher block_fetcher_; + /// The callback to read timestamps. + const std::function timer_seconds_; + + /// \brief The key type for the file block cache. + /// + /// The file block cache key is a {filename, offset} pair. + typedef std::pair Key; + + /// \brief The state of a block. + /// + /// A block begins in the CREATED stage. The first thread will attempt to read + /// the block from the filesystem, transitioning the state of the block to + /// FETCHING. After completing, if the read was successful the state should + /// be FINISHED. Otherwise the state should be ERROR. A subsequent read can + /// re-fetch the block if the state is ERROR. + enum class FetchState { + CREATED, + FETCHING, + FINISHED, + ERROR, + }; + + /// \brief A block of a file. + /// + /// A file block consists of the block data, the block's current position in + /// the LRU cache, the timestamp (seconds since epoch) at which the block + /// was cached, a coordination lock, and state & condition variables. + /// + /// Thread safety: + /// The iterator and timestamp fields should only be accessed while holding + /// the block-cache-wide mu_ instance variable. The state variable should only + /// be accessed while holding the Block's mu lock. The data vector should only + /// be accessed after state == FINISHED, and it should never be modified. + /// + /// In order to prevent deadlocks, never grab the block-cache-wide mu_ lock + /// AFTER grabbing any block's mu lock. It is safe to grab mu without locking + /// mu_. + struct Block { + /// The block data. + std::vector data; + /// A list iterator pointing to the block's position in the LRU list. + std::list::iterator lru_iterator; + /// A list iterator pointing to the block's position in the LRA list. + std::list::iterator lra_iterator; + /// The timestamp (seconds since epoch) at which the block was cached. + uint64_t timestamp; + /// Mutex to guard state variable + absl::Mutex mu; + /// The state of the block. + FetchState state ABSL_GUARDED_BY(mu) = FetchState::CREATED; + /// Wait on cond_var if state is FETCHING. + absl::CondVar cond_var; + }; + + /// \brief The block map type for the file block cache. + /// + /// The block map is an ordered map from Key to Block. + typedef std::map> BlockMap; + + /// Prune the cache by removing files with expired blocks. + void Prune() ABSL_LOCKS_EXCLUDED(mu_); + + bool BlockNotStale(const std::shared_ptr& block) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// Look up a Key in the block cache. + std::shared_ptr Lookup(const Key& key) ABSL_LOCKS_EXCLUDED(mu_); + + void MaybeFetch(const Key& key, const std::shared_ptr& block, + TF_Status* status) ABSL_LOCKS_EXCLUDED(mu_); + + /// Trim the block cache to make room for another entry. + void Trim() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// Update the LRU iterator for the block at `key`. + void UpdateLRU(const Key& key, const std::shared_ptr& block, + TF_Status* status) ABSL_LOCKS_EXCLUDED(mu_); + + /// Remove all blocks of a file, with mu_ already held. + void RemoveFile_Locked(const std::string& filename) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// Remove the block `entry` from the block map and LRU list, and update the + /// cache size accordingly. + void RemoveBlock(BlockMap::iterator entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// The cache pruning thread that removes files with expired blocks. + std::unique_ptr> pruning_thread_; + + /// Notification for stopping the cache pruning thread. + absl::Notification stop_pruning_thread_; + + /// Guards access to the block map, LRU list, and cached byte count. + mutable absl::Mutex mu_; + + /// The block map (map from Key to Block). + BlockMap block_map_ ABSL_GUARDED_BY(mu_); + + /// The LRU list of block keys. The front of the list identifies the most + /// recently accessed block. + std::list lru_list_ ABSL_GUARDED_BY(mu_); + + /// The LRA (least recently added) list of block keys. The front of the list + /// identifies the most recently added block. + /// + /// Note: blocks are added to lra_list_ only after they have successfully been + /// fetched from the underlying block store. + std::list lra_list_ ABSL_GUARDED_BY(mu_); + + /// The combined number of bytes in all of the cached blocks. + size_t cache_size_ ABSL_GUARDED_BY(mu_) = 0; + + // A filename->file_signature map. + std::map file_signature_map_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace tf_gcs_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc new file mode 100644 index 00000000000000..8436b1a1b68d12 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc @@ -0,0 +1,600 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h" + +#include + +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/blocking_counter.h" +#include "tensorflow/core/platform/cloud/now_seconds_env.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +Status ReadCache(tf_gcs_filesystem::RamFileBlockCache* cache, + const string& filename, size_t offset, size_t n, + std::vector* out) { + out->clear(); + out->resize(n, 0); + size_t bytes_transferred = 0; + TF_Status status; + cache->Read(filename, offset, n, out->data(), &bytes_transferred, &status); + EXPECT_LE(bytes_transferred, n); + out->resize(bytes_transferred, n); + return status.status; +} + +TEST(RamFileBlockCacheTest, IsCacheEnabled) { + auto fetcher = [](const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred, + TF_Status* status) { + // Do nothing. + return TF_SetStatus(status, TF_OK, ""); + }; + tf_gcs_filesystem::RamFileBlockCache cache1(0, 0, 0, fetcher); + tf_gcs_filesystem::RamFileBlockCache cache2(16, 0, 0, fetcher); + tf_gcs_filesystem::RamFileBlockCache cache3(0, 32, 0, fetcher); + tf_gcs_filesystem::RamFileBlockCache cache4(16, 32, 0, fetcher); + + EXPECT_FALSE(cache1.IsCacheEnabled()); + EXPECT_FALSE(cache2.IsCacheEnabled()); + EXPECT_FALSE(cache3.IsCacheEnabled()); + EXPECT_TRUE(cache4.IsCacheEnabled()); +} + +TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) { + int calls = 0; + auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred, + TF_Status* status) { + calls++; + memset(buffer, 'x', n); + *bytes_transferred = n; + return TF_SetStatus(status, TF_OK, ""); + }; + string filename = "file"; + tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher); + std::vector out; + + // First read. + EXPECT_TRUE(cache.ValidateAndUpdateFileSignature(filename, 123)); + TF_EXPECT_OK(ReadCache(&cache, filename, 0, 16, &out)); + EXPECT_EQ(calls, 1); + + // Second read. Hit cache. + EXPECT_TRUE(cache.ValidateAndUpdateFileSignature(filename, 123)); + TF_EXPECT_OK(ReadCache(&cache, filename, 0, 16, &out)); + EXPECT_EQ(calls, 1); + + // Third read. File signatures are different. + EXPECT_FALSE(cache.ValidateAndUpdateFileSignature(filename, 321)); + TF_EXPECT_OK(ReadCache(&cache, filename, 0, 16, &out)); + EXPECT_EQ(calls, 2); +} + +TEST(RamFileBlockCacheTest, PassThrough) { + const string want_filename = "foo/bar"; + const size_t want_offset = 42; + const size_t want_n = 1024; + int calls = 0; + auto fetcher = [&calls, want_filename, want_offset, want_n]( + const string& got_filename, size_t got_offset, + size_t got_n, char* buffer, size_t* bytes_transferred, + TF_Status* status) { + EXPECT_EQ(got_filename, want_filename); + EXPECT_EQ(got_offset, want_offset); + EXPECT_EQ(got_n, want_n); + calls++; + memset(buffer, 'x', got_n); + *bytes_transferred = got_n; + return TF_SetStatus(status, TF_OK, ""); + }; + // If block_size, max_bytes, or both are zero, or want_n is larger than + // max_bytes the cache is a pass-through. + tf_gcs_filesystem::RamFileBlockCache cache1(1, 0, 0, fetcher); + tf_gcs_filesystem::RamFileBlockCache cache2(0, 1, 0, fetcher); + tf_gcs_filesystem::RamFileBlockCache cache3(0, 0, 0, fetcher); + tf_gcs_filesystem::RamFileBlockCache cache4(1000, 1000, 0, fetcher); + std::vector out; + TF_EXPECT_OK(ReadCache(&cache1, want_filename, want_offset, want_n, &out)); + EXPECT_EQ(calls, 1); + TF_EXPECT_OK(ReadCache(&cache2, want_filename, want_offset, want_n, &out)); + EXPECT_EQ(calls, 2); + TF_EXPECT_OK(ReadCache(&cache3, want_filename, want_offset, want_n, &out)); + EXPECT_EQ(calls, 3); + TF_EXPECT_OK(ReadCache(&cache4, want_filename, want_offset, want_n, &out)); + EXPECT_EQ(calls, 4); +} + +TEST(RamFileBlockCacheTest, BlockAlignment) { + // Initialize a 256-byte buffer. This is the file underlying the reads we'll + // do in this test. + const size_t size = 256; + std::vector buf; + for (int i = 0; i < size; i++) { + buf.push_back(i); + } + // The fetcher just fetches slices of the buffer. + auto fetcher = [&buf](const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred, + TF_Status* status) { + if (offset < buf.size()) { + size_t bytes_to_copy = std::min(buf.size() - offset, n); + memcpy(buffer, buf.data() + offset, bytes_to_copy); + *bytes_transferred = bytes_to_copy; + } else { + *bytes_transferred = 0; + } + return TF_SetStatus(status, TF_OK, ""); + }; + for (size_t block_size = 2; block_size <= 4; block_size++) { + // Make a cache of N-byte block size (1 block) and verify that reads of + // varying offsets and lengths return correct data. + tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0, + fetcher); + for (size_t offset = 0; offset < 10; offset++) { + for (size_t n = block_size - 2; n <= block_size + 2; n++) { + std::vector got; + TF_EXPECT_OK(ReadCache(&cache, "", offset, n, &got)); + // Verify the size of the read. + if (offset + n <= size) { + // Expect a full read. + EXPECT_EQ(got.size(), n) << "block size = " << block_size + << ", offset = " << offset << ", n = " << n; + } else { + // Expect a partial read. + EXPECT_EQ(got.size(), size - offset) + << "block size = " << block_size << ", offset = " << offset + << ", n = " << n; + } + // Verify the contents of the read. + std::vector::const_iterator begin = buf.begin() + offset; + std::vector::const_iterator end = + offset + n > buf.size() ? buf.end() : begin + n; + std::vector want(begin, end); + EXPECT_EQ(got, want) << "block size = " << block_size + << ", offset = " << offset << ", n = " << n; + } + } + } +} + +TEST(RamFileBlockCacheTest, CacheHits) { + const size_t block_size = 16; + std::set calls; + auto fetcher = [&calls, block_size](const string& filename, size_t offset, + size_t n, char* buffer, + size_t* bytes_transferred, + TF_Status* status) { + EXPECT_EQ(n, block_size); + EXPECT_EQ(offset % block_size, 0); + EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset; + calls.insert(offset); + memset(buffer, 'x', n); + *bytes_transferred = n; + return TF_SetStatus(status, TF_OK, ""); + }; + const uint32 block_count = 256; + tf_gcs_filesystem::RamFileBlockCache cache( + block_size, block_count * block_size, 0, fetcher); + std::vector out; + out.resize(block_count, 0); + // The cache has space for `block_count` blocks. The loop with i = 0 should + // fill the cache, and the loop with i = 1 should be all cache hits. The + // fetcher checks that it is called once and only once for each offset (to + // fetch the corresponding block). + for (int i = 0; i < 2; i++) { + for (int j = 0; j < block_count; j++) { + TF_EXPECT_OK(ReadCache(&cache, "", block_size * j, block_size, &out)); + } + } +} + +TEST(RamFileBlockCacheTest, OutOfRange) { + // Tests reads of a 24-byte file with block size 16. + const size_t block_size = 16; + const size_t file_size = 24; + bool first_block = false; + bool second_block = false; + auto fetcher = [block_size, &first_block, &second_block]( + const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred, + TF_Status* status) { + EXPECT_EQ(n, block_size); + EXPECT_EQ(offset % block_size, 0); + size_t bytes_to_copy = 0; + if (offset == 0) { + // The first block (16 bytes) of the file. + memset(buffer, 'x', n); + bytes_to_copy = n; + first_block = true; + } else if (offset == block_size) { + // The second block (8 bytes) of the file. + bytes_to_copy = file_size - block_size; + memset(buffer, 'x', bytes_to_copy); + second_block = true; + } + *bytes_transferred = bytes_to_copy; + return TF_SetStatus(status, TF_OK, ""); + }; + tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0, + fetcher); + std::vector out; + // Reading the first 16 bytes should be fine. + TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size, &out)); + EXPECT_TRUE(first_block); + EXPECT_EQ(out.size(), block_size); + // Reading at offset file_size + 4 will read the second block (since the read + // at file_size + 4 = 28 will be aligned to an offset of 16) but will return + // OutOfRange because the offset is past the end of the 24-byte file. + Status status = ReadCache(&cache, "", file_size + 4, 4, &out); + EXPECT_EQ(status.code(), error::OUT_OF_RANGE); + EXPECT_TRUE(second_block); + // Reading the second full block will return 8 bytes, from a cache hit. + second_block = false; + TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out)); + EXPECT_FALSE(second_block); + EXPECT_EQ(out.size(), file_size - block_size); +} + +TEST(RamFileBlockCacheTest, Inconsistent) { + // Tests the detection of interrupted reads leading to partially filled blocks + // where we expected complete blocks. + const size_t block_size = 16; + // This fetcher returns OK but only fills in one byte for any offset. + auto fetcher = [block_size](const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred, + TF_Status* status) { + EXPECT_EQ(n, block_size); + EXPECT_EQ(offset % block_size, 0); + EXPECT_GE(n, 1); + memset(buffer, 'x', 1); + *bytes_transferred = 1; + return TF_SetStatus(status, TF_OK, ""); + }; + tf_gcs_filesystem::RamFileBlockCache cache(block_size, 2 * block_size, 0, + fetcher); + std::vector out; + // Read the second block; this should yield an OK status and a single byte. + TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out)); + EXPECT_EQ(out.size(), 1); + // Now read the first block; this should yield an INTERNAL error because we + // had already cached a partial block at a later position. + Status status = ReadCache(&cache, "", 0, block_size, &out); + EXPECT_EQ(status.code(), error::INTERNAL); +} + +TEST(RamFileBlockCacheTest, LRU) { + const size_t block_size = 16; + std::list calls; + auto fetcher = [&calls, block_size](const string& filename, size_t offset, + size_t n, char* buffer, + size_t* bytes_transferred, + TF_Status* status) { + EXPECT_EQ(n, block_size); + EXPECT_FALSE(calls.empty()) << "at offset = " << offset; + if (!calls.empty()) { + EXPECT_EQ(offset, calls.front()); + calls.pop_front(); + } + memset(buffer, 'x', n); + *bytes_transferred = n; + return TF_SetStatus(status, TF_OK, ""); + }; + const uint32 block_count = 2; + tf_gcs_filesystem::RamFileBlockCache cache( + block_size, block_count * block_size, 0, fetcher); + std::vector out; + // Read blocks from the cache, and verify the LRU behavior based on the + // fetcher calls that the cache makes. + calls.push_back(0); + // Cache miss - drains an element from `calls`. + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); + // Cache hit - does not drain an element from `calls`. + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); + calls.push_back(block_size); + // Cache miss followed by cache hit. + TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out)); + calls.push_back(2 * block_size); + // Cache miss followed by cache hit. Causes eviction of LRU element. + TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out)); + // LRU element was at offset 0. Cache miss. + calls.push_back(0); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); + // Element at 2 * block_size is still in cache, and this read should update + // its position in the LRU list so it doesn't get evicted by the next read. + TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out)); + // Element at block_size was evicted. Reading this element will also cause + // the LRU element (at 0) to be evicted. + calls.push_back(block_size); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out)); + // Element at 0 was evicted again. + calls.push_back(0); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); +} + +TEST(RamFileBlockCacheTest, MaxStaleness) { + int calls = 0; + auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred, + TF_Status* status) { + calls++; + memset(buffer, 'x', n); + *bytes_transferred = n; + return TF_SetStatus(status, TF_OK, ""); + }; + std::vector out; + std::unique_ptr env(new NowSecondsEnv); + // Create a cache with max staleness of 2 seconds, and verify that it works as + // expected. + tf_gcs_filesystem::RamFileBlockCache cache1( + 8, 16, 2 /* max staleness */, fetcher, + [&env]() { return env->NowSeconds(); }); + // Execute the first read to load the block. + TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out)); + EXPECT_EQ(calls, 1); + // Now advance the clock one second at a time and redo the read. The call + // count should advance every 3 seconds (i.e. every time the staleness is + // greater than 2). + for (int i = 1; i <= 10; i++) { + env->SetNowSeconds(i + 1); + TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out)); + EXPECT_EQ(calls, 1 + i / 3); + } + // Now create a cache with max staleness of 0, and verify that it also works + // as expected. + calls = 0; + env->SetNowSeconds(0); + tf_gcs_filesystem::RamFileBlockCache cache2( + 8, 16, 0 /* max staleness */, fetcher, + [&env]() { return env->NowSeconds(); }); + // Execute the first read to load the block. + TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out)); + EXPECT_EQ(calls, 1); + // Advance the clock by a huge amount and verify that the cached block is + // used to satisfy the read. + env->SetNowSeconds(365 * 24 * 60 * 60); // ~1 year, just for fun. + TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out)); + EXPECT_EQ(calls, 1); +} + +TEST(RamFileBlockCacheTest, RemoveFile) { + int calls = 0; + auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred, + TF_Status* status) { + calls++; + char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x'; + if (offset > 0) { + // The first block is lower case and all subsequent blocks are upper case. + c = toupper(c); + } + memset(buffer, c, n); + *bytes_transferred = n; + return TF_SetStatus(status, TF_OK, ""); + }; + // This cache has space for 4 blocks; we'll read from two files. + const size_t n = 3; + tf_gcs_filesystem::RamFileBlockCache cache(8, 32, 0, fetcher); + std::vector out; + std::vector a(n, 'a'); + std::vector b(n, 'b'); + std::vector A(n, 'A'); + std::vector B(n, 'B'); + // Fill the cache. + TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out)); + EXPECT_EQ(out, a); + EXPECT_EQ(calls, 1); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out)); + EXPECT_EQ(out, A); + EXPECT_EQ(calls, 2); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out)); + EXPECT_EQ(out, b); + EXPECT_EQ(calls, 3); + TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out)); + EXPECT_EQ(out, B); + EXPECT_EQ(calls, 4); + // All four blocks should be in the cache now. + TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out)); + EXPECT_EQ(out, a); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out)); + EXPECT_EQ(out, A); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out)); + EXPECT_EQ(out, b); + TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out)); + EXPECT_EQ(out, B); + EXPECT_EQ(calls, 4); + // Remove the blocks from "a". + cache.RemoveFile("a"); + // Both blocks from "b" should still be there. + TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out)); + EXPECT_EQ(out, b); + TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out)); + EXPECT_EQ(out, B); + EXPECT_EQ(calls, 4); + // The blocks from "a" should not be there. + TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out)); + EXPECT_EQ(out, a); + EXPECT_EQ(calls, 5); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out)); + EXPECT_EQ(out, A); + EXPECT_EQ(calls, 6); +} + +TEST(RamFileBlockCacheTest, Prune) { + int calls = 0; + auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred, + TF_Status* status) { + calls++; + memset(buffer, 'x', n); + *bytes_transferred = n; + return TF_SetStatus(status, TF_OK, ""); + }; + std::vector out; + // Our fake environment is initialized with the current timestamp. + std::unique_ptr env(new NowSecondsEnv); + uint64 now = Env::Default()->NowSeconds(); + env->SetNowSeconds(now); + tf_gcs_filesystem::RamFileBlockCache cache( + 8, 32, 1 /* max staleness */, fetcher, + [&env]() { return env->NowSeconds(); }); + // Read three blocks into the cache, and advance the timestamp by one second + // with each read. Start with a block of "a" at the current timestamp `now`. + TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out)); + // Now load a block of a different file "b" at timestamp `now` + 1 + env->SetNowSeconds(now + 1); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out)); + // Now load a different block of file "a" at timestamp `now` + 1. When the + // first block of "a" expires, this block should also be removed because it + // also belongs to file "a". + TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out)); + // Ensure that all blocks are in the cache (i.e. reads are cache hits). + EXPECT_EQ(cache.CacheSize(), 24); + EXPECT_EQ(calls, 3); + TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out)); + EXPECT_EQ(calls, 3); + // Advance the fake timestamp so that "a" becomes stale via its first block. + env->SetNowSeconds(now + 2); + // The pruning thread periodically compares env->NowSeconds() with the oldest + // block's timestamp to see if it should evict any files. At the current fake + // timestamp of `now` + 2, file "a" is stale because its first block is stale, + // but file "b" is not stale yet. Thus, once the pruning thread wakes up (in + // one second of wall time), it should remove "a" and leave "b" alone. + uint64 start = Env::Default()->NowSeconds(); + do { + Env::Default()->SleepForMicroseconds(100000); + } while (cache.CacheSize() == 24 && Env::Default()->NowSeconds() - start < 3); + // There should be one block left in the cache, and it should be the first + // block of "b". + EXPECT_EQ(cache.CacheSize(), 8); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out)); + EXPECT_EQ(calls, 3); + // Advance the fake time to `now` + 3, at which point "b" becomes stale. + env->SetNowSeconds(now + 3); + // Wait for the pruner to remove "b". + start = Env::Default()->NowSeconds(); + do { + Env::Default()->SleepForMicroseconds(100000); + } while (cache.CacheSize() == 8 && Env::Default()->NowSeconds() - start < 3); + // The cache should now be empty. + EXPECT_EQ(cache.CacheSize(), 0); +} + +TEST(RamFileBlockCacheTest, ParallelReads) { + // This fetcher won't respond until either `callers` threads are calling it + // concurrently (at which point it will respond with success to all callers), + // or 10 seconds have elapsed (at which point it will respond with an error). + const int callers = 4; + BlockingCounter counter(callers); + auto fetcher = [&counter](const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred, + TF_Status* status) { + counter.DecrementCount(); + if (!counter.WaitFor(std::chrono::seconds(10))) { + // This avoids having the test time out, which is harder to debug. + return TF_SetStatus(status, TF_FAILED_PRECONDITION, + "desired concurrency not reached"); + } + memset(buffer, 'x', n); + *bytes_transferred = n; + return TF_SetStatus(status, TF_OK, ""); + }; + const int block_size = 8; + tf_gcs_filesystem::RamFileBlockCache cache( + block_size, 2 * callers * block_size, 0, fetcher); + std::vector> threads; + for (int i = 0; i < callers; i++) { + threads.emplace_back( + Env::Default()->StartThread({}, "caller", [&cache, i]() { + std::vector out; + TF_EXPECT_OK( + ReadCache(&cache, "a", i * block_size, block_size, &out)); + std::vector x(block_size, 'x'); + EXPECT_EQ(out, x); + })); + } + // The `threads` destructor blocks until the threads can be joined, once their + // respective reads finish (which happens once they are all concurrently being + // executed, or 10 seconds have passed). +} + +TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) { + // Concurrent reads to the same file blocks should be de-duplicated. + const size_t block_size = 16; + int num_requests = 0; + Notification notification; + auto fetcher = [&num_requests, ¬ification, block_size]( + const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred, + TF_Status* status) { + EXPECT_EQ(n, block_size); + EXPECT_EQ(offset, 0); + num_requests++; + memset(buffer, 'x', n); + *bytes_transferred = n; + notification.Notify(); + // Wait for other thread to issue read. + Env::Default()->SleepForMicroseconds(100000); // 0.1 secs + return TF_SetStatus(status, TF_OK, ""); + }; + tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0, + fetcher); + // Fork off thread for parallel read. + std::unique_ptr concurrent( + Env::Default()->StartThread({}, "concurrent", [&cache] { + std::vector out; + TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size / 2, &out)); + EXPECT_EQ(out.size(), block_size / 2); + })); + notification.WaitForNotification(); + std::vector out; + TF_EXPECT_OK(ReadCache(&cache, "", block_size / 2, block_size / 2, &out)); + EXPECT_EQ(out.size(), block_size / 2); + + EXPECT_EQ(1, num_requests); +} + +TEST(RamFileBlockCacheTest, Flush) { + int calls = 0; + auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred, + TF_Status* status) { + calls++; + memset(buffer, 'x', n); + *bytes_transferred = n; + return TF_SetStatus(status, TF_OK, ""); + }; + tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher); + std::vector out; + TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out)); + EXPECT_EQ(calls, 1); + cache.Flush(); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out)); + EXPECT_EQ(calls, 2); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index 5452907f3e8fbf..5931e229e28fac 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -19,9 +19,6 @@ package( cc_library( name = "concrete_function", - srcs = [ - "concrete_function.cc", - ], hdrs = [ "concrete_function.h", ], @@ -29,7 +26,6 @@ cc_library( ":function_metadata", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", - "//tensorflow/core:protos_all_cc", ], ) @@ -60,11 +56,16 @@ cc_library( "saved_model_utils.h", ], deps = [ + ":function_metadata", "//tensorflow/c:tf_tensor_internal", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", "//tensorflow/c/experimental/saved_model/core/revived_types:variable", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -92,16 +93,28 @@ cc_library( ) cc_library( - name = "tf_saved_model_impl", + name = "tf_concrete_function_test_protos", + testonly = True, + srcs = ["tf_concrete_function_test_protos.cc"], + hdrs = ["tf_concrete_function_test_protos.h"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "tf_saved_model_api", srcs = [ - "tf_saved_model_impl.cc", + "tf_saved_model_api.cc", ], - hdrs = ["tf_saved_model_impl.h"], + hdrs = ["tf_saved_model_api.h"], deps = [ ":concrete_function", ":saved_model_api", + "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/core:lib", - "//tensorflow/core/common_runtime/eager:context", "@com_google_absl//absl/types:optional", ], ) @@ -114,19 +127,11 @@ cc_library( "saved_model_api.h", ], visibility = ["//tensorflow/python:__pkg__"], -) - -filegroup( - name = "mobile_srcs_only_runtime", - srcs = [ - "concrete_function.cc", - "concrete_function.h", - "function_metadata.h", - "saved_model_api.h", - "tf_saved_model_impl.cc", - "tf_saved_model_impl.h", + deps = [ + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", ], - visibility = ["//tensorflow/core:__pkg__"], ) tf_cc_test( @@ -150,6 +155,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "object_graph_traversal_test", + srcs = [ + "object_graph_traversal_test.cc", + ], + deps = [ + ":saved_model_utils", + ":test_utils", + ":tf_concrete_function_test_protos", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime:core_cpu_lib", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:core", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "saved_variable_loading_test", srcs = [ @@ -172,3 +198,28 @@ tf_cc_test( "//tensorflow/core/common_runtime/eager:core", ], ) + +tf_cc_test( + name = "tf_concrete_function_loading_test", + srcs = [ + "tf_concrete_function_loading_test.cc", + ], + deps = [ + ":saved_model_utils", + ":test_utils", + ":tf_concrete_function_test_protos", + "//tensorflow/c:tensor_interface", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime:core_cpu_lib", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:core", + ], +) diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h index 22535641ef5d65..2cc627bcf2732f 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ +#include #include #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" -#include "tensorflow/core/framework/function.pb.h" namespace tensorflow { @@ -35,19 +35,14 @@ namespace tensorflow { // and have only a single implementation. class ConcreteFunction { public: - virtual ~ConcreteFunction() = 0; + virtual ~ConcreteFunction() = default; // This method returns the "Call" Op used to execute the function. - virtual ImmediateExecutionOperation* GetCallOp() = 0; + virtual Status GetCallOp(ImmediateOpPtr* out) = 0; - const std::vector& GetCaptures() - const; - const FunctionMetadata& GetFunctionMetadata() const; - - private: - FunctionMetadata metadata_; - std::vector captures_; - FunctionDef* function_; + virtual const std::vector& GetCaptures() + const = 0; + virtual const FunctionMetadata& GetFunctionMetadata() const = 0; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc b/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc new file mode 100644 index 00000000000000..1c70d40cadac4c --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc @@ -0,0 +1,380 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/string_view.h" +#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { +namespace { + +SavedObjectGraph ParseSavedObjectGraph(StringPiece text_proto) { + SavedObjectGraph value; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto), + &value)); + return value; +} + +constexpr absl::string_view kSingleChildFoo = R"( +nodes { + children { + node_id: 1 + local_name: "foo" + } + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } +} +nodes { + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } +} +)"; + +constexpr absl::string_view kSingleChildFooWithFuncBar = R"( +nodes { + children { + node_id: 1 + local_name: "foo" + } + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } +} +nodes { + children { + node_id: 2 + local_name: "bar" + } + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } +} +nodes { + function { + concrete_functions: "__inference_my_func_5" + function_spec { + fullargspec { + named_tuple_value { + name: "FullArgSpec" + values { + key: "args" + value { + list_value { + } + } + } + values { + key: "varargs" + value { + none_value { + } + } + } + values { + key: "varkw" + value { + none_value { + } + } + } + values { + key: "defaults" + value { + none_value { + } + } + } + values { + key: "kwonlyargs" + value { + list_value { + } + } + } + values { + key: "kwonlydefaults" + value { + none_value { + } + } + } + values { + key: "annotations" + value { + dict_value { + } + } + } + } + } + input_signature { + tuple_value { + } + } + } + } +} +concrete_functions { + key: "__inference_my_func_5" + value { + canonicalized_input_signature { + tuple_value { + values { + tuple_value { + } + } + values { + dict_value { + } + } + } + } + output_signature { + tensor_spec_value { + shape { + } + dtype: DT_FLOAT + } + } + } +} +)"; + +// In this graph, foo.baz and bar.wombat should point to the same object. +constexpr absl::string_view kMultiplePathsToChild = R"( +nodes { + children { + node_id: 1 + local_name: "foo" + } + children { + node_id: 2 + local_name: "bar" + } + children { + node_id: 3 + local_name: "signatures" + } + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } +} +nodes { + children { + node_id: 4 + local_name: "baz" + } + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } +} +nodes { + children { + node_id: 4 + local_name: "wombat" + } + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } +} +nodes { + user_object { + identifier: "signature_map" + version { + producer: 1 + min_consumer: 1 + } + } +} +nodes { + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } +} +)"; + +// `foo` has edge `bar`, which has edge `parent` pointing back to `foo`. +constexpr absl::string_view kCycleBetweenParentAndChild = R"( +nodes { + children { + node_id: 1 + local_name: "foo" + } + children { + node_id: 2 + local_name: "signatures" + } + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } +} +nodes { + children { + node_id: 3 + local_name: "bar" + } + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } +} +nodes { + user_object { + identifier: "signature_map" + version { + producer: 1 + min_consumer: 1 + } + } +} +nodes { + children { + node_id: 1 + local_name: "parent" + } + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } +} +)"; + +TEST(ObjectGraphTraversalTest, Success) { + SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo); + const SavedObject* obj = internal::FindNodeAtPath("foo", object_graph); + ASSERT_NE(nullptr, obj); + EXPECT_EQ(obj->kind_case(), SavedObject::kUserObject); + EXPECT_EQ(obj->user_object().identifier(), "_generic_user_object"); +} + +TEST(ObjectGraphTraversalTest, ObjectNotFound) { + SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo); + const SavedObject* obj = internal::FindNodeAtPath("bar", object_graph); + EXPECT_EQ(nullptr, obj); +} + +TEST(ObjectGraphTraversalTest, CaseSensitiveMismatch) { + SavedObjectGraph object_graph = ParseSavedObjectGraph(kSingleChildFoo); + const SavedObject* obj = internal::FindNodeAtPath("FOO", object_graph); + EXPECT_EQ(nullptr, obj); +} + +TEST(ObjectGraphTraversalTest, NestedObjectFound) { + SavedObjectGraph object_graph = + ParseSavedObjectGraph(kSingleChildFooWithFuncBar); + const SavedObject* obj = internal::FindNodeAtPath("foo.bar", object_graph); + ASSERT_NE(nullptr, obj); + EXPECT_EQ(obj->kind_case(), SavedObject::kFunction); + EXPECT_EQ(obj->function().concrete_functions_size(), 1); + EXPECT_EQ(obj->function().concrete_functions(0), "__inference_my_func_5"); +} + +TEST(ObjectGraphTraversalTest, MultiplePathsAliasSameObject) { + SavedObjectGraph object_graph = ParseSavedObjectGraph(kMultiplePathsToChild); + const SavedObject* foo_baz = + internal::FindNodeAtPath("foo.baz", object_graph); + ASSERT_NE(nullptr, foo_baz); + EXPECT_EQ(foo_baz->kind_case(), SavedObject::kUserObject); + EXPECT_EQ(foo_baz->user_object().identifier(), "_generic_user_object"); + + const SavedObject* bar_wombat = + internal::FindNodeAtPath("bar.wombat", object_graph); + ASSERT_NE(nullptr, bar_wombat); + EXPECT_EQ(bar_wombat->kind_case(), SavedObject::kUserObject); + EXPECT_EQ(bar_wombat->user_object().identifier(), "_generic_user_object"); + + EXPECT_EQ(foo_baz, bar_wombat); +} + +TEST(ObjectGraphTraversalTest, CyclesAreOK) { + SavedObjectGraph object_graph = + ParseSavedObjectGraph(kCycleBetweenParentAndChild); + const SavedObject* foo = internal::FindNodeAtPath("foo", object_graph); + ASSERT_NE(nullptr, foo); + EXPECT_EQ(foo->kind_case(), SavedObject::kUserObject); + EXPECT_EQ(foo->user_object().identifier(), "_generic_user_object"); + + const SavedObject* foo_bar = + internal::FindNodeAtPath("foo.bar", object_graph); + ASSERT_NE(nullptr, foo_bar); + EXPECT_EQ(foo_bar->kind_case(), SavedObject::kUserObject); + EXPECT_EQ(foo_bar->user_object().identifier(), "_generic_user_object"); + + const SavedObject* foo_bar_parent = + internal::FindNodeAtPath("foo.bar.parent", object_graph); + ASSERT_NE(nullptr, foo_bar_parent); + EXPECT_EQ(foo_bar_parent->kind_case(), SavedObject::kUserObject); + EXPECT_EQ(foo_bar_parent->user_object().identifier(), "_generic_user_object"); + + const SavedObject* foo_bar_parent_bar = + internal::FindNodeAtPath("foo.bar.parent.bar", object_graph); + ASSERT_NE(nullptr, foo_bar_parent_bar); + EXPECT_EQ(foo_bar_parent_bar->kind_case(), SavedObject::kUserObject); + EXPECT_EQ(foo_bar_parent_bar->user_object().identifier(), + "_generic_user_object"); + + EXPECT_EQ(foo, foo_bar_parent); + EXPECT_EQ(foo_bar, foo_bar_parent_bar); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/ops/BUILD b/tensorflow/c/experimental/saved_model/core/ops/BUILD index 9dd87e945786af..673ea1a80e261e 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/BUILD +++ b/tensorflow/c/experimental/saved_model/core/ops/BUILD @@ -14,6 +14,27 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "restore_ops", + srcs = [ + "restore_ops.cc", + ], + hdrs = [ + "restore_ops.h", + ], + deps = [ + "//tensorflow/c:tensor_interface", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/llvm_rtti", + ], +) + cc_library( name = "variable_ops", srcs = [ @@ -36,6 +57,34 @@ cc_library( ], ) +tf_cc_test( + name = "restore_ops_test", + srcs = [ + "restore_ops_test.cc", + ], + data = [ + "//tensorflow/cc/saved_model:saved_model_half_plus_two", + ], + deps = [ + ":restore_ops", + "//tensorflow/c:tensor_interface", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:test_utils", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime:core_cpu_lib", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:core", + ], +) + tf_cc_test( name = "variable_ops_test", srcs = [ @@ -47,6 +96,7 @@ tf_cc_test( "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:test_utils", "//tensorflow/core:all_kernels", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc new file mode 100644 index 00000000000000..6609ecee508fb1 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc @@ -0,0 +1,111 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h" + +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace internal { + +namespace { + +// Creates a scalar string tensorhandle containing a single string `s` +Status CreateStringScalarTensorHandle(ImmediateExecutionContext* ctx, + const std::string& s, + ImmediateTensorHandlePtr* out) { + AbstractTensorPtr tensor(ctx->CreateStringScalar(s)); + if (tensor.get() == nullptr) { + return errors::Internal( + "Failed to create scalar string tensor for checkpoint restore"); + } + + out->reset(ctx->CreateLocalHandle(tensor.get())); + return Status(); +} + +// Creates a Rank 1 string tensorhandle containing a single string `s` +Status CreateStringVectorTensorHandle(ImmediateExecutionContext* ctx, + const std::string& s, + ImmediateTensorHandlePtr* out) { + int64 flat_shape[] = {1}; + AbstractTensorPtr tensor(ctx->CreateTensor(DT_STRING, flat_shape)); + if (tensor.get() == nullptr) { + return errors::Internal( + "Failed to create vector string tensor for checkpoint restore"); + } + // Use placement new to construct the string, since we don't have + // access to Tensor::flat. This is conceptually equivalent to: + // tensor.flat()(0) = s + new (tensor->Data()) tstring(s); + + out->reset(ctx->CreateLocalHandle(tensor.get())); + return Status(); +} + +} // namespace + +Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix, + const std::string& checkpoint_key, DataType dtype, + ImmediateTensorHandlePtr* out) { + // Create the EagerOp + ImmediateOpPtr restore_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(restore_op->Reset("RestoreV2", "/cpu:0")); + TF_RETURN_IF_ERROR(restore_op->SetAttrTypeList("dtypes", &dtype, 1)); + + ImmediateTensorHandlePtr prefix_handle; + TF_RETURN_IF_ERROR( + CreateStringScalarTensorHandle(ctx, prefix, &prefix_handle)); + + ImmediateTensorHandlePtr names_handle; + TF_RETURN_IF_ERROR( + CreateStringVectorTensorHandle(ctx, checkpoint_key, &names_handle)); + + // Note that empty string is the slice spec used for a non-partitioned + // ResourceVariable: + // https://github.com/tensorflow/tensorflow/blob/06ff30f7ea35098cb68a231a9eb7ff3ff4be4e1e/tensorflow/python/training/saving/saveable_object_util.py#L194 + ImmediateTensorHandlePtr shapes_and_slices_handle; + TF_RETURN_IF_ERROR( + CreateStringVectorTensorHandle(ctx, "", &shapes_and_slices_handle)); + + TF_RETURN_IF_ERROR(restore_op->AddInput(prefix_handle.get())); + TF_RETURN_IF_ERROR(restore_op->AddInput(names_handle.get())); + TF_RETURN_IF_ERROR(restore_op->AddInput(shapes_and_slices_handle.get())); + + AbstractTensorHandle* restored_handle = nullptr; + int num_retvals = 1; + TF_RETURN_IF_ERROR(restore_op->Execute( + absl::MakeSpan(&restored_handle, num_retvals), &num_retvals)); + AbstractTensorHandlePtr owned_restored_handle(restored_handle); + if (!tensorflow::isa( + owned_restored_handle.get())) { + return errors::Internal("Unexpected tensor handle kind."); + } + out->reset(reinterpret_cast( + owned_restored_handle.release())); + return Status(); +} + +} // namespace internal +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h new file mode 100644 index 00000000000000..f215bc9e7ab8ca --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_ + +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace internal { + +// TODO(bmzhao): Add a function to restore multiple tensors in one call. + +// Restores a single non-partioned tensorhandle of dtype `dtype`, using +// checkpoint at `prefix`, with a value stored in `checkpoint_key`. +Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix, + const std::string& checkpoint_key, DataType dtype, + ImmediateTensorHandlePtr* out); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_ diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc new file mode 100644 index 00000000000000..52a652a90efe23 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc @@ -0,0 +1,111 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h" + +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/test_utils.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +std::string CheckpointPrefix(StringPiece saved_model_dir) { + return io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", + saved_model_dir, kSavedModelVariablesDirectory, + kSavedModelVariablesFilename); +} + +class RestoreOpsTest : public ::testing::Test { + public: + RestoreOpsTest() + : device_mgr_(testing::CreateTestingDeviceMgr()), + ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {} + + EagerContext* context() { return ctx_.get(); } + + private: + std::unique_ptr device_mgr_; + EagerContextPtr ctx_; +}; + +// One way of obtaining the checkpointa checkpoint's tensor names is: +// bazel run //tensorflow/python/tools:inspect_checkpoint -- --all_tensors +// --file_name="$CKPT_PREFIX". +// Here are the values for VarsAndArithmeticObjectGraph: +// tensor: child/z/.ATTRIBUTES/VARIABLE_VALUE (float32) [] +// 3.0 +// tensor: x/.ATTRIBUTES/VARIABLE_VALUE (float32) [] +// 1.0 +// tensor: y/.ATTRIBUTES/VARIABLE_VALUE (float32) [] +// 2.0 + +TEST_F(RestoreOpsTest, RestoreSuccessful) { + ImmediateTensorHandlePtr x_handle; + TF_EXPECT_OK(internal::SingleRestore( + context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"), + "x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle)); + AbstractTensorPtr x = testing::TensorHandleToTensor(x_handle.get()); + EXPECT_EQ(x->Type(), DT_FLOAT); + EXPECT_EQ(x->NumElements(), 1); + EXPECT_EQ(x->NumDims(), 0); + EXPECT_FLOAT_EQ(*reinterpret_cast(x->Data()), 1.0f); + + ImmediateTensorHandlePtr y_handle; + TF_EXPECT_OK(internal::SingleRestore( + context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"), + "y/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &y_handle)); + AbstractTensorPtr y = testing::TensorHandleToTensor(y_handle.get()); + EXPECT_EQ(y->Type(), DT_FLOAT); + EXPECT_EQ(y->NumElements(), 1); + EXPECT_EQ(y->NumDims(), 0); + EXPECT_FLOAT_EQ(*reinterpret_cast(y->Data()), 2.0f); + + ImmediateTensorHandlePtr z_handle; + TF_EXPECT_OK(internal::SingleRestore( + context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"), + "child/z/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &z_handle)); + AbstractTensorPtr z = testing::TensorHandleToTensor(z_handle.get()); + EXPECT_EQ(z->Type(), DT_FLOAT); + EXPECT_EQ(z->NumElements(), 1); + EXPECT_EQ(z->NumDims(), 0); + EXPECT_FLOAT_EQ(*reinterpret_cast(z->Data()), 3.0f); +} + +TEST_F(RestoreOpsTest, BadCheckpointPrefixShouldFail) { + ImmediateTensorHandlePtr x_handle; + Status status = internal::SingleRestore( + context(), CheckpointPrefix("unknown_bad_checkpoint_prefix"), + "x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle); + EXPECT_FALSE(status.ok()) << status.error_message(); +} + +TEST_F(RestoreOpsTest, BadCheckpointKeyShouldFail) { + ImmediateTensorHandlePtr x_handle; + Status status = internal::SingleRestore( + context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"), + "bad_checkpoint_key", DT_FLOAT, &x_handle); + EXPECT_FALSE(status.ok()) << status.error_message(); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc index 09c45332efc873..55a4a32e983eff 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/test_utils.h" #include "tensorflow/c/tensor_interface.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -39,17 +40,8 @@ ImmediateTensorHandlePtr CreateScalarTensorHandle(EagerContext* context, class VariableOpsTest : public ::testing::Test { public: VariableOpsTest() - : device_mgr_(std::make_unique(DeviceFactory::NewDevice( - "CPU", {}, "/job:localhost/replica:0/task:0"))), - ctx_(new EagerContext( - SessionOptions(), - tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - tensorflow::ContextMirroringPolicy::MIRRORING_NONE, - /* async= */ false, - /* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(), - /* device_mgr_owned= */ false, /* rendezvous= */ nullptr, - /* custom_kernel_creator= */ nullptr, - /* cluster_flr= */ nullptr)) {} + : device_mgr_(testing::CreateTestingDeviceMgr()), + ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {} EagerContext* context() { return ctx_.get(); } diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 84fad2ea8f68f7..8bb15674db0e9f 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -58,3 +58,24 @@ cc_library( "//tensorflow/c/eager:immediate_execution_tensor_handle", ], ) + +cc_library( + name = "tf_concrete_function", + srcs = [ + "tf_concrete_function.cc", + ], + hdrs = [ + "tf_concrete_function.h", + ], + deps = [ + ":tensorhandle_convertible", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:concrete_function", + "//tensorflow/c/experimental/saved_model/core:function_metadata", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + ], +) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc new file mode 100644 index 00000000000000..aa6f0e7205e8db --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" + +#include +#include + +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +TFConcreteFunction::TFConcreteFunction( + const std::string& name, + std::vector captures, + FunctionMetadata metadata, ImmediateExecutionContext* ctx) + : name_(name), + captures_(std::move(captures)), + metadata_(std::move(metadata)), + ctx_(ctx) {} + +TFConcreteFunction::~TFConcreteFunction() { + Status status = ctx_->RemoveFunction(name_); + if (!status.ok()) { + LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " + << status.error_message(); + } +} + +Status TFConcreteFunction::Create( + const FunctionDef* function_def, + std::vector captures, + FunctionMetadata metadata, ImmediateExecutionContext* ctx, + std::unique_ptr* out) { + TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); + out->reset(new TFConcreteFunction(function_def->signature().name(), + std::move(captures), std::move(metadata), + ctx)); + return Status(); +} + +const std::vector& +TFConcreteFunction::GetCaptures() const { + return captures_; +} + +const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const { + return metadata_; +} + +Status TFConcreteFunction::GetCallOp(ImmediateOpPtr* out) { + out->reset(ctx_->CreateOperation()); + // In eager mode, TF2 python executes functions by constructing an op with + // the name of the functiondef: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 + // In graph mode, we create a PartitionedCallOp instead: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 + + // TODO(bmzhao): After discussing with Allen, we should execute this via a + // PartitionedCallOp for compatibility with "tooling that assumes functions in + // graphs are PartitionedCallOps". + TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); + return Status(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h new file mode 100644 index 00000000000000..71c8322414d2c6 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// TF Eager Runtime-based implementation of a "ConcreteFunction" loaded from a +// saved model. +class TFConcreteFunction : public ConcreteFunction { + public: + // Factory function for creating a TFConcreteFunction. + // + // Params: + // function_def - The function_def associated with the created + // TFConcreteFunction. TFConcreteFunction will register this + // function_def with `ctx` on creation, and de-register it on + // destruction. function_def must be non-null, but + // otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // TFConcreteFunction. + // metadata - The FunctionMetadata associated with this TFConcreteFunction. + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFConcreteFunction. + // out - The output TFConcreteFunction. + static Status Create(const FunctionDef* function_def, + std::vector captures, + FunctionMetadata metadata, + ImmediateExecutionContext* ctx, + std::unique_ptr* out); + + // This method returns the "Call" Op used to execute the function. + Status GetCallOp(ImmediateOpPtr* out) override; + + const std::vector& GetCaptures() + const override; + + const FunctionMetadata& GetFunctionMetadata() const override; + + ~TFConcreteFunction() override; + + private: + TFConcreteFunction(const std::string& name, + std::vector captures, + FunctionMetadata metadata, ImmediateExecutionContext* ctx); + + TFConcreteFunction(const TFConcreteFunction&) = delete; + TFConcreteFunction& operator=(const TFConcreteFunction&) = delete; + + // Name of the FunctionDef corresponding to this TFConcreteFunction + std::string name_; + std::vector captures_; + FunctionMetadata metadata_; + ImmediateExecutionContext* ctx_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index 196420eb537d08..2037c4886de502 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -15,16 +15,133 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" +#include #include +#include "absl/strings/str_split.h" +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/tf_tensor_internal.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" +#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" namespace tensorflow { namespace internal { +namespace { + +// This returns the size of `tf.nest.flatten(value)`, on values that are +// used in tf.function's input_signatures. +int FlattenedSize(const tensorflow::StructuredValue& value, Status* status) { + // This follows the logic from + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775 + switch (value.kind_case()) { + case StructuredValue::kDictValue: { + const DictValue& dict = value.dict_value(); + int size = 0; + for (const auto& field : dict.fields()) { + size += FlattenedSize(field.second, status); + } + return size; + } + case StructuredValue::kTupleValue: { + const TupleValue& tuple = value.tuple_value(); + int size = 0; + for (const StructuredValue& value : tuple.values()) { + size += FlattenedSize(value, status); + } + return size; + } + case StructuredValue::kListValue: { + const ListValue& list = value.list_value(); + int size = 0; + for (const StructuredValue& value : list.values()) { + size += FlattenedSize(value, status); + } + return size; + } + case StructuredValue::kTensorSpecValue: { + return 1; + } + case StructuredValue::kNoneValue: { + // Base case: do nothing. + // This arises, for example, as the top-level object of an output + // signature when there are no return values. + return 0; + } + default: { + status->Update(errors::Internal("Unhandled structured value kind ", + value.kind_case())); + return 0; + } + } +} + +// Perform some basic sanity checks on SavedConcreteFunction's input and +// output signatures with respect to the corresponding FunctionDef's input +// and output args. +Status ValidateSavedFunctionCompatibleWithFunctionDef( + const SavedConcreteFunction& saved_concrete_function, + const FunctionDef* function_def) { + // tf.functions go through many transformations before becoming FunctionDefs + // 1. flatten user-provided inputs: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2671-L2675 + // 2. convert user-provided inputs to tensors: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2687-L2688 + // 3. filter any non-tensor, non-variable inputs: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1840-L1841 + // 4. concatenate any captured inputs: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1912 + + // Since our API is limited to tf.functions annotated with input signatures, + // conditions 2 and 3 are trivially satisfied. + // We need to ensure that: + // flatten(input_signature).size() + captures.size() = fdef.signature().size() + // A concrete function's serialized "canonicalized_input_signature" comes + // from encoding its "structured_input_signature" field: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/saved_model/function_serialization.py#L70-L71 + // The "structured_input_signature" is guaranteed to be a tuple of the python + // args, kwargs that correspond to the tf.function: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979 + + const std::string& name = function_def->signature().name(); + const StructuredValue& input_signature = + saved_concrete_function.canonicalized_input_signature(); + Status status; + int input_signature_size = FlattenedSize(input_signature, &status); + TF_RETURN_IF_ERROR(status); + if (input_signature_size + saved_concrete_function.bound_inputs_size() != + function_def->signature().input_arg_size()) { + return errors::FailedPrecondition( + "FunctionDef ", name, " has ", + function_def->signature().input_arg_size(), + " inputs, but the SavedConcreteFunction has ", input_signature_size, + " flattened user inputs and ", + saved_concrete_function.bound_inputs_size(), " captured inputs."); + } + + const StructuredValue& output_signature = + saved_concrete_function.output_signature(); + int output_signature_size = FlattenedSize(output_signature, &status); + TF_RETURN_IF_ERROR(status); + if (output_signature_size != function_def->signature().output_arg_size()) { + return errors::FailedPrecondition( + "FunctionDef ", name, " has ", + function_def->signature().output_arg_size(), + " outputs, but the SavedConcreteFunction has ", output_signature_size, + " flattened outputs."); + } + + return status; +} + +} // namespace Status TensorProtoToConstant(ImmediateExecutionContext* ctx, const TensorProto& proto, @@ -54,5 +171,80 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx, return Status(); } +Status LoadTFConcreteFunction( + const SavedConcreteFunction& saved_concrete_function, + const FunctionDef* function_def, + const std::unordered_map>& + captured_objects, + ImmediateExecutionContext* ctx, std::unique_ptr* out) { + TF_RETURN_IF_ERROR(ValidateSavedFunctionCompatibleWithFunctionDef( + saved_concrete_function, function_def)); + + // Copy over captures + std::vector captures; + captures.reserve(saved_concrete_function.bound_inputs_size()); + for (int bound_input : saved_concrete_function.bound_inputs()) { + auto iter = captured_objects.find(bound_input); + if (iter == captured_objects.end()) { + return errors::FailedPrecondition("Failed to find bound_input ", + bound_input, + " for SavedConcreteFunction"); + } + captures.push_back(iter->second->handle()); + } + + return TFConcreteFunction::Create(function_def, std::move(captures), {}, ctx, + out); +} + +const SavedObject* FindNodeAtPath(StringPiece path, + const SavedObjectGraph& object_graph) { + const auto& nodes = object_graph.nodes(); + if (nodes.empty()) { + return nullptr; + } + + // Starting from the root, iterate through the saved object graph, matching + // object names as we go. + const SavedObject* current_node = &nodes.Get(0); + + for (absl::string_view object_name : absl::StrSplit(path, '.')) { + auto child_node_iter = std::find_if( + current_node->children().begin(), current_node->children().end(), + [object_name]( + const TrackableObjectGraph::TrackableObject::ObjectReference& obj) { + return object_name == obj.local_name(); + }); + if (child_node_iter == current_node->children().end()) { + return nullptr; + } + current_node = &nodes.Get(child_node_iter->node_id()); + } + + return current_node; +} + +std::unordered_map +NodeToAttrMap(const tensorflow::GraphDef& graphdef) { + std::unordered_map + result; + for (const tensorflow::NodeDef& node : graphdef.node()) { + result[node.name()] = &node.attr(); + } + return result; +} + +std::unordered_map +FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) { + std::unordered_map + result; + for (const FunctionDef& function_def : library.function()) { + result[function_def.signature().name()] = &function_def; + } + return result; +} + } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h index ab1531709e43d1..57f30afa91bb5e 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -19,10 +19,18 @@ limitations under the License. // Some internal utility functions for the SavedModelAPI, factored out into a // separately unit-testable header. +#include +#include + #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" namespace tensorflow { @@ -43,6 +51,32 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx, const SavedVariable& variable, std::unique_ptr* output); +// Creates a TFConcreteFunction from a SavedConcreteFunction. +Status LoadTFConcreteFunction( + const SavedConcreteFunction& saved_concrete_function, + const FunctionDef* function_def, + const std::unordered_map>& + captured_objects, + ImmediateExecutionContext* ctx, std::unique_ptr* out); + +// Find the SavedObject in `object_graph` at location `path`. `path` must be a +// dot-delimited string of object names relative to the root object. If no +// object is found, returns nullptr. Callers must ensure `object_graph` outlives +// the returned pointer. +const SavedObject* FindNodeAtPath(StringPiece path, + const SavedObjectGraph& object_graph); + +// Maps each node in `graphdef` to its corresponding Attribute Map. +// Callers must ensure that `graphdef` outlives the returned map. +std::unordered_map +NodeToAttrMap(const tensorflow::GraphDef& graphdef); + +// Maps the name of each FunctionDef in `library` to its corresponding +// FunctionDef. Callers must ensure `library` outlives the returned map. +std::unordered_map +FunctionNameToFunctionDefMap(const FunctionDefLibrary& library); + } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/test_utils.cc b/tensorflow/c/experimental/saved_model/core/test_utils.cc index 920b7dd01391f9..b803d129b909bf 100644 --- a/tensorflow/c/experimental/saved_model/core/test_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/test_utils.cc @@ -139,5 +139,13 @@ void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a, } } +AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle) { + Status status; + AbstractTensorPtr tensor(handle->Resolve(&status)); + CHECK(status.ok()) << status.error_message(); + CHECK_NE(tensor.get(), nullptr); + return tensor; +} + } // namespace testing } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/test_utils.h b/tensorflow/c/experimental/saved_model/core/test_utils.h index fe80a66064930c..bdc1ca762eed22 100644 --- a/tensorflow/c/experimental/saved_model/core/test_utils.h +++ b/tensorflow/c/experimental/saved_model/core/test_utils.h @@ -69,6 +69,10 @@ void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer, void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a, void* b); +// Converts a TensorHandle to a Tensor, and dies if unsuccessful. This should +// only be used for testing purposes. +AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle); + } // namespace testing } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc new file mode 100644 index 00000000000000..05fbac13077bcc --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc @@ -0,0 +1,271 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" +#include "tensorflow/c/experimental/saved_model/core/test_utils.h" +#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { +namespace { + +class SavedConcreteFunctionLoadingTest : public ::testing::Test { + public: + SavedConcreteFunctionLoadingTest() + : device_mgr_(testing::CreateTestingDeviceMgr()), + ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {} + + EagerContext* context() { return ctx_.get(); } + + private: + std::unique_ptr device_mgr_; + EagerContextPtr ctx_; +}; + +class DummyCapture : public TensorHandleConvertible { + public: + DummyCapture(ImmediateExecutionContext* ctx, int8 value) + : TensorHandleConvertible( + testing::CreateTensorHandle(ctx, DT_FLOAT, {2, 4}, value)) {} +}; + +FunctionDef FuncDefWithNumInputsOutputs(int num_inputs, int num_outputs) { + FunctionDef func; + OpDef* signature = func.mutable_signature(); + for (int i = 0; i < num_inputs; ++i) { + signature->add_input_arg(); + } + for (int i = 0; i < num_outputs; ++i) { + signature->add_output_arg(); + } + return func; +} + +// A SavedConcreteFunction whose canonicalized input signature +// has less inputs than its corresponding FunctionDef should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, TooFewInputsInSavedConcreteFunction) { + // `saved` has 1 input + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::SingleArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + + // `func` has 2 inputs + FunctionDef func = FuncDefWithNumInputsOutputs(2, 0); + + std::unique_ptr result; + Status status = + internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose canonicalized input signature length + +// captures is less than its corresponding FunctionDef should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, + TooFewInputsWithCapturesInSavedConcreteFunction) { + // `saved` has 1 input, and 1 capture, for a total of 2 inputs + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::SingleArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + saved.add_bound_inputs(5); + + // `func` has 3 inputs + FunctionDef func = FuncDefWithNumInputsOutputs(3, 0); + + std::unordered_map> captures; + captures[5] = std::make_unique(context(), 10); + + std::unique_ptr result; + Status status = internal::LoadTFConcreteFunction(saved, &func, captures, + context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose canonicalized input signature +// has more inputs than its corresponding FunctionDef should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, TooManyInputsInSavedConcreteFunction) { + // `saved` has 3 inputs + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::ThreeArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + + // `func` has 2 inputs + FunctionDef func = FuncDefWithNumInputsOutputs(2, 0); + + std::unique_ptr result; + Status status = + internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose canonicalized input signature +// has the same number of inputs than its corresponding FunctionDef, but has +// additional captures should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, + TooManyInputsWithCaptureInSavedConcreteFunction) { + // `saved` has 3 inputs, and 1 capture, for a total of 4 inputs. + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::ThreeArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + saved.add_bound_inputs(5); + + // `func` has 3 inputs. + FunctionDef func = FuncDefWithNumInputsOutputs(3, 0); + + std::unordered_map> captures; + captures[5] = std::make_unique(context(), 10); + + std::unique_ptr result; + Status status = internal::LoadTFConcreteFunction(saved, &func, captures, + context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose capture refers to an index not in the capture +// map should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, ImproperCaptureIndex) { + // `saved` has 3 inputs, 1 capture, for a total of 4 inputs + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::ThreeArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + // Capture is at index "10" + saved.add_bound_inputs(10); + + // `func` has 4 inputs + FunctionDef func = FuncDefWithNumInputsOutputs(4, 0); + + // `captures` only has a capture for index 5 + std::unordered_map> captures; + captures[5] = std::make_unique(context(), 10); + + std::unique_ptr result; + Status status = internal::LoadTFConcreteFunction(saved, &func, captures, + context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose outputs are fewer than its corresponding +// functiondef should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, TooFewOutputsInSavedConcreteFunction) { + // `saved` has 0 inputs, 1 output + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::ZeroArgInputSignature(); + *saved.mutable_output_signature() = testing::SingleReturnOutputSignature(); + + // `func` has 0 inputs, 2 outputs + FunctionDef func = FuncDefWithNumInputsOutputs(0, 2); + + std::unique_ptr result; + Status status = + internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose outputs exceed its corresponding functiondef +// should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, + TooManyOutputsInSavedConcreteFunction) { + // `saved` has 1 input, 3 outputs + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::SingleArgInputSignature(); + *saved.mutable_output_signature() = testing::ThreeReturnOutputSignature(); + + // `func` has 1 input, 2 outputs + FunctionDef func = FuncDefWithNumInputsOutputs(1, 2); + + std::unique_ptr result; + Status status = + internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose (inputs + captures) = functiondef inputs, +// and whose outputs = functiondef outputs should successfully load. +TEST_F(SavedConcreteFunctionLoadingTest, SuccessfulLoad) { + // `saved` has 1 input, 2 captures, 3 outputs + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::SingleArgInputSignature(); + *saved.mutable_output_signature() = testing::ThreeReturnOutputSignature(); + saved.add_bound_inputs(2); + saved.add_bound_inputs(5); + + // `func` has 3 inputs, 3 outputs + FunctionDef func = FuncDefWithNumInputsOutputs(3, 3); + + std::unordered_map> captures; + captures[2] = std::make_unique(context(), 1); + captures[5] = std::make_unique(context(), 10); + + std::unique_ptr result; + Status status = internal::LoadTFConcreteFunction(saved, &func, captures, + context(), &result); + TF_EXPECT_OK(status) << status.error_message(); +} + +// A TFConcreteFunction should register functiondefs on creation, and +// remove them upon deletion. +TEST_F(SavedConcreteFunctionLoadingTest, RegistersAndRemovesFunctionDefs) { + std::string func_name = "FooBarBazWombatFunction"; + + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::ZeroArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + FunctionDef func = FuncDefWithNumInputsOutputs(0, 0); + *func.mutable_signature()->mutable_name() = func_name; + + { + std::unique_ptr result; + Status status = + internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); + TF_EXPECT_OK(status) << status.error_message(); + // The function should be registered with context. + EXPECT_TRUE(context()->FindFunctionByName(func_name)); + } + + // After `result's` destructor runs, the function should no longer be + // registered with context. + EXPECT_FALSE(context()->FindFunctionByName(func_name)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc new file mode 100644 index 00000000000000..6250af6dba1359 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc @@ -0,0 +1,213 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h" + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { +namespace testing { +namespace { + +constexpr absl::string_view kZeroArgInputSignatureTextProto = R"( +tuple_value: { + values: { + tuple_value: { + } + } + values: { + dict_value: { + } + } +} +)"; + +constexpr absl::string_view kSingleArgInputSignatureTextProto = R"( +tuple_value: { + values: { + tuple_value: { + values: { + tensor_spec_value: { + name : "x" + shape: { + dim: { + size: 1 + } + dim: { + size: 10 + } + } + dtype: DT_FLOAT + } + } + } + } + values: { + dict_value: { + } + } +} +)"; + +constexpr absl::string_view kThreeArgInputSignatureTextProto = R"( +tuple_value: { + values: { + tuple_value: { + values: { + tensor_spec_value: { + name : "x" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } + values: { + tensor_spec_value: { + name : "y" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } + values: { + tensor_spec_value: { + name : "z" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } + } + } + values: { + dict_value: { + } + } +} + +)"; + +constexpr absl::string_view kZeroReturnOutputSignatureTextProto = R"( +none_value: {} +)"; + +constexpr absl::string_view kSingleReturnOutputSignatureTextProto = R"( +tensor_spec_value: { + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT +} +)"; + +constexpr absl::string_view kThreeReturnOutputSignatureTextProto = R"( +tuple_value: { + values: { + dict_value: { + fields: { + key : "a" + value: { + tensor_spec_value: { + name : "0/a" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } + } + fields: { + key : "b" + value: { + tensor_spec_value: { + name : "0/b" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } + } + } + } + values: { + tensor_spec_value: { + name : "1" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } +} +)"; + +StructuredValue ParseStructuredValue(absl::string_view text_proto) { + StructuredValue value; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto), + &value)); + return value; +} + +} // namespace + +StructuredValue ZeroArgInputSignature() { + return ParseStructuredValue(kZeroArgInputSignatureTextProto); +} + +StructuredValue SingleArgInputSignature() { + return ParseStructuredValue(kSingleArgInputSignatureTextProto); +} + +StructuredValue ThreeArgInputSignature() { + return ParseStructuredValue(kThreeArgInputSignatureTextProto); +} + +StructuredValue ZeroReturnOutputSignature() { + return ParseStructuredValue(kZeroReturnOutputSignatureTextProto); +} + +StructuredValue SingleReturnOutputSignature() { + return ParseStructuredValue(kSingleReturnOutputSignatureTextProto); +} + +StructuredValue ThreeReturnOutputSignature() { + return ParseStructuredValue(kThreeReturnOutputSignatureTextProto); +} + +} // namespace testing +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h new file mode 100644 index 00000000000000..8aa7d5694e1a12 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ + +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { +namespace testing { + +// Returns a StructuredValue corresponding to the serialized InputSignature of a +// tf.function with 0 inputs +StructuredValue ZeroArgInputSignature(); + +// Returns a StructuredValue corresponding to the serialized InputSignature of a +// tf.function with 1 input +StructuredValue SingleArgInputSignature(); + +// Returns a StructuredValue corresponding to the serialized InputSignature of a +// tf.function with 3 inputs +StructuredValue ThreeArgInputSignature(); + +// Returns a StructuredValue corresponding to the serialized OutputSignature of +// a tf.function with no return values +StructuredValue ZeroReturnOutputSignature(); + +// Returns a StructuredValue corresponding to the serialized OutputSignature of +// a tf.function with a single tensor output +StructuredValue SingleReturnOutputSignature(); + +// Returns a StructuredValue corresponding to the serialized OutputSignature of +// a tf.function with three tensor outputs +StructuredValue ThreeReturnOutputSignature(); + +} // namespace testing +} // namespace tensorflow +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc similarity index 80% rename from tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.cc rename to tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index 4c36adc6de535c..225ba1db9f44fe 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -13,34 +13,34 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h" +#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h" #include #include #include #include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" -#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/platform/errors.h" namespace tensorflow { -Status TFSavedModelAPIImpl::GetFunction(const std::string& function_path, - ConcreteFunction** function) { +Status TFSavedModelAPI::GetFunction(const std::string& function_path, + ConcreteFunction** function) { // TODO(bmzhao): Add support for retrieving a function. return errors::Unimplemented( "Retrieving functions is unimplemented currently"); } -Status TFSavedModelAPIImpl::GetSignatureDefFunction( +Status TFSavedModelAPI::GetSignatureDefFunction( const std::string& signature_def_key, ConcreteFunction** function) { // TODO(bmzhao): Add support for retrieving a signaturedef function. return errors::Unimplemented( "Retrieving functions is unimplemented currently"); } -std::vector TFSavedModelAPIImpl::ListFunctions() { +std::vector TFSavedModelAPI::ListFunctions() { std::vector result; result.reserve(functions_.size()); for (ConcreteFunction& function : functions_) { @@ -49,10 +49,10 @@ std::vector TFSavedModelAPIImpl::ListFunctions() { return result; } -Status TFSavedModelAPIImpl::Load( +Status TFSavedModelAPI::Load( const std::string& directory, const absl::optional>& tags, - EagerContext* context, std::unique_ptr* out) { + ImmediateExecutionContext* context, std::unique_ptr* out) { // TODO(bmzhao): Add support for loading a TFSavedModelImpl. return errors::Unimplemented( "TFSavedModelAPIImpl loading is unimplemented currently"); diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h similarity index 55% rename from tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h rename to tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h index 4612c610d46478..cc631a9f3ae77f 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h @@ -21,14 +21,30 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" -#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { -class TFSavedModelAPIImpl : public SavedModelAPI { +// An implementation of the SavedModelAPI using the TF Eager runtime. See +// https://github.com/tensorflow/community/blob/master/rfcs/20200218-tf-c-saved-model.md +// Conceptually, there are many differences between a tf.function and +// a FunctionDef is executed by the C API. +// 1. A tf.function is polymorphic, meaning it can correspond to multiple +// ConcreteFunctions (of differing shapes, python arguments, etc). A +// FunctionDef corresponds to a single ConcreteFunction. +// 2. A tf.function can take arbitrary python inputs, whereas the FunctionDef +// only accepts tensors. +// 3. A tf.function is a closure that can contain captured inputs, whereas +// FunctionDefs loaded from SavedModels are "functional" (all inputs are +// explicitly passed as arguments). +// The SavedModelAPI only supports loading tf.functions annotated with input +// signatures so that we ensure that there is a 1:1 mapping between tf.function +// -> FunctionDef, and have a guarantee that all inputs are tensors. +// (https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/eager/def_function.py#L1167-L1171), +class TFSavedModelAPI : public SavedModelAPI { public: Status GetFunction(const std::string& function_path, ConcreteFunction** function) override; @@ -39,14 +55,15 @@ class TFSavedModelAPIImpl : public SavedModelAPI { static Status Load( const std::string& directory, const absl::optional>& tags, - EagerContext* context, std::unique_ptr* out); + ImmediateExecutionContext* context, + std::unique_ptr* out); std::vector ListFunctions() override; - ~TFSavedModelAPIImpl() override = default; + ~TFSavedModelAPI() override = default; private: - TFSavedModelAPIImpl() = default; + TFSavedModelAPI() = default; std::vector functions_; }; diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 888c284bb12d1e..b22718dfd04a7a 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -41,11 +41,13 @@ cc_library( ":tensorhandle_list", ":tensorhandle_list_type", "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status_internal", "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:tfe_op_internal", "//tensorflow/c/experimental/saved_model/core:concrete_function", "//tensorflow/c/experimental/saved_model/core:function_metadata", + "//tensorflow/core:lib", ], ) @@ -144,7 +146,7 @@ cc_library( "//tensorflow/c:tf_status_internal", "//tensorflow/c/eager:tfe_context_internal", "//tensorflow/c/experimental/saved_model/core:saved_model_api", - "//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl", + "//tensorflow/c/experimental/saved_model/core:tf_saved_model_api", "//tensorflow/core:lib", "//tensorflow/core/common_runtime/eager:context", "@com_google_absl//absl/types:optional", @@ -205,9 +207,13 @@ tf_cc_test( ], deps = [ "//tensorflow/c:tf_status", + "//tensorflow/c:tf_tensor", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:c_api_test_util", + "//tensorflow/c/experimental/saved_model/public:concrete_function", "//tensorflow/c/experimental/saved_model/public:saved_model_api", + "//tensorflow/c/experimental/saved_model/public:tensorhandle_list", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc index dd54416ddf95f3..12d49212a881be 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -15,12 +15,15 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h" #include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/core/platform/status.h" extern "C" { @@ -34,8 +37,11 @@ const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures()); } -TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) { - return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp()); +TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func, + TF_Status* status) { + tensorflow::ImmediateOpPtr call_op(nullptr); + status->status = tensorflow::unwrap(func)->GetCallOp(&call_op); + return tensorflow::wrap(call_op.release()); } } // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc index 2768217fb49929..983c98affb232b 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -22,7 +22,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" -#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h" +#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h" @@ -44,8 +44,8 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx, status->status = tensorflow::errors::Unimplemented( "TFRT SavedModel implementation will be added in the future"); } else { - std::unique_ptr saved_model; - status->status = tensorflow::TFSavedModelAPIImpl::Load( + std::unique_ptr saved_model; + status->status = tensorflow::TFSavedModelAPI::Load( dirname, absl::nullopt, tensorflow::down_cast( tensorflow::unwrap(ctx)), @@ -74,8 +74,8 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, status->status = tensorflow::errors::Unimplemented( "TFRT SavedModel implementation will be added in the future"); } else { - std::unique_ptr saved_model; - status->status = tensorflow::TFSavedModelAPIImpl::Load( + std::unique_ptr saved_model; + status->status = tensorflow::TFSavedModelAPI::Load( dirname, tagset, tensorflow::down_cast( tensorflow::unwrap(ctx)), diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h index 2a87214270c8bb..944ddecea16300 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -41,7 +41,7 @@ TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( // Returns a TFE_Op suitable for executing this function. TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( - TF_ConcreteFunction* func); + TF_ConcreteFunction* func, TF_Status* status); #ifdef __cplusplus } // end extern "C" diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index a0ed0d9f245cbc..3021a38e88874f 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -248,15 +248,22 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index, size_t len, TF_Status* status) { TF_SetStatus(status, TF_OK, ""); auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); - tensorflow::AllocatorAttributes attr = cc_ctx->output_alloc_attr(index); - auto* allocator = cc_ctx->get_allocator(attr); - void* data = tensorflow::allocate_tensor("TF_AllocateOutput", len, allocator); - TF_Tensor* result = TF_NewTensor(dtype, dims, num_dims, data, len, - tensorflow::deallocate_buffer, allocator); - TF_SetOutput(context, index, result, status); - if (TF_GetCode(status) != TF_OK) { - TF_DeleteTensor(result); + + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + tensorflow::gtl::ArraySlice dimarray( + reinterpret_cast(dims), num_dims); + tensorflow::Tensor* tensor; + tensorflow::Status s = cc_ctx->allocate_output( + index, tensorflow::TensorShape(dimarray), &tensor); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); return nullptr; } - return result; + TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return nullptr; + } + return tf_tensor; } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 0c959e327a8395..d091146c75a2c3 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") -load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available") +load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available", "if_llvm_system_z_available") package( default_visibility = ["//visibility:private"], @@ -35,7 +35,9 @@ cc_library( "flags.h", "quantize.h", ], - defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]), + defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]) + if_llvm_system_z_available([ + "TF_LLVM_S390X_AVAILABLE=1", + ]), visibility = ["//tensorflow/python:__pkg__"], deps = [ ":aot_only_var_handle_op", @@ -73,7 +75,9 @@ cc_library( "@llvm-project//llvm:Target", "@llvm-project//llvm:X86CodeGen", # fixdeps: keep "//tensorflow/core:regexp_internal", - ] + if_llvm_aarch64_available([ + ] + if_llvm_system_z_available([ + "@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep + ]) + if_llvm_aarch64_available([ "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep ]), ) @@ -114,7 +118,9 @@ cc_library( "@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep "@llvm-project//llvm:Target", "@llvm-project//llvm:X86CodeGen", # fixdeps: keep - ] + if_llvm_aarch64_available([ + ] + if_llvm_system_z_available([ + "@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep + ]) + if_llvm_aarch64_available([ "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep ]), ) diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index fe0d6d5a074a6c..b22baa305e7d84 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -168,6 +168,12 @@ static void InitializeTargets() { LLVMInitializeAArch64TargetInfo(); LLVMInitializeAArch64TargetMC(); LLVMInitializeAArch64AsmPrinter(); +#endif +#if TF_LLVM_S390X_AVAILABLE + LLVMInitializeSystemZTarget(); + LLVMInitializeSystemZTargetInfo(); + LLVMInitializeSystemZTargetMC(); + LLVMInitializeSystemZAsmPrinter(); #endif LLVMInitializeARMTarget(); LLVMInitializeARMTargetInfo(); diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index f2b28e70ff1517..29f37bf749867a 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -434,5 +434,6 @@ def target_llvm_triple(): "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", "//tensorflow:macos": "x86_64-none-darwin", "//tensorflow:windows": "x86_64-none-windows", + "//tensorflow:linux_s390x": "systemz-none-linux-gnu", "//conditions:default": "x86_64-pc-linux", }) diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 3beccf0881b0f7..347bae087dff0a 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -1,6 +1,7 @@ package( default_visibility = [ "//tensorflow/compiler/tf2xla:internal", + "//tensorflow/core/tpu:__subpackages__", ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index e3542586c8909d..48347a2915fa6c 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -350,7 +350,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* compilation_result; xla::LocalExecutable* executable; - ResourceVarsSnapshot variables; + ResourceVarsSnapshot variables_snapshot; { std::vector variable_infos; OP_REQUIRES_OK( @@ -361,8 +361,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { variable_infos, constants_, /*lazy=*/false, &client, &compilation_result, &executable); OP_REQUIRES_OK(ctx, s); - OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_, - variable_infos, &variables)); + OP_REQUIRES_OK(ctx, + SnapshotResourceVariables(ctx, resources_, variable_infos, + &variables_snapshot)); } se::Stream* stream = @@ -377,7 +378,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { client, allocator, /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), platform_info_.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, compilation_result, variables, + launch_context.PopulateInputs(ctx, compilation_result, variables_snapshot, /*missing_ctx_input_prefix=*/0); // Execute the computation. @@ -415,10 +416,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { const xla::HloInputOutputAliasConfig& input_output_alias = executable->executable()->module().input_output_alias_config(); - OP_REQUIRES_OK( - ctx, launch_context.PopulateOutputs( - ctx, compilation_result, run_result.ConsumeValueOrDie(), - /*missing_ctx_input_prefix=*/0, input_output_alias, variables)); + OP_REQUIRES_OK(ctx, + launch_context.PopulateOutputs( + ctx, compilation_result, run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/0, input_output_alias, + variables_snapshot)); VLOG(1) << "Done"; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index dc5df94e963cfc..55ff57a04c5e55 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1096,33 +1096,33 @@ StatusOr IsIdentityDrivingConstsInLoop(Node* node) { return true; } -absl::flat_hash_set GetOrCreateWhitelist() { - absl::flat_hash_map>* whitelist_table = - tensorflow::GetWhitelistTable(); +absl::flat_hash_set GetOrCreateAllowlist() { + absl::flat_hash_map>* allowlist_table = + tensorflow::GetAllowlistTable(); MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - absl::flat_hash_set whitelist; + absl::flat_hash_set allowlist; for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) { if (s == "FUSIBLE") { - for (auto pair : *whitelist_table) { - whitelist.insert(pair.second.begin(), pair.second.end()); + for (auto pair : *allowlist_table) { + allowlist.insert(pair.second.begin(), pair.second.end()); } - } else if (whitelist_table->contains(s)) { - auto v = whitelist_table->at(s); - whitelist.insert(v.begin(), v.end()); + } else if (allowlist_table->contains(s)) { + auto v = allowlist_table->at(s); + allowlist.insert(v.begin(), v.end()); } else if (!s.empty()) { // Should be a user provided TF operation. - whitelist.insert(string(s)); + allowlist.insert(string(s)); } } - if (VLOG_IS_ON(2) && !whitelist.empty()) { - std::vector vwhitelist(whitelist.begin(), whitelist.end()); - absl::c_sort(vwhitelist); + if (VLOG_IS_ON(2) && !allowlist.empty()) { + std::vector vallowlist(allowlist.begin(), allowlist.end()); + absl::c_sort(vallowlist); VLOG(2) << "XLA clustering will only consider the following TF operations: " - << absl::StrJoin(vwhitelist, " "); + << absl::StrJoin(vallowlist, " "); } - return whitelist; + return allowlist; } Status MarkForCompilationPassImpl::FindCompilationCandidates() { @@ -1156,12 +1156,12 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size(); - auto whitelist = GetOrCreateWhitelist(); + auto allowlist = GetOrCreateAllowlist(); std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); // Check that user's provided TF operation really exists. - for (const auto& s : whitelist) { + for (const auto& s : allowlist) { if (!all_ops.contains(string(s))) { return errors::InvalidArgument( "The operation '", s, @@ -1206,7 +1206,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { continue; } - if (!whitelist.empty() && !whitelist.contains(node->def().op())) { + if (!allowlist.empty() && !allowlist.contains(node->def().op())) { VLOG(1) << "Rejecting TF operation " << node->def().op() << " as it is not listed in --tf_xla_ops_to_cluster."; continue; @@ -1781,7 +1781,7 @@ Status MarkForCompilationPass::RunForTest( return MarkForCompilation(options, debug_options); } -absl::flat_hash_map>* GetWhitelistTable() { +absl::flat_hash_map>* GetAllowlistTable() { // Table format: category name: {list of TF operations in that category} static absl::flat_hash_map>* result = new absl::flat_hash_map>{ @@ -1845,7 +1845,7 @@ absl::flat_hash_map>* GetWhitelistTable() { namespace testing { void ResetClusterSequenceNumber() { cluster_sequence_num = 0; } -absl::flat_hash_set GetKnownXLAWhitelistOp() { +absl::flat_hash_set GetKnownXLAAllowlistOp() { absl::flat_hash_set result{"AdjustContrastv2", "AdjustHue", "AdjustSaturation", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index 8b660710898017..0e9a64e7f28082 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -58,7 +58,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_node_info = nullptr); -absl::flat_hash_map>* GetWhitelistTable(); +absl::flat_hash_map>* GetAllowlistTable(); namespace testing { // DO NOT USE IN PRODUCTION. @@ -66,8 +66,8 @@ namespace testing { // Resets some internal state to let us write reliable unit tests. void ResetClusterSequenceNumber(); -// Return a list of operation that we choose not to put into the whitelist. -absl::flat_hash_set GetKnownXLAWhitelistOp(); +// Return a list of operation that we choose not to put into the allowlist. +absl::flat_hash_set GetKnownXLAAllowlistOp(); } // namespace testing } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 0e1cc2d19fee86..3ae72eb514cf3b 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -1802,34 +1802,34 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) { EXPECT_NE(clusters["relu0"], clusters["relu1"]); } } -TEST(XlaCompilationTest, XLALiteWhitelist) { - auto* whitelist_table = tensorflow::GetWhitelistTable(); - absl::flat_hash_set hwhitelist; +TEST(XlaCompilationTest, XLALiteAllowlist) { + auto* allowlist_table = tensorflow::GetAllowlistTable(); + absl::flat_hash_set hallowlist; std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); // Check that all the operations in the table are existing TF operations - for (auto pair : *whitelist_table) { - hwhitelist.insert(pair.second.begin(), pair.second.end()); + for (auto pair : *allowlist_table) { + hallowlist.insert(pair.second.begin(), pair.second.end()); for (auto op : pair.second) { ASSERT_TRUE(all_ops.contains(op)); } } - // Check that all registered XLA operation are in the whitelist + // Check that all registered XLA operation are in the allowlist // table or are known to not be in it. absl::flat_hash_set known_not_in_list = - tensorflow::testing::GetKnownXLAWhitelistOp(); + tensorflow::testing::GetKnownXLAAllowlistOp(); std::vector unknow_op; for (string op : vall_ops) { - if (!hwhitelist.contains(op) && !known_not_in_list.contains(op)) { + if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) { unknow_op.push_back(op); } } EXPECT_TRUE(unknow_op.empty()) << "Someone added support for a new TF opeations inside XLA. They must " - "be included in the XLALite whitelist or blacklist:\n" + "be included in the XLALite allowlist or blacklist:\n" << absl::StrJoin(unknow_op, "\n"); } } // namespace diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index eb31b23c99176f..7f107aaef11ecd 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -231,49 +231,6 @@ void XlaComputationLaunchContext::PopulateInputs( } } -static bool MustAliasOutput( - const xla::HloInputOutputAliasConfig& input_output_alias, int output_num) { - xla::ShapeIndex output_index; - if (input_output_alias.shape().IsTuple()) { - output_index = {output_num}; - } else { - DCHECK_EQ(output_num, 0) - << "output_num must be 0 for non-tuple shapes but is " << output_num; - output_index = {}; - } - if (input_output_alias.shape().tuple_shapes_size() == 0) { - return false; - } - return input_output_alias.OutputHasAlias(output_index) && - input_output_alias.GetAliasedParameter(output_index).value().kind == - xla::HloInputOutputAliasConfig::kUserAlias; -} - -// Returns an aliased tensor if it exists, nullptr otherwise. -static const Tensor* FindAliasedTensorForOutput( - int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix, - const xla::HloInputOutputAliasConfig& input_output_alias, - absl::Span input_mapping, - const ResourceVarsSnapshot& resource_var_snapshots) { - if (MustAliasOutput(input_output_alias, output_num)) { - int xla_param = input_output_alias.GetAliasedParameter({output_num}) - .value() - .parameter_number; - int tf_param = input_mapping[xla_param] - missing_ctx_input_prefix; - const Tensor* input_tensor = &ctx->input(tf_param); - - // If input tensor is a resource variable, alias to the snapshot we took at - // entry time. - if (input_tensor->dtype() == DT_RESOURCE) { - auto& v = resource_var_snapshots.at(missing_ctx_input_prefix + tf_param); - CHECK(v.has_value()); - return &v.value(); - } - return input_tensor; - } - return nullptr; -} - // Construct the tensor for given type and buffer. static Tensor MakeTensor(DataType dtype, const TensorShape& shape, se::DeviceMemoryBase buffer, Allocator* allocator) { @@ -293,66 +250,40 @@ static Tensor GetOrCreateTensorForOutput( const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype, const TensorShape& output_shape, se::DeviceMemoryBase output_buffer, Allocator* output_allocator) { - if (const Tensor* aliased_tensor = FindAliasedTensorForOutput( - output_num, ctx, missing_ctx_input_prefix, input_output_alias, - input_mapping, resource_var_snapshots)) { - return *aliased_tensor; + xla::ShapeIndex output_index = input_output_alias.shape().IsTuple() + ? xla::ShapeIndex({output_num}) + : xla::ShapeIndex({}); + CHECK(input_output_alias.shape().IsTuple() || output_num == 0); + if (absl::optional alias = + input_output_alias.GetAliasedParameter(output_index)) { + int tf_param = + input_mapping[alias->parameter_number] - missing_ctx_input_prefix; + const Tensor* input_tensor = &ctx->input(tf_param); + + // If input tensor is a resource variable, alias to the snapshot we took at + // entry time. + if (input_tensor->dtype() == DT_RESOURCE) { + const absl::optional& v = + resource_var_snapshots.at(missing_ctx_input_prefix + tf_param); + CHECK(v.has_value()); + return *v; + } + return *input_tensor; } return MakeTensor(output_dtype, output_shape, output_buffer, output_allocator); } -static Status SetBufferForTensorUnderAllocateXlaTensors( - const xla::HloInputOutputAliasConfig& input_output_alias, int output_num, - OpKernelContext* ctx, int i, tensorflow::TensorShape shape, - xla::ScopedShapedBuffer* output, - std::shared_ptr definition_event, se::Stream* stream, - bool use_multiple_streams) { - if (MustAliasOutput(input_output_alias, output_num)) { - return errors::Unimplemented( - "Aliasing is not yet supported for allocate_xla_tensors_."); - } - Tensor* output_tensor; - TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); +static void PopulateXlaTensor(Tensor* output_tensor, + xla::ScopedShapedBuffer* output, int output_num, + se::Stream* stream, bool use_multiple_streams, + std::shared_ptr definition_event) { XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); - if (xla_tensor) { - xla_tensor->set_shaped_buffer(output->TakeSubTree({output_num})); - if (use_multiple_streams) { - xla_tensor->ResetDefinitionEvent(definition_event, stream); - } - } else { - // xla_tensor wasn't valid, which must mean this is a zero-element - // tensor. - CHECK_EQ(output_tensor->TotalBytes(), 0); + CHECK(xla_tensor); + xla_tensor->set_shaped_buffer(output->TakeSubTree({output_num})); + if (use_multiple_streams) { + xla_tensor->ResetDefinitionEvent(definition_event, stream); } - return Status::OK(); -} - -static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors( - const xla::HloInputOutputAliasConfig& input_output_alias, int output_num, - OpKernelContext* ctx, int i, const XlaCompiler::ResourceUpdate& write, - xla::ScopedShapedBuffer* output, - std::shared_ptr definition_event, - absl::Span variable_infos, se::Stream* stream, - bool use_multiple_streams) { - if (MustAliasOutput(input_output_alias, output_num)) { - return errors::Unimplemented( - "Aliasing is not yet supported for allocate_xla_tensors_."); - } - Tensor output_tensor; - TF_RETURN_IF_ERROR( - ctx->allocate_temp(write.type, write.shape, &output_tensor)); - if (write.shape.num_elements() > 0) { - XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); - CHECK(xla_tensor); - xla_tensor->set_shaped_buffer(output->TakeSubTree({output_num})); - if (use_multiple_streams) { - xla_tensor->ResetDefinitionEvent(definition_event, stream); - } - } - *variable_infos[i].var()->tensor() = output_tensor; - variable_infos[i].var()->is_initialized |= write.modified; - return Status::OK(); } // Sets output `output_num` for `ctx` provided it is known at a compile time. @@ -526,10 +457,12 @@ Status XlaComputationLaunchContext::PopulateOutputs( ctx->set_output(i, ctx->input(input_index)); } else { if (allocate_xla_tensors_) { - TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors( - input_output_alias, output_num, ctx, i, shape, &output, - definition_event, stream, use_multiple_streams_)); - + Tensor* output_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); + if (output_tensor->TotalBytes() > 0) { + PopulateXlaTensor(output_tensor, &output, output_num, stream, + use_multiple_streams_, definition_event); + } } else { se::DeviceMemoryBase buffer = output.buffer({output_num}); Tensor output_tensor = GetOrCreateTensorForOutput( @@ -561,20 +494,24 @@ Status XlaComputationLaunchContext::PopulateOutputs( return errors::Internal("Mismatched type in variable write"); } + Tensor output_tensor; if (allocate_xla_tensors_) { - TF_RETURN_IF_ERROR(SetBufferForResourceVarTensorUnderAllocateXlaTensors( - input_output_alias, output_num, ctx, i, write, &output, - definition_event, variable_infos, stream, use_multiple_streams_)); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(write.type, write.shape, &output_tensor)); + if (write.shape.num_elements() > 0) { + PopulateXlaTensor(&output_tensor, &output, output_num, stream, + use_multiple_streams_, definition_event); + } } else { se::DeviceMemoryBase buffer = output.buffer({output_num}); output.set_buffer(se::OwningDeviceMemory(), {output_num}); - Tensor output_tensor = GetOrCreateTensorForOutput( + output_tensor = GetOrCreateTensorForOutput( output_num, ctx, missing_ctx_input_prefix, input_output_alias, compilation_result->input_mapping, resource_var_snapshots, write.type, write.shape, buffer, allocator); - *variable_infos[i].var()->tensor() = output_tensor; - variable_infos[i].var()->is_initialized |= write.modified; } + *variable_infos[i].var()->tensor() = output_tensor; + variable_infos[i].var()->is_initialized |= write.modified; ++output_num; } return Status::OK(); diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 84314b9a4ca53e..c7bda887db037f 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -125,7 +125,7 @@ gentbl( #TODO(aminim): revisit the naming and grouping of these rules post-move. gentbl( - name = "xla_canonicalize_inc_gen", + name = "canonicalize_inc_gen", tbl_outs = [ ("-gen-rewriters", "lib/Dialect/mhlo/transforms/generated_canonicalize.inc"), ], @@ -203,12 +203,12 @@ cc_library( ], includes = ["include"], deps = [ + ":canonicalize_inc_gen", ":chlo_ops_inc_gen", ":convert_op_folder", ":hlo_ops_base_inc_gen", ":hlo_ops_inc_gen", ":infer_fusibility_op_interface", - ":xla_canonicalize_inc_gen", "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -265,7 +265,7 @@ cc_library( ) cc_library( - name = "xla_sink_constants_to_control_flow", + name = "sink_constants_to_control_flow", srcs = ["lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], deps = [ @@ -280,8 +280,8 @@ cc_library( ) cc_library( - name = "map_xla_to_scalar_op", - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"], + name = "map_lmhlo_to_scalar_op", + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"], deps = [ ":hlo", ":lhlo", @@ -306,7 +306,7 @@ cc_library( deps = [ ":hlo", ":lhlo", - ":map_xla_to_scalar_op", + ":map_lmhlo_to_scalar_op", "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:Affine", @@ -350,15 +350,16 @@ cc_library( ) cc_library( - name = "xla_legalize_to_linalg", - srcs = ["lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc"], + name = "legalize_to_linalg", + srcs = ["lib/Dialect/mhlo/transforms/legalize_to_linalg.cc"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], deps = [ ":hlo", ":lhlo", - ":map_xla_to_scalar_op", + ":map_lmhlo_to_scalar_op", "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:Pass", @@ -369,8 +370,8 @@ cc_library( ) cc_library( - name = "xla_transform_unranked_hlo", - srcs = ["lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc"], + name = "transform_unranked_hlo", + srcs = ["lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], deps = [ ":hlo", @@ -390,7 +391,7 @@ cc_library( deps = [ ":hlo", ":lhlo", - ":map_xla_to_scalar_op", + ":map_lmhlo_to_scalar_op", "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:GPUDialect", @@ -482,8 +483,8 @@ tf_cc_test( ) cc_library( - name = "xla_hlo_fusion", - srcs = ["lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc"], + name = "mhlo_fusion", + srcs = ["lib/Dialect/mhlo/transforms/mhlo_fusion.cc"], deps = [ ":cycle_detector", ":hlo", @@ -499,7 +500,7 @@ cc_library( ) gentbl( - name = "xla_legalize_to_standard_inc_gen", + name = "legalize_to_standard_inc_gen", tbl_outs = [ ("-gen-rewriters", "lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"), ], @@ -515,7 +516,7 @@ gentbl( ) cc_library( - name = "xla_legalize_control_flow", + name = "legalize_control_flow", srcs = ["lib/Dialect/mhlo/transforms/legalize_control_flow.cc"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], deps = [ @@ -531,13 +532,13 @@ cc_library( ) cc_library( - name = "xla_legalize_to_standard", + name = "legalize_to_standard", srcs = ["lib/Dialect/mhlo/transforms/legalize_to_standard.cc"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], deps = [ ":hlo", - ":xla_legalize_tanh_to_approximation", - ":xla_legalize_to_standard_inc_gen", + ":legalize_tanh_to_approximation", + ":legalize_to_standard_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -548,7 +549,7 @@ cc_library( ) cc_library( - name = "xla_legalize_tanh_to_approximation", + name = "legalize_tanh_to_approximation", srcs = ["lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc"], hdrs = [ "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", @@ -566,7 +567,7 @@ cc_library( ) gentbl( - name = "xla_lower_complex_inc_gen", + name = "lower_complex_inc_gen", tbl_outs = [ ("-gen-rewriters", "lib/Dialect/mhlo/transforms/generated_lower_complex.inc"), ], @@ -583,7 +584,8 @@ gentbl( ) cc_library( - name = "xla_lower", + #TODO(aminim): find a better name here? + name = "mhlo_to_mhlo_lowering_patterns", srcs = [ "lib/Dialect/mhlo/transforms/generated_lower_complex.inc", "lib/Dialect/mhlo/transforms/lower_complex.cc", @@ -608,7 +610,7 @@ cc_library( ) cc_library( - name = "xla_materialize_broadcasts", + name = "materialize_broadcasts", srcs = [ "lib/Dialect/mhlo/transforms/materialize_broadcasts.cc", ], @@ -626,7 +628,7 @@ cc_library( ) cc_library( - name = "xla_unfuse_batch_norm", + name = "unfuse_batch_norm", srcs = ["lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc"], hdrs = [ "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", @@ -653,7 +655,7 @@ cc_library( ) cc_library( - name = "xla_test_passes", + name = "test_passes", srcs = [ "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", "lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc", @@ -667,8 +669,10 @@ cc_library( ":hlo", ":lhlo", ":lhlo_legalize_to_llvm", # build-cleaner: keep - ":xla_materialize_broadcasts", # build-cleaner: keep - ":xla_unfuse_batch_norm", # build-cleaner: keep + ":materialize_broadcasts", # build-cleaner: keep + ":unfuse_batch_norm", # build-cleaner: keep + "@llvm-project//mlir:AffineToStandardTransforms", + "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMDialect", @@ -680,3 +684,40 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "all_passes_for_testing", + visibility = [ + "//tensorflow/compiler/mlir:__subpackages__", + ], + deps = [ + ":chlo_legalize_to_hlo", + ":hlo_dialect_registration", + ":hlo_legalize_to_lhlo", + ":legalize_control_flow", + ":legalize_tanh_to_approximation", + ":legalize_to_linalg", + ":legalize_to_standard", + ":lhlo", + ":lhlo_copy_removal", + ":lhlo_fuse_linalg", + ":lhlo_legalize_to_affine", + ":lhlo_legalize_to_gpu", + ":lhlo_legalize_to_parallel_loops", + ":mhlo_fusion", + ":mhlo_to_mhlo_lowering_patterns", + ":sink_constants_to_control_flow", + ":test_passes", + ":transform_unranked_hlo", + ], +) + +cc_binary( + name = "mlir-hlo-opt", + deps = [ + ":all_passes_for_testing", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:MlirOptMain", + ], +) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index 453d755a485ac4..1fbf55ded83435 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -28,18 +28,18 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { -namespace xla_chlo { +namespace chlo { -class XlaHloClientDialect : public Dialect { +class HloClientDialect : public Dialect { public: - explicit XlaHloClientDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "xla_chlo"; } + explicit HloClientDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "chlo"; } }; #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" -} // namespace xla_chlo +} // namespace chlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index dc6161f756e6ea..79d6fb25318dc7 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -17,12 +17,12 @@ limitations under the License. // These ops are not necessarily orthogonal or optimized for transformation but // for ease of expression in certain cases deemed important for client // libraries (i.e. implicit broadcasting, helper ops, etc). -// This dialect is considered to exist in addition to augment the xla_hlo +// This dialect is considered to exist in addition to augment the mhlo // dialect for ergonomic needs, not duplicate/replace it. // // The typical use of this dialect is for client libraries to be able to emit // less constrained ops and rely on the conversion framework to lower any -// xla_chlo ops to canonical xla_hlo ops. +// chlo ops to canonical mhlo ops. // // See: https://www.tensorflow.org/xla/operation_semantics @@ -35,16 +35,16 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" def HLOClient_Dialect : Dialect { - let name = "xla_chlo"; - let cppNamespace = "xla_chlo"; + let name = "chlo"; + let cppNamespace = "chlo"; let summary = [{ - XLA Client HLO Ops + Client HLO Ops }]; let description = [{ This dialect contains ops that align closely with the API surface area of the XlaBuilder C++ API, where such ops have semantics that go beyond - what exists in the lower level dialects (such as `xla_hlo`). Essentially, + what exists in the lower level dialects (such as `mhlo`). Essentially, whenever the client library uses syntactic sugar or composition of multiple ops for an API call, this dialect tries to model the API call and provide conversion patterns to fully materialize into lower level @@ -60,12 +60,12 @@ class HLOClient_Op traits> : } //===----------------------------------------------------------------------===// -// XLA binary elementwise op definitions. +// CHLO binary elementwise op definitions. // From the client perspective, each of these support both explicit rank // broadcasting (via the broadcast_dimensions attribute) and implicit degenerate // shape broadcasting. // -// These correspond to operations in the xla_hlo dialect without the +// These correspond to operations in the mhlo dialect without the // "broadcast_" prefix, except that those ops require same-shaped operands and // results. // diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index 9f8b5e6e741a93..4de52639bca936 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines the operations used in the XLA dialect. +// This file defines the operations used in the MHLO dialect. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ @@ -37,12 +37,12 @@ class OpBuilder; #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc" -namespace xla_hlo { +namespace mhlo { -class XlaHloDialect : public Dialect { +class MhloDialect : public Dialect { public: - explicit XlaHloDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "xla_hlo"; } + explicit MhloDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "mhlo"; } // Registered hook to materialize a constant operation from a given attribute // value with the desired resultant type. @@ -82,7 +82,7 @@ class TokenType : public Type::TypeBase { // %1 = index_cast %0 : index to i64 // %2 = dim %arg0, 1 : memref // %3 = index_cast %2 : index to i64 -// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3) +// %4 = "mhlo.scalars_to_dimension_tensor"(%1, %3) // : (i64, i64) -> tensor<2xi64> // // and returns %4 as the shape value. @@ -93,7 +93,7 @@ LogicalResult deriveShapeFromFirstOperand( #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" -} // end namespace xla_hlo +} // end namespace mhlo } // end namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 304bc1ef22e8f2..0ed4235e23f2b5 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -13,10 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This is the operation definition file for XLA HLO ops which map to the -// traditional definition in xla_data.proto (or are aligned with the goals -// thereof). -// See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto +// This is the operation definition file for MHLO ops. #ifndef HLO_OPS #define HLO_OPS @@ -29,8 +26,8 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" def HLO_Dialect : Dialect { - let name = "xla_hlo"; - let cppNamespace = "xla_hlo"; + let name = "mhlo"; + let cppNamespace = "mhlo"; } class HLO_Op traits> : @@ -44,7 +41,7 @@ class HLO_Op traits> : } //===----------------------------------------------------------------------===// -// XLA nullary op definitions. +// MHLO nullary op definitions. //===----------------------------------------------------------------------===// def HLO_ConstOp : HLO_Op<"constant", @@ -78,6 +75,8 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp { // TODO(b/130357376): Iota has special conversion logic to HLO. let hasCustomHLOConverter = 1; + let hasCanonicalizer = 1; + let hasFolder = 1; } def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> { @@ -112,7 +111,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { } //===----------------------------------------------------------------------===// -// XLA unary elementwise op definitions. +// MHLO unary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions @@ -263,7 +262,7 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", HLO_FpOrComplexTensor>, BASE_HLO_TanhOp; //===----------------------------------------------------------------------===// -// XLA binary elementwise op definitions. +// MHLO binary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations @@ -362,7 +361,7 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", } //===----------------------------------------------------------------------===// -// XLA binary logical elementwise op definitions. +// MHLO binary logical elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations @@ -380,7 +379,7 @@ def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp; def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp; //===----------------------------------------------------------------------===// -// XLA communication op definitions. +// MHLO communication op definitions. //===----------------------------------------------------------------------===// // InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'. @@ -480,7 +479,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> { } //===----------------------------------------------------------------------===// -// XLA parallelism related op definitions. +// MHLO parallelism related op definitions. //===----------------------------------------------------------------------===// def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, @@ -491,7 +490,7 @@ def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, } //===----------------------------------------------------------------------===// -// XLA control flow op definitions. +// MHLO control flow op definitions. //===----------------------------------------------------------------------===// def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> { @@ -639,7 +638,7 @@ def HLO_ReduceOp: HLO_Op<"reduce", [ } //===----------------------------------------------------------------------===// -// XLA tuple op definitions. +// MHLO tuple op definitions. //===----------------------------------------------------------------------===// def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO_GetTupleElementOp { let arguments = (ins @@ -683,7 +682,7 @@ def HLO_CompareOp: HLO_Op<"compare", } //===----------------------------------------------------------------------===// -// XLA Slice definitions. +// MHLO Slice definitions. //===----------------------------------------------------------------------===// def HLO_SliceOp: HLO_Op< @@ -744,7 +743,7 @@ def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice", //===----------------------------------------------------------------------===// -// XLA Other op definitions. +// MHLO Other op definitions. //===----------------------------------------------------------------------===// def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>, @@ -1319,7 +1318,7 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> { } //===----------------------------------------------------------------------===// -// XLA RngUniform Operator. +// MHLO RngUniform Operator. //===----------------------------------------------------------------------===// def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp { let arguments = (ins @@ -1346,7 +1345,7 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { } //===----------------------------------------------------------------------===// -// XLA Quantize Operator. +// MHLO Quantize Operator. //===----------------------------------------------------------------------===// def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>, BASE_HLO_DequantizeOp { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 84045d25e3e3d9..7f9784d7f11ea7 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -35,7 +35,7 @@ def HLO_Complex : Complex>; defvar BroadcastDimAttr = I64ElementsAttr; //===----------------------------------------------------------------------===// -// XLA on tensors type definitions. +// MHLO on tensors type definitions. //===----------------------------------------------------------------------===// // Token type. @@ -78,7 +78,7 @@ def HLO_StaticShapeTensor : StaticShapeTensorOf<[ AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; //===----------------------------------------------------------------------===// -// XLA on tensors combined type definitions. +// MHLO on tensors combined type definitions. //===----------------------------------------------------------------------===// // Any integer or floating-point tensor types @@ -97,7 +97,7 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>; def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; //===----------------------------------------------------------------------===// -// XLA nullary op definitions. +// MHLO nullary op definitions. //===----------------------------------------------------------------------===// class BASE_HLO_ConstOp { @@ -117,7 +117,7 @@ class BASE_HLO_IotaOp { } //===----------------------------------------------------------------------===// -// XLA unary elementwise op definitions. +// MHLO unary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index c6ea1fe97493fe..e1ae9e1fb8916d 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -25,19 +25,19 @@ def NullArrayAttr : NativeCodeCall<"ArrayAttr()">; def CastIntElementsAttr : NativeCodeCall<"$0.cast()">; class ConstantSplat : NativeCodeCall< - "xla::getSplat(&$_builder, $0, " # value # ")">; + "hlo::getSplat(&$_builder, $0, " # value # ")">; def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< - "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">; + "hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1)">; def BinBroadcastDimensionsNonEmpty : NativeCodeCall< - "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">; + "hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">; // Here, the element type can be any integer or float type. But, note that only // 32 bit integers are supported for the value. class GetScalarOfType : NativeCodeCall< - "xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; + "hlo::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; #endif // HLO_UTILS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index ad1aa78b7f8235..fd31bec44c0960 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -35,18 +35,18 @@ class OpBuilder; #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc" -namespace xla_lhlo { +namespace lmhlo { -class XlaLhloDialect : public Dialect { +class LmhloDialect : public Dialect { public: - explicit XlaLhloDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "xla_lhlo"; } + explicit LmhloDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "lmhlo"; } }; #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" -} // namespace xla_lhlo +} // namespace lmhlo } // end namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 5a8c3ccd4a4449..87082219db706b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This is the operation definition file for LXLA. +// This is the operation definition file for LMHLO, the "late" MHLO variant of +// the dialect, which operates on buffers instead of tensors. // -// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to +// This file largely overlaps with mhlo_ops.td at a logic level. It's tempting to // merge these two files together, but we need to consider the following // obstacles: // * We need to have a common representation for arguments. That is to say, @@ -38,12 +39,12 @@ include "mlir/Interfaces/ViewLikeInterface.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" def LHLO_Dialect : Dialect { - let name = "xla_lhlo"; - let cppNamespace = "xla_lhlo"; + let name = "lmhlo"; + let cppNamespace = "lmhlo"; } //===----------------------------------------------------------------------===// -// XLA type definitions. +// LMHLO type definitions. //===----------------------------------------------------------------------===// // Any integer tensor types @@ -66,7 +67,7 @@ def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; //===----------------------------------------------------------------------===// -// XLA nullary op definitions. +// LMHLO nullary op definitions. //===----------------------------------------------------------------------===// class LHLO_Op traits> : @@ -86,7 +87,7 @@ def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp { } //===----------------------------------------------------------------------===// -// XLA unary elementwise op definitions. +// LMHLO unary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions @@ -157,7 +158,7 @@ def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HL def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp; //===----------------------------------------------------------------------===// -// XLA binary elementwise op definitions. +// LMHLO binary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations @@ -212,7 +213,7 @@ def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp; def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp; //===----------------------------------------------------------------------===// -// XLA control flow op definitions. +// LMHLO control flow op definitions. //===----------------------------------------------------------------------===// // TODO(b/139813999): specify required function signature in a type-safe way. @@ -253,7 +254,7 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ // TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example, // A tuple-like pattern match syntax could work: -// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) { +// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) { // ... // }, { // ... @@ -284,7 +285,7 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>, } //===----------------------------------------------------------------------===// -// XLA tuple op definitions. +// LMHLO tuple op definitions. //===----------------------------------------------------------------------===// def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { @@ -298,7 +299,7 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { } //===----------------------------------------------------------------------===// -// XLA Slice definitions. +// LMHLO Slice definitions. //===----------------------------------------------------------------------===// def LHLO_SliceOp: LHLO_Op< @@ -337,7 +338,7 @@ def HLO_StaticMemRefCastOp: Op -> memref<5xf32, offset: 2, strides: [1]> // The result of the op is a rank-1 memref with `[5]` shape, stride 1 and @@ -379,7 +380,7 @@ def HLO_DynamicMemRefCastOp: Op -> memref // The result of the op is a type-erased memref with `[%size_X, %size_Y]` // shape and `[%step_X, %step_Y]` strides. The offset will be inherited @@ -470,14 +471,6 @@ def ReshapeMemRefCastOp: Op]; - let extraClassDeclaration = [{ MemRefType getType() { return getResult().getType().cast(); } }]; @@ -491,7 +484,7 @@ def ReshapeMemRefCastOp: Op, diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index 718f0b80f0ba66..a0246f93180dec 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { template struct HloToLhloOpImpl { @@ -31,10 +31,10 @@ struct HloToLhloOpImpl { template using HloToLhloOp = typename HloToLhloOpImpl::Type; -#define MAP_HLO_TO_LHLO(OpName) \ - template <> \ - struct HloToLhloOpImpl { \ - using Type = xla_lhlo::OpName; \ +#define MAP_HLO_TO_LHLO(OpName) \ + template <> \ + struct HloToLhloOpImpl { \ + using Type = lmhlo::OpName; \ } MAP_HLO_TO_LHLO(AbsOp); @@ -74,7 +74,7 @@ MAP_HLO_TO_LHLO(TanhOp); #undef MAP_HLO_TO_LHLO -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h similarity index 60% rename from tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h rename to tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index f443b2d943790f..5d2bffcec2ad92 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ -#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace impl { // A struct to map LhloBinaryOpTy type to the corresponding floating-point and @@ -33,32 +33,32 @@ template struct LhloToScalarOp; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::AddFOp; using IOp = ::mlir::AddIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::CmpFOp; using IOp = ::mlir::CmpIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::DivFOp; using IOp = ::mlir::SignedDivIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::MulFOp; using IOp = ::mlir::MulIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::RemFOp; using IOp = ::mlir::SignedRemIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::SubFOp; using IOp = ::mlir::SubIOp; }; @@ -116,16 +116,17 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { Type element_type = args.front().getType(); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } if (element_type.isa()) { - // xla_lhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) + // lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); @@ -133,30 +134,30 @@ inline Value MapLhloOpToStdScalarOp( b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); auto lhs_gt_zero = b->create>(loc, CmpIPredicate::sge, lhs, zero_intval); - auto neg_val = b->create>(loc, zero_intval, lhs); + auto neg_val = b->create>(loc, zero_intval, lhs); return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val); } return nullptr; } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template -inline Optional getCmpPredicate( - StringRef xla_comparison_direction) { +inline Optional getCmpPredicate(StringRef comparison_direction) { return llvm::None; } template <> inline Optional getCmpPredicate( - StringRef xla_comparison_direction) { - return llvm::StringSwitch>(xla_comparison_direction) + StringRef comparison_direction) { + return llvm::StringSwitch>(comparison_direction) .Case("EQ", CmpFPredicate::OEQ) .Case("NE", CmpFPredicate::ONE) .Case("GE", CmpFPredicate::OGE) @@ -168,8 +169,8 @@ inline Optional getCmpPredicate( template <> inline Optional getCmpPredicate( - StringRef xla_comparison_direction) { - return llvm::StringSwitch>(xla_comparison_direction) + StringRef comparison_direction) { + return llvm::StringSwitch>(comparison_direction) .Case("EQ", CmpIPredicate::eq) .Case("NE", CmpIPredicate::ne) .Case("GE", CmpIPredicate::sge) @@ -179,11 +180,11 @@ inline Optional getCmpPredicate( .Default(llvm::None); } -template -inline Value MapXlaCompareOpToStdScalarOp(Location loc, - StringRef comparison_direction, - ArrayRef result_types, - ArrayRef args, OpBuilder* b) { +template +inline Value MapCompareOpToStdScalarOp(Location loc, + StringRef comparison_direction, + ArrayRef result_types, + ArrayRef args, OpBuilder* b) { const auto& lhs = args[0]; const auto& rhs = args[1]; Type element_type = lhs.getType(); @@ -191,44 +192,47 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc, Optional predicate = getCmpPredicate(comparison_direction); assert(predicate.hasValue() && "expected valid comparison direction"); - return b->create>(loc, predicate.getValue(), lhs, - rhs); + return b->create>(loc, predicate.getValue(), lhs, + rhs); } if (element_type.isa()) { Optional predicate = getCmpPredicate(comparison_direction); assert(predicate.hasValue() && "expected valid comparison direction"); - return b->create>(loc, predicate.getValue(), lhs, - rhs); + return b->create>(loc, predicate.getValue(), lhs, + rhs); } return nullptr; } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return args.front(); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( +inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, @@ -236,21 +240,23 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( +inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { Type sourceType = args.front().getType(); @@ -288,9 +294,10 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { // Dot Op converter from lhlo to affine only accepts float and integer types. const auto& lhs = args[0]; const auto& rhs = args[1]; @@ -312,25 +319,27 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } -/// Implements the conversion of XLA op to scalar op (to use within region of a +/// Implements the conversion of HLO op to scalar op (to use within region of a /// linalg.generic op) for compare-select style operations like min/max. template -struct XlaCompareSelectOpToStdScalarOp { +struct CompareSelectOpToStdScalarOp { static Value map(Location loc, StringRef comparison_direction, ArrayRef result_types, ArrayRef args, OpBuilder* b) { @@ -342,8 +351,8 @@ struct XlaCompareSelectOpToStdScalarOp { /// dialect with a given predicate based on the element type of the operand. template -struct XlaCompareSelectOpToStdScalarOp { +struct CompareSelectOpToStdScalarOp { static Value map(Location loc, StringRef comparison_direction, ArrayRef result_types, ArrayRef args, OpBuilder* b) { @@ -355,72 +364,75 @@ struct XlaCompareSelectOpToStdScalarOpcreate<::mlir::SelectOp>(loc, cmp, args[0], args[1]); } - return XlaCompareSelectOpToStdScalarOp::map( - loc, comparison_direction, result_types, args, b); + return CompareSelectOpToStdScalarOp::map(loc, comparison_direction, + result_types, args, b); } }; template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { - return XlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>::map(loc, "GT", - result_types, args, - b); +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return CompareSelectOpToStdScalarOp< + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "GT", result_types, + args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { - return XlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>::map(loc, "LT", - result_types, args, - b); +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return CompareSelectOpToStdScalarOp< + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "LT", result_types, + args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { Type element_type = args.front().getType(); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } if (element_type.isa()) { - // xla_lhlo.neg(x, result) -> result = sub(0, x) + // lmhlo.neg(x, result) -> result = sub(0, x) Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); auto zero_intval = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); - return b->create>(loc, zero_intval, lhs); + return b->create>(loc, zero_intval, lhs); } return nullptr; } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( +inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, @@ -428,9 +440,10 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { Type element_type = args.front().getType(); if (element_type.isa()) { FloatType float_type = element_type.cast(); @@ -442,69 +455,72 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } } // namespace impl -struct XlaOpToStdScalarOp { - // Implementation for LHLO ops except xla_lhlo::CompareOp. - template ::value && - std::is_same, + !std::is_same::value && + std::is_same, std::false_type>::value>> - static Value map(XlaOpTy op, ArrayRef result_types, + static Value map(HloOpTy op, ArrayRef result_types, ArrayRef args, OpBuilder* b, unsigned i = 0) { return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, args, b); } - // Implementation for HLO ops except xla_hlo::CompareOp. - template , + // Implementation for HLO ops except mhlo::CompareOp. + template , typename = std::enable_if_t< - !std::is_same::value && + !std::is_same::value && !std::is_same::value>> - static Value map(XlaOpTy op, ArrayRef result_types, + static Value map(HloOpTy op, ArrayRef result_types, ArrayRef args, OpBuilder* b, int i = 0) { return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, args, b); } - // Implementation for xla_lhlo::CompareOp. + // Implementation for lmhlo::CompareOp. template ::value>> - static Value map(xla_lhlo::CompareOp op, ArrayRef result_types, + LhloOpTy, lmhlo::CompareOp>::value>> + static Value map(lmhlo::CompareOp op, ArrayRef result_types, ArrayRef args, OpBuilder* b) { auto comparison_direction = op.comparison_direction(); - return impl::MapXlaCompareOpToStdScalarOp( + return impl::MapCompareOpToStdScalarOp( op.getLoc(), comparison_direction, result_types, args, b); } - // Implementation for xla_hlo::CompareOp. - template ::value>> - static Value map(xla_hlo::CompareOp op, ArrayRef result_types, + // Implementation for mhlo::CompareOp. + template ::value>> + static Value map(mhlo::CompareOp op, ArrayRef result_types, ArrayRef args, OpBuilder* b) { auto comparison_direction = op.comparison_direction(); - return impl::MapXlaCompareOpToStdScalarOp( + return impl::MapCompareOpToStdScalarOp( op.getLoc(), comparison_direction, result_types, args, b); } }; -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 312ac952368071..9ea39e95fef4c8 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -29,7 +29,7 @@ template class OperationPass; class Pass; -namespace xla_hlo { +namespace mhlo { /// Lowers HLO control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass(); @@ -55,12 +55,12 @@ std::unique_ptr> createTransformUnrankedHloPass(); // necessary to export to XLA. std::unique_ptr> createSinkConstantsToControlFlowPass(); -// fuse xla_hlo ops to kLoop/kInput fusion patterns -std::unique_ptr> createXlaHloFusionPass(); +// fuse mhlo ops to kLoop/kInput fusion patterns +std::unique_ptr> createMhloFusionPass(); -} // namespace xla_hlo +} // namespace mhlo -namespace xla_lhlo { +namespace lmhlo { // Lowers from LHLO dialect to Affine dialect. std::unique_ptr> createLegalizeToAffinePass(); @@ -92,14 +92,14 @@ std::unique_ptr createLhloCopyRemovalPass(); // Lowers from LHLO dialect to parallel loops. std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); -} // namespace xla_lhlo +} // namespace lmhlo -namespace xla { +namespace hlo { /// Lowers the standard TanhOp to an approximation that does not use intrinsics. std::unique_ptr> createLegalizeTanhToApproximationPass(); -} // namespace xla +} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 6e85b45d14ce00..cb9a85a658a3ce 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -27,7 +27,7 @@ class LLVMTypeConverter; class LowerToLLVMOptions; class OwningRewritePatternList; class BufferAssignmentPlacer; -namespace xla_hlo { +namespace mhlo { // Collection of rewrite patterns for lowering a general dot product. void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns, @@ -38,8 +38,8 @@ void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns, void PopulateComplexLoweringPatterns(MLIRContext *context, OwningRewritePatternList *patterns); -void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, - MLIRContext *ctx); +void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, + MLIRContext *ctx); // Collection of rewrite patterns for lowering of HLO to LHLO dialect. void populateHLOToLHLOConversionPattern( @@ -73,34 +73,34 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context, void PopulateUnfuseBatchNormPatterns(MLIRContext *context, OwningRewritePatternList *patterns); -} // namespace xla_hlo +} // namespace mhlo -namespace xla_lhlo { +namespace lmhlo { /// Collect a set of patterns to convert from the LHLO dialect to LLVM. void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, LLVMTypeConverter *converter, OwningRewritePatternList *patterns); -} // namespace xla_lhlo +} // namespace lmhlo -namespace xla_chlo { +namespace chlo { // Populates a collection of conversion patterns for legalizing client-HLO to // HLO. void PopulateLegalizeChloToHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns); -} // namespace xla_chlo +} // namespace chlo -namespace xla { +namespace hlo { // Populates a pattern that translates the standard TanhOp to an approximation // that does not use intrinsics. void PopulateTanhToApproximationPatterns(MLIRContext *context, OwningRewritePatternList *patterns); -} // namespace xla +} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h index 7c5b5e3311c83e..3be7d42cc25fb1 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_ // Utilities relating to implementing HLO broadcasting. // Note: This file should not depend on any non-MLIR TensorFlow libraries. @@ -27,7 +27,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { -namespace xla { +namespace hlo { // Checks whether the given operand types and broadcast_dims attr represent a // legal combination for "numpy" style broadcasting (where 1-dims are prepended @@ -43,7 +43,7 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, Value rhs, OpBuilder& builder); -} // namespace xla +} // namespace hlo } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h index 5fe2f80561f3a4..a63df336d8f853 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h @@ -13,21 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project namespace mlir { -namespace xla { +namespace hlo { // Converts the given elements attr to the specified elements type. // Requires type of the elements and new_type to be either integer or float // type. mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, mlir::Type new_type); -} // namespace xla +} // namespace hlo } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h index 7afba7f4f2f90d..79b56b39d55502 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_ #include @@ -162,4 +162,4 @@ class GraphCycles { } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h index 5db8ad38fccbbd..b31ba231acd3a9 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -23,7 +23,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project namespace mlir { -namespace xla { +namespace hlo { // Computes the broadcast dimensions attr for an elementwise binary operator // between two ranked tensors. @@ -68,7 +68,7 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) { // Requires `ty` to be either FloatType of IntegerType. DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value); -} // namespace xla +} // namespace hlo } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc index e43890979e5e1a..c6c193a9d89c00 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" namespace mlir { -namespace xla_chlo { +namespace chlo { template static LogicalResult Verify(T op) { @@ -137,7 +137,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( auto broadcast_dimensions = op->getAttr("broadcast_dimensions") .dyn_cast_or_null(); if (broadcast_dimensions && - !xla::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) { + !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) { // Note: It is unclear whether the general specification of explicit // broadcast_dimensions on binary ops is a feature we want to carry // forward. While it can technically be implemented for ranked-dynamic, @@ -150,7 +150,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( << "broadcast_dimensions = " << broadcast_dimensions; } - Value computed_shape = xla::ComputeBinaryElementwiseBroadcastingResultExtents( + Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents( loc, lhs, rhs, builder); if (!computed_shape) return failure(); reifiedReturnShapes.push_back(computed_shape); @@ -263,10 +263,10 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp); #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" //===----------------------------------------------------------------------===// -// xla_chlo Dialect Constructor +// chlo Dialect Constructor //===----------------------------------------------------------------------===// -XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context) +HloClientDialect::HloClientDialect(MLIRContext* context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST @@ -274,5 +274,5 @@ XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context) >(); } -} // namespace xla_chlo +} // namespace chlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc index 65d200aa5f2e81..f4df946d11ab9a 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc @@ -17,8 +17,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -// Static initialization for XLA dialect registration. -static mlir::DialectRegistration xla_hlo_ops; -static mlir::DialectRegistration - xla_chlo_ops; -static mlir::DialectRegistration xla_lhlo_ops; +// Static initialization for *HLO dialects registration. +static mlir::DialectRegistration mhlo_ops; +static mlir::DialectRegistration chlo_ops; +static mlir::DialectRegistration lmhlo_ops; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index 03caf6272cd654..cbd478a0283a57 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines the operations used in the XLA dialect. +// This file defines the operations used in the MHLO dialect. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" @@ -60,16 +60,14 @@ limitations under the License. namespace mlir { #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc" -namespace xla_hlo { +namespace mhlo { -Operation* XlaHloDialect::materializeConstant(OpBuilder& builder, - Attribute value, Type type, - Location loc) { +Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value, + Type type, Location loc) { // HLO dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. if (value.isa()) - return builder.create(loc, type, - value.cast()); + return builder.create(loc, type, value.cast()); return nullptr; } @@ -167,7 +165,7 @@ void ConstOp::build(OpBuilder& builder, OperationState& result, } // TODO: support other XLA specific types. - assert(type && "unsupported attribute type for building xla_hlo.constant"); + assert(type && "unsupported attribute type for building mhlo.constant"); result.types.push_back(type); result.addAttribute("value", value); } @@ -215,6 +213,52 @@ static LogicalResult Verify(IotaOp op) { return success(); } +// Iota operations across multiple dimensions can be reduced to an iota and a +// ranked broadcast. +struct IotaBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IotaOp iota, + PatternRewriter& rewriter) const override { + auto result_ty = iota.getType().cast(); + if (!result_ty.hasRank() || result_ty.getRank() < 2) { + return failure(); + } + + auto iota_dimension = iota.iota_dimension(); + + auto iota_type = RankedTensorType::get( + {result_ty.getDimSize(iota_dimension.getLimitedValue())}, + result_ty.getElementType()); + + auto new_iota = rewriter.create(iota.getLoc(), iota_type, + rewriter.getI64IntegerAttr(0)); + + auto broadcast_attr = DenseIntElementsAttr::get( + RankedTensorType::get({1}, rewriter.getIntegerType(64)), + {iota_dimension}); + rewriter.replaceOpWithNewOp(iota, result_ty, new_iota, + broadcast_attr); + return success(); + } +}; + +void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + +OpFoldResult IotaOp::fold(ArrayRef operands) { + auto dimension = iota_dimension().getLimitedValue(); + auto result_ty = getResult().getType().cast(); + if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) { + Builder builder(getContext()); + return builder.getZeroAttr(result_ty); + } + + return {}; +} + //===----------------------------------------------------------------------===// // DynamicIotaOp //===----------------------------------------------------------------------===// @@ -236,11 +280,63 @@ struct DynamicIotaIsStatic : public OpRewritePattern { } }; +// Dynamic Iota operations across multiple dimensions can be reduced to an iota +// and a ranked broadcast. +struct DynamicIotaBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicIotaOp iota, + PatternRewriter& rewriter) const override { + auto result_ty = iota.getType().cast(); + if (!result_ty.hasRank() || result_ty.getRank() < 2) { + return failure(); + } + + auto iota_dimension = iota.iota_dimension(); + auto iota_dimension_int = iota_dimension.getLimitedValue(); + + auto converted_shape = rewriter.create( + iota.getLoc(), + RankedTensorType::get( + iota.output_shape().getType().cast().getShape(), + rewriter.getI64Type()), + iota.output_shape()); + + auto sliced_shape = rewriter.create( + iota.getLoc(), converted_shape, + GetI64ElementsAttr(iota_dimension_int, &rewriter), + GetI64ElementsAttr(iota_dimension_int + 1, &rewriter), + GetI64ElementsAttr(1, &rewriter)); + + auto converted_sliced_shape = rewriter.create( + iota.getLoc(), + RankedTensorType::get( + {1}, + iota.output_shape().getType().cast().getElementType()), + sliced_shape); + + auto iota_type = RankedTensorType::get( + {result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType()); + + auto new_iota = rewriter.create( + iota.getLoc(), iota_type, converted_sliced_shape, + rewriter.getI64IntegerAttr(0)); + + auto broadcast_attr = DenseIntElementsAttr::get( + RankedTensorType::get({1}, rewriter.getIntegerType(64)), + {iota_dimension}); + rewriter.replaceOpWithNewOp( + iota, result_ty, new_iota, iota.output_shape(), broadcast_attr); + return success(); + } +}; + } // namespace void DynamicIotaOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -319,7 +415,7 @@ OpFoldResult ConvertOp::fold(ArrayRef operands) { // If the operand is constant, we can do the conversion now. if (auto elementsAttr = operands.front().dyn_cast_or_null()) { - return xla::ConvertElementsAttr(elementsAttr, + return hlo::ConvertElementsAttr(elementsAttr, getElementTypeOrSelf(getResult())); } @@ -387,7 +483,7 @@ static LogicalResult Verify(GetTupleElementOp op) { OpFoldResult GetTupleElementOp::fold(ArrayRef operands) { if (auto tupleOp = - dyn_cast_or_null(getOperand().getDefiningOp())) { + dyn_cast_or_null(getOperand().getDefiningOp())) { return tupleOp.getOperand(index().getLimitedValue()); } @@ -535,17 +631,25 @@ static LogicalResult Verify(BroadcastInDimOp op) { return success(); } -OpFoldResult BroadcastInDimOp::fold(ArrayRef) { +OpFoldResult BroadcastInDimOp::fold(ArrayRef attrs) { auto type = getType().cast(); - if (type != getOperand().getType()) { - return nullptr; - } - auto broadcast_values = broadcast_dimensions().getValues(); - if (!std::equal(broadcast_values.begin(), broadcast_values.end(), - llvm::seq(0, type.getRank()).begin())) { - return nullptr; + if (type == getOperand().getType()) { + auto broadcast_values = broadcast_dimensions().getValues(); + if (!std::equal(broadcast_values.begin(), broadcast_values.end(), + llvm::seq(0, type.getRank()).begin())) { + return {}; + } + return getOperand(); } - return getOperand(); + + // Constant fold when an operand is a splat tensor attribute. + if (!attrs[0] || !type.hasStaticShape()) return {}; + auto splatOperandAttr = attrs[0].dyn_cast(); + if (!splatOperandAttr) return {}; + // MLIR core bug (https://bugs.llvm.org/show_bug.cgi?id=46588): dense element + // attribute iterator not implemented for complex element types. + if (type.getElementType().isa()) return {}; + return SplatElementsAttr::get(type, splatOperandAttr.getSplatValue()); } //===----------------------------------------------------------------------===// @@ -693,10 +797,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, } OpFoldResult ComplexOp::fold(ArrayRef operands) { - auto real_op = - dyn_cast_or_null(getOperand(0).getDefiningOp()); - auto imag_op = - dyn_cast_or_null(getOperand(1).getDefiningOp()); + auto real_op = dyn_cast_or_null(getOperand(0).getDefiningOp()); + auto imag_op = dyn_cast_or_null(getOperand(1).getDefiningOp()); if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) { return real_op.getOperand(); } @@ -727,7 +829,7 @@ void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) { OpFoldResult ImagOp::fold(ArrayRef operands) { if (auto complex_op = - dyn_cast_or_null(getOperand().getDefiningOp())) { + dyn_cast_or_null(getOperand().getDefiningOp())) { return complex_op.getOperand(1); } @@ -740,7 +842,7 @@ void RealOp::build(OpBuilder& builder, OperationState& state, Value val) { OpFoldResult RealOp::fold(ArrayRef operands) { if (auto complex_op = - dyn_cast_or_null(getOperand().getDefiningOp())) { + dyn_cast_or_null(getOperand().getDefiningOp())) { return complex_op.getOperand(0); } @@ -1148,7 +1250,7 @@ static LogicalResult Verify(MapOp op) { // RecvOp //===----------------------------------------------------------------------===// -// Checks that the result type is of the form `tuple` +// Checks that the result type is of the form `tuple` static LogicalResult Verify(RecvOp op) { auto result_ty = op.getResult().getType().cast(); auto subtypes = result_ty.getTypes(); @@ -2020,7 +2122,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" //===----------------------------------------------------------------------===// -// xla_hlo Dialect Interfaces +// mhlo Dialect Interfaces //===----------------------------------------------------------------------===// namespace { @@ -2032,7 +2134,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface { BlockAndValueMapping& valueMapping) const final { return true; } - // Operations in xla_hlo dialect are always legal to inline since they are + // Operations in mhlo dialect are always legal to inline since they are // pure. bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final { return true; @@ -2041,10 +2143,10 @@ struct HLOInlinerInterface : public DialectInlinerInterface { } // end anonymous namespace //===----------------------------------------------------------------------===// -// xla_hlo Dialect Constructor +// mhlo Dialect Constructor //===----------------------------------------------------------------------===// -XlaHloDialect::XlaHloDialect(MLIRContext* context) +MhloDialect::MhloDialect(MLIRContext* context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST @@ -2052,26 +2154,23 @@ XlaHloDialect::XlaHloDialect(MLIRContext* context) >(); addInterfaces(); addTypes(); - // Support unknown operations because not all XLA operations are registered. - // allowUnknownOperations(); } -Type XlaHloDialect::parseType(DialectAsmParser& parser) const { +Type MhloDialect::parseType(DialectAsmParser& parser) const { StringRef data_type; if (parser.parseKeyword(&data_type)) return Type(); if (data_type == "token") return TokenType::get(getContext()); - parser.emitError(parser.getNameLoc()) - << "unknown xla_hlo type: " << data_type; + parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type; return nullptr; } -void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const { +void MhloDialect::printType(Type type, DialectAsmPrinter& os) const { if (type.isa()) { os << "token"; return; } - os << ""; + os << ""; } //===----------------------------------------------------------------------===// @@ -2106,5 +2205,5 @@ LogicalResult deriveShapeFromFirstOperand( return success(); } -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc index 3a0d7ebfc644d0..bd0dc224cccd87 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines the operations used in the XLA dialect. +// This file defines the operations used in the LMHLO dialect. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" @@ -46,9 +46,9 @@ limitations under the License. namespace mlir { #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc" -namespace xla_lhlo { +namespace lmhlo { -XlaLhloDialect::XlaLhloDialect(MLIRContext *context) +LmhloDialect::LmhloDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST @@ -138,5 +138,5 @@ void FusionOp::build(OpBuilder &builder, OperationState &result, FusionOp::ensureTerminator(*bodyRegion, builder, result.location); } -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index cdd429c674e4af..06e95e04c76def 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -25,12 +25,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" namespace mlir { -namespace xla_chlo { +namespace chlo { namespace { // Converts binary ops that statically are determined to not broadcast directly -// to the corresponding xla_hlo non-broadcasting op. +// to the corresponding mhlo non-broadcasting op. template struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -63,7 +63,7 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern { }; // Converts a binary op with ranked broadcasting operands to explicitly -// broadcast and invoke the corresponding xla_hlo non-broadcasting op. +// broadcast and invoke the corresponding mhlo non-broadcasting op. // Note that dynamic broadcasting supported by this pattern is only valid for // "numpy" broadcasting semantics as defined here: // https://docs.scipy.org/doc/numpy/reference/ufuncs.html @@ -96,7 +96,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp // Check for "numpy"-style rank broadcast. auto broadcast_dimensions = op.broadcast_dimensions(); if (broadcast_dimensions && - !xla::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) { + !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) { // Note: It is unclear whether the general specification of explicit // broadcast_dimensions on binary ops is a feature we want to carry // forward. While it can technically be implemented for ranked-dynamic, @@ -126,7 +126,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); Value result_extents = - xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, + hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, rewriter); // Note that we unconditionally emit DynamicBroadcastInDim ops and let @@ -136,7 +136,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp // properly. auto lhs_broadcast_dimensions = llvm::to_vector<4>( llvm::seq(result_rank - lhs_type.getRank(), result_rank)); - Value broadcasted_lhs = rewriter.create( + Value broadcasted_lhs = rewriter.create( loc, RankedTensorType::get(result_type.getShape(), lhs_type.getElementType()), @@ -144,7 +144,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp rewriter.getI64TensorAttr(lhs_broadcast_dimensions)); auto rhs_broadcast_dimensions = llvm::to_vector<4>( llvm::seq(result_rank - rhs_type.getRank(), result_rank)); - Value broadcasted_rhs = rewriter.create( + Value broadcasted_rhs = rewriter.create( loc, RankedTensorType::get(result_type.getShape(), rhs_type.getElementType()), @@ -182,23 +182,21 @@ struct HloBinaryElementwiseAdaptor { }; struct HloComplexAdaptor { - static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op, - Type result_type, Value broadcasted_lhs, - Value broadcasted_rhs, - OpBuilder &builder) { - return builder.create(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs); + static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); } }; struct HloCompareAdaptor { - static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op, - Type result_type, Value broadcasted_lhs, - Value broadcasted_rhs, - OpBuilder &builder) { - return builder.create(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs, - from_op.comparison_direction()); + static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs, + from_op.comparison_direction()); } }; @@ -214,29 +212,28 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, HloBinaryElementwiseAdaptor>(context, \ patterns); - POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp); - POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp); - POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op); - POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp); - POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp); - POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp); - POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp); - POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp); - POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp); - POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp); - POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp); - POPULATE_BCAST(BroadcastShiftRightArithmeticOp, - xla_hlo::ShiftRightArithmeticOp); - POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp); - POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp); - POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp); + POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp); + POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp); + POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op); + POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp); + POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp); + POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp); + POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp); + POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp); + POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp); + POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp); + POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp); + POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp); + POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp); + POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp); + POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp); // Broadcasting ops requiring special construction. - PopulateForBinaryOp(context, patterns); - PopulateForBinaryOp(context, patterns); + PopulateForBinaryOp( + context, patterns); + PopulateForBinaryOp( + context, patterns); } -} // namespace xla_chlo +} // namespace chlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index f3422614b94d97..48749c7d43d377 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_chlo { +namespace chlo { namespace { @@ -31,9 +31,9 @@ struct TestChloLegalizeToHloPass ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns; - conversionTarget.addIllegalDialect(); - // Consider the xla_hlo dialect legal for tests. - conversionTarget.addLegalDialect(); + conversionTarget.addIllegalDialect(); + // Consider the mhlo dialect legal for tests. + conversionTarget.addLegalDialect(); // The conversion uses helpers from the Standard dialect. conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); @@ -49,9 +49,9 @@ struct TestChloLegalizeToHloPass } // namespace -} // namespace xla_chlo +} // namespace chlo } // namespace mlir -static mlir::PassRegistration pass( - "test-xla-chlo-legalize-to-hlo", +static mlir::PassRegistration pass( + "mhlo-test-chlo-legalize-to-hlo", "Test pass for applying chlo -> hlo legalization patterns"); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 5ff10fb419f507..4ee45d56a8e923 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -37,14 +37,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { template using BaseOpConversion = BufferAssignmentOpConversionPattern; using StdReturnOpConverter = detail::BufferAssignmentReturnOpConverter; + lmhlo::CopyOp, true>; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, @@ -128,20 +128,20 @@ class HloToLhloOpConverter : public BaseOpConversion { op->getLoc(), result.value(), results_shape.front(), &rewriter)); } } - rewriter.create>(op->getLoc(), llvm::None, - buffer_args, op->getAttrs()); + rewriter.create>(op->getLoc(), llvm::None, + buffer_args, op->getAttrs()); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); return success(); } }; struct HloToLhloDynamicBroadcastInDimOpConverter - : public BaseOpConversion { + : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - xla_hlo::DynamicBroadcastInDimOp op, ArrayRef operands, + mhlo::DynamicBroadcastInDimOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); Value resultBuffer = InsertDynamicAllocAndDealloc( @@ -149,7 +149,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter Value transformed_operand = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); - rewriter.create( + rewriter.create( loc, transformed_operand, resultBuffer, op.broadcast_dimensions()); rewriter.replaceOp(op, {resultBuffer}); @@ -161,8 +161,8 @@ struct HloToLhloDynamicBroadcastInDimOpConverter // Inserts dynamic memref to change the layout of the memref to put 0-stride // and size of the target dimension if size-1 dimension expansion is // necessary. - xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( - xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { + lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( + mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { auto loc = op.getLoc(); auto operand_type = operand.getType().cast(); auto operand_shape = operand_type.getShape(); @@ -214,18 +214,43 @@ struct HloToLhloDynamicBroadcastInDimOpConverter makeStridedLinearLayoutMap(dynamic_layout, /*offset=*/0, b->getContext())); - auto transformed_operand = b->create( + auto transformed_operand = b->create( loc, type_erased_memref_type, operand, sizes, strides); return transformed_operand; } }; -struct HloToLhloReduceOpConverter : public BaseOpConversion { +struct HloToLhloDynamicReshapeConverter + : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - xla_hlo::ReduceOp op, ArrayRef operands, + mhlo::DynamicReshapeOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Type result_type; + if (auto ranked_type = op.getType().dyn_cast()) { + result_type = + MemRefType::get(ranked_type.getShape(), ranked_type.getElementType()); + } else if (auto unranked_type = + op.getType().dyn_cast()) { + result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0); + } else { + return failure(); + } + mhlo::DynamicReshapeOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp( + op, result_type, adaptor.operand(), adaptor.output_shape()); + return success(); + } +}; + +struct HloToLhloReduceOpConverter : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::ReduceOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); // TODO(b/137624192) Implement variadic reduce. @@ -241,8 +266,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { buffer_args.push_back( InsertAlloc(loc, result, this->bufferAssignment, &rewriter)); } - auto new_op = rewriter.create( - loc, llvm::None, buffer_args, op.getAttrs()); + auto new_op = rewriter.create(loc, llvm::None, buffer_args, + op.getAttrs()); // Copy over the operations inside the region. rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end()); @@ -267,7 +292,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { } // Insert terminator at the end. rewriter.setInsertionPointToEnd(&entry_block); - rewriter.create(loc); + rewriter.create(loc); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); @@ -296,8 +321,8 @@ class HloToLhloTensorStoreOpConverter LogicalResult matchAndRewrite( mlir::TensorStoreOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { - rewriter.replaceOpWithNewOp( - op, llvm::None, operands.front(), operands.back()); + rewriter.replaceOpWithNewOp(op, llvm::None, operands.front(), + operands.back()); return success(); } }; @@ -311,16 +336,16 @@ class HloToLhloTensorStoreOpConverter // %arg1: memref<2x2xf32>, // %arg2: memref<2x2xf32>, // %arg3: memref<2x2xf32>) { -// "xla_lhlo.fusion"() ({ +// "lmhlo.fusion"() ({ // %0 = tensor_load %arg1 : memref<2x2xf32> // %1 = tensor_load %arg2 : memref<2x2xf32> -// %2 = "xla_hlo.add"(%0, %1) : +// %2 = "mhlo.add"(%0, %1) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // %3 = tensor_load %arg0 : memref<2x2xf32> -// %4 = "xla_hlo.multiply"(%2, %3) : +// %4 = "mhlo.multiply"(%2, %3) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // tensor_store %4, %arg3 : memref<2x2xf32> -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.terminator"() : () -> () // }) : () -> () // return // } @@ -330,13 +355,13 @@ class HloToLhloTensorStoreOpConverter // %arg1: memref<2x2xf32>, // %arg2: memref<2x2xf32>, // %arg3: memref<2x2xf32>) { -// "xla_lhlo.fusion"() ( { +// "lmhlo.fusion"() ( { // %0 = alloc() : memref<2x2xf32> -// "xla_lhlo.add"(%arg1, %arg2, %0) : +// "lmhlo.add"(%arg1, %arg2, %0) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -// "xla_lhlo.multiply"(%0, %arg0, %arg3) : +// "lmhlo.multiply"(%0, %arg0, %arg3) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.terminator"() : () -> () // }) : () -> () // return // } @@ -344,8 +369,8 @@ class HloToLhloTensorStoreOpConverter // FuncOp signature conversion example: // // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { -// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> -// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>, +// %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> +// tensor<4xf32> %1 = "mhlo.add"(%arg0, %0) : (tensor<4xf32>, // tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> // } // @@ -357,13 +382,13 @@ class HloToLhloTensorStoreOpConverter // %arg2: memref<4xf32>) { // %0 = alloc() : memref<4xf32> -// "xla_lhlo.maximum"(%arg0, %arg1, %0) : +// "lmhlo.maximum"(%arg0, %arg1, %0) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () // %1 = alloc() : memref<4xf32> -// "xla_lhlo.add"(%arg0, %0, %1) : +// "lmhlo.add"(%arg0, %0, %1) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () -// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () +// "lmhlo.terminator"() : () -> () // } struct HloLegalizeToLhlo @@ -381,26 +406,31 @@ struct HloLegalizeToLhlo OwningRewritePatternList patterns; auto& context = getContext(); ConversionTarget target(context); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addLegalOp(); target.addLegalOp(); - target.addIllegalDialect(); + target.addIllegalDialect(); BufferAssignmentTypeConverter converter; + auto isMemRefType = [](Type type) { return type.isa(); }; target.addDynamicallyLegalOp([&](FuncOp op) { auto inputs = op.getType().getInputs(); - return llvm::all_of(inputs, - [](Type input) { return input.isa(); }) && + return llvm::all_of(inputs, isMemRefType) && converter.isLegal(&op.getBody()); }); - target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { - return std::all_of(returnOp.operand_type_begin(), - returnOp.operand_type_end(), - [](Type type) { return type.isa(); }); + target.addDynamicallyLegalOp([&](CallOp op) { + return std::all_of(op.operand_type_begin(), op.operand_type_end(), + isMemRefType) && + std::all_of(op.result_type_begin(), op.result_type_end(), + isMemRefType); + }); + target.addDynamicallyLegalOp([&](mlir::ReturnOp op) { + return std::all_of(op.operand_type_begin(), op.operand_type_end(), + isMemRefType); }); auto module = getOperation(); @@ -411,12 +441,12 @@ struct HloLegalizeToLhlo &converter, &patterns); if (results_escape_function) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp, + mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp, /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment, &converter, &patterns); } else { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp, + mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp, /*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment, &converter, &patterns); } @@ -442,38 +472,39 @@ void populateHLOToLHLOConversionPattern( // clang-format off patterns->insert< HloToLhloDynamicBroadcastInDimOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, + HloToLhloDynamicReshapeConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloReduceOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter @@ -489,5 +520,5 @@ std::unique_ptr> createLegalizeToLhloPass( static PassRegistration legalize_pass( "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect"); -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc index 8f93990e260034..440df7ec23f937 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file implements logic for lowering XLA dialect to Standard dialect. +// This file implements logic for lowering MHLO dialect to Standard dialect. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" @@ -35,7 +35,7 @@ limitations under the License. using mlir::PassRegistration; namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { struct LegalizeControlFlow : public mlir::PassWrapper { @@ -51,7 +51,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block, OpBuilder* builder) { for (auto& old_block : region->getBlocks()) { Block* block = mapper.lookup(&old_block); - auto return_op = dyn_cast(block->getTerminator()); + auto return_op = dyn_cast(block->getTerminator()); if (!return_op) continue; builder->setInsertionPointToEnd(block); builder->create(loc, target_block, return_op.getOperands()); @@ -61,7 +61,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block, return success(); } -LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) { +LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) { Operation* op_inst = if_op.getOperation(); mlir::OpBuilder builder(if_op); auto orig_block = op_inst->getBlock(); @@ -106,13 +106,13 @@ LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) { return success(); } -LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { - // Converts an XLA while loop into control flow. This generates a set of MLIR - // blocks and branches, along with inlining the regions provided by the XLA +LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { + // Converts a MHLO while loop into control flow. This generates a set of MLIR + // blocks and branches, along with inlining the regions provided by the MHLO // while loop. The structure should be similar to below: // // - // %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}} + // %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}} // auto* op_inst = while_op.getOperation(); mlir::OpBuilder builder(while_op); @@ -147,7 +147,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { // extract_element and conditional branch. This changes the block below: // ^cond(%0): // - // "xla_hlo".return(%1) + // "mhlo".return(%1) // // Into: // ^cond(%0): @@ -156,14 +156,14 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { // cond_br %2, ^body(%0), ^tail(%0) // Branch. builder.setInsertionPointToStart(cond_block); - // Replace the xla_hlo::ReturnOp with a branch back to the condition block. - // This is required as the xla_hlo::ReturnOp is used to mark the end of a + // Replace the mhlo::ReturnOp with a branch back to the condition block. + // This is required as the mhlo::ReturnOp is used to mark the end of a // block for regions nested inside of a operations (MLIR ReturnOp cannot be // nested within an non-function region). for (auto& block : while_op.cond()) { auto new_block = mapper.lookup(&block); - auto return_op = dyn_cast(new_block->getTerminator()); + auto return_op = dyn_cast(new_block->getTerminator()); if (!return_op) continue; builder.setInsertionPointToEnd(new_block); @@ -183,7 +183,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { // conditional block. This changes the block below: // ^body(%0): // - // "xla_hlo".return(%1) + // "mhlo".return(%1) // // Into: // ^body(%0): @@ -191,8 +191,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { // br ^cond(%0) // Branch. for (auto& block : while_op.body()) { auto new_block = mapper.lookup(&block); - auto return_op = - dyn_cast(new_block->getTerminator()); + auto return_op = dyn_cast(new_block->getTerminator()); if (!return_op) continue; builder.setInsertionPointToEnd(new_block); builder.create(loc, cond_block, return_op.getOperands()); @@ -224,14 +223,14 @@ void LegalizeControlFlow::runOnFunction() { } } } // namespace -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir std::unique_ptr> -mlir::xla_hlo::createLegalizeControlFlowPass() { +mlir::mhlo::createLegalizeControlFlowPass() { return std::make_unique(); } -static PassRegistration legalize_cf_pass( - "xla-legalize-control-flow", - "Legalize from XLA control flow to MLIR control flow"); +static PassRegistration legalize_cf_pass( + "mhlo-legalize-control-flow", + "Legalize from MHLO control flow to CFG control flow"); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc index 4f32bab025536a..1890646160ee62 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla { +namespace hlo { namespace { /// Emits the fast tanh approximation that is also used by XLA. @@ -149,8 +149,8 @@ void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context, } static PassRegistration legalize_pass( - "xla-legalize-tanh-to-approximation", + "mhlo-legalize-tanh-to-approximation", "Legalize tanh from standard dialect to an approximation"); -} // namespace xla +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc similarity index 82% rename from tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc rename to tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index d9f168820129d2..717e96824368a9 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -16,6 +16,7 @@ limitations under the License. // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. #include "absl/memory/memory.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -32,7 +33,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { @@ -49,12 +50,12 @@ Value getResultValue(Operation* op) { } template -ShapedType getXLAOpResultType(Operation* op) { +ShapedType getHloOpResultType(Operation* op) { return getResultValue(op).getType().template cast(); } template -bool verifyXLAOpBufferOrTensorSemantics(Operation* op) { +bool verifyHloOpBufferOrTensorSemantics(Operation* op) { auto verifyType = [&](Value val) -> bool { return (isLHLO && val.getType().isa()) || (!isLHLO && val.getType().isa()); @@ -131,9 +132,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern { loc, opResultTypes, args, args_count, results_count, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { - // TODO(ravishankarm) : For now use the method in xla_lhlo namespace. + // TODO(ravishankarm) : For now use the method in lmhlo namespace. // That method needs to be moved out of there. - Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + Value opResult = lmhlo::HloOpToStdScalarOp::map( op, bodyResultTypes, llvm::to_vector<2>(args.take_front(args_count)), &rewriter); nestedBuilder.create(loc, opResult); @@ -162,8 +163,8 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { // Create two loads from the input. auto lhs = rewriter.create(loc, lhlo_op.lhs()); auto rhs = rewriter.create(loc, lhlo_op.rhs()); - // TODO(ravishankarm) : Move this method out of xla_lhlo namespace. - Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + // TODO(ravishankarm) : Move this method out of lmhlo namespace. + Value opResult = lmhlo::HloOpToStdScalarOp::map( lhlo_op, argType.getElementType(), llvm::ArrayRef{lhs, rhs}, &rewriter); rewriter.create(loc, opResult, lhlo_op.out()); @@ -173,21 +174,21 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { }; //===----------------------------------------------------------------------===// -// xla_lhlo.convolution conversion pattern. +// lmhlo.convolution conversion pattern. //===----------------------------------------------------------------------===// -/// Converts xla_lhlo.convolution operation to a linalg.conv op. -struct ConvToLinalgConverter : public OpConversionPattern { +/// Converts lmhlo.convolution operation to a linalg.conv op. +struct ConvToLinalgConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; // This code has been adapted from IREE's - // (https://github.com/google/iree/) xla_hlo -> linalg conversion. + // (https://github.com/google/iree/) mhlo -> linalg conversion. LogicalResult matchAndRewrite( - xla_lhlo::ConvOp op, ArrayRef args, + lmhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { // Check validity of dimension information. - if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers = + if (const lmhlo::ConvDimensionNumbers& dimensionNumbers = op.dimension_numbers()) { const int inputSpatialRank = llvm::size(dimensionNumbers.input_spatial_dimensions()); @@ -274,7 +275,7 @@ struct ConvToLinalgConverter : public OpConversionPattern { } }; -/// Base class for lowering xla operations that have one operand and one result, +/// Base class for lowering HLO operations that have one operand and one result, /// and are semantically equivalent to a copy of the input to the output (like /// transpose, some reshape, etc.). The derived classes need to provide a method /// `getIndexingMaps` that returns AffineMaps for the index maps of the input @@ -287,8 +288,8 @@ class DataMovementOpConverter : public OpConversionPattern { LogicalResult matchAndRewrite( OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - if (!verifyXLAOpBufferOrTensorSemantics(op)) return failure(); - auto resultType = getXLAOpResultType(op); + if (!verifyHloOpBufferOrTensorSemantics(op)) return failure(); + auto resultType = getHloOpResultType(op); SmallVector indexing_maps = Derived::getIndexingMaps(op, &rewriter); @@ -322,7 +323,7 @@ class BroadcastConverter ShapedType inputType = broadcastOp.operand().getType().template cast(); unsigned inputRank = inputType.getRank(); - unsigned nloops = getXLAOpResultType(broadcastOp).getRank(); + unsigned nloops = getHloOpResultType(broadcastOp).getRank(); // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to // the input's dimensions. @@ -348,15 +349,15 @@ class BroadcastConverter class HloBroadcastInDimConverter : public DataMovementOpConverter { + mhlo::BroadcastInDimOp, false> { public: using DataMovementOpConverter::DataMovementOpConverter; static SmallVector getIndexingMaps( - xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) { - auto resultType = getXLAOpResultType(broadcastOp); + mhlo::BroadcastInDimOp broadcastOp, Builder* b) { + auto resultType = getHloOpResultType(broadcastOp); auto operandType = broadcastOp.operand().getType().template cast(); unsigned nloops = resultType.getRank(); @@ -388,14 +389,14 @@ class HloBroadcastInDimConverter }; class LhloBroadcastInDimConverter - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::BroadcastInDimOp op, ArrayRef args, + lmhlo::BroadcastInDimOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); + lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); auto result_type = operand_adaptor.output().getType().cast(); auto result_shape = result_type.getShape(); @@ -444,9 +445,9 @@ class LhloBroadcastInDimConverter // Inserts 'linalg.reshape' if there is a size-1 dim expansion. std::pair> InsertReshapeIfNecessary( - xla_lhlo::BroadcastInDimOp op, ArrayRef args, + lmhlo::BroadcastInDimOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const { - xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); + lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); Value operand = operand_adaptor.operand(); auto operand_type = operand_adaptor.operand().getType().cast(); auto operand_shape = operand_type.getShape(); @@ -512,7 +513,7 @@ class LhloBroadcastInDimConverter return std::make_pair(operand, broadcast_dims); } - SmallVector getIndexingMaps(xla_lhlo::BroadcastInDimOp op, + SmallVector getIndexingMaps(lmhlo::BroadcastInDimOp op, ArrayRef broadcastDims, ArrayRef resultShape, MemRefType operandType, @@ -555,7 +556,7 @@ class TransposeConverter isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { auto resultType = - getXLAOpResultType(op).template cast(); + getHloOpResultType(op).template cast(); auto nloops = resultType.getRank(); SmallVector inputExprs; inputExprs.resize(resultType.getRank()); @@ -579,11 +580,11 @@ class ReshapeOpConverter : public OpConversionPattern { LogicalResult matchAndRewrite( OpTy reshapeOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - if (!verifyXLAOpBufferOrTensorSemantics(reshapeOp)) + if (!verifyHloOpBufferOrTensorSemantics(reshapeOp)) return failure(); ShapedType operandType = reshapeOp.operand().getType().template cast(); - ShapedType resultType = getXLAOpResultType(reshapeOp); + ShapedType resultType = getHloOpResultType(reshapeOp); if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); @@ -639,12 +640,12 @@ class ReshapeOpConverter : public OpConversionPattern { } }; -class IotaConverter : public OpConversionPattern { +class IotaConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::IotaOp iotaOp, ArrayRef args, + lmhlo::IotaOp iotaOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto resultMemrefType = iotaOp.getOperand().getType().dyn_cast(); @@ -680,19 +681,20 @@ class IotaConverter : public OpConversionPattern { } }; -class ConstConverter : public OpConversionPattern { +class ConstConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::ConstOp constOp, ArrayRef args, + lmhlo::ConstOp constOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = constOp.getLoc(); auto valueAttr = constOp.value().cast(); if (valueAttr.getType().getRank() != 0) return failure(); auto stdConstOp = rewriter.create(loc, valueAttr.getValue({})); - rewriter.create(loc, stdConstOp, constOp.getOperand()); + rewriter.create(loc, stdConstOp, constOp.getOperand(), + ValueRange()); rewriter.eraseOp(constOp); return success(); } @@ -708,7 +710,7 @@ class ReverseConverter isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { auto resultType = - getXLAOpResultType(op).template cast(); + getHloOpResultType(op).template cast(); auto nloops = resultType.getRank(); SmallVector inputExprs; inputExprs.reserve(nloops); @@ -726,12 +728,12 @@ class ReverseConverter } }; -class SliceConverter : public OpConversionPattern { +class SliceConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::SliceOp sliceOp, ArrayRef args, + lmhlo::SliceOp sliceOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = sliceOp.getLoc(); auto argType = @@ -763,50 +765,50 @@ class SliceConverter : public OpConversionPattern { void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off - patterns->insert, + patterns->insert, ConstConverter, ConvToLinalgConverter, IotaConverter, LhloBroadcastInDimConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, // TODO(ataei): Remove this pattern, CopyOp is folded away. - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - ReshapeOpConverter, - ReverseConverter, - ScalarPointwiseToStandardConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + ReshapeOpConverter, + ReverseConverter, + ScalarPointwiseToStandardConverter, SliceConverter >(context); // clang-format on } // Converts LHLO ops to Linalg generic. -// Sample result for xla_lhlo::AddOp. +// Sample result for lmhlo::AddOp. // -// "xla_lhlo.add"(%arg1, %arg2, %out) : +// "lmhlo.add"(%arg1, %arg2, %out) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // // will be converted to @@ -827,7 +829,8 @@ struct LhloLegalizeToLinalg void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); auto func = getFunction(); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); @@ -845,7 +848,7 @@ struct HloLegalizeToLinalg target.addLegalDialect(); auto func = getFunction(); - xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); + mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); if (failed(applyPartialConversion(func, target, patterns, nullptr))) { signalPassFailure(); } @@ -854,49 +857,49 @@ struct HloLegalizeToLinalg } // namespace -namespace xla_lhlo { +namespace lmhlo { std::unique_ptr> createLegalizeLhloToLinalgPass() { return absl::make_unique(); } static PassRegistration legalize_lhlo_pass( "lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect"); -} // namespace xla_lhlo +} // namespace lmhlo -namespace xla_hlo { +namespace mhlo { void populateHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { - patterns->insert, + patterns->insert, HloBroadcastInDimConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - ReshapeOpConverter, - ReverseConverter, - TransposeConverter>(context); + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + ReshapeOpConverter, + ReverseConverter, + TransposeConverter>(context); } std::unique_ptr> createLegalizeHloToLinalgPass() { @@ -905,5 +908,5 @@ std::unique_ptr> createLegalizeHloToLinalgPass() { static PassRegistration legalize_hlo_pass( "hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect"); -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc index 5dd37084bd876e..c71aa1d0460155 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file implements logic for lowering XLA dialect to Standard dialect. +// This file implements logic for lowering MHLO dialect to Standard dialect. #include "llvm/ADT/StringSwitch.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -28,14 +28,14 @@ namespace mlir { namespace { #include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc" } // end anonymous namespace -namespace xla_hlo { +namespace mhlo { namespace { -class CompareIConvert : public OpRewritePattern { +class CompareIConvert : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(xla_hlo::CompareOp op, + LogicalResult matchAndRewrite(mhlo::CompareOp op, PatternRewriter &rewriter) const override { auto lhs = op.lhs(); auto rhs = op.rhs(); @@ -68,11 +68,11 @@ class CompareIConvert : public OpRewritePattern { } }; -class CompareFConvert : public OpRewritePattern { +class CompareFConvert : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(xla_hlo::CompareOp op, + LogicalResult matchAndRewrite(mhlo::CompareOp op, PatternRewriter &rewriter) const override { auto lhs = op.lhs(); auto rhs = op.rhs(); @@ -109,11 +109,11 @@ class CompareFConvert : public OpRewritePattern { // convert the integer constant to iota result type. For complex types, the real // part is replaced with the generated constant and the imaginary part is // replaced with zero tensor. -class ConvertIotaOp : public OpRewritePattern { +class ConvertIotaOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(xla_hlo::IotaOp op, + LogicalResult matchAndRewrite(mhlo::IotaOp op, PatternRewriter &rewriter) const override { auto output_type = op.getType().cast(); auto output_size = output_type.getNumElements(); @@ -168,8 +168,7 @@ class ConvertIotaOp : public OpRewritePattern { loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0))); auto imag_zeroes = rewriter.create(loc, int_or_float_shape_ty, zeroes); - rewriter.replaceOpWithNewOp(op, iota_const, - imag_zeroes); + rewriter.replaceOpWithNewOp(op, iota_const, imag_zeroes); return success(); } }; @@ -188,8 +187,8 @@ std::unique_ptr> createLegalizeToStdPass() { return std::make_unique(); } -void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, - mlir::MLIRContext *ctx) { +void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, + mlir::MLIRContext *ctx) { mlir::populateWithGenerated(ctx, patterns); patterns->insert(ctx); } @@ -197,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, /// Perform the lowering to standard dialect. void LegalizeToStandard::runOnFunction() { OwningRewritePatternList patterns; - mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext()); + mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext()); applyPatternsAndFoldGreedily(getFunction(), patterns); } static PassRegistration legalize_pass( - "xla-legalize-to-std", "Legalize from XLA dialect to standard dialect"); + "mhlo-legalize-to-std", "Legalize from MHLO dialect to standard dialect"); -} // end namespace xla_hlo +} // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td index ee467a312d6ea9..0e6fdf06701ca3 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This is the legalization pattern definition file for XLA to StandardOps. +// This is the legalization pattern definition file for MHLO to StandardOps. include "mlir/IR/OpBase.td" include "mlir/Dialect/StandardOps/IR/Ops.td" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc index 145cd75b61c0b7..d2607887482dce 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // Removes LHLO copy operations that copy from allocated buffers to block @@ -34,7 +34,7 @@ struct LhloCopyRemoval : mlir::PassWrapper> { void runOnOperation() override { llvm::SmallVector eraseList; auto operation = getOperation(); - operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) { + operation->walk([&](mlir::lmhlo::CopyOp copyOp) { // If this region contains more than one block, then ignore this copy // operation. if (copyOp.getParentRegion()->getBlocks().size() > 1) { @@ -101,5 +101,5 @@ std::unique_ptr createLhloCopyRemovalPass() { static PassRegistration copy_removal_pass( "lhlo-copy-removal", "Removes redundant LHLO copy operations"); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index 5efb0fa78e547e..d832b96bf7b83b 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { using linalg::LinalgOp; @@ -147,5 +147,5 @@ static PassRegistration legalize_pass( "lhlo-fuse-linalg", "Greedily fuse linalg ops obtained after LHLO lowering."); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index e87125df86dd65..a353472be4b2ae 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -25,10 +25,10 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // Builds an affine loop nest iterating from zeros to "upper_bounds" with unit @@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern { auto r = builder.create(loc, rhs, rhs_indices); auto result = rewriter.create(loc, op.output(), result_indices); - Value op_result = xla_lhlo::XlaOpToStdScalarOp::map( + Value op_result = lmhlo::HloOpToStdScalarOp::map( op, element_type, {l, r, result}, &builder); map_status = success(op_result != nullptr); if (failed(map_status)) return; @@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern { ValueRange induction_vars) { auto l = builder.create(loc, lhs, induction_vars); auto r = builder.create(loc, rhs, induction_vars); - Value op_result = xla_lhlo::XlaOpToStdScalarOp::map( + Value op_result = lmhlo::HloOpToStdScalarOp::map( op, element_type, {l, r}, &builder); map_status = success(op_result != nullptr); if (failed(map_status)) return; @@ -127,13 +127,13 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, DotOpConverter>(context); // clang-format on } @@ -157,5 +157,5 @@ std::unique_ptr> createLegalizeToAffinePass() { static PassRegistration legalize_pass( "lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect"); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index 7489a092c27e31..0ff491a93c3762 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -35,10 +35,10 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // A simple translation of LHLO reduce operations to a corresponding gpu @@ -173,7 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); + gpu::GPUDialect, scf::SCFDialect, LmhloDialect>(); target.addIllegalOp(); auto func = getFunction(); patterns.insert(func.getContext()); @@ -192,5 +192,5 @@ std::unique_ptr> createLegalizeToGpuPass() { static PassRegistration legalize_pass( "lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect"); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc index 6ae3c334493ced..32606f068a8ae9 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { struct StaticMemRefCastOpConverter @@ -123,14 +123,139 @@ struct DynamicMemRefCastOpConverter } }; +struct ReshapeMemRefCastOpConverter + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + auto reshape_op = cast(op); + Type dst_type = reshape_op.getResult().getType(); + auto element_type = dst_type.cast().getElementType(); + + auto shape = reshape_op.shape(); + + ReshapeMemRefCastOp::Adaptor operands_adaptor(operands); + PtrsAndOffset ptrs_n_offset = ExtractMemRefPtrsAndOffset( + loc, reshape_op.operand(), operands_adaptor.operand(), &rewriter); + + MemRefDescriptor shape_desc(operands_adaptor.shape()); + + auto shape_memref_type = shape.getType().cast(); + + if (shape_memref_type.hasStaticShape()) { + auto shape_length = shape_memref_type.getDimSize(0); + + MemRefType targetMemRefType = MemRefType::get( + SmallVector(shape_length, 1), element_type); + auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return failure(); + // Create descriptor. + auto desc = + MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + desc.setAllocatedPtr(rewriter, loc, ptrs_n_offset.allocated_ptr); + desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr); + desc.setOffset(rewriter, loc, ptrs_n_offset.offset); + + auto llvmIndexTy = typeConverter.convertType(rewriter.getIndexType()) + .cast(); + auto llvmIndexTyPtr = llvmIndexTy.getPointerTo(); + Value stride_carried = rewriter.create( + loc, llvmIndexTy, + rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + for (int i = shape_length - 1; i >= 0; --i) { + Value pos = rewriter.create( + loc, llvmIndexTy, + rewriter.getIntegerAttr(rewriter.getIndexType(), i)); + Value ptr = rewriter.create( + loc, llvmIndexTyPtr, shape_desc.alignedPtr(rewriter, loc), + ValueRange{pos}); + Value extracted_size = rewriter.create(loc, ptr); + desc.setSize(rewriter, loc, i, extracted_size); + desc.setStride(rewriter, loc, i, stride_carried); + // Update stride + if (i > 0) { + stride_carried = + rewriter.create(loc, stride_carried, extracted_size); + } + } + if (dst_type.isa()) { + rewriter.replaceOp(op, {desc}); + } else { + Value rank = rewriter.create( + loc, llvmIndexTy, + rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length)); + Value alloca = + typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter); + Value void_ptr = + rewriter.create(loc, getVoidPtrType(), alloca); + auto unranked_desc = UnrankedMemRefDescriptor::pack( + rewriter, loc, typeConverter, dst_type.cast(), + {rank, void_ptr}); + rewriter.replaceOp(op, {unranked_desc}); + } + } else { + /* + * TODO(pifon, herhut): + * Compute strides with llvm.loop; + * Use UnrankedMemrefDescr::ComputeSize with Alloca; + * Set all the fields using getelementptr. + */ + return failure(); + } + return success(); + } + + private: + struct PtrsAndOffset { + Value allocated_ptr; + Value aligned_ptr; + Value offset; + }; + + PtrsAndOffset ExtractMemRefPtrsAndOffset( + Location loc, Value originalOperand, Value convertedOperand, + ConversionPatternRewriter *rewriter) const { + Type operandType = originalOperand.getType(); + Value descriptor_ptr; + if (operandType.isa()) { + descriptor_ptr = convertedOperand; + } else { + UnrankedMemRefDescriptor unranked_descriptor(convertedOperand); + Value underlying_desc_ptr = + unranked_descriptor.memRefDescPtr(*rewriter, loc); + + Type element_type = + operandType.cast().getElementType(); + LLVM::LLVMType memref_type_0d = + typeConverter.convertType(MemRefType::get(/*shape=*/{}, element_type)) + .cast(); + descriptor_ptr = rewriter->create( + loc, memref_type_0d.getPointerTo(), underlying_desc_ptr); + descriptor_ptr = rewriter->create(loc, descriptor_ptr); + } + MemRefDescriptor descriptor(descriptor_ptr); + PtrsAndOffset result; + result.allocated_ptr = descriptor.allocatedPtr(*rewriter, loc); + result.aligned_ptr = descriptor.alignedPtr(*rewriter, loc); + result.offset = descriptor.offset(*rewriter, loc); + return result; + } +}; + } // namespace void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, LLVMTypeConverter *converter, OwningRewritePatternList *patterns) { - patterns->insert( - *converter, options); + patterns->insert(*converter, options); } -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc index ade121423cf370..d6cda99a9123ca 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project @@ -23,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { class TestLhloToLLVMPass @@ -38,11 +40,14 @@ class TestLhloToLLVMPass populateStdToLLVMConversionPatterns(converter, patterns); PopulateLhloToLLVMConversionPatterns( LowerToLLVMOptions::getDefaultOptions(), &converter, &patterns); + mlir::populateLoopToStdConversionPatterns(patterns, &getContext()); + + mlir::populateAffineToStdConversionPatterns(patterns, m.getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); - target.addIllegalDialect(); + target.addIllegalDialect(); if (failed(applyFullConversion(m, target, patterns))) { signalPassFailure(); @@ -55,5 +60,5 @@ class TestLhloToLLVMPass static PassRegistration legalize_lhlo_pass( "test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM."); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index 6f4da98db65190..4255d87d48ef6f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // Clones and adapts the code in `lhlo_block` that works on buffers and has a @@ -154,14 +154,14 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, return b->create(loc, lower, upper, step); } -// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp. +// Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops if there are // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` // contains the reduction operator. // // Example: // -// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( { +// "lmhlo.reduce"(%buffer, %init_buf, %result) ( { // ^bb0(%lhs: memref, %rhs: memref, %res: memref): // // } ) {dimensions = dense<[1]> : tensor<1xi64>} @@ -187,27 +187,27 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, // } : f32 // scf.yield // } -class ReduceOpConverter : public OpConversionPattern { +class ReduceOpConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::ReduceOp xla_reduce_op, ArrayRef /*args*/, + lmhlo::ReduceOp reduce_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { // TODO(b/137624192) Implement variadic reduce. - if (xla_reduce_op.out().size() != 1) return failure(); + if (reduce_op.out().size() != 1) return failure(); - scf::ReduceOp reduce_op = - CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter); - ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op, - &xla_reduce_op.body().front(), &rewriter); - rewriter.replaceOp(xla_reduce_op, llvm::None); + scf::ReduceOp scf_reduce_op = + CreateReduceOpInNestedParallelLoops(reduce_op, &rewriter); + ConvertToReductionOperator(reduce_op.getLoc(), scf_reduce_op, + &reduce_op.body().front(), &rewriter); + rewriter.replaceOp(reduce_op, llvm::None); return success(); } private: // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp - // refers to the parallel dimensions of `xla_reduce_op` if any and the inner + // refers to the parallel dimensions of `reduce_op` if any and the inner // ParallelOp refers to the reduction dimensions. The scf.reduce op is // returned. // @@ -226,16 +226,15 @@ class ReduceOpConverter : public OpConversionPattern { // scf.yield // } scf::ReduceOp CreateReduceOpInNestedParallelLoops( - xla_lhlo::ReduceOp xla_reduce_op, - ConversionPatternRewriter* rewriter) const { - auto loc = xla_reduce_op.getLoc(); + lmhlo::ReduceOp reduce_op, ConversionPatternRewriter* rewriter) const { + auto loc = reduce_op.getLoc(); DenseSet reducing_dims; - for (const auto& rdim : xla_reduce_op.dimensions().getIntValues()) { + for (const auto& rdim : reduce_op.dimensions().getIntValues()) { reducing_dims.insert(rdim.getSExtValue()); } - Value operand = *xla_reduce_op.operands().begin(); - Value out = *xla_reduce_op.out().begin(); + Value operand = *reduce_op.operands().begin(); + Value out = *reduce_op.out().begin(); SmallVector parallel_lower, parallel_upper, parallel_step; SmallVector reduce_lower, reduce_upper, reduce_step; auto operand_shape = operand.getType().cast().getShape(); @@ -252,7 +251,7 @@ class ReduceOpConverter : public OpConversionPattern { } // Load initial value from memref. SmallVector init_value = { - rewriter->create(loc, *xla_reduce_op.init_values().begin())}; + rewriter->create(loc, *reduce_op.init_values().begin())}; // Outer ParallelOp is not needed if it is a reduction across all dims. scf::ParallelOp outer; if (!parallel_lower.empty()) { @@ -293,7 +292,7 @@ class ReduceOpConverter : public OpConversionPattern { rewriter->setInsertionPointToStart(inner.getBody()); Value elem = rewriter->create( - loc, *xla_reduce_op.operands().begin(), indices); + loc, *reduce_op.operands().begin(), indices); return rewriter->create(loc, elem); } }; @@ -314,7 +313,7 @@ class ReduceOpConverter : public OpConversionPattern { // accumulator = reduction_operator(output[O], value) // output[O] = accumulator // -// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a +// Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a // scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops that traverese output // buffer. The inner `ParalleOp` refers to the reduction loops that traverse @@ -325,11 +324,11 @@ class ReduceOpConverter : public OpConversionPattern { // func @reduce_window(%arg: memref<112x112xf32>, // %init: memref, // %result: memref<56x56xf32>) { -// "xla_lhlo.reduce_window"(%arg, %init, %result) ( { +// "lmhlo.reduce_window"(%arg, %init, %result) ( { // ^bb0(%lhs: memref, %rhs: memref, %res: memref): -// "xla_lhlo.maximum"(%lhs, %rhs, %res) +// "lmhlo.maximum"(%lhs, %rhs, %res) // : (memref, memref, memref) -> () -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.terminator"() : () -> () // }) { // padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, // window_dimensions = dense<[3, 3]> : tensor<2xi64>, @@ -359,47 +358,47 @@ class ReduceOpConverter : public OpConversionPattern { // return // } class ReduceWindowOpConverter - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, + lmhlo::ReduceWindowOp reduce_window_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { scf::ParallelOp output_loop, window_loop; std::tie(output_loop, window_loop) = - CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op, + CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op, &rewriter); scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( - xla_reduce_window_op, output_loop, window_loop, &rewriter); + reduce_window_op, output_loop, window_loop, &rewriter); - ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op, - &xla_reduce_window_op.body().front(), &rewriter); - rewriter.replaceOp(xla_reduce_window_op, llvm::None); + ConvertToReductionOperator(reduce_window_op.getLoc(), reduce_op, + &reduce_window_op.body().front(), &rewriter); + rewriter.replaceOp(reduce_window_op, llvm::None); return success(); } private: std::pair CreateParallelLoopsToTraverseOutputAndWindow( - xla_lhlo::ReduceWindowOp xla_reduce_window_op, + lmhlo::ReduceWindowOp reduce_window_op, ConversionPatternRewriter* rewriter) const { - auto loc = xla_reduce_window_op.getLoc(); + auto loc = reduce_window_op.getLoc(); Value init_value = - rewriter->create(loc, xla_reduce_window_op.init_value()); + rewriter->create(loc, reduce_window_op.init_value()); Value zero = rewriter->create(loc, 0); Value one = rewriter->create(loc, 1); // Create an outer parallel loop that spans the output of ReduceWindowOp. - Value xla_output = xla_reduce_window_op.out(); - auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter); + Value output = reduce_window_op.out(); + auto output_loop = MakeLoopOverShape(loc, output, rewriter); // Create a nested loop that traverses the window. SmallVector window_lower, window_upper, window_step; rewriter->setInsertionPointToStart(output_loop.getBody()); - for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) { + for (const auto& window_dim : reduce_window_op.window_dimensions()) { window_step.push_back(one); window_lower.push_back(zero); window_upper.push_back( @@ -410,39 +409,38 @@ class ReduceWindowOpConverter Value reduction_result = *window_loop.getResults().begin(); auto output_ivs = output_loop.getInductionVars(); - rewriter->create(loc, reduction_result, xla_output, output_ivs); + rewriter->create(loc, reduction_result, output, output_ivs); return std::make_pair(output_loop, window_loop); } scf::ReduceOp CreateReduceOpInNestedParallelLoops( - xla_lhlo::ReduceWindowOp xla_reduce_window_op, - scf::ParallelOp output_loop, scf::ParallelOp window_loop, - ConversionPatternRewriter* rewriter) const { + lmhlo::ReduceWindowOp reduce_window_op, scf::ParallelOp output_loop, + scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const { rewriter->setInsertionPointToStart(window_loop.getBody()); - auto loc = xla_reduce_window_op.getLoc(); + auto loc = reduce_window_op.getLoc(); - if (xla_reduce_window_op.base_dilations().hasValue() || - xla_reduce_window_op.window_dilations().hasValue()) { - xla_reduce_window_op.emitRemark( + if (reduce_window_op.base_dilations().hasValue() || + reduce_window_op.window_dilations().hasValue()) { + reduce_window_op.emitRemark( "Lowering to parallel loops does not support `base_dilations` or " "`window_dilations` attributes yet. The attributes will be ignored."); } - Value xla_operand = xla_reduce_window_op.operand(); - auto xla_operand_type = xla_operand.getType().cast(); + Value operand = reduce_window_op.operand(); + auto operand_type = operand.getType().cast(); // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not. - MappedIvs mapped_ivs = MapWindowIvsToInput( - xla_reduce_window_op, output_loop.getInductionVars(), - window_loop.getInductionVars(), rewriter); + MappedIvs mapped_ivs = + MapWindowIvsToInput(reduce_window_op, output_loop.getInductionVars(), + window_loop.getInductionVars(), rewriter); auto elem_or_init = rewriter->create( - loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds, + loc, operand_type.getElementType(), mapped_ivs.in_bounds, /*withElseRegion=*/true); OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); Value elem = then_builder.create( - loc, xla_reduce_window_op.operand(), mapped_ivs.ivs); + loc, reduce_window_op.operand(), mapped_ivs.ivs); then_builder.create(loc, elem); OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); @@ -481,12 +479,12 @@ class ReduceWindowOpConverter // initialized_flag = true // output(selected_index) = scatter(output(selected_index), source(S)) class SelectAndScatterOpConverter - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef /*args*/, + lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { auto loc = s_and_s_op.getLoc(); InitializeOutput(s_and_s_op, &rewriter); @@ -515,7 +513,7 @@ class SelectAndScatterOpConverter } private: - void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op, + void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); Value init_value = b->create(loc, s_and_s_op.init_value()); @@ -533,7 +531,7 @@ class SelectAndScatterOpConverter SmallVector window_ivs; scf::ForOp inner_loop; }; - WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op, + WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op, scf::ParallelOp loop_over_src, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); @@ -598,7 +596,7 @@ class SelectAndScatterOpConverter SmallVector ivs_val_flag_; }; - SmallVector SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op, + SmallVector SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op, scf::ParallelOp loop_over_src, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); @@ -636,9 +634,10 @@ class SelectAndScatterOpConverter return window_loops.selected_ivs; } - SmallVector SelectOrInitialize( - xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef operand_ivs, - IterArgs* ivs_val_flag, OpBuilder* b) const { + SmallVector SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op, + ArrayRef operand_ivs, + IterArgs* ivs_val_flag, + OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); Value true_i1 = b->create( loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1)); @@ -707,9 +706,9 @@ struct LhloLegalizeToParallelLoops ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + scf::SCFDialect, LmhloDialect>(); + target.addIllegalOp(); if (failed(applyPartialConversion(func, target, patterns))) { signalPassFailure(); @@ -727,5 +726,5 @@ static PassRegistration legalize_lhlo_pass( "lhlo-legalize-to-parallel-loops", "Legalize from LHLO dialect to parallel loops."); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc index ae19953371b1a1..54ea4955573a67 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc @@ -45,13 +45,13 @@ class LowerComplex : public PassWrapper { public: explicit LowerComplex() : PassWrapper() {} - /// Performs the lowering to XLA dialect. + /// Performs the lowering to MHLO dialect. void runOnFunction() override; }; } // end anonymous namespace namespace mlir { -namespace xla { +namespace hlo { namespace { #include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc" @@ -62,18 +62,18 @@ void PopulateComplexLoweringPatterns(MLIRContext* context, OwningRewritePatternList* patterns) { populateWithGenerated(context, patterns); } -} // end namespace xla +} // end namespace hlo } // end namespace mlir // Lowers the complex operations that can be represented using other operations. void LowerComplex::runOnFunction() { // Add lowering patterns to the list. OwningRewritePatternList patterns; - mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns); + mlir::hlo::PopulateComplexLoweringPatterns(&getContext(), &patterns); applyPatternsAndFoldGreedily(getFunction(), patterns); } static PassRegistration pass( - "test-xla-lower-complex", + "mhlo-test-lower-complex", "Lower complex operations into non-complex operations"); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc index 0c308bfc75d122..32a6ce42e5e1ab 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file implements logic for lowering XLA general dot to a regular dot. +// This file implements logic for lowering MHLO general dot to a regular dot. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" @@ -84,14 +84,14 @@ Value TransposeReshape(Value arg, mlir::Location loc, transposed_shape.push_back(arg_shape[val]); } auto transpose_type = RankedTensorType::get(transposed_shape, element_type); - auto transpose_result = rewriter->create( + auto transpose_result = rewriter->create( loc, transpose_type, arg, transpose_permutation_attr); // Return the final result. auto reshaped_type = RankedTensorType::get({left_size, right_size}, element_type); - return rewriter->create(loc, reshaped_type, - transpose_result); + return rewriter->create(loc, reshaped_type, + transpose_result); } Value ProcessDotArg(Value arg, mlir::Location loc, @@ -125,8 +125,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc, return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter); } -struct GeneralDotConvert - : public OpRewritePattern { +struct GeneralDotConvert : public OpRewritePattern { // Attempts to lower a General Dot operator to a standard Dot operator. // General dots include batching dimensions and can have collapsing // dimensions along any axis. Inserting correctly arrange transpose and @@ -138,7 +137,7 @@ struct GeneralDotConvert explicit GeneralDotConvert(MLIRContext *context) : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op, + LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op, PatternRewriter &rewriter) const override { auto dot_element_type = mlir::getElementTypeOrSelf(op); @@ -162,11 +161,11 @@ struct GeneralDotConvert auto new_dot_type = RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); - auto new_dot_op = rewriter.create( + auto new_dot_op = rewriter.create( op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config())); - rewriter.replaceOpWithNewOp(op, op.getType(), - new_dot_op); + rewriter.replaceOpWithNewOp(op, op.getType(), + new_dot_op); return success(); } }; @@ -176,19 +175,18 @@ struct LegalizeGeneralDot /// Lower all general dots that can be represented as a non-batched matmul. void runOnFunction() override { OwningRewritePatternList patterns; - mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, - &getContext()); + mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext()); applyPatternsAndFoldGreedily(getFunction(), patterns); } }; } // namespace -void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns( +void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns( OwningRewritePatternList *patterns, MLIRContext *ctx) { patterns->insert(ctx); } static PassRegistration legalize_pass( - "test-xla-lower-general-dot", + "mhlo-test-lower-general-dot", "Tests lowering general dot to a non-batched dot when possible"); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc index d536b35c4569ca..c2f88ad5e3104e 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { @@ -86,5 +86,5 @@ void PopulateMaterializeBroadcastsPatterns(MLIRContext *context, patterns->insert(context); } -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc index 0c685ba31b5990..1d5d593bd43271 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { @@ -33,8 +33,8 @@ struct TestMaterializeBroadcastsPass ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns; - // Consider the xla_hlo dialect legal for tests. - conversionTarget.addLegalDialect(); + // Consider the mhlo dialect legal for tests. + conversionTarget.addLegalDialect(); // The conversion uses helpers from the Standard dialect. conversionTarget.addLegalDialect(); @@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass } // namespace -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir -static mlir::PassRegistration - pass("test-xla-materialize-broadcasts", - "Test pass for materializing 'broadcast_dimensions' attributes"); +static mlir::PassRegistration pass( + "mhlo-test-materialize-broadcasts", + "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_fusion.cc similarity index 97% rename from tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc rename to tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_fusion.cc index ce53d4fa1dfd1a..91f9344b8c5554 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_fusion.cc @@ -60,7 +60,7 @@ limitations under the License. // shape dialect once it is ready. namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { using llvm::EquivalenceClasses; @@ -479,7 +479,7 @@ class FusionPlanner { EquivalenceClasses leader_for_node_; }; -struct XlaHloFusion : public mlir::PassWrapper { +struct MhloFusion : public mlir::PassWrapper { void runOnFunction() override { FuncOp func = getFunction(); if (!IsTargetFunc(func)) { @@ -544,7 +544,7 @@ struct XlaHloFusion : public mlir::PassWrapper { } FusionOp fusion = - b.create(fused_loc, output_types, inputs); + b.create(fused_loc, output_types, inputs); Region& region = fusion.fused_computation(); region.push_back(new Block); Block& block = region.front(); @@ -552,7 +552,7 @@ struct XlaHloFusion : public mlir::PassWrapper { op->moveBefore(&block, block.end()); } b.setInsertionPoint(&block, block.end()); - b.create(fused_loc, outputs); + b.create(fused_loc, outputs); for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) { Value output = std::get<0>(output_and_result); @@ -568,12 +568,12 @@ struct XlaHloFusion : public mlir::PassWrapper { } // namespace -std::unique_ptr> createXlaHloFusion() { - return std::make_unique(); +std::unique_ptr> createMhloFusion() { + return std::make_unique(); } -static PassRegistration xla_hlo_fusion_pass( - "xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns."); +static PassRegistration mhlo_fusion_pass( + "mhlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns."); -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc index 16ad47f0ce277f..b05918030e9435 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { @@ -71,7 +71,7 @@ class SinkConstantsToControlFlow }; static mlir::PassRegistration pass( - "xla-hlo-sink-constants-to-control-flow", + "mhlo-sink-constants-to-control-flow", "Sink constants implicitly captured in control flow regions. This is " "necessary to export to XLA."); @@ -81,5 +81,5 @@ std::unique_ptr> createSinkConstantsToControlFlowPass() { return std::make_unique(); } -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc index 71441656c080bd..184420bb8f70cb 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc @@ -22,12 +22,12 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { -namespace xla { +namespace hlo { namespace { struct InferReturnTypeComponentsPattern : public RewritePattern { InferReturnTypeComponentsPattern(MLIRContext *context) - : RewritePattern("xla_test.get_return_type_components", 1, context) {} + : RewritePattern("mhlo_test.get_return_type_components", 1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getNumOperands() != 1) return failure(); @@ -44,7 +44,7 @@ struct InferReturnTypeComponentsPattern : public RewritePattern { } // Replace the op with another pass-through op with attributes added. - OperationState state(op->getLoc(), "xla_test.return_type_components", + OperationState state(op->getLoc(), "mhlo_test.return_type_components", op->getOperands(), op->getResultTypes(), op->getAttrs()); auto new_op = rewriter.createOperation(state); @@ -65,7 +65,7 @@ struct InferReturnTypeComponentsPattern : public RewritePattern { struct ReifyReturnTypeShapesPattern : public RewritePattern { ReifyReturnTypeShapesPattern(MLIRContext *context) - : RewritePattern("xla_test.reify_return_type_shapes", 1, context) {} + : RewritePattern("mhlo_test.reify_return_type_shapes", 1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getNumOperands() != 1) return failure(); @@ -92,9 +92,9 @@ struct TestInferShapedTypeMethodsPass }; } // namespace -} // namespace xla +} // namespace hlo } // namespace mlir -static mlir::PassRegistration pass( - "test-xla-infer-shaped-type-methods", +static mlir::PassRegistration pass( + "mhlo-test-infer-shaped-type-methods", "Uses test ops to invoke InferShapedTypeOpInterface methods"); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc similarity index 97% rename from tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc rename to tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 2b9ea182cf223f..53947855cc7468 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { // TODO(frgossen): Make it variadic. @@ -69,7 +69,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern { rewriter.create(loc, numElementsAsIndex); auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, operandTy.getElementType()); - Value flatOperand = rewriter.create( + Value flatOperand = rewriter.create( loc, flatTensorTy, operand, flatShapeAsDimTensor); // Generate IR for the actual operation. @@ -80,7 +80,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern { rewriter.getIndexType()); Value shapeAsExtentTensor = rewriter.create(loc, extentTensorTy, shape); - Value result = rewriter.create( + Value result = rewriter.create( loc, operandTy, flatResult, shapeAsExtentTensor); rewriter.replaceOp(op, result); @@ -152,7 +152,7 @@ struct TransformUnrankedHloPass // Setup conversion target. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target.addLegalDialect(); target.addLegalOp(); AddLegalOpOnRankedTensor(&target); @@ -184,5 +184,5 @@ static PassRegistration transform_unranked_hlo_pass( "transform-unranked-hlo", "Realize element-wise operations on ranked tensors where possible"); -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc index 8d501a8dbdddd5..09c9c61119ee48 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { @@ -40,12 +40,12 @@ Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type, auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64)); auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim}); if (shape_value) { - return rewriter.createOrFold( + return rewriter.createOrFold( loc, result_type, value_1d, shape_value, dims); } assert(result_type.hasStaticShape()); - return rewriter.create(loc, result_type, value_1d, - dims); + return rewriter.create(loc, result_type, value_1d, + dims); } // Calculate the shape value of operand, assuming it is a dynamic shape with @@ -89,25 +89,25 @@ Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr, auto epsilon_tensor_attr = DenseElementsAttr::get(scalar_type, {epsilon_attr.cast()}); Value epsilon = - rewriter.create(op->getLoc(), epsilon_tensor_attr); + rewriter.create(op->getLoc(), epsilon_tensor_attr); auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64)); auto dims = DenseIntElementsAttr::get(dims_type, SmallVector{}); if (broadcast_to_type.hasStaticShape()) { - return rewriter.create( + return rewriter.create( op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims); } Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter); - return rewriter.createOrFold( + return rewriter.createOrFold( op->getLoc(), broadcast_to_type, epsilon, shape_value, /*broadcast_dims=*/dims); } class UnfuseBatchNormInferencePattern - : public OpRewritePattern { + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op, + LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op, PatternRewriter& rewriter) const override { // Enforce type invariants. // Note that we deduce the actual element type from the variance, @@ -132,9 +132,9 @@ class UnfuseBatchNormInferencePattern if (!epsilon) { return failure(); } - Value stddev = rewriter.create(bn_op.getLoc(), - bn_op.variance(), epsilon); - stddev = rewriter.create(bn_op.getLoc(), stddev); + Value stddev = + rewriter.create(bn_op.getLoc(), bn_op.variance(), epsilon); + stddev = rewriter.create(bn_op.getLoc(), stddev); // Broadcast all terms. Value shape_value; @@ -156,14 +156,13 @@ class UnfuseBatchNormInferencePattern // Compute: // scale * (input - mean) / stddev + offset - Value result = rewriter.create( - bn_op.getLoc(), bn_op.operand(), broadcast_mean); - result = rewriter.create(bn_op.getLoc(), result, - broadcast_scale); - result = rewriter.create(bn_op.getLoc(), result, - broadcast_stddev); - rewriter.replaceOpWithNewOp(bn_op, result, - broadcast_offset); + Value result = rewriter.create(bn_op.getLoc(), bn_op.operand(), + broadcast_mean); + result = + rewriter.create(bn_op.getLoc(), result, broadcast_scale); + result = + rewriter.create(bn_op.getLoc(), result, broadcast_stddev); + rewriter.replaceOpWithNewOp(bn_op, result, broadcast_offset); return success(); } @@ -180,5 +179,5 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext* context, patterns->insert(context); } -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc index df7ebf19b93bbc..c26d73f3306e02 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { @@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass } // namespace -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir -static mlir::PassRegistration pass( - "test-xla-unfuse-batch-norm", +static mlir::PassRegistration pass( + "mhlo-test-unfuse-batch-norm", "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc b/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc index 7037013a78f9b5..e05ec3c3481f5b 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc @@ -24,7 +24,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project namespace mlir { -namespace xla { +namespace hlo { bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, DenseIntElementsAttr broadcast_dims) { @@ -70,5 +70,5 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, result_shape_v); } -} // namespace xla +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc b/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc index e5f3f65c60aa83..ea074c4907d23f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc @@ -22,7 +22,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project namespace mlir { -namespace xla { +namespace hlo { mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, mlir::Type new_type) { @@ -82,5 +82,5 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, })); } -} // namespace xla +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc b/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc index 331b4000740276..184d113fb9d230 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc @@ -20,7 +20,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project namespace mlir { -namespace xla { +namespace hlo { DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y, bool allow_empty) { @@ -66,5 +66,5 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { return DenseElementsAttr::get(scalar_ty, value); } -} // namespace xla +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/tests/BUILD b/tensorflow/compiler/mlir/hlo/tests/BUILD new file mode 100644 index 00000000000000..2c3150a217af3f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/BUILD @@ -0,0 +1,19 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package(licenses = ["notice"]) + +glob_lit_tests( + data = [":test_utilities"], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = ["mlir"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir/hlo:mlir-hlo-opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir new file mode 100644 index 00000000000000..87774129ffb3a8 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -0,0 +1,544 @@ +// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s + +// CHECK-LABEL: add_fold +func @add_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> + %1 = mhlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64> + // CHECK: mhlo.constant dense<[6, 8, 10, 12]> + %2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: add_scalar_fold +func @add_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<1> : tensor<4xi64> + %1 = mhlo.constant dense<5> : tensor<4xi64> + // CHECK: mhlo.constant dense<6> + %2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: add_fold_float +func @add_fold_float() -> tensor<4xf64> { + %0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> + %1 = mhlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64> + // CHECK: mhlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]> + %2 = "mhlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: sub_scalar_fold +func @sub_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<5> : tensor<4xi64> + %1 = mhlo.constant dense<1> : tensor<4xi64> + // CHECK: mhlo.constant dense<4> + %2 = "mhlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: multiply_scalar_fold +func @multiply_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<5> : tensor<4xi64> + %1 = mhlo.constant dense<3> : tensor<4xi64> + // CHECK: mhlo.constant dense<15> + %2 = "mhlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: divide_scalar_fold +func @divide_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<7> : tensor<4xi64> + %1 = mhlo.constant dense<5> : tensor<4xi64> + // CHECK: mhlo.constant dense<1> + %2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: divide_fold_float +func @divide_fold_float() -> tensor<4xf64> { + %0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: mhlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]> + %2 = "mhlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: max_scalar_fold +func @max_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<7> : tensor<4xi64> + %1 = mhlo.constant dense<5> : tensor<4xi64> + // CHECK: mhlo.constant dense<7> + %2 = "mhlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: max_fold_float +func @max_fold_float() -> tensor<4xf64> { + %0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: mhlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]> + %2 = "mhlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: min_scalar_fold +func @min_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<7> : tensor<4xi64> + %1 = mhlo.constant dense<-5> : tensor<4xi64> + // CHECK: mhlo.constant dense<-5> + %2 = "mhlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: min_fold_float +func @min_fold_float() -> tensor<4xf64> { + %0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]> + %2 = "mhlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: concatenate_noop +func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> + %0 = "mhlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32> + + // CHECK: return [[ARG]] + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_remove_operand +func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> { + // CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32> + // CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32> + + // CHECK: return [[ARG0]] + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_empty_bool +func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> { + // CHECK: mhlo.constant + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> + + return %0 : tensor<0xi1> +} + +// CHECK-LABEL: concatenate_empty_int +func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> { + // CHECK: mhlo.constant + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32> + + return %0 : tensor<0xi32> +} + +// CHECK-LABEL: concatenate_empty_float +func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { + // CHECK: mhlo.constant + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> + + return %0 : tensor<0xf32> +} + +// CHECK-LABEL: concatenate_const_1D +func @concatenate_const_1D() -> tensor<4xi32> { + // CHECK: [[VAL:%.+]]= mhlo.constant dense<[0, 1, 2, 3]> + %0 = mhlo.constant dense<[0, 1]> : tensor<2xi32> + %1 = mhlo.constant dense<[2, 3]> : tensor<2xi32> + %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_const_1D_float +func @concatenate_const_1D_float() -> tensor<4xf32> { + // CHECK: [[VAL:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> + + %0 = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32> + %1 = mhlo.constant dense<[2.0, 3.0]> : tensor<2xf32> + %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32> + + // CHECK: return [[VAL]] + return %2 : tensor<4xf32> +} + +// CHECK-LABEL: concatenate_const_2D_vertical +func @concatenate_const_2D_vertical() -> tensor<2x2xi32> { + // CHECK: [[VAL:%.+]]= mhlo.constant dense<[ + // CHECK-SAME: [0, 1], [2, 3] + // CHECK-SAME: ]> + %0 = mhlo.constant dense<[[0, 1]]> : tensor<1x2xi32> + %1 = mhlo.constant dense<[[2, 3]]> : tensor<1x2xi32> + %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<2x2xi32> +} + +// CHECK-LABEL: concatenate_const_2D_horizontal +func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { + // CHECK: [[VAL:%.+]]= mhlo.constant dense<[ + // CHECK-SAME: [0, 2], [1, 3] + // CHECK-SAME: ]> + %0 = mhlo.constant dense<[[0], [1]]> : tensor<2x1xi32> + %1 = mhlo.constant dense<[[2], [3]]> : tensor<2x1xi32> + %2 = "mhlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<2x2xi32> +} + +// CHECK-LABEL: dynamic_slice_variable_start +func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + // CHECK: "mhlo.dynamic-slice" + %1 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + return %1 : tensor<1x4xi32> +} + +// CHECK-LABEL: dynamic_slice_constant_start +func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { + // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) + // CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64> + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} + // CHECK: return %[[RESULT]] : tensor<2xi32> + %0 = mhlo.constant dense<1> : tensor + %1 = "mhlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> + return %1 : tensor<2xi32> +} + +// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape +func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor { + // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) + // CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64> + // CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64> + // CHECK: return %[[RESULT]] : tensor + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + %2 = "mhlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: slice_2D_noop +// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64> +func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { + %0 = "mhlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>) + + // CHECK-NEXT: return [[ARG]] + return %0 : tensor<2x2xi64> +} + +// CHECK-LABEL: slice_1D_fold +func @slice_1D_fold() -> tensor<2xi64> { + %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> + // CHECK: mhlo.constant dense<[7, 9]> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) + return %1 : tensor<2xi64> +} + +// CHECK-LABEL: slice_1D_fp +func @slice_1D_fp() -> tensor<2xf32> { + %0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32> + // CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>) + return %1 : tensor<2xf32> +} + +// CHECK-LABEL: slice_1D_strided_fold +func @slice_1D_strided_fold() -> tensor<2xi64> { + %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> + // CHECK: mhlo.constant dense<[7, 10]> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) + return %1 : tensor<2xi64> +} + +// CHECK-LABEL: slice_2D_fold +func @slice_2D_fold() -> tensor<2x2xi64> { + %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: mhlo.constant dense<[ + // CHECK-SAME: [6, 7], + // CHECK-SAME: [10, 11] + // CHECK-SAME: ]> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>) + return %1 : tensor<2x2xi64> +} + +// CHECK-LABEL: slice_2D_fold_horizontal +func @slice_2D_fold_horizontal() -> tensor<1x4xi64> { + %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: mhlo.constant dense<[ + // CHECK-SAME: [0, 1, 2, 3] + // CHECK-SAME: ]> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>) + return %1 : tensor<1x4xi64> +} + +// CHECK-LABEL: slice_2D_fold_vertical +func @slice_2D_fold_vertical() -> tensor<4x1xi64> { + %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: mhlo.constant dense<[ + // CHECK-SAME: [2], [6], [10], [14] + // CHECK-SAME: ]> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>) + return %1 : tensor<4x1xi64> +} + +// CHECK-LABEL: slice_concat_fold_first +func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) + // CHECK: return %arg0 + return %1 : tensor<1x5xf32> +} + +// CHECK-LABEL: slice_concat_fold_second +func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) + // CHECK: return %arg1 + return %1 : tensor<1x5xf32> +} + +// CHECK-LABEL: slice_concat_fold_second_with_slice +func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> { + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>) + + // CHECK: return [[SLICE]] + return %1 : tensor<1x4xf32> +} + +// CHECK-LABEL: slice_concat_fold_middle +func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> + // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + %1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>) + + // CHECK: return [[SLICE]] + return %1 : tensor<1x5xf32> +} + +// CHECK-LABEL: slice_concat_fold_two +func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> { + // CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64} + %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> + + // CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + %1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>) + + // CHECK: return [[SLICE]] + return %1 : tensor<2x5xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_identity +func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + // CHECK: return %arg0 + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + return %0 : tensor<2x3x4xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts +func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { + // CHECK: mhlo.broadcast_in_dim + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation +func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: mhlo.broadcast_in_dim + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic +func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { + // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32> + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> + // CHECK: return %[[RESULT]] : tensor<5x4xf32> + return %0 : tensor<5x4xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d +func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { + %cst = mhlo.constant dense<0.000000e+00> : tensor + %b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x64x224x224xf32> + return %b : tensor<1x64x224x224xf32> +} +// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x224x224xf32> +// CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xf32> + +// CHECK-LABEL: func @broadcast_in_dim_constant_fold +func @broadcast_in_dim_constant_fold() -> tensor<1x64x4x4xf32> { + %cst = mhlo.constant dense<0.000000e+00> : tensor<4x4xf32> + %b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32> + return %b : tensor<1x64x4x4xf32> +} +// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x4x4xf32> +// CHECK-NEXT: return %[[CST]] : tensor<1x64x4x4xf32> + +// CHECK-LABEL: @complex_expand_fold +func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex>) + %1 = "mhlo.real"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + %2 = "mhlo.imag"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + // CHECK: return %arg0, %arg1 + return %1, %2 : tensor<4xf32>, tensor<4xf32> +} + +// CHECK-LABEL: @complex_collapse_fold +func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + %0 = "mhlo.real"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + %1 = "mhlo.imag"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + %2 = "mhlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + // CHECK: return %arg0 + return %2 : tensor<4xcomplex> +} + +// CHECK-LABEL: @dynamic_iota_is_static +func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> { + // CHECK: [[RESULT:%.*]] = "mhlo.iota" + // CHECK: return [[RESULT]] + %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: @dynamic_iota_broadcast +func @dynamic_iota_broadcast(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { + // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32> + // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>, tensor<2xindex>) -> tensor<5x?xi32> + %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<5x?xi32> + + // CHECK: return [[BROADCAST]] + return %0 : tensor<5x?xi32> +} + +// CHECK-LABEL: @dynamic_iota_broadcast_second +func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { + // CHECK-NEXT: [[CAST1:%.+]] = index_cast %arg0 : tensor<2xindex> to tensor<2xi64> + // CHECK-NEXT: [[SLICE:%.+]] = "mhlo.slice"([[CAST1]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK-NEXT: [[CAST2:%.+]] = index_cast [[SLICE]] : tensor<1xi64> to tensor<1xindex> + // CHECK-NEXT: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[CAST2]]) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor + // CHECK-NEXT: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor<5x?xi32> + %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<5x?xi32> + + // CHECK: return [[BROADCAST]] + return %0 : tensor<5x?xi32> +} + +// CHECK-LABEL: @dynamic_iota_constant +func @dynamic_iota_constant(%arg0 : tensor<2xindex>) -> tensor<1x?xi32> { + // CHECK: [[IOTA:%.+]] = mhlo.constant dense<0> : tensor<1xi32> + // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xi32>, tensor<2xindex>) -> tensor<1x?xi32> + %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<1x?xi32> + + // CHECK: return [[BROADCAST]] + return %0 : tensor<1x?xi32> +} + +// CHECK-LABEL: @iota_constant +func @iota_constant() -> tensor<1xi32> { + // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1xi32> + + // CHECK: return [[CONST]] : tensor<1xi32> + return %0 : tensor<1xi32> +} + +// CHECK-LABEL: @iota_constant_multi +func @iota_constant_multi() -> tensor<1x4xi32> { + // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1x4xi32> + + // CHECK: return [[CONST]] : tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// CHECK-LABEL: @iota_not_lowered_to_constant +func @iota_not_lowered_to_constant() -> tensor<4xi32> { + // CHECK: [[RESULT:%.*]] = "mhlo.iota" + // CHECK: return [[RESULT]] + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: @iota_broadcast +func @iota_broadcast() -> tensor<5x4xi32> { + // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32> + // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>) -> tensor<5x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5x4xi32> + + return %0 : tensor<5x4xi32> +} + +// CHECK-LABEL: @iota_broadcast +func @iota_broadcast_second() -> tensor<5x4xi32> { + // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<5x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<5x4xi32> + + return %0 : tensor<5x4xi32> +} + +// CHECK-LABEL: @unary_einsum +func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { + // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: "mhlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"} + %0 = "mhlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: func @fold_copy +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { + // CHECK: return [[ARG]] + %0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic +func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> { + // CHECK: mhlo.reshape + %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32> + return %0 : tensor<4x1xf32> +} + +// CHECK-LABEL: do_not_dce_while_with_outfeed +func @do_not_dce_while_with_outfeed(%arg0: tensor) -> tensor { + // CHECK: mhlo.while + %0 = "mhlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + %1 = "mhlo.create_token"() : () -> !mhlo.token + // Side-effecting op outfeed present inside while. + %2 = "mhlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor, !mhlo.token) -> !mhlo.token + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + + return %arg0 : tensor +} + +// CHECK-LABEL: dce_while_without_side_effect +func @dce_while_without_side_effect(%arg0: tensor) -> tensor { + // CHECK-NOT: mhlo.while + %0 = "mhlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + %1 = "mhlo.create_token"() : () -> !mhlo.token + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + + return %arg0 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir similarity index 51% rename from tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir rename to tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir index db36686717b218..65074325563e12 100644 --- a/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-test-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s // CHECK-LABEL: @broadcast_add // Note that all broadcast_ops are expanded from the same template, so @@ -11,46 +11,46 @@ func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xinde // CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]] // CHECK: return %[[EXTENTS]] - %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor - %1 = "xla_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> return %1 : tensor<1xindex> } // ----- // CHECK-LABEL: @complex_ranked_components func @complex_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor> { - %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> - // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex} - %1 = "xla_test.get_return_type_components"(%0) : (tensor>) -> tensor> + %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> + // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex} + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor>) -> tensor> return %1 : tensor> } // ----- // CHECK-LABEL: @compare_ranked_components func @compare_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor - // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1} - %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1} + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor) -> tensor return %0 : tensor } // ----- // CHECK-LABEL: @broadcast_add_ranked_components_r1 func @broadcast_add_ranked_components_r1(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor - // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32} - %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32} + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor) -> tensor return %1 : tensor } // ----- // CHECK-LABEL: @broadcast_add_ranked_components_r1x2 func @broadcast_add_ranked_components_r1x2(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor // TODO: Overly broad shapes are being returned. Tighten the calculation // and update/extend these tests. - // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32} - %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32} + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor) -> tensor return %1 : tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir similarity index 60% rename from tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir rename to tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir index 65285021fd47db..2c0e2d7f17078f 100644 --- a/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -1,11 +1,11 @@ -// RUN: xla-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s // Check the non-broadcast case for each registered op, then just check a // representative op for detailed broadcast semantics. // CHECK-LABEL: @addWithoutBroadcast func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.add %arg0, %arg1 - %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.add %arg0, %arg1 + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -20,13 +20,13 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor : tensor<1xi64>} - // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]] + // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]] // CHECK-NEXT: shape.assuming_yield %[[RESULT]] // CHECK-NEXT: } // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor - %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor return %0 : tensor } @@ -41,13 +41,13 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> + // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> // CHECK-NEXT: shape.assuming_yield %[[RESULT]] // CHECK-NEXT: } // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor> - %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> + %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> return %0 : tensor> } @@ -62,13 +62,13 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor // CHECK: shape.assuming_yield %[[RESULT]] // CHECK-NEXT: } // CHECK: return %[[FINAL_RESULT]] : tensor - %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor return %0 : tensor } @@ -76,8 +76,8 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // Verifies that broadcast_dimensions validity checks are valid. // CHECK-LABEL: @dynamicNonScalarBroadcastDimensions func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK: xla_hlo.add - %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + // CHECK: mhlo.add + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -85,8 +85,8 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor< // Verifies that broadcast_dimensions validity checks are valid. // CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor) -> tensor<1x4xf32> { - // CHECK: xla_hlo.add - %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> + // CHECK: mhlo.add + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -95,7 +95,7 @@ func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} // expected-error @+1 {{failed to legalize operation}} - %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -104,7 +104,7 @@ func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %a func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} // expected-error @+1 {{failed to legalize operation}} - %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -113,127 +113,127 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: // expansions. Tests below merely verify that the op has an expansion. // CHECK-LABEL: @andWithoutBroadcast func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: xla_hlo.and %arg0, %arg1 - %0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + // CHECK: mhlo.and %arg0, %arg1 + %0 = chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } // ----- // CHECK-LABEL: @atan2WithoutBroadcast func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.atan2 %arg0, %arg1 - %0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.atan2 %arg0, %arg1 + %0 = chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- // CHECK-LABEL: @compareWithoutBroadcast func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> { - // CHECK: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> return %0 : tensor<4xi1> } // ----- // CHECK-LABEL: @complexWithoutBroadcast func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex> { - // CHECK: "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> - %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + // CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> return %0 : tensor<4xcomplex> } // ----- // CHECK-LABEL: @divideWithoutBroadcast func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.divide %arg0, %arg1 - %0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.divide %arg0, %arg1 + %0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- // CHECK-LABEL: @maximumWithoutBroadcast func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.maximum %arg0, %arg1 - %0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.maximum %arg0, %arg1 + %0 = chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- // CHECK-LABEL: @minimumWithoutBroadcast func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.minimum %arg0, %arg1 - %0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.minimum %arg0, %arg1 + %0 = chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- // CHECK-LABEL: @multiplyWithoutBroadcast func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.multiply %arg0, %arg1 - %0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.multiply %arg0, %arg1 + %0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- // CHECK-LABEL: @orWithoutBroadcast func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: xla_hlo.or %arg0, %arg1 - %0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + // CHECK: mhlo.or %arg0, %arg1 + %0 = chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } // ----- // CHECK-LABEL: @powerWithoutBroadcast func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.power %arg0, %arg1 - %0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.power %arg0, %arg1 + %0 = chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- // CHECK-LABEL: @remainderWithoutBroadcast func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.remainder %arg0, %arg1 - %0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.remainder %arg0, %arg1 + %0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- // CHECK-LABEL: @shift_leftWithoutBroadcast func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.shift_left %arg0, %arg1 - %0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.shift_left %arg0, %arg1 + %0 = chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- // CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 - %0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.shift_right_arithmetic %arg0, %arg1 + %0 = chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- // CHECK-LABEL: @shift_right_logicalWithoutBroadcast func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.shift_right_logical %arg0, %arg1 - %0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.shift_right_logical %arg0, %arg1 + %0 = chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- // CHECK-LABEL: @subWithoutBroadcast func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.subtract %arg0, %arg1 - %0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.subtract %arg0, %arg1 + %0 = chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- // CHECK-LABEL: @xorWithoutBroadcast func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: xla_hlo.xor %arg0, %arg1 - %0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + // CHECK: mhlo.xor %arg0, %arg1 + %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } diff --git a/tensorflow/compiler/mlir/xla/tests/concatenate.mlir b/tensorflow/compiler/mlir/hlo/tests/concatenate.mlir similarity index 54% rename from tensorflow/compiler/mlir/xla/tests/concatenate.mlir rename to tensorflow/compiler/mlir/hlo/tests/concatenate.mlir index 5b1225e1e87d4c..aeefd68e2b7ada 100644 --- a/tensorflow/compiler/mlir/xla/tests/concatenate.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/concatenate.mlir @@ -1,9 +1,9 @@ -// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: func @single_operand // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> { - %0 = "xla_hlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %0 = "mhlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32> // CHECK-NEXT: return [[ARG]] return %0 : tensor<1x2xf32> } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/hlo/tests/convert.mlir b/tensorflow/compiler/mlir/hlo/tests/convert.mlir new file mode 100644 index 00000000000000..dab395c52cdab8 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/convert.mlir @@ -0,0 +1,225 @@ +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s + +// ----- + +// CHECK-LABEL: func @same_type +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @same_type(%arg: tensor) -> tensor { + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[ARG]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @int_widening +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @int_widening(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @int_narrowing +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @int_narrowing(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @float_int +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @float_int(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @int_float +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @int_float(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @high_rank_tensor +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> { + // CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32> + %0 = "mhlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32> + // CHECK-NEXT: return [[RES]] + return %0 : tensor<2x3xf32> +} + +// ----- + + +// CHECK-LABEL: func @const_same_type +func @const_same_type() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_float_int +func @const_float_int() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42.0> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_int_float +func @const_int_float() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor + %cst = mhlo.constant dense<4> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_negative_int_float +func @const_negative_int_float() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-4.{{0*}}e+00> : tensor + %cst = mhlo.constant dense<-4> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_int_bf16 +func @const_int_bf16() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor + %cst = mhlo.constant dense<4> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_bf16_int +func @const_bf16_int() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42.0> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_int_narrowing +func @const_int_narrowing() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_int_widening +func @const_int_widening() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_negative_int_widening +func @const_negative_int_widening() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor + %cst = mhlo.constant dense<-42> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_float_narrowing +func @const_float_narrowing() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor + %cst = mhlo.constant dense<4.2> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_f32_bf16 +func @const_f32_bf16() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+01> : tensor + %cst = mhlo.constant dense<42.0> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_bf16_f64 +func @const_bf16_f64() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.187500e+00> : tensor + %cst = mhlo.constant dense<4.2> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_bf16_int +func @const_bf16_int() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42.0> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + + +// ----- + +// CHECK-LABEL: func @const_high_rank_tensor +func @const_high_rank_tensor() -> tensor<2x3xi32> { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6] + // CHECK-SAME: ]> : tensor<2x3xi32> + %cst = mhlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> + %0 = "mhlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<2x3xi32> +} + diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir new file mode 100644 index 00000000000000..cc60217be657de --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement %s -o - | FileCheck %s + +// CHECK-LABEL: func @func_op_unranked_arg_result +func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { + return %arg0 : tensor<*xf32> +} +// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32> +// CHECK-NEXT: return [[ARG]] : memref<*xf32> + +// ----- + +// CHECK-LABEL: func @dynamic_reshape_from_unranked +func @dynamic_reshape_from_unranked( + %operand: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor { + %reshaped = "mhlo.dynamic_reshape"(%operand, %shape) + : (tensor<*xf32>, tensor<1xi32>) -> tensor + return %reshaped : tensor +} +// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>) +// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) +// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref + +// ----- + +// CHECK-LABEL: func @dynamic_reshape_to_unranked +func @dynamic_reshape_to_unranked( + %operand: tensor, %shape: tensor) -> tensor<*xf32> { + %reshaped = "mhlo.dynamic_reshape"(%operand, %shape) + : (tensor, tensor) -> tensor<*xf32> + return %reshaped : tensor<*xf32> +} +// CHECK-SAME: ([[ARG:%.*]]: memref, [[SHAPE:%.*]]: memref) +// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) +// CHECK-SAME: : (memref, memref) -> memref<*xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir similarity index 76% rename from tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir rename to tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index f3ce29f1bd257a..aa5d800b82b3e2 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -1,13 +1,13 @@ -// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s -// RUN: xla-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s // BOTH-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.exponential"(%tensor_operand) + %tensor_result = "mhlo.exponential"(%tensor_operand) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} + // BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -18,40 +18,40 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> { return %arg0 : tensor<4xf32> } // PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) -// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () +// PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () // PRE-NEXT: return // ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] -// ESC-NOT: "xla_lhlo.copy" +// ESC-NOT: "lmhlo.copy" // ESC-NEXT: return %[[ARG0]] // ----- // BOTH-LABEL: func @func_op_long func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> - %2 = xla_hlo.add %arg0, %1 : tensor<4xf32> - %3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> - %4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32> - %5 = xla_hlo.multiply %2, %4 : tensor<4xf32> + %1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32> + %2 = mhlo.add %arg0, %1 : tensor<4xf32> + %3 = mhlo.minimum %arg0, %arg1 : tensor<4xf32> + %4 = mhlo.subtract %arg1, %3 : tensor<4xf32> + %5 = mhlo.multiply %2, %4 : tensor<4xf32> return %5 : tensor<4xf32> } // PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) // ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32> // BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) +// BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) // BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) +// BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) // BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> // BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) +// BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) // BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> -//  BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) +//  BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) // BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) +// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) // BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> -// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () +// PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () // PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> // PRE-NEXT: return // ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32> @@ -65,16 +65,16 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, // BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32> %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> - %sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2) + %sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) + // BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> - %tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier) + %tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) + // BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> - // BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) + // BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) tensor_store %tensor_result, %result : memref<2x2xf32> // BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> // BOTH-NEXT: return @@ -86,9 +86,9 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, // BOTH-LABEL: func @copy func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.copy"(%tensor_operand) + %tensor_result = "mhlo.copy"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -98,9 +98,9 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @exp func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.exponential"(%tensor_operand) + %tensor_result = "mhlo.exponential"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -110,9 +110,9 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @log func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.log"(%tensor_operand) + %tensor_result = "mhlo.log"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.log"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -125,9 +125,9 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %tensor_pred = tensor_load %pred : memref<2x2xi1> %tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32> - %tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) + %tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + // BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -138,10 +138,10 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { %tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32> - %tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs) + %tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs) {comparison_direction = "EQ"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> - // BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} + // BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} tensor_store %tensor_result, %result : memref<2x2xi1> return } @@ -151,10 +151,10 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x // BOTH-LABEL: func @broadcast func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { %tensor_operand = tensor_load %operand : memref<5xf32> - %tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand) + %tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<10x5xf32> - // BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<10x5xf32> return } @@ -169,11 +169,12 @@ func @external_func() -> tensor<3xi64> func @dyn_broadcast(%operand: memref) { // BOTH-SAME: (%[[OPERAND:.*]]: memref) %tensor_operand = tensor_load %operand : memref - %shape = call @external_func() : () -> tensor<3xi64> - %tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { + %c1 = constant 1 : i64 + %shape = tensor_from_elements(%c1, %c1, %c1) : tensor<3xi64> + %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> } : (tensor, tensor<3xi64>) -> tensor - // BOTH: %[[SHAPE:.*]] = call @external_func() + // BOTH: %[[SHAPE:.*]] = tensor_from_elements // BOTH: %[[C0:.*]] = constant 0 : index // BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> // BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index @@ -204,12 +205,12 @@ func @dyn_broadcast(%operand: memref) { // BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] // BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index - // BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast + // BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast // BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]) // BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] // BOTH-SAME: : memref -> memref - // BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { + // BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { // BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> // BOTH-SAME: } : (memref, memref) -> () @@ -226,9 +227,9 @@ func @complex(%real: memref<2x2xf32>, %result: memref<2x2xcomplex>) { %tensor_real = tensor_load %real : memref<2x2xf32> %tensor_imag = tensor_load %imag : memref<2x2xf32> - %tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag) + %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex> - // BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xcomplex> return } @@ -238,9 +239,9 @@ func @complex(%real: memref<2x2xf32>, // BOTH-LABEL: func @real func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> - %tensor_result = "xla_hlo.real"(%tensor_operand) + %tensor_result = "mhlo.real"(%tensor_operand) : (tensor<2x2xcomplex>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.real"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -250,9 +251,9 @@ func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @imag func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> - %tensor_result = "xla_hlo.imag"(%tensor_operand) + %tensor_result = "mhlo.imag"(%tensor_operand) : (tensor<2x2xcomplex>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -261,9 +262,9 @@ func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @iota func @iota(%result: memref<10xi32>) { - %tensor_result = "xla_hlo.iota"() + %tensor_result = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<10xi32> - // BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} + // BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} tensor_store %tensor_result, %result : memref<10xi32> return } @@ -273,9 +274,9 @@ func @iota(%result: memref<10xi32>) { // BOTH-LABEL: func @abs func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.abs"(%tensor_operand) + %tensor_result = "mhlo.abs"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -285,9 +286,9 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @ceil func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.ceil"(%tensor_operand) + %tensor_result = "mhlo.ceil"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -297,9 +298,9 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @convert func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.convert"(%tensor_operand) + %tensor_result = "mhlo.convert"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}}) // BOTH-NOT: tensor_store tensor_store %tensor_result, %result : memref<2x2xf32> return @@ -310,9 +311,9 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @cos func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.cosine"(%tensor_operand) + %tensor_result = "mhlo.cosine"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -322,9 +323,9 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @neg func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.negate"(%tensor_operand) + %tensor_result = "mhlo.negate"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -334,9 +335,9 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @rsqrt func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.rsqrt"(%tensor_operand) + %tensor_result = "mhlo.rsqrt"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -346,9 +347,9 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @sign func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.sign"(%tensor_operand) + %tensor_result = "mhlo.sign"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -358,9 +359,9 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @sqrt func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.sqrt"(%tensor_operand) + %tensor_result = "mhlo.sqrt"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -370,9 +371,9 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @tanh func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.tanh"(%tensor_operand) + %tensor_result = "mhlo.tanh"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -383,9 +384,9 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32> - %tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs) + %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) + // BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -395,7 +396,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x // Dynamic shape binary element-wise operation. // BOTH-LABEL: func @add_dyn func @add_dyn(%lhs: tensor, %rhs: tensor) { - %result = "xla_hlo.add"(%lhs, %rhs) + %result = "mhlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor // BOTH: %[[C0:.*]] = constant 0 : index // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref @@ -411,7 +412,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) { // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) - // BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () + // BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () return } @@ -420,7 +421,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) { // Dynamic shape unary element-wise operation. // BOTH-LABEL: func @tanh_dyn func @tanh_dyn(%arg0: tensor) { - %result = "xla_hlo.tanh"(%arg0) + %result = "mhlo.tanh"(%arg0) : (tensor) -> tensor // BOTH: %[[C0:.*]] = constant 0 : index // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref @@ -436,7 +437,7 @@ func @tanh_dyn(%arg0: tensor) { // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) - // BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () + // BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () return } @@ -447,10 +448,10 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { // PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // BOTH-NEXT: %[[ALLOC:.*]] = alloc -// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () - %dot = "xla_hlo.dot"(%arg0, %arg0) +// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () + %dot = "mhlo.dot"(%arg0, %arg0) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> -// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]]) +// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]]) // ESC: return %[[ALLOC]] return %dot : tensor<1024x1024xf32> } @@ -461,12 +462,12 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> { %c0 = constant 0 : index // BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> - // BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) + // BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) // BOTH-SAME: padding = dense<[ // BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> // BOTH-SAME: rhs_dilation = dense<[1, 2]> // BOTH-SAME: window_strides = dense<[2, 1]> - %out = "xla_hlo.convolution"(%filter, %input) { + %out = "mhlo.convolution"(%filter, %input) { batch_group_count = 1 : i64, dimension_numbers = { input_batch_dimension = 0 : i64, diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir similarity index 84% rename from tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir rename to tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir index 399a5cc343890e..320ce069ac01c6 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @float_add @@ -10,7 +10,7 @@ func @float_add(%lhs: tensor<2x2xf32>, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32 // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = addf %[[ARG0]], %[[ARG1]] // CHECK: linalg.yield %[[RESULT]] - %0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xf32>, + %0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -22,7 +22,7 @@ func @integer_add(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: addi - %0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -34,7 +34,7 @@ func @float_mul(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: mulf - %0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>, + %0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -46,7 +46,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: muli - %0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -58,7 +58,7 @@ func @float_remainder(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: remf - %0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>, + %0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -70,7 +70,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: remi_signed - %0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -79,7 +79,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>, // CHECK-LABEL: func @float_rsqrt func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { - %tensor_result = "xla_hlo.rsqrt"(%operand) + %tensor_result = "mhlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK: linalg.generic // CHECK: rsqrt @@ -93,7 +93,7 @@ func @float_sub(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: subf - %0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, + %0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -105,7 +105,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: subi - %0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -116,7 +116,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>, func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: absf - %0 = "xla_hlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -126,7 +126,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: exp - %0 = "xla_hlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -136,7 +136,7 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: log - %0 = "xla_hlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -146,7 +146,7 @@ func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: ceilf - %0 = "xla_hlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -156,7 +156,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: negf - %0 = "xla_hlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -166,7 +166,7 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: tanh - %0 = "xla_hlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -177,7 +177,7 @@ func @integer_and(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: and - %0 = "xla_hlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "mhlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -187,7 +187,7 @@ func @integer_and(%lhs: tensor<2x2xi32>, // CHECK-LABEL: func @float_cmp func @float_cmp(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { - %0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} + %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> return %0 : tensor<2x2xi1> } @@ -201,7 +201,7 @@ func @float_cmp(%lhs: tensor<2x2xf32>, // CHECK-LABEL: func @int_cmp func @int_cmp(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi1> { - %0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"} + %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"} : (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>) return %0 : tensor<2x2xi1> } @@ -216,7 +216,7 @@ func @int_cmp(%lhs: tensor<2x2xi32>, func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: cos - %0 = "xla_hlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -226,7 +226,7 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: sin - %0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -235,7 +235,7 @@ func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @copy // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { - %0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>) + %0 = "mhlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>) return %0 : tensor<2x4x8xf32> } // CHECK: return [[ARG]] : tensor<2x4x8xf32> @@ -245,7 +245,7 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { // CHECK-LABEL: func @select func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = "xla_hlo.select"(%pred, %lhs, %rhs) + %0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) return %0 : tensor<2x2xf32> } @@ -260,7 +260,7 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { - %0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor) -> tensor<4x2x1xf32> + %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor) -> tensor<4x2x1xf32> return %0: tensor<4x2x1xf32> } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -273,7 +273,7 @@ func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> // CHECK-LABEL: func @broadcast func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { - %0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> + %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> return %0: tensor<4x2x1x4x?x16xf32> } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -286,7 +286,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> // CHECK-LABEL: func @broadcast_in_dim func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { - %0 = "xla_hlo.broadcast_in_dim"(%operand) + %0 = "mhlo.broadcast_in_dim"(%operand) {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>} : (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> return %0 : tensor<7x10x6x4x5xf32> @@ -302,7 +302,7 @@ func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { // CHECK-LABEL: func @broadcast_in_dim_with_one_to_one func @broadcast_in_dim_with_one_to_one( %operand: tensor<1xf32>) -> tensor<1x5xf32> { - %0 = "xla_hlo.broadcast_in_dim"(%operand) + %0 = "mhlo.broadcast_in_dim"(%operand) {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<1x5xf32> return %0 : tensor<1x5xf32> @@ -317,7 +317,7 @@ func @broadcast_in_dim_with_one_to_one( // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func @broadcast_scalar(%operand: tensor) -> tensor<7x10x6xf32> { - %0 = "xla_hlo.broadcast_in_dim"(%operand) + %0 = "mhlo.broadcast_in_dim"(%operand) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<7x10x6xf32> return %0 : tensor<7x10x6xf32> @@ -332,7 +332,7 @@ func @broadcast_scalar(%operand: tensor) -> tensor<7x10x6xf32> { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func @transpose func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> return %0 : tensor<3x2x5x9xi32> } @@ -344,7 +344,7 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-LABEL: func @reshape_3D_2D func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> + %0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> return %0 : tensor<12x42xi32> } // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] @@ -355,7 +355,7 @@ func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> // CHECK-LABEL: func @reshape_4D_2D func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32> + %0 = "mhlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32> return %0 : tensor<12x42xi32> } // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] @@ -366,7 +366,7 @@ func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> { // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // CHECK-LABEL: func @reshape_2D_4D func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32> + %0 = "mhlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32> return %0 : tensor<12x1x42x1xi32> } // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] @@ -375,7 +375,7 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { // CHECK-LABEL: func @minf func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = "xla_hlo.minimum"(%lhs, %rhs) + %0 = "mhlo.minimum"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -389,7 +389,7 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @maxi func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "xla_hlo.maximum"(%lhs, %rhs) + %0 = "mhlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -404,7 +404,7 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()> // CHECK-LABEL: func @add_scalar func @add_scalar(%lhs: tensor, %rhs: tensor) -> tensor { - %0 = "xla_hlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor return %0 : tensor } // CHECK: linalg.generic @@ -417,7 +417,7 @@ func @add_scalar(%lhs: tensor, %rhs: tensor) -> tensor { func @reshape_collapse_single_dim (%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32> return %0 : tensor<1x784xf32> } // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> @@ -428,7 +428,7 @@ func @reshape_collapse_single_dim // ----- func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> return %0 : tensor<2x4x3xf32> } // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> @@ -440,7 +440,7 @@ func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> { // ----- func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32> return %0 : tensor<2x4x2xf32> } // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)> @@ -451,7 +451,7 @@ func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> { // ----- func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32> return %0 : tensor<1x4x2xf32> } // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -462,7 +462,7 @@ func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> { func @reshape_multiple_collapse (%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> return %0 : tensor<1x4x5x6xf32> } // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)> @@ -476,7 +476,7 @@ func @reshape_multiple_collapse // CHECK-LABEL: func @convert_i32_to_f32 func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> + %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> return %result : tensor<2x2xf32> } // CHECK: linalg.generic @@ -488,7 +488,7 @@ func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @convert_i16_to_i32 func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32> + %result = "mhlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32> return %result : tensor<2x2xi32> } // CHECK: linalg.generic @@ -500,7 +500,7 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> { // CHECK-LABEL: func @convert_i32_to_i16 func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16> + %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16> return %result : tensor<2x2xi16> } // CHECK: linalg.generic @@ -512,7 +512,7 @@ func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> { // CHECK-LABEL: func @convert_f32_to_f64 func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64> + %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64> return %result : tensor<2x2xf64> } // CHECK: linalg.generic @@ -524,7 +524,7 @@ func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> { // CHECK-LABEL: func @convert_f64_to_f32 func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32> + %result = "mhlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32> return %result : tensor<2x2xf32> } // CHECK: linalg.generic @@ -536,7 +536,7 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { // CHECK-LABEL: func @convert_f32_to_i32 func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> + %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> return %result : tensor<2x2xi32> } // CHECK: linalg.generic @@ -550,7 +550,7 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reverse func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { - %result = "xla_hlo.reverse"(%input) { + %result = "mhlo.reverse"(%input) { dimensions = dense<1> : tensor<1xi64> } : (tensor<2x3xf32>) -> tensor<2x3xf32> return %result : tensor<2x3xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/inlining.mlir b/tensorflow/compiler/mlir/hlo/tests/inlining.mlir similarity index 54% rename from tensorflow/compiler/mlir/xla/tests/inlining.mlir rename to tensorflow/compiler/mlir/hlo/tests/inlining.mlir index 9d1582c99e5580..f4ed563623f952 100644 --- a/tensorflow/compiler/mlir/xla/tests/inlining.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/inlining.mlir @@ -1,28 +1,28 @@ -// RUN: xla-opt %s -inline | FileCheck %s +// RUN: mlir-hlo-opt %s -inline | FileCheck %s -// Test case: Basic test of inlining into xla_hlo.while. +// Test case: Basic test of inlining into mhlo.while. // CHECK-LABEL: func @caller -// CHECK: "xla_hlo.while"{{.*}}( { +// CHECK: "mhlo.while"{{.*}}( { // CHECK: }, { -// CHECK: "xla_hlo.exponential" +// CHECK: "mhlo.exponential" // CHECK: }) // CHECK-LABEL: func @callee func @caller(%arg0: tensor, %pred: tensor) -> tensor { - %0 = "xla_hlo.while"(%arg0) ( { + %0 = "mhlo.while"(%arg0) ( { ^entry(%unused: tensor): - "xla_hlo.return"(%pred) : (tensor) -> () + "mhlo.return"(%pred) : (tensor) -> () }, { ^entry(%0: tensor): %1 = call @callee(%0) : (tensor) -> (tensor) - "xla_hlo.return"(%1) : (tensor) -> () + "mhlo.return"(%1) : (tensor) -> () } ) : (tensor) -> (tensor) return %0 : tensor } func @callee(%arg0: tensor) -> tensor { - %0 = "xla_hlo.exponential"(%arg0) : (tensor) -> tensor + %0 = "mhlo.exponential"(%arg0) : (tensor) -> tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize-control-flow.mlir similarity index 64% rename from tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir rename to tensorflow/compiler/mlir/hlo/tests/legalize-control-flow.mlir index 83880bc8ce95da..274792e62a2a07 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/legalize-control-flow.mlir @@ -1,24 +1,24 @@ -// RUN: xla-opt -xla-legalize-control-flow %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-legalize-control-flow %s -o - | FileCheck %s // CHECK-LABEL: func @while(%arg0: tensor) -> tensor { func @while(%arg0: tensor) -> tensor { //CHECK: br ^bb1(%arg0 : tensor) //CHECK: ^bb1([[VAL0:%.+]]: tensor): - //CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]]) + //CHECK: [[VAL1:%.+]] = "mhlo.compare"([[VAL0]], [[VAL0]]) //CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor //CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor), ^bb3([[VAL0]] : tensor) //CHECK: ^bb2([[VAL3:%.+]]: tensor): - //CHECK: [[VAL4:%.+]] = xla_hlo.add [[VAL3]], [[VAL3]] + //CHECK: [[VAL4:%.+]] = mhlo.add [[VAL3]], [[VAL3]] //CHECK: br ^bb1([[VAL4]] : tensor) //CHECK: ^bb3([[VAL5:%.+]]: tensor): - %0 = "xla_hlo.while"(%arg0) ( { + %0 = "mhlo.while"(%arg0) ( { ^bb0(%arg1: tensor): - %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): - %1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) : (tensor) -> tensor // CHECK-NEXT: return [[VAL5]] @@ -30,27 +30,27 @@ func @conditional(%arg0: tensor) -> tensor { // CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor %cst = constant dense<1.000000e+01> : tensor - // CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - %0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + %0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor // CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor), ^bb2(%arg0 : tensor) - %1 = "xla_hlo.if"(%0, %arg0, %arg0) ( { + %1 = "mhlo.if"(%0, %arg0, %arg0) ( { ^bb0(%arg1: tensor): // CHECK: ^bb1([[VAL2:%.+]]: tensor): - // CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor) -> tensor + // CHECK: [[VAL3:%.+]] = "mhlo.log"([[VAL2]]) : (tensor) -> tensor // CHECK: br ^bb3([[VAL3]] : tensor) - %2 = "xla_hlo.log"(%arg1) : (tensor) -> tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = "mhlo.log"(%arg1) : (tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () }, { ^bb0(%arg1: tensor): // CHECK: ^bb2([[VAL4:%.+]]: tensor): - // CHECK: [[VAL5:%.+]] = "xla_hlo.exponential"([[VAL4]]) : (tensor) -> tensor + // CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]]) : (tensor) -> tensor // CHECK: br ^bb3([[VAL5]] : tensor) - %2 = "xla_hlo.exponential"(%arg1) : (tensor) -> tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = "mhlo.exponential"(%arg1) : (tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () }) : (tensor, tensor, tensor) -> tensor // CHECK: ^bb3([[VAL6:%.+]]: tensor): @@ -62,27 +62,27 @@ func @conditional(%arg0: tensor) -> tensor { func @while_with_multiple_blocks_in_body(%arg0: tensor) -> tensor { // CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor) // CHECK: ^[[COND_ENTRY]](%0: tensor): - // CHECK: %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK: %2 = extract_element %1[] : tensor // CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor), ^[[EXIT:.+]](%0 : tensor) // CHECK: ^[[BODY_ENTRY]](%3: tensor): // CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor) // CHECK: ^[[BODY_SUCC]](%4: tensor): - // CHECK: %5 = xla_hlo.add %4, %4 : tensor + // CHECK: %5 = mhlo.add %4, %4 : tensor // CHECK: br ^[[COND_ENTRY]](%5 : tensor) // CHECK: ^[[EXIT]](%6: tensor): // CHECK: return %6 : tensor // CHECK: } - %0 = "xla_hlo.while"(%arg0) ( { + %0 = "mhlo.while"(%arg0) ( { ^cond_entry(%arg1: tensor): - %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^body_entry(%arg1: tensor): br ^body_succ(%arg1: tensor) ^body_succ(%0: tensor): - %1 = xla_hlo.add %0, %0 : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %0, %0 : tensor + "mhlo.return"(%1) : (tensor) -> () }) : (tensor) -> tensor return %0 : tensor @@ -94,7 +94,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor) -> tensor { // CHECK: ^[[COND_ENTRY]](%0: tensor): // CHECK: br ^[[COND_SUCC:.+]](%0 : tensor) // CHECK: ^[[COND_SUCC]](%1: tensor): - // CHECK: %2 = "xla_hlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK: %3 = extract_element %2[] : tensor // CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor), ^[[EXIT:.+]](%0 : tensor) // CHECK: ^[[BODY_ENTRY]](%4: tensor): @@ -102,15 +102,15 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor) -> tensor { // CHECK: ^[[EXIT]](%5: tensor): // CHECK: return %5 : tensor // CHECK: } - %0 = "xla_hlo.while"(%arg0) ( { + %0 = "mhlo.while"(%arg0) ( { ^cond_entry(%arg1: tensor): br ^cond_succ(%arg1: tensor) ^cond_succ(%0: tensor): - %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^body_entry(%arg1: tensor): - "xla_hlo.return"(%arg1) : (tensor) -> () + "mhlo.return"(%arg1) : (tensor) -> () }) : (tensor) -> tensor return %0 : tensor @@ -123,24 +123,24 @@ func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, % // CHECK: ^[[THEN_ENTRY]](%1: tensor): // CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor) // CHECK: ^[[THEN_SUCC]](%2: tensor): - // CHECK: %3 = "xla_hlo.log"(%2) : (tensor) -> tensor + // CHECK: %3 = "mhlo.log"(%2) : (tensor) -> tensor // CHECK: br ^[[EXIT:.+]](%3 : tensor) // CHECK: ^[[ELSE_ENTRY]](%4: tensor): - // CHECK: %5 = "xla_hlo.exponential"(%4) : (tensor) -> tensor + // CHECK: %5 = "mhlo.exponential"(%4) : (tensor) -> tensor // CHECK: br ^[[EXIT]](%5 : tensor) // CHECK: ^[[EXIT]](%6: tensor): // CHECK: return %6 : tensor // CHECK: } - %1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( { + %1 = "mhlo.if"(%pred, %arg0, %arg1) ( { ^then_entry(%arg2: tensor): br ^then_succ(%arg2: tensor) ^then_succ(%0: tensor): - %2 = "xla_hlo.log"(%0) : (tensor) -> tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = "mhlo.log"(%0) : (tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () }, { ^else_entry(%arg2: tensor): - %2 = "xla_hlo.exponential"(%arg2) : (tensor) -> tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = "mhlo.exponential"(%arg2) : (tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () }) : (tensor, tensor, tensor) -> tensor return %1 : tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir similarity index 64% rename from tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir rename to tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir index ebb54152e30711..37a61498fbfaee 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir @@ -1,21 +1,21 @@ -// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-legalize-to-std %s -o - | FileCheck %s // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32> - %1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32> - %2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32> - %3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32> - %4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %4 : tensor<4xf32> return %4 : tensor<4xf32> @@ -24,19 +24,19 @@ func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf // CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32> - %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32> - %1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32> - %2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32> - %3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32> - %4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: return %4 : tensor<4xi32> return %4 : tensor<4xi32> @@ -45,17 +45,17 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32 // CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> - %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32> - %1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32> - %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32> - %3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32> - %4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32> - %5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> } @@ -63,28 +63,28 @@ func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi // CHECK-LABEL: func @compare_float func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32> - %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32> - %1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32> - %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32> - %3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32> - %4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32> - %5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> } // CHECK-LABEL: func @int_constant func @int_constant() -> (tensor, tensor<2x3xi32>, tensor<2x3xi32>) { // CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor - %0 = "xla_hlo.constant"() {value = dense<0> : tensor} : () -> (tensor) + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> (tensor) // CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32> - %1 = "xla_hlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) + %1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) // CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32> - %2 = "xla_hlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) + %2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) // CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor, tensor<2x3xi32>, tensor<2x3xi32> return %0, %1, %2: tensor, tensor<2x3xi32>, tensor<2x3xi32> } @@ -92,11 +92,11 @@ func @int_constant() -> (tensor, tensor<2x3xi32>, tensor<2x3xi32>) { // CHECK-LABEL: func @float_constant func @float_constant() -> (tensor, tensor<2x3xf32>, tensor<2x3xf32>) { // CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor - %0 = "xla_hlo.constant"() {value = dense<0.0> : tensor} : () -> (tensor) + %0 = "mhlo.constant"() {value = dense<0.0> : tensor} : () -> (tensor) // CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32> - %1 = "xla_hlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) + %1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) // CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32> - %2 = "xla_hlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) + %2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) // CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor, tensor<2x3xf32>, tensor<2x3xf32> return %0, %1, %2: tensor, tensor<2x3xf32>, tensor<2x3xf32> } @@ -105,7 +105,7 @@ func @float_constant() -> (tensor, tensor<2x3xf32>, tensor<2x3xf32>) { // CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> { func @iota.const.1() -> tensor<4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> // CHECK-NEXT: return %[[CST]] : tensor<4xi32> return %0 : tensor<4xi32> } @@ -113,7 +113,7 @@ func @iota.const.1() -> tensor<4xi32> { // CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> { func @iota.const.2() -> tensor<2x4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> return %0 : tensor<2x4xi32> } @@ -121,7 +121,7 @@ func @iota.const.2() -> tensor<2x4xi32> { // CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> { func @iota.const.3() -> tensor<2x4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> return %0 : tensor<2x4xi32> } @@ -129,7 +129,7 @@ func @iota.const.3() -> tensor<2x4xi32> { // CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> { func @iota.const.4() -> tensor<2x3x4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32> } @@ -137,7 +137,7 @@ func @iota.const.4() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> { func @iota.const.5() -> tensor<2x3x4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32> } @@ -145,7 +145,7 @@ func @iota.const.5() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> { func @iota.const.6() -> tensor<2x3x4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32> } @@ -153,7 +153,7 @@ func @iota.const.6() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.f32 func @iota.const.f32() -> tensor<4xf32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> // CHECK-NEXT: return %[[CST]] : tensor<4xf32> return %0 : tensor<4xf32> } @@ -161,7 +161,7 @@ func @iota.const.f32() -> tensor<4xf32> { // CHECK-LABEL: func @iota.const.f64 func @iota.const.f64() -> tensor<4xf64> { // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64> // CHECK-NEXT: return %[[CST]] : tensor<4xf64> return %0 : tensor<4xf64> } @@ -169,7 +169,7 @@ func @iota.const.f64() -> tensor<4xf64> { // CHECK-LABEL: func @iota.const.bf16 func @iota.const.bf16() -> tensor<4xbf16> { // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16> // CHECK-NEXT: return %[[CST]] : tensor<4xbf16> return %0 : tensor<4xbf16> } @@ -178,8 +178,8 @@ func @iota.const.bf16() -> tensor<4xbf16> { func @iota.const.complex.f32() -> tensor<4xcomplex> { // CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> // CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32> - // CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]]) - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> + // CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]]) + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex> return %0 : tensor<4xcomplex> } @@ -188,8 +188,8 @@ func @iota.const.complex.f32() -> tensor<4xcomplex> { func @iota.const.complex.f64() -> tensor<4xcomplex> { // CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> // CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64> - // CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]]) - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> + // CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]]) + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex> return %0 : tensor<4xcomplex> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize_tanh_to_approximation.mlir similarity index 98% rename from tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir rename to tensorflow/compiler/mlir/hlo/tests/legalize_tanh_to_approximation.mlir index f3bdc7d96cb247..aa834d36ac476b 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/legalize_tanh_to_approximation.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s func @tanh_f64(%arg0 : f64) -> f64 { %res = tanh %arg0 : f64 diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir new file mode 100644 index 00000000000000..6d7992cb868d32 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-hlo-opt -lhlo-copy-removal %s -o - | FileCheck %s + +// CHECK-LABEL: func @remove_simple +func @remove_simple(%arg0: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @remove_without_dealloc +func @remove_without_dealloc(%arg0: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @replace_dependency +func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @keep_copies +func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { + // CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_not_be_removed +func @must_not_be_removed(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + // CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32> + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "lmhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_be_removed_first +func @must_be_removed_first(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "lmhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_be_removed_second +func @must_be_removed_second(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "lmhlo.terminator"() : () -> () +} diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir similarity index 96% rename from tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir rename to tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir index b04c97f42d7200..6a674664a36667 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir @@ -1,6 +1,6 @@ -// RUN: xla-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always -// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED -// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP +// RUN: mlir-hlo-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always +// RUN: mlir-hlo-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED +// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP #map0 = affine_map<(d0, d1) -> (d0, d1)> #pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-select-and-scatter.mlir similarity index 93% rename from tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir rename to tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-select-and-scatter.mlir index 9887860ca26a78..a6bb876d3dc0b1 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-select-and-scatter.mlir @@ -1,27 +1,27 @@ // GenericAtomicRMWOp should contain only ops with no side effects. // Unfortunately, the legalization pattern for SelectAndScatterOp has to adapt -// to XLA LHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body. +// to LMHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body. // Lowering to STD dialect and store forwarding pass would be required to get // rid of them. This is exactly what is done in the real MLIR GPU pipeline, but // here we disable verification with `verify-each=0` to check the output IR. -// RUN: xla-opt %s -lhlo-legalize-to-parallel-loops -canonicalize --verify-each=0 | FileCheck %s +// RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize --verify-each=0 | FileCheck %s func @select_and_scatter(%arg: memref<112x112xf32>, %src: memref<56x56xf32>, %init: memref, %result: memref<112x112xf32>) { - "xla_lhlo.select_and_scatter"(%arg, %src, %init, %result) ( { + "lmhlo.select_and_scatter"(%arg, %src, %init, %result) ( { // select ^bb0(%lhs: memref, %rhs: memref, %pred: memref): - "xla_lhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} : + "lmhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () }, { // scatter ^bb0(%lhs: memref, %rhs: memref, %out: memref): - "xla_lhlo.add"(%lhs, %rhs, %out) : + "lmhlo.add"(%lhs, %rhs, %out) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () }) { padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, window_dimensions = dense<[3, 3]> : tensor<2xi64>, @@ -29,7 +29,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, } : (memref<112x112xf32>, memref<56x56xf32>, memref, memref<112x112xf32>) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // CHECK-LABEL: func @select_and_scatter( // CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>, @@ -121,7 +121,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref // Compute PRED. - // CHECK: "xla_lhlo.compare"( + // CHECK: "lmhlo.compare"( // CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]]) // CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref @@ -182,7 +182,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref // Compute scatter value. -// CHECK: "xla_lhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) : +// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) : // CHECK-SAME: (memref, memref, memref) -> () // CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir similarity index 87% rename from tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir rename to tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir index 483204cf0d5b04..87818045993761 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s // Smoke test. // CHECK-LABEL: func @min_op @@ -14,7 +14,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, // CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> // CHECK: return - "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : + "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> () return } @@ -24,7 +24,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: addf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"} + "lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -32,7 +32,7 @@ func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: addi %{{.*}}, %{{.*}} : i32 - "xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"} + "lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -42,7 +42,7 @@ func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: and %{{.*}}, %{{.*}} : i32 - "xla_lhlo.and"(%lhs, %rhs, %result) {name = "and.1"} + "lmhlo.and"(%lhs, %rhs, %result) {name = "and.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -52,7 +52,7 @@ func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: divf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} + "lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -60,7 +60,7 @@ func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: divi_signed %{{.*}}, %{{.*}} : i32 - "xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} + "lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -71,7 +71,7 @@ func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 - "xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} + "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -81,7 +81,7 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 - "xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} + "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -92,7 +92,7 @@ func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 - "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} + "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -102,7 +102,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 - "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} + "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -112,7 +112,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: mulf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} + "lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -121,7 +121,7 @@ func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: muli %{{.*}}, %{{.*}} : i32 - "xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} + "lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -131,7 +131,7 @@ func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: subf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} + "lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -139,7 +139,7 @@ func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: subi %{{.*}}, %{{.*}} : i32 - "xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} + "lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -158,7 +158,7 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32> // CHECK: return - "xla_lhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) : (memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> () return } @@ -175,7 +175,7 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32> // CHECK: return - "xla_lhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) : (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () return } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-gpu.mlir similarity index 81% rename from tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir rename to tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-gpu.mlir index c86744a9090a17..02ad36536394be 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-gpu.mlir @@ -1,13 +1,13 @@ -// RUN: xla-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s func @reduce(%arg: memref<100x10xf32>, %init: memref, %result: memref<100xf32>) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<100x10xf32>, memref, memref<100xf32>) -> () return @@ -25,7 +25,7 @@ func @reduce(%arg: memref<100x10xf32>, // CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { // CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref // CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref -// CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () +// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () // CHECK: } // CHECK: gpu.terminator // CHECK: } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir similarity index 83% rename from tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir rename to tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir index f57270635985eb..dd88e5c80bf8d8 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir @@ -1,10 +1,10 @@ -// RUN: xla-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @element_wise func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.add"(%lhs, %rhs, %result) + "lmhlo.add"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -19,7 +19,7 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, func @element_wise_with_dynamic_shape(%lhs: memref, %rhs: memref, %result: memref) { - "xla_lhlo.add"(%lhs, %rhs, %result) + "lmhlo.add"(%lhs, %rhs, %result) : (memref, memref, memref) -> () return } @@ -33,7 +33,7 @@ func @element_wise_with_dynamic_shape(%lhs: memref, // CHECK-LABEL: func @element_wise_scalar func @element_wise_scalar(%lhs: memref, %rhs: memref, %result: memref) { - "xla_lhlo.add"(%lhs, %rhs, %result) + "lmhlo.add"(%lhs, %rhs, %result) : (memref, memref, memref) -> () return } @@ -48,7 +48,7 @@ func @element_wise_scalar(%lhs: memref, %rhs: memref, // CHECK-LABEL: func @minf func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.minimum"(%lhs, %rhs, %result) + "lmhlo.minimum"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -63,7 +63,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @maxi func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.maximum"(%lhs, %rhs, %result) + "lmhlo.maximum"(%lhs, %rhs, %result) : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -78,7 +78,7 @@ func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // CHECK-LABEL: func @and func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.and"(%lhs, %rhs, %result) + "lmhlo.and"(%lhs, %rhs, %result) : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -91,7 +91,7 @@ func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.exponential"(%input, %result) + "lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -104,7 +104,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @log func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -116,7 +116,7 @@ func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @copy func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { - "xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () + "lmhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () return } // CHECK: linalg.generic @@ -128,7 +128,7 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { // CHECK-LABEL: func @float_cmp func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { - "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"} + "lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> () return } @@ -142,7 +142,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @int_cmp func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi1>) { - "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} + "lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> () return } @@ -156,7 +156,7 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // CHECK-LABEL: func @select func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.select"(%pred, %lhs, %rhs, %result) + "lmhlo.select"(%pred, %lhs, %rhs, %result) : (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -170,7 +170,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @iota func @iota(%out: memref<7x10xf32>) { - "xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> () + "lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> () return } // CHECK: linalg.indexed_generic @@ -186,7 +186,7 @@ func @iota(%out: memref<7x10xf32>) { // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func @broadcast_scalar(%operand: memref, %result: memref<4x2x1xf32>) { - "xla_lhlo.broadcast"(%operand, %result) { + "lmhlo.broadcast"(%operand, %result) { broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> } : (memref, memref<4x2x1xf32>) -> () return @@ -203,7 +203,7 @@ func @broadcast_scalar(%operand: memref, %result: memref<4x2x1xf32>) { // CHECK-LABEL: func @broadcast func @broadcast(%operand: memref<4x?x16xf32>, %result: memref<4x2x1x4x?x16xf32>) { - "xla_lhlo.broadcast"(%operand, %result) { + "lmhlo.broadcast"(%operand, %result) { broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> } : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> () return @@ -220,7 +220,7 @@ func @broadcast(%operand: memref<4x?x16xf32>, // CHECK-LABEL: func @dynamic_broadcast_in_dim func @dynamic_broadcast_in_dim(%operand: memref, %result: memref) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64> } : (memref, memref) -> () return @@ -237,7 +237,7 @@ func @dynamic_broadcast_in_dim(%operand: memref, // CHECK-LABEL: func @static_broadcast_in_dim_no_expansion func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>, %result: memref<5x10xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[0]> : tensor<1xi64> } : (memref<5xf32>, memref<5x10xf32>) -> () return @@ -255,7 +255,7 @@ func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>, // CHECK-LABEL: func @static_broadcast_in_dim_expansion func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>, %result: memref<5x10x100xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[2, 0]> : tensor<2xi64> } : (memref<1x5xf32>, memref<5x10x100xf32>) -> () return @@ -274,7 +274,7 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>, // CHECK-LABEL: func @static_broadcast_in_dim_scalar func @static_broadcast_in_dim_scalar(%operand: memref, %result: memref<5x10xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[]> : tensor<0xi64> } : (memref, memref<5x10xf32>) -> () return @@ -291,7 +291,7 @@ func @static_broadcast_in_dim_scalar(%operand: memref, // CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, %result: memref<1x5xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[0]> : tensor<1xi64> } : (memref<1xf32>, memref<1x5xf32>) -> () return @@ -307,7 +307,7 @@ func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, // CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>, %result: memref<5x5xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (memref<1xf32>, memref<5x5xf32>) -> () return @@ -323,19 +323,19 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>, // CHECK-LABEL: func @constant func @constant(%value: memref) { - "xla_lhlo.constant"(%value) { + "lmhlo.constant"(%value) { value = dense<10> : tensor } : (memref) -> () return } // CHECK: %[[CONSTANT:.*]] = constant 10 : i32 -// CHECK: store %[[CONSTANT]], %{{.*}}[] : memref +// CHECK: affine.store %[[CONSTANT]], %{{.*}}[] : memref // ----- // CHECK-LABEL: func @absf func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -348,7 +348,7 @@ func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @absi func @absi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -364,7 +364,7 @@ func @absi(%input: memref<2x2xi32>, // CHECK-LABEL: func @ceil func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -376,7 +376,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_i32_to_f32 func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -389,7 +389,7 @@ func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_i16_to_i32 func @convert_i16_to_i32(%input: memref<2x2xi16>, %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -401,7 +401,7 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>, // CHECK-LABEL: func @convert_i32_to_i16 func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> () return } // CHECK: linalg.generic @@ -413,7 +413,7 @@ func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) { // CHECK-LABEL: func @convert_f32_to_f64 func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> () return } // CHECK: linalg.generic @@ -425,7 +425,7 @@ func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) { // CHECK-LABEL: func @convert_f64_to_f32 func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -437,7 +437,7 @@ func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_i32_to_i32 func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -448,7 +448,7 @@ func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @convert_f32_to_f32 func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -459,7 +459,7 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_f32_to_i32 func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) + "lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xi32>) -> () return } @@ -472,7 +472,7 @@ func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @cos func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -485,7 +485,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sin func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sine"(%input, %result) + "lmhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -498,7 +498,7 @@ func @sin(%input: memref<2x2xf32>, // CHECK-LABEL: func @negf func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -510,7 +510,7 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @negi func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -524,7 +524,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @rem func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.remainder"(%lhs, %rhs, %result) + "lmhlo.remainder"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -537,7 +537,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @rsqrt func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -549,7 +549,7 @@ func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sign func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -562,7 +562,7 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sqrt func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -574,7 +574,7 @@ func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @tanh func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -588,7 +588,7 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @complex(%real: memref<2x2xf32>, %imag: memref<2x2xf32>, %cplx: memref<2x2xcomplex>) { - "xla_lhlo.complex"(%real, %imag, %cplx) + "lmhlo.complex"(%real, %imag, %cplx) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex>) -> () return } @@ -602,7 +602,7 @@ func @complex(%real: memref<2x2xf32>, // CHECK-LABEL: func @real func @real(%cplx: memref<2x2xcomplex>, %real: memref<2x2xf32>) { - "xla_lhlo.real"(%cplx, %real) + "lmhlo.real"(%cplx, %real) : (memref<2x2xcomplex>, memref<2x2xf32>) -> () return } @@ -616,7 +616,7 @@ func @real(%cplx: memref<2x2xcomplex>, // CHECK-LABEL: func @imag func @imag(%cplx: memref<2x2xcomplex>, %imag: memref<2x2xf32>) { - "xla_lhlo.imag"(%cplx, %imag) + "lmhlo.imag"(%cplx, %imag) : (memref<2x2xcomplex>, memref<2x2xf32>) -> () return } @@ -629,7 +629,7 @@ func @imag(%cplx: memref<2x2xcomplex>, // CHECK: func @slice(%[[IN:.*]]: memref, %[[OUT:.*]]: memref) func @slice(%operand: memref, %result: memref) { - "xla_lhlo.slice"(%operand, %result) { + "lmhlo.slice"(%operand, %result) { start_indices = dense<[0,1]> : tensor<2xi64>, limit_indices = dense<[2,3]> : tensor<2xi64>, strides = dense<[1,1]> : tensor<2xi64> @@ -653,7 +653,7 @@ func @slice(%operand: memref, %result: memref) { // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-LABEL: func @reshape_3D_2D func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) + "lmhlo.reshape"(%arg0, %arg1) : (memref<12x1x42xi32>, memref<12x42xi32>) -> () return } @@ -666,7 +666,7 @@ func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> // CHECK-LABEL: func @reshape_4D_2D func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) + "lmhlo.reshape"(%arg0, %arg1) : (memref<12x42x1x1xi32>, memref<12x42xi32>) -> () return } @@ -679,7 +679,7 @@ func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // CHECK-LABEL: func @reshape_2D_4D func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) + "lmhlo.reshape"(%arg0, %arg1) : (memref<12x42xi32>, memref<12x1x42x1xi32>) -> () return } @@ -692,7 +692,7 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reverse func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) { - "xla_lhlo.reverse"(%arg0, %arg1) { + "lmhlo.reverse"(%arg0, %arg1) { dimensions = dense<1> : tensor<1xi64> } : (memref<2x3xf32>, memref<2x3xf32>) -> () return @@ -710,15 +710,15 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m // CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64> // CHECK-SAME: strides = [2, 1]} // With all atributes explicitly specified. - "xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () + "lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () // Dilation left unspecified, sets default dilation since linalg expects it. // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) // CHECK-SAME: dilations = [1, 1] // Padding is not set if it's zero. // CHECK-NOT: padding - "xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () + "lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () - "xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () + "lmhlo.terminator"() : () -> () } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir similarity index 94% rename from tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir rename to tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir index bd552282fcc801..a25a508b2d3e26 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir @@ -1,8 +1,8 @@ -// RUN: xla-opt %s --test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s --test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s // CHECK-LABEL: func @static_memref_cast func @static_memref_cast(%buf : memref<10x1x5xf32>) { - %0 = xla_lhlo.static_memref_cast %buf + %0 = lmhlo.static_memref_cast %buf : memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]> return } @@ -38,7 +38,7 @@ func @dynamic_memref_cast(%buf : memref) { %size_Y = constant 50 : index %stride_X = constant 1 : index %stride_Y = constant 0 : index - %0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y] + %0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y] : memref -> memref return } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-parallel-loops.mlir similarity index 90% rename from tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir rename to tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-parallel-loops.mlir index 5127bcfcd8fa6a..1530f59317d6f5 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-parallel-loops.mlir @@ -1,13 +1,13 @@ -// RUN: xla-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s func @reduce(%arg: memref<100x10x5xf32>, %init: memref, %result: memref<100x5xf32>) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<100x10x5xf32>, memref, memref<100x5xf32>) -> () return @@ -35,7 +35,7 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } @@ -49,11 +49,11 @@ func @reduce(%arg: memref<100x10x5xf32>, func @reduce_no_outer_loop(%arg: memref<100xf32>, %init: memref, %result: memref<1xf32>) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<100xf32>, memref, memref<1xf32>) -> () return @@ -76,7 +76,7 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] // CHECK: } @@ -88,11 +88,11 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, func @dynamic_reduce(%arg: memref, %init: memref, %result: memref) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref, memref, memref) -> () return @@ -121,7 +121,7 @@ func @dynamic_reduce(%arg: memref, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } @@ -135,11 +135,11 @@ func @dynamic_reduce(%arg: memref, func @reduce_window(%arg: memref<112x112xf32>, %init: memref, %result: memref<56x56xf32>) { - "xla_lhlo.reduce_window"(%arg, %init, %result) ( { + "lmhlo.reduce_window"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.maximum"(%lhs, %rhs, %res) + "lmhlo.maximum"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () }) { padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, window_dimensions = dense<[3, 3]> : tensor<2xi64>, @@ -189,7 +189,7 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo_ops.mlir similarity index 67% rename from tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir rename to tensorflow/compiler/mlir/hlo/tests/lhlo_ops.mlir index 77a3d0fe4a9037..30ff9659d3b0b3 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo_ops.mlir @@ -1,10 +1,10 @@ -// RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s // ----- // CHECK-LABEL: func @ceil func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -12,7 +12,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point values}} - "xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -20,7 +20,7 @@ func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @cos func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -28,7 +28,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @cos func @cos(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () return } @@ -36,7 +36,7 @@ func @cos(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -44,7 +44,7 @@ func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @sin func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -52,7 +52,7 @@ func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sin func @sin(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { - "xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + "lmhlo.sine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () return } @@ -60,7 +60,7 @@ func @sin(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -68,7 +68,7 @@ func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @add_memrefs func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -76,7 +76,7 @@ func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1 // CHECK-LABEL: func @abs_memref func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -84,7 +84,7 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @convert_memref func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () { - "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> () + "lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> () return } @@ -92,7 +92,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () { func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () { // expected-error@+1{{requires the same shape for all operands}} - "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> () + "lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> () return } @@ -100,7 +100,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () { // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -108,7 +108,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { - "xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + "lmhlo.exponential"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () return } @@ -116,7 +116,7 @@ func @exp(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -124,7 +124,7 @@ func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @log_memref func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -132,7 +132,7 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @log_memref func @log_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.log"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.log"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -140,7 +140,7 @@ func @log_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -148,7 +148,7 @@ func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @neg_memref func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -156,7 +156,7 @@ func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @rsqrt_memref func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -164,7 +164,7 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @rsqrt_memref func @rsqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.rsqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -172,7 +172,7 @@ func @rsqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -180,7 +180,7 @@ func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @sqrt_memref func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -188,7 +188,7 @@ func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @sqrt_memref func @sqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.sqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -196,7 +196,7 @@ func @sqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) - func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -204,7 +204,7 @@ func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @sign_memref func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -212,7 +212,7 @@ func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @tanh_memref func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -220,7 +220,7 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @tanh_memref func @tanh_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.tanh"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -228,15 +228,15 @@ func @tanh_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) - func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } // ----- func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { - // expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}} - "xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> () + // expected-error@+1{{'lmhlo.tanh' op requires all operands to have the same type}} + "lmhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> () return } @@ -244,7 +244,7 @@ func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { // CHECK-LABEL: func @add_memref func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -252,7 +252,7 @@ func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @div_memref func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -260,7 +260,7 @@ func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @max_memref func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -268,7 +268,7 @@ func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @min_memref func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -276,7 +276,7 @@ func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @mul_memref func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -284,7 +284,7 @@ func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @sub_memref func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -292,7 +292,7 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @and_memref func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { - "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () return } @@ -300,7 +300,7 @@ func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32 // CHECK-LABEL: func @and_memref func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { - "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () return } @@ -308,7 +308,7 @@ func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -316,7 +316,7 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @or_memref func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { - "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () return } @@ -324,7 +324,7 @@ func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32> // CHECK-LABEL: func @or_memref func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { - "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () return } @@ -332,7 +332,7 @@ func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) - func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -340,7 +340,7 @@ func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32> // CHECK-LABEL: func @xor_memref func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { - "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () return } @@ -348,7 +348,7 @@ func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32 // CHECK-LABEL: func @xor_memref func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { - "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () return } @@ -356,7 +356,7 @@ func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -364,7 +364,7 @@ func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @broadcast_in_dim_memref func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () { - "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> () + "lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> () return } @@ -372,7 +372,7 @@ func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) - // CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref func @broadcast_in_dim_zero_rank_memref(%arg0: memref, %out: memref<1x2x3xi32>) -> () { - "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref, memref<1x2x3xi32>) -> () + "lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref, memref<1x2x3xi32>) -> () return } @@ -381,10 +381,10 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref, %out: memref<1x2x3xi // CHECK-LABEL: func @reduce_memref func @reduce_memref(%input: memref<10xf32>, %init: memref, %out: memref<1xf32>) -> () { - "xla_lhlo.reduce"(%input, %init, %out) ( { + "lmhlo.reduce"(%input, %init, %out) ( { ^bb0(%arg1: memref, %arg2: memref, %result: memref): - "xla_lhlo.add"(%arg1, %arg2, %result) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.add"(%arg1, %arg2, %result) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref, memref<1xf32>) -> () return } @@ -393,14 +393,14 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref, %out: memref<1xf // CHECK-LABEL: func @fusion_memref func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.fusion"() ( { + "lmhlo.fusion"() ( { %0 = tensor_load %input1 : memref<10xf32> %1 = tensor_load %input2 : memref<10xf32> - %2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %3 = tensor_load %input3 : memref<10xf32> - %4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> tensor_store %4, %out : memref<10xf32> - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) : () -> () return } @@ -409,18 +409,18 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m // CHECK-LABEL: func @case_memref func @case_memref(%index: memref, %operand_1: memref, %operand_2: memref, %operand_3: memref, %out: memref) -> () { - "xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { + "lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { ^bb0(%arg0: memref): - "xla_lhlo.negate"(%arg0, %out) : (memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.negate"(%arg0, %out) : (memref, memref) -> () + "lmhlo.terminator"() : () -> () }, { ^bb0(%arg0: memref): - "xla_lhlo.copy"(%arg0, %out) : (memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.copy"(%arg0, %out) : (memref, memref) -> () + "lmhlo.terminator"() : () -> () }, { ^bb0(%arg0: memref): - "xla_lhlo.add"(%arg0, %arg0, %out) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.add"(%arg0, %arg0, %out) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () } ) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>} : (memref, memref, memref, memref, memref) -> () @@ -430,7 +430,7 @@ func @case_memref(%index: memref, %operand_1: memref, %operand_2: memr // ----- func @static_memref_cast(%in: memref<10x1xf32>) { - %out = xla_lhlo.static_memref_cast %in + %out = lmhlo.static_memref_cast %in : memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]> return } @@ -440,7 +440,7 @@ func @static_memref_cast(%in: memref<10x1xf32>) { func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) { // expected-error @+1 {{operand must have static shape}} - %out = xla_lhlo.static_memref_cast %in + %out = lmhlo.static_memref_cast %in : memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]> return } @@ -449,7 +449,7 @@ func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) { func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) { // expected-error @+1 {{result must have static shape}} - %out = xla_lhlo.static_memref_cast %in + %out = lmhlo.static_memref_cast %in : memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]> return } @@ -459,7 +459,7 @@ func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) { func @dynamic_memref_cast(%in: memref) { %size = constant 10 : index %step = constant 1 : index - %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] + %out = lmhlo.dynamic_memref_cast %in(%size)[%step] : memref -> memref return } @@ -471,7 +471,7 @@ func @dynamic_memref_cast_incompatible_result_type(%in: memref) { // expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}} %size = constant 10 : index %step = constant 1 : index - %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] + %out = lmhlo.dynamic_memref_cast %in(%size)[%step] : memref -> memref return } @@ -483,19 +483,19 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>, // CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>, // CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref - // CHECK-NEXT: [[DYN_VEC:%.*]] = xla_lhlo.reshape_memref_cast [[UNRANKED]] + // CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]] // CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref - %dyn_vec = xla_lhlo.reshape_memref_cast %unranked(%shape1) + %dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1) : (memref<*xf32>, memref<1xi32>) -> memref - // CHECK-NEXT: [[DYN_MAT:%.*]] = xla_lhlo.reshape_memref_cast [[DYN_VEC]] + // CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]] // CHECK-SAME: : (memref, memref<2xi32>) -> memref - %dyn_mat = xla_lhlo.reshape_memref_cast %dyn_vec(%shape2) + %dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2) : (memref, memref<2xi32>) -> memref - // CHECK-NEXT: {{%.*}} = xla_lhlo.reshape_memref_cast [[DYN_MAT]] + // CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]] // CHECK-SAME: : (memref, memref) -> memref<*xf32> - %new_unranked = xla_lhlo.reshape_memref_cast %dyn_mat(%shape3) + %new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3) : (memref, memref) -> memref<*xf32> return } @@ -505,7 +505,7 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>, func @reshape_memref_cast_element_type_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{element types of source and destination memref types should be the same}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<*xf32>, memref<1xi32>) -> memref } @@ -514,7 +514,7 @@ func @reshape_memref_cast_element_type_mismatch( func @reshape_memref_cast_dst_ranked_shape_unranked( %buf: memref<*xf32>, %shape: memref) { // expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<*xf32>, memref) -> memref return } @@ -524,7 +524,7 @@ func @reshape_memref_cast_dst_ranked_shape_unranked( func @reshape_memref_cast_dst_shape_rank_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{length of shape operand differs from the result's memref rank}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<*xf32>, memref<1xi32>) -> memref return } @@ -535,7 +535,7 @@ func @reshape_memref_cast_affine_map_is_not_identity( %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>, %shape: memref<1xi32>) { // expected-error @+1 {{operand memref type should have identity affine map}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>) -> memref<8xf32> return @@ -545,7 +545,7 @@ func @reshape_memref_cast_affine_map_is_not_identity( // CHECK-LABEL: func @atan2_memrefs func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -553,7 +553,7 @@ func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref // CHECK-LABEL: func @atan2_memrefs func @atan2_memrefs(%arg0: memref<1xcomplex>, %arg1: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { - "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>, memref<1xcomplex>) -> () + "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>, memref<1xcomplex>) -> () return } @@ -561,7 +561,7 @@ func @atan2_memrefs(%arg0: memref<1xcomplex>, %arg1: memref<1xcomplex> func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -569,7 +569,7 @@ func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref // CHECK-LABEL: func @bitcast_convert_memrefs func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> () + "lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> () return } @@ -577,7 +577,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () { // expected-error@+1{{requires the same shape for all operands}} - "xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> () + "lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> () return } @@ -585,7 +585,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> // CHECK-LABEL: func @clz_memrefs func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -593,7 +593,7 @@ func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @expm1_memrefs func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -601,7 +601,7 @@ func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @expm1_memrefs func @expm1_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { - "xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () + "lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () return } @@ -609,7 +609,7 @@ func @expm1_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -617,7 +617,7 @@ func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // expected-error@+1{{must be memref of floating-point values}} - "xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -625,7 +625,7 @@ func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @imag_memrefs func @imag_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () + "lmhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () return } @@ -633,7 +633,7 @@ func @imag_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error@+1{{must be memref of complex-type values}} - "xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -641,7 +641,7 @@ func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @real_memrefs func @real_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () + "lmhlo.real"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () return } @@ -649,7 +649,7 @@ func @real_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error@+1{{must be memref of complex-type values}} - "xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -657,7 +657,7 @@ func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @is_finite_memrefs func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () { - "xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> () + "lmhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> () return } @@ -665,7 +665,7 @@ func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () { // CHECK-LABEL: func @log1p_memrefs func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -673,7 +673,7 @@ func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @log1p_memrefs func @log1p_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { - "xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () + "lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () return } @@ -681,7 +681,7 @@ func @log1p_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -689,7 +689,7 @@ func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @not_memrefs func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -697,7 +697,7 @@ func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @not_memrefs func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () { - "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> () + "lmhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> () return } @@ -705,7 +705,7 @@ func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () { func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -713,7 +713,7 @@ func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @popcnt_memrefs func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -721,7 +721,7 @@ func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -729,7 +729,7 @@ func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @reduce_precision_memrefs func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> () return } @@ -737,7 +737,7 @@ func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> // CHECK-LABEL: func @round_memrefs func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -745,7 +745,7 @@ func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // expected-error@+1{{must be memref of floating-point values}} - "xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -753,7 +753,7 @@ func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @shift_left_memrefs func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -761,7 +761,7 @@ func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: m func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -769,7 +769,7 @@ func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: m // CHECK-LABEL: func @shift_right_arithmetic_memrefs func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -777,7 +777,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -785,7 +785,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, // CHECK-LABEL: func @shift_right_logical_memrefs func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -793,7 +793,7 @@ func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %a func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -801,17 +801,17 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a // CHECK-LABEL: func @all_reduce_memrefs func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () { - "xla_lhlo.all_reduce"(%arg0, %arg_out) ({ + "lmhlo.all_reduce"(%arg0, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): - %max = xla_hlo.maximum %lhs, %rhs : tensor - "xla_hlo.return"(%max) : (tensor) -> () + %max = mhlo.maximum %lhs, %rhs : tensor + "mhlo.return"(%max) : (tensor) -> () }) { replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> () - "xla_lhlo.all_reduce"(%arg0, %arg_out) ({ + "lmhlo.all_reduce"(%arg0, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): - %max = xla_hlo.maximum %lhs, %rhs : tensor - "xla_hlo.return"(%max) : (tensor) -> () + %max = mhlo.maximum %lhs, %rhs : tensor + "mhlo.return"(%max) : (tensor) -> () }) { replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, @@ -826,11 +826,11 @@ func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () // CHECK-LABEL: func @collective_permute_memrefs func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () { - "xla_lhlo.collective_permute"(%arg0, %arg_out) { + "lmhlo.collective_permute"(%arg0, %arg_out) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> } : (memref<128x32xf32>, memref<128x32xf32>) -> () - "xla_lhlo.collective_permute"(%arg0, %arg_out) { + "lmhlo.collective_permute"(%arg0, %arg_out) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, channel_id = { handle = 5 : i64, type = 2 : i64 } } : (memref<128x32xf32>, memref<128x32xf32>) -> () @@ -841,7 +841,7 @@ func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128 // CHECK-LABEL: func @fft_memrefs func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex>) -> () { - "xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex>) -> () + "lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex>) -> () return } @@ -852,7 +852,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>, %grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>, %grad_offset: memref<8xf32>) -> () { - "xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + "lmhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () return @@ -863,7 +863,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, // CHECK-LABEL: func @batch_norm_inference_memrefs func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, %arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () { - "xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + "lmhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> () return } @@ -874,7 +874,7 @@ func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, %output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>, %batch_var: memref<8xf32>) -> () { - "xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + "lmhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () return } @@ -883,8 +883,8 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3 // CHECK-LABEL: func @cholesky_memrefs func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () { - "xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () - "xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () + "lmhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () + "lmhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () return } @@ -892,7 +892,7 @@ func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291x // CHECK-LABEL: func @infeed_memrefs func @infeed_memrefs(%arg_out: memref<3xf32>) -> () { - "xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> () + "lmhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> () return } @@ -900,7 +900,7 @@ func @infeed_memrefs(%arg_out: memref<3xf32>) -> () { // CHECK-LABEL: func @outfeed_memrefs func @outfeed_memrefs(%arg0: memref<3xf32>) -> () { - "xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> () + "lmhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> () return } @@ -908,7 +908,7 @@ func @outfeed_memrefs(%arg0: memref<3xf32>) -> () { // CHECK-LABEL: func @replica_id_memrefs func @replica_id_memrefs(%arg_out: memref) -> () { - "xla_lhlo.replica_id"(%arg_out) : (memref) -> () + "lmhlo.replica_id"(%arg_out) : (memref) -> () return } @@ -916,7 +916,7 @@ func @replica_id_memrefs(%arg_out: memref) -> () { // CHECK-LABEL: func @triangular_solve_memrefs func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () { - "xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} + "lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> () return } @@ -925,9 +925,9 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, % // CHECK-LABEL: func @while_memrefs func @while_memrefs(%arg0: memref, %arg_out: memref) -> () { - "xla_lhlo.while"(%arg0, %arg_out) ( - { ^bb0(%arg: memref, %cond: memref): "xla_lhlo.terminator"() : () -> () }, - { ^bb0(%arg: memref, %body_out: memref): "xla_lhlo.terminator"() : () -> () } + "lmhlo.while"(%arg0, %arg_out) ( + { ^bb0(%arg: memref, %cond: memref): "lmhlo.terminator"() : () -> () }, + { ^bb0(%arg: memref, %body_out: memref): "lmhlo.terminator"() : () -> () } ) : (memref, memref) -> () return } @@ -936,9 +936,9 @@ func @while_memrefs(%arg0: memref, %arg_out: memref) -> () { // CHECK-LABEL: func @while_memrefs func @while_memrefs(%arg0: memref, %arg1: memref<5xf32>, %arg0_out: memref, %arg1_out: memref<5xf32>) -> () { - "xla_lhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) ( - { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %cond: memref): "xla_lhlo.terminator"() : () -> () }, - { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %body_out0: memref, %body_out1: memref<5xf32>): "xla_lhlo.terminator"() : () -> () } + "lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) ( + { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %cond: memref): "lmhlo.terminator"() : () -> () }, + { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %body_out0: memref, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () } ) : (memref, memref<5xf32>, memref, memref<5xf32>) -> () return } @@ -947,7 +947,7 @@ func @while_memrefs(%arg0: memref, %arg1: memref<5xf32>, %arg0_out: memref< // CHECK-LABEL: func @bitcast_memrefs func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () { - "xla_lhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> () + "lmhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> () return } @@ -956,10 +956,10 @@ func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () { // CHECK-LABEL: func @scatter_memrefs func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>, %updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () { - "xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({ + "lmhlo.scatter" (%input, %indices, %updates, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): // no predecessors - %add = xla_hlo.add %lhs, %rhs : tensor - "xla_hlo.return"(%add) : (tensor) -> () + %add = mhlo.add %lhs, %rhs : tensor + "mhlo.return"(%add) : (tensor) -> () }) { scatter_dimension_numbers = { update_window_dims = dense<[1]> : tensor<1xi64>, @@ -977,10 +977,10 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32 // CHECK-LABEL: func @map_memrefs func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () { - "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ + "lmhlo.map"(%arg0, %arg1, %arg_out) ({ ^bb0(%a: tensor, %b: tensor): - %c = xla_hlo.add %a, %b : tensor - "xla_hlo.return"(%c) : (tensor) -> () + %c = mhlo.add %a, %b : tensor + "mhlo.return"(%c) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> () return } @@ -989,10 +989,10 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () { // expected-error@+1{{requires the same shape for all operands}} - "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ + "lmhlo.map"(%arg0, %arg1, %arg_out) ({ ^bb0(%a: tensor, %b: tensor): - %c = xla_hlo.add %a, %b : tensor - "xla_hlo.return"(%c) : (tensor) -> () + %c = mhlo.add %a, %b : tensor + "mhlo.return"(%c) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> () return } @@ -1001,7 +1001,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref // CHECK-LABEL: func @rng_get_and_update_state_memrefs func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () { - "xla_lhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> () + "lmhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> () return } @@ -1010,10 +1010,10 @@ func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () { // CHECK-LABEL: func @sort_memrefs func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { - "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): - %7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () return } @@ -1023,10 +1023,10 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, // CHECK-LABEL: func @sort_memrefs func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { - "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): - %7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () return } @@ -1036,10 +1036,10 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, // CHECK-LABEL: func @sort_memrefs func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { - "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): - %7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () return } diff --git a/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir b/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir new file mode 100644 index 00000000000000..8d84e7140f34e0 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir @@ -0,0 +1,224 @@ +// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s + +// CHECK-LABEL: @add +func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3 + %4 = "mhlo.add"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "mhlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @add_unranked +func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3 + %4 = "mhlo.add"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "mhlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @sub +func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3 + %4 = "mhlo.subtract"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "mhlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @sub_unranked +func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3 + %4 = "mhlo.subtract"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "mhlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @mul +func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3 + // CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3 + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]] + %4 = "mhlo.multiply"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "mhlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32> + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @mul_unranked +func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3 + // CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3 + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]] + %4 = "mhlo.multiply"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "mhlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + + // CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32> + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @div +func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3) + + // Compute the numerator's real component: + // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag + // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]] + // CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]] + + // Compute the real valued denominator as rhs * con(rhs): + // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]] + + // Compute the numerator's imaginary component: + // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag + // CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]] + // CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]] + + // Divide the numerator by the real valued denominator. + // CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]] + // CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]] + %4 = "mhlo.divide"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + + %5 = "mhlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "mhlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return [[VAL10]], [[VAL11]] + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @div_unranked +func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3) + + // Compute the numerator's real component: + // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag + // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]] + // CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]] + + // Compute the real valued denominator as rhs * con(rhs): + // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]] + + // Compute the numerator's imaginary component: + // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag + // CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]] + // CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]] + + // Divide the numerator by the real valued denominator. + // CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]] + // CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]] + %4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + + %5 = "mhlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "mhlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + + // CHECK: return [[VAL10]], [[VAL11]] + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @abs +func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { + %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg0 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1 + // CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]]) + %1 = "mhlo.abs"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) + %2 = "mhlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return [[VAL3]] + return %2 : tensor<2xf32> +} + +// CHECK-LABEL: @exp +func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0) + // CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1) + // CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1) + // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]] + %1 = "mhlo.exponential"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) + %2 = "mhlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %3 = "mhlo.imag"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return [[VAL3]], [[VAL4]] + return %2, %3 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @exp_unranked +func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0) + // CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1) + // CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1) + // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]] + %1 = "mhlo.exponential"(%0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) + %2 = "mhlo.real"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %3 = "mhlo.imag"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) + + // CHECK: return [[VAL3]], [[VAL4]] + return %2, %3 : tensor<*xf32>, tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/lower-general-dot.mlir b/tensorflow/compiler/mlir/hlo/tests/lower-general-dot.mlir new file mode 100644 index 00000000000000..36cb1fd6159497 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/lower-general-dot.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-hlo-opt -mhlo-test-lower-general-dot -split-input-file %s -o - | FileCheck %s + +// CHECK-LABEL: @testDebatch1 +func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> { + // CHECK-DAG: [[R0:%.+]] = "mhlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32> + // CHECK-DAG: [[R1:%.+]] = "mhlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + // CHECK: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32> + %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32> + + return %0 : tensor<1x1x3xf32> +} + +// ----- + +// CHECK-LABEL: @testDebatch2 +func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> { + // CHECK-DAG: [[R0:%.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> + // CHECK-DAG: [[R1:%.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32> + // CHECK-DAG: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32> + // CHECK-DAG: [[R3:%.+]] = "mhlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32> + // CHECK: [[R4:%.+]] = "mhlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + + %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32> + return %0 : tensor<3x1x1xf32> +} + +// ----- + +// CHECK-LABEL: @testBatchPassthrough +func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> { + // CHECK-NEXT: "mhlo.dot_general"(%arg0, %arg1) + %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32> + return %0 : tensor<3x2x1xf32> +} + diff --git a/tensorflow/compiler/mlir/hlo/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/hlo/tests/materialize-broadcasts.mlir new file mode 100644 index 00000000000000..682987d776dd7f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/materialize-broadcasts.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-hlo-opt -mhlo-test-materialize-broadcasts -split-input-file %s -o - | FileCheck %s + +// CHECK-LABEL: @clampBroadcast +// CHECK-SAME: (%[[MIN:.+]]: tensor, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor) +func @clampBroadcast(%min: tensor, %value: tensor<4xf32>, %max: tensor) -> tensor<4xf32> { + // CHECK-DAG: %[[MIN_BC:.+]] = "mhlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> + // CHECK-DAG: %[[MAX_BC:.+]] = "mhlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> + // CHECK: "mhlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "mhlo.clamp"(%min, %value, %max) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> + return %0 : tensor<4xf32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/mhlo-fusion.mlir b/tensorflow/compiler/mlir/hlo/tests/mhlo-fusion.mlir new file mode 100644 index 00000000000000..d349077e881c43 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/mhlo-fusion.mlir @@ -0,0 +1,97 @@ +// RUN: mlir-hlo-opt %s -mhlo-fusion -split-input-file | FileCheck %s + +// CHECK-LABEL: func @multi_outputs_same +func @multi_outputs_same(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "mhlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor + %2 = "mhlo.add"(%1, %1) : (tensor, tensor) -> tensor + // CHECK: %[[RET:.*]]:2 = "mhlo.fusion" + // CHECK-NEXT: mhlo.add + // CHECK-NEXT: mhlo.subtract + // CHECK-NEXT: mhlo.add + // CHECK-NEXT: mhlo.return + return %1, %2 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @multi_outputs_same_2 +func @multi_outputs_same_2(%arg0: tensor, %arg1: tensor) -> (tensor, tensor, tensor) { + %0 = "mhlo.abs"(%arg0) : (tensor) -> tensor + %1 = "mhlo.abs"(%arg1) : (tensor) -> tensor + %2 = "mhlo.add"(%0, %1) : (tensor, tensor) -> tensor + %3 = "mhlo.abs"(%0) : (tensor) -> tensor + %4 = "mhlo.abs"(%1) : (tensor) -> tensor + // CHECK: %[[RET:.*]]:3 = "mhlo.fusion" + // CHECK-NEXT: mhlo.abs + // CHECK-NEXT: mhlo.abs + // CHECK-NEXT: mhlo.add + // CHECK-NEXT: mhlo.abs + // CHECK-NEXT: mhlo.abs + // CHECK-NEXT: mhlo.return + return %2, %3, %4 : tensor, tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @multi_outputs_not_sure_same +func @multi_outputs_not_sure_same(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "mhlo.add"(%arg0, %arg0) : (tensor, tensor) -> tensor + // CHECK-NOT: mhlo.fusion + %1 = "mhlo.subtract"(%arg1, %arg1) : (tensor, tensor) -> tensor + return %0, %1 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @reduce +func @reduce(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "mhlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor + // CHECK: %[[RET0:.*]] = "mhlo.fusion" + // CHECK-NEXT: mhlo.add + // CHECK-NEXT: mhlo.subtract + // CHECK-NEXT: mhlo.return + // Currently we do not support fuse arguments and ops without direct producer-consumer + // relationship. Thus Reduce Op should not be fused with above two ops. + + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.reduce"(%arg0, %2) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %4 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "mhlo.return"(%4) : (tensor) -> () + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor + %4 = "mhlo.add"(%3, %3) : (tensor, tensor) -> tensor + // Above two ops should not be fused since reduce op can not be + // fused with its consumer. + // CHECK-NOT: mhlo.fusion + + return %1, %4 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @reduce_2 +func @reduce_2(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "mhlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor + + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.reduce"(%1, %2) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %4 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "mhlo.return"(%4) : (tensor) -> () + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[RET0:.*]]:2 = "mhlo.fusion" + // CHECK-NEXT: mhlo.add + // CHECK-NEXT: mhlo.subtract + // CHECK-NEXT: mhlo.constant + // CHECK-NEXT: mhlo.reduce + // CHECK: mhlo.return + + // Following op should not be fused with the above ops since reduce op can not be + // fused with its consumer. + // CHECK-NOT: mhlo.fusion + %4 = "mhlo.add"(%3, %3) : (tensor, tensor) -> tensor + return %1, %4 : tensor, tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/xla-transform-unranked-hlo.mlir b/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir similarity index 63% rename from tensorflow/compiler/mlir/xla/tests/xla-transform-unranked-hlo.mlir rename to tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir index c3b6497b934110..80474156f29106 100644 --- a/tensorflow/compiler/mlir/xla/tests/xla-transform-unranked-hlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s +// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s // Check the validity of expected IR. // CHECK-LABEL: @sqr_transform_result @@ -9,15 +9,15 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { %num_elements = shape.num_elements %shape %num_elements_as_index = shape.size_to_index %num_elements %flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex> - %flat_a = "xla_hlo.dynamic_reshape"(%a, %flat_shape) + %flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor // Apply operation. - %flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor) -> tensor + %flat_b = "mhlo.sqrt"(%flat_a) : (tensor) -> tensor // Restore original shape. %shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor - %b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) + %b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) : (tensor, tensor) -> tensor<*xf32> return %b : tensor<*xf32> @@ -33,12 +33,12 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] // CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> - // CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor + // CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor // CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor - // CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: return %[[B]] : tensor<*xf32> - %b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> + %b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> return %b : tensor<*xf32> } @@ -48,9 +48,9 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @sqrt_ranked // CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>) func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> { - // CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32> + // CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32> // CHECK-NEXT: return %[[B]] : tensor<3x?xf32> - %b = "xla_hlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32> + %b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32> return %b : tensor<3x?xf32> } @@ -60,9 +60,9 @@ func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> { // CHECK-LABEL: @sqrt_static // CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>) func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: return %[[B]] : tensor<2x3xf32> - %b = "xla_hlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32> + %b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32> return %b : tensor<2x3xf32> } @@ -77,12 +77,12 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] // CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> - // CHECK: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[FLAT_B:.*]] = "xla_hlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[FLAT_RESULT:.*]] = xla_hlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor + // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor // CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> // CHECK: return %[[RESULT]] : tensor<*xf32> - %result = xla_hlo.add %a, %b : tensor<*xf32> + %result = mhlo.add %a, %b : tensor<*xf32> return %result : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/ops.mlir similarity index 62% rename from tensorflow/compiler/mlir/xla/tests/ops.mlir rename to tensorflow/compiler/mlir/hlo/tests/ops.mlir index 2c68a0f5c8969c..b46827b88a5375 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/ops.mlir @@ -1,21 +1,21 @@ -// RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s // Tests for types, ops with custom constraints, verifiers, printer or parser // methods. -// CHECK-LABEL: func @token_type() -> !xla_hlo.token -func @token_type() -> !xla_hlo.token +// CHECK-LABEL: func @token_type() -> !mhlo.token +func @token_type() -> !mhlo.token // ----- -// expected-error@+1 {{unknown xla_hlo type: foobar}} -func @invalid_type() -> !xla_hlo.foobar +// expected-error@+1 {{unknown mhlo type: foobar}} +func @invalid_type() -> !mhlo.foobar // ----- // CHECK-LABEL: func @alltoall func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { - %0 = "xla_hlo.all_to_all"(%data) { + %0 = "mhlo.all_to_all"(%data) { split_dimension = 1 : i64, concat_dimension = 0 : i64, split_count = 4 : i64, @@ -28,7 +28,7 @@ func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { // CHECK-LABEL: func @alltoall_unranked_input func @alltoall_unranked_input(%data: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.all_to_all"(%data) { + %0 = "mhlo.all_to_all"(%data) { split_dimension = 1 : i64, concat_dimension = 0 : i64, split_count = 5 : i64, @@ -41,7 +41,7 @@ func @alltoall_unranked_input(%data: tensor<*xf32>) -> tensor<*xf32> { func @alltoall_invalid_split_dim_size(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { // expected-error@+1 {{split dimension has size 16, expected to be a multiple of split_count 5}} - %0 = "xla_hlo.all_to_all"(%data) { + %0 = "mhlo.all_to_all"(%data) { split_dimension = 1 : i64, concat_dimension = 0 : i64, split_count = 5 : i64, @@ -54,7 +54,7 @@ func @alltoall_invalid_split_dim_size(%data: tensor<4x16xf32>) -> tensor<16x4xf3 // CHECK-LABEL: func @broadcast func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -62,7 +62,7 @@ func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_sizes has rank 2 instead of rank 1}} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -70,7 +70,7 @@ func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_bad_result_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{result rank (3) does not match operand rank (1) plus size of broadcast_sizes (1)}} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -78,7 +78,7 @@ func @broadcast_bad_result_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_bad_first_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{result has shape [1, 3] instead of [2, 3]}} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x3xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x3xi32> return %0 : tensor<1x3xi32> } @@ -86,7 +86,7 @@ func @broadcast_bad_first_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{result has shape [2, 1] instead of [2, 3]}} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<2x1xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<2x1xi32> return %0 : tensor<2x1xi32> } @@ -94,7 +94,7 @@ func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2 // CHECK-LABEL: func @broadcast_in_dim func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> return %0 : tensor<1x2x2xi32> } @@ -102,7 +102,7 @@ func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { // CHECK-LABEL: func @broadcast_in_dim_zero_rank func @broadcast_in_dim_zero_rank(%arg0: tensor) -> tensor<1x2x3xi32> { - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -110,7 +110,7 @@ func @broadcast_in_dim_zero_rank(%arg0: tensor) -> tensor<1x2x3xi32> { // CHECK-LABEL: func @dynamic_broadcast_in_dim func @dynamic_broadcast_in_dim(%arg0: tensor, %shape: tensor<3xi64>) -> tensor { - %0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xi64>) -> tensor + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xi64>) -> tensor return %0 : tensor } @@ -118,7 +118,7 @@ func @dynamic_broadcast_in_dim(%arg0: tensor, %shape: tensor<3xi64>) -> func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -126,7 +126,7 @@ func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x func @broadcast_in_dim_bad_dimension_size(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions size (1) does not match operand rank (2)}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -134,7 +134,7 @@ func @broadcast_in_dim_bad_dimension_size(%arg0: tensor<1x2xi32>) -> tensor<1x2x func @broadcast_in_dim_bad_rank_decrease(%arg0: tensor<1x2x3xi32>) -> tensor<3xi32> { // expected-error@+1 {{result rank (1) is less than operand rank (3)}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x3xi32>) -> tensor<3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x3xi32>) -> tensor<3xi32> return %0 : tensor<3xi32> } @@ -142,7 +142,7 @@ func @broadcast_in_dim_bad_rank_decrease(%arg0: tensor<1x2x3xi32>) -> tensor<3xi func @broadcast_in_dim_dimension_values_too_large(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions contains invalid value 9 for result result with rank 3}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[9, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[9, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -150,7 +150,7 @@ func @broadcast_in_dim_dimension_values_too_large(%arg0: tensor<1x2xi32>) -> ten func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{size of operand dimension 0 (3) is not equal to 1 or size of result dimension 1 (2)}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -158,18 +158,18 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x func @case_mismatch_num_args(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { // expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}} - %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): - %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor, %arg1: tensor): - %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.copy"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () } ) : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor @@ -179,18 +179,18 @@ func @case_mismatch_num_args(%index: tensor, %operand_1: tensor, %oper func @case_mismatch_num_results(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { // expected-error@+1 {{branch 1 returned values do not match op result types}} - %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): - %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1, %arg0) : (tensor, tensor) -> () + %1 = "mhlo.copy"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1, %arg0) : (tensor, tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () } ) : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor @@ -200,18 +200,18 @@ func @case_mismatch_num_results(%index: tensor, %operand_1: tensor, %o func @case_mismatch_arg_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { // expected-error@+1 {{expects operand 2 to be of type 'tensor', but found 'tensor'}} - %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): - %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = xla_hlo.constant dense<2.0> : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.constant dense<2.0> : tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () } ) : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor @@ -221,18 +221,18 @@ func @case_mismatch_arg_type(%index: tensor, %operand_1: tensor, %oper func @case_mismatch_return_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { // expected-error@+1 {{branch 1 returned values do not match op result types}} - %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): - %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = xla_hlo.constant dense<2> : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.constant dense<2> : tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () } ) : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor @@ -242,7 +242,7 @@ func @case_mismatch_return_type(%index: tensor, %operand_1: tensor, %o func @case_empty_region(%index: tensor, %operand_1: tensor) -> () { // expected-error@+1 {{cannot have empty regions}} - "xla_hlo.case"(%index, %operand_1) ( {} ) : (tensor, tensor) -> tensor + "mhlo.case"(%index, %operand_1) ( {} ) : (tensor, tensor) -> tensor return } @@ -250,7 +250,7 @@ func @case_empty_region(%index: tensor, %operand_1: tensor) -> () { // CHECK-LABEL: func @comp_eq func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> return %0 : tensor<3xi1> } @@ -258,7 +258,7 @@ func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { func @comp_bad_direction(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { // expected-error@+1 {{'comparison_direction' failed to satisfy constraint}} - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "FOOBAR"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "FOOBAR"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> return %0 : tensor<3xi1> } @@ -266,7 +266,7 @@ func @comp_bad_direction(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3 func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{duplicate sources not allowed}} - %0 = "xla_hlo.collective_permute"(%arg0) { + %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1], [0, 2], [2, 3]]> : tensor<3x2xi64> } : (tensor<128x32xf32>) -> tensor<128x32xf32> return %0 : tensor<128x32xf32> @@ -276,7 +276,7 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor< func @collective_permute_duplicate_targets(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{duplicate targets not allowed}} - %0 = "xla_hlo.collective_permute"(%arg0) { + %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 1]]> : tensor<3x2xi64> } : (tensor<128x32xf32>) -> tensor<128x32xf32> return %0 : tensor<128x32xf32> @@ -286,7 +286,7 @@ func @collective_permute_duplicate_targets(%arg0: tensor<128x32xf32>) -> tensor< func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{expect source_target_pairs attribute to be of rank 2, but got rank 1}} - %0 = "xla_hlo.collective_permute"(%arg0) { + %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[0, 1]> : tensor<2xi64> } : (tensor<128x32xf32>) -> tensor<128x32xf32> return %0 : tensor<128x32xf32> @@ -296,7 +296,7 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor< func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{expect source_target_pairs attribute of shape (N, 2), but got (2, 3)}} - %0 = "xla_hlo.collective_permute"(%arg0) { + %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi64> } : (tensor<128x32xf32>) -> tensor<128x32xf32> return %0 : tensor<128x32xf32> @@ -306,15 +306,15 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor< // CHECK-LABEL: @concat_1D func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> return %0 : tensor<3xi32> } // ----- func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> { - // expected-error@+1 {{'xla_hlo.concatenate' op requires the same element type for all operands and results}} - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32> + // expected-error@+1 {{'mhlo.concatenate' op requires the same element type for all operands and results}} + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32> return %0 : tensor<3xi32> } @@ -322,23 +322,23 @@ func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tenso // CHECK-LABEL: @concat_1D_unranked func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> return %0 : tensor<*xi32> } // ----- func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> { - // expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}} - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32> + // expected-error@+1 {{'mhlo.concatenate' op inferred type incompatible with return type of operation}} + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32> return %0 : tensor<3xi32> } // ----- func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> { - // expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}} - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32> + // expected-error@+1 {{'mhlo.concatenate' op inferred type incompatible with return type of operation}} + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -346,7 +346,7 @@ func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi // CHECK-LABEL: func @clamp func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> { - %0 = "xla_hlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %0 = "mhlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -354,15 +354,15 @@ func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @clamp_scalar func @clamp_scalar(%arg0: tensor<1xi32>, %arg1: tensor) -> tensor<1xi32> { - %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg1) : (tensor, tensor<1xi32>, tensor) -> tensor<1xi32> + %0 = "mhlo.clamp"(%arg1, %arg0, %arg1) : (tensor, tensor<1xi32>, tensor) -> tensor<1xi32> return %0: tensor<1xi32> } // ----- func @clamp_invalid_clamp_element_type(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> tensor<1xi32> { - // expected-error@+1 {{'xla_hlo.clamp' op requires the same element type for all operands and results}} - %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg0) : (tensor<1xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // expected-error@+1 {{'mhlo.clamp' op requires the same element type for all operands and results}} + %0 = "mhlo.clamp"(%arg1, %arg0, %arg0) : (tensor<1xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -370,7 +370,7 @@ func @clamp_invalid_clamp_element_type(%arg0: tensor<1xi32>, %arg1: tensor<1xf32 func @clamp_invalid_clamp_shape(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<1xi32> { // expected-error@+1 {{min shape [2] is not scalar and does not match operand shape [1]}} - %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg0) : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %0 = "mhlo.clamp"(%arg1, %arg0, %arg0) : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -378,7 +378,7 @@ func @clamp_invalid_clamp_shape(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> t // CHECK-LABEL: func @dot_vector func @dot_vector(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) -> tensor { - %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor return %0: tensor } @@ -386,7 +386,7 @@ func @dot_vector(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) -> tensor // CHECK-LABEL: func @dot_matrix func @dot_matrix(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0: tensor<2x2xi32> } @@ -394,7 +394,7 @@ func @dot_matrix(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi // CHECK-LABEL: func @dot_precision_config func @dot_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0: tensor<2x2xi32> } @@ -402,23 +402,23 @@ func @dot_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> te func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { // expected-error@+1 {{'precision_config' failed to satisfy constraint}} - %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["FOO", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["FOO", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0: tensor<2x2xi32> } // ----- -func @infeed_invalid_number_of_results(%token: !xla_hlo.token) -> tuple>, !xla_hlo.token, tensor> { +func @infeed_invalid_number_of_results(%token: !mhlo.token) -> tuple>, !mhlo.token, tensor> { // expected-error@+1 {{result is expected to be a tuple of size 2, but got 3}} - %0 = "xla_hlo.infeed"(%token) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple>, !xla_hlo.token, tensor> - return %0 : tuple>, !xla_hlo.token, tensor> + %0 = "mhlo.infeed"(%token) {infeed_config = "foobar"} : (!mhlo.token) -> tuple>, !mhlo.token, tensor> + return %0 : tuple>, !mhlo.token, tensor> } // ----- -func @infeed_non_token_second_result(%token: !xla_hlo.token) -> tuple>, tensor> { +func @infeed_non_token_second_result(%token: !mhlo.token) -> tuple>, tensor> { // expected-error@+1 {{second element of result tuple is expected to be of token type, but got 'tensor'}} - %0 = "xla_hlo.infeed"(%token) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple>, tensor> + %0 = "mhlo.infeed"(%token) {infeed_config = "foobar"} : (!mhlo.token) -> tuple>, tensor> return %0 : tuple>, tensor> } @@ -426,7 +426,7 @@ func @infeed_non_token_second_result(%token: !xla_hlo.token) -> tuple tensor { // expected-error@+1 {{does not support scalars}} - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor return %0 : tensor } @@ -434,7 +434,7 @@ func @iota_scalar() -> tensor { func @iota_invalid_iota_dimension() -> tensor<4xi32> { // expected-error@+1 {{iota dimension cannot go beyond the output rank or be negative}} - %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4xi32> + %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -442,10 +442,10 @@ func @iota_invalid_iota_dimension() -> tensor<4xi32> { func @map_mismatched_args(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // expected-error@+1 {{expects number of operands to match the arity of map computation, but got: 2 and 1}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg: tensor): - %1 = xla_hlo.add %arg, %arg {name = "add"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg, %arg {name = "add"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -454,10 +454,10 @@ func @map_mismatched_args(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor< func @map_non_scalar_computation_operand(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{computation arguments must be 0-rank tensor, but got: arg #1 of type 'tensor<5xf32>'}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor<5xf32>): - %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -466,10 +466,10 @@ func @map_non_scalar_computation_operand(%arg0: tensor<4x5xf32>, %arg1: tensor<4 func @map_mismatch_operand_and_computation_args(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{element type of operands and computation arguments must match, but got: 'f32' and 'i32'}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -478,10 +478,10 @@ func @map_mismatch_operand_and_computation_args(%arg0: tensor<4x5xf32>, %arg1: t func @map_invalid_number_of_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{computation must return single output, but got: 0}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor - "xla_hlo.return"() : () -> () + %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + "mhlo.return"() : () -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -490,10 +490,10 @@ func @map_invalid_number_of_computation_output(%arg0: tensor<4x5xf32>, %arg1: te func @main_non_scalar_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{computation must return 0-rank tensor, but got: 'tensor<5xf32>'}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor<5xf32> - "xla_hlo.return"(%1) : (tensor<5xf32>) -> () + %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor<5xf32> + "mhlo.return"(%1) : (tensor<5xf32>) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -502,10 +502,10 @@ func @main_non_scalar_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4 func @mismatch_computation_output_type(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{element type of result and computation output must match, but got: 'f32' and 'i32'}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.constant {value = dense<2> : tensor} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.constant {value = dense<2> : tensor} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -514,10 +514,10 @@ func @mismatch_computation_output_type(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5 func @map_invalid_dimension_numbers(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{requires monotonically increasing dimension numbers, but got: dense<[1, 0]> : tensor<2xi64>}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -526,10 +526,10 @@ func @map_invalid_dimension_numbers(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf3 func @map_mismatch_arguments_and_dimensions(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{applied to a subset of dimensions currently not supported: operand dimensions = 2, requested map dimensions size = 3}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -538,48 +538,48 @@ func @map_mismatch_arguments_and_dimensions(%arg0: tensor<4x5xf32>, %arg1: tenso // CHECK-LABEL: func @map_unranked func @map_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // ----- -func @recv_invalid_number_of_results(%token: !xla_hlo.token) -> tuple, tensor, !xla_hlo.token> { +func @recv_invalid_number_of_results(%token: !mhlo.token) -> tuple, tensor, !mhlo.token> { // expected-error@+1 {{result is expected to be a tuple of size 2, but got 3}} - %0 = "xla_hlo.recv"(%token) { + %0 = "mhlo.recv"(%token) { channel_id = { handle = 5 : i64, type = 3 : i64 // Host to device channel }, is_host_transfer = true - } : (!xla_hlo.token) -> tuple, tensor, !xla_hlo.token> - return %0 : tuple, tensor, !xla_hlo.token> + } : (!mhlo.token) -> tuple, tensor, !mhlo.token> + return %0 : tuple, tensor, !mhlo.token> } // ----- -func @recv_non_token_second_result(%token: !xla_hlo.token) -> tuple, tensor> { +func @recv_non_token_second_result(%token: !mhlo.token) -> tuple, tensor> { // expected-error@+1 {{second element of result tuple is expected to be of token type, but got 'tensor'}} - %0 = "xla_hlo.recv"(%token) { + %0 = "mhlo.recv"(%token) { channel_id = { handle = 5 : i64, type = 3 : i64 // Host to device channel }, is_host_transfer = true - } : (!xla_hlo.token) -> tuple, tensor> + } : (!mhlo.token) -> tuple, tensor> return %0 : tuple, tensor> } // ----- func @rng_uniform_invalid_type(%mu: tensor>, %sigma: tensor) -> tensor<2x3x5xf32> { - %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> + %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{but got 'tensor>'}} - %0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng_uniform"(%mu, %sigma, %shape) : (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> return %0 : tensor<2x3x5xf32> } @@ -587,7 +587,7 @@ func @rng_uniform_invalid_type(%mu: tensor>, %sigma: tensor) - // CHECK-LABEL: func @select func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -595,7 +595,7 @@ func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi3 // CHECK-LABEL: func @select_scalar_pred func @select_scalar_pred(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -603,7 +603,7 @@ func @select_scalar_pred(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tenso // CHECK-LABEL: func @select_cast_compatible_types func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<*xi32>, %arg2: tensor<2x3xi32>) -> tensor<*xi32> { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<*xi32>, tensor<2x3xi32>) -> tensor<*xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<*xi32>, tensor<2x3xi32>) -> tensor<*xi32> return %0 : tensor<*xi32> } @@ -612,7 +612,7 @@ func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<*xi32>, %arg func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<2x?xi32>, %arg2: tensor) -> tensor { // TODO(lucyfox): Update once this is supported. // expected-error@+1 {{currently unsupported operand types: 'tensor<2x?xi32>' and 'tensor'}} - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x?xi32>, tensor) -> tensor + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x?xi32>, tensor) -> tensor return %0 : tensor } @@ -620,7 +620,7 @@ func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<2x?xi32>, %a // CHECK-LABEL: func @select_scalar_x_y func @select_scalar_x_y(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -628,7 +628,7 @@ func @select_scalar_x_y(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) values}} - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -636,7 +636,7 @@ func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // expected-error@+1 {{incompatible operand types: 'tensor<2x4xi32>' and 'tensor<2x3xi32>'}} - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -644,7 +644,7 @@ func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x4xi32>, %ar func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // expected-error@+1 {{incompatible operand types: 'tensor<2x3xf32>' and 'tensor<2x3xi32>'}} - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -652,7 +652,7 @@ func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf3 // CHECK-LABEL: func @slice func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { - %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -660,7 +660,7 @@ func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { // expected-error@+1 {{failed to verify that all of {start_indices, limit_indices, strides} have same type}} - %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -668,7 +668,7 @@ func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> { // expected-error@+1 {{requires the same element type for all operands and results}} - %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32> + %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -676,7 +676,7 @@ func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> { // CHECK-LABEL: func @dynamic_slice func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -684,7 +684,7 @@ func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // expected-error@+1 {{has mismatched number of slice sizes (1) and number of start indices (2)}} - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -692,7 +692,7 @@ func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor, // CHECK-LABEL: @dynamic_slice_different_indice_element_type func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor) -> tensor<1x4xi32> { - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -700,7 +700,7 @@ func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xf32> { // expected-error@+1 {{failed to verify that all of {operand, result} have same element type}} - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xf32> + %0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -708,7 +708,7 @@ func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { // expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> + %0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -716,7 +716,7 @@ func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) // CHECK-LABEL: @dynamic_update_slice func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start1: tensor, %start2: tensor) -> tensor<3x4xi64> { - %0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor, tensor) -> tensor<3x4xi64> + %0 = "mhlo.dynamic-update-slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor, tensor) -> tensor<3x4xi64> return %0 : tensor<3x4xi64> } @@ -724,7 +724,7 @@ func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %sta func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { // expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} - %0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> + %0 = "mhlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> return %0 : tensor<3x4xi64> } @@ -732,21 +732,21 @@ func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tenso // CHECK-LABEL: func @transpose func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } // ----- func @transpose_ranked(%arg0: tensor) -> tensor { - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor) -> tensor + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor) -> tensor return %0: tensor } // ----- func @transpose_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<*xi32>) -> tensor<*xi32> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<*xi32>) -> tensor<*xi32> return %0: tensor<*xi32> } @@ -754,7 +754,7 @@ func @transpose_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { func @transpose_bad_permutations_rank(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // expected-error@+1 {{permutation has rank 2 instead of rank 1}} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[[1]]> : tensor<1x1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[[1]]> : tensor<1x1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } @@ -762,7 +762,7 @@ func @transpose_bad_permutations_rank(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1 func @transpose_bad_permutations_size(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // expected-error@+1 {{operand rank (4) does not match permutation size (1)}} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1]> : tensor<1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1]> : tensor<1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } @@ -770,7 +770,7 @@ func @transpose_bad_permutations_size(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1 func @transpose_operand_result_rank_mismatch(%arg0: tensor<1x2x3x4xi32>) -> tensor<2xi32> { // expected-error@+1 {{result rank (1) does not match permutation size (4)}} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2xi32> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } @@ -778,14 +778,14 @@ func @transpose_operand_result_rank_mismatch(%arg0: tensor<1x2x3x4xi32>) -> ten func @transpose_operand_result_permutation_mismatch(%arg0: tensor<1x?x3x?xi32>) -> tensor { // expected-error@+1 {{result type tensor is incompatible with the expected type tensor}} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x?x3x?xi32>) -> tensor + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x?x3x?xi32>) -> tensor return %0: tensor } // ----- func @triangular_solve_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -793,7 +793,7 @@ func @triangular_solve_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> t func @triangular_solve_rank_less_than_2(%arg0: tensor<4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { // expected-error@+1 {{operand 'a' must have rank >= 2, but got 'tensor<4xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> return %0 : tensor<4x3xf32> } @@ -801,7 +801,7 @@ func @triangular_solve_rank_less_than_2(%arg0: tensor<4xf32>, %arg1: tensor<4x3x func @triangular_solve_unequal_minor_dims_a(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { // expected-error@+1 {{two minor dimensions of operand 'a' must have equal size, but got 'tensor<4x3xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> return %0 : tensor<4x3xf32> } @@ -809,7 +809,7 @@ func @triangular_solve_unequal_minor_dims_a(%arg0: tensor<4x3xf32>, %arg1: tenso func @triangular_solve_unequal_rank(%arg0: tensor<10x4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { // expected-error@+1 {{operands must have equal rank, but got 'tensor<10x4x4xf32>' and 'tensor<4x3xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> return %0 : tensor<4x3xf32> } @@ -817,7 +817,7 @@ func @triangular_solve_unequal_rank(%arg0: tensor<10x4x4xf32>, %arg1: tensor<4x3 func @triangular_solve_mismatch_shared_dim(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf32> { // expected-error@+1 {{shared dimension of operands 'a' and 'b' does not match, but got 'tensor<4x4xf32>' and 'tensor<3x4xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> return %0 : tensor<3x4xf32> } @@ -825,7 +825,7 @@ func @triangular_solve_mismatch_shared_dim(%arg0: tensor<4x4xf32>, %arg1: tensor func @triangular_solve_mismatch_leading_dims(%arg0: tensor<10x5x4x4xf32>, %arg1: tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> { // expected-error@+1 {{leading batch dimensions of the operands must be same, but got 'tensor<10x5x4x4xf32>' and 'tensor<10x6x4x3xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x5x4x4xf32>, tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x5x4x4xf32>, tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> return %0 : tensor<10x6x4x3xf32> } @@ -833,7 +833,7 @@ func @triangular_solve_mismatch_leading_dims(%arg0: tensor<10x5x4x4xf32>, %arg1: func @triangular_solve_mismatch_result_and_b_type(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{result and operand 'b' must have same shape, but got 'tensor<4x4xf32>' and 'tensor<4x3xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x4xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x4xf32> return %0 : tensor<4x4xf32> } @@ -841,7 +841,7 @@ func @triangular_solve_mismatch_result_and_b_type(%arg0: tensor<4x4xf32>, %arg1: // CHECK-LABEL: func @tuple func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> { - %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> + %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> return %0: tuple, tensor<1x2xf32>> } @@ -849,7 +849,7 @@ func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple func @tuple_arg_size_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, tensor, tensor> { // expected-error@+1 {{has return type tuple, tensor, tensor>, but expected tuple, tensor>}} - %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor, tensor> + %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor, tensor> return %0 : tuple, tensor, tensor> } @@ -857,29 +857,29 @@ func @tuple_arg_size_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, %arg1: tensor) -> tuple, tensor> { // expected-error@+1 {{has return type tuple, tensor>, but expected tuple, tensor>}} - %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor> + %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor> return %0 : tuple, tensor> } // ----- func @get_tuple_element(%arg0: tuple, tensor>) -> tensor { - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor return %0 : tensor } // ----- -func @get_tuple_element_token(%arg0: tuple, !xla_hlo.token>) -> !xla_hlo.token { - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, !xla_hlo.token>) -> !xla_hlo.token - return %0 : !xla_hlo.token +func @get_tuple_element_token(%arg0: tuple, !mhlo.token>) -> !mhlo.token { + %0 = "mhlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, !mhlo.token>) -> !mhlo.token + return %0 : !mhlo.token } // ----- func @get_tuple_element_bad_type(%arg0: tuple, tensor>) -> tensor { // expected-error@+1 {{has return type tensor, but expected tensor}} - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor return %0 : tensor } @@ -887,7 +887,7 @@ func @get_tuple_element_bad_type(%arg0: tuple, tensor>) -> tens func @get_tuple_element_index_out_of_bounds(%arg0: tuple, tensor>) -> tensor { // expected-error@+1 {{index 2 is out of bounds of operand with size 2}} - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, tensor>) -> tensor + %0 = "mhlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, tensor>) -> tensor return %0 : tensor } @@ -895,14 +895,14 @@ func @get_tuple_element_index_out_of_bounds(%arg0: tuple, tensor, %arg1: tensor<4xi32>) -> tensor<4xi32> { - %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %0 = "mhlo.and"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } // ----- // CHECK-LABEL: func @or_i1_type func @or_i1_type(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + %0 = "mhlo.or"(%arg0, %arg1) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } @@ -910,7 +910,7 @@ func @or_i1_type(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // expected-error@+1 {{but got 'tensor<4xf32>'}} - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "mhlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -918,7 +918,7 @@ func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor< func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { // expected-error@+1 {{must be tensor of floating-point values, but got 'tensor<4xi32>'}} - %0 = "xla_hlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %0 = "mhlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -927,11 +927,11 @@ func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { // Verifiers HLO constant op custom printing and parsing. // CHECK-LABEL: func @constants func @constants() -> () { - // CHECK: xla_hlo.constant dense<0> : tensor - %0 = "xla_hlo.constant"() {value = dense<0> : tensor} : () -> (tensor) + // CHECK: mhlo.constant dense<0> : tensor + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> (tensor) - // CHECK: xla_hlo.constant {extra_attr = 3 : i32} dense<0> : tensor - %1 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor) + // CHECK: mhlo.constant {extra_attr = 3 : i32} dense<0> : tensor + %1 = "mhlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor) return } @@ -939,18 +939,18 @@ func @constants() -> () { func @constant_invalid() -> () { // expected-error@+1 {{op failed to verify that all of {value, output} have same type}} - %0 = "xla_hlo.constant"() {value = dense<0> : tensor} : () -> (tensor<*xi32>) + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> (tensor<*xi32>) return } // ----- func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { - // CHECK: xla_hlo.sort - %0 = "xla_hlo.sort"(%input0, %input1) ( { + // CHECK: mhlo.sort + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -959,10 +959,10 @@ func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { func @sort_no_operands() { // expected-error @+1 {{op requires at least one input}} - %0 = "xla_hlo.sort"() ( { + %0 = "mhlo.sort"() ( { ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): - %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : () -> tuple<> return } @@ -970,10 +970,10 @@ func @sort_no_operands() { // ----- func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -982,10 +982,10 @@ func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{comparator block argument #0 should be of type 'tensor' but got 'tensor'}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -994,10 +994,10 @@ func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op requires all inputs to have the same dimensions}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -1006,10 +1006,10 @@ func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found 10}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -1018,10 +1018,10 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3 func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found -3}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -1030,10 +1030,10 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3 func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op comparator block should have 4 arguments}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -1042,10 +1042,10 @@ func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op comparator block argument #3 should be of type 'tensor' but got 'tensor'}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -1054,7 +1054,7 @@ func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x1 // CHECK: func @dequantize func @dequantize(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { - %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> + %0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> return %0 : tensor<16x64xbf16> } @@ -1062,7 +1062,7 @@ func @dequantize(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { func @dequantize_wrong_shape(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { // expected-error @+1 {{mismatched dimensions.}} - %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = true} : (tensor<16x16xi32>) -> tensor<16x64xbf16> + %0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = true} : (tensor<16x16xi32>) -> tensor<16x64xbf16> return %0 : tensor<16x64xbf16> } @@ -1070,7 +1070,7 @@ func @dequantize_wrong_shape(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { func @dequantize_wrong_size(%arg: tensor<16x16xi32>) -> tensor<16x16xbf16> { // expected-error @+1 {{last dimension of output should be 4x of the input.}} - %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x16xbf16> + %0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x16xbf16> return %0 : tensor<16x16xbf16> } @@ -1078,7 +1078,7 @@ func @dequantize_wrong_size(%arg: tensor<16x16xi32>) -> tensor<16x16xbf16> { func @dequantize_wrong_mode(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { // expected-error @+1 {{Dequantization mode. Only MIN_COMBINED is supported.}} - %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "hello", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> + %0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "hello", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> return %0 : tensor<16x64xbf16> } @@ -1086,7 +1086,7 @@ func @dequantize_wrong_mode(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> { // expected-error @+1 {{number of output elements (9) doesn't match expected number of elements (8)}} - %0 = "xla_hlo.reshape"(%operand) : (tensor<2x4xf32>) -> tensor<3x3xf32> + %0 = "mhlo.reshape"(%operand) : (tensor<2x4xf32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> } @@ -1094,7 +1094,7 @@ func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> { func @dot_general(%arg0: tensor, %arg1: tensor) { // expected-error @+1 {{lhs and rhs should have the same number of batching dimensions}} - %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { + %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, @@ -1107,7 +1107,7 @@ func @dot_general(%arg0: tensor, %arg1: tensor) { func @dot_general(%arg0: tensor, %arg1: tensor) { // expected-error @+1 {{lhs and rhs should have the same number of contracting dimensions}} - %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { + %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[]> : tensor<0xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, @@ -1119,7 +1119,7 @@ func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func @compatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor { - %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor + %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor return %0 : tensor } @@ -1127,6 +1127,6 @@ func @compatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor func @incompatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor { // expected-error @+1 {{output should have a rank equal to the number of elements in output_shape}} - %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor + %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/reduce.mlir b/tensorflow/compiler/mlir/hlo/tests/reduce.mlir similarity index 56% rename from tensorflow/compiler/mlir/xla/tests/reduce.mlir rename to tensorflow/compiler/mlir/hlo/tests/reduce.mlir index d49b34d6f74cd8..586a1995471738 100644 --- a/tensorflow/compiler/mlir/xla/tests/reduce.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/reduce.mlir @@ -1,14 +1,14 @@ -// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: func @noop // CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>) // CHECK: return %[[ARG0]] func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %2 = "xla_hlo.reduce"(%arg0, %0) ( { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %2 = "mhlo.reduce"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): - %4 = xla_hlo.add %arg1, %arg2 : tensor - "xla_hlo.return"(%4) : (tensor) -> () + %4 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%4) : (tensor) -> () }) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> } diff --git a/tensorflow/compiler/mlir/hlo/tests/reshape.mlir b/tensorflow/compiler/mlir/hlo/tests/reshape.mlir new file mode 100644 index 00000000000000..9aa28a44f4e2f0 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/reshape.mlir @@ -0,0 +1,149 @@ +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s + +// CHECK-LABEL: func @const_fold_collapse_to_scalar +func @const_fold_collapse_to_scalar() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42> : tensor<1x1xi32> + %0 = "mhlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_fold_collapse_to_tensor +func @const_fold_collapse_to_tensor() -> tensor<2xi32> { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<2xi32> + %cst = mhlo.constant dense<42> : tensor<1x2xi32> + %0 = "mhlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: func @const_fold_expand +func @const_fold_expand() -> tensor<1xi32> { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<1xi32> + %cst = mhlo.constant dense<42> : tensor + %0 = "mhlo.reshape"(%cst) : (tensor) -> tensor<1xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @const_fold_nontrivial +func @const_fold_nontrivial() -> tensor<16xi64> { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64> + %cst = mhlo.constant dense<42> : tensor<4x4xi64> + %0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<16xi64> +} + +// ----- + +// CHECK-LABEL: func @const_fold_flatten +func @const_fold_flatten() -> tensor<16xi64> { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64> + %cst = mhlo.constant dense<42> : tensor<4x4xi64> + %0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<16xi64> +} + +// ----- + +// CHECK-LABEL: func @const_fold_6 +func @const_fold_6() -> tensor<6xi32> { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> + %cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> + %0 = "mhlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<6xi32> +} + +// ----- + +// CHECK-LABEL: func @const_fold_same_shape +func @const_fold_same_shape() -> tensor<2x3xi32> { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6] + // CHECK-SAME: ]> : tensor<2x3xi32> + %cst = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> + %0 = "mhlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @const_fold_float +func @const_fold_float() -> tensor<16xf64> { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64> + %cst = mhlo.constant dense<4.2> : tensor<4x4xf64> + %0 = "mhlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<16xf64> +} + +// ----- + +// CHECK-LABEL: func @non_const_same_shape +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK-NEXT: return [[ARG]] + %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @non_const_chained_reshape +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) { + // CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32> + // CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> + %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> + return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed +} + +// ----- + +// CHECK-LABEL: func @non_const_chained_reshape_unused_parent +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> { + // CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> + %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> + // CHECK-NEXT: return [[RES]] + return %1 : tensor<6xi32> +} + +// ----- + +// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: return [[ARG]] + return %1 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @non_const_many_chained_reshapes +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> { + // CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> + %0 = "mhlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32> + %1 = "mhlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32> + %2 = "mhlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32> + %3 = "mhlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32> + %4 = "mhlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32> + // CHECK-NEXT: return [[RES]] + return %4 : tensor<1x2x4x3xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/reverse.mlir b/tensorflow/compiler/mlir/hlo/tests/reverse.mlir similarity index 51% rename from tensorflow/compiler/mlir/xla/tests/reverse.mlir rename to tensorflow/compiler/mlir/hlo/tests/reverse.mlir index e0e80400b81eb4..6e291af8f87960 100644 --- a/tensorflow/compiler/mlir/xla/tests/reverse.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/reverse.mlir @@ -1,9 +1,9 @@ -// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: func @noop // CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>) func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { - %0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %0 = "mhlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32> // CHECK: return %[[ARG0]] return %0 : tensor<1x2xf32> } diff --git a/tensorflow/compiler/mlir/hlo/tests/sink-constants-to-control-flow.mlir b/tensorflow/compiler/mlir/hlo/tests/sink-constants-to-control-flow.mlir new file mode 100644 index 00000000000000..f8b6b629c9ec7d --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/sink-constants-to-control-flow.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-hlo-opt %s -mhlo-sink-constants-to-control-flow | FileCheck %s + +// Tests sinking constants to a while loop. + +// CHECK-LABEL: func @sink_const_to_while +func @sink_const_to_while(%arg0: tensor) -> tensor { + // CHECK-NEXT: mhlo.while + %c0 = mhlo.constant dense<1> : tensor + %c1 = mhlo.constant dense<2> : tensor + %0 = "mhlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + // CHECK: %[[ARG1A:.+]]: tensor + // CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor + // CHECK: "mhlo.compare"(%[[C0]], %[[ARG1A]]) + %1 = "mhlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + // CHECK: %[[ARG1B:.+]]: tensor + // CHECK-DAG: %[[C1:.+]] = mhlo.constant dense<2> : tensor + // CHECK-DAG: %[[ADD0:.+]] = mhlo.add %[[ARG1B]], %[[ARG1B]] + %2 = mhlo.add %arg1, %arg1 : tensor + // CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], %[[ADD0]] + %3 = mhlo.add %c1, %2 : tensor + // CHECK: %[[ADD2:.+]] = mhlo.add %[[C1]], %[[ADD1]] + %4 = mhlo.add %c1, %3 : tensor + "mhlo.return"(%4) : (tensor) -> () + }) : (tensor) -> tensor + return %0 : tensor +} + +// Tests sinking constants to a conditional op. + +// CHECK-LABEL: func @sink_const_to_conditional +func @sink_const_to_conditional(%arg0: tensor) -> tensor { + %c0 = mhlo.constant dense<1> : tensor + %c1 = mhlo.constant dense<2> : tensor + %0 = "mhlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + %1 = "mhlo.tuple"(%arg0) : (tensor) -> tuple> + // CHECK: mhlo.if + %2 = "mhlo.if"(%0, %1, %1) ( { + ^bb0(%arg1: tuple>): + // CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor + %3 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD0:.+]] = mhlo.add %[[C0]], + %4 = mhlo.add %c0, %3 : tensor + %5 = "mhlo.tuple"(%4) : (tensor) -> tuple> + "mhlo.return"(%5) : (tuple>) -> () + }, { + ^bb0(%arg1: tuple>): + // CHECK: %[[C1:.+]] = mhlo.constant dense<2> : tensor + %6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], + %7 = mhlo.add %c1, %6 : tensor + %8 = "mhlo.tuple"(%7) : (tensor) -> tuple> + "mhlo.return"(%8) : (tuple>) -> () + }) : (tensor, tuple>, tuple>) -> tuple> + %9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor + return %9 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/transpose.mlir b/tensorflow/compiler/mlir/hlo/tests/transpose.mlir similarity index 53% rename from tensorflow/compiler/mlir/xla/tests/transpose.mlir rename to tensorflow/compiler/mlir/hlo/tests/transpose.mlir index 11470a2fd88925..bbfedc57383e49 100644 --- a/tensorflow/compiler/mlir/xla/tests/transpose.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/transpose.mlir @@ -1,9 +1,9 @@ -// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: func @remove_noop // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> { - %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> + %0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> // CHECK-NEXT: return [[ARG]] return %0 : tensor<2x3x9x5xi32> } @@ -13,8 +13,8 @@ func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> { // CHECK-LABEL: func @keep_real_transpose // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { - // CHECK-NEXT: "xla_hlo.transpose"([[ARG]]) - %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> + // CHECK-NEXT: "mhlo.transpose"([[ARG]]) + %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> return %0 : tensor<3x2x5x9xi32> } @@ -23,7 +23,7 @@ func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { // CHECK-LABEL: func @keep_same_shape_real_transpose // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> { - // CHECK-NEXT: "xla_hlo.transpose"([[ARG]]) - %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32> + // CHECK-NEXT: "mhlo.transpose"([[ARG]]) + %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32> return %0 : tensor<4x4xi32> } diff --git a/tensorflow/compiler/mlir/hlo/tests/tuple.mlir b/tensorflow/compiler/mlir/hlo/tests/tuple.mlir new file mode 100644 index 00000000000000..4ecc1e308ba1be --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/tuple.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s + +// CHECK-LABEL: func @fold_access +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @fold_access(%arg : tensor) -> tensor { + // CHECK-NEXT: return [[ARG]] + %tuple = "mhlo.tuple"(%arg) : (tensor) -> tuple> + %element = "mhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple>) -> tensor + return %element : tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir b/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir similarity index 52% rename from tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir rename to tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir index b33a13c0cb9716..c1930721218ebf 100644 --- a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s +// RUN: mlir-hlo-opt -split-input-file -mhlo-test-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s // CHECK-LABEL: @batchNormInference_2D_inner_features // CHECK-SAME: %[[X:[^:[:space:]]+]] @@ -10,19 +10,18 @@ func @batchNormInference_2D_inner_features( %x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor<4x256xf32>) { - // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor - // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<256xf32> - // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> - // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32> - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + // CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32> + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> + // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32> + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> tensor<4x256xf32> @@ -36,12 +35,12 @@ func @batchNormInference_2D_inner_features( // the verifier to enforce the rest. // CHECK-SAME: %[[X:[^:]+]] // CHECK-SAME: %[[SCALE:[^:]+]] -// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> +// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> func @batchNormInference_4D_middle_features( %x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor<3x4x256x6xf32>) { - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} : (tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> tensor<3x4x256x6xf32> @@ -51,12 +50,12 @@ func @batchNormInference_4D_middle_features( // ----- // CHECK-LABEL: @batchNormInference_f64 // Validate that epsilon is properly promoted to f64 -// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor +// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf64> func @batchNormInference_f64( %x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>, %mean: tensor<256xf64>, %variance: tensor<256xf64>) -> (tensor<4x256xf64>) { - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 1.0 : f32, feature_index = 1 : i64} : (tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>) -> tensor<4x256xf64> @@ -66,12 +65,12 @@ func @batchNormInference_f64( // ----- // CHECK-LABEL: @batchNormInference_f16 // Validate that epsilon is properly promoted to f64 -// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor +// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf16> func @batchNormInference_f16( %x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>, %mean: tensor<256xf16>, %variance: tensor<256xf16>) -> (tensor<4x256xf16>) { - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 1.0 : f32, feature_index = 1 : i64} : (tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>) -> tensor<4x256xf16> @@ -85,7 +84,7 @@ func @batchNormInference_f16_overflow( %mean: tensor<256xf16>, %variance: tensor<256xf16>) -> (tensor<4x256xf16>) { // expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}} - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 0.00000001 : f32, feature_index = 1 : i64} : (tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>) -> tensor<4x256xf16> @@ -108,26 +107,26 @@ func @batchNormInference_dynamic_shape( // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index - // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor + // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex> - // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor - // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor - // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor + // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor + // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor - // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor - // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor - // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor + // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor + // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor + // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 0.001 : f32, feature_index = 1 : i64} : (tensor, tensor, tensor, tensor, tensor) -> tensor diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index b1f448899644dc..8d0c204f434939 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -472,23 +472,17 @@ filegroup( ], ) -genrule( +gentbl( name = "op_quant_spec_getters_inc", - srcs = [ - "ir/tfl_ops.td", - "ir/tfl_op_interfaces.td", - "experimental/tfl_hardware_interfaces.td", + tbl_outs = [("", "utils/generated_op_quant_spec_getters.inc")], + tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen", + td_file = "ir/tfl_ops.td", + td_srcs = [ "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", + "experimental/tfl_hardware_interfaces.td", + "ir/tfl_op_interfaces.td", ], - outs = [ - "utils/generated_op_quant_spec_getters.inc", - ], - cmd = ("$(location //tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen) " + - "-I external/llvm-project/mlir/include " + - "-I tensorflow/compiler/mlir " + - "$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"), - tools = ["//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen"], ) # Library with tensorflow Lite dialect static initialization. diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 83ff997124641d..92c45b98ea7021 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -30,7 +30,7 @@ struct PassConfig { explicit PassConfig(QuantizationSpecs specs) : emit_builtin_tflite_ops(true), lower_tensor_list_ops(false), - trim_functions_whitelist({}), + trim_functions_allowlist({}), quant_specs(std::move(specs)), form_clusters(false), unfold_batch_matmul(true), @@ -44,8 +44,8 @@ struct PassConfig { // If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic // TF ops before legalization to TF Lite dialect. bool lower_tensor_list_ops; - // The whitelist of functions that would be preserved after trimming. - llvm::ArrayRef trim_functions_whitelist; + // The allowlist of functions that would be preserved after trimming. + llvm::ArrayRef trim_functions_allowlist; // All information about quantization. QuantizationSpecs quant_specs; // If `form_clusters` is true , clusters are formed by grouping consecutive diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index ee8b34598e27c9..fb20e842a75ef3 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -71,7 +71,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" +#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/string_util.h" @@ -101,7 +101,7 @@ using mlir::Value; using tensorflow::OpOrArgLocNameMapper; using tensorflow::OpOrArgNameMapper; using tensorflow::Status; -using tflite::flex::IsWhitelistedFlexOp; +using tflite::flex::IsAllowlistedFlexOp; using xla::StatusOr; template @@ -972,7 +972,7 @@ Optional> Translator::BuildOperator( // model is of an open op system. // // The following algorithm is followed: - // if flex is enabled and the op is whitelisted as flex + // if flex is enabled and the op is allowlisted as flex // we emit op as flex. // if custom is enabled // we emit the op as custom. @@ -982,11 +982,11 @@ Optional> Translator::BuildOperator( } // Flex op case - // Eventually, the whitelist will go away and we will rely on some TF op + // Eventually, the allowlist will go away and we will rely on some TF op // trait (e.g. No side effect) to determine if it is a supported "Flex" // op or not. if (enabled_op_types_.contains(OpType::kSelectTf) && - IsWhitelistedFlexOp(node_def->op())) { + IsAllowlistedFlexOp(node_def->op())) { // Construct ops as flex op encoding TensorFlow node definition // as custom options. // Flex ops are named with the kFlexOpNamePrefix prefix to the actual @@ -1037,7 +1037,7 @@ Optional> Translator::BuildOperator( } // Insert failed op to `flex_ops` or `custom_ops`. - if (IsWhitelistedFlexOp(node_def->op())) { + if (IsAllowlistedFlexOp(node_def->op())) { failed_flex_ops_.insert(os.str()); } else { failed_custom_ops_.insert(os.str()); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index faa6bc36824de6..fa85b4e50fd97f 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -152,6 +152,14 @@ StatusOr GetQuantizedType(const TensorT& tensor, Builder builder, uint32_t flags = is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0; + // Rejects if quantized tensors have zero scales. + for (float scale : quant_params.scale) { + if (scale == 0) { + return errors::InvalidArgument( + "Quantized tensors must have non-zero scales"); + } + } + // Scale size can't be zero as it is checked before. if (quant_params.scale.size() != 1) { llvm::SmallVector scales(quant_params.scale.begin(), @@ -443,8 +451,7 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, if (auto float_type = elem_type.dyn_cast()) { TF_ASSIGN_OR_RETURN(value, ConvertFloatBuffer(shaped_type, float_type, buffer)); - } else if (elem_type.isa() || - elem_type.isa()) { + } else if (elem_type.isa()) { TF_ASSIGN_OR_RETURN(value, ConvertIntBuffer(shaped_type, elem_type, buffer)); } else if (elem_type.isa()) { @@ -456,8 +463,7 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, refs.push_back({ref.data(), ref.size()}); value = mlir::DenseStringElementsAttr::get(shaped_type, refs); - } else if (elem_type.isa() || - elem_type.isa()) { + } else if (elem_type.isa()) { auto dialect = elem_type.getContext()->getRegisteredDialect("tf"); tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer); std::string mangled = tensorflow::mangling_util::MangleTensor(repr); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 1b409bc939bcc7..0c9ccf1a97944b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -694,8 +694,7 @@ void QuantizationDriver::SetupAllStates() { fn_.walk([&](Operation *op) { if (op->isKnownTerminator() || op->hasTrait() || - llvm::isa(op) || - llvm::isa(op)) + llvm::isa(op)) return; work_list_.push_back(op); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 35c930281d0318..4ced43014f5963 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -386,8 +386,7 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern { Operation* def = pre_quantized.getDefiningOp(); if (!def) return failure(); - if (llvm::isa(def) || - llvm::isa(def) || + if (llvm::isa(def) || def->hasTrait()) { return failure(); } diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir index e1c81493c2e80c..1a61bc3f517dea 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir @@ -13,7 +13,7 @@ func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> { } func @emptyTensorlistConst(%arg0 : tensor<1xi32>) -> tensor<0x3xi32> { - // CHECK: %[[LIST:.*]] = "tf.Const"() {value = dense<{{\[\[}}]]> : tensor<0x3xi32>} : () -> tensor<0x3xi32> + // CHECK: %[[LIST:.*]] = "tf.Const"() {value = dense<> : tensor<0x3xi32>} : () -> tensor<0x3xi32> %0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20222A5C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C30303322"> : tensor} : () -> tensor>> // CHECK: return %[[LIST]] diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir index 361af19040d658..7ce60d9806292c 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s -split-input-file | FileCheck %s +// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s -split-input-file -verify-diagnostics | FileCheck %s module{ func @embedding(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> attributes {tf._implements = "embedding_matmul", tf._reference = "mlir"} { @@ -453,3 +453,31 @@ func @inference_standard_lstm_time_major_cannot_fuse(%arg0: tensor, % // CHECK: return [[VAL_11]], [[VAL_10]], [[VAL_11]], [[VAL_11]], [[VAL_12]] : tensor, tensor, tensor, tensor, tensor // CHECK: } } + +// ----- + +module { +func @nms_padded(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> (tensor<1x10xi32>, tensor) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} { + %0 = "tf.Const"() {value = dense<1> : tensor<1x10xi32>} : () -> tensor<1x10xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor + return %0, %1 : tensor<1x10xi32>, tensor +} + +// CHECK: func @nms_padded(%[[VAL_119:.*]]: tensor<100x4xf32>, %[[VAL_120:.*]]: tensor<100xf32>, %[[VAL_121:.*]]: tensor, %[[VAL_122:.*]]: tensor, %[[VAL_123:.*]]: tensor, %[[VAL_124:.*]]: tensor, %[[VAL_125:.*]]: tensor, %[[VAL_126:.*]]: tensor, %[[VAL_127:.*]]: tensor) -> (tensor<1x10xi32>, tensor) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} { +// CHECK: %[[VAL_128:.*]], %[[VAL_129:.*]] = "tfl.non_max_suppression_v4"(%[[VAL_119]], %[[VAL_120]], %[[VAL_121]], %[[VAL_122]], %[[VAL_123]]) : (tensor<100x4xf32>, tensor<100xf32>, tensor, tensor, tensor) -> (tensor<1x10xi32>, tensor) +// CHECK: return %[[VAL_128]], %[[VAL_129]] : tensor<1x10xi32>, tensor +// CHECK: } +} + +// ----- + +module { +// expected-error @+1 {{Invalid number of results from non_max_suppression_padded_v2}} +func @nms_padded_invalid_num_results(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} + +// expected-error @+1 {{Invalid number of arguments to non_max_suppression_padded_v2}} +func @nms_padded_invalid_num_args(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor, %arg3: tensor) -> (tensor<1x10xi32>, tensor) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} + +// expected-error @+1 {{TFLite does not support batched input for non_max_suppression_padded}} +func @nms_padded_with_batches(%arg0: tensor<2x100x4xf32>, %arg1: tensor<2x100xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> (tensor<2x10xi32>, tensor) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index 05eb8de71e9737..53caf15bc8f6e1 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-whitelist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s +// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-allowlist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s // CHECK-LABEL: quantize_float_placeholder_only func @quantize_float_placeholder_only(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor, tensor<2x3xi32>, tensor<2x3xf32>) { diff --git a/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir index 0087ae1215657f..0b8c147cde2068 100644 --- a/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-whitelist="bar,foobar" %s | FileCheck %s +// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-allowlist="bar,foobar" %s | FileCheck %s func @foo(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> { return %arg0 : tensor<1x4xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index bd3c217605b19a..d26a4906420946 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -560,7 +560,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { return failure(); ShapedType filter_type = filter_cst.getType(); - if (llvm::isa(binary_op) || llvm::isa(binary_op)) { + if (llvm::isa(binary_op)) { auto padding = fc_op.template getAttrOfType("padding"); if (padding && padding.getValue() != "VALID") return failure(); @@ -606,7 +606,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { rewriter.create(fc_op.getLoc(), new_bias_type, new_bias); fc_op.setOperand(0, binary_op->getOperand(0)); fc_op.setOperand(2, new_bias_op); - } else if (llvm::isa(binary_op) || llvm::isa(binary_op)) { + } else if (llvm::isa(binary_op)) { // The fusion of mul/div is actually applying the following // transformation: // w * (x ' c) + b => (w ' c) x + b diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 105c9394fb4168..af97931b2a3cb6 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -61,7 +61,7 @@ std::unique_ptr> CreatePostQuantizePass( // Creates an instance of the TensorFlow Lite dialect TrimFunctions // pass. std::unique_ptr> CreateTrimFunctionsPass( - llvm::ArrayRef trim_funcs_whitelist); + llvm::ArrayRef trim_funcs_allowlist); // Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions // pass. diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 221e8c70cd7f0f..3d2ab662e6f50d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -57,6 +57,7 @@ namespace { constexpr char kTFAPIImplements[] = "tf.api_implements"; constexpr char kTfTextAPIPRefix[] = "tftext:"; +constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2"; // Abstracts the conversion of the embedded lookup composite function. class ConvertEmbeddedLookupFunc { @@ -94,6 +95,59 @@ class ConvertEmbeddedLookupFunc { FuncOp func_; }; +// Abstracts the conversion of the padded NMS composite function. +class ConvertNMSPaddedFunc { + public: + explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {} + + void RewriteFunc() { + func_.setAttr(kTFImplements, + StringAttr::get(kTfNMSPadded, func_.getContext())); + Value boxes = func_.getArgument(0); + Value scores = func_.getArgument(1); + Value max_output_size = func_.getArgument(2); + Value iou_threshold = func_.getArgument(3); + Value score_threshold = func_.getArgument(4); + auto output_type0 = func_.getType().getResult(0); + auto output_type1 = func_.getType().getResult(1); + + OpBuilder builder(func_.getBody()); + auto op = builder.create( + func_.getLoc(), output_type0, output_type1, boxes, scores, + max_output_size, iou_threshold, score_threshold); + + builder.create(func_.getLoc(), op.getResults()); + } + + LogicalResult VerifySignature() { + // Verify high-level function signature. + // Relevant argument characteristics are checked by the TFL op definition. + if (func_.getNumArguments() < 5) { + return func_.emitError() + << "Invalid number of arguments to " + "non_max_suppression_padded_v2 (need atleast 5): " + << func_.getNumArguments(); + } + if (func_.getType().getNumResults() != 2) { + return func_.emitError() << "Invalid number of results from " + "non_max_suppression_padded_v2 (need 2): " + << func_.getType().getNumResults(); + } + // The TFLite fused op does not support batching yet. + // TODO(b/158709815): Add support for batches with padded NMS. + auto boxes_type = + func_.getArgument(0).getType().dyn_cast(); + if (!boxes_type.hasRank() || boxes_type.getRank() != 2) { + return func_.emitError() << "TFLite does not support batched input for " + "non_max_suppression_padded"; + } + return success(); + } + + private: + FuncOp func_; +}; + // This pass uses mechanisms listed in RFC: // https://github.com/tensorflow/community/pull/113 // It prepares composite functions that are attributed to indicate @@ -139,6 +193,14 @@ void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func, if (failed(convert_layer_norm_lstm_cell_simple.RewriteFunc())) { return signalPassFailure(); } + } else if (attr.getValue() == kTfNMSPadded) { + func.eraseBody(); + func.addEntryBlock(); + ConvertNMSPaddedFunc convert_nms_padded(func); + if (failed(convert_nms_padded.VerifySignature())) { + return signalPassFailure(); + } + convert_nms_padded.RewriteFunc(); } } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 579063f9c9dd70..9a27d0de62a735 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -35,9 +35,9 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" // NOLINTNEXTLINE -static llvm::cl::list quantize_whitelist( - "tfl-test-quantize-whitelist", llvm::cl::value_desc("list"), - llvm::cl::desc("comma separated list of whitelisted functions to be " +static llvm::cl::list quantize_allowlist( + "tfl-test-quantize-allowlist", llvm::cl::value_desc("list"), + llvm::cl::desc("comma separated list of allowlisted functions to be " "quantized. Only used in tests"), llvm::cl::CommaSeparated); @@ -108,7 +108,7 @@ class PrepareQuantizePass // Get the min and max values from the quantization specification for the // current function function and argument index. Uses default values if - // the function is specified in the `quantize_whitelist`. + // the function is specified in the `quantize_allowlist`. std::pair, llvm::Optional> GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) { if (func_name == quant_specs_.target_func) { @@ -132,7 +132,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) { // Skip this function because it isn't the target function from the spec or // in the function while list. if (target_func != func_name && - !llvm::is_contained(quantize_whitelist, func_name)) { + !llvm::is_contained(quantize_allowlist, func_name)) { return false; } diff --git a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc index 013ffc26ea8da4..9eedf2b4fa6050 100644 --- a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc @@ -29,12 +29,12 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -// The cmd line flag to specify the whitelist of functions. Rest are trimmed +// The cmd line flag to specify the allowlist of functions. Rest are trimmed // after this pass is run. // NOLINTNEXTLINE -static llvm::cl::list trim_funcs_whitelist( - "tfl-trim-funcs-whitelist", llvm::cl::value_desc("list"), - llvm::cl::desc("comma separated list of whitelisted functions. The first " +static llvm::cl::list trim_funcs_allowlist( + "tfl-trim-funcs-allowlist", llvm::cl::value_desc("list"), + llvm::cl::desc("comma separated list of allowlisted functions. The first " "function specified will be used as main."), llvm::cl::CommaSeparated); @@ -43,25 +43,25 @@ namespace TFL { namespace { // The pass to trim functions before we legalize to TFL -// dialect using the specified whitelist. +// dialect using the specified allowlist. class TrimFunctionsPass : public mlir::PassWrapper> { public: - explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {} - explicit TrimFunctionsPass(llvm::ArrayRef trim_funcs_whitelist) - : trim_funcs_whitelist_(trim_funcs_whitelist) {} + explicit TrimFunctionsPass() : trim_funcs_allowlist_(trim_funcs_allowlist) {} + explicit TrimFunctionsPass(llvm::ArrayRef trim_funcs_allowlist) + : trim_funcs_allowlist_(trim_funcs_allowlist) {} private: void runOnOperation() override; bool TrimModule(); void Verify(); - llvm::ArrayRef trim_funcs_whitelist_; + llvm::ArrayRef trim_funcs_allowlist_; }; void TrimFunctionsPass::runOnOperation() { - // trim the functions in the module using the trim_funcs_whitelist_ - // by removing functions not in the whitelist. + // trim the functions in the module using the trim_funcs_allowlist_ + // by removing functions not in the allowlist. if (TrimModule()) { // verify the updated module is still valid, if not signal the // pass as failed. @@ -70,20 +70,20 @@ void TrimFunctionsPass::runOnOperation() { } bool TrimFunctionsPass::TrimModule() { - // if no trim_funcs_whitelist_ is specified, this pass is a no-op. - if (trim_funcs_whitelist_.empty()) return false; + // if no trim_funcs_allowlist_ is specified, this pass is a no-op. + if (trim_funcs_allowlist_.empty()) return false; llvm::SmallVector funcs_to_trim; for (auto func : getOperation().getOps()) { - if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) { - // If no main is specified in the whitelist, use the 1st func - // in trim_funcs_whitelist as the main. + if (llvm::is_contained(trim_funcs_allowlist_, func.getName())) { + // If no main is specified in the allowlist, use the 1st func + // in trim_funcs_allowlist as the main. // TODO(ashwinm): Currently tflite flatbuffer export assumes there is // always a main. This is strictly not required for TFlite. We need to // remove that restriction once we have support to attribute the main // tensorflow function in MLIR TF import using an entry_point attr. - if (!llvm::is_contained(trim_funcs_whitelist_, "main") && - func.getName() == trim_funcs_whitelist_[0]) { + if (!llvm::is_contained(trim_funcs_allowlist_, "main") && + func.getName() == trim_funcs_allowlist_[0]) { func.setName("main"); } } else { @@ -99,7 +99,7 @@ bool TrimFunctionsPass::TrimModule() { } // validate that all reachable functions from the remaining functions are -// also in the whitelist. +// also in the allowlist. void TrimFunctionsPass::Verify() { // TODO(ashwinm): Instead, we should make sure that references to all // SymbolRefAttrs of all ops are present. @@ -109,7 +109,7 @@ void TrimFunctionsPass::Verify() { auto walk_result = func.walk([&](CallOp op) -> WalkResult { if (!symbol_table.lookup(op.getCallee())) return getOperation().emitError() - << func.getName() << " is not in the funcs whitelist"; + << func.getName() << " is not in the funcs allowlist"; return WalkResult::advance(); }); if (walk_result.wasInterrupted()) return signalPassFailure(); @@ -121,13 +121,13 @@ void TrimFunctionsPass::Verify() { // Creates an instance of the TensorFlow Lite dialect TrimFunctions /// pass. std::unique_ptr> CreateTrimFunctionsPass( - llvm::ArrayRef trim_funcs_whitelist) { - return std::make_unique(trim_funcs_whitelist); + llvm::ArrayRef trim_funcs_allowlist) { + return std::make_unique(trim_funcs_allowlist); } static PassRegistration pass( "tfl-trim-funcs-tf", - "Trim functions to restrict them to a specified whitelist prior to " + "Trim functions to restrict them to a specified allowlist prior to " "legalization to TensorFlow lite dialect"); } // namespace TFL diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 0765f292e371e9..67002aa65bf708 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -115,12 +115,12 @@ Status MlirFunctionOptimizationPass::Run( }); if (!is_enabled) { - VLOG(1) << "None of the MLIR optimization passes are enabled " + VLOG(0) << "None of the MLIR optimization passes are enabled " << "(registered " << registry_->passes().size() << ")"; return Status::OK(); } - VLOG(1) << "Running MLIR Graph Optimization Passes " + VLOG(0) << "Running MLIR Graph Optimization Passes " << "(registered " << registry_->passes().size() << " passes)"; GraphDebugInfo debug_info; @@ -187,12 +187,12 @@ Status MlirV1CompatGraphOptimizationPass::Run( }); if (!is_enabled) { - VLOG(1) << "None of the MLIR optimization passes are enabled " + VLOG(0) << "None of the MLIR optimization passes are enabled " << "(registered" << registry_->passes().size() << " passes)"; return Status::OK(); } - VLOG(1) << "Running MLIR Graph Optimization V1 Compat Passes " + VLOG(0) << "Running MLIR Graph Optimization V1 Compat Passes " << "(registered" << registry_->passes().size() << " passes)"; GraphDebugInfo debug_info; diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD index 7945d324dea7d3..5e21dddd444bb4 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD @@ -22,6 +22,7 @@ tf_python_pybind_extension( "//tensorflow/python:pybind11_status", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", "@llvm-project//mlir:StandardOps", "@pybind11", ], diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc index 6f468cd426788f..63ca4c7bb283fb 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc @@ -15,7 +15,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" +#include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -29,6 +33,21 @@ PYBIND11_MODULE(mlir_wrapper, m) { mlir::registerDialect(); mlir::registerDialect(); }); + m.def("verify", [](std::string input) { + llvm::SourceMgr SM = llvm::SourceMgr(); + SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input), + llvm::SMLoc()); + mlir::MLIRContext ctx; + auto module = mlir::parseSourceFile(SM, &ctx); + if (!module) { + return false; + } + if (failed(mlir::verify(*module))) { + module->emitError("Invalid MLIR module: failed verification."); + return false; + } + return true; + }); init_basic_classes(m); init_types(m); diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index 00d909b523e2a8..e3158f21cb271e 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -70,7 +70,7 @@ config.mlir_tools_dir, config.llvm_tools_dir ] tool_names = [ - 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', + 'mlir-opt', 'mlir-hlo-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir' diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index c5cd2b17920842..82175d7f680495 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -42,6 +42,7 @@ mlir_tf_tools_dirs = [ 'tensorflow/compiler/mlir', + 'tensorflow/compiler/mlir/hlo', 'tensorflow/compiler/mlir/lite', 'tensorflow/compiler/mlir/tensorflow', 'tensorflow/compiler/mlir/tfjs', diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 9ba0d03517cc51..3ecbf5ea98a558 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -144,6 +144,7 @@ gentbl( td_srcs = [ "@llvm-project//mlir:include/mlir/IR/OpBase.td", "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td", + "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td", ], test = True, ) @@ -542,6 +543,7 @@ cc_library( "transforms/freeze_global_tensors.cc", "transforms/lift_variables_pass.cc", "transforms/optimize_global_tensors.cc", + "transforms/remove_vars_in_session_initializer.cc", ], hdrs = [ "transforms/tf_saved_model_passes.h", @@ -623,6 +625,7 @@ cc_library( "transforms/tpu_rewrite_pass.cc", "transforms/tpu_sharding_identification_pass.cc", "transforms/tpu_space_to_depth_pass.cc", + "transforms/tpu_update_embedding_enqueue_op_inputs.cc", "transforms/tpu_variable_runtime_reformatting.cc", "translate/breakup-islands.cc", "translate/tf_executor_to_functional.cc", @@ -786,7 +789,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler/utils:transitive_fanin", - "//tensorflow/core/platform:protobuf_internal", "//tensorflow/core/platform:types", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", @@ -1261,27 +1263,23 @@ tf_native_cc_binary( ], ) -genrule( +gentbl( name = "derived_attr_populator_inc", - srcs = [ + tbl_outs = [ + ("", "translate/derived_attr_populator.inc"), + ], + tblgen = ":derived_attr_populator_gen", + td_file = "ir/tf_ops.td", + td_srcs = [ + "@llvm-project//mlir:include/mlir/IR/OpBase.td", "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", - "@llvm-project//mlir:include/mlir/IR/OpBase.td", "ir/tf_generated_ops.td", "ir/tf_op_base.td", "ir/tf_op_interfaces.td", - "ir/tf_ops.td", - ], - outs = [ - "translate/derived_attr_populator.inc", ], - cmd = ("$(location :derived_attr_populator_gen) " + - "-I external/llvm-project/mlir/include " + - "-I tensorflow/compiler/mlir " + - "$(location //tensorflow/compiler/mlir/tensorflow:ir/tf_ops.td) " + " -o $@"), - tools = [":derived_attr_populator_gen"], ) filegroup( @@ -1328,7 +1326,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "//tensorflow/compiler/mlir/hlo:hlo", - "//tensorflow/compiler/mlir/hlo:xla_sink_constants_to_control_flow", + "//tensorflow/compiler/mlir/hlo:sink_constants_to_control_flow", "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", "//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index e4de66b59e2659..be203e0397e50f 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -168,8 +168,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { var_handle.resource(), GetOrCreateIdForVarHandle(var_handle, &next_unique_id, &var_handle_name_id_map)); - } else if (llvm::isa(op) || - llvm::isa(op)) { + } else if (llvm::isa(op)) { for (auto operand_and_result : llvm::zip(op->getOperands(), op->getResults())) { forward_input_to_output(std::get<0>(operand_and_result), @@ -333,7 +332,7 @@ bool OpIsDeclaration(Operation* op, const ResourceAliasAnalysis& alias_analysis) { // TODO(yuanzx): Add other types of resources. return llvm::isa(op) || - ((llvm::isa(op) || llvm::isa(op)) && + (llvm::isa(op) && !FindAccessedResources(op, alias_analysis).empty()); } diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index 2bb79acb8dab7f..ffd9c149d2df6e 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -93,6 +93,15 @@ class MlirTensor : public TracingTensorHandle { explicit MlirTensor(Value value) : TracingTensorHandle(kMlir), value_(value) {} + tensorflow::DataType DataType() const override { + tensorflow::DataType type; + Status s = ConvertScalarTypeToDataType(value_.getType(), &type); + if (!s.ok()) { + return tensorflow::DT_INVALID; + } + return type; + } + void Release() override { delete this; } Value getValue() { return value_; } @@ -127,7 +136,7 @@ class MlirAbstractOp : public TracingOperation { Status SetDeviceName(const char* name) override; Status AddInput(AbstractTensorHandle* input) override; - Status AddInputList(absl::Span inputs) override; + Status AddInputList(absl::Span inputs) override; Status Execute(absl::Span retvals, int* num_retvals) override; @@ -464,7 +473,8 @@ Status MlirAbstractOp::SetDeviceName(const char* name) { return Status::OK(); } -Status MlirAbstractOp::AddInputList(absl::Span inputs) { +Status MlirAbstractOp::AddInputList( + absl::Span inputs) { return tensorflow::errors::Unimplemented( "AddInputList has not been implemented yet."); } diff --git a/tensorflow/compiler/mlir/tensorflow/g3doc/images/space_to_depth_transform.png b/tensorflow/compiler/mlir/tensorflow/g3doc/images/space_to_depth_transform.png new file mode 100644 index 00000000000000..1add1369cc1dbd Binary files /dev/null and b/tensorflow/compiler/mlir/tensorflow/g3doc/images/space_to_depth_transform.png differ diff --git a/tensorflow/compiler/mlir/tensorflow/g3doc/space_to_depth.md b/tensorflow/compiler/mlir/tensorflow/g3doc/space_to_depth.md new file mode 100644 index 00000000000000..5eb2d2a5ed6ab8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/g3doc/space_to_depth.md @@ -0,0 +1,196 @@ +# Automatic Space to Depth Transform in MLIR Bridge + +Author: wangtao@, yuanzx@, hinsu@, lyandy@, chiachenc@, aminim@, jpienaar@, +dehao@ + +## TL;DR + +_This document describes an automatic space to depth transform for the first +convolution in the new MLIR bridge to improve MXU efficiency of low batch size +convolutions._ + +## Background + +For image models, the first layer is usually not MXU friendly as it has a +feature size of 3. This results in poor performance especially with small batch. + +One way to address this issue is to use the `space-to-depth` transform. This +optimization tiles the 2x2 space dimensions to the feature dimension so that the +feature dimension becomes 3\*4=12, which is more MXU friendly. In order to make +this optimization efficient, the shape of the weight needs to be padded and +transposed to the shape that the convolution emitter expects. The input also +needs to be transposed on the host and padded on the device to make the +convolution efficient. Although a 2x2 space-to-depth transform works only when +the first convolution has a stride of 2, many image models, ResNet-like in +particular, have a stride-2 convolution in the first layer. + +Space to depth helped models such as MaskRCNN, SSD and I3D gain more than 2X +speedup and reduce memory usage in the first convolution. + +The first convolution in many image models, including ResNet or ResNet-like, is +a (kernel=7, stride=2) 2D convolution. The input of the convolution is images, +which usually has RGB channels. The input of this first convolution is of shape +[batch\_size, height, width, 3] and the kernel size is [kernel\_size, +kernel\_size, 3, out\_channel]. Space to depth is to transform this first +convolution's input to [batch\_size, height // stride, width // stride, 3 \* +stride \* stride] and the kernel to [kernel\_size // stride, kernel\_size // +stride, 3 \* stride \* stride, out\_channel] to improve TPU MXU utilization. + +![drawings](images/space_to_depth_transform.png) + +This optimization can be automatically done by the graph optimizer where weight +transformation is done at variable loading time and the input transformation is +done for every inference invocation. A further optimization can fuse this (at +host) with the double transpose to minimize memory operation on host. + +## Proposed Method + +**block\_size** is defined as the number of space sizes transformed to the depth +dimension. _stride % block\_size == 0_ and _stride >= block\_size_ is required +to do the transform. There are three parts of automatically space to depth +transformation: + +1. Transform input on the host. + + Space-to-depth performs the following permutation, which is equivalent to + `tf.nn.space_to_depth`. + + ```python + images = tf.reshape(images, [batch, h // block_size, block_size, + w // block_size, block_size, c]) + images = tf.transpose(images, [0, 1, 3, 2, 4, 5]) + images = tf.reshape(images, [batch, h // block_size, w // block_size, + c * (block_size ** 2)]) + ``` + + `SpaceToDepthOp` can be called on the host to perform the transform. + +1. Weight Transformation + + Weight Transformation is similar to Input Transform. Weight transform is + needed to apply space to depth optimization for a model that needs to load a + pre-train checkpoint. This transform can be done on the host or TPU device + based on the cost. As the size of the kernel is relatively small, this won't + add additional cost to TPU device time. Below is the logic to transform the + kernel of shape [7, 7, 3, 64] to [4, 4, 12, 84]. + + ```python + conv0 = tf.compat.v1.layers.Conv2D( + filters=filters, + kernel_size=kernel_size, + strides=2, + padding=('SAME' if strides == 1 else 'VALID'), + use_bias=False, + kernel_initializer=tf.variance_scaling_initializer(), + data_format=data_format) + + # Use the image size without space-to-depth transform as the input of conv0. + batch_size, h, w, channel = inputs.get_shape().as_list() + conv0.build([ + batch_size, h * space_to_depth_block_size, w * space_to_depth_block_size, + channel // (space_to_depth_block_size**2) + ]) + + kernel = conv0.weights[0] + # [7, 7, 3, 64] --> [8, 8, 3, 64] + + kernel = tf.pad( + kernel, + paddings=tf.constant([[1, 0], [1, 0], [0, 0], [0, 0]]), + mode='CONSTANT', + constant_values=0.) + # Transform kernel follows the space-to-depth logic: https://www.tensorflow.org/api_docs/python/tf/nn/space_to_depth) + kernel = tf.reshape( + kernel, + [4, space_to_depth_block_size, 4, space_to_depth_block_size, 3, filters]) + + kernel = tf.transpose(kernel, [0, 2, 1, 3, 4, 5]) + kernel = tf.reshape(kernel, [4, 4, int(channel), filters]) + kernel = tf.cast(kernel, inputs.dtype) + ``` + + If kernel\_size % block\_size != 0, padding is needed for the weight before + transform, input of Convolution needs to be padded as well. + +1. Rewrite the first convolution + + Need to rewrite the first convolution's shape of input from [batch\_size, + height, width, 3] to [batch\_size, height // block\_size, width // + block\_size, 3 \* block\_size \* block\_size] and kernel shape from + [kernel\_size, kernel\_size, 3, out\_channel] to [kernel\_size // + block\_size, kernel\_size // block\_size, 3 \* block\_size \* block\_size, + + This is the proposed workflow for automatic space to depth transformation. + All the transformations will be triggered in a MLIR SpaceToDepthRewritePass, + this Rewrite pass will be triggered before TPURewrite so that no metadata + rewrite is needed. + +* First, the rewrite pass will walk through all the convolutions in func of + tf\_device::LaunchOp and get the first Convolution and its shape; +* Second, the rewrite pass will apply transformations to the first + convolution, the padding before the first convolution, first convolution's + filters and its Conv2DBackPropFilter; +* At last, the rewrite pass will insert SpaceToDepthOp after IteratorGetNext + where the iterator's result has the same shape as the first convolution's + input. + +#### Pseudo MLIR code before and after RewritePass + +```mlir +// Example: original program: +// +module { + func @while_body { + %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}: + -> tensor<2x224x224x3xf32> + %device_launch = "tf_device.launch_func"(%input,...) {func = @_func,...) + return ... + } + func @_func(%input: tensor<2x224x224x3xf32>, + %filter: tensor<7x7x3x64xf32>) { + %6 = "tf.Conv2D"(%input, %filter) {strides = [1, 2, 2, 1]}: + (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> + tensor<2x112x112x64xf32> + } +} + +// With this pass, the program will be transformed into: +module { + func @while_body { + %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"} + -> tensor<2x224x224x3xf32> + %space_to_depth = "tf.SpaceToDepth"(%input) {block_size = 2, ...}: + (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32> + %device_launch = "tf_device.launch_func"(%space_to_depth,...) {func = @_func,...) + return ... + } + func @_func(%input: tensor<2x112x112x12xf32>, + %filter: tensor<7x7x3x64xf32>) { + %filter_transform = "tf.Pad/tf.Transpose/tf.Reshape"(%filter): + tensor<7x7x3x64xf32>) -> tensor<4x4x12x64xf32> + %conv = "tf.Conv2D"(%input, %filter_transfrom) {strides = [1, 1, 1, 1]}: + (tensor<2x112x112x12xf32>, tensor<4x4x12x64xf32>) -> + tensor<2x112x112x64xf32> + } +} +``` + +### SpaceToDepth Trigger Condition + +Space to depth will only be triggered when batch size is small and the first +convolution channel size is small. Stride of the convolution should be bigger +than 1 as well. A cost model will be built that takes input shape and host cost +into consideration to trigger the transformation. There will be a flag to +disable this feature as well. + +### Fuse SpaceToDepth with Automatic Double Transpose + +The transpose and reshape op in SpaceToDepthOp on TPU hosts may cause image +model to be infeed bound. To reduce host time, space to depth transform can be +fused with `automatic double transpose` to reduce extra overhead on the host. + +### Extend from Conv2D to Conv3D + +SpaceToDepth not only helps with 2D image models but also 3D image models such +as I3D. The plan is to apply automatic space to depth for Conv2D as the first +step. After Conv2D is well tested, will generalize this technique to Conv3D. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 7dd7428248772e..77008b55672013 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -229,7 +229,8 @@ namespace { ParseResult ParseReplicateOpOperands( OpAsmParser* parser, OperationState* state, llvm::SmallVectorImpl>* - operands, + replicated_inputs, + llvm::SmallVectorImpl* packed_inputs, llvm::SmallVectorImpl* region_args, llvm::SmallVectorImpl* region_arg_types) { // No operands or empty operand list. @@ -238,26 +239,61 @@ ParseResult ParseReplicateOpOperands( return success(); // Parse comma separated operands of the following format: - // [%a, ...] as %block_arg: type + // replicated_input + // [%a, ...] as %block_arg0: type + // packed_input + // %b as %block_arg1: type + // + // Replicated inputs are placed before packed inputs when forming the op. + llvm::SmallVector replicated_region_args; + llvm::SmallVector packed_region_args; + llvm::SmallVector replicated_region_arg_types; + llvm::SmallVector packed_region_arg_types; do { - if (parser->parseOperandList(operands->emplace_back(), - OpAsmParser::Delimiter::Square) || - parser->parseKeyword("as", - " between replicated inputs and block argument") || - parser->parseRegionArgument(region_args->emplace_back()) || - parser->parseColonType(region_arg_types->emplace_back())) + OpAsmParser::OperandType operand_type; + if (parser->parseOptionalOperand(operand_type).hasValue()) { + packed_inputs->emplace_back(operand_type); + if (parser->parseKeyword("as", + " between packed input and block argument") || + parser->parseRegionArgument(packed_region_args.emplace_back()) || + parser->parseColonType(packed_region_arg_types.emplace_back())) + return failure(); + } else if (parser->parseOperandList(replicated_inputs->emplace_back(), + OpAsmParser::Delimiter::Square) || + parser->parseKeyword( + "as", " between replicated inputs and block argument") || + parser->parseRegionArgument( + replicated_region_args.emplace_back()) || + parser->parseColonType( + replicated_region_arg_types.emplace_back())) { return failure(); + } } while (succeeded(parser->parseOptionalComma())); + region_args->reserve(replicated_region_args.size() + + packed_region_args.size()); + region_args->append(replicated_region_args.begin(), + replicated_region_args.end()); + region_args->append(packed_region_args.begin(), packed_region_args.end()); + + region_arg_types->reserve(replicated_region_arg_types.size() + + packed_region_arg_types.size()); + region_arg_types->append(replicated_region_arg_types.begin(), + replicated_region_arg_types.end()); + region_arg_types->append(packed_region_arg_types.begin(), + packed_region_arg_types.end()); + // Parse remaining `)` surrounding operands. return parser->parseRParen(); } -ParseResult SetOperands( +ParseResult SetReplicateOpOperands( llvm::SMLoc loc, OpAsmParser* parser, OperationState* state, - llvm::ArrayRef> operands, - llvm::ArrayRef region_arg_types, int* n) { - if (operands.empty()) return success(); + llvm::ArrayRef> + replicated_inputs, + llvm::ArrayRef packed_inputs, + llvm::ArrayRef region_arg_types, int32_t* n) { + if (replicated_inputs.empty() && packed_inputs.empty()) return success(); for (const auto& attr : state->attributes) if (attr.first.strref() == "n") @@ -267,38 +303,68 @@ ParseResult SetOperands( if (*n < 2) return parser->emitError(loc) << "expects 'n' to be at least 2, got " << *n; - for (int i = 0, e = operands.size(); i < e; ++i) { - const auto& operand = operands[i]; + for (auto replicated_input_and_idx : llvm::enumerate(replicated_inputs)) { + const int32_t idx = replicated_input_and_idx.index(); + const auto& replicated_input = replicated_input_and_idx.value(); // Check if replicated input matches `n`. - if (operand.size() != *n) + if (replicated_input.size() != *n) return parser->emitError(loc) - << "expects number of operands for replicated input " << i - << " to be 'n' (" << *n << "), got " << operand.size(); + << "expects number of operands for replicated input " << idx + << " to be 'n' (" << *n << "), got " << replicated_input.size(); // Resolve replicated input and block argument type. - if (parser->resolveOperands(operand, region_arg_types[i], state->operands)) + if (parser->resolveOperands(replicated_input, region_arg_types[idx], + state->operands)) + return failure(); + } + + const int32_t num_replicated_block_args = replicated_inputs.size(); + for (auto packed_input_and_idx : llvm::enumerate(packed_inputs)) { + const int32_t idx = packed_input_and_idx.index(); + const auto& packed_input = packed_input_and_idx.value(); + + // Resolve packed input and block argument type. + if (parser->resolveOperand( + packed_input, region_arg_types[idx + num_replicated_block_args], + state->operands)) return failure(); } return success(); } +constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; + ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) { llvm::SMLoc loc = parser->getCurrentLocation(); // Parse operands, attributes, and region of op. - llvm::SmallVector, 8> operands; + llvm::SmallVector, 8> + replicated_inputs; + llvm::SmallVector packed_inputs; llvm::SmallVector region_args; llvm::SmallVector region_arg_types; - int n = 0; + int32_t n = 0; Region& body = *state->addRegion(); - if (ParseReplicateOpOperands(parser, state, &operands, ®ion_args, + if (ParseReplicateOpOperands(parser, state, &replicated_inputs, + &packed_inputs, ®ion_args, ®ion_arg_types) || parser->parseOptionalAttrDict(state->attributes) || - SetOperands(loc, parser, state, operands, region_arg_types, &n) || + SetReplicateOpOperands(loc, parser, state, replicated_inputs, + packed_inputs, region_arg_types, &n) || parser->parseRegion(body, region_args, region_arg_types)) return failure(); + // Add derived `operand_segment_sizes` attribute based on parsed operands. + if (!state->attributes.get(kOperandSegmentSizesAttr)) { + int32_t num_replicated_inputs = replicated_inputs.size() * n; + int32_t num_packed_inputs = packed_inputs.size(); + auto attr = DenseIntElementsAttr::get( + VectorType::get({2}, parser->getBuilder().getI32Type()), + {num_replicated_inputs, num_packed_inputs}); + state->addAttribute(kOperandSegmentSizesAttr, attr); + } + // Ensure that the region is well formed: it contains at least a block with // a ReturnOp terminator. ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location); @@ -323,22 +389,40 @@ void Print(ReplicateOp op, OpAsmPrinter* p) { *p << op.getOperationName(); // Print comma separated operands of the following format: - // [%a, ...] as %block_arg: type - int n = op.getAttrOfType("n").getInt(); + // replicated_input + // [%a, ...] as %block_arg0: type + // packed_input + // %b as %block_arg1: type + const int32_t n = op.n().getSExtValue(); + const int32_t num_replicated_inputs = + (*op.operand_segment_sizes().int_value_begin()).getSExtValue(); + const int32_t num_replicated_block_args = num_replicated_inputs / n; + if (op.getNumOperands()) { *p << '('; Block& block = op.body().front(); interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) { const int block_arg_num = arg.getArgNumber(); - *p << '['; - p->printOperands(std::next(op.operand_begin(), block_arg_num * n), - std::next(op.operand_begin(), (block_arg_num + 1) * n)); - *p << "] as " << arg << ": " << arg.getType(); + if (block_arg_num < num_replicated_block_args) { + *p << '['; + p->printOperands( + std::next(op.replicated_inputs().begin(), block_arg_num * n), + std::next(op.replicated_inputs().begin(), (block_arg_num + 1) * n)); + *p << "]"; + } else { + p->printOperand(*std::next(op.packed_inputs().begin(), + block_arg_num - num_replicated_block_args)); + } + *p << " as " << arg << ": " << arg.getType(); }); *p << ')'; } - p->printOptionalAttrDict(op.getAttrs()); + // Skip derived `operand_segment_sizes` attribute as custom print format of + // operands holds enough information to calculate these variadic operand list + // lengths. + p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/ArrayRef{ + kOperandSegmentSizesAttr}); p->printRegion(op.body(), /*printEntryBlockArgs=*/false); } @@ -353,9 +437,7 @@ LogicalResult VerifyCompatibleTypes(Type a, Type b) { } LogicalResult Verify(ReplicateOp op) { - uint64_t n = op.n().getLimitedValue(); - if (n < 2) - return op.emitOpError() << "expects 'n' to be at least 2, got " << n; + int32_t n = op.n().getSExtValue(); // Check number of devices, if set, matches `n`. if (op.devices().hasValue()) { @@ -381,22 +463,46 @@ LogicalResult Verify(ReplicateOp op) { Block& block = op.body().front(); - // Check number of operands matches `n` * number of block arguments. - if (op.getNumOperands() != n * block.getNumArguments()) + auto operand_segment_sizes = op.operand_segment_sizes(); + const int32_t num_replicated_inputs = + operand_segment_sizes.getValue({0}).getInt(); + const int32_t num_packed_inputs = + operand_segment_sizes.getValue({1}).getInt(); + + if (num_replicated_inputs % n != 0) + return op.emitOpError() + << "expects number of replicated inputs (" << num_replicated_inputs + << ") to be evenly divisible by 'n' (" << n << ")"; + + const int32_t num_replicated_block_args = num_replicated_inputs / n; + if (num_replicated_block_args + num_packed_inputs != block.getNumArguments()) return op.emitOpError() - << "expects number of operands (" << op.getNumOperands() - << ") to be equal to 'n' * number of block arguments (" << n << " * " - << block.getNumArguments() << ")"; + << "expects number of block arguments (" << block.getNumArguments() + << ") to be equal to number of replicated inputs (" + << num_replicated_inputs << ") / 'n' (" << n + << ") + number of packed inputs (" << num_packed_inputs << ")"; + + // Check input types match block argument types. + auto verify_operand_types = [&](BlockArgument block_arg, + int32_t op_operand_idx) -> LogicalResult { + Type op_operand_type = op.getOperand(op_operand_idx).getType(); + if (failed(VerifyCompatibleTypes(block_arg.getType(), op_operand_type))) + return op.emitOpError() + << "expects operand " << op_operand_idx << " (" << op_operand_type + << ") and block argument " << block_arg.getArgNumber() << " (" + << block_arg.getType() << ") to have compatible types"; - // Check replicated input types match block argument types. + return success(); + }; for (auto block_arg : block.getArguments()) { - Type block_arg_type = block_arg.getType(); - for (int i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i) - if (failed(VerifyCompatibleTypes(block_arg_type, - op.getOperand(i).getType()))) - return op.emitOpError() - << "incompatible types for operand " << i - << " and block argument " << block_arg.getArgNumber(); + if (block_arg.getArgNumber() < num_replicated_block_args) { + for (int32_t i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i) + if (failed(verify_operand_types(block_arg, i))) return failure(); + } else { + const int32_t idx = block_arg.getArgNumber() - num_replicated_block_args + + num_replicated_inputs; + if (failed(verify_operand_types(block_arg, idx))) return failure(); + } } Operation& terminator = block.back(); @@ -412,8 +518,8 @@ LogicalResult Verify(ReplicateOp op) { for (auto operand_type_and_idx : llvm::enumerate(terminator.getOperandTypes())) { Type operand_type = operand_type_and_idx.value(); - int operand_idx = operand_type_and_idx.index(); - for (int i = n * operand_idx, e = i + n; i < e; ++i) + int32_t operand_idx = operand_type_and_idx.index(); + for (int32_t i = n * operand_idx, e = i + n; i < e; ++i) if (failed(VerifyCompatibleTypes(operand_type, op.getType(i)))) return op.emitOpError() << "incompatible types for result " << i << " and terminator operand " << operand_idx; @@ -428,7 +534,7 @@ void BuildReplicateOp( const llvm::SmallDenseMap>& devices, llvm::ArrayRef> replicated_inputs, - ResultsTy replica_output_types) { + llvm::ArrayRef packed_inputs, ResultsTy replica_output_types) { DCHECK_GE(n, 2); state->addAttribute("n", builder->getI32IntegerAttr(n)); @@ -456,6 +562,19 @@ void BuildReplicateOp( block.addArgument(replicated_input.second); } + for (auto& packed_input : packed_inputs) { + state->addOperands(packed_input); + block.addArgument(packed_input.getType()); + } + + // Add derived `operand_segment_sizes` attribute. + int32_t num_replicated_inputs = replicated_inputs.size() * n; + int32_t num_packed_inputs = packed_inputs.size(); + auto operand_segment_sizes = + DenseIntElementsAttr::get(VectorType::get({2}, builder->getI32Type()), + {num_replicated_inputs, num_packed_inputs}); + state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes); + for (const auto& output_type : replica_output_types) state->addTypes(llvm::SmallVector(n, output_type)); } @@ -466,9 +585,10 @@ void ReplicateOp::build( const llvm::SmallDenseMap>& devices, llvm::ArrayRef, Type>> replicated_inputs, + llvm::ArrayRef packed_inputs, llvm::ArrayRef replica_output_types) { BuildReplicateOp(&builder, &state, n, devices, replicated_inputs, - replica_output_types); + packed_inputs, replica_output_types); } void ReplicateOp::build( @@ -476,9 +596,69 @@ void ReplicateOp::build( const llvm::SmallDenseMap>& devices, llvm::ArrayRef> replicated_inputs, + llvm::ArrayRef packed_inputs, Operation::result_type_range replica_output_types) { BuildReplicateOp(&builder, &state, n, devices, replicated_inputs, - replica_output_types); + packed_inputs, replica_output_types); +} + +// Returns the number of packed block arguments. +unsigned ReplicateOp::GetNumPackedBlockArguments() { + return packed_inputs().size(); +} + +// Returns the number of replicated block arguments. +unsigned ReplicateOp::GetNumReplicatedBlockArguments() { + return GetBody().getNumArguments() - GetNumPackedBlockArguments(); +} + +// Returns the replicated block arguments. A copy should be made if the +// replicate op is being modified. +llvm::ArrayRef ReplicateOp::GetReplicatedBlockArguments() { + return GetBody().getArguments().drop_back(GetNumPackedBlockArguments()); +} + +// Returns the packed block arguments. A copy should be made if the replicate op +// is being modified. +llvm::ArrayRef ReplicateOp::GetPackedBlockArguments() { + return GetBody().getArguments().take_back(GetNumPackedBlockArguments()); +} + +// Checks if a block argument is replicated (forwarding replicated inputs). +bool ReplicateOp::IsReplicatedBlockArgument(BlockArgument block_arg) { + assert(block_arg.getOwner() == &GetBody()); + return block_arg.getArgNumber() < GetNumReplicatedBlockArguments(); +} + +// Checks if a block argument is packed (forwarding a packed input). +bool ReplicateOp::IsPackedBlockArgument(BlockArgument block_arg) { + return !IsReplicatedBlockArgument(block_arg); +} + +// Returns the operand index of the operand being forwarded as a +// replicated/packed block argument for a given replica. This assumes a valid +// block argument (of the replicate op) and a valid replica is provided. +unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument( + BlockArgument block_arg, unsigned replica) { + const int32_t num_replicas = nAttr().getInt(); + assert(replica < num_replicas && block_arg.getOwner() == &GetBody()); + + const unsigned num_replicated_args = GetNumReplicatedBlockArguments(); + if (block_arg.getArgNumber() < num_replicated_args) + return block_arg.getArgNumber() * num_replicas + replica; + + return block_arg.getArgNumber() - num_replicated_args + + replicated_inputs().size(); +} + +// Returns the operand being forwarded as a replicated/packed block argument for +// a given replica. This assumes a valid block argument (of the replicate op) +// and a valid replica is provided. +Value ReplicateOp::GetReplicaOperandForBlockArgument(BlockArgument block_arg, + unsigned replica) { + const unsigned operand_index = + GetReplicaOperandIndexForBlockArgument(block_arg, replica); + return getOperand(operand_index); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td index d0c15f7e9ecedc..3a92e3237dc947 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -177,8 +177,8 @@ def TfDevice_ParallelExecuteOp : TfDevice_Op<"parallel_execute", let verifier = [{ return Verify(*this); }]; } -def TfDevice_ReplicateOp : - TfDevice_Op<"replicate", [SingleBlockImplicitTerminator<"ReturnOp">]> { +def TfDevice_ReplicateOp : TfDevice_Op<"replicate", + [SingleBlockImplicitTerminator<"ReturnOp">, AttrSizedOperandSegments]> { let summary = "Wraps an N-way replicated computation."; let description = [{ @@ -187,22 +187,30 @@ across multiple devices. The number of replications is based on the `n` attribute. Explicit devices can be populated in the `devices` attribute, and it must be a mapping of device alias to list of explicit or aliased device names from the outer scope. The device name map specifies devices on which replicated -ops inside tf_device.replicate will be executed. A tf_device.parallel_execute -inside the tf_device.replicate op region may be used to represent computations -across a larger set of devices. In that case, the device alias can be used to -specify device assignment and replication of each concurrent execution -(i.e. region) defined by tf_device.parallel_execute op. The size of each value -list in the device name map must match `n`. Within a replica, the execution -semantics follow standard sequential behavior. Ops in the tf_device.replicate -wrapped with a tf_device.launch will have its device set to the associated -replicated device from `devices` if the tf_device.launch refers to an aliased -device name. Otherwise the device already set in tf_device.launch is used -instead. Operands are replicated inputs: each group of `n` inputs corresponds to -an input for a single individual replica and is mapped to a single region -argument. Inside one group the operands are matching in order the `devices` -attribute. Each replicated input must have compatible shapes and types. Operands -not replicated can be implicitly captured by ops in the region. Results are -replicated each from the regions terminator. +ops inside tf_device.replicate will be executed. + +A tf_device.parallel_execute inside the tf_device.replicate op region may be +used to represent computations across a larger set of devices. In that case, the +device alias can be used to specify device assignment and replication of each +concurrent execution (i.e. region) defined by tf_device.parallel_execute op. +The size of each value list in the device name map must match `n`. Within a +replica, the execution semantics follow standard sequential behavior. Ops in the +tf_device.replicate wrapped with a tf_device.launch will have its device set to +the associated replicated device from `devices` if the tf_device.launch refers +to an aliased device name. Otherwise the device already set in tf_device.launch +is used instead. + +Operands are replicated inputs and packed inputs. + +replicated_inputs: each group of `n` inputs corresponds to an input for a single +individual replica and is mapped to a single region argument. Inside one group +the operands are matching in order the `devices` attribute. Each replicated +input must have compatible shapes and types. +packed_inputs: each input corresponds to an input broadcasted across all +replicas and is mapped to a single region argument. + +Operands not replicated can be implicitly captured by ops in the region. Results +are replicated each from the regions terminator. For example: ``` @@ -214,46 +222,55 @@ For example: %5 = "tf.opF"() : () -> tensor %6 = "tf.opG"() : () -> tensor %7 = "tf.opH"() : () -> tensor -%8 = "tf.opI"() : () -> tensor -%output:8 = tf_device.replicate([%0, %1] as %input_0:tensor, - [%2, %3] as %input_1:tensor, - [%4, %5] as %input_2:tensor - [%6, %7] as %input_3:tensor) +%8 = "tf.opI"() : () -> tensor +%9 = "tf.opJ"() : () -> tensor +%output:8 = tf_device.replicate([%0, %1] as %input_0: tensor, + [%2, %3] as %input_1: tensor, + [%4, %5] as %input_2: tensor, + [%6, %7] as %input_3: tensor, + %8 as %input_4: tensor) {n = 2 : i32, devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"], DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} { // Inside the region, %0, %2, %4, and %6 corresponds to // "/DEVICE:0"/"/DEVICE:2" and %1, %3, %5, and %7 corresponds to // "/DEVICE:1"/"/DEVICE:3", depending on which device alias is used. - %j = "tf_device.launch"() ( { - %9 = "tf.opJ"(%input_0, %6) : (tensor, tensor) -> tensor + %k = "tf_device.launch"() ( { + %9 = "tf.opK"(%input_0, %input_4, %9) : + (tensor, tensor, tensor) -> tensor tf_device.return %9 : tensor }) {device = "DEVICE_ALIAS_0"} : () -> tensor - %k = "tf_device.launch"() ( { - %10 = "tf.opK"(%input_1, %6) : (tensor, tensor) -> tensor + %l = "tf_device.launch"() ( { + %10 = "tf.opL"(%input_1, %input_4, %9) : + (tensor, tensor, tensor) -> tensor tf_device.return %10 : tensor }) {device = "DEVICE_ALIAS_1"} : () -> tensor - %l = "tf_device.launch"() ( { - %11 = "tf.opL"(%input_2, %6) : (tensor, tensor) - -> tensor + %m = "tf_device.launch"() ( { + %11 = "tf.opM"(%input_2, %input_4, %9) : + (tensor, tensor, tensor) + -> tensor tf_device.return %11 : tensor }) {device = "/DEVICE:4"} : () -> tensor - %m = "tf.opM"(%input_3, %6) : (tensor, tensor) - -> tensor - tf_device.return %j, %k, %l, %m : + %n = "tf.opN"(%input_3, %input_4, %9) : + (tensor, tensor, tensor) + -> tensor + tf_device.return %k, %l, %m, %n : tensor, tensor, tensor, tensor } -// %output#0 corresponds to %j returned from "/DEVICE:0" -// %output#1 corresponds to %j returned from "/DEVICE:1" -// %output#2 corresponds to %k returned from "/DEVICE:2" -// %output#3 corresponds to %k returned from "/DEVICE:3" -// %output#4, %output#5 corresponds to %l and will be returned from "/DEVICE:4" -// %output#6, %output#7 corresponds to %m and will have no device set +// %output#0 corresponds to %k returned from "/DEVICE:0" +// %output#1 corresponds to %k returned from "/DEVICE:1" +// %output#2 corresponds to %l returned from "/DEVICE:2" +// %output#3 corresponds to %l returned from "/DEVICE:3" +// %output#4, %output#5 corresponds to %m and will be returned from "/DEVICE:4" +// %output#6, %output#7 corresponds to %n and will have no device set ``` }]; let arguments = (ins Variadic:$replicated_inputs, + Variadic:$packed_inputs, + + I32ElementsAttr:$operand_segment_sizes, Confined]>:$n, OptionalAttr:$devices ); @@ -266,16 +283,26 @@ For example: let extraClassDeclaration = [{ Block &GetBody() { return getOperation()->getRegion(0).front(); } + unsigned GetNumReplicatedBlockArguments(); + unsigned GetNumPackedBlockArguments(); + llvm::ArrayRef GetPackedBlockArguments(); + llvm::ArrayRef GetReplicatedBlockArguments(); + bool IsReplicatedBlockArgument(BlockArgument block_arg); + bool IsPackedBlockArgument(BlockArgument block_arg); + unsigned GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg, unsigned replica); + Value GetReplicaOperandForBlockArgument(BlockArgument block_arg, unsigned replica); }]; let builders = [ OpBuilder<"OpBuilder& builder, OperationState& state, int n, " "const llvm::SmallDenseMap>& devices, " "llvm::ArrayRef, Type>> replicated_inputs, " + "llvm::ArrayRef packed_inputs, " "llvm::ArrayRef replica_output_types">, OpBuilder<"OpBuilder& builder, OperationState& state, int n, " "const llvm::SmallDenseMap>& devices, " "llvm::ArrayRef> replicated_inputs, " + "llvm::ArrayRef packed_inputs, " "Operation::result_type_range replica_output_types"> ]; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 1e66eee06bbc7f..1b1d5ba6f3b01b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -71,7 +71,7 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { // Allow inlining into tf.island regions if the incoming region has a single // block. return llvm::isa(dest->getParentOp()) && - std::next(src->begin()) == src->end(); + llvm::hasSingleElement(*src); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 1edac0e535fba8..a0e73f116cf020 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -87,7 +87,7 @@ tf.math.acosh(x) ==> [nan nan 0. 0.62236255 5.9914584 9.903487 inf] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>, +def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Returns x + y element-wise."; @@ -136,7 +136,7 @@ Inputs must be of same size and shape. let hasFolder = 1; } -def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>, +def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Returns x + y element-wise."; @@ -648,7 +648,7 @@ tf.math.atan(y) # [1.047, 0.785] = x TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_Atan2Op : TF_Op<"Atan2", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_Atan2Op : TF_Op<"Atan2", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = [{ Computes arctangent of `y/x` element-wise, respecting signs of the arguments. @@ -765,7 +765,7 @@ def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; } -def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect]> { +def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "Multiplies slices of two tensors in batches."; let description = [{ @@ -806,7 +806,7 @@ It is computed as: let hasCanonicalizer = 1; } -def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect]> { +def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "Multiplies slices of two tensors in batches."; let description = [{ @@ -1422,7 +1422,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect]> { +def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "Clips tensor values to a specified min and max."; let description = [{ @@ -1984,7 +1984,7 @@ Given an input tensor, this function computes hyperbolic cosine of every TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CrossOp : TF_Op<"Cross", [NoSideEffect]> { +def TF_CrossOp : TF_Op<"Cross", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Compute the pairwise cross product."; let description = [{ @@ -2469,7 +2469,7 @@ Computes Psi, the derivative of Lgamma (the log of the absolute value of TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Returns x / y element-wise."; @@ -2494,7 +2494,7 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>, let hasFolder = 1; } -def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Returns 0 if the denominator is zero."; @@ -3374,7 +3374,7 @@ def TF_FloorDivOp : TF_Op<"FloorDiv", [NoSideEffect, ResultsBroadcastableShape]> TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_FloorModOp : TF_Op<"FloorMod", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_FloorModOp : TF_Op<"FloorMod", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = [{ Returns element-wise remainder of division. When `x < 0` xor `y < 0` is @@ -4111,7 +4111,7 @@ def ApplyG(op, dy, _): TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; } -def TF_IgammaOp : TF_Op<"Igamma", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_IgammaOp : TF_Op<"Igamma", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = [{ Compute the lower regularized incomplete Gamma function `P(a, x)`. @@ -4145,7 +4145,7 @@ Gamma function. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Computes the gradient of `igamma(a, x)` wrt `a`."; @@ -4161,7 +4161,7 @@ def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableS TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_IgammacOp : TF_Op<"Igammac", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_IgammacOp : TF_Op<"Igammac", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = [{ Compute the upper regularized incomplete Gamma function `Q(a, x)`. @@ -4928,7 +4928,7 @@ def TF_LookupTableSizeV2Op : TF_Op<"LookupTableSizeV2", []> { ); } -def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect]> { +def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = [{ Multiply the matrix "a" by the matrix "b". }]; @@ -5692,7 +5692,7 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> { }]; } -def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise."; @@ -5766,7 +5766,7 @@ retained with length 1. TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; } -def TF_MinimumOp : TF_Op<"Minimum", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_MinimumOp : TF_Op<"Minimum", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Returns the min of x and y (i.e. x < y ? x : y) element-wise."; @@ -5899,7 +5899,7 @@ graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.Tensor TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; } -def TF_ModOp : TF_Op<"Mod", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_ModOp : TF_Op<"Mod", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = [{ Returns element-wise remainder of division. This emulates C semantics in that @@ -5925,7 +5925,7 @@ the result here is consistent with a truncating divide. E.g. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape]>, +def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Returns x * y element-wise."; @@ -6426,7 +6426,7 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>; } -def TF_PowOp : TF_Op<"Pow", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_PowOp : TF_Op<"Pow", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Computes the power of one value to another."; @@ -6905,7 +6905,7 @@ lower bound 0 is included in the range, while the upper bound 1 is excluded. }]; } -def TF_RangeOp : TF_Op<"Range", [NoSideEffect]> { +def TF_RangeOp : TF_Op<"Range", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "Creates a sequence of numbers."; let description = [{ @@ -7270,6 +7270,8 @@ reshape(t, []) ==> 7 let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; } def TF_ResizeBilinearOp : TF_Op<"ResizeBilinear", [NoSideEffect]> { @@ -8264,6 +8266,39 @@ def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> ]; } +def TF_SelfAdjointEigV2Op : TF_Op<"SelfAdjointEigV2", [NoSideEffect]> { + let summary = [{ +Computes the eigen decomposition of one or more square self-adjoint matrices. + }]; + + let description = [{ +Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in +`input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues +are sorted in non-decreasing order. + +```python +# a is a tensor. +# e is a tensor of eigenvalues. +# v is a tensor of eigenvectors. +e, v = self_adjoint_eig(a) +e = self_adjoint_eig(a, compute_v=False) +``` + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input, + + DefaultValuedAttr:$compute_v + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$e, + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$v + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_SeluOp : TF_Op<"Selu", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` @@ -9516,7 +9551,7 @@ Examples: TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } -def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Returns x - y element-wise."; @@ -10673,7 +10708,7 @@ Python Semantics. let hasCanonicalizer = 1; } -def TF_TruncateModOp : TF_Op<"TruncateMod", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_TruncateModOp : TF_Op<"TruncateMod", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = [{ Returns element-wise remainder of division. This emulates C semantics in that @@ -11146,7 +11181,7 @@ where(input) ==> [[0, 0, 0], TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Returns 0 if x == 0, and x / y otherwise, elementwise."; @@ -11512,7 +11547,7 @@ tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect]> { +def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "Returns 0 if x == 0, and x * log1p(y) otherwise, elementwise."; let arguments = (ins @@ -11527,7 +11562,7 @@ def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, WithBroadcastableBinOpBuilder { let summary = "Returns 0 if x == 0, and x * log(y) otherwise, elementwise."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 8dced9f4288f0e..de6ce2d313a31c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -506,28 +506,52 @@ LogicalResult FoldOperandsPermutation( //===----------------------------------------------------------------------===// namespace { -// Folder that returns LHS of an Arithmetic Op if the RHS is a constant -// known to be Identity (e.g X+0) +// Fold Arithmetic Op if one of the operands is a constant known to be an +// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if +// known identity value is either lhs or rhs. template < typename OpT, typename std::enable_if::value>::type * = nullptr> OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, ArrayRef operands) { - auto result_op_type = arithmetic_op.getResult().getType(); auto lhs_type = arithmetic_op.x().getType().template cast(); - if (!result_op_type.template cast().hasStaticShape()) return {}; + auto rhs_type = arithmetic_op.y().getType().template cast(); + auto result_type = + arithmetic_op.getResult().getType().template cast(); + + // We can fold arithmetic operation only of we can prove that we will not + // accidentally hide a broadcasting error. + auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty, + ShapedType result_ty) -> bool { + // Scalar identity is broadcastable to any operand shape, we only need to + // check that operand has the same shape as a result. + bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0; + if (scalar_identity) return operand_ty == result_ty; + + // If identity is not a scalar, we must verify that all shapes are equal + // and statically known. + // + // TODO(ezhulenev): Fold if identity shape is statically know to be + // broadcastable to the operand shape. + return operand_ty == result_ty && identity_ty == result_ty && + result_ty.hasStaticShape(); + }; - // We only handle non-broadcastable case. - if (result_op_type != lhs_type) { - return {}; - } + // Check that we have a constant operand on one side (candidate for identity). + const bool is_commutative = + (std::is_same::value || std::is_same::value); + auto lhs_attr = operands[0].dyn_cast_or_null(); + auto rhs_attr = operands[1].dyn_cast_or_null(); + if (!rhs_attr && !(is_commutative && lhs_attr)) return {}; // Mul and Div ops have identity value one while AddV2 and SubOp have identity // value zero. - int identity = + const int identity = (std::is_same::value || std::is_same::value || - std::is_same::value); + std::is_same::value) + ? 1 + : 0; Type element_ty = lhs_type.getElementType(); Attribute identity_attr; @@ -539,23 +563,19 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, return {}; } - if (auto attr = operands[1].dyn_cast_or_null()) { - if (attr.isSplat() && attr.getSplatValue() == identity_attr) + // Fold: Op(Operand, Identity) -> Operand. + if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) { + if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr) return arithmetic_op.x(); } - auto rhs_type = arithmetic_op.y().getType().template cast(); - // TODO(chhe): we could fold and add an identity to force the broadcast. - if (result_op_type != rhs_type) { - return {}; - } - - bool is_symmetric = - (std::is_same::value || std::is_same::value); - if (auto attr = operands[0].dyn_cast_or_null()) { - if (is_symmetric && attr.isSplat() && attr.getSplatValue() == identity_attr) + // Fold: Op(Identity, Operand) -> Operand for commutative operations. + if (lhs_attr && is_commutative && + is_valid_broadcasting(rhs_type, lhs_type, result_type)) { + if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr) return arithmetic_op.y(); } + return {}; } } // namespace @@ -1168,8 +1188,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result, ShapedType type; if (auto elem_attr = value.dyn_cast()) { return ConstOp::build(builder, result, elem_attr); - } else if (value.isa() || value.isa() || - value.isa()) { + } else if (value.isa()) { // All TensorFlow types must be tensor types. In the build() method, // we want to provide more flexibility by allowing attributes of scalar // types. But we need to wrap it up with ElementsAttr to construct @@ -2137,10 +2156,6 @@ static LogicalResult Verify(IfRegionOp op) { return failure(); if (failed(VerifyRegionResults(op, op.else_branch(), "else"))) return failure(); - if (op.then_branch().front().getNumArguments() != 0) - return op.emitOpError() << "then region cannot have any arguments"; - if (op.else_branch().front().getNumArguments() != 0) - return op.emitOpError() << "else region cannot have any arguments"; return success(); } @@ -2870,6 +2885,11 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, return unranked(); } +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// @@ -4274,6 +4294,117 @@ LogicalResult WhileRegionOp::moveOutOfLoop( return success(); } +//===----------------------------------------------------------------------===// +// WhileRegionOp canonicalization +//===----------------------------------------------------------------------===// +namespace { +// Eliminate values that pass through the WhileRegionOp body. +struct WhileRegionEliminatePassThrough + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileRegionOp while_op, + PatternRewriter &rewriter) const override { + // Replace values that simply passthrough the body with extern values. The + // block arguments of body and while match and so the corresponding cond + // argument can be easily found. + int old_num_operands = while_op.getNumOperands(); + int new_num_operands = old_num_operands; + auto &body_block = while_op.body().front(); + auto &cond_block = while_op.cond().front(); + auto &yield = *body_block.getTerminator(); + + // Bit mask indicating which operands will be removed. + SmallVector removed_operand(old_num_operands, false); + + for (int op_idx : llvm::seq(0, old_num_operands)) { + auto body_arg = body_block.getArgument(op_idx); + if (body_arg == yield.getOperand(op_idx)) { + // Replace the use of the passthrough value with the while operand + // in the body and condition regions, as well as the while output (if + // type match) + // TODO(jurahul): Use PatternRewriter API for IR modification. + auto value = while_op.getOperand(op_idx); + if (body_arg.getType() == value.getType()) + body_arg.replaceAllUsesWith(value); + + auto cond_arg = cond_block.getArgument(op_idx); + if (cond_arg.getType() == value.getType()) + cond_arg.replaceAllUsesWith(value); + + auto result = while_op.getResult(op_idx); + if (result.getType() == value.getType()) + result.replaceAllUsesWith(value); + } + + // Now check if the operand is unused in both regions as well as the + // result. If so, mark it for removal. + if (body_block.getArgument(op_idx).use_empty() && + cond_block.getArgument(op_idx).use_empty() && + while_op.getResult(op_idx).use_empty()) { + removed_operand[op_idx] = true; + new_num_operands--; + } + } + + if (new_num_operands == old_num_operands) return failure(); + + // Compress the operands, region arguments, and outputs. + SmallVector new_while_operands; + SmallVector new_result_types; + new_while_operands.reserve(new_num_operands); + new_result_types.reserve(new_num_operands); + + // Build new operands and result type. + int next_idx = 0; + for (int op_idx : llvm::seq(0, old_num_operands)) { + if (removed_operand[op_idx]) continue; + new_while_operands.push_back(while_op.getOperand(op_idx)); + new_result_types.push_back(while_op.getResult(op_idx).getType()); + next_idx++; + } + + // Create the new while operation. + auto new_while_op = + rewriter.create(while_op.getLoc(), new_result_types, + new_while_operands, while_op.getAttrs()); + + // Move region bodies to the new while. + rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(), + new_while_op.cond().end()); + rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(), + new_while_op.body().end()); + + auto &new_cond_block = new_while_op.cond().front(); + auto &new_body_block = new_while_op.body().front(); + auto &new_yield = *new_body_block.getTerminator(); + + // Build a vector of new results. Also patch up the region bodies and yield. + SmallVector new_results; + next_idx = 0; + for (int op_idx : llvm::seq(0, old_num_operands)) { + if (removed_operand[op_idx]) { + new_cond_block.eraseArgument(next_idx); + new_body_block.eraseArgument(next_idx); + new_yield.eraseOperand(next_idx); + new_results.push_back(nullptr); + } else { + new_results.push_back(new_while_op.getResult(next_idx++)); + } + } + + rewriter.replaceOp(while_op, new_results); + return success(); + } +}; + +} // anonymous namespace + +void WhileRegionOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // XdivyOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index d004bc521a8941..4f319bf4c30103 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -247,7 +247,7 @@ def TF_YieldOp : TF_Op<"Yield", } def TF_IfRegionOp : TF_Op<"IfRegion", - [SingleBlockImplicitTerminator<"YieldOp">]> { + [SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> { let summary = "output = cond ? then_branch output : else_branch output"; let description = [{ @@ -658,7 +658,9 @@ def TL_WhileRegionOp : TF_Op<"WhileRegion", This implies that the operand and result types for tf.WhileRegion should be the same. Note that the condition and body regions can implicitly capture - loop invariant values directly. + loop invariant values directly. In canonical form, iteration variables that + pass through the loop body unmodified are converted to implicitly captured + references to their values outside the loop. }]; let arguments = (ins @@ -676,6 +678,8 @@ def TL_WhileRegionOp : TF_Op<"WhileRegion", let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; } def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> { @@ -717,7 +721,8 @@ This operation holds the metadata common to operations of a `tpu.replicate()` co DefaultValuedAttr:$host_compute_core, DefaultValuedAttr:$padding_map, DefaultValuedAttr:$step_marker_location, - DefaultValuedAttr:$allow_soft_placement + DefaultValuedAttr:$allow_soft_placement, + DefaultValuedAttr:$use_spmd_for_xla_partitioning ); let results = (outs); @@ -1160,4 +1165,35 @@ array([0, 2, 2]) ); } +def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> { + let summary = "Calls a function placed on a specified TPU device."; + + let arguments = (ins + Variadic:$args, + I32Tensor:$device_ordinal, + + SymbolRefAttr:$f, + DefaultValuedAttr:$autotuner_thresh + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; + + let extraClassDeclaration = [{ + // Gets the argument operands to the called function. + operand_range getArgOperands() { return args(); } + + // Returns the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return getAttrOfType("f"); + } + }]; + + let verifier = [{ return VerifyPartitionedCall(*this); }]; +} + #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index feccd433a9a973..edfc7feefd566d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -141,16 +141,27 @@ static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) { return mlir::success(); } -Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor) { - auto type = global_tensor.type().cast(); - return RankedTensorType::get( - {}, TF::ResourceType::get({type}, type.getContext())); +Type GetBoundInputArgTypeFor(mlir::Operation *op) { + if (auto global_tensor = llvm::dyn_cast(op)) { + auto type = global_tensor.type().cast(); + return RankedTensorType::get( + {}, TF::ResourceType::get({type}, type.getContext())); + } + + if (auto asset = llvm::dyn_cast(op)) { + return RankedTensorType::get({}, TF::StringType::get(asset.getContext())); + } + + op->emitError() << "unknown symbol operation"; + return {}; } static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics, Type arg_type, - GlobalTensorOp global_tensor) { - auto expected_type = GetBoundInputArgTypeFor(global_tensor); + mlir::Operation *symbol_op) { + auto expected_type = GetBoundInputArgTypeFor(symbol_op); + if (!expected_type) return failure(); + if (arg_type != expected_type) { return op_for_diagnostics->emitError() << "bound input with type " << arg_type << " expected to have type " @@ -169,14 +180,14 @@ LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute( } auto symbol_name = named_attr.second.cast().getValue(); auto module = op->getParentOfType(); - auto global_tensor = module.lookupSymbol(symbol_name); - if (!global_tensor) { + mlir::Operation *symbol_op = module.lookupSymbol(symbol_name); + if (!symbol_op) { return op->emitError() << "'tf_saved_model.bound_input' attribute must " "reference a valid symbol, got invalid symbol '" << symbol_name << "'"; } auto arg_type = cast(op).getArgument(arg_index).getType(); - return VerifyBoundInputArgType(op, arg_type, global_tensor); + return VerifyBoundInputArgType(op, arg_type, symbol_op); } if (named_attr.first == "tf_saved_model.index_path") { return VerifyIndexPath(op, named_attr); @@ -345,7 +356,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) { LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute( Operation *op, NamedAttribute named_attr) { if (named_attr.first == "tf_saved_model.exported_names") { - if (!isa(op) && !isa(op)) { + if (!isa(op)) { return op->emitError() << "'tf_saved_model.exported_names' must be on a " "'func' or 'tf_saved_model.global_tensor' op"; } @@ -404,12 +415,12 @@ bool HasTfSavedModelSemantics(ModuleOp module) { return module.getAttr("tf_saved_model.semantics") != nullptr; } -GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index, - const SymbolTable &symbol_table) { +Operation *LookupBoundInput(FuncOp func, int arg_index, + const SymbolTable &symbol_table) { auto attr = func.getArgAttrOfType( arg_index, "tf_saved_model.bound_input"); if (!attr) return nullptr; - return symbol_table.lookup(attr.getValue()); + return symbol_table.lookup(attr.getValue()); } SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h index 056df4d6a43c27..02b7f0b75f49ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h @@ -36,6 +36,8 @@ class TensorFlowSavedModelDialect : public Dialect { NamedAttribute named_attr) override; LogicalResult verifyOperationAttribute(Operation *op, NamedAttribute named_attr) override; + + static StringRef getDialectNamespace() { return "tf_saved_model"; } }; // Declares the operations for this dialect using the generated header. @@ -54,12 +56,19 @@ bool HasTfSavedModelSemantics(ModuleOp module); // Returns the tf_saved_model.global_tensor op that func's arg_index'th argument // refers to as a bound input, or null. -GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index, - const SymbolTable &symbol_table); - -// Gets the type that an exported function arg that is bound to `global_tensor` -// should have. -Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor); +Operation *LookupBoundInput(FuncOp func, int arg_index, + const SymbolTable &symbol_table); + +template +T LookupBoundInputOfType(FuncOp func, int arg_index, + const SymbolTable &symbol_table) { + return llvm::dyn_cast_or_null( + LookupBoundInput(func, arg_index, symbol_table)); +} + +// Gets the type that an exported function arg that is bound to symbol ops such +// as `global_tensor` and `asset` should have. +Type GetBoundInputArgTypeFor(mlir::Operation *op); // Returns the session initializer of this module if it exists. Returns null // otherwise. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td index dc1210a4d2a49f..a22a684953be70 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td @@ -19,6 +19,7 @@ limitations under the License. #define SAVED_MODEL_DIALECT include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" //===----------------------------------------------------------------------===// // Dialect definition @@ -154,4 +155,24 @@ def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> { let hasCanonicalizer = 1; } +def TfSavedModel_AssetOp: TfSavedModel_Op<"asset", [Symbol]> { + let summary = "Represents an asset in saved model."; + let description = [{ + Represents an asset in the saved model that points to an external file. It + is a scalar string tensor and it is passed as an argument to the session + initializer function. + + The `sym_name` represents the symbol table name used for internal IR + references. + + The `filename` attribute contains the file path to the asset file and it is + relative to saved model directory. + }]; + + let arguments = (ins + StrAttr:$sym_name, + StrAttr:$filename + ); +} + #endif // SAVED_MODEL_DIALECT diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index a250ca1af8ca9f..f352bc0eb476a1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -90,8 +90,7 @@ class TensorFlowType : public Type { // Returns true if the specified type is a valid TensorFlow element type. static inline bool IsValidTFElementType(Type type) { - return type.isa() || type.isa() || - type.isa() || type.isa(); + return type.isa(); } // Returns true if this is a valid TensorFlow tensor type. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir b/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir index e0e777728ead80..c327a47f62d494 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir @@ -19,7 +19,7 @@ module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: func @_func // CHECK-SAME: %[[ARG0:.*]]: tensor, - // CHECK-SAME: %[[ARG1:.*]]: tensor {xla_hlo.is_same_data_across_replicas} + // CHECK-SAME: %[[ARG1:.*]]: tensor {mhlo.is_same_data_across_replicas} // CHECK-SAME: %[[ARG2:.*]]: tensor) func @_func(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -54,9 +54,9 @@ module attributes {tf.versions = {producer = 888 : i32}} { } // CHECK-LABEL: func @_func - // CHECK-SAME: %[[ARG0:.*]]: tensor {xla_hlo.is_same_data_across_replicas}, + // CHECK-SAME: %[[ARG0:.*]]: tensor {mhlo.is_same_data_across_replicas}, // CHECK-SAME: %[[ARG1:.*]]: tensor, - // CHECK-SAME: %[[ARG2:.*]]: tensor>> {xla_hlo.is_same_data_across_replicas} + // CHECK-SAME: %[[ARG2:.*]]: tensor>> {mhlo.is_same_data_across_replicas} func @_func(%arg0: tensor, %arg1: tensor, %arg2: tensor>>) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor @@ -78,7 +78,7 @@ module attributes {tf.versions = {producer = 888 : i32}} { } // CHECK-LABEL: func @_func - // CHECK-NOT: xla_hlo.is_same_data_across_replicas + // CHECK-NOT: mhlo.is_same_data_across_replicas func @_func(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 21ad6036aaf5d2..8597740a4aefa3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -190,6 +190,27 @@ func @testSubOfNeg(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8 // CHECK: return %0 } +// CHECK-LABEL: testSubOfZero +func @testSubOfZero(%arg0: tensor, %arg1: tensor<4x1xf32>) -> (tensor, tensor<4x1xf32>) { + %0 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %1 = "tf.Sub"(%arg0, %0) : (tensor, tensor) -> tensor + %2 = "tf.Sub"(%arg1, %0) : (tensor<4x1xf32>, tensor) -> tensor<4x1xf32> + return %1, %2: tensor, tensor<4x1xf32> + +// CHECK: return %arg0, %arg1 +} + +// CHECK-LABEL: testSubOfZeroWithBroadcasting +func @testSubOfZeroWithBroadcasting(%arg0: tensor<4x1xf32>) -> tensor<4x4xf32> { + // This is an identity arithmetic operation, however we do not currently fold + // it because it has a broadcasting. + %0 = "tf.Const"() {value = dense<[[0.0, 0.0, 0.0, 0.0]]> : tensor<1x4xf32>} : () -> tensor<1x4xf32> + %1 = "tf.Sub"(%arg0, %0) : (tensor<4x1xf32>, tensor<1x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + +// CHECK: return %1 +} + // CHECK-LABEL: testSquareOfSub func @testSquareOfSub(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Sub"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> @@ -257,6 +278,46 @@ func @testAddV2OfNegRight(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> t // CHECK: return %0 } +// CHECK-LABEL: testAddV2IdentityScalar +func @testAddV2IdentityScalar(%arg0: tensor, %arg1: tensor, %arg2: tensor<4xf32>) -> (tensor, tensor, tensor<4xf32>) { + %0 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + + // Identity scalar (0.0) is foldable with operand of any shape because + // scalar is safely broadcastable to any shape. + + %1 = "tf.AddV2"(%arg0, %0) : (tensor, tensor) -> tensor + %2 = "tf.AddV2"(%arg1, %0) : (tensor, tensor) -> tensor + %3 = "tf.AddV2"(%arg2, %0) : (tensor<4xf32>, tensor) -> tensor<4xf32> + + %4 = "tf.AddV2"(%0, %1) : (tensor, tensor) -> tensor + %5 = "tf.AddV2"(%0, %2) : (tensor, tensor) -> tensor + %6 = "tf.AddV2"(%0, %3) : (tensor, tensor<4xf32>) -> tensor<4xf32> + + // CHECK: return %arg0, %arg1, %arg2 + return %4, %5, %6: tensor, tensor, tensor<4xf32> +} + +// CHECK-LABEL: testAddV2IdentityTensor +func @testAddV2IdentityTensor(%arg0: tensor, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) { + %0 = "tf.Const"() {value = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf32>} : () -> tensor<4xf32> + + // If operand is a scalar, then the identity value (0.0 for addition) can + // be of any shape, because operand is safely broadcastable to any shape. + // + // However we can't fold this arithmetic operation because the operand + // shape does not match the result shape. + + %1 = "tf.AddV2"(%arg0, %0) : (tensor, tensor<4xf32>) -> tensor<4xf32> + %2 = "tf.AddV2"(%0, %arg0) : (tensor<4xf32>, tensor) -> tensor<4xf32> + + // If operand has the same shape as a result, we can fold it. + %3 = "tf.AddV2"(%arg1, %0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %4 = "tf.AddV2"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + // CHECK: return %1, %2, %arg1, %arg1 + return %1, %2, %3, %4: tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32> +} + // CHECK-LABEL: testDoubleConj func @testDoubleConj(%arg0: tensor<8x16x32x64xcomplex>) -> tensor<8x16x32x64xcomplex> { %0 = "tf.Conj"(%arg0) : (tensor<8x16x32x64xcomplex>) -> tensor<8x16x32x64xcomplex> @@ -302,6 +363,20 @@ func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi // CHECK: return %arg0 } +// CHECK-LABEL: testRedundantReshape +func @testRedundantReshape(%arg0: tensor<4x4xi32>) -> tensor<2x8xi32> { + %0 = "tf.Const"() {value = dense<[8, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.Const"() {value = dense<[2, 8]> : tensor<2xi32>} : () -> tensor<2xi32> + %2 = "tf.Reshape"(%arg0, %0) : (tensor<4x4xi32>, tensor<2xi32>) -> tensor<8x2xi32> + %3 = "tf.Reshape"(%2, %1) : (tensor<8x2xi32>, tensor<2xi32>) -> tensor<2x8xi32> + return %3: tensor<2x8xi32> + + // CHECK: %0 = "tf.Const" + // CHECK-SAME: value = dense<[2, 8]> : tensor<2xi32> + // CHECK: %1 = "tf.Reshape"(%arg0, %0) + // CHECK: return %1 : tensor<2x8xi32> +} + // CHECK-LABEL: testSelectScalarPred func @testSelectScalarPred(%arg0: tensor, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> { // CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16> @@ -648,3 +723,152 @@ func @erase_tf_var_is_initialized(%arg0 : tensor>>) -> // Unused VarIsInitializedOp is erased. // CHECK: tf.VarHandleOp // CHECK-NEXT: tf.UnknownOp + + +// Simple pass through value +// CHECK-LABEL: testWhileRegionSimplePassThrough +func @testWhileRegionSimplePassThrough(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + // CHECK: "tf.WhileRegion"(%arg1) + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + // condition, check if count has reached 0 + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %zero = constant dense<0> : tensor + %ne = "tf.NotEqual"(%carg1, %zero) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg1, %one) : (tensor, tensor) -> tensor + "tf.Yield"(%barg0, %sub) : (tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return %arg0 : tensor<*xf32> + return %0#0 : tensor<*xf32> +} + +// Multiple pass through values +// CHECK-LABEL: testWhileRegionMultiplePassThrough +func @testWhileRegionMultiplePassThrough(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor) -> tensor<*xf32> { + // Verify that first 3 operands are elimiinated. + // CHECK: "tf.WhileRegion"(%arg3) + %0:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %arg3) ( + { + // condition, check if count has reached 0 + ^bb0(%carg0 : tensor<*xf32>, %carg1 : tensor<*xf32>, %carg2 : tensor<*xf32>, %carg3 : tensor): + %zero = constant dense<0> : tensor + %ne = "tf.NotEqual"(%carg3, %zero) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0 : tensor<*xf32>, %barg1 : tensor<*xf32>, %barg2 : tensor<*xf32>, %barg3 : tensor): + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg3, %one) : (tensor, tensor) -> tensor + "tf.Yield"(%barg0, %barg1, %barg2, %sub) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor) + + // CHECK: %[[SUB0:.*]] = "tf.Sub"(%arg0, %arg1) + // CHECK: %[[SUB1:.*]] = "tf.Sub"(%arg2, %[[SUB0]]) + // CHECK: return %[[SUB1]] + %sub0 = "tf.Sub" (%0#0, %0#1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %sub1 = "tf.Sub" (%0#2, %sub0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %sub1 : tensor<*xf32> +} + +// Multiple non contiguous pass through values +// CHECK-LABEL: testWhileRegionMultiplePassThroughNonContiguous +func @testWhileRegionMultiplePassThroughNonContiguous(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor) -> tensor<*xf32> { + // Verify arg0 and arg2 are eliminated + // CHECK: %[[WHILE_OUT:.*]]:2 = "tf.WhileRegion"(%arg1, %arg3) + %0:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %arg3) ( + { + // condition, check if count has reached 0 + ^bb0(%carg0 : tensor<*xf32>, %carg1 : tensor<*xf32>, %carg2 : tensor<*xf32>, %carg3 : tensor): + %zero = constant dense<0> : tensor + %ne = "tf.NotEqual"(%carg3, %zero) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0 : tensor<*xf32>, %barg1 : tensor<*xf32>, %barg2 : tensor<*xf32>, %barg3 : tensor): + %arg1neg = "tf.Neg"(%barg1) : (tensor<*xf32>) -> tensor<*xf32> + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg3, %one) : (tensor, tensor) -> tensor + "tf.Yield"(%barg0, %arg1neg, %barg2, %sub) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor) + + // Verify that use of while loop results corresponding to result #0 and 2 of + // the while are replaces with corresponding WhileRegion operands + // CHECK: %[[SUB0:.*]] = "tf.Sub"(%arg0, %[[WHILE_OUT]]#0) + // CHECK: %[[SUB1:.*]] = "tf.Sub"(%arg2, %[[SUB0]]) + // CHECK: return %[[SUB1]] + %sub0 = "tf.Sub" (%0#0, %0#1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %sub1 = "tf.Sub" (%0#2, %sub0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %sub1 : tensor<*xf32> +} + +// Pass through but with type mismatch (tensor<*xf32> is compatible with +// tensor in the body). WhileRegion canonicalization does not handle +// this. +// CHECK-LABEL: testWhileRegionPassThroughTypeMismatch +func @testWhileRegionPassThroughTypeMismatch(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + // Verify that the While stay's unchanged + // CHECK: "tf.WhileRegion"(%arg0, %arg1) + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + // condition, check if count has reached 0 + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %zero = constant dense<0> : tensor + %ne = "tf.NotEqual"(%carg1, %zero) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor, %barg1: tensor): + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg1, %one) : (tensor, tensor) -> tensor + "tf.Yield"(%barg0, %sub) : (tensor, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // Verify that the result stays uchanged + // CHECK: return %arg0 : tensor<*xf32> + return %0#0 : tensor<*xf32> +} + +// Unused value flowing through the while (operand 2 and 3, is unused in the +// while and the corresponding result is unused as well). Canonicalization will +// eliminate them. +// CHECK-LABEL: testWhileRegionUnusedValue +func @testWhileRegionUnusedValue(%arg0 : tensor<*xf32>, %arg1 : tensor, %arg2: tensor) -> tensor<*xf32> { + %cst = constant dense <33.0> : tensor + // Verify that last 2 operands of while (unused) are removed + // CHECK: %[[WHILE_OUT:.*]]:2 = "tf.WhileRegion"(%arg0, %arg1) + %0:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %cst) ( + { + // condition, check if count has reached 0 + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor, %carg2:tensor, %carg3:tensor): + %zero = constant dense<0> : tensor + %ne = "tf.NotEqual"(%carg1, %zero) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor, %barg2:tensor, %barg3:tensor): + %add = "tf.Add"(%barg0, %barg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg1, %one) : (tensor, tensor) -> tensor + %dummy0 = constant dense<7> : tensor + %dummy1 = constant dense<3.0> : tensor + "tf.Yield"(%add, %sub, %dummy0, %dummy1) : (tensor<*xf32>, tensor, tensor, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor, tensor, tensor) -> (tensor<*xf32>, tensor, tensor, tensor) + + // Verify that return still uses while result # 0 + // CHECK: return %[[WHILE_OUT]]#0 : tensor<*xf32> + return %0#0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index 3ae6023400cd47..7b8c998bcf1566 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -4,7 +4,7 @@ func @testShape(tensor, tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<0xi32>, tensor, tensor) { ^bb0(%arg0: tensor, %arg1: tensor<1x32x32x16xf32>, %arg2: tensor<*xf32>): - // CHECK: tf.Const{{.*}} dense<[]> : tensor<0xi32> + // CHECK: tf.Const{{.*}} dense<> : tensor<0xi32> %0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor) -> tensor<0xi32> // Result shape need not be static. Folding harness uses TensorFlow constant @@ -91,7 +91,7 @@ func @testEmptybf16() -> (tensor<5xbf16>) { // CHECK-LABEL: func @testShapeN func @testShapeN(%arg0: tensor, %arg1: tensor<1x32x32x16xf32>, %arg2: tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor<4xi64>, tensor) { - // CHECK: "tf.Const"() {value = dense<[]> : tensor<0xi64> + // CHECK: "tf.Const"() {value = dense<> : tensor<0xi64> // CHECK: "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi64>} %0:2 = "tf.ShapeN"(%arg0, %arg1) : (tensor, tensor<1x32x32x16xf32>) -> (tensor<0xi64>, tensor<4xi64>) @@ -442,3 +442,24 @@ func @DontRemoveTrivialMul(%arg0: tensor<1x6x8x1xf32>) -> tensor<1x6x8x1xf32> { // CHECK: %[[RESULT:.*]] = "tf.Mul"(%arg0, %[[CONST]]) : (tensor<1x6x8x1xf32>, tensor) -> tensor<1x6x8x1xf32> // CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32> } + +// Do not fold if total result size is large (>128 KB) and more than 2 times +// the size of operands. + +// LINT.IfChange(folding-policy-test) +// CHECK-LABEL: DontFoldTile +func @DontFoldTile() -> (tensor<8x10000xi32>) { + %const_10000 = "tf.Const"() {value = dense<10000> : tensor} : () -> tensor + %const_0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %const_1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %const_8_1 = "tf.Const"() {value = dense<[8, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.Range"(%const_0, %const_10000, %const_1) : (tensor, tensor, tensor) -> tensor<10000xi32> + %2 = "tf.ExpandDims"(%1, %const_0) : (tensor<10000xi32>, tensor) -> tensor<1x10000xi32> + %3 = "tf.Tile"(%2, %const_8_1) : (tensor<1x10000xi32>, tensor<2xi32>) -> tensor<8x10000xi32> + // CHECK-NOT: tf.Range + // CHECK-NOT: tf.ExpandDims + // CHECK: [[TILE:%.*]] = "tf.Tile" + // CHECK: return [[TILE]] + return %3 : tensor<8x10000xi32> +} +// LINT.ThenChange(../transforms/constant_fold.cc:folding-policy) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir index 4d557119a0b5dd..130887555b01dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir @@ -72,6 +72,32 @@ func @einsum_reshapetail(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6x2xf32>) -> // CHECK: return %[[v2]] : tensor<3x4x6x2xf32> } +func @einsum_reduceddim(%arg0: tensor<2x5x7xf32>, %arg1: tensor<2x5x7x3xf32>) -> tensor<2x5x3xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bin,binj->bij"}: (tensor<2x5x7xf32>, tensor<2x5x7x3xf32>) -> tensor<2x5x3xf32> + return %0 : tensor<2x5x3xf32> + // CHECK-LABEL: einsum_reduceddim + // CHECK: %[[cst:.*]] = constant dense<[2, 5, 1, 7]> : tensor<4xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[2, 5, 3]> : tensor<3xi64> + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<4xi64>) -> tensor<2x5x1x7xf32> + // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%[[v0]], %arg1) {adj_x = false, adj_y = false} : (tensor<2x5x1x7xf32>, tensor<2x5x7x3xf32>) -> tensor<2x5x1x3xf32> + // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<2x5x1x3xf32>, tensor<3xi64>) -> tensor<2x5x3xf32> + // CHECK: return %[[v2]] : tensor<2x5x3xf32> +} + +func @einsum_transposereduceddim(%arg0: tensor<2x5x7xf32>, %arg1: tensor<2x5x3x7xf32>) -> tensor<2x5x3xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bij,binj->bin"}: (tensor<2x5x7xf32>, tensor<2x5x3x7xf32>) -> tensor<2x5x3xf32> + return %0 : tensor<2x5x3xf32> + // CHECK-LABEL: einsum_transposereduceddim + // CHECK: %[[cst:.*]] = constant dense<[2, 5, 1, 7]> : tensor<4xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32> + // CHECK: %[[cst_2:.*]] = constant dense<[2, 5, 3]> : tensor<3xi64> + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<4xi64>) -> tensor<2x5x1x7xf32> + // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<2x5x3x7xf32>, tensor<4xi32>) -> tensor<2x5x7x3xf32> + // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x5x1x7xf32>, tensor<2x5x7x3xf32>) -> tensor<2x5x1x3xf32> + // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_2]]) : (tensor<2x5x1x3xf32>, tensor<3xi64>) -> tensor<2x5x3xf32> + // CHECK: return %[[v3]] : tensor<2x5x3xf32> +} + func @einsum_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> @@ -89,7 +115,7 @@ func @einsum_illegal_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> t func @einsum_no_match5D(%arg0: tensor<4x5xf32>, %arg1: tensor<2x4x7x3x5xf32>) -> tensor<4xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"}: (tensor<4x5xf32>, tensor<2x4x7x3x5xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> -// CHECK-LABEL: einsum_no_match5D +// CHECK-LABEL: einsum_no_match5D // CHECK: %[[v0:.*]] = "tf.Einsum" // CHECK: return %[[v0]] } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index c8542ab3bae260..4f044cd5eff987 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -2,734 +2,734 @@ func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = xla_hlo.add %arg0, %arg0 : tensor<2xi32> - %1 = xla_hlo.add %0, %arg0 : tensor<2xi32> + %0 = mhlo.add %arg0, %arg0 : tensor<2xi32> + %1 = mhlo.add %0, %arg0 : tensor<2xi32> return %1 : tensor<2xi32> } func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0 : tensor<4x4x4x4xi32> } func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - %0 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> + %0 = mhlo.shift_left %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> + %0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32> return %0 : tensor<4xf32> } func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> + %0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32> return %0 : tensor<4xf32> } func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> + %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> + %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32> return %0 : tensor<2xi32> } func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - %0 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + %0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { - %0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + %0 = "chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> return %0 : tensor<2x4xi32> } func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { - %0 = xla_hlo.and %arg0, %arg0 : tensor<2xi1> + %0 = mhlo.and %arg0, %arg0 : tensor<2xi1> return %0 : tensor<2xi1> } func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + %0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { - %0 = xla_hlo.or %arg0, %arg0 : tensor<2xi1> + %0 = mhlo.or %arg0, %arg0 : tensor<2xi1> return %0 : tensor<2xi1> } func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + %0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - %0 = xla_hlo.or %arg0, %arg1 : tensor<4xi32> + %0 = mhlo.or %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + %0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - %0 = xla_hlo.and %arg0, %arg1 : tensor<4xi32> + %0 = mhlo.and %arg0, %arg1 : tensor<4xi32> return %0 : tensor<4xi32> } func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + %0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = xla_hlo.power %arg0, %arg0 : tensor<2xf32> + %0 = mhlo.power %arg0, %arg0 : tensor<2xf32> return %0 : tensor<2xf32> } func @pow_dynamic(%arg0: tensor) -> tensor { - %0 = xla_hlo.power %arg0, %arg0 : tensor + %0 = mhlo.power %arg0, %arg0 : tensor return %0 : tensor } func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - %0 = xla_hlo.constant dense<0> : tensor<2x3xi32> - %1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - %2 = xla_hlo.constant dense<0> : tensor<3xi32> - %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> - %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> - %6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32> - %7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> - %8 = xla_hlo.constant dense<1> : tensor<3xi32> - %9 = xla_hlo.subtract %7, %8 : tensor<3xi32> - %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> - %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> - %12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> - %13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> - %14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = mhlo.constant dense<0> : tensor<2x3xi32> + %1 = "chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %2 = mhlo.constant dense<0> : tensor<3xi32> + %3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> + %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %6 = "mhlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %7 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> + %8 = mhlo.constant dense<1> : tensor<3xi32> + %9 = mhlo.subtract %7, %8 : tensor<3xi32> + %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %12 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> + %13 = "chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %14 : tensor<2x3xi32> } func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - %0 = xla_hlo.constant dense<0> : tensor<3xi32> - %1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - %2 = xla_hlo.constant dense<0> : tensor<2x3xi32> - %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> - %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> - %6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> - %7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> - %8 = xla_hlo.constant dense<1> : tensor<2x3xi32> - %9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32> - %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> - %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> - %12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> - %13 = xla_hlo.divide %11, %12 : tensor<2x3xi32> - %14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = mhlo.constant dense<0> : tensor<3xi32> + %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %2 = mhlo.constant dense<0> : tensor<2x3xi32> + %3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %6 = "mhlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> + %7 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %8 = mhlo.constant dense<1> : tensor<2x3xi32> + %9 = mhlo.subtract %7, %8 : tensor<2x3xi32> + %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %12 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %13 = mhlo.divide %11, %12 : tensor<2x3xi32> + %14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %14 : tensor<2x3xi32> } func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32> - %1 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32> - %2 = "xla_hlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32> + %0 = mhlo.divide %arg0, %arg0 : tensor<2xf32> + %1 = mhlo.divide %arg0, %arg0 : tensor<2xf32> + %2 = "mhlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32> return %2 : tensor<2xf32> } func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> - %1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> - %2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %1 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %2 = "mhlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16> return %2 : tensor<2x3xf16> } func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - %2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + %2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> return %2 : tensor<6x3xf32> } func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { - %2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> + %2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> return %2 : tensor<3x6xf32> } func @const() -> tensor<2xi32> { - %0 = xla_hlo.constant dense<0> : tensor<2xi32> + %0 = mhlo.constant dense<0> : tensor<2xi32> return %0 : tensor<2xi32> } func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { - %0 = xla_hlo.constant dense<0> : tensor - %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + %0 = mhlo.constant dense<0> : tensor + %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> return %1 : tensor<1xi32> } func @relu_unranked(%arg0: tensor) -> tensor { - %0 = xla_hlo.constant dense<0> : tensor - %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %0 = mhlo.constant dense<0> : tensor + %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %1 : tensor } func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { - %0 = xla_hlo.constant dense<0> : tensor - %1 = xla_hlo.constant dense<6> : tensor - %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> - %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<6> : tensor + %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> return %3 : tensor<1xi32> } func @relu6_unranked(%arg0: tensor) -> tensor { - %0 = xla_hlo.constant dense<0> : tensor - %1 = xla_hlo.constant dense<6> : tensor - %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor - %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<6> : tensor + %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %3 : tensor } func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf32> { - %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor - %2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> - %3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor + %2 = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32> + %3 = "mhlo.select"(%1, %arg0, %2) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return %3 : tensor<4x8xf32> } func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> return %0 : tensor<3x2xi32> } func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } func @selectv2_pred_scalar(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { - %0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> - %1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> - %2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> + %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64> + %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64> + %2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> return %2 : tensor<3x2xf32> } func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { - %0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi32> - %1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64> - %2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> + %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi32> + %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> + %2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> return %2 : tensor<3x2x1xf32> } func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { - %0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64> - %1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64> - %2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> + %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> + %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> + %2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> return %2 : tensor<3x2x1xf32> } func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { - %0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> - %1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> - %2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor) -> tensor<4x?xf32> + %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64> + %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64> + %2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor) -> tensor<4x?xf32> return %2 : tensor<4x?xf32> } func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> - %1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> - %2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> + %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64> + %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64> + %2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> return %2 : tensor<*xf32> } func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @abs_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.abs"(%arg0) : (tensor) -> tensor + %0 = "mhlo.abs"(%arg0) : (tensor) -> tensor return %0 : tensor } func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @ceil_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.ceil"(%arg0) : (tensor) -> tensor + %0 = "mhlo.ceil"(%arg0) : (tensor) -> tensor return %0 : tensor } func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { - %0 = "xla_hlo.abs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @cos_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.cosine"(%arg0) : (tensor) -> tensor + %0 = "mhlo.cosine"(%arg0) : (tensor) -> tensor return %0 : tensor } func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @exp_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.exponential"(%arg0) : (tensor) -> tensor + %0 = "mhlo.exponential"(%arg0) : (tensor) -> tensor return %0 : tensor } func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @floor_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + %0 = "mhlo.floor"(%arg0) : (tensor) -> tensor return %0 : tensor } func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { - %0 = "xla_hlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> + %0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> return %0 : tensor<2xi1> } func @is_finite_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.is_finite"(%arg0) : (tensor) -> tensor + %0 = "mhlo.is_finite"(%arg0) : (tensor) -> tensor return %0 : tensor } func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> { - %0 = "xla_hlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> + %0 = "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> return %0 : tensor<*xi1> } func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @log_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.log"(%arg0) : (tensor) -> tensor + %0 = "mhlo.log"(%arg0) : (tensor) -> tensor return %0 : tensor } func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @log1p_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.log_plus_one"(%arg0) : (tensor) -> tensor + %0 = "mhlo.log_plus_one"(%arg0) : (tensor) -> tensor return %0 : tensor } func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @neg_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + %0 = "mhlo.negate"(%arg0) : (tensor) -> tensor return %0 : tensor } func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = xla_hlo.constant dense<5.000000e-01> : tensor - %1 = xla_hlo.constant dense<2> : tensor<1xi64> - %2 = xla_hlo.constant dense<5.000000e-01> : tensor<2xf32> - %3 = xla_hlo.multiply %arg0, %2 : tensor<2xf32> - %4 = "xla_hlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32> - %5 = xla_hlo.multiply %4, %2 : tensor<2xf32> - %6 = xla_hlo.add %5, %2 : tensor<2xf32> + %0 = mhlo.constant dense<5.000000e-01> : tensor + %1 = mhlo.constant dense<2> : tensor<1xi64> + %2 = mhlo.constant dense<5.000000e-01> : tensor<2xf32> + %3 = mhlo.multiply %arg0, %2 : tensor<2xf32> + %4 = "mhlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32> + %5 = mhlo.multiply %4, %2 : tensor<2xf32> + %6 = mhlo.add %5, %2 : tensor<2xf32> return %6 : tensor<2xf32> } func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @sin_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.sine"(%arg0) : (tensor) -> tensor + %0 = "mhlo.sine"(%arg0) : (tensor) -> tensor return %0 : tensor } func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @rsqrt_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.rsqrt"(%arg0) : (tensor) -> tensor + %0 = "mhlo.rsqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @sqrt_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.sqrt"(%arg0) : (tensor) -> tensor + %0 = "mhlo.sqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @tanh_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.tanh"(%arg0) : (tensor) -> tensor + %0 = "mhlo.tanh"(%arg0) : (tensor) -> tensor return %0 : tensor } func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @bitcast_dynamic(%arg0: tensor) -> tensor { - %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor) -> tensor + %0 = "mhlo.bitcast_convert"(%arg0) : (tensor) -> tensor return %0 : tensor } func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { - %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> + %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> return %0 : tensor<2xi32> } func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { - %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> - %1 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> - %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> - %3 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> - %4 = "xla_hlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> - %5 = "xla_hlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> - %6 = "xla_hlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> + %1 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> + %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> + %3 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> + %4 = "mhlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> + %5 = "mhlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> + %6 = "mhlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> return %6 : tensor<1x2x3x4xf32> } func @size_rank_one_i32(%arg0: tensor) -> tensor { - %0 = xla_hlo.constant dense<1> : tensor + %0 = mhlo.constant dense<1> : tensor return %0 : tensor } func @size_rank_one_i64(%arg0: tensor) -> tensor { - %0 = xla_hlo.constant dense<1> : tensor + %0 = mhlo.constant dense<1> : tensor return %0 : tensor } func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { - %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> + %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> return %0 : tensor<3xcomplex> } func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { - %0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + %0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> return %0 : tensor<2xf32> } func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> { - %0 = "xla_hlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32> + %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32> return %0 : tensor<1x519xf32> } func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32> return %0 : tensor<2x2x6xf32> } func @convert_dot_1d_2d(%arg0: tensor<256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1xf32> { - %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32> + %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32> return %0 : tensor<1xf32> } func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) -> tensor<1xf32> { - %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32> + %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32> return %0 : tensor<1xf32> } func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> tensor { - %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor + %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor return %0 : tensor } func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> { - %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> + %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> return %0 : tensor<1x1xf32> } func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> { - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> return %0 : tensor<3x8x8x16xf32> } func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> { - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> return %0 : tensor<3x8x8x16xf32> } func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { - %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> + %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> return %0 : tensor<3x5x1x4xf32> } func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - %0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = + %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> @@ -737,7 +737,7 @@ func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32> } func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - %0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = + %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> @@ -745,7 +745,7 @@ func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x2 } func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - %0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = + %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> @@ -753,22 +753,22 @@ func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3 } func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { - %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %1 = "xla_hlo.reduce"(%arg0, %0) ( { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.reduce"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): - %2 = xla_hlo.add %arg1, %arg2 : tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%2) : (tensor) -> () }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor) -> tensor<1xf32> return %1 : tensor<1xf32> } func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { // "0xFF800000" represents -INF for f32. - %0 = xla_hlo.constant dense<0xFF800000> : tensor - %1 = "xla_hlo.reduce"(%arg0, %0) ( { + %0 = mhlo.constant dense<0xFF800000> : tensor + %1 = "mhlo.reduce"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): - %2 = xla_hlo.maximum %arg1, %arg2 : tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = mhlo.maximum %arg1, %arg2 : tensor + "mhlo.return"(%2) : (tensor) -> () }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor) -> tensor<1xf32> return %1 : tensor<1xf32> } @@ -776,11 +776,11 @@ func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { // "0x7F800000" represents INF for f32. - %0 = xla_hlo.constant dense<0x7F800000> : tensor - %1 = "xla_hlo.reduce"(%arg0, %0) ( { + %0 = mhlo.constant dense<0x7F800000> : tensor + %1 = "mhlo.reduce"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): - %2 = xla_hlo.minimum %arg1, %arg2 : tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = mhlo.minimum %arg1, %arg2 : tensor + "mhlo.return"(%2) : (tensor) -> () }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor) -> tensor<1xf32> return %1 : tensor<1xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index c04f034ede6874..3215055a249a48 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -86,6 +86,15 @@ func @mul_no_nan(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> tensor<2x3xf32 return %0 : tensor<2x3xf32> } +// CHECK-LABEL: @is_nan +func @is_nan(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> { + // CHECK: %[[NAN:.*]] = "tf.Const"() {value = dense<0x7FC00000> : tensor} : () -> tensor + // CHECK: %[[RESULT:.*]] = "tf.Equal"(%arg0, %[[NAN]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor) -> tensor<3x4xi1> + %0 = "tf.IsNan"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1> + // CHECK: return %[[RESULT]] + return %0 : tensor<3x4xi1> +} + // CHECK-LABEL: func @fill // CHECK-SAME: (%[[ARG0:.*]]: tensor<*xi64>, %[[ARG1:.*]]: tensor<*xf32>) func @fill(%arg0: tensor<*xi64>, %arg1: tensor<*xf32>) -> tensor<*xf32> { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir index 9a7732ce238a30..9931a45f99578c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir @@ -121,6 +121,26 @@ func @replicate_control() { // CHECK: tf_executor.fetch %[[SINK]] +// Tests unused replica are pinned to the graph fetch. +// CHECK-LABEL: func @unused_replica +func @unused_replica(%arg0: tensor) { + %0 = tf_executor.graph { + %1:3 = tf_executor.island { + %2:2 = tf_device.replicate([%arg0, %arg0] as %ri0: tensor) {n = 2 : i32} { + tf_device.return %ri0 : tensor + } + tf_executor.yield %2#0, %2#1 : tensor, tensor + } + tf_executor.fetch %1#1 : tensor + } + return +} + +// CHECK: {{%.*}}, [[REPLICA_0_CONTROL:%.*]] = tf_executor.island +// CHECK: [[REPLICA_1_OUTPUT:%.*]], {{%.*}} = tf_executor.island +// CHECK: tf_executor.fetch [[REPLICA_1_OUTPUT]], [[REPLICA_0_CONTROL]] + + // Tests replicate results are remapped correctly. // CHECK-LABEL: func @replicate_result func @replicate_result(%arg0: tensor, %arg1: tensor) { @@ -143,6 +163,33 @@ func @replicate_result(%arg0: tensor, %arg1: tensor) { // CHECK: tf_executor.fetch %[[REPLICA_0]]#0, %[[REPLICA_1]]#0, %[[REPLICA_0]]#1, %[[REPLICA_1]]#1 +// Tests replicate results are remapped correctly with packed inputs. +// CHECK-LABEL: func @replicate_with_packed_input +func @replicate_with_packed_input(%arg0: tensor, %arg1: tensor) { + %0:4 = tf_executor.graph { + %1:5 = tf_executor.island { + %2:4 = tf_device.replicate(%arg0 as %arg2: tensor, %arg1 as %arg3: tensor) + {n = 2 : i32, _packed_input_indices = [0, 1]} { + %3 = "tf.opA"(%arg2) : (tensor) -> tensor + %4 = "tf.opB"(%arg3) : (tensor) -> tensor + tf_device.return %3, %4 : tensor, tensor + } + tf_executor.yield %2#0, %2#1, %2#2, %2#3 : tensor, tensor, tensor, tensor + } + tf_executor.fetch %1#0, %1#1, %1#2, %1#3 : tensor, tensor, tensor, tensor + } + return +} + +// CHECK: %[[REPLICA_0:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: "tf.opA"(%arg0) +// CHECK: "tf.opB"(%arg1) +// CHECK: %[[REPLICA_1:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: "tf.opA"(%arg0) +// CHECK: "tf.opB"(%arg1) +// CHECK: tf_executor.fetch %[[REPLICA_0]]#0, %[[REPLICA_1]]#0 + + // Tests replica id is added correctly. // CHECK-LABEL: func @replica_id_attr_added func @replica_id_attr_added(%arg0: tensor, %arg1: tensor) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 24964284c28026..7c8e4382e2b1ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -710,3 +710,29 @@ func @callee(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.res %0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource>> return %0 : tensor<*x!tf.resource>> } + +// ----- + +// Tests call op where it's result is the result of a tf.ReadVariableOp. + +// CHECK-LABEL: func @call_with_forwarded_read_only_result +// CHECK-SAME: (%[[RESOURCE_ARG0:.*]]: tensor<*x!tf.resource>>) +func @call_with_forwarded_read_only_result(%arg0: tensor<*x!tf.resource>>) { + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[RESOURCE_ARG0]]) + %0 = "tf_device.cluster"() ( { + // CHECK: %[[CALL:.*]] = "tf.StatefulPartitionedCall"(%[[READ]]) + %1 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor<*x!tf.resource>>) -> tensor + // CHECK-NEXT: tf_device.return %[[CALL]] + tf_device.return %1 : tensor + }) {} : () -> tensor + return +} + +func @callee(%arg0: tensor<*x!tf.resource>>) -> tensor { + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource>>) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} + +// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor) -> tensor +// CHECK-NEXT: return %[[A0]] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index a0e9b0a5115b95..4193edf8cc642e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -354,21 +354,27 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // Test propagation from called functions to the call site. // CHECK-LABEL: func @stateful_partitioned_call( // CHECK-SAME: -> tensor<20xi32> - func @stateful_partitioned_call(%arg0: tensor<20xi32>) -> tensor<*xi32> { + func @stateful_partitioned_call(%arg0: tensor<20xi32>, %arg1: tensor) -> tensor<*xi32> { // CHECK: tf.PartitionedCall // CHECK-SAME: (tensor<20xi32>) -> tensor<20xi32> - %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @a_called_func} : (tensor<20xi32>) -> (tensor<*xi32>) + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @partitioned_called_func} : (tensor<20xi32>) -> tensor<*xi32> // CHECK: tf.StatefulPartitionedCall // CHECK-SAME: (tensor<20xi32>) -> tensor<20xi32> - %1 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_partitioned_call_func} : (tensor<20xi32>) -> (tensor<*xi32>) + %1 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_partitioned_call_func} : (tensor<20xi32>) -> tensor<*xi32> + // CHECK: tf.TPUPartitionedCall + // CHECK-SAME: (tensor<20xi32>, tensor) -> tensor<20xi32> + %2 = "tf.TPUPartitionedCall"(%arg0, %arg1) {autotuner_thresh = 0 : i64, f = @tpu_partitioned_call_func} : (tensor<20xi32>, tensor) -> tensor<*xi32> return %0 : tensor<*xi32> } - func @a_called_func(%arg0: tensor) -> (tensor) { + func @partitioned_called_func(%arg0: tensor) -> (tensor) { return %arg0 : tensor } func @stateful_partitioned_call_func(%arg0: tensor) -> (tensor) { return %arg0 : tensor } + func @tpu_partitioned_call_func(%arg0: tensor) -> (tensor) { + return %arg0 : tensor + } // Test propagation involving const values across caller and callee. func @partitioned_call_const(%arg0 : tensor<6xf32>) -> tensor<*xf32> { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir index f0ca3d2b7f8494..e4fdad2eddbe4f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir @@ -36,7 +36,7 @@ func @main() -> tensor { // CHECK-NEXT: %[[SUB:.*]] = "tf.Sub"(%[[READ_SIZE1]], %[[CONST1_1]]) // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[READ_VAL1]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32> - // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor // CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[SUB]]) : (tensor>>, tensor<1xi32>) -> () "tf.StackCloseV2"(%stack) : (tensor) -> () diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir index c453a3815f2c08..3d187aa5d604c0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir @@ -4,8 +4,8 @@ // CHECK-LABEL: func @main func @main() -> (tensor, tensor) { - // CHECK-NEXT: "tf.Const"() {value = dense<[]> : tensor<0xi32>} - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: "tf.Const"() {value = dense<> : tensor<0xi32>} + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NEXT: "tf.Const"() {value = dense<10> : tensor} %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor @@ -28,10 +28,10 @@ func @main() -> (tensor, tensor) { // CHECK-NEXT: %[[SUB:.*]] = "tf.Sub"(%[[NEW_SIZE]], %[[CONST1_1]]) // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32> - // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor %pop:2 = "tf.TensorListPopBack"(%push, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) - // CHECK-NEXT: %[[SCALAR_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} + // CHECK-NEXT: %[[SCALAR_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} // CHECK-NEXT: %[[LENGTH:.*]] = "tf.Reshape"(%[[NEW_SIZE]], %[[SCALAR_SHAPE]]) %length = "tf.TensorListLength"(%push) : (tensor>>) -> tensor // CHECK-NEXT: return %[[ELEM]], %[[LENGTH]] : tensor, tensor @@ -46,8 +46,8 @@ func @main() -> (tensor, tensor) { // CHECK-LABEL: func @main // CHECK-SAME: (%[[ARG0:.*]]: tensor) -> (tensor, tensor<10xf32>, tensor) func @main(%arg0: tensor) -> (tensor, tensor<10xf32>, tensor) { - // CHECK-NEXT: "tf.Const"() {value = dense<[]> : tensor<0xi32>} - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: "tf.Const"() {value = dense<> : tensor<0xi32>} + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NEXT: %[[NUM:.*]] = "tf.Const"() {value = dense<10> : tensor} %num = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor @@ -69,7 +69,7 @@ func @main(%arg0: tensor) -> (tensor, tensor<10xf32>, tensor) { // CHECK-NEXT: %[[GET_INDEX:.*]] = "tf.Reshape"(%[[ARG0]], %[[SIZE_SHAPE2]]) : (tensor, tensor<1xi32>) -> tensor<1xi32> // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[UPDATE]], %[[GET_INDEX]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32> - // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor %get = "tf.TensorListGetItem"(%set, %arg0, %elem_shape) : (tensor>>, tensor, tensor<0xi32>) -> tensor // CHECK-NEXT: %[[ADDN:.*]] = "tf.AddN"(%[[UPDATE]], %[[BROADCAST]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> @@ -92,8 +92,8 @@ func @main(%arg0: tensor) -> (tensor, tensor<10xf32>, tensor) { // CHECK-LABEL: func @main // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<10xf32>) -> tensor func @main(%arg0: tensor, %arg1: tensor<10xf32>) -> tensor { - // CHECK-NEXT: "tf.Const"() {value = dense<[]> : tensor<0xi32>} - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: "tf.Const"() {value = dense<> : tensor<0xi32>} + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NEXT: %[[BUFFER:.*]] = "tf.Identity"(%[[ARG1]]) : (tensor<10xf32>) -> tensor<10xf32> // CHECK-NEXT: %[[SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> %tl = "tf.TensorListFromTensor"(%arg1, %elem_shape) : (tensor<10xf32>, tensor<0xi32>) -> tensor>> @@ -101,7 +101,7 @@ func @main(%arg0: tensor, %arg1: tensor<10xf32>) -> tensor { // CHECK-NEXT: %[[GET_INDEX:.*]] = "tf.Reshape"(%[[ARG0]], %[[SIZE_SHAPE]]) : (tensor, tensor<1xi32>) -> tensor<1xi32> // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[BUFFER]], %[[GET_INDEX]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32> - // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor %get = "tf.TensorListGetItem"(%tl, %arg0, %elem_shape) : (tensor>>, tensor, tensor<0xi32>) -> tensor // CHECK-NEXT: return %[[ELEM]] : tensor @@ -164,7 +164,7 @@ func @main(%arg0: tensor<10x8x9xf32>, %arg1: tensor<5xi32>, %arg2: tensor<5x8x9x // CHECK-LABEL: func @main func @main() -> () { - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK-NOT: tf.EmptyTensorList %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> @@ -205,7 +205,7 @@ func @while_cond(%arg0: tensor>>, %arg1: tensor) -> // CHECK-LABEL: func @main func @main(%arg0: tensor) -> () { - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK-NOT: tf.EmptyTensorList %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> @@ -231,14 +231,14 @@ func @if_then(%arg0: tensor>>) -> tensor, %[[EARG1:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>) func @if_else(%arg0: tensor>>) -> tensor>> { - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NOT: "tf.TensorListPopBack" // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[EARG0]]) // CHECK: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[EARG1]], %[[CONST1_1]]) // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32> - // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor // CHECK-NOT: "tf.TensorListPopBack" %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) @@ -252,7 +252,7 @@ func @if_else(%arg0: tensor>>) -> tensor) -> () { - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK-NOT: tf.EmptyTensorList %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> @@ -278,14 +278,14 @@ func @branch_0(%arg0: tensor>>) -> tensor, %[[EARG1:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>) func @branch_1(%arg0: tensor>>) -> tensor>> { - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NOT: "tf.TensorListPopBack" // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[EARG0]]) // CHECK: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[EARG1]], %[[CONST1_1]]) // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32> - // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor // CHECK-NOT: "tf.TensorListPopBack" %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) @@ -294,14 +294,14 @@ func @branch_1(%arg0: tensor>>) -> tensor, %[[EARG1:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>) func @branch_2(%arg0: tensor>>) -> tensor>> { - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NOT: "tf.TensorListPopBack" // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[EARG0]]) // CHECK: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[EARG1]], %[[CONST1_1]]) // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32> - // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor // CHECK-NOT: "tf.TensorListPopBack" %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) @@ -314,7 +314,7 @@ func @branch_2(%arg0: tensor>>) -> tensor) -> () { - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK-NOT: tf.EmptyTensorList // CHECK: %[[INIT:.*]] = "tf.BroadcastTo" @@ -357,7 +357,7 @@ func @callee(%arg0: tensor>>, %arg1: tensor) -> tens // CHECK-LABEL: func @main func @main(%arg0: tensor) -> () { - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK-NOT: tf.EmptyTensorList // CHECK: %[[INIT:.*]] = "tf.BroadcastTo" @@ -403,7 +403,7 @@ func @main() -> () { } // CHECK: func @callee() func @callee() -> () attributes {sym_visibility = "public"} { - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK-NOT: tf.EmptyTensorList // CHECK: "tf.BroadcastTo" @@ -416,7 +416,7 @@ func @callee() -> () attributes {sym_visibility = "public"} { // Tests that the pass reports error on unknown maximum size. func @main(%arg0: tensor) -> () { - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // expected-error @+1 {{unknown max element count}} %tl = "tf.EmptyTensorList"(%elem_shape, %arg0) : (tensor<0xi32>, tensor) -> tensor>> return @@ -439,7 +439,7 @@ func @main(%arg0: tensor<*xi32>) -> () { // list. func @main(%arg0: tensor<*xi32>) -> () { - %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> %num = "tf.Const"() {value = dense<10> : tensor} : () -> tensor %tl = "tf.TensorListReserve"(%elem_shape, %num) : (tensor<0xi32>, tensor) -> tensor>> %elem = "tf._SomeOp"() : () -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 3ae75b475d637a..4464669051937b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1055,7 +1055,7 @@ func @testIfRegionThenConsumingElse(%arg0: tensor, %arg1: tensor<2xf32>) -> // The regions for IfRegion themselves cannot have any arguments func @testInvalidIfRegionThenArg(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { %neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> - // expected-error @+1 {{then region cannot have any arguments}} + // expected-error @+1 {{'tf.IfRegion' op region #0 should have no arguments}} %0 = "tf.IfRegion"(%arg0) ({ ^bb(%arg_bb: tensor<2xf32>): %t = "tf.Abs"(%arg_bb) : (tensor<2xf32>) -> tensor<2xf32> @@ -1072,7 +1072,7 @@ func @testInvalidIfRegionThenArg(%arg0: tensor, %arg1: tensor<2xf32>) -> ten func @testInvalidIfRegionElseArg(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { %neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> - // expected-error @+1 {{else region cannot have any arguments}} + // expected-error @+1 {{'tf.IfRegion' op region #1 should have no arguments}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%neg) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir index e72b02156d7a37..745cf72f959bd0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir @@ -42,16 +42,19 @@ func @empty_replicate() { // CHECK-LABEL: func @replicate_with_multiple_operands func @replicate_with_multiple_operands() { - %0 = "tf.opA"() : () -> (tensor<*xi1>) - %1 = "tf.opB"() : () -> (tensor<*xi1>) - %2 = "tf.opC"() : () -> (tensor<*xi1>) - %3 = "tf.opD"() : () -> (tensor<*xi32>) - %4 = "tf.opE"() : () -> (tensor<*xi32>) - %5 = "tf.opF"() : () -> (tensor<*xi32>) - %6 = "tf.opG"() : () -> (tensor<*xf32>) - %7 = "tf.opH"() : () -> (tensor<*xf32>) - %8 = "tf.opI"() : () -> (tensor<*xf32>) - tf_device.replicate([%0, %1, %2] as %input0: tensor<*xi1>, [%3, %4, %5] as %input1: tensor<*xi32>, [%6, %7, %8] as %input2: tensor<*xf32>) {n = 3 : i32} { + %0 = "tf.opA"() : () -> tensor<*xi1> + %1 = "tf.opB"() : () -> tensor<*xi1> + %2 = "tf.opC"() : () -> tensor<*xi1> + %3 = "tf.opD"() : () -> tensor<*xi32> + %4 = "tf.opE"() : () -> tensor<*xi32> + %5 = "tf.opF"() : () -> tensor<*xi32> + %6 = "tf.opG"() : () -> tensor<*xf32> + %7 = "tf.opH"() : () -> tensor<*xf32> + %8 = "tf.opI"() : () -> tensor<*xf32> + %9 = "tf.opJ"() : () -> tensor<*xi8> + %10 = "tf.opK"() : () -> tensor<*xi16> + %11 = "tf.opL"() : () -> tensor<*xi64> + tf_device.replicate([%0, %1, %2] as %input0: tensor<*xi1>, %9 as %input1: tensor<*xi8>, %10 as %input2: tensor<*xi16>, [%3, %4, %5] as %input3: tensor<*xi32>, [%6, %7, %8] as %input4: tensor<*xf32>, %11 as %input5: tensor<*xi64>) {n = 3 : i32} { tf_device.return } return @@ -65,12 +68,32 @@ func @replicate_with_multiple_operands() { // CHECK: %[[OP_G:[a-z0-9]*]] = "tf.opG" // CHECK: %[[OP_H:[a-z0-9]*]] = "tf.opH" // CHECK: %[[OP_I:[a-z0-9]*]] = "tf.opI" +// CHECK: %[[OP_J:[a-z0-9]*]] = "tf.opJ" +// CHECK: %[[OP_K:[a-z0-9]*]] = "tf.opK" +// CHECK: %[[OP_L:[a-z0-9]*]] = "tf.opL" // CHECK: tf_device.replicate -// CHECK-SAME: ([%[[OP_A]], %[[OP_B]], %[[OP_C]]] as %{{[a-z0-9]*}}: tensor<*xi1>, [%[[OP_D]], %[[OP_E]], %[[OP_F]]] as %{{[a-z0-9]*}}: tensor<*xi32>, [%[[OP_G]], %[[OP_H]], %[[OP_I]]] as %{{[a-z0-9]*}}: tensor<*xf32>) +// CHECK-SAME: [%[[OP_A]], %[[OP_B]], %[[OP_C]]] as %{{[a-z0-9]*}}: tensor<*xi1> +// CHECK-SAME: [%[[OP_D]], %[[OP_E]], %[[OP_F]]] as %{{[a-z0-9]*}}: tensor<*xi32> +// CHECK-SAME: [%[[OP_G]], %[[OP_H]], %[[OP_I]]] as %{{[a-z0-9]*}}: tensor<*xf32> +// CHECK-SAME: %[[OP_J]] as %{{[a-z0-9]*}}: tensor<*xi8> +// CHECK-SAME: %[[OP_K]] as %{{[a-z0-9]*}}: tensor<*xi16> +// CHECK-SAME: %[[OP_L]] as %{{[a-z0-9]*}}: tensor<*xi64> // CHECK-SAME: n = 3 // CHECK-NEXT: tf_device.return } +// CHECK-LABEL: func @replicate_derived_operand_segment_sizes +func @replicate_derived_operand_segment_sizes() { + tf_device.replicate {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} { + } + return + +// CHECK: tf_device.replicate +// CHECK-SAME: n = 2 +// CHECK-NOT: operand_segment_sizes +// CHECK-NEXT: tf_device.return +} + // CHECK-LABEL: func @replicate_with_return // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<*xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<*xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>) func @replicate_with_return(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xi32>) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir index 0eb5f878c2ae3b..ed205c1b9266cc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir @@ -61,7 +61,7 @@ func @parser_replicate_terminator() { func @verifier_replicate_no_block() { "tf_device.replicate" () ({ // expected-error@-1 {{'tf_device.replicate' op region #0 ('body') failed to verify constraint: region with 1 blocks}} - }) {n = 2 : i32} : () -> () + }) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> () return } @@ -72,7 +72,7 @@ func @verifier_replicate_empty_block() { "tf_device.replicate" () ({ // expected-error@-1 {{'tf_device.replicate' op expects a non-empty block}} ^entry: - }) {n = 2 : i32} : () -> () + }) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> () return } @@ -85,7 +85,7 @@ func @verifier_replicate_terminator() { // expected-error@-1 {{'tf_device.replicate' op expects regions to end with 'tf_device.return', found 'std.return'}} ^entry: return - }) {n = 2 : i32} : () -> () + }) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> () return } @@ -97,7 +97,7 @@ func @verifier_replicate_n() { // expected-error@-1 {{'tf_device.replicate' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 2}} ^entry: tf_device.return - }) {n = 1 : i32} : () -> () + }) {n = 1 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> () } // ----- @@ -109,43 +109,66 @@ func @verifier_replicate_n_device() { // expected-error@-1 {{'tf_device.replicate' op expects number of devices (2) to be equal to 'n' (3)}} ^entry: tf_device.return - }) {n = 3 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"]}} : () -> () + }) {devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"]}, n = 3 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> () } // ----- -// Check that replicate op's `devices` attribute must consist of dictionary +// Check that replicate op's 'devices' attribute must consist of dictionary // with values as list with size equal to 'n' attribute. func @verifier_replicate_n_device_multiple_alias() { "tf_device.replicate" () ({ // expected-error@-1 {{'tf_device.replicate' op expects number of devices (2) to be equal to 'n' (3)}} ^entry: tf_device.return - }) {n = 3 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"], TPU_REPLICATED_CORE_1 = ["/DEVICE:2"]}} : () -> () + }) {devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"], TPU_REPLICATED_CORE_1 = ["/DEVICE:2"]}, n = 3 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> () } // ----- -// Check that a replicate with mismatched operand and block arg counts is -// invalid. -func @verifier_replicate_operand_block_arg_count(%arg0: tensor<*xi32>) { - "tf_device.replicate" (%arg0, %arg0, %arg0) ({ -// expected-error@-1 {{'tf_device.replicate' op expects number of operands (3) to be equal to 'n' * number of block arguments (2 * 1)}} - ^entry(%input0: tensor<*xi32>): +// Check number of replicated inputs is evenly divisible by 'n'. +func @verifier_replicate_bad_operand_segment_sizes(%arg0: tensor<*xi32>) { + "tf_device.replicate" (%arg0, %arg0, %arg0, %arg0) ({ +// expected-error@-1 {{'tf_device.replicate' op expects number of replicated inputs (4) to be evenly divisible by 'n' (3)}} + ^entry(%input0: tensor<*xi32>, %input1: tensor<*xi32>): tf_device.return - }) {n = 2 : i32} : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> () + }) {n = 3 : i32, operand_segment_sizes = dense<[4, 0]> : vector<2xi32>} : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> () } // ----- -// Check that a replicate with incompatible operand and block argument type is -// invalid. -func @verifier_replicate_operand_block_arg_type(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { +// Check number of replicated inputs / 'n' + number of packed inputs matches the +// number of block arguments. +func @verifier_replicate_num_block_args(%arg0: tensor<*xi32>) { + "tf_device.replicate" (%arg0, %arg0, %arg0, %arg0, %arg0) ({ +// expected-error@-1 {{'tf_device.replicate' op expects number of block arguments (2) to be equal to number of replicated inputs (3) / 'n' (3) + number of packed inputs (2)}} + ^entry(%input0: tensor<*xi32>, %input1: tensor<*xi32>): + tf_device.return + }) {n = 3 : i32, operand_segment_sizes = dense<[3, 2]> : vector<2xi32>} : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> () +} + +// ----- + +// Check that a replicate with incompatible replicated operand and block +// argument type is invalid. +func @verifier_replicate_replicated_operand_block_arg_type(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { "tf_device.replicate" (%arg0, %arg1) ({ -// expected-error@-1 {{'tf_device.replicate' op incompatible types for operand 1 and block argument 0}} +// expected-error@-1 {{'tf_device.replicate' op expects operand 1 ('tensor<*xi1>') and block argument 0 ('tensor<*xi32>') to have compatible types}} + ^entry(%input0: tensor<*xi32>): + tf_device.return + }) {n = 2 : i32, operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<*xi32>, tensor<*xi1>) -> () +} + +// ----- + +// Check that a replicate with incompatible packed operand and block argument +// type is invalid. +func @verifier_replicate_packed_operand_block_arg_type(%arg0: tensor<*xi1>) { + "tf_device.replicate" (%arg0) ({ +// expected-error@-1 {{'tf_device.replicate' op expects operand 0 ('tensor<*xi1>') and block argument 0 ('tensor<*xi32>') to have compatible types}} ^entry(%input0: tensor<*xi32>): tf_device.return - }) {n = 2 : i32} : (tensor<*xi32>, tensor<*xi1>) -> () + }) {n = 2 : i32, operand_segment_sizes = dense<[0, 1]> : vector<2xi32>} : (tensor<*xi1>) -> () } // ----- @@ -157,7 +180,7 @@ func @verifier_replicate_result_return_operand_count(%arg0: tensor<*xi32>) { // expected-error@-1 {{'tf_device.replicate' op expects number of results (3) to be equal to 'n' * number of terminator operands (2 * 1)}} ^entry: tf_device.return %arg0 : tensor<*xi32> - }) {n = 2 : i32} : () -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) + }) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) } // ----- @@ -169,7 +192,7 @@ func @verifier_replicate_result_return_operand_type(%arg0: tensor<*xi32>) { // expected-error@-1 {{'tf_device.replicate' op incompatible types for result 1 and terminator operand 0}} ^entry: tf_device.return %arg0 : tensor<*xi32> - }) {n = 2 : i32} : () -> (tensor<*xi32>, tensor<*xi1>) + }) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> (tensor<*xi32>, tensor<*xi1>) } // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py index 51475197a12700..c51fcbfb259624 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py @@ -64,9 +64,9 @@ def Test(): inputs={'x': tensor_info_x}, outputs={'r': tensor_info_r}, method_name='some_function')) - } + }, None, None if __name__ == '__main__': common_v1.set_tf_options() - common_v1.do_test(Test()) + common_v1.do_test(Test) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py index 6160e25577bbe3..7a61b4b4f6a8ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py @@ -46,10 +46,7 @@ def set_tf_options(): # This function needs to take a "create_module_fn", as opposed to just the # module itself, because the creation of the module has to be delayed until # after absl and tensorflow have run various initialization steps. -def do_test(signature_def_map, - init_op=None, - canonicalize=False, - show_debug_info=False): +def do_test(create_signature, canonicalize=False, show_debug_info=False): """Runs test. 1. Performs absl and tf "main"-like initialization that must run before almost @@ -62,10 +59,10 @@ def do_test(signature_def_map, This is only for use by the MLIR SavedModel importer tests. Args: - signature_def_map: A map from string key to signature_def. The key will be - used as function name in the resulting MLIR. - init_op: The initializer op for the saved model. If set, it will generate a - initializer graph in the resulting MLIR. + create_signature: A functor that return signature_def_map, init_op and + assets_collection. signature_def_map is a map from string key to + signature_def. The key will be used as function name in the resulting + MLIR. canonicalize: If true, canonicalizer will be run on the resulting MLIR. show_debug_info: If true, shows debug locations in the resulting MLIR. """ @@ -84,6 +81,8 @@ def app_main(argv): else: save_model_path = tempfile.mktemp(suffix='.saved_model') + signature_def_map, init_op, assets_collection = create_signature() + sess = tf.Session() sess.run(tf.initializers.global_variables()) builder = tf.saved_model.builder.SavedModelBuilder(save_model_path) @@ -91,6 +90,7 @@ def app_main(argv): sess, [tf.saved_model.tag_constants.SERVING], signature_def_map, main_op=init_op, + assets_collection=assets_collection, strip_default_attrs=True) builder.save() diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_duplicate_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_duplicate_v1.py new file mode 100644 index 00000000000000..78fde0dca014dc --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_duplicate_v1.py @@ -0,0 +1,61 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/control_flow_duplicate_v1 | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v1 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 + +# Tests handling dupliate functions after V1 control flow is functionalized. + +# CHECK: func {{.*}} tf_saved_model.exported_names = ["key_1"] +# CHECK: "tf.If" +# CHECK-SAME: else_branch = @[[else:[a-zA-Z_0-9]+]] +# CHECK-SAME: then_branch = @[[then:[a-zA-Z_0-9]+]] + +# CHECK: func {{.*}} tf_saved_model.exported_names = ["key_2"] +# CHECK: "tf.If" +# CHECK-SAME: else_branch = @[[else]] +# CHECK-SAME: then_branch = @[[then]] + +# CHECK: func @[[else]]( +# CHECK: func @[[then]]( + + +def Test(): + + zero = tf.constant(0) + one = tf.constant(1) + x = tf.placeholder(tf.int32, shape=(), name='input') + result = tf.cond(x > zero, lambda: tf.square(x), lambda: tf.add(x, one)) + + tensor_info_result = tf.compat.v1.saved_model.utils.build_tensor_info(result) + + signature_def = tf.saved_model.signature_def_utils.build_signature_def( + inputs=None, + outputs={'result': tensor_info_result}, + method_name='some_function') + + return {'key_1': signature_def, 'key_2': signature_def}, None, None + + +if __name__ == '__main__': + common_v1.set_tf_options() + common_v1.do_test(Test) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py index 4684dc071f2c93..209ed3492e8bad 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py @@ -47,9 +47,9 @@ def Test(): outputs={'result': tensor_info_result}, method_name='some_function') - return {'key': signature_def} + return {'key': signature_def}, None, None if __name__ == '__main__': common_v1.set_tf_options() - common_v1.do_test(Test()) + common_v1.do_test(Test) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/defun_export.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/defun_export.py index 8bd128898a093d..cdd80789ebc746 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/defun_export.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/defun_export.py @@ -55,9 +55,9 @@ def test_defun(): }, outputs={'z': tensor_info_z}, method_name='test_function')) - } + }, None, None if __name__ == '__main__': common_v1.set_tf_options() - common_v1.do_test(test_defun()) + common_v1.do_test(test_defun) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/duplicate_method_names_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/duplicate_method_names_v1.py index 43fea693198117..204eafe8eda032 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/duplicate_method_names_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/duplicate_method_names_v1.py @@ -51,9 +51,9 @@ def Test(): inputs=None, outputs={'t': tensor_info_t}, method_name='some_function') # Create two signatures that share the same variable. - return {'key': signature_def, 'key2': signature_def2} + return {'key': signature_def, 'key2': signature_def2}, None, None if __name__ == '__main__': common_v1.set_tf_options() - common_v1.do_test(Test()) + common_v1.do_test(Test) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py new file mode 100644 index 00000000000000..7e86953eb8f74a --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py @@ -0,0 +1,73 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/hash_table_asset_v1| FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +import tensorflow.compat.v1 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 + +# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> () +# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset:.*]]"} + +# CHECK: func [[init]] +# CHECK-SAME: [[ARG:%.*]]: tensor {tf_saved_model.bound_input = @[[asset]]} +# CHECK-NEXT: [[R0:%.*]] = "tf.HashTableV2"() +# CHECK-SAME: shared_name = "[[hash_table:.*]]" +# CHECK-NEXT: "tf.InitializeTableFromTextFileV2"([[R0]], [[ARG]]) + + +def write_vocabulary_file(vocabulary): + """Write temporary vocab file for module construction.""" + tmpdir = tempfile.mkdtemp() + vocabulary_file = os.path.join(tmpdir, 'tokens.txt') + with tf.io.gfile.GFile(vocabulary_file, 'w') as f: + for entry in vocabulary: + f.write(entry + '\n') + return vocabulary_file + + +def test(): + + table_initializer = tf.lookup.TextFileInitializer( + write_vocabulary_file(['cat', 'is', 'on', 'the', 'mat']), tf.string, + tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64, + tf.lookup.TextFileIndex.LINE_NUMBER) + table = tf.lookup.StaticVocabularyTable(table_initializer, num_oov_buckets=10) + + x = tf.placeholder(tf.string, shape=(), name='input') + r = table.lookup(x) + + tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x) + tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r) + + return { + 'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def( + inputs={'x': tensor_info_x}, + outputs={'r': tensor_info_r}, + method_name='some_function')) + }, tf.tables_initializer(), tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS) + + +if __name__ == '__main__': + common_v1.set_tf_options() + common_v1.do_test(test) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py index 16290455608c41..3044a9b1c6130d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py @@ -84,9 +84,9 @@ def Test(): inputs={'x': tensor_info_x}, outputs={'r': tensor_info_r}, method_name='some_function')) - } + }, tf.tables_initializer(), None if __name__ == '__main__': common_v1.set_tf_options() - common_v1.do_test(Test(), tf.tables_initializer()) + common_v1.do_test(Test) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_arguments_results_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_arguments_results_v1.py index 1eb7861736849d..f25319353eda46 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_arguments_results_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_arguments_results_v1.py @@ -84,9 +84,9 @@ def Test(): 'd': tensor_info_s, }, method_name='reverse_arguments')) - } + }, None, None if __name__ == '__main__': common_v1.set_tf_options() - common_v1.do_test(Test()) + common_v1.do_test(Test) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_variables_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_variables_v1.py index ada77026006f5f..1159f1328a6c21 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_variables_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_variables_v1.py @@ -56,9 +56,9 @@ def Test(): inputs=None, outputs={'z': tensor_info_z}, method_name='some_function')) - } + }, None, None if __name__ == '__main__': common_v1.set_tf_options() - common_v1.do_test(Test()) + common_v1.do_test(Test) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/remove_init_variable_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/remove_init_variable_v1.py index 117132649d71a9..d5ed626a29044f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/remove_init_variable_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/remove_init_variable_v1.py @@ -65,10 +65,9 @@ def Test(): inputs={'x': tensor_info_x}, outputs={'r': tensor_info_r}, method_name='some_function')) - } + }, tf.initializers.global_variables(), None if __name__ == '__main__': common_v1.set_tf_options() - common_v1.do_test( - Test(), tf.initializers.global_variables(), canonicalize=True) + common_v1.do_test(Test, canonicalize=True) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py index 753b108c9868dd..5446764f2857e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py @@ -61,9 +61,9 @@ def Test(): method_name='some_other_function') # Create two signatures that share the same variable. - return {'key': signature_def, 'key2': signature_def2} + return {'key': signature_def, 'key2': signature_def2}, None, None if __name__ == '__main__': common_v1.set_tf_options() - common_v1.do_test(Test()) + common_v1.do_test(Test) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir index 326d356d985bfd..7156a1fab63758 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir @@ -7,6 +7,12 @@ module attributes {tf_saved_model.semantics} { initializer = @init } : () -> () + // CHECK: tf_saved_model.asset + "tf_saved_model.asset"() { + filename = "asset_filename", + sym_name = "asset_sym_name" + } : () -> () + // Representation for constants: (immutable) global tensor. // CHECK: tf_saved_model.global_tensor "tf_saved_model.global_tensor"() { @@ -48,6 +54,7 @@ module attributes {tf_saved_model.semantics} { // CHECK: func @init // CHECK-SAME: exported_names = ["__tf_saved_model_session_initializer"] func @init( + %arg0: tensor {tf_saved_model.bound_input = @asset_sym_name}, %arg1: tensor>> {tf_saved_model.bound_input = @some_constant} ) attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir index 99201d9d964f4c..dcb889ff99e76d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics +// RUN: tf-opt %s -split-input-file -verify-diagnostics -allow-unregistered-dialect module attributes {tf_saved_model.semantics} { @@ -387,3 +387,16 @@ module attributes {tf_saved_model.semantics} { return } } + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{unknown symbol operation}} + "some_dialect.some_op"() {sym_name = "v"} : () -> () + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + attributes { tf_saved_model.exported_names = ["a"] } { + return + } + +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_remove_vars_in_session_initializer.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_remove_vars_in_session_initializer.mlir new file mode 100644 index 00000000000000..a2eed45690e7e8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_remove_vars_in_session_initializer.mlir @@ -0,0 +1,83 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-saved-model-remove-vars-in-session-initializer | FileCheck %s + +module attributes {tf_saved_model.semantics} { + // Test case: No session initializer op +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // Test case: No matching function for the given session initializer. + // expected-error@+1 {{'tf_saved_model.session_initializer' op the initializer function does not exist}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // Test case: Invalid multiple blocks in the initializer funcion. + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + // expected-error@+1 {{expects exactly one block in the MLIR function}} + func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} { + br ^bb1 + ^bb1: + return + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // Test case: No variables + // CHECK: func @init() + // CHECK: tf.Const + // CHECK: return + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} { + "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32> + return + } +} + +// ----- + +module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} { + // Test case: Variable removal. + // CHECK: func @init() + // CHECK-NOT: tf.VarHandleOp + // CHECK-NOT: tf.Const + // CHECK-NOT: tf.AssignAddVariableOp + // CHECK: return + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} { + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "w"} : () -> tensor<*x!tf.resource>> + %2 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + "tf.AssignAddVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>>, tensor) -> () + "tf.AssignAddVariableOp"(%1, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>>, tensor) -> () + return + } +} + +// ----- + +module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} { + // Test case: Removal of shared variables. + // CHECK: func @init() + // CHECK-NOT: tf.VarHandleOp + // CHECK-NOT: tf.Const + // CHECK-NOT: tf.AssignAddVariableOp + // CHECK: return + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} { + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "w"} : () -> tensor<*x!tf.resource>> + %2 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %4 = "tf.Add"(%2, %3) : (tensor, tensor) -> tensor + "tf.AssignAddVariableOp"(%0, %4) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>>, tensor) -> () + "tf.AssignAddVariableOp"(%1, %4) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>>, tensor) -> () + return + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir index fb2c03e27c83f5..43be8743e51526 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir @@ -282,3 +282,111 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr return %1 : tensor } } + +// ----- + +// Tests that the pass can correctly transform a training loop with a packed +// variable. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + // CHECK-LABEL: func @main + func @main(%arg0: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + %arg2: tensor<*x!tf.resource>> {tf.device = "/device:COMPOSITE:0"}) { + + %0 = "tf.Const"() {value = dense<100> : tensor} : () -> tensor + // CHECK: %[[STATE0:.*]] = "tf.VarHandleOp"() + // CHECK-SAME: device = "/device:TPU:0" + // CHECK: %[[STATE1:.*]] = "tf.VarHandleOp"() + // CHECK-SAME: device = "/device:TPU:1" + // CHECK: %[[WHILE:.*]]:6 = "tf.While"( + // CHECK-SAME: %[[STATE0]], %[[STATE1]]) + %1:4 = "tf.While"(%0, %arg0, %arg1, %arg2) + {T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE", + "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE"], + body = @while_body_7560, + cond = @while_cond_7550, device = "", is_stateless = false, + output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]} + : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) + -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) + // CHECK: %[[DEFAULT:.*]] = "tf.Const"() + // CHECK: tf_device.replicate + // CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf.resource>>, + // CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor>>, + // CHECK-SAME: as %[[V1:.*]]: tensor<*x!tf.resource>> + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"] + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + return + } + // CHECK-LABEL: func @while_body_7560 + func @while_body_7560(%arg0: tensor, + %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + %arg3: tensor<*x!tf.resource>> {tf.device = "/device:COMPOSITE:0"}) + -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) { + // CHECK-SAME: (%[[ITER:.*]]: tensor, + // CHECK-SAME: %[[BODY_ARG1:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + // CHECK-SAME: %[[BODY_ARG2:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + // CHECK-SAME: %[[BODY_ARG3:.*]]: tensor<*x!tf.resource>> {tf.device = "/device:COMPOSITE:0"}, + // CHECK-SAME: %[[STATE_ARG0:.*]]: tensor>> {tf.device = "/device:TPU:0"}, + // CHECK-SAME: %[[STATE_ARG1:.*]]: tensor>> {tf.device = "/device:TPU:1"}) + %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + // CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir"() + %compile:2 = "tf_device.launch"() ( { + %2:2 = "tf._TPUCompileMlir"() { + NumDynamicShapes = 0 : i64, + // The metadata encodes 2 parameter and two return values. + metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", + mlir_module = "..."} : () -> (tensor, tensor) + tf_device.return %2#0, %2#1 : tensor, tensor + }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () + tf_device.return + }) {device = "/device:CPU:0"} : () -> () + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[BODY_ARG1]], %[[BODY_ARG2]]] as %[[R0:.*]]: tensor<*x!tf.resource>>, + // CHECK-SAME: [%[[STATE_ARG0]], %[[STATE_ARG1]]] as %[[R_STATE:.*]]: tensor>>, + // CHECK-SAME: %[[BODY_ARG3]] as %[[R1:.*]]: tensor<*x!tf.resource>> + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"] + %rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource>>, + %arg3 as %arg31: tensor<*x!tf.resource>>) + {_mirrored_variable_indices = [0, 1], _packed_input_indices = [1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { + // CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]]) + %id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + // CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1) + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) + {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + tf_device.return %ret : tensor + } + return %1, %arg1, %arg2, %arg3 : tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, tensor<*x!tf.resource>> + } + // CHECK-LABEL: func @while_cond_7550 + func @while_cond_7550(%arg0: tensor, + %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + %arg3: tensor<*x!tf.resource>> {tf.device = "/device:COMPOSITE:0"}) + -> tensor { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + return %1 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/end_to_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/end_to_end.mlir index 460a4185f88c94..41055152ab6ccd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/end_to_end.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/end_to_end.mlir @@ -9,7 +9,7 @@ module attributes {tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", tf_executor.graph { %outputs, %control = tf_executor.island wraps "std.constant"() {value = dense<2.000000e+00> : tensor} : () -> tensor %outputs_0, %control_1 = tf_executor.island wraps "std.constant"() {value = dense<3.000000e+00> : tensor} : () -> tensor - %control_2 = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_tpu = true} : () -> () + %control_2 = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_tpu = true, use_spmd_for_xla_partitioning = false} : () -> () %outputs_3, %control_4 = tf_executor.island wraps "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "x", shape = "tfshape$dim { }"} : () -> tensor<0xf32> %outputs_5, %control_6 = tf_executor.island wraps "tf.TPUReplicatedInput"(%outputs_3) {N = 1 : i64, T = "tfdtype$DT_FLOAT", device = "", name = "input0"} : (tensor<0xf32>) -> tensor<0xf32> %outputs_7, %control_8 = tf_executor.island wraps "tf.Identity"(%outputs_5) {T = "tfdtype$DT_FLOAT", _tpu_input_identity = true, _tpu_replicate = "cluster", device = "", name = "replicated_input_0"} : (tensor<0xf32>) -> tensor<0xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index 4e4317ce5dd2a7..37dfec5e6df1ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -41,115 +41,6 @@ func @metadata_op_removed() { // CHECK-NOT: "tf.TPUReplicateMetadata" -// Test ops in an island with the same `_tpu_replicate` attribute are merged -// under a `tf_device.cluster`. -// CHECK-LABEL: func @simple_island -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) -func @simple_island(%arg0 : tensor) -> tensor { - %0 = tf_executor.graph { - %1:2 = tf_executor.island { - "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> () - %3 = "tf.opA"(%arg0) {_tpu_replicate = "replicate"} : (tensor) -> tensor - %4 = "tf.opB"() : () -> tensor - %5 = "tf.opC"(%3) {_tpu_replicate = "replicate"} : (tensor) -> tensor - tf_executor.yield %5 : tensor - } - tf_executor.fetch %1#0 : tensor - } - return %0 : tensor -} - -// CHECK: "tf.opB" -// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ( { -// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) -// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) -// CHECK-NEXT: tf_device.return %[[OP_C]] -// CHECK-NEXT: _tpu_replicate = "replicate" -// CHECK-SAME: device = "device" -// CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[CLUSTER]] - - -// Test ops in an island with the same `_tpu_replicate` attribute are merged -// under a `tf_device.cluster`, even when the associated TPUReplicateMetadata op -// is in a different island. -// CHECK-LABEL: func @simple_island_separate_metadata -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) -func @simple_island_separate_metadata(%arg0 : tensor) -> tensor { - %0 = tf_executor.graph { - %1 = tf_executor.island { - "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> () - } - %2:2 = tf_executor.island { - %3 = "tf.opA"(%arg0) {_tpu_replicate = "replicate"} : (tensor) -> tensor - %4 = "tf.opB"() : () -> tensor - %5 = "tf.opC"(%3) {_tpu_replicate = "replicate"} : (tensor) -> tensor - tf_executor.yield %5 : tensor - } - tf_executor.fetch %2#0 : tensor - } - return %0 : tensor -} - -// CHECK: "tf.opB" -// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ( { -// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) -// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) -// CHECK-NEXT: tf_device.return %[[OP_C]] -// CHECK-NEXT: _tpu_replicate = "replicate" -// CHECK-SAME: device = "device" -// CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[CLUSTER]] - - -// Test ops in multiple islands with the same `_tpu_replicate` attribute are -// merged under `tf_device.cluster` ops only within their respective island. -// CHECK-LABEL: func @multiple_islands_separate_metadata -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) -func @multiple_islands_separate_metadata(%arg0 : tensor) -> (tensor, tensor) { - %0:2 = tf_executor.graph { - %1 = tf_executor.island { - "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> () - } - %2:2 = tf_executor.island { - %3 = "tf.opA"(%arg0) {_tpu_replicate = "replicate"} : (tensor) -> tensor - %4 = "tf.opB"() : () -> tensor - %5 = "tf.opC"(%3) {_tpu_replicate = "replicate"} : (tensor) -> tensor - tf_executor.yield %5 : tensor - } - %6:2 = tf_executor.island { - %7 = "tf.opD"(%2#0) {_tpu_replicate = "replicate"} : (tensor) -> tensor - %8 = "tf.opE"() : () -> tensor - %9 = "tf.opF"(%arg0) {_tpu_replicate = "replicate"} : (tensor) -> tensor - tf_executor.yield %9 : tensor - } - tf_executor.fetch %2#0, %6#0 : tensor, tensor - } - return %0#0, %0#1 : tensor, tensor -} - -// CHECK: %[[ISLAND_1:.*]], %[[ISLAND_1_control:.*]] = tf_executor.island { -// CHECK: "tf.opB" -// CHECK: %[[CLUSTER_0:[0-9]*]] = "tf_device.cluster"() ( { -// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) -// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) -// CHECK-NEXT: tf_device.return %[[OP_C]] -// CHECK-NEXT: _tpu_replicate = "replicate" -// CHECK-SAME: device = "device" -// CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[CLUSTER_0]] -// CHECK: tf_executor.island { -// CHECK: "tf.opE" -// CHECK: %[[CLUSTER_1:[0-9]*]] = "tf_device.cluster"() ( { -// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[ISLAND_1]]) -// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_0]]) -// CHECK-NEXT: tf_device.return %[[OP_F]] -// CHECK-NEXT: _tpu_replicate = "replicate" -// CHECK-SAME: device = "device" -// CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[CLUSTER_1]] - - // Test ops in a function body with the same `_tpu_replicate` attribute are // merged under a `tf_device.cluster` op. // CHECK-LABEL: func @ops_in_func_body @@ -185,9 +76,9 @@ func @ops_in_func_body(%arg0 : tensor) -> (tensor, tensor, tensor) func @nested_cluster_op_user(%arg0 : tensor) -> (tensor) { %0 = "tf.opA"(%arg0) {_tpu_replicate = "replicate"} : (tensor) -> tensor - %1 = tf_executor.graph { - tf_executor.fetch %0 : tensor - } + %1 = "tf_device.launch"() ( { + tf_device.return %0 : tensor + }) {device = "device"} : () -> tensor %2 = "tf.opB"(%0) {_tpu_replicate = "replicate"} : (tensor) -> tensor "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> () return %2 : tensor @@ -200,8 +91,8 @@ func @nested_cluster_op_user(%arg0 : tensor) -> (tensor) { // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: tf_executor.graph { -// CHECK-NEXT: tf_executor.fetch %[[CLUSTER]]#0 +// CHECK: tf_device.launch +// CHECK-NEXT: tf_device.return %[[CLUSTER]]#0 // CHECK: return %[[CLUSTER]]#1 @@ -363,15 +254,17 @@ func @replication(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> // Non-negative `index` should precede `index` of -1, and ordering of ops with // `index` of -1 does not matter. // CHECK-LABEL: func @sort_replicated_input -// CHECK-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor, %[[ARG_2:.*]]: tensor, %[[ARG_3:.*]]: tensor, %[[ARG_4:.*]]: tensor, %[[ARG_5:.*]]: tensor) -func @sort_replicated_input(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor) { +// CHECK-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor, %[[ARG_2:.*]]: tensor, %[[ARG_3:.*]]: tensor, %[[ARG_4:.*]]: tensor, %[[ARG_5:.*]]: tensor, %[[ARG_6:.*]]: tensor, %[[ARG_7:.*]]: tensor) +func @sort_replicated_input(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor) { %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -1 : i64} : (tensor, tensor) -> tensor - %1 = "tf.TPUReplicatedInput"(%arg1, %arg1) {index = 2 : i64} : (tensor, tensor) -> tensor + %1 = "tf.TPUReplicatedInput"(%arg1, %arg1) {index = 3 : i64} : (tensor, tensor) -> tensor %2 = "tf.TPUReplicatedInput"(%arg2, %arg2) {index = 0 : i64} : (tensor, tensor) -> tensor %3 = "tf.TPUReplicatedInput"(%arg3, %arg3) {index = -1 : i64} : (tensor, tensor) -> tensor %4 = "tf.TPUReplicatedInput"(%arg4, %arg4) {index = 1 : i64} : (tensor, tensor) -> tensor - %5 = "tf.TPUReplicatedInput"(%arg5, %arg5) {index = -1 : i64} : (tensor, tensor) -> tensor - "tf.opA"(%0, %1, %2, %3, %4, %5) {_tpu_replicate = "replicate", device = "device"} : (tensor, tensor, tensor, tensor, tensor, tensor) -> () + %5 = "tf.TPUReplicatedInput"(%arg5) {index = -1 : i64, is_packed = true} : (tensor) -> tensor + %6 = "tf.TPUReplicatedInput"(%arg6) {index = 2 : i64, is_packed = true} : (tensor) -> tensor + %7 = "tf.TPUReplicatedInput"(%arg7, %arg7) {index = -1 : i64} : (tensor, tensor) -> tensor + "tf.opA"(%0, %1, %2, %3, %4, %5, %6, %7) {_tpu_replicate = "replicate", device = "device"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () return } @@ -382,47 +275,60 @@ func @sort_replicated_input(%arg0: tensor, %arg1: tensor, %arg2: tensor< // CHECK-SAME: [%[[ARG_1]], %[[ARG_1]]] as %{{[a-z0-9]*}} // CHECK-DAG: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}} // CHECK-DAG: [%[[ARG_3]], %[[ARG_3]]] as %{{[a-z0-9]*}} -// CHECK-DAG: [%[[ARG_5]], %[[ARG_5]]] as %{{[a-z0-9]*}} -// CHECK-SAME: _replicated_input_indices = [0, 1, 2, -1, -1, -1] +// CHECK-DAG: [%[[ARG_7]], %[[ARG_7]]] as %{{[a-z0-9]*}} +// CHECK-DAG: %[[ARG_6]] as %{{[a-z0-9]*}} +// CHECK-DAG: %[[ARG_5]] as %{{[a-z0-9]*}} +// CHECK-SAME: _replicated_input_indices = [0, 1, 3, -1, -1, -1, 2, -1] // Test TPUReplicatedInputs with non contiguous `index` attributes. // CHECK-LABEL: func @non_contigous_indices -// CHECK-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor, %[[ARG_2:.*]]: tensor) -func @non_contigous_indices(%arg0: tensor, %arg1: tensor, %arg2: tensor) { +// CHECK-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor, %[[ARG_2:.*]]: tensor, %[[ARG_3:.*]]: tensor, %[[ARG_4:.*]]: tensor, %[[ARG_5:.*]]: tensor) +func @non_contigous_indices(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor) { %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 8 : i64} : (tensor, tensor) -> tensor "tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor) -> () - %1 = "tf.TPUReplicatedInput"(%arg1, %arg1) : (tensor, tensor) -> tensor - "tf.opB"(%1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor) -> () - %2 = "tf.TPUReplicatedInput"(%arg2, %arg2) {index = 2 : i64} : (tensor, tensor) -> tensor - "tf.opC"(%2) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor) -> () + %1 = "tf.TPUReplicatedInput"(%arg1) {index = 6 : i64, is_packed = true} : (tensor) -> tensor + "tf.opA"(%1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor) -> () + %2 = "tf.TPUReplicatedInput"(%arg2, %arg2) : (tensor, tensor) -> tensor + "tf.opB"(%2) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor) -> () + %3 = "tf.TPUReplicatedInput"(%arg3) {is_packed = true} : (tensor) -> tensor + "tf.opB"(%3) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor) -> () + %4 = "tf.TPUReplicatedInput"(%arg4, %arg4) {index = 2 : i64} : (tensor, tensor) -> tensor + "tf.opC"(%4) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor) -> () + %5 = "tf.TPUReplicatedInput"(%arg5) {index = 4 : i64, is_packed = true} : (tensor) -> tensor + "tf.opC"(%5) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor) -> () "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () return } // CHECK: tf_device.replicate -// CHECK-SAME: [%[[ARG_2]], %[[ARG_2]]] as %{{[a-z0-9]*}} +// CHECK-SAME: [%[[ARG_4]], %[[ARG_4]]] as %{{[a-z0-9]*}} // CHECK-SAME: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}} -// CHECK-SAME: [%[[ARG_1]], %[[ARG_1]]] as %{{[a-z0-9]*}} -// CHECK-SAME: _replicated_input_indices = [2, 8, -1] +// CHECK-SAME: [%[[ARG_2]], %[[ARG_2]]] as %{{[a-z0-9]*}} +// CHECK-SAME: %[[ARG_5]] as %{{[a-z0-9]*}} +// CHECK-SAME: %[[ARG_1]] as %{{[a-z0-9]*}} +// CHECK-SAME: %[[ARG_3]] as %{{[a-z0-9]*}} +// CHECK-SAME: _replicated_input_indices = [2, 8, -1, 4, 6, -1] // Test that the `is_mirrored_variable` attribute is preserved in the // tf_device.replicate op. // CHECK-LABEL: func @mirrored_variables -// CHECK-SAME: (%[[ARG_0:.*]]: tensor>>, %[[ARG_1:.*]]: tensor>>, %[[ARG_2:.*]]: tensor>>, %[[ARG_3:.*]]: tensor>>) -func @mirrored_variables(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>) { +// CHECK-SAME: (%[[ARG_0:.*]]: tensor>>, %[[ARG_1:.*]]: tensor>>, %[[ARG_2:.*]]: tensor>>, %[[ARG_3:.*]]: tensor>>, %[[ARG_4:.*]]: tensor>>) +func @mirrored_variables(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>, %arg4: tensor>>) { %0 = "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 0 : i64} : (tensor>>, tensor>>) -> tensor>> %1 = "tf.TPUReplicatedInput"(%arg2, %arg3) {index = 1 : i64, is_mirrored_variable = true} : (tensor>>, tensor>>) -> tensor>> - "tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device"} : (tensor>>, tensor>>) -> () + %2 = "tf.TPUReplicatedInput"(%arg4) {index = 2 : i64, is_mirrored_variable = true, is_packed = true} : (tensor>>) -> tensor>> + "tf.opA"(%0, %1, %2) {_tpu_replicate = "replicate", device = "device"} : (tensor>>, tensor>>, tensor>>) -> () "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () return } // CHECK: tf_device.replicate // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %{{[a-z0-9]*}} -// CHECK-SAME: _mirrored_variable_indices = [1] -// CHECK-SAME: _replicated_input_indices = [0, 1] +// CHECK-SAME: %[[ARG_4]] as %{{[a-z0-9]*}} +// CHECK-SAME: _mirrored_variable_indices = [1, 2] +// CHECK-SAME: _replicated_input_indices = [0, 1, 2] // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir index 7feea3314fd3e4..2e1b1549e9f232 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir @@ -17,7 +17,7 @@ func @single_arg_single_shape(%arg0: tensor) { } // CHECK-LABEL: func @func0 -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor {xla_hlo.padding_map = {padding_arg_indices = [1 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor {mhlo.padding_map = {padding_arg_indices = [1 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor) func @func0(%arg0: tensor, %arg1: tensor) { return } @@ -44,7 +44,7 @@ func @single_arg_multiple_shapes(%arg0: tensor) { } // CHECK-LABEL: func @func1 -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor {xla_hlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor {mhlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor) func @func1(%arg0: tensor, %arg1: tensor, %arg2: tensor) { return } @@ -76,7 +76,7 @@ func @multiple_args(%arg0: tensor) { } // CHECK-LABEL: func @func2 -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor {xla_hlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor {xla_hlo.padding_map = {padding_arg_indices = [3 : i32], shape_indices = [1 : i32]}}) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor {mhlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor {mhlo.padding_map = {padding_arg_indices = [3 : i32], shape_indices = [1 : i32]}}) func @func2(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) { return } @@ -97,7 +97,7 @@ func @remap_indices(%arg0: tensor) { } // CHECK-LABEL: func @func3 -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor {xla_hlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [2 : i32]}}) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor {mhlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [2 : i32]}}) func @func3(%arg0: tensor, %arg1: tensor, %arg2: tensor) { return } @@ -196,7 +196,7 @@ func @missing_padding_arg(%arg0: tensor) { } // CHECK-LABEL: func @func8 -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor {xla_hlo.padding_map = {padding_arg_indices = [2 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor {mhlo.padding_map = {padding_arg_indices = [2 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor) func @func8(%arg0: tensor, %arg1: tensor, %arg2: tensor) { return } @@ -218,7 +218,7 @@ func @missing_replicated_input_indices(%arg0: tensor) { } // CHECK-LABEL: func @func9 -// CHECK-NOT: xla_hlo.padding_map +// CHECK-NOT: mhlo.padding_map func @func9(%arg0: tensor, %arg1: tensor) { return } @@ -240,7 +240,7 @@ func @non_contigous_indices(%arg0: tensor) { } // CHECK-LABEL: func @func10 -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor {xla_hlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [6 : i32]}}) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor, %{{[a-z0-9]+}}: tensor {mhlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [6 : i32]}}) func @func10(%arg0: tensor, %arg1: tensor) { return } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 8cf55628a89c37..fa70ca85419406 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -5,7 +5,7 @@ // expected-error@+1 {{requires attribute 'tf.versions'}} module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_tf_versions() { - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -20,7 +20,7 @@ module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_devices() { // expected-error@+1 {{error in fetching TPU compilation/execution devices: no TPU_SYSTEM devices found}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -36,7 +36,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'num_cores_per_replica'}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -51,7 +51,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'num_cores_per_replica'}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = "", step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = "", step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -66,7 +66,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'step_marker_location'}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -81,7 +81,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_step_marker_location() { // expected-error@+1 {{requires attribute 'step_marker_location'}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -96,7 +96,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_step_marker_location() { // expected-error@+1 {{bad 'step_marker_location' attribute with value 'test'}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "test", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "test", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -111,7 +111,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_padding_map() { // expected-error@+1 {{requires attribute 'padding_map'}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -126,7 +126,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_padding_map() { // expected-error@+1 {{requires attribute 'padding_map'}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -141,7 +141,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_padding_map() { // expected-error@+1 {{bad 'padding_map' attribute at index 0, not a string}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [1], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [1], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -156,7 +156,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_element_padding_map() { // expected-error@+1 {{bad 'padding_map' attribute at index 0 with value 'test': failed to parse to tpu::PaddingMap}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["test"], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["test"], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -171,7 +171,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_topology() { // expected-error@+1 {{requires attribute 'topology'}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -186,7 +186,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_topology() { // expected-error@+1 {{requires attribute 'topology'}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = 1 : i32, device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = 1 : i32, device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -201,7 +201,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @invalid_topology() { // expected-error@+1 {{error in fetching TPU compilation/execution devices}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "test", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "test", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -216,7 +216,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_device_assignment() { // expected-error@+1 {{requires attribute 'device_assignment'}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -231,7 +231,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_device_assignment() { // expected-error@+1 {{requires attribute 'device_assignment'}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = "", input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -246,7 +246,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_device_assignment() { // expected-error@+1 {{bad 'device_assignment' attribute at index 0, not an int}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [""], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [""], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -282,7 +282,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @invalid_device_assignment() { // expected-error@+1 {{error in fetching TPU compilation/execution devices}} - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "\0A\03\01\01\02\10\01\18\02\22\06\00\00\00\00\00\01", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "\0A\03\01\01\02\10\01\18\02\22\06\00\00\00\00\00\01", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () return } func @empty_func() { @@ -297,7 +297,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'input_sharding_configuration'}} - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], output_sharding_configuration = []} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -322,7 +322,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'input_sharding_configuration'}} - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = "", output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = "", output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -337,7 +337,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @mismatched_size_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'input_sharding_configuration' attribute, expected array attribute of size 1, got size 0}} - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -352,7 +352,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unsupported_operand_type(%arg0: tensor) { // expected-error@+1 {{failed to determine operand type at index 0: Converting i2 to DataType}} - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -367,7 +367,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'input_sharding_configuration' attribute at index 0, not a string}} - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [1], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [1], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -382,7 +382,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_element_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'input_sharding_configuration' attribute at index 0 with value 'test': failed to parse to xla::OpSharding}} - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["test"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["test"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -397,7 +397,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'output_sharding_configuration'}} - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -412,7 +412,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'output_sharding_configuration'}} - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ""} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = "", use_spmd_for_xla_partitioning = false} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -427,7 +427,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @mismatched_size_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'output_sharding_configuration' attribute, expected array attribute of size 1, got size 0}} - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = []} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -443,7 +443,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'output_sharding_configuration' attribute at index 0, not a string}} - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [1]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [1], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -458,7 +458,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_element_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'output_sharding_configuration' attribute at index 0 with value 'test': failed to parse to xla::OpSharding}} - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["test"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["test"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -478,7 +478,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @default_step_marker_location func @default_step_marker_location() { - "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () // CHECK: metadata // CHECK-SAME: num_replicas: 1 // CHECK-SAME: num_cores_per_replica: 1 @@ -497,7 +497,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @unranked_shape_arg func @unranked_shape_arg(%arg0: tensor<*xi32>) -> tensor<*xi32> { - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>) -> tensor<*xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<*xi32>) -> tensor<*xi32> // CHECK: metadata // CHECK-SAME: shape {\0A unknown_rank: true @@ -515,7 +515,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @partial_shape_arg func @partial_shape_arg(%arg0: tensor) -> tensor { - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: shape {\0A dim {\0A size: -1\0A }\0A dim {\0A size: -1\0A }\0A dim {\0A size: 3\0A }\0A } @@ -546,7 +546,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @static_shape_arg func @static_shape_arg(%arg0: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> { - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<1x2x3xi32>) -> tensor<1x2x3xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<1x2x3xi32>) -> tensor<1x2x3xi32> // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: shape @@ -571,7 +571,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @resource_arg func @resource_arg(%arg0: tensor<*x!tf.resource>) -> tensor<*x!tf.resource> { - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource> // CHECK: metadata // CHECK: dtype: DT_RESOURCE // CHECK-SAME: kind: VARIABLE @@ -590,7 +590,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @parameter_arg func @parameter_arg(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<*xf32>) -> tensor<*xf32> // CHECK: metadata // CHECK: dtype: DT_FLOAT // CHECK-SAME: kind: PARAMETER @@ -650,7 +650,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @metadata func @metadata(%arg0: tensor<8xi32>) -> tensor<8xi32> { - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32> // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: dtype: DT_INT32 @@ -684,6 +684,24 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- +// Tests metadata is populated correctly for use_spmd_for_xla_partitioning == +// true. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // CHECK-LABEL: func @metadata + func @metadata(%arg0: tensor<8xi32>) -> tensor<8xi32> { + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = true} : (tensor<8xi32>) -> tensor<8xi32> + // CHECK: metadata + // CHECK-SAME: use_spmd_for_xla_partitioning: true + return %0: tensor<8xi32> + } + func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> { + return %arg0 : tensor<8xi32> + } +} + +// ----- + // Tests shape ops are only generated for operands with non static shapes. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { @@ -694,7 +712,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NOT: "tf.Shape"(%[[ARG_3]]) // CHECK: %[[ARG_0_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_0]]) // CHECK: %[[ARG_2_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_2]]) - %0 = "tf_device.cluster_func"(%arg0, %arg1, %arg2, %arg3) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>, tensor<8xi32>, tensor<*xi32>, tensor<8xi32>) -> tensor<8xi32> + %0 = "tf_device.cluster_func"(%arg0, %arg1, %arg2, %arg3) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<*xi32>, tensor<8xi32>, tensor<*xi32>, tensor<8xi32>) -> tensor<8xi32> // CHECK: "tf._TPUCompileMlir"(%[[ARG_0_SHAPE]], %[[ARG_2_SHAPE]]) return %0: tensor<8xi32> @@ -715,7 +733,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -777,7 +795,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[RI_0]], %[[COMPILE_OUTPUT]]#1) - %2 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %2 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: tf_device.return %[[EXECUTE_OUTPUT]] tf_device.return %2 : tensor @@ -805,7 +823,7 @@ module attributes {tf.versions = {producer = 888 : i32}} { func @single_gpu_cluster_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor - %1 = "tf_device.cluster_func"(%0) {device = "gpu0", func = @gpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {device = "gpu0", func = @gpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: tf_device.cluster_func // CHECK-SAME: device = "gpu0" // CHECK-SAME: func = @gpu0_func @@ -833,7 +851,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -882,7 +900,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -927,7 +945,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -980,7 +998,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1027,7 +1045,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func0, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func0, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1041,7 +1059,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1) - %2 = "tf_device.cluster_func"(%1) {_tpu_replicate = "cluster1", func = @tpu0_func1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %2 = "tf_device.cluster_func"(%1) {_tpu_replicate = "cluster1", func = @tpu0_func1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]]) @@ -1083,7 +1101,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1097,7 +1115,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1) - %2 = "tf_device.cluster_func"(%1) {_tpu_replicate = "cluster1", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %2 = "tf_device.cluster_func"(%1) {_tpu_replicate = "cluster1", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]]) @@ -1135,7 +1153,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1205,7 +1223,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf.TPUCompileSucceededAssert" // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute" - %1 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor %compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor %compile_result2 = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor @@ -1248,7 +1266,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor "tf.D"(%program) : (tensor) -> () tf_device.return }, { - %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor tf_device.return %4 : tensor }, { %status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor, tensor) @@ -1291,7 +1309,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: "tf.TPUExecute" // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:1" - %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32> return %0 : tensor<8xi32> } func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> { @@ -1355,7 +1373,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: "tf.TPUExecute" // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "TPU_REPLICATED_CORE_1" - %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32> tf_device.return %1 : tensor<8xi32> } return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> @@ -1390,7 +1408,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[RI_1]], %[[RI_2]], %[[COMPILE]]#2) // CHECK: device = "TPU_REPLICATED_CORE_1" - %1 = "tf_device.cluster_func"(%ri, %ri2, %ri3) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "", ""], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>, tensor<*xi1>, tensor<*xi32>) -> tensor<8xi32> + %1 = "tf_device.cluster_func"(%ri, %ri2, %ri3) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "", ""], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>, tensor<*xi1>, tensor<*xi32>) -> tensor<8xi32> tf_device.return %1 : tensor<8xi32> } return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> @@ -1428,7 +1446,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute" // CHECK: device = "TPU_REPLICATED_CORE_1" - %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32> tf_device.return %1 : tensor<8xi32> } return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> @@ -1466,7 +1484,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] // CHECK: device = "TPU_REPLICATED_CORE_1" - %1, %2 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1533,7 +1551,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2) // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] // CHECK: device = "TPU_REPLICATED_CORE_1" - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1601,7 +1619,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[CONST_CONCAT_DIM:[0-9]*]] = "tf.Const"() // CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#2 - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1644,7 +1662,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { // expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}} - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1684,7 +1702,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc func @uneven_output_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { // expected-error@+1 {{incorrect sharding format for outputs. Number of tiled outputs(4) must match the number of logical devices(2)}} - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["", ""], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["", ""], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1790,7 +1808,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4) // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1897,7 +1915,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4) // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1981,7 +1999,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[CONCAT2_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT2_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#3, %[[PARALLEL_EXECUTE_OUTPUT]]#4 // CHECK: %[[CONST_CONCAT3_DIM:[0-9]*]] = "tf.Const"() // CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[CONCAT_OUTPUT]], %[[CONCAT2_OUTPUT]] - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -2066,7 +2084,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#4) // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -2150,7 +2168,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[CONCAT2_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT2_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#2, %[[PARALLEL_EXECUTE_OUTPUT]]#0 // CHECK: %[[CONST_CONCAT3_DIM:[0-9]*]] = "tf.Const"() // CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[CONCAT_OUTPUT]], %[[CONCAT2_OUTPUT]] - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index e4fc3a89d4d33c..2e3e38c700419a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -30,8 +30,8 @@ func @check_default_sharding_for_block_arg_inputs_outputs(%arg0: tensor<*xi32>) } // CHECK-LABEL: func @func_without_sharding -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) +// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> { return %arg0 : tensor<*xi32> } @@ -51,8 +51,8 @@ func @check_default_sharding_for_inputs_outputs(%arg0: tensor<*xi32>) { } // CHECK-LABEL: func @func_without_sharding -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) +// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> { %0 = "tf.A"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> return %0 : tensor<*xi32> @@ -72,8 +72,8 @@ func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) { } // CHECK-LABEL: func @inputs_with_sharding_func -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}) -// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}) +// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) func @inputs_with_sharding_func(%arg0: tensor<*xi32>) -> tensor<*xi32> { %0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32> %1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>) @@ -94,8 +94,8 @@ return } // CHECK-LABEL: func @func_with_sharding -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"}) -// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"}) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"}) +// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"}) func @func_with_sharding(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) { %0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32> %1 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1> @@ -119,8 +119,8 @@ func @check_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { } // CHECK-LABEL: func @func_with_sharding_after_identity -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"}) -// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"}) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"}) +// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"}) func @func_with_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) { %0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> %1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32> @@ -145,8 +145,8 @@ func @check_sharding_after_read_variable(%arg0: tensor<*xi32>, %arg1: tensor<*xi } // CHECK-LABEL: func @func_with_sharding_after_read_variable -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*x!tf.resource>> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*x!tf.resource>> {xla_hlo.sharding = "\04\05\06"}) -// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"}) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*x!tf.resource>> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*x!tf.resource>> {mhlo.sharding = "\04\05\06"}) +// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"}) func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) -> (tensor<*xi32>, tensor<*xi1>) { %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> %1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<32xf32>) -> tensor<32xf32> @@ -173,8 +173,8 @@ func @check_sharding_after_cast_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { } // CHECK-LABEL: func @func_with_sharding_after_cast -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"}) -// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"}) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"}) +// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"}) func @func_with_sharding_after_cast(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) { %0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> %1 = "tf.Cast"(%0) : (tensor<*xi32>) -> tensor<*xi1> @@ -200,8 +200,8 @@ func @check_sharding_inside_functional_op(%arg0: tensor<*xi32>, %arg1: tensor<*x } // CHECK-LABEL: func @func_with_device_training_loop -// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"}) -// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"}) +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"}) +// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"}) func @func_with_device_training_loop(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) { %1:2 = "tf.StatefulPartitionedCall"(%arg0){f= @func_body, config="", config_proto="", executor_type=""} : (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir index 017a331946d8cb..199426b1aa98dc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir @@ -45,8 +45,8 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0" return %10 : tensor } // CHECK-LABEL: func @_func - // CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { - func @_func(%arg0: tensor<2x224x224x3xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { + // CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { + func @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %1 = "tf.Const"() {value = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32> %2 = "tf.Const"() {value = dense<[7, 7, 3, 64]> : tensor<4xi32>} : () -> tensor<4xi32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir new file mode 100644 index 00000000000000..b77e4b1fbd0d3f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir @@ -0,0 +1,79 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-update-embedding-enqueue-op-inputs | FileCheck %s + +// CHECK-LABEL: func @check_enqueue_ops_update_for_eval +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_2:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_3:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_4:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor +func @check_enqueue_ops_update_for_eval(%arg0: tensor, %arg1: tensor, + %arg2 :tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, + %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { + // CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"() + %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> + %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor + + // CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_7]]) + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () + %2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @check_enqueue_ops_update_for_training +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_2:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_3:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_4:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor +func @check_enqueue_ops_update_for_training(%arg0: tensor, %arg1: tensor, + %arg2 :tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, + %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { + // CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"() + %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> + %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor + + %2 = "tf.Const"() {value = dense<0.0> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %3 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + "tf.SendTPUEmbeddingGradients"(%2, %3) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> () + + // CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_6]]) + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () + %4:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) + return +} + +// ----- + +func @check_enqueue_ops_with_different_attr_disallowed(%arg0: tensor, %arg1: tensor, + %arg2 :tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, + %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { + %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> + %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor + // expected-error @+1 {{'tf.EnqueueTPUEmbeddingSparseTensorBatch' op must have a corresponding 'tf.RecvTPUEmbeddingActivations' op}} + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call_123", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () + %2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) + return +} + +// ----- + +func @check_embedding_ops_with_missing_attribute_disallowed(%arg0: tensor, %arg1: tensor, + %arg2 :tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, + %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { + %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> + %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call_123", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () + // expected-error @+1 {{'tf.RecvTPUEmbeddingActivations' op requires attribute '_tpu_embedding_layer'}} + %2:2 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc index dc24e478378aed..e275b0aefaeaff 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc @@ -33,7 +33,7 @@ namespace TFDevice { namespace { -constexpr char kReplicationAttr[] = "xla_hlo.is_same_data_across_replicas"; +constexpr char kReplicationAttr[] = "mhlo.is_same_data_across_replicas"; constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices"; // Analyzes the inputs to ClusterFuncOps in the module, and annotates their diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index ef7f63f82e3ab1..1963931b497874 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -72,22 +72,21 @@ tensorflow::Status RunTPUBridge( } // namespace void CreateTPUBridgePipeline(OpPassManager &pm) { - // Run island coarsening before shape inference to allow more exact shape - // inference using constant folding within islands. - pm.addNestedPass(tf_executor::CreateTFExecutorIslandCoarseningPass()); - // TODO(b/150462212): Move graph pruning before island coarsening. pm.addNestedPass(tf_executor::CreateTFExecutorGraphPruningPass()); + // It is assumed at this stage there are no V1 control flow ops as Graph + // functionalization is ran before import. Ops can be lifted out of + // tf_executor dialect islands/graphs. + pm.addNestedPass(CreateExecutorDialectToFunctionalConversionPass()); // Run shape inference so that tf_executor/tf_device ops created later will // likely to inherit more concrete types. pm.addPass(TF::CreateTFShapeInferencePass()); OpPassManager &func_pm = pm.nest(); func_pm.addPass(CreateTPUClusterFormationPass()); - func_pm.addPass(createCanonicalizerPass()); // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass // because DecomposeResourceOpsPass uses pattern rewriter which hoists // changed constants out of tf_device.Launch. func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass()); - pm.addNestedPass(CreateTPUHostComputationExpansionPass()); + func_pm.addPass(CreateTPUHostComputationExpansionPass()); pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass()); // Run another shape inference pass because resource decomposition might have // created new partial types. @@ -108,6 +107,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { } void CreateTPUBridgePipelineV1(OpPassManager &pm) { + pm.addPass(TF::CreateTFShapeInferencePass()); // For V1 compatibility, we process a module where the graph does not have // feeds and fetched. We extract first the TPU computation in a submodule, // where it'll be in a function with args and returned values, much more like diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index e6a9ce4ad6298a..9d72284da91e6d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -194,6 +194,13 @@ def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)), def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)), (replaceWithValue $arg)>; +//===----------------------------------------------------------------------===// +// Reshape op patterns. +//===----------------------------------------------------------------------===// + +def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape), + (TF_ReshapeOp $arg, $shape)>; + //===----------------------------------------------------------------------===// // Select op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 16de2874fda79c..007baaae433e9d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -29,6 +29,43 @@ limitations under the License. namespace mlir { namespace TF { +// Implements a TF specific policy on when constant folding is allowed. +// Policy: Always allow folding if we do not know the shape of an operand or +// result i.e., one of these values has non-static shape. If we know all the +// shapes, find the total size of the operands and results. Folding of the op is +// allowed if one of the following conditions are met: +// 1. size of results is less than a certain threshold (`kSizeThreshold`), or +// 2. size of results is within a factor (`kSizeFactor`) of size of operands, or +// TODO(b/157226221): Look into other heuristics for constant fold policy. +// LINT.IfChange(folding-policy) +static bool ShouldBeFolded(Operation* inst) { + constexpr int kSizeFactor = 2; + constexpr int64_t kSizeThreshold = (1 << 20); // 128 KB + bool has_unknown_shape = false; + auto get_size = [&](TypeRange types) { + int64_t size = 0; + for (auto t : types) { + auto tensor_type = t.cast(); + // Ignore types with undefined bit widths. + if (!tensor_type.getElementType().isIntOrFloat()) continue; + if (!tensor_type.hasStaticShape()) { + has_unknown_shape = true; + return size; + } + size += tensor_type.getNumElements() * + tensor_type.getElementType().getIntOrFloatBitWidth(); + } + return size; + }; + + int64_t results_size = get_size(inst->getResultTypes()); + int64_t operands_size = get_size(inst->getOperandTypes()); + + return has_unknown_shape || (results_size <= kSizeThreshold) || + (results_size <= kSizeFactor * operands_size); +} +// LINT.ThenChange(../tests/constant-fold.mlir:folding-policy-test) + LogicalResult ConstantFoldFallbackHook( Operation* inst, ArrayRef operands, SmallVectorImpl& results) { // NOLINT @@ -53,6 +90,10 @@ LogicalResult ConstantFoldFallbackHook( return failure(); } + // Determine if we should attempt to fold this operation by considering the + // size/size increase due to folding. + if (!ShouldBeFolded(inst)) return failure(); + // TODO(jpienaar): Currently this persists the entire program execution. This // should instead be per module/set from the Graph being executed in TF (if // any) so that the value of variables in the context could be read. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index d9af88bfbae681..1e622a295ecabe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -52,6 +52,8 @@ enum EinsumEquation { BroadcastMatMul, ReduceSum, TransposeMatMul, + BatchMatMulReducedDim, + TransposeReducedDim, UnsupportedEquation }; @@ -136,6 +138,14 @@ EinsumEquation parseEquation(const std::vector& eqn) { if (is_equal(eqn, {A, B, C, COMMA, B, D, A, ARROW, B, D, C})) { return EinsumEquation::TransposeMatMul; } + // BIN,BINJ->BIJ + if (is_equal(eqn, {A, B, C, COMMA, A, B, C, D, ARROW, A, B, D})) { + return EinsumEquation::BatchMatMulReducedDim; + } + // BIJ,BINJ->BIN + if (is_equal(eqn, {A, B, C, COMMA, A, B, D, C, ARROW, A, B, D})) { + return EinsumEquation::TransposeReducedDim; + } return EinsumEquation::UnsupportedEquation; } @@ -349,6 +359,53 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite( rewriter.replaceOp(op, {bmm_op.getResult()}); return success(); } + if (einsum_eqn == EinsumEquation::BatchMatMulReducedDim) { + // Case "BIN,BINJ->BIJ" + // Reshape LHS + auto lhs_element_type = lhs_type.getElementType(); + const int lhs_dim0 = lhs_shape[0]; + const int lhs_dim1 = lhs_shape[1]; + const int lhs_dim2 = lhs_shape[2]; + const int rhs_dim3 = rhs_shape[3]; + + auto reshaped_lhs = createReshapeOp(lhs, {lhs_dim0, lhs_dim1, 1, lhs_dim2}, + lhs_element_type, loc, &rewriter); + std::vector bmm_shape = {lhs_dim0, lhs_dim1, 1, rhs_dim3}; + auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); + auto bmm_op = rewriter.create( + loc, ArrayRef{bmm_type}, reshaped_lhs, rhs, + rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); + + auto bmm_element_type = bmm_type.getElementType(); + auto final_reshape = createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim3}, + bmm_element_type, loc, &rewriter); + rewriter.replaceOp(op, {final_reshape.getResult()}); + } + if (einsum_eqn == EinsumEquation::TransposeReducedDim) { + // Case "BIJ,BINJ->BIN" + // Reshape LHS + auto lhs_element_type = lhs_type.getElementType(); + const int lhs_dim0 = lhs_shape[0]; + const int lhs_dim1 = lhs_shape[1]; + const int lhs_dim2 = lhs_shape[2]; + const int rhs_dim2 = rhs_shape[2]; + + auto reshaped_lhs = createReshapeOp(lhs, {lhs_dim0, lhs_dim1, 1, lhs_dim2}, + lhs_element_type, loc, &rewriter); + // Transpose RHS + rhs = createTransposeOp(rhs, loc, {0, 1, 3, 2}, &rewriter); + std::vector bmm_shape = {lhs_dim0, lhs_dim1, 1, rhs_dim2}; + auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); + auto bmm_op = rewriter.create( + loc, ArrayRef{bmm_type}, reshaped_lhs, rhs, + rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); + + auto bmm_element_type = bmm_type.getElementType(); + auto final_reshape = createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim2}, + bmm_element_type, loc, &rewriter); + rewriter.replaceOp(op, {final_reshape.getResult()}); + } + return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc index a0cf9c8eb9a253..e076dbae0b6504 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc @@ -60,7 +60,8 @@ void FreezeGlobalTensorsPass::runOnOperation() { for (int i = 0, e = func.getNumArguments(); i < e; ++i) { SmallVector read_variable_ops_to_erase; - auto global_tensor = LookupBoundInput(func, i, symbol_table); + auto global_tensor = + LookupBoundInputOfType(func, i, symbol_table); if (!global_tensor) continue; frozen_global_tensors.insert(global_tensor); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 3480b0cace7ba3..c263dcc75d14e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -47,14 +47,14 @@ namespace mlir { namespace TF { namespace { -using xla_hlo::DotDimensionNumbers; +using mhlo::DotDimensionNumbers; -class ConvertConvOp : public OpConversionPattern { +class ConvertConvOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_hlo::ConvOp conv_op, ArrayRef args, + mhlo::ConvOp conv_op, ArrayRef args, ConversionPatternRewriter &rewriter) const final { if (!IsSupportedConvOp(conv_op)) { return failure(); @@ -120,7 +120,7 @@ class ConvertConvOp : public OpConversionPattern { }; private: - bool IsSamePadding(xla_hlo::ConvOp conv_op, int num_spatial_dims, + bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims, ArrayRef strides, ArrayRef dilation, ArrayRef padding_array) const { for (auto i : llvm::seq(0, num_spatial_dims)) { @@ -142,7 +142,7 @@ class ConvertConvOp : public OpConversionPattern { return true; } - void CreateConvOp(xla_hlo::ConvOp conv_op, ArrayRef strides, + void CreateConvOp(mhlo::ConvOp conv_op, ArrayRef strides, StringRef padding, ArrayRef dilation, bool is_depthwise_conv, ConversionPatternRewriter &rewriter) const { @@ -167,13 +167,13 @@ class ConvertConvOp : public OpConversionPattern { } } - bool IsSupportedConvOp(xla_hlo::ConvOp conv_op) const { + bool IsSupportedConvOp(mhlo::ConvOp conv_op) const { if (!conv_op.lhs().getType().cast().hasStaticShape() || !conv_op.rhs().getType().cast().hasStaticShape() || !conv_op.getType().cast().hasStaticShape()) return false; - // All ones in "lhs_dilation" means this "xla_hlo.conv" op should be + // All ones in "lhs_dilation" means this "mhlo.conv" op should be // converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp". if (conv_op.lhs_dilation().hasValue()) { auto lhs_dilation = conv_op.lhs_dilation().getValue(); @@ -236,15 +236,15 @@ class ConvertConvOp : public OpConversionPattern { } }; -class ConvertSliceOp : public OpConversionPattern { +class ConvertSliceOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_hlo::SliceOp slice_op, ArrayRef args, + mhlo::SliceOp slice_op, ArrayRef args, ConversionPatternRewriter &rewriter) const final { DenseIntElementsAttr strides = slice_op.strides(); - // Strides must be 1 otherwise we cannot legalize this `xla_hlo.slice` op. + // Strides must be 1 otherwise we cannot legalize this `mhlo.slice` op. if (!strides.isSplat() || strides.getSplatValue().cast().getInt() != 1) return failure(); @@ -374,10 +374,10 @@ class DotDimensionsInfo { DimensionSetVector out_dimensions_; }; -// Converts xla_hlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be +// Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be // inserted to convert to well-formed matrix multiply. Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) { - auto dot_general_op = cast(old_op); + auto dot_general_op = cast(old_op); auto lhs_type = dot_general_op.lhs().getType().cast(); auto rhs_type = dot_general_op.rhs().getType().cast(); auto result_type = dot_general_op.getResult().getType().cast(); @@ -405,7 +405,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) { lhs_dot_dimensions_info.batch_dimensions().SizesArray(), lhs_dot_dimensions_info.out_dimensions().SizesArray(), lhs_dot_dimensions_info.contracting_dimensions().SizesArray()); - auto lhs_transposed = rewriter.create( + auto lhs_transposed = rewriter.create( loc, RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()), dot_general_op.lhs(), @@ -423,7 +423,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) { rhs_dot_dimensions_info.batch_dimensions().SizesArray(), rhs_dot_dimensions_info.contracting_dimensions().SizesArray(), rhs_dot_dimensions_info.out_dimensions().SizesArray()); - auto rhs_transposed = rewriter.create( + auto rhs_transposed = rewriter.create( loc, RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()), dot_general_op.rhs(), @@ -438,7 +438,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) { lhs_dot_dimensions_info.FlattenedOutDimensionSize()}, llvm::ArrayRef{ lhs_dot_dimensions_info.FlattenedContractingDimensionSize()}); - auto lhs_flattend = rewriter.create( + auto lhs_flattend = rewriter.create( loc, RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()), lhs_transposed.getResult()); @@ -450,7 +450,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) { rhs_dot_dimensions_info.FlattenedContractingDimensionSize()}, llvm::ArrayRef{ rhs_dot_dimensions_info.FlattenedOutDimensionSize()}); - auto rhs_flattend = rewriter.create( + auto rhs_flattend = rewriter.create( loc, RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()), rhs_transposed.getResult()); @@ -466,14 +466,14 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) { loc, RankedTensorType::get(matmul_shape, result_type.getElementType()), lhs_flattend.getResult(), rhs_flattend.getResult()); auto reshaped = - rewriter.create(loc, result_type, matmul.getResult()); + rewriter.create(loc, result_type, matmul.getResult()); return reshaped.getResult(); } -// This function tries to match that the "xla_hlo::ReduceOp" only has one -// input, one init_value and one result. Also "xla_hlo::ReduceOp" has two ops +// This function tries to match that the "mhlo::ReduceOp" only has one +// input, one init_value and one result. Also "mhlo::ReduceOp" has two ops // in the region, and the last one is return op. -LogicalResult MatchReduceOpInput(xla_hlo::ReduceOp reduce_op) { +LogicalResult MatchReduceOpInput(mhlo::ReduceOp reduce_op) { if (reduce_op.operands().size() != 1 || reduce_op.init_values().size() != 1 || reduce_op.getResults().size() != 1) return failure(); @@ -489,23 +489,23 @@ LogicalResult MatchReduceOpInput(xla_hlo::ReduceOp reduce_op) { return success(); } -// TODO(jingpu): This "xla_hlo::ReduceOp" can corresponds to many TF ops +// TODO(jingpu): This "mhlo::ReduceOp" can corresponds to many TF ops // with different ops in reduce_op.body. Now we only match to "tf.Max", "tf.Min" // and "tf.Sum". -class ConvertReduceOpToTfSum : public OpConversionPattern { +class ConvertReduceOpToTfSum : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_hlo::ReduceOp reduce_op, ArrayRef args, + mhlo::ReduceOp reduce_op, ArrayRef args, ConversionPatternRewriter &rewriter) const final { if (failed(MatchReduceOpInput(reduce_op))) return failure(); Operation *first_op = &reduce_op.body().front().front(); - if (!llvm::isa(first_op)) return failure(); + if (!llvm::isa(first_op)) return failure(); // In `MatchReduceOpInput` function, we already match that the - // "xla_hlo::ReduceOp" only has one input, one init_value and one result. + // "mhlo::ReduceOp" only has one input, one init_value and one result. auto input = reduce_op.operands()[0]; // Get reduction dimension. DenseIntElementsAttr dimension = reduce_op.dimensions(); @@ -531,20 +531,20 @@ class ConvertReduceOpToTfSum : public OpConversionPattern { }; }; -class ConvertReduceOpToTfMax : public OpConversionPattern { +class ConvertReduceOpToTfMax : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_hlo::ReduceOp reduce_op, ArrayRef args, + mhlo::ReduceOp reduce_op, ArrayRef args, ConversionPatternRewriter &rewriter) const final { if (failed(MatchReduceOpInput(reduce_op))) return failure(); Operation *first_op = &reduce_op.body().front().front(); - if (!llvm::isa(first_op)) return failure(); + if (!llvm::isa(first_op)) return failure(); // In `MatchReduceOpInput` function, we already match that the - // "xla_hlo::ReduceOp" only has one input, one init_value and one result. + // "mhlo::ReduceOp" only has one input, one init_value and one result. auto input = reduce_op.operands()[0]; // Get reduction dimension. DenseIntElementsAttr dimension = reduce_op.dimensions(); @@ -572,20 +572,20 @@ class ConvertReduceOpToTfMax : public OpConversionPattern { }; }; -class ConvertReduceOpToTfMin : public OpConversionPattern { +class ConvertReduceOpToTfMin : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_hlo::ReduceOp reduce_op, ArrayRef args, + mhlo::ReduceOp reduce_op, ArrayRef args, ConversionPatternRewriter &rewriter) const final { if (failed(MatchReduceOpInput(reduce_op))) return failure(); Operation *first_op = &reduce_op.body().front().front(); - if (!llvm::isa(first_op)) return failure(); + if (!llvm::isa(first_op)) return failure(); // In `MatchReduceOpInput` function, we already match that the - // "xla_hlo::ReduceOp" only has one input, one init_value and one result. + // "mhlo::ReduceOp" only has one input, one init_value and one result. Value input = reduce_op.operands()[0]; // Get reduction dimension. DenseIntElementsAttr dimension = reduce_op.dimensions(); @@ -645,10 +645,10 @@ ConstantOp ShapeToConst(PatternRewriter &rewriter, Value value) { return rewriter.create(value.getLoc(), attr_type, attr); } -// Converts xla_hlo.dot to tf.MatMul. Reshape ops will be inserted when +// Converts mhlo.dot to tf.MatMul. Reshape ops will be inserted when // necessary. Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) { - auto dot_op = cast(old_op); + auto dot_op = cast(old_op); const mlir::Location loc = dot_op.getLoc(); // Normalizes a ShapedType to 2d if the ShapedType is less than 2d by // inserting dummy 1-element dimensions in the begining. Does nothing if the @@ -677,7 +677,7 @@ Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) { return input; } - auto reshape = rewriter.create( + auto reshape = rewriter.create( loc, normalize_rank(input_type), input); return reshape.getResult(); }; @@ -694,7 +694,7 @@ Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) { loc, normalize_rank(output_type), a, b, /*transpose_a=*/rewriter.getBoolAttr(false), transpose_b); auto reshape = - rewriter.create(loc, output_type, matmul.product()); + rewriter.create(loc, output_type, matmul.product()); return reshape.getResult(); } @@ -752,7 +752,7 @@ void LegalizeHloToTf::runOnFunction() { target.addLegalDialect(); target.addLegalOp(); if (failed(applyPartialConversion(getFunction(), target, patterns))) { - getFunction().emitError("xla_hlo to TF legalization failed."); + getFunction().emitError("mhlo to TF legalization failed."); signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index acf9cd27b47e6c..6b7d7178ab64f9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -154,6 +154,14 @@ foreach fromToBinPair = [[TF_DivNoNanOp, TF_DivOp], def LowerFillOp : Pat<(TF_FillOp $dims, $value), (TF_BroadcastToOp $value, $dims)>; +//===----------------------------------------------------------------------===// +// NaN op patterns. +//===----------------------------------------------------------------------===// + +def LowerIsNanOp : Pat<(TF_IsNanOp $x), + (TF_EqualOp $x, (TF_ConstOp:$nan (GetScalarNanOfType $x)), + /*incompatible_shape_error*/ConstBoolAttrTrue)>; + //===----------------------------------------------------------------------===// // L2Loss op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 85efb761d8bb6d..5af8a0195a4481 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -287,6 +287,11 @@ CreateTPUExtractHeadTailOutsideCompilationPass(); // that are only used for host computation. std::unique_ptr> CreateTPUHostComputationExpansionPass(); +// Creates a pass that updates inputs to TPU embedding layer enqueue ops so that +// correct ops are invoked during training and evaluation. +std::unique_ptr> +CreateTPUUpdateEmbeddingEnqueueOpInputsPass(); + // Creates a pass that extract outside compilation (CPU ops inside TPU cluster) // ops to a separate parallel_execute region to run on CPU. std::unique_ptr> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc b/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc new file mode 100644 index 00000000000000..f916706a5977b1 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc @@ -0,0 +1,116 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" + +namespace mlir { +namespace tf_saved_model { +namespace { +using mlir::Operation; +using mlir::TF::VarHandleOp; + +class RemoveVariablesInSessionInitializerPass + : public PassWrapper> { + public: + void runOnOperation() override; +}; + +void RecursiveRemove(Operation* op, + llvm::SmallVectorImpl& erase_list, + llvm::SmallPtrSetImpl& dead_ops) { + for (mlir::Value res : op->getResults()) { + for (Operation* user : res.getUsers()) { + if (!dead_ops.insert(user).second) continue; + RecursiveRemove(user, erase_list, dead_ops); + } + } + + erase_list.push_back(op); + + for (auto& use : op->getOpOperands()) { + if (auto op_result = use.get().dyn_cast()) { + Operation* def = op_result.getDefiningOp(); + if (!dead_ops.insert(def).second) continue; + RecursiveRemove(def, erase_list, dead_ops); + } + } +} + +void RemoveVariables(llvm::ArrayRef vars) { + // TODO(b/160906885): Repalce the following code with an non-recursive one. + llvm::SmallVector erase_list; + llvm::SmallPtrSet dead_ops; + + // Marks all the variables dead. + dead_ops.insert(vars.begin(), vars.end()); + + // Removes relevant ops in topological order. + for (auto& op : vars) RecursiveRemove(op, erase_list, dead_ops); + + // Erases the ops. + for (auto op : erase_list) op->erase(); +} + +void RemoveVariablesInSessionInitializerPass::runOnOperation() { + ModuleOp module = getOperation(); + SessionInitializerOp session_init_op = GetSessionInitializerOp(module); + + if (!session_init_op) return; + + SymbolTable symbol_table(module); + FuncOp init_func_op = + symbol_table.lookup(session_init_op.initializer()); + + if (!init_func_op) { + module.emitError("no session initializer function found"); + return signalPassFailure(); + } + + if (init_func_op.getBlocks().size() != 1) { + init_func_op.emitError("expects exactly one block in the MLIR function"); + return signalPassFailure(); + } + + auto var_handle_ops = init_func_op.getBlocks().front().getOps(); + llvm::SmallVector init_vars(var_handle_ops.begin(), + var_handle_ops.end()); + RemoveVariables(init_vars); +} + +} // namespace + +static PassRegistration pass( + "tf-saved-model-remove-vars-in-session-initializer", + "Remove variables in tf saved model's session initializer."); + +std::unique_ptr> +CreateRemoveVariablesInSessionInitializerPass() { + return std::make_unique(); +} + +} // namespace tf_saved_model +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 6eedfbbaf4b3d8..b16868311f0f74 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -103,8 +103,8 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // Map block arg to replica arg. mapping.clear(); for (auto& block_arg : replicate_op.GetBody().getArguments()) - mapping.map(block_arg, replicate_op.getOperand( - block_arg.getArgNumber() * num_replicas + i)); + mapping.map(block_arg, + replicate_op.GetReplicaOperandForBlockArgument(block_arg, i)); // Copy over replicate region into replica island. replicate_op.body().cloneInto(&replica.body(), mapping); @@ -184,6 +184,7 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // tf_executor.yield %a1, %b1 : tensor, tensor // } void CreateIslandsFromReplicate(const Dialect* tf_dialect, + tf_executor::GraphOp graph_op, tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op) { OpBuilder builder(island_op); @@ -225,18 +226,36 @@ void CreateIslandsFromReplicate(const Dialect* tf_dialect, island_op.control().replaceAllUsesWith(island_sink.control()); } + // Replicas with no uses should be pinned to a graph fetch so they still + // execute. + llvm::SmallVector unused_replica_controls; + for (auto& replica : replicas) + if (replica.use_empty()) + unused_replica_controls.push_back(replica.control()); + + if (!unused_replica_controls.empty()) { + tf_executor::FetchOp fetch = graph_op.GetFetch(); + auto fetches = llvm::to_vector<8>(fetch.getOperands()); + fetches.append(unused_replica_controls.begin(), + unused_replica_controls.end()); + builder.setInsertionPoint(fetch); + builder.create(fetch.getLoc(), fetches); + fetch.erase(); + } + island_op.erase(); } // Finds islands with a single `tf_device.replicate` and create individual // islands per replica of the replicate. void LowerSingleIslandReplicateToIslands(const Dialect* tf_dialect, + tf_executor::GraphOp graph_op, tf_executor::IslandOp island_op) { if (!island_op.WrapsSingleOp()) return; if (auto replicate_op = llvm::dyn_cast(&island_op.GetBody().front())) - CreateIslandsFromReplicate(tf_dialect, island_op, replicate_op); + CreateIslandsFromReplicate(tf_dialect, graph_op, island_op, replicate_op); } void ReplicateToIslandPass::runOnFunction() { @@ -246,8 +265,10 @@ void ReplicateToIslandPass::runOnFunction() { getFunction().emitError() << "'tf' dialect is not registered"; } - getFunction().walk([&](tf_executor::IslandOp island_op) { - LowerSingleIslandReplicateToIslands(tf_dialect, island_op); + getFunction().walk([&](tf_executor::GraphOp graph_op) { + for (auto island_op : + llvm::make_early_inc_range(graph_op.getOps())) + LowerSingleIslandReplicateToIslands(tf_dialect, graph_op, island_op); }); } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 2d30bbd1b938fd..6a67f0bea0adcb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -375,7 +375,7 @@ LogicalResult FindResourceArgUseInfo( info.data_type = assign.value().getType(); continue; } - if (isa(user) || isa(user)) { + if (isa(user)) { // Stacks will be handled by a separate pass. do_not_touch = true; break; @@ -880,23 +880,34 @@ LogicalResult HandlePartitionedCallOpCallee( result->arg_data_type_and_updated_output_index[entry.getFirst()] = { entry.getSecond(), -1}; } - llvm::SmallVector new_retvals; - for (auto val : callee.front().getTerminator()->getOperands()) { - // Remove resource type outputs. - if (getElementTypeOrSelf(val.getType()).isa()) continue; - new_retvals.push_back(val); + llvm::SmallVector retval_indices_to_preserve; + for (auto& val : callee.front().getTerminator()->getOpOperands()) { + // Store indices of results that are not resources. + if (!getElementTypeOrSelf(val.get().getType()).isa()) + retval_indices_to_preserve.push_back(val.getOperandNumber()); } + int64_t num_retvals = retval_indices_to_preserve.size(); + llvm::SmallVector new_retvals; // Lift resources. LiftArgRetResourcesForFunction( callee, remaining_resource_data_types, [&](int64_t index, Value value) { result->arg_data_type_and_updated_output_index[index].second = - new_retvals.size(); + num_retvals++; new_retvals.push_back(value); }); + auto old_return = callee.front().getTerminator(); + llvm::SmallVector old_and_new_retvals; + old_and_new_retvals.reserve(retval_indices_to_preserve.size() + + new_retvals.size()); + for (int64_t retval_index : retval_indices_to_preserve) + old_and_new_retvals.push_back(old_return->getOperand(retval_index)); + + old_and_new_retvals.append(new_retvals.begin(), new_retvals.end()); // Replace old return with the new ones with update values. OpBuilder builder(old_return); - auto new_return = builder.create(old_return->getLoc(), new_retvals); + auto new_return = + builder.create(old_return->getLoc(), old_and_new_retvals); old_return->erase(); callee.setType(FunctionType::get( callee.getType().getInputs(), diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 2afc1c2d7b6666..f9c81634ae5bbe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -205,9 +205,9 @@ GetSubtypes(Type type) { // Returns whether type can be further refined. bool CanBeRefined(Type type) { auto shape_type = type.dyn_cast(); - return shape_type && (!shape_type.hasStaticShape() || - shape_type.getElementType().isa() || - shape_type.getElementType().isa()); + return shape_type && + (!shape_type.hasStaticShape() || + shape_type.getElementType().isa()); } // Infers the shape from a (Stateful)PartionedCall operation by looking up the @@ -712,8 +712,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { // The shape function of these ops sometimes does not propagate subtypes // (handle shapes) for resource and variant types. We use a simple passthrough // to make sure they are preserved in the output. - if (isa(op) || isa(op) || - isa(op) || isa(op)) { + if (isa(op)) { return RefineTypeForPassThroughOperands(op, op->getOperands(), op->getResults()); } @@ -729,7 +728,8 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { // Handle call operations by looking up callee and infering return shape as // needed. - if (isa(op) || isa(op)) + if (isa( + op)) return InferShapeForCall(op); // tf.Cast are only inferred if they have at least one user in the TF dialect @@ -889,8 +889,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { }; auto new_element_type = shaped_type.getElementType(); // Populate the handle shapes for a resource/variant. - if (new_element_type.isa() || - new_element_type.isa()) { + if (new_element_type.isa()) { auto handle_shapes_types = c.output_handle_shapes_and_types(output); if (handle_shapes_types) { SmallVector subtypes; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index 734a7d04a86ebc..5e095a311eebbf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -488,7 +488,7 @@ LogicalResult DecomposeStackOpsInternal( llvm::StringMap* decomposed_partitioned_call_callees) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) { - if (llvm::isa(&op) || llvm::isa(&op)) { + if (llvm::isa(&op)) { // Removes identity nodes in the block. The device computation does not // need such nodes to carry information. op.replaceAllUsesWith(op.getOperands()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index cbd24f8a8154c5..9c659a950780b3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -809,7 +809,7 @@ LogicalResult DecomposeTensorArrayOps( llvm::StringMap* decomposed_partitioned_call_callees) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) { - if (llvm::isa(&op) || llvm::isa(&op)) { + if (llvm::isa(&op)) { op.replaceAllUsesWith(op.getOperands()); op.erase(); } else if (auto ta = llvm::dyn_cast(&op)) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h index 242c4c002c9a25..f7a73dc1561865 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h @@ -30,6 +30,16 @@ std::unique_ptr> CreateOptimizeGlobalTensorsPass(); // Creates a pass that freezes tf_saved_model.global_tensor ops. std::unique_ptr> CreateFreezeGlobalTensorsPass(); +// Creates as pass that removes variables in the session initializer. +// This job is required with lifting variable passes. Originally, the session +// initializer function does assigning variables. However, the read-only +// variable assignments will be done via lifting variables pass by converting +// the read-only variables to constant ops, instead. This pass removes the +// redundant operations. This pass should be located in front of the pass for +// lifting read-only variables. +std::unique_ptr> +CreateRemoveVariablesInSessionInitializerPass(); + // Creates as pass that creates GlobalTensorOp for each variable from function // arguments and converts the function arguments to the corresponding saved // model arguments. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index f3337ec0dfc2b6..9abf67b62a94fc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -49,7 +49,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { @@ -330,40 +329,51 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { // Index attribute value stored on TPUReplicatedInput op. These will be used // later for dynamic padder. llvm::SmallVector replicated_input_indices; + llvm::SmallVector packed_input_indices; bool has_replicated_input_index = false; // Indices of the replicate op's arguments that are mirrored variables. llvm::SmallVector mirrored_variable_indices; // Check if number of operands of each used TPUReplicatedInput op matches - // `num_replicas`. Collect all their operands and associated type for creating - // the replicate op. + // `num_replicas` or 1. Collect all their operands and associated type for + // creating the replicate op. llvm::SmallVector, 8> replicated_inputs; + llvm::SmallVector packed_inputs; for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) { auto input = pos_and_input.value(); - if (input->getNumOperands() != num_replicas) - return input->emitOpError() << "requires " << num_replicas << " operands"; - - replicated_inputs.push_back( - {input->getOperands(), input->getOperand(0).getType()}); + bool is_packed = llvm::cast(input).is_packed(); + int num_inputs = is_packed ? 1 : num_replicas; + if (input->getNumOperands() != num_inputs) + return input->emitOpError() << "requires " << num_inputs << " operands"; auto tpu_replicated_input = llvm::cast(input); int64_t tpu_replicated_input_index = tpu_replicated_input.index().getSExtValue(); - replicated_input_indices.push_back(tpu_replicated_input_index); + if (is_packed) { + packed_inputs.push_back(input->getOperand(0)); + packed_input_indices.push_back(tpu_replicated_input_index); + } else { + replicated_inputs.push_back( + {input->getOperands(), input->getOperand(0).getType()}); + replicated_input_indices.push_back(tpu_replicated_input_index); + } if (tpu_replicated_input_index != -1) has_replicated_input_index = true; if (tpu_replicated_input.is_mirrored_variable()) mirrored_variable_indices.push_back(pos_and_input.index()); } + replicated_input_indices.append(packed_input_indices.begin(), + packed_input_indices.end()); + // Create replicate op. OpBuilder builder(cluster); auto replicate_op = builder.create( cluster.getLoc(), num_replicas, llvm::SmallDenseMap>(), - replicated_inputs, cluster.getResultTypes()); + replicated_inputs, packed_inputs, cluster.getResultTypes()); if (has_replicated_input_index) replicate_op.setAttr(kReplicatedInputIndicesAttr, builder.getI64ArrayAttr(replicated_input_indices)); @@ -493,19 +503,9 @@ void TPUClusterFormation::runOnFunction() { if (failed(FormClustersInBlock(&block, metadata_map))) return signalPassFailure(); - auto island_result = getFunction().walk([&](tf_executor::IslandOp island) { - if (failed(FormClustersInBlock(&island.GetBody(), metadata_map))) - return WalkResult::interrupt(); - - return WalkResult::advance(); - }); - - if (island_result.wasInterrupted()) return signalPassFailure(); - // Remove TPUReplicatedInput and TPUReplicatedOutput nodes. auto remove_result = getFunction().walk([&](Operation* op) { - if (!llvm::isa(op) && - !llvm::isa(op)) + if (!llvm::isa(op)) return WalkResult::advance(); // Forward operand to result. When `num_replicas` attribute is 1, no diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc index 5f33654d070006..09339928b27f87 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc @@ -151,7 +151,7 @@ LogicalResult GetRemappedPaddings( // Inserts padding maps for relevant arguments as argument attributes on the // encapsulated function. The padding maps will be in the form of: -// %arg0 : type {xla_hlo.padding_map = {shape_indices = [...], +// %arg0 : type {mhlo.padding_map = {shape_indices = [...], // padding_arg_indices = [...]}} void AnnotateFunctionArgumentsWithPaddings( FuncOp func, @@ -174,7 +174,7 @@ void AnnotateFunctionArgumentsWithPaddings( "padding_arg_indices", builder.getI32ArrayAttr(padding.getSecond().second)); func.setArgAttr( - padding.getFirst(), "xla_hlo.padding_map", + padding.getFirst(), "mhlo.padding_map", builder.getDictionaryAttr({shape_indices, padding_arg_indices})); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index ec9b3df525f075..050ba24417ff2f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -69,6 +69,7 @@ constexpr char kPaddingMapAttr[] = "padding_map"; constexpr char kDeviceAttr[] = "device"; constexpr char kDevicesAttr[] = "devices"; constexpr char kVersionsAttr[] = "tf.versions"; +constexpr char kUseXlaSpmdAttr[] = "use_spmd_for_xla_partitioning"; constexpr char kBadStringArrayElementMsg[] = "bad '{0}' attribute at index {1}, not a string"; @@ -331,6 +332,10 @@ LogicalResult SetMetadataProtoFromClusterFuncOp( if (xla_device_assignment.hasValue()) *metadata->mutable_device_assignment() = std::move(xla_device_assignment.getValue()); + auto use_spmd_attr = op.getAttrOfType(kUseXlaSpmdAttr); + if (!use_spmd_attr) + return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr)); + metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue()); if (failed(SetMetadataProtoArgs(op, metadata))) return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index 1203eea2f8411b..0b9eaba8c979f8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -39,7 +39,7 @@ namespace mlir { namespace TFTPU { namespace { -constexpr char kShardingAttr[] = "xla_hlo.sharding"; +constexpr char kShardingAttr[] = "mhlo.sharding"; struct TPUShardingIdentificationPass : public PassWrapper { + void runOnFunction() override; +}; + +// Extracts `_tpu_embedding_layer` attribute from TPU embedding ops and +// clear the attribute from the operation. This ensures that future optimization +// passes does not trigger additional logic due to presence of this attribute. +LogicalResult ExtractEmbeddingAttribute( + Operation* op, llvm::StringMap* embedding_op_map) { + auto embedding_attr = op->getAttrOfType(kTPUEmbeddingAttr); + if (!embedding_attr) + return op->emitOpError("requires attribute '_tpu_embedding_layer'"); + + if (!embedding_op_map->insert({embedding_attr.getValue(), op}).second) + return op->emitOpError( + "found duplicate TPU embedding ops potentially from multiple " + "TPUEmbedding layers"); + + op->removeAttr(kTPUEmbeddingAttr); + return success(); +} + +LogicalResult FindTPUEmbeddingOps( + FuncOp func_op, llvm::StringMap* enqueue_op_map, + llvm::StringMap* recv_activation_op_map, + llvm::StringMap* send_gradient_op_map) { + auto walk_result = func_op.walk([&](Operation* op) { + if (llvm::isa(op)) + if (failed(ExtractEmbeddingAttribute(op, recv_activation_op_map))) + return WalkResult::interrupt(); + + if (llvm::isa(op)) + if (failed(ExtractEmbeddingAttribute(op, send_gradient_op_map))) + return WalkResult::interrupt(); + + if (llvm::isa(op)) + if (failed(ExtractEmbeddingAttribute(op, enqueue_op_map))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + return failure(walk_result.wasInterrupted()); +} + +// Updates the operand of TPU embedding enqueue ops depending on whether +// the graph is in training mode or in non-training mode. +// If SendTPUEmbeddingGradients op is present, this means that graph is in +// training mode. As so, correctly feed in `then` branch value of SelectV2 +// operand as inputs to the TPU embedding enqueue ops. +LogicalResult UpdateEmbeddingEnqueueOpInput( + const llvm::StringMap& enqueue_op_map, + const llvm::StringMap& recv_activation_op_map, + const llvm::StringMap& send_gradient_op_map) { + for (const auto& it : enqueue_op_map) { + const auto& embedding_attr = it.getKey(); + Operation* embedding_op = it.second; + if (!recv_activation_op_map.count(embedding_attr)) + return embedding_op->emitOpError() + << "must have a corresponding '" + << TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op"; + + // TPU Embedding enqueue ops take different inputs depending on whether + // graph is in training mode or in eval/prediction mode. The inputs to the + // enqueue ops are present/listed as operands to SelectV2 op. Then branch + // operand of the SelectV2 op represents input to take during training + // and else branch operand represents input to take during + // prediction/evaluation. If SendTPUEmbeddingGradients op exists in the + // graph, then graph is in training mode, so correctly forward the input + // of SelectV2 op as operand to the TPU embedding enqueue op. + bool is_training = send_gradient_op_map.count(embedding_attr); + for (auto enqueue_operand : embedding_op->getOperands()) { + if (auto select = llvm::dyn_cast_or_null( + enqueue_operand.getDefiningOp())) { + enqueue_operand.replaceAllUsesWith(is_training ? select.t() + : select.e()); + } + } + } + + return success(); +} + +void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() { + OpBuilder builder(&getContext()); + auto func_op = getFunction(); + + // All TPU embedding layer related ops are annotated with + // `_tpu_embedding_layer` attribute along with corresponding string attribute. + // Store all tpu embedding layer related ops with value of + // `_tpu_embedding_layer` attribute as map key. + llvm::StringMap enqueue_op_map; + llvm::StringMap recv_activation_op_map; + llvm::StringMap send_gradient_op_map; + if (failed(FindTPUEmbeddingOps(func_op, &enqueue_op_map, + &recv_activation_op_map, + &send_gradient_op_map))) + return signalPassFailure(); + + if (enqueue_op_map.size() != recv_activation_op_map.size()) { + func_op.emitError() << "expects the number of embedding enqueue ops to " + "match the number of '" + << TF::RecvTPUEmbeddingActivationsOp::getOperationName() + << "' ops"; + return signalPassFailure(); + } + + if (failed(UpdateEmbeddingEnqueueOpInput( + enqueue_op_map, recv_activation_op_map, send_gradient_op_map))) + return signalPassFailure(); +} + +} // anonymous namespace + +std::unique_ptr> +CreateTPUUpdateEmbeddingEnqueueOpInputsPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-update-embedding-enqueue-op-inputs", + "Updates inputs to TPU embedding enqueue ops depending on whether graph " + "is in training mode or in evaluation mode."); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index b8f55e3b9793b6..5bc6bd4e053ca4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -192,15 +192,25 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( // The XLA backend does not yet support formatting 64-bit data types. if (data_type.getIntOrFloatBitWidth() == 64) continue; + const auto& block_arg = replicate.GetBody().getArgument(replicate_arg); + + int64_t num_inputs = 0; + if (replicate.IsReplicatedBlockArgument(block_arg)) { + num_inputs = num_replicas; + } else { + num_inputs = 1; + } + // We have found a mirrored variable which is an input to the replicated // `execute`. Now find if this mirrored variable is a pass-through of while // arguments. llvm::SmallVector while_args; - for (int64_t i = 0; i < num_replicas; ++i) { + for (int64_t i = 0; i < num_inputs; ++i) { llvm::SmallPtrSet skipped_identities; - auto replicate_operand = - SkipIdentity(replicate.getOperand(num_replicas * replicate_arg + i), - /*allow_other_use=*/false, &skipped_identities); + + auto replicate_operand = SkipIdentity( + replicate.GetReplicaOperandForBlockArgument(block_arg, i), + /*allow_other_use=*/false, &skipped_identities); auto block_arg = replicate_operand.dyn_cast(); // To qualify for a valid pass-through mirrored variable, it must satisfy // 1) it is the body's argument; @@ -267,26 +277,39 @@ tf_device::ReplicateOp AddInputsToReplicateOp( llvm::SmallVector, Type>, 8> new_replicated_inputs; + llvm::SmallVector new_packed_inputs; llvm::SmallVector, 8> replicated_inputs; - replicated_inputs.reserve(replicate.GetBody().getNumArguments()); - for (auto arg : llvm::enumerate(replicate.GetBody().getArguments())) { - int64_t i = arg.index(); + replicated_inputs.reserve(replicate.GetNumReplicatedBlockArguments()); + new_packed_inputs.reserve(replicate.GetNumPackedBlockArguments()); + for (const auto& arg : replicate.GetReplicatedBlockArguments()) { replicated_inputs.emplace_back(); - for (int64_t j = i * num_replicas; j < (i + 1) * num_replicas; ++j) { - replicated_inputs.back().push_back(replicate.getOperand(j)); + for (int64_t i = 0; i < num_replicas; ++i) { + replicated_inputs.back().push_back( + replicate.GetReplicaOperandForBlockArgument(arg, i)); } - new_replicated_inputs.emplace_back(replicated_inputs.back(), - arg.value().getType()); + new_replicated_inputs.emplace_back(replicated_inputs.back(), arg.getType()); + } + for (const auto& arg : replicate.GetPackedBlockArguments()) { + new_packed_inputs.emplace_back( + replicate.GetReplicaOperandForBlockArgument(arg, /*replica=*/0)); } new_replicated_inputs.emplace_back(new_inputs, new_inputs.front().getType()); OpBuilder builder(replicate); auto new_replicate = builder.create( replicate.getLoc(), num_replicas, devices, new_replicated_inputs, + new_packed_inputs, llvm::to_vector<8>( replicate.GetBody().getTerminator()->getOperandTypes())); for (auto arg : replicate.GetBody().getArguments()) { - arg.replaceAllUsesWith( - new_replicate.GetBody().getArgument(arg.getArgNumber())); + if (replicate.IsReplicatedBlockArgument(arg)) { + arg.replaceAllUsesWith( + new_replicate.GetBody().getArgument(arg.getArgNumber())); + } else { + // There is a new added replicated state variable between replicated args + // and packed args. + arg.replaceAllUsesWith( + new_replicate.GetBody().getArgument(arg.getArgNumber() + 1)); + } } for (auto& op : llvm::make_early_inc_range(replicate.GetBody())) { op.moveBefore(&new_replicate.GetBody(), new_replicate.GetBody().end()); @@ -495,7 +518,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, } reformat_operands.push_back(compile_launch.getResult(1)); reformat_operands.push_back(replicate.GetBody().getArgument( - replicate.GetBody().getNumArguments() - 1)); + replicate.GetNumReplicatedBlockArguments() - 1)); builder.setInsertionPoint(execute_launch); auto reformat_op = builder.create( execute_launch.getLoc(), llvm::ArrayRef{}, reformat_operands, @@ -507,14 +530,20 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, // replicate op. llvm::SmallVector, Type>, 8> unformat_replicate_operands; + llvm::SmallVector unformat_packed_operands; for (const auto& entry : execute_arg_to_outer_args) { - unformat_replicate_operands.emplace_back(entry.second, - entry.second.front().getType()); + if (entry.second.size() > 1) { + unformat_replicate_operands.emplace_back(entry.second, + entry.second.front().getType()); + } else { + unformat_packed_operands.emplace_back(entry.second.front()); + } } llvm::SmallVector state_var_vals(state_vars.size()); for (const auto& entry : llvm::enumerate(state_vars)) { state_var_vals[entry.index()] = entry.value().resource(); } + // Add the replicated state var to the end of the replicate operands. unformat_replicate_operands.emplace_back(state_var_vals, state_var_vals.front().getType()); // Build a constant default key to specify that the unformatting should @@ -529,13 +558,21 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, // With all replicated inputs, now build the replicate op. auto unformat_replicate = builder.create( while_op.getLoc(), num_replicas, devices, unformat_replicate_operands, - ArrayRef{}); + unformat_packed_operands, ArrayRef{}); // Then build the unformat op in the replicate op. builder.setInsertionPointToEnd(&unformat_replicate.GetBody()); llvm::SmallVector unformat_operands; - for (auto arg : unformat_replicate.GetBody().getArguments()) { - unformat_operands.push_back(arg); - } + // Add the replicated state var (the last replicated operand of the + // ReplicateOp) as the last operand of TPUReshardVariablesOp. + BlockArgument state = unformat_replicate.GetReplicatedBlockArguments().back(); + auto replicated_block_args = + unformat_replicate.GetReplicatedBlockArguments().drop_back(1); + auto packed_block_args = unformat_replicate.GetPackedBlockArguments(); + unformat_operands.append(replicated_block_args.begin(), + replicated_block_args.end()); + unformat_operands.append(packed_block_args.begin(), packed_block_args.end()); + unformat_operands.push_back(state); + // Insert the default key as the second last operand. unformat_operands.insert( unformat_operands.begin() + unformat_operands.size() - 1, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index b6fad8f598752d..7983dfe0065038 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -576,9 +576,8 @@ StatusOr> Exporter::Convert( // Adds nodes for operations. for (Operation& inst : graph_op.GetBody()) { for (auto type : inst.getResultTypes()) - if (!type.isa() && - !type.isa() && - !type.isa()) + if (!type.isa()) return errors::InvalidArgument( "Values must be of tensor type, TensorFlow control type, or " "TensorFlow token type. Found ", diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 617f413e2382a3..c7d5339f93cce8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -102,6 +102,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" @@ -119,6 +120,7 @@ namespace tensorflow { using mlir::NamedAttrList; using mlir::TensorType; using mlir::TF::VarHandleOp; +using mlir::tf_saved_model::AssetOp; using mlir::tf_saved_model::GlobalTensorOp; using mlir::tf_saved_model::SessionInitializerOp; using stream_executor::port::StatusOr; @@ -2894,8 +2896,8 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) { llvm::SmallVector new_input_types; for (int i = 0, e = func.getNumArguments(); i < e; i++) { auto arg = func.front().getArgument(i); - auto global_tensor = - mlir::tf_saved_model::LookupBoundInput(func, i, symbol_table); + auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType< + mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table); if (global_tensor) { auto old_type = arg.getType(); auto new_type = @@ -2968,17 +2970,37 @@ void SortSavedModelModule(mlir::ModuleOp module) { mlir::FuncOp func; }; llvm::SmallVector named_funcs; + llvm::SmallVector private_funcs; for (auto func : module.getOps()) { auto exported_names = mlir::tf_saved_model::GetExportedNames(func); - named_funcs.push_back( - {exported_names.empty() ? "" : exported_names.front(), func}); + if (!exported_names.empty()) + named_funcs.push_back({exported_names.front(), func}); + else + private_funcs.push_back(func); } llvm::stable_sort(named_funcs, [](const NamedFunc& a, const NamedFunc& b) { - return std::make_tuple(a.name.empty(), a.name) < - std::make_tuple(b.name.empty(), b.name); + return a.name < b.name; + }); + llvm::stable_sort(private_funcs, [](mlir::FuncOp a, mlir::FuncOp b) { + return a.getName() < b.getName(); + }); + + struct NamedAsset { + llvm::StringRef name; + AssetOp asset; + }; + llvm::SmallVector assets; + for (auto asset : module.getOps()) { + assets.push_back({asset.getName(), asset}); + } + llvm::stable_sort(assets, [](const NamedAsset& a, const NamedAsset& b) { + return a.name < b.name; }); // Move onto the front of the module in reverse of the final desired order. + for (auto func : llvm::reverse(private_funcs)) { + func.getOperation()->moveBefore(&module.getBody()->front()); + } for (auto named_func : llvm::reverse(named_funcs)) { named_func.func.getOperation()->moveBefore(&module.getBody()->front()); } @@ -2987,6 +3009,10 @@ void SortSavedModelModule(mlir::ModuleOp module) { &module.getBody()->front()); } + for (auto asset : assets) { + asset.asset.getOperation()->moveBefore(&module.getBody()->front()); + } + auto initializers = module.getOps(); if (!initializers.empty()) { (*initializers.begin()) @@ -3282,7 +3308,8 @@ class SavedModelSignatureDefImporter { graph_(std::make_unique(OpRegistry::Global())), debug_info_(), exported_names_(exported_names), - module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) { + module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))), + symbol_table_(module_.get()) { // debug_info might not be loaded with loader_lite. if (bundle_.debug_info != nullptr) debug_info_ = *bundle_.debug_info; } @@ -3297,8 +3324,13 @@ class SavedModelSignatureDefImporter { Status ConvertSignature(const std::string& sig_def_key, const SignatureDef& signature_def); + struct AssetInfo { + std::string tensor_name; + mlir::tf_saved_model::AssetOp op; + }; + StatusOr> ConvertAssets(); // Converts the initialization graph in the SavedModel to an MLIR function. - Status ConvertInitializer(); + Status ConvertInitializer(const std::vector& assets); // Converts a graph with feeds and fetches to an MLIR function. StatusOr ConvertGraph( @@ -3307,19 +3339,13 @@ class SavedModelSignatureDefImporter { const std::vector>& outputs, const std::vector control_outputs); - // Remove variables in the session initializer. - Status RemoveVariablesInSessionInitializer(); - - // Removes the variable and related ops in the init function if it is already - // imported as a global tensor. - void RemoveVariable(VarHandleOp op); - - // Runs graph pruning and executor dialect to functional conversion. - Status ExecutorDialectToFunctional(); - // Lifts the variables in `module_`. Status LiftVariables(); + // Moves the functions in `sub_module` to `module_` and skips the duplicate + // functions. + void MoveConvertedFunctionsToModule(mlir::ModuleOp sub_module); + GraphImportConfig::InputArrays ParseInputArrays( const std::vector>& inputs); @@ -3331,6 +3357,7 @@ class SavedModelSignatureDefImporter { GraphDebugInfo debug_info_; absl::Span exported_names_; mlir::OwningModuleRef module_; + mlir::SymbolTable symbol_table_; }; Status SavedModelSignatureDefImporter::InitializeGraph(bool upgrade_legacy) { @@ -3352,32 +3379,77 @@ Status SavedModelSignatureDefImporter::InitializeGraph(bool upgrade_legacy) { return Status::OK(); } -Status SavedModelSignatureDefImporter::ConvertInitializer() { +StatusOr> +SavedModelSignatureDefImporter::ConvertAssets() { std::vector asset_file_defs; TF_RETURN_IF_ERROR( internal::GetAssetFileDefs(bundle_.meta_graph_def, &asset_file_defs)); - if (!asset_file_defs.empty()) - return errors::Unimplemented( - absl::StrCat("Assets are not supported in signaturedef importer")); + std::vector results; + results.reserve(asset_file_defs.size()); + + mlir::OpBuilder builder(module_->getBodyRegion()); + for (const auto& asset : asset_file_defs) { + auto asset_op = builder.create( + module_->getLoc(), + /*sym_name=*/ + builder.getStringAttr( + absl::StrCat("__tf_saved_model_asset_", asset.filename())), + /*filename=*/ + builder.getStringAttr( + io::JoinPath(kSavedModelAssetsDirectory, asset.filename()))); + + results.push_back({asset.tensor_info().name(), asset_op}); + } + + return results; +} +void SavedModelSignatureDefImporter::MoveConvertedFunctionsToModule( + mlir::ModuleOp sub_module) { + // Iterate through all functions and insert the ones that do not already exist + // in `module_`. + for (auto func : sub_module.getOps()) { + if (symbol_table_.lookup(func.getName())) continue; + symbol_table_.insert(func.clone()); + } +} + +Status SavedModelSignatureDefImporter::ConvertInitializer( + const std::vector& assets) { std::string init_node_name; TF_RETURN_IF_ERROR( internal::GetInitOp("", bundle_.meta_graph_def, &init_node_name)); if (init_node_name.empty()) return Status::OK(); - TF_ASSIGN_OR_RETURN(auto sub_module, - ConvertGraph(init_node_name, {}, {}, {init_node_name})); + std::vector> inputs; + inputs.reserve(assets.size()); + for (const auto& asset : assets) { + TensorInfo tensor_info; + tensor_info.set_name(asset.tensor_name); + tensor_info.set_dtype(DT_STRING); + inputs.push_back({asset.tensor_name, tensor_info}); + } - mlir::SymbolTable symbol_table(*sub_module); + TF_ASSIGN_OR_RETURN(auto sub_module, ConvertGraph(init_node_name, inputs, {}, + {init_node_name})); - auto init_func_op = symbol_table.lookup(init_node_name); + mlir::SymbolTable sub_symbol_table(*sub_module); + auto init_func_op = sub_symbol_table.lookup(init_node_name); init_func_op.removeAttr("tf.entry_function"); mlir::OpBuilder builder(module_->getBodyRegion()); + // Bind asset inputs to asset ops. + assert(init_func_op.getNumArguments() == assets.size()); + for (const auto& iter : llvm::enumerate(assets)) { + auto asset_op = iter.value().op; + init_func_op.setArgAttr(iter.index(), "tf_saved_model.bound_input", + builder.getSymbolRefAttr(asset_op.getName())); + } + // Set the exported name of init function to an reserved name for // tf_saved_model. init_func_op.setAttr( @@ -3388,11 +3460,7 @@ Status SavedModelSignatureDefImporter::ConvertInitializer() { module_->getLoc(), builder.getSymbolRefAttr(init_func_op.getName())); // Move the converted functions to top level MLIR module. - auto* block = module_->getBody(); - auto* sub_block = sub_module->getBody(); - block->getOperations().splice( - mlir::Block::iterator(block->getTerminator()), sub_block->getOperations(), - sub_block->begin(), mlir::Block::iterator(sub_block->getTerminator())); + MoveConvertedFunctionsToModule(*sub_module); return Status::OK(); } @@ -3426,14 +3494,13 @@ SavedModelSignatureDefImporter::ConvertSignatures() { TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def)); } - TF_RETURN_IF_ERROR(ConvertInitializer()); + TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets()); + TF_RETURN_IF_ERROR(ConvertInitializer(assets)); mlir::OpBuilder builder(module_->getBodyRegion()); module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr()); module_->setAttr("tf_saved_model.under_construction", builder.getUnitAttr()); - TF_RETURN_IF_ERROR(ExecutorDialectToFunctional()); - TF_RETURN_IF_ERROR(RemoveVariablesInSessionInitializer()); TF_RETURN_IF_ERROR(LiftVariables()); module_->removeAttr("tf_saved_model.under_construction"); @@ -3482,8 +3549,8 @@ Status SavedModelSignatureDefImporter::ConvertSignature( mlir::OpBuilder builder(sub_module->getBodyRegion()); // Find the FuncOp which corresponds to current SignatureDef. - mlir::SymbolTable symbol_table(*sub_module); - auto func_op = symbol_table.lookup(sig_def_key); + mlir::SymbolTable sub_symbol_table(*sub_module); + auto func_op = sub_symbol_table.lookup(sig_def_key); TF_RET_CHECK(func_op) << "Graphdef importer should have created a function named " << sig_def_key << "."; @@ -3504,92 +3571,19 @@ Status SavedModelSignatureDefImporter::ConvertSignature( } // Move the converted functions to top level MLIR module. - auto* block = module_->getBody(); - auto* sub_block = sub_module->getBody(); - block->getOperations().splice( - mlir::Block::iterator(block->getTerminator()), sub_block->getOperations(), - sub_block->begin(), mlir::Block::iterator(sub_block->getTerminator())); - - return Status::OK(); -} // namespace - -Status SavedModelSignatureDefImporter::RemoveVariablesInSessionInitializer() { - // TODO(b/153507667): Make a pass for the job. - SessionInitializerOp session_initializer = - mlir::tf_saved_model::GetSessionInitializerOp(*module_); - - if (!session_initializer) return Status::OK(); - - mlir::FuncOp session_initializer_func = nullptr; - - for (auto func : module_->getOps()) { - if (session_initializer.initializer() == func.getName()) { - session_initializer_func = func; - break; - } - } - - if (!session_initializer_func) - return errors::Internal("No session initializer function found."); - - if (session_initializer_func.getBlocks().size() != 1) - return errors::Internal("Expects exactly one block in the MLIR function."); - - llvm::SmallVector init_vars; - mlir::Block& block = session_initializer_func.getBlocks().front(); - for (VarHandleOp op : block.getOps()) { - init_vars.push_back(op); - } - - for (auto op : init_vars) RemoveVariable(op); + MoveConvertedFunctionsToModule(*sub_module); return Status::OK(); } -void SavedModelSignatureDefImporter::RemoveVariable(VarHandleOp op) { - llvm::SmallVector work_list; - work_list.push_back(op); - while (!work_list.empty()) { - auto* op = work_list.back(); - work_list.pop_back(); - - for (mlir::Value res : op->getResults()) { - for (mlir::Operation* user : res.getUsers()) { - work_list.push_back(user); - } - } - - for (auto& use : op->getOpOperands()) { - if (mlir::Value value = use.get()) { - mlir::Operation* def = value.getDefiningOp(); - work_list.push_back(def); - } - } - - op->dropAllReferences(); - op->dropAllDefinedValueUses(); - - op->erase(); - } -} - -Status SavedModelSignatureDefImporter::ExecutorDialectToFunctional() { +Status SavedModelSignatureDefImporter::LiftVariables() { mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); mlir::PassManager pm(module_->getContext()); pm.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass()); pm.addPass(mlir::CreateExecutorDialectToFunctionalConversionPass()); - if (mlir::failed(pm.run(*module_))) - return diag_handler.Combine( - errors::Internal("failed to coarsening islands.")); - - return Status::OK(); -} - -Status SavedModelSignatureDefImporter::LiftVariables() { - mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); - - mlir::PassManager pm(module_->getContext()); + pm.addPass( + mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass()); pm.addPass( mlir::TF:: CreateConvertReadonlyReferenceVariablesToResourceVariablesPass()); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index de9078d9c40392..5e548da55f13ff 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -108,7 +108,7 @@ Status GetXlaInputShapes( // Rewrite layout with sharding, if sharding is set. auto sharding = - main_func.getArgAttrOfType(i, "xla_hlo.sharding"); + main_func.getArgAttrOfType(i, "mhlo.sharding"); if (!sharding) continue; absl::optional arg_sharding; @@ -253,7 +253,7 @@ static void RegisterDialects() { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); - mlir::registerDialect(); + mlir::registerDialect(); return true; }(); (void)init_once; @@ -279,9 +279,9 @@ Status ConvertMLIRToXlaComputation( // LegalizeTFControlFlow encapsulates arguments for control flow operations // with a tuple argument which break the assumption of resource lifting // inside PromoteResourcesToArgs. - tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass()); + tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); - tf2xla.addNestedPass(mlir::xla_hlo::createLegalizeTFPass(true)); + tf2xla.addNestedPass(mlir::mhlo::createLegalizeTFPass(true)); for (auto& target_pass : custom_legalization_passes) { tf2xla.addNestedPass(std::move(target_pass)); } @@ -290,7 +290,7 @@ Status ConvertMLIRToXlaComputation( // Leverage tf2xla kernels for ops that didn't get lowered in the previous // legalization pass. - tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type)); + tf2xla.addPass(mlir::mhlo::createLegalizeTfWithTf2XlaPass(device_type)); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); // Run shape inference pass to propagate shapes through tensor_cast operations @@ -303,12 +303,11 @@ Status ConvertMLIRToXlaComputation( // expose more graph pruning and canonicalization opportunities that are // necessary for the second LegalizeTFPass(allow_partial_conversion=false) // invocation. - tf2xla.addNestedPass( - mlir::xla_hlo::createLegalizeTFPass(false)); + tf2xla.addNestedPass(mlir::mhlo::createLegalizeTFPass(false)); // In order to export to XLA, we must sink constants to control flow regions, // since XLA uses functional control flow. tf2xla.addNestedPass( - mlir::xla_hlo::createSinkConstantsToControlFlowPass()); + mlir::mhlo::createSinkConstantsToControlFlowPass()); if (VLOG_IS_ON(1)) { // Print the whole module after each pass which requires disabling diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index cc724ba36f8ca8..dde2408c83a1f2 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -184,7 +184,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { // only be lowered when tf.Shape is folded into a constant. constexpr char mlir_module[] = R"( module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {xla_hlo.is_same_data_across_replicas}) -> tensor<10x19xf32> { + func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas}) -> tensor<10x19xf32> { %0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64> %1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor<10x19xf32> return %1 : tensor<10x19xf32> @@ -344,7 +344,7 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { TEST(CompileSerializedMlirToXlaHloTest, ArgumentSharding) { constexpr char mlir_module[] = R"( module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {xla_hlo.sharding = ""}) { + func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) { return } } @@ -383,7 +383,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> () { TEST(CompileSerializedMlirToXlaHloTest, BadArgumentSharding) { constexpr char mlir_module[] = R"( module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "bad_sharding"}) { + func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "bad_sharding"}) { return } } @@ -403,7 +403,7 @@ module attributes {tf.versions = {producer = 179 : i32}} { TEST(CompileSerializedMlirToXlaHloTest, ResultSharding) { constexpr char mlir_module[] = R"( module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} { - func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {xla_hlo.sharding = ""}) { + func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {mhlo.sharding = ""}) { return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32> } } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index 49a8f704b309e5..b23fbe7d73c737 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -761,7 +761,7 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) { auto replicate = builder.create( mlir::UnknownLoc::get(&context), /*num_replicas=*/2, devices, llvm::ArrayRef, mlir::Type>>{}, - llvm::ArrayRef{}); + llvm::ArrayRef{}, llvm::ArrayRef{}); builder.setInsertionPoint(&replicate.body().front(), replicate.body().front().begin()); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 7b4e2d9c2e51b8..b5735f823e4533 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/mlir/hlo:xla_materialize_broadcasts", # buildcleaner: keep - "//tensorflow/compiler/mlir/hlo:xla_unfuse_batch_norm", # buildcleaner: keep + "//tensorflow/compiler/mlir/hlo:materialize_broadcasts", # buildcleaner: keep + "//tensorflow/compiler/mlir/hlo:unfuse_batch_norm", # buildcleaner: keep "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service/gpu:stream_executor_util", @@ -41,7 +41,7 @@ cc_library( tf_cc_binary( name = "tf_to_cubin", srcs = ["tf_to_cubin.cc"], - visibility = ["//tensorflow/core/kernels/cubin_headers:__pkg__"], + visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"], deps = [ ":cubin_creator", "//tensorflow/compiler/mlir:init_mlir", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc index 05456bf60f3aa8..1f511e27d9e0bf 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc @@ -87,15 +87,15 @@ struct MaterializeBroadcastsPass mlir::ConversionTarget conversionTarget(getContext()); mlir::OwningRewritePatternList conversionPatterns; - // Consider the xla_hlo dialect legal for tests. - conversionTarget.addLegalDialect(); + // Consider the mhlo dialect legal for tests. + conversionTarget.addLegalDialect(); // The conversion uses helpers from the Standard dialect. conversionTarget.addLegalDialect(); - mlir::xla_hlo::SetupMaterializeBroadcastsLegality(&getContext(), - &conversionTarget); - mlir::xla_hlo::PopulateMaterializeBroadcastsPatterns(&getContext(), - &conversionPatterns); + mlir::mhlo::SetupMaterializeBroadcastsLegality(&getContext(), + &conversionTarget); + mlir::mhlo::PopulateMaterializeBroadcastsPatterns(&getContext(), + &conversionPatterns); if (failed(applyPartialConversion(getFunction(), conversionTarget, conversionPatterns))) { @@ -108,7 +108,7 @@ struct UnfuseBatchNormPass : public mlir::PassWrapper { void runOnFunction() override { mlir::OwningRewritePatternList patterns; - mlir::xla_hlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); + mlir::mhlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); } }; @@ -122,13 +122,13 @@ Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) { /*shouldPrintAfterPass=*/enable_if_vlog_is_on, /*printModuleScope=*/false, /*printAfterOnlyOnChange=*/false, llvm::dbgs()); - pm.addNestedPass(mlir::xla_hlo::createLegalizeTFPass(false)); + pm.addNestedPass(mlir::mhlo::createLegalizeTFPass(false)); pm.addNestedPass( absl::make_unique()); pm.addNestedPass(absl::make_unique()); - pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass( + pm.addPass(mlir::mhlo::createLegalizeToLhloPass( /*results_escape_functions=*/true)); - pm.addNestedPass(mlir::xla_lhlo::createLhloCopyRemovalPass()); + pm.addNestedPass(mlir::lmhlo::createLhloCopyRemovalPass()); if (failed(pm.run(module))) { return InternalError("Lowering TF to LHLO failed."); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index f0e9c12bc86d20..838b060079c682 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -91,7 +91,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_inc_gen", "//tensorflow/compiler/mlir/tensorflow:translate_utils", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla/client:xla_builder", @@ -116,9 +115,9 @@ cc_library( ) cc_library( - name = "xla_hlo_to_lhlo_with_xla", - srcs = ["transforms/xla_hlo_to_lhlo_with_xla.cc"], - hdrs = ["transforms/xla_hlo_to_lhlo_with_xla.h"], + name = "mhlo_to_lhlo_with_xla", + srcs = ["transforms/mhlo_to_lhlo_with_xla.cc"], + hdrs = ["transforms/mhlo_to_lhlo_with_xla.h"], deps = [ ":hlo_module_importer", ":hlo_utils", @@ -341,23 +340,23 @@ tf_native_cc_binary( ], ) -genrule( +gentbl( name = "operator_writer_inc", - srcs = [ + tbl_outs = [("", "operator_writers.inc")], + tblgen = ":operator_writer_gen", + td_file = "//tensorflow/compiler/mlir/hlo:include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", + td_relative_includes = [ + "../hlo/include", + ], + td_srcs = [ + "@llvm-project//mlir:include/mlir/IR/OpBase.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", - "@llvm-project//mlir:include/mlir/IR/OpBase.td", "//tensorflow/compiler/mlir/hlo:hlo_ops_td_files", - "//tensorflow/compiler/mlir/hlo:include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", + # Any file in this directory is OK: this will force the current path to exist so + # that the relative path can be resolved. + "BUILD", ], - outs = ["operator_writers.inc"], - cmd = ("$(location :operator_writer_gen) " + - "-I external/llvm-project/mlir/include " + - "-I tensorflow/compiler/mlir " + - "-I tensorflow/compiler/mlir/hlo/include " + - "$(location //tensorflow/compiler/mlir/hlo:include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td) " + - " -o $@"), - tools = [":operator_writer_gen"], ) cc_library( @@ -366,28 +365,28 @@ cc_library( "//tensorflow/compiler/mlir:__subpackages__", ], deps = [ - ":xla_hlo_to_lhlo_with_xla", + ":mhlo_to_lhlo_with_xla", ":xla_legalize_tf", ":xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo", "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", + "//tensorflow/compiler/mlir/hlo:legalize_control_flow", + "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation", + "//tensorflow/compiler/mlir/hlo:legalize_to_linalg", + "//tensorflow/compiler/mlir/hlo:legalize_to_standard", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/mlir/hlo:lhlo_copy_removal", "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_parallel_loops", - "//tensorflow/compiler/mlir/hlo:xla_hlo_fusion", - "//tensorflow/compiler/mlir/hlo:xla_legalize_control_flow", - "//tensorflow/compiler/mlir/hlo:xla_legalize_tanh_to_approximation", - "//tensorflow/compiler/mlir/hlo:xla_legalize_to_linalg", - "//tensorflow/compiler/mlir/hlo:xla_legalize_to_standard", - "//tensorflow/compiler/mlir/hlo:xla_lower", - "//tensorflow/compiler/mlir/hlo:xla_sink_constants_to_control_flow", - "//tensorflow/compiler/mlir/hlo:xla_test_passes", - "//tensorflow/compiler/mlir/hlo:xla_transform_unranked_hlo", + "//tensorflow/compiler/mlir/hlo:mhlo_fusion", + "//tensorflow/compiler/mlir/hlo:mhlo_to_mhlo_lowering_patterns", + "//tensorflow/compiler/mlir/hlo:sink_constants_to_control_flow", + "//tensorflow/compiler/mlir/hlo:test_passes", + "//tensorflow/compiler/mlir/hlo:transform_unranked_hlo", ], ) diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.cc b/tensorflow/compiler/mlir/xla/attribute_importer.cc index 201ec0d053f91e..5a3b20b97cac7d 100644 --- a/tensorflow/compiler/mlir/xla/attribute_importer.cc +++ b/tensorflow/compiler/mlir/xla/attribute_importer.cc @@ -42,7 +42,7 @@ mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, } // Converts the gather dimensions to attributes. -mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers( +mlir::mhlo::GatherDimensionNumbers ConvertGatherDimensionNumbers( const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) { std::vector offset_dims(dnums.offset_dims().begin(), dnums.offset_dims().end()); @@ -50,14 +50,14 @@ mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers( dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); std::vector start_index_map(dnums.start_index_map().begin(), dnums.start_index_map().end()); - return mlir::xla_hlo::GatherDimensionNumbers::get( + return mlir::mhlo::GatherDimensionNumbers::get( Convert(offset_dims, builder), Convert(collapsed_slice_dims, builder), Convert(start_index_map, builder), builder->getI64IntegerAttr(dnums.index_vector_dim()), builder->getContext()); } -mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers( +mlir::mhlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers( const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) { std::vector update_window_dims(dnums.update_window_dims().begin(), dnums.update_window_dims().end()); @@ -66,7 +66,7 @@ mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers( std::vector scatter_dims_to_operand_dims( dnums.scatter_dims_to_operand_dims().begin(), dnums.scatter_dims_to_operand_dims().end()); - return mlir::xla_hlo::ScatterDimensionNumbers::get( + return mlir::mhlo::ScatterDimensionNumbers::get( Convert(update_window_dims, builder), Convert(inserted_window_dims, builder), Convert(scatter_dims_to_operand_dims, builder), @@ -74,7 +74,7 @@ mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers( builder->getContext()); } -mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers( +mlir::mhlo::DotDimensionNumbers ConvertDotDimensionNumbers( const DotDimensionNumbers& dnums, mlir::Builder* builder) { std::vector rhs_contracting_dimensions( dnums.rhs_contracting_dimensions().begin(), @@ -93,12 +93,12 @@ mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers( auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions, builder); auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions, builder); - return mlir::xla_hlo::DotDimensionNumbers::get( + return mlir::mhlo::DotDimensionNumbers::get( lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr, rhs_contracting_dims_attr, builder->getContext()); } -mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers( +mlir::mhlo::ConvDimensionNumbers ConvertConvDimensionNumbers( const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) { llvm::SmallVector input_spatial_dims( dnums.input_spatial_dimensions().begin(), @@ -109,7 +109,7 @@ mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers( llvm::SmallVector output_spatial_dims( dnums.output_spatial_dimensions().begin(), dnums.output_spatial_dimensions().end()); - return mlir::xla_hlo::ConvDimensionNumbers::get( + return mlir::mhlo::ConvDimensionNumbers::get( builder->getI64IntegerAttr(dnums.input_batch_dimension()), builder->getI64IntegerAttr(dnums.input_feature_dimension()), Convert(input_spatial_dims, builder), diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.h b/tensorflow/compiler/mlir/xla/attribute_importer.h index 25ef9680220a72..d84d8762f855e0 100644 --- a/tensorflow/compiler/mlir/xla/attribute_importer.h +++ b/tensorflow/compiler/mlir/xla/attribute_importer.h @@ -29,19 +29,19 @@ mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, mlir::Builder* builder); // Converts the gather dimensions to attributes. -mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers( +mlir::mhlo::GatherDimensionNumbers ConvertGatherDimensionNumbers( const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder); // Converts the scatter dimensions to attributes. -mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers( +mlir::mhlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers( const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder); // Converts the dot dimensions to attributes. -mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers( +mlir::mhlo::DotDimensionNumbers ConvertDotDimensionNumbers( const DotDimensionNumbers& dnums, mlir::Builder* builder); // Converts the conv dimensions to attributes. -mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers( +mlir::mhlo::ConvDimensionNumbers ConvertConvDimensionNumbers( const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder); } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 77b9857c91cc0e..ad177ce1dc5652 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -171,7 +171,7 @@ tensorflow::Status HloFunctionImporter::ImportInstructions( if (llvm::isa(block->getParentOp())) { builder.create(loc, result); } else { - builder.create(loc, result); + builder.create(loc, result); } return tensorflow::Status::OK(); } @@ -202,18 +202,18 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kIota: { return func_builder - ->create( + ->create( loc, result_type, func_builder->getI64IntegerAttr( Cast(instruction)->iota_dimension())) .getOperation(); } -#define MakeAndReturn(mlir_op) \ - { \ - mlir::Operation* new_operation = \ - func_builder->create(loc, result_type, \ - operands, attributes); \ - return new_operation; \ +#define MakeAndReturn(mlir_op) \ + { \ + mlir::Operation* new_operation = \ + func_builder->create(loc, result_type, operands, \ + attributes); \ + return new_operation; \ } case HloOpcode::kBroadcast: { // Note that the HLO broadcast is more powerful than the XLA broadcast op. @@ -314,14 +314,14 @@ StatusOr HloFunctionImporter::ImportInstruction( instruction->dynamic_slice_sizes().begin(), instruction->dynamic_slice_sizes().end()); return func_builder - ->create( + ->create( loc, result_type, operands[0], makeArrayRef(operands).drop_front(), Convert(slice_sizes)) .getOperation(); } case HloOpcode::kDynamicUpdateSlice: { return func_builder - ->create( + ->create( loc, result_type, operands[0], operands[1], llvm::ArrayRef(operands.begin() + 2, operands.end())) .getOperation(); @@ -354,10 +354,10 @@ StatusOr HloFunctionImporter::ImportInstruction( } return func_builder - ->create(loc, result_type, operands[0], - operands[1], Convert(edge_padding_low), - Convert(edge_padding_high), - Convert(interior_padding)) + ->create(loc, result_type, operands[0], + operands[1], Convert(edge_padding_low), + Convert(edge_padding_high), + Convert(interior_padding)) .getOperation(); } case HloOpcode::kScatter: { @@ -372,7 +372,7 @@ StatusOr HloFunctionImporter::ImportInstruction( attributes.push_back(builder_->getNamedAttr( "unique_indices", builder_->getBoolAttr(scatter->unique_indices()))); - auto scatter_op = func_builder->create( + auto scatter_op = func_builder->create( loc, result_type, operands, attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(), &scatter_op.update_computation())); @@ -394,7 +394,7 @@ StatusOr HloFunctionImporter::ImportInstruction( Convert(window_dimensions))); attributes.push_back(ConvertPadding(padding)); auto select_scatter_op = - func_builder->create( + func_builder->create( loc, result_type, operands, attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(), &select_scatter_op.select())); @@ -410,7 +410,7 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kSlice: { return func_builder - ->create( + ->create( loc, result_type, operands[0], ConvertDimensions(instruction->slice_starts()), ConvertDimensions(instruction->slice_limits()), @@ -419,7 +419,7 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kSort: { auto sort_instruction = Cast(instruction); - auto sort_op = func_builder->create( + auto sort_op = func_builder->create( loc, result_type, operands, builder_->getI64IntegerAttr(sort_instruction->sort_dimension()), builder_->getBoolAttr(sort_instruction->is_stable())); @@ -437,8 +437,8 @@ StatusOr HloFunctionImporter::ImportInstruction( TF_RETURN_IF_ERROR(GetMlirTypes( {instruction->true_computation()->root_instruction()}, &rets)); - auto op = func_builder->create(loc, rets, operands, - attributes); + auto op = func_builder->create(loc, rets, operands, + attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(), &op.true_branch())); TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(), @@ -451,7 +451,7 @@ StatusOr HloFunctionImporter::ImportInstruction( {instruction->branch_computation(0)->root_instruction()}, &rets)); int num_branches = instruction->branch_count(); - auto op = func_builder->create( + auto op = func_builder->create( loc, rets, operands, attributes, num_branches); for (auto index_and_computation : llvm::enumerate(instruction->branch_computations())) { @@ -465,7 +465,7 @@ StatusOr HloFunctionImporter::ImportInstruction( // TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr // for concatenate dimension. return func_builder - ->create( + ->create( loc, result_type, operands, builder_->getI64IntegerAttr(instruction->concatenate_dimension())) .getOperation(); @@ -474,7 +474,7 @@ StatusOr HloFunctionImporter::ImportInstruction( auto all_reduce = Cast(instruction); attributes.push_back(ConvertReplicaGroups(all_reduce->replica_groups())); attributes.push_back(ConvertChannelHandle(all_reduce->channel_id())); - auto all_reduce_op = func_builder->create( + auto all_reduce_op = func_builder->create( loc, result_type, operands, attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(), &all_reduce_op.computation())); @@ -484,7 +484,7 @@ StatusOr HloFunctionImporter::ImportInstruction( // Operands in the first half are reduction inputs and the remaining // operands are corresponding initial values. size_t num_inputs = operands.size() / 2; - auto reduce = func_builder->create( + auto reduce = func_builder->create( loc, result_type, llvm::makeArrayRef(operands).take_front(num_inputs), llvm::makeArrayRef(operands).drop_front(num_inputs), ConvertDimensions(instruction->dimensions())); @@ -494,7 +494,7 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kReverse: { return func_builder - ->create( + ->create( loc, result_type, operands[0], ConvertDimensions(instruction->dimensions())) .getOperation(); @@ -505,14 +505,14 @@ StatusOr HloFunctionImporter::ImportInstruction( switch (instruction->random_distribution()) { case xla::RNG_UNIFORM: return func_builder - ->create( - loc, result_type, operands[0], operands[1], shape) + ->create(loc, result_type, operands[0], + operands[1], shape) .getOperation(); case xla::RNG_NORMAL: return func_builder - ->create( - loc, result_type, operands[0], operands[1], shape) + ->create(loc, result_type, operands[0], + operands[1], shape) .getOperation(); default: @@ -522,7 +522,7 @@ StatusOr HloFunctionImporter::ImportInstruction( } } case HloOpcode::kWhile: { - auto op = func_builder->create( + auto op = func_builder->create( loc, operands[0].getType(), operands[0]); TF_RETURN_IF_ERROR( ImportAsRegion(*instruction->while_condition(), &op.cond())); @@ -585,14 +585,14 @@ StatusOr HloFunctionImporter::ImportInstruction( attributes.push_back(builder_->getNamedAttr( "window_dilations", ConvertDimensions(win_dilations))); attributes.push_back(ConvertPadding(padding)); - auto reduce = func_builder->create( + auto reduce = func_builder->create( loc, result_type, operands, attributes); TF_RETURN_IF_ERROR( ImportAsRegion(*instruction->to_apply(), &reduce.body())); return reduce.getOperation(); } case HloOpcode::kMap: { - auto op = func_builder->create( + auto op = func_builder->create( loc, result_type, operands, ConvertDimensions(instruction->dimensions())); TF_RETURN_IF_ERROR( @@ -714,7 +714,7 @@ StatusOr HloFunctionImporter::ImportInstruction( // is not mentioned in xla client anywhere or in the hlo of our sample // models. default: { - mlir::OperationState result(loc, "xla_hlo.unknown"); + mlir::OperationState result(loc, "mhlo.unknown"); result.addOperands(operands); result.addTypes(result_type); for (auto attr : attributes) { @@ -840,7 +840,7 @@ mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle( const xla::ChannelHandle& channel) { return builder_->getNamedAttr( "channel_handle", - mlir::xla_hlo::ChannelHandle::get( + mlir::mhlo::ChannelHandle::get( builder_->getI64IntegerAttr(channel.handle()), builder_->getI64IntegerAttr(channel.type()), context_)); } diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index 67f8b59c815516..db981bb0227bd0 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ #include @@ -143,4 +143,4 @@ class HloFunctionImporter { } // namespace xla -#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_ +#endif // TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.h b/tensorflow/compiler/mlir/xla/hlo_module_importer.h index 6b7aded0eb6e5b..69ac1e282193bf 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_MODULE_IMPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_MODULE_IMPORTER_H_ #include @@ -59,4 +59,4 @@ class HloModuleImporter { } // namespace xla -#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_ +#endif // TENSORFLOW_COMPILER_MLIR_XLA_MODULE_IMPORTER_H_ diff --git a/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h b/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h index 5f212ffc893629..e613ce72b23758 100644 --- a/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h +++ b/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TO_MLIR_HLO_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TO_MLIR_HLO_H_ #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/xla/status.h" @@ -36,4 +36,4 @@ Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module); } // namespace xla -#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_ +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TO_MLIR_HLO_H_ diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index 930e7f6ef02147..84c574139e99fb 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -197,7 +197,7 @@ StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, } } -mlir::xla_hlo::GatherDimensionNumbers CreateGatherDimensionNumbers( +mlir::mhlo::GatherDimensionNumbers CreateGatherDimensionNumbers( const GatherDimensionNumbers& input, mlir::Builder builder) { auto offset_dims = CreateDenseIntElementsAttrFromVector( llvm::SmallVector{input.offset_dims().begin(), @@ -215,7 +215,7 @@ mlir::xla_hlo::GatherDimensionNumbers CreateGatherDimensionNumbers( mlir::IntegerAttr index_vector_dim = builder.getI64IntegerAttr(input.index_vector_dim()); - return mlir::xla_hlo::GatherDimensionNumbers::get( + return mlir::mhlo::GatherDimensionNumbers::get( offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim, builder.getContext()); } diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.h b/tensorflow/compiler/mlir/xla/hlo_utils.h index 3db1c7324b829c..1b77d60c83c672 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/hlo_utils.h @@ -15,8 +15,8 @@ limitations under the License. // This file defines helpers useful when creating or manipulating lhlo/hlo. -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_UTILS_H_ #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -39,7 +39,7 @@ mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, mlir::Builder builder); -mlir::xla_hlo::GatherDimensionNumbers CreateGatherDimensionNumbers( +mlir::mhlo::GatherDimensionNumbers CreateGatherDimensionNumbers( const GatherDimensionNumbers& input, mlir::Builder builder); template @@ -77,11 +77,11 @@ static StatusOr ConvertShapeToType(const Shape& shape, return builder.getTupleType(contents); } if (shape.IsToken()) { - return mlir::xla_hlo::TokenType::get(builder.getContext()); + return mlir::mhlo::TokenType::get(builder.getContext()); } return ConvertTensorShapeToType(shape, builder); } } // namespace xla -#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_XLA_UTILS_H_ diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 49bf8e84d016bc..31512c90f097c4 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -33,8 +33,7 @@ namespace xla { static std::string GetMlirOpName(HloOpcode opcode) { std::string op_name = HloOpcodeString(opcode); absl::c_replace(op_name, '-', '_'); - return mlir::xla_hlo::XlaHloDialect::getDialectNamespace().str() + "." + - op_name; + return mlir::mhlo::MhloDialect::getDialectNamespace().str() + "." + op_name; } static std::string ToString(mlir::Type ty) { @@ -90,7 +89,7 @@ XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(mlir::DenseElementsAttr attr, CreateDenseElementsAttrFromLiteral(literal, builder_)); - auto op = builder_.create(loc_, attr); + auto op = builder_.create(loc_, attr); return MakeXlaOp(op); }); } @@ -108,7 +107,7 @@ StatusOr MlirHloBuilder::ConvGeneralDilatedInternal( mlir::ArrayAttr config_attr; if (precision_config) config_attr = ConvertPrecisionConfig(precision_config, &builder_); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, GetValue(lhs), GetValue(rhs), GetI64ElementsAttr(window_strides, &builder_), ConvertPadding(padding, &builder_), @@ -125,7 +124,7 @@ StatusOr MlirHloBuilder::FftInternal( absl::Span fft_length) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, GetValue(operand), builder_.getStringAttr(FftType_Name(fft_type)), GetI64ElementsAttr(fft_length, &builder_)); @@ -135,15 +134,16 @@ StatusOr MlirHloBuilder::FftInternal( StatusOr MlirHloBuilder::CustomCallInternal( const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, - absl::optional> operand_shapes_with_layout) { + absl::optional> operand_shapes_with_layout, + bool has_side_effect) { if (operand_shapes_with_layout.has_value()) return Unimplemented( "CustomCall doesn't support operands shapes with layout"); TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name), - /*has_side_effect=*/builder_.getBoolAttr(false), + /*has_side_effect=*/builder_.getBoolAttr(has_side_effect), builder_.getStringAttr(opaque)); return MakeXlaOp(op); } @@ -155,13 +155,13 @@ StatusOr MlirHloBuilder::ReduceInternal( // Reduce takes two set of variadic operands inputs and init_values. // all_operands contains both of these so split operands into two parts. int64_t num_args = all_operands.size() / 2; - auto op = builder_.create( + auto op = builder_.create( loc_, GetValues(all_operands.first(num_args)), GetValues(all_operands.subspan(num_args)), GetI64ElementsAttr(dimensions_to_reduce, &builder_)); TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body())); if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0)); - auto tuple = builder_.create(loc_, op.getResults()); + auto tuple = builder_.create(loc_, op.getResults()); return MakeXlaOp(tuple); } @@ -183,7 +183,7 @@ StatusOr MlirHloBuilder::ReduceWindowInternal( auto padding_ty = mlir::RankedTensorType::get({static_cast(padding.size()) / 2, 2}, builder_.getIntegerType(64)); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, GetValue(operand), GetValue(init_value), GetI64ElementsAttr(sizes, &builder_), GetI64ElementsAttr(strides, &builder_), @@ -199,7 +199,7 @@ XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) { TF_ASSIGN_OR_RETURN( mlir::Type ty, ConvertShapeToType(shape, builder_)); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension)); return MakeXlaOp(op); @@ -210,7 +210,7 @@ StatusOr MlirHloBuilder::TransposeInternal( const Shape& shape, XlaOp operand, absl::Span permutation) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_)); return MakeXlaOp(op); } @@ -219,7 +219,7 @@ StatusOr MlirHloBuilder::RevInternal( const Shape& shape, XlaOp operand, absl::Span dimensions) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_)); return MakeXlaOp(op); } @@ -230,7 +230,7 @@ StatusOr MlirHloBuilder::GatherInternal( absl::Span slice_sizes, bool indices_are_sorted) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, GetValue(input), GetValue(start_indices), ConvertGatherDimensionNumbers(dimension_numbers, &builder_), GetI64ElementsAttr(slice_sizes, &builder_)); @@ -244,7 +244,7 @@ StatusOr MlirHloBuilder::ScatterInternal( bool unique_indices) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates), ConvertScatterDimensionNumbers(dimension_numbers, &builder_), builder_.getBoolAttr(indices_are_sorted), @@ -262,11 +262,11 @@ StatusOr MlirHloBuilder::RngOpInternal( // and RngNormal can be mapped to the new op. std::string op_name; if (distribution == xla::RandomDistribution::RNG_UNIFORM) { - op_name = "xla_hlo.rng_uniform"; + op_name = "mhlo.rng_uniform"; } else { TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL) << "Unexpected distribution: " << distribution; - op_name = "xla_hlo.rng_normal"; + op_name = "mhlo.rng_normal"; } if (shape.is_dynamic()) @@ -288,7 +288,7 @@ StatusOr MlirHloBuilder::ReshapeInternal(const Shape& shape, TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); mlir::Value value = GetValue(operand); - auto op = builder_.create(loc_, ty, value); + auto op = builder_.create(loc_, ty, value); return MakeXlaOp(op.getResult()); } @@ -298,7 +298,7 @@ StatusOr MlirHloBuilder::DotGeneralInternal( const PrecisionConfig* precision_config) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, GetValue(lhs), GetValue(rhs), ConvertDotDimensionNumbers(dimension_number, &builder_), ConvertPrecisionConfig(precision_config, &builder_)); @@ -312,7 +312,7 @@ StatusOr MlirHloBuilder::InDimBroadcast( TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); mlir::Value value = GetValue(operand); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, value, GetI64ElementsAttr(broadcast_dimensions, &builder_)); return MakeXlaOp(op.getResult()); } @@ -322,7 +322,7 @@ StatusOr MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs, ComparisonDirection direction) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, GetValue(lhs), GetValue(rhs), builder_.getStringAttr(ComparisonDirectionToString(direction))); return MakeXlaOp(op.getResult()); @@ -343,8 +343,8 @@ StatusOr MlirHloBuilder::AddOpWithShape( XlaOp MlirHloBuilder::CreateToken() { return ReportErrorOrReturn([&]() -> StatusOr { - return MakeXlaOp(builder_.create( - loc_, mlir::xla_hlo::TokenType::get(builder_.getContext()))); + return MakeXlaOp(builder_.create( + loc_, mlir::mhlo::TokenType::get(builder_.getContext()))); }); } @@ -353,16 +353,16 @@ StatusOr MlirHloBuilder::InfeedWithTokenInternal( TF_ASSIGN_OR_RETURN(mlir::Type result_type, ConvertShapeToType( infeed_instruction_shape, builder_)); - return MakeXlaOp(builder_.create( - loc_, result_type, GetValue(token), - /*infeed_config=*/config)); + return MakeXlaOp( + builder_.create(loc_, result_type, GetValue(token), + /*infeed_config=*/config)); } StatusOr MlirHloBuilder::OutfeedWithTokenInternal( XlaOp operand, XlaOp token, const Shape& shape_with_layout, const string& outfeed_config) { - auto token_type = mlir::xla_hlo::TokenType::get(builder_.getContext()); - return MakeXlaOp(builder_.create( + auto token_type = mlir::mhlo::TokenType::get(builder_.getContext()); + return MakeXlaOp(builder_.create( loc_, token_type, GetValue(operand), GetValue(token), outfeed_config)); } @@ -372,7 +372,7 @@ StatusOr MlirHloBuilder::ConcatInDimInternal( mlir::Type result_type, ConvertShapeToType(shape, builder_)); auto mlir_operands = GetValues(operands); - return MakeXlaOp(builder_.create( + return MakeXlaOp(builder_.create( loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension))); } @@ -382,7 +382,7 @@ StatusOr MlirHloBuilder::GetTupleElementInternal(const Shape& shape, TF_ASSIGN_OR_RETURN( mlir::Type result_type, ConvertShapeToType(shape, builder_)); - return MakeXlaOp(builder_.create( + return MakeXlaOp(builder_.create( loc_, result_type, GetValue(tuple_data), builder_.getI32IntegerAttr(index))); } @@ -390,7 +390,7 @@ StatusOr MlirHloBuilder::GetTupleElementInternal(const Shape& shape, StatusOr MlirHloBuilder::SliceInternal( const Shape& shape, XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides) { - return MakeXlaOp(builder_.create( + return MakeXlaOp(builder_.create( loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_), GetI64ElementsAttr(limit_indices, &builder_), GetI64ElementsAttr(strides, &builder_))); @@ -402,7 +402,7 @@ StatusOr MlirHloBuilder::DynamicSliceInternal( TF_ASSIGN_OR_RETURN( mlir::Type result_ty, ConvertShapeToType(shape, builder_)); - return MakeXlaOp(builder_.create( + return MakeXlaOp(builder_.create( loc_, result_ty, GetValue(operand), GetValues(start_indices), GetI64ElementsAttr(slice_sizes, &builder_))); } @@ -413,7 +413,7 @@ StatusOr MlirHloBuilder::DynamicUpdateSliceInternal( TF_ASSIGN_OR_RETURN( mlir::Type result_ty, ConvertShapeToType(shape, builder_)); - return MakeXlaOp(builder_.create( + return MakeXlaOp(builder_.create( loc_, result_ty, GetValue(operand), GetValue(update), GetValues(start_indices))); } @@ -432,7 +432,7 @@ StatusOr MlirHloBuilder::PadInternal( high.push_back(dimension.edge_padding_high()); internal.push_back(dimension.interior_padding()); } - return MakeXlaOp(builder_.create( + return MakeXlaOp(builder_.create( loc_, result_type, GetValue(operand), GetValue(padding_value), GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_), GetI64ElementsAttr(internal, &builder_))); @@ -444,7 +444,7 @@ StatusOr MlirHloBuilder::TupleInternal( for (auto& element : elements) { operands.push_back(GetValue(element)); } - return MakeXlaOp(builder_.create(loc_, operands)); + return MakeXlaOp(builder_.create(loc_, operands)); } StatusOr MlirHloBuilder::CreateOp( diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 4d7d93af7a7514..ab1a0d2c9b3c8e 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -34,7 +34,7 @@ limitations under the License. namespace xla { -// Provides a way to construct xla_hlo dialect ops in MLIR using XlaBuilder +// Provides a way to construct mhlo dialect ops in MLIR using XlaBuilder // interface. // // Requires that all XlaOp arguments are either returned by any of the builder @@ -124,11 +124,11 @@ class MlirHloBuilder : public XlaBuilder { FftType fft_type, absl::Span fft_length) override; - StatusOr CustomCallInternal(const string& call_target_name, - absl::Span operands, - const Shape& shape, const string& opaque, - absl::optional> - operand_shapes_with_layout) override; + StatusOr CustomCallInternal( + const string& call_target_name, absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> operand_shapes_with_layout, + bool has_side_effect) override; StatusOr ReduceInternal( const Shape& shape, absl::Span all_operands, diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 4f414df680d864..a4c3c43cfbfdb4 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -69,12 +69,12 @@ using ::tensorflow::uint32; using ::tensorflow::uint64; using ::tensorflow::uint8; -constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map"; +constexpr char kPaddingMapAttr[] = "mhlo.padding_map"; constexpr char kShapeIndicesAttr[] = "shape_indices"; constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices"; -constexpr char kShardingAttr[] = "xla_hlo.sharding"; -constexpr char kFrontendAttributesAttr[] = "xla_hlo.frontend_attributes"; -constexpr char kRepicationAttr[] = "xla_hlo.is_same_data_across_replicas"; +constexpr char kShardingAttr[] = "mhlo.sharding"; +constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes"; +constexpr char kRepicationAttr[] = "mhlo.is_same_data_across_replicas"; // Passes through everything except for unique_ptr, on which it calls get(). // This exists to allow the generated code to call XLA functions that take a raw @@ -247,7 +247,7 @@ static std::unique_ptr Convert_precision_config( } static xla::DotDimensionNumbers Convert_dot_dimension_numbers( - mlir::xla_hlo::DotDimensionNumbers dot_dimension_numbers_attr) { + mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr) { xla::DotDimensionNumbers dot_dimension_numbers; auto rhs_contracting_dimensions = @@ -282,7 +282,7 @@ static xla::DotDimensionNumbers Convert_dot_dimension_numbers( } static xla::ConvolutionDimensionNumbers Convert_dimension_numbers( - mlir::xla_hlo::ConvDimensionNumbers input) { + mlir::mhlo::ConvDimensionNumbers input) { xla::ConvolutionDimensionNumbers output; output.set_input_batch_dimension( @@ -315,7 +315,7 @@ static xla::ConvolutionDimensionNumbers Convert_dimension_numbers( return output; } -xla::ChannelHandle Convert_channel_handle(mlir::xla_hlo::ChannelHandle attr) { +xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) { xla::ChannelHandle channel_handle; channel_handle.set_handle(ConvertAPInt(attr.handle().getValue())); channel_handle.set_type(static_cast( @@ -333,7 +333,7 @@ static xla::ComparisonDirection Convert_comparison_direction( } static xla::GatherDimensionNumbers Convert_dimension_numbers( - mlir::xla_hlo::GatherDimensionNumbers input) { + mlir::mhlo::GatherDimensionNumbers input) { xla::GatherDimensionNumbers output; auto offset_dims = ConvertDenseIntAttr(input.offset_dims()); @@ -357,7 +357,7 @@ static xla::GatherDimensionNumbers Convert_dimension_numbers( } static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers( - mlir::xla_hlo::ScatterDimensionNumbers input) { + mlir::mhlo::ScatterDimensionNumbers input) { xla::ScatterDimensionNumbers output; auto update_window_dims = ConvertDenseIntAttr(input.update_window_dims()); @@ -574,7 +574,7 @@ llvm::SmallVector GetTuple(mlir::Operation::operand_range values, } // namespace namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) { @@ -829,7 +829,7 @@ LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) { } LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) { - // Failure on purpose because `xla_hlo::ReturnOp` will be handled by + // Failure on purpose because `mhlo::ReturnOp` will be handled by // special purpose logic in `ConvertToHloModule::Lower`. return failure(); } @@ -943,7 +943,7 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { } } // namespace -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir #include "tensorflow/compiler/mlir/xla/operator_writers.inc" @@ -1060,7 +1060,7 @@ LogicalResult ConvertToHloModule::Lower( return success(); } - if (isa(inst)) { + if (isa(inst)) { // Construct the return value for the function. If there are multiple // values returned, then create a tuple, else return value directly. xla::XlaOp return_value; @@ -1405,7 +1405,7 @@ void AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto* binding, } // Validates and populates dynamic parameter bindings from a module's entry -// function `xla_hlo.padding_map` argument attributes to a `xla::HloModuleProto` +// function `mhlo.padding_map` argument attributes to a `xla::HloModuleProto` // `DynamicParameterBindingProto`. LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module, xla::HloModuleProto* hlo_module_proto, diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 7c2aaa381ba3e0..108544d96ff71c 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -73,8 +73,8 @@ static StringRef GetClientBuilder(const Operator& op) { } static void BuildOperator(const Operator& op, raw_ostream& os) { - os << "mlir::LogicalResult ExportXlaOp(mlir::xla_hlo::" - << op.getCppClassName() << " op, OpLoweringContext ctx) {\n" + os << "mlir::LogicalResult ExportXlaOp(mlir::mhlo::" << op.getCppClassName() + << " op, OpLoweringContext ctx) {\n" << " auto& value_map = *ctx.values;\n" << " auto result = op.getResult();\n"; @@ -164,12 +164,12 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { Operator op(def); // Cast to the current operation and build the exporter. - os << " if (auto xla_op = llvm::dyn_cast(op)) {\n"; os << " return "; // The autogenerated converters aren't in the same namespace. // TODO(jpienaar): Reconsider this. - if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::xla_hlo::"; + if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::mhlo::"; os << "ExportXlaOp(xla_op, lowering_context);\n"; os << " }\n"; } diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir deleted file mode 100644 index 1954c3344df98a..00000000000000 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ /dev/null @@ -1,457 +0,0 @@ -// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s - -// CHECK-LABEL: add_fold -func @add_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> - %1 = xla_hlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<[6, 8, 10, 12]> - %2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - return %2 : tensor<4xi64> -} - -// CHECK-LABEL: add_scalar_fold -func @add_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<1> : tensor<4xi64> - %1 = xla_hlo.constant dense<5> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<6> - %2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - return %2 : tensor<4xi64> -} - -// CHECK-LABEL: add_fold_float -func @add_fold_float() -> tensor<4xf64> { - %0 = xla_hlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> - %1 = xla_hlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64> - // CHECK: xla_hlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]> - %2 = "xla_hlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) - return %2 : tensor<4xf64> -} - -// CHECK-LABEL: sub_scalar_fold -func @sub_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<5> : tensor<4xi64> - %1 = xla_hlo.constant dense<1> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<4> - %2 = "xla_hlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - return %2 : tensor<4xi64> -} - -// CHECK-LABEL: multiply_scalar_fold -func @multiply_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<5> : tensor<4xi64> - %1 = xla_hlo.constant dense<3> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<15> - %2 = "xla_hlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - return %2 : tensor<4xi64> -} - -// CHECK-LABEL: divide_scalar_fold -func @divide_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<7> : tensor<4xi64> - %1 = xla_hlo.constant dense<5> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<1> - %2 = "xla_hlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - return %2 : tensor<4xi64> -} - -// CHECK-LABEL: divide_fold_float -func @divide_fold_float() -> tensor<4xf64> { - %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> - %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> - // CHECK: xla_hlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]> - %2 = "xla_hlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) - return %2 : tensor<4xf64> -} - -// CHECK-LABEL: max_scalar_fold -func @max_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<7> : tensor<4xi64> - %1 = xla_hlo.constant dense<5> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<7> - %2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - return %2 : tensor<4xi64> -} - -// CHECK-LABEL: max_fold_float -func @max_fold_float() -> tensor<4xf64> { - %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> - %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> - // CHECK: xla_hlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]> - %2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) - return %2 : tensor<4xf64> -} - -// CHECK-LABEL: min_scalar_fold -func @min_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<7> : tensor<4xi64> - %1 = xla_hlo.constant dense<-5> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<-5> - %2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - return %2 : tensor<4xi64> -} - -// CHECK-LABEL: min_fold_float -func @min_fold_float() -> tensor<4xf64> { - %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> - %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> - // CHECK: xla_hlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]> - %2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) - return %2 : tensor<4xf64> -} - -// CHECK-LABEL: concatenate_noop -func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> - %0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32> - - // CHECK: return [[ARG]] - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: concatenate_remove_operand -func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> { - // CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32> - // CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32> - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32> - - // CHECK: return [[ARG0]] - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: concatenate_empty_bool -func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> { - // CHECK: xla_hlo.constant - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> - - return %0 : tensor<0xi1> -} - -// CHECK-LABEL: concatenate_empty_int -func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> { - // CHECK: xla_hlo.constant - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32> - - return %0 : tensor<0xi32> -} - -// CHECK-LABEL: concatenate_empty_float -func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { - // CHECK: xla_hlo.constant - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> - - return %0 : tensor<0xf32> -} - -// CHECK-LABEL: concatenate_const_1D -func @concatenate_const_1D() -> tensor<4xi32> { - // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]> - %0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32> - %1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32> - %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32> - - // CHECK: return [[VAL]] - return %2 : tensor<4xi32> -} - -// CHECK-LABEL: concatenate_const_1D_float -func @concatenate_const_1D_float() -> tensor<4xf32> { - // CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> - - %0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32> - %1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32> - %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32> - - // CHECK: return [[VAL]] - return %2 : tensor<4xf32> -} - -// CHECK-LABEL: concatenate_const_2D_vertical -func @concatenate_const_2D_vertical() -> tensor<2x2xi32> { - // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ - // CHECK-SAME: [0, 1], [2, 3] - // CHECK-SAME: ]> - %0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32> - %1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32> - %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> - - // CHECK: return [[VAL]] - return %2 : tensor<2x2xi32> -} - -// CHECK-LABEL: concatenate_const_2D_horizontal -func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { - // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ - // CHECK-SAME: [0, 2], [1, 3] - // CHECK-SAME: ]> - %0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32> - %1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32> - %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> - - // CHECK: return [[VAL]] - return %2 : tensor<2x2xi32> -} - -// CHECK-LABEL: dynamic_slice_variable_start -func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { - // CHECK: "xla_hlo.dynamic-slice" - %1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> - return %1 : tensor<1x4xi32> -} - -// CHECK-LABEL: dynamic_slice_constant_start -func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0) - // CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64> - // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} - // CHECK: return %[[RESULT]] : tensor<2xi32> - %0 = xla_hlo.constant dense<1> : tensor - %1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> - return %1 : tensor<2xi32> -} - -// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape -func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor { - // CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0) - // CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64> - // CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64> - // CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64> - // CHECK: return %[[RESULT]] : tensor - %0 = xla_hlo.constant dense<1> : tensor - %1 = xla_hlo.constant dense<0> : tensor - %2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor - return %2 : tensor -} - -// CHECK-LABEL: slice_2D_noop -// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64> -func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { - %0 = "xla_hlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>) - - // CHECK-NEXT: return [[ARG]] - return %0 : tensor<2x2xi64> -} - -// CHECK-LABEL: slice_1D_fold -func @slice_1D_fold() -> tensor<2xi64> { - %0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<[7, 9]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) - return %1 : tensor<2xi64> -} - -// CHECK-LABEL: slice_1D_fp -func @slice_1D_fp() -> tensor<2xf32> { - %0 = xla_hlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32> - // CHECK: xla_hlo.constant dense<[7.000000e+00, 9.000000e+00]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>) - return %1 : tensor<2xf32> -} - -// CHECK-LABEL: slice_1D_strided_fold -func @slice_1D_strided_fold() -> tensor<2xi64> { - %0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<[7, 10]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) - return %1 : tensor<2xi64> -} - -// CHECK-LABEL: slice_2D_fold -func @slice_2D_fold() -> tensor<2x2xi64> { - %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> - // CHECK-NEXT: xla_hlo.constant dense<[ - // CHECK-SAME: [6, 7], - // CHECK-SAME: [10, 11] - // CHECK-SAME: ]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>) - return %1 : tensor<2x2xi64> -} - -// CHECK-LABEL: slice_2D_fold_horizontal -func @slice_2D_fold_horizontal() -> tensor<1x4xi64> { - %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> - // CHECK-NEXT: xla_hlo.constant dense<[ - // CHECK-SAME: [0, 1, 2, 3] - // CHECK-SAME: ]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>) - return %1 : tensor<1x4xi64> -} - -// CHECK-LABEL: slice_2D_fold_vertical -func @slice_2D_fold_vertical() -> tensor<4x1xi64> { - %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> - // CHECK-NEXT: xla_hlo.constant dense<[ - // CHECK-SAME: [2], [6], [10], [14] - // CHECK-SAME: ]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>) - return %1 : tensor<4x1xi64> -} - -// CHECK-LABEL: slice_concat_fold_first -func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) - // CHECK: return %arg0 - return %1 : tensor<1x5xf32> -} - -// CHECK-LABEL: slice_concat_fold_second -func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) - // CHECK: return %arg1 - return %1 : tensor<1x5xf32> -} - -// CHECK-LABEL: slice_concat_fold_second_with_slice -func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - // CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>) - - // CHECK: return [[SLICE]] - return %1 : tensor<1x4xf32> -} - -// CHECK-LABEL: slice_concat_fold_middle -func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> - // CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>) - - // CHECK: return [[SLICE]] - return %1 : tensor<1x5xf32> -} - -// CHECK-LABEL: slice_concat_fold_two -func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> { - // CHECK: [[CONCAT:%.+]] = "xla_hlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64} - %0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> - - // CHECK: [[SLICE:%.+]] = "xla_hlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>) - - // CHECK: return [[SLICE]] - return %1 : tensor<2x5xf32> -} - -// CHECK-LABEL: func @broadcast_in_dim_identity -func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { - // CHECK: return %arg0 - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> - return %0 : tensor<2x3x4xf32> -} - -// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts -func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { - // CHECK: xla_hlo.broadcast_in_dim - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -// CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation -func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: xla_hlo.broadcast_in_dim - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - - -// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic -func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { - // CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32> - %0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> - // CHECK: return %[[RESULT]] : tensor<5x4xf32> - return %0 : tensor<5x4xf32> -} - -// CHECK-LABEL: @complex_expand_fold -func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex>) - %1 = "xla_hlo.real"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) - %2 = "xla_hlo.imag"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) - // CHECK: return %arg0, %arg1 - return %1, %2 : tensor<4xf32>, tensor<4xf32> -} - -// CHECK-LABEL: @complex_collapse_fold -func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { - %0 = "xla_hlo.real"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) - %1 = "xla_hlo.imag"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) - %2 = "xla_hlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> - // CHECK: return %arg0 - return %2 : tensor<4xcomplex> -} - -// CHECK-LABEL: @dynamic_iota_is_static -func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> { - // CHECK: [[RESULT:%.*]] = "xla_hlo.iota" - // CHECK: return [[RESULT]] - %0 = "xla_hlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32> - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: @iota_not_lowered_to_constant -func @iota_not_lowered_to_constant() -> tensor<4xi32> { - // CHECK: [[RESULT:%.*]] = "xla_hlo.iota" - // CHECK: return [[RESULT]] - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: @unary_einsum -func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { - // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor - // CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"} - %0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -// CHECK-LABEL: func @fold_copy -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { - // CHECK: return [[ARG]] - %0 = "xla_hlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic -func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> { - // CHECK: xla_hlo.reshape - %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32> - return %0 : tensor<4x1xf32> -} - -// CHECK-LABEL: do_not_dce_while_with_outfeed -func @do_not_dce_while_with_outfeed(%arg0: tensor) -> tensor { - // CHECK: xla_hlo.while - %0 = "xla_hlo.while"(%arg0) ( { - ^bb0(%arg1: tensor): - %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () - }, { - ^bb0(%arg1: tensor): - %1 = "xla_hlo.create_token"() : () -> !xla_hlo.token - // Side-effecting op outfeed present inside while. - %2 = "xla_hlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor, !xla_hlo.token) -> !xla_hlo.token - "xla_hlo.return"(%arg1) : (tensor) -> () - }) : (tensor) -> tensor - - return %arg0 : tensor -} - -// CHECK-LABEL: dce_while_without_side_effect -func @dce_while_without_side_effect(%arg0: tensor) -> tensor { - // CHECK-NOT: xla_hlo.while - %0 = "xla_hlo.while"(%arg0) ( { - ^bb0(%arg1: tensor): - %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () - }, { - ^bb0(%arg1: tensor): - %1 = "xla_hlo.create_token"() : () -> !xla_hlo.token - "xla_hlo.return"(%arg1) : (tensor) -> () - }) : (tensor) -> tensor - - return %arg0 : tensor -} diff --git a/tensorflow/compiler/mlir/xla/tests/convert.mlir b/tensorflow/compiler/mlir/xla/tests/convert.mlir deleted file mode 100644 index 26d91132d32b73..00000000000000 --- a/tensorflow/compiler/mlir/xla/tests/convert.mlir +++ /dev/null @@ -1,225 +0,0 @@ -// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s - -// ----- - -// CHECK-LABEL: func @same_type -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @same_type(%arg: tensor) -> tensor { - %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor - // CHECK-NEXT: return [[ARG]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @int_widening -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @int_widening(%arg: tensor) -> tensor { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor - %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor - // CHECK-NEXT: return [[RES]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @int_narrowing -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @int_narrowing(%arg: tensor) -> tensor { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor - %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor - // CHECK-NEXT: return [[RES]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @float_int -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @float_int(%arg: tensor) -> tensor { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor - %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor - // CHECK-NEXT: return [[RES]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @int_float -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @int_float(%arg: tensor) -> tensor { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor - %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor - // CHECK-NEXT: return [[RES]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @high_rank_tensor -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32> - %0 = "xla_hlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32> - // CHECK-NEXT: return [[RES]] - return %0 : tensor<2x3xf32> -} - -// ----- - - -// CHECK-LABEL: func @const_same_type -func @const_same_type() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_float_int -func @const_float_int() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42.0> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_int_float -func @const_int_float() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor - %cst = xla_hlo.constant dense<4> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_negative_int_float -func @const_negative_int_float() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-4.{{0*}}e+00> : tensor - %cst = xla_hlo.constant dense<-4> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_int_bf16 -func @const_int_bf16() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor - %cst = xla_hlo.constant dense<4> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_bf16_int -func @const_bf16_int() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42.0> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_int_narrowing -func @const_int_narrowing() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_int_widening -func @const_int_widening() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_negative_int_widening -func @const_negative_int_widening() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-42> : tensor - %cst = xla_hlo.constant dense<-42> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_float_narrowing -func @const_float_narrowing() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor - %cst = xla_hlo.constant dense<4.2> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_f32_bf16 -func @const_f32_bf16() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+01> : tensor - %cst = xla_hlo.constant dense<42.0> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_bf16_f64 -func @const_bf16_f64() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor - %cst = xla_hlo.constant dense<4.2> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_bf16_int -func @const_bf16_int() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42.0> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - - -// ----- - -// CHECK-LABEL: func @const_high_rank_tensor -func @const_high_rank_tensor() -> tensor<2x3xi32> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[ - // CHECK-SAME: [1, 2, 3], [4, 5, 6] - // CHECK-SAME: ]> : tensor<2x3xi32> - %cst = xla_hlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> - %0 = "xla_hlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32> - // CHECK-NEXT: return [[CST]] - return %0 : tensor<2x3xi32> -} - diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir index 5c8cc84304058a..09a85177fae15f 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir @@ -1,308 +1,308 @@ // RUN: xla-opt -split-input-file -xla-hlo-to-lhlo-with-xla %s | FileCheck --enable-var-scope %s // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.abs +// CHECK: lmhlo.abs // CHECK-SAME: %[[ARG0]], %[[VIEW]] - %abs = "xla_hlo.abs"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %abs = "mhlo.abs"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %abs : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.add +// CHECK: lmhlo.add // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return - %res = "xla_hlo.add"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.add"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> -// CHECK: lhlo.and +// CHECK: lmhlo.and // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return - %res = "xla_hlo.and"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %res = "mhlo.and"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %res : tensor<2x2xi32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.ceil +// CHECK: lmhlo.ceil // CHECK-SAME: %[[ARG0]], %[[VIEW]] - %res = "xla_hlo.ceil"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.ceil"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcomplex> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex> -// CHECK: lhlo.complex +// CHECK: lmhlo.complex // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return - %res = "xla_hlo.complex"(%value0, %value1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) + %res = "mhlo.complex"(%value0, %value1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) return %res : tensor<1x2xcomplex> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xcomplex> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex> -// CHECK: lhlo.cosine +// CHECK: lmhlo.cosine // CHECK-SAME: %[[ARG0]], %[[VIEW]] // CHECK-NEXT: return - %res = "xla_hlo.cosine"(%value0) : (tensor<1x2xcomplex>) -> tensor<1x2xcomplex> + %res = "mhlo.cosine"(%value0) : (tensor<1x2xcomplex>) -> tensor<1x2xcomplex> return %res : tensor<1x2xcomplex> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.divide +// CHECK: lmhlo.divide // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return - %res = "xla_hlo.divide"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.divide"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.exponential +// CHECK: lmhlo.exponential // CHECK-SAME: %[[ARG0]], %[[VIEW]] - %res = "xla_hlo.exponential"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.exponential"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.log +// CHECK: lmhlo.log // CHECK-SAME: %[[ARG0]], %[[VIEW]] - %res = "xla_hlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.maximum +// CHECK: lmhlo.maximum // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return - %res = "xla_hlo.maximum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.maximum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.minimum +// CHECK: lmhlo.minimum // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return - %res = "xla_hlo.minimum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.minimum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.multiply +// CHECK: lmhlo.multiply // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return - %res = "xla_hlo.multiply"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.multiply"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.negate +// CHECK: lmhlo.negate // CHECK-SAME: %[[ARG0]], %[[VIEW]] - %res = "xla_hlo.negate"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.negate"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<8xi8> func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32> -// CHECK: lhlo.real +// CHECK: lmhlo.real // CHECK-SAME: %[[ARG0]], %[[VIEW]] - %res = "xla_hlo.real"(%value0) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) + %res = "mhlo.real"(%value0) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) return %res : tensor<1x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<8xi8> func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32> -// CHECK: lhlo.imag +// CHECK: lmhlo.imag // CHECK-SAME: %[[ARG0]], %[[VIEW]] - %res = "xla_hlo.imag"(%value0) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) + %res = "mhlo.imag"(%value0) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) return %res : tensor<1x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> -// CHECK: lhlo.remainder +// CHECK: lmhlo.remainder // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return - %res = "xla_hlo.remainder"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %res = "mhlo.remainder"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %res : tensor<2x2xi32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.rsqrt +// CHECK: lmhlo.rsqrt // CHECK-SAME: %[[ARG0]], %[[VIEW]] - %res = "xla_hlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 -// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {xla_lhlo.params = 2 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {lmhlo.params = 2 // CHECK-SAME: %[[ARG3:.*]]: memref<16xi8> func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.select +// CHECK: lmhlo.select // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]] // CHECK-NEXT: return - %0 = "xla_hlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) + %0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) return %0 : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.sign +// CHECK: lmhlo.sign // CHECK-SAME: %[[ARG0]], %[[VIEW]] - %res = "xla_hlo.sign"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.sign"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.sqrt +// CHECK: lmhlo.sqrt // CHECK-SAME: %[[ARG0]], %[[VIEW]] - %res = "xla_hlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> -// CHECK: lhlo.subtract +// CHECK: lmhlo.subtract // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return - %res = "xla_hlo.subtract"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %res = "mhlo.subtract"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %res : tensor<2x2xi32> } // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.tanh +// CHECK: lmhlo.tanh // CHECK-SAME: %[[ARG0]], %[[VIEW]] - %res = "xla_hlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %res = "mhlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } @@ -311,16 +311,16 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @main // CHECK-SAME: %[[ARG0:.*]]: memref<5x5xi32> // CHECK-SAME: %[[ARG1:.*]]: memref<5x5xf32> -// CHECK-SAME: %[[ARG2:.*]]: memref<100xi8> {xla_lhlo.alloc = 0 -// CHECK-SAME: %[[ARG3:.*]]: memref<100xi8> {xla_lhlo.alloc = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<100xi8> {lmhlo.alloc = 0 +// CHECK-SAME: %[[ARG3:.*]]: memref<100xi8> {lmhlo.alloc = 1 // CHECK: %[[VIEW0:.*]] = std.view %[[ARG2]]{{.*}} : memref<100xi8> to memref<5x5xi32> // CHECK: %[[VIEW1:.*]] = std.view %[[ARG3]]{{.*}} : memref<100xi8> to memref<5x5xf32> -// CHECK: "xla_lhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]]) +// CHECK: "lmhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]]) func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> tuple, tensor<5x5xf32>> { - %res = "xla_hlo.sort"(%key, %value) ({ + %res = "mhlo.sort"(%key, %value) ({ ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): - %ret = "xla_hlo.compare"(%c, %d) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%ret) : (tensor) -> () + %ret = "mhlo.compare"(%c, %d) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%ret) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true}: (tensor<5x5xi32>, tensor<5x5xf32>) -> tuple, tensor<5x5xf32>> return %res : tuple, tensor<5x5xf32>> diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir index d442319e7b2b30..cc07624d63df3d 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir @@ -3,14 +3,14 @@ // Current allocation will lead to one buffer argument for the "value" and // another one for the output, an no returned values. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 : index}, -// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {xla_lhlo.alloc = 0 : index, xla_lhlo.liveout = true} +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 : index}, +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {lmhlo.alloc = 0 : index, lmhlo.liveout = true} // CHECK-SAME: ) { func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { // The only expected instruction is a copy from the input into the output. // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][%[[C0]]][] : memref<16xi8> to memref<2x2xf32> - // CHECK: xla_lhlo.copy + // CHECK: lmhlo.copy // CHECK-SAME: %[[ARG0]], %[[OUTPUT]] return %value : tensor<2x2xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir index ce6dbe66581557..de03921f091e9d 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir @@ -15,11 +15,11 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> // CHECK: [[BCASTHEAD:%.*]] = "shape.broadcast"([[LHSHEAD]], [[RHSHEAD]]) : (!shape.shape, !shape.shape) -> !shape.shape // CHECK: [[LHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[LHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape // CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] : tensor<3xindex> -// CHECK: [[LHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32> +// CHECK: [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32> // CHECK: [[RHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[RHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape // CHECK: [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]] : tensor<3xindex> -// CHECK: [[RHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32> -// CHECK: [[RESULT:%.*]] = "xla_hlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> +// CHECK: [[RHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32> +// CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> // CHECK: return [[RESULT]] : tensor<3x4x4xf32> // CHECK: } @@ -29,9 +29,9 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> { // CHECK-LABEL: func @batchmatmulv2_lhs_batch -// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} -// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} -// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { +// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} +// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, // CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, @@ -42,9 +42,9 @@ func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { // CHECK-LABEL: func @batchmatmulv2_rhs_batch -// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} -// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} -// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { +// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} +// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, // CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, @@ -55,7 +55,7 @@ func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-LABEL: func @batchmatmulv2_dynamic -// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, // CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, @@ -66,10 +66,10 @@ func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> { // CHECK-LABEL: func @batchmatmulv2_adj_real -// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { -// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>, +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { +// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, -// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>, +// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>, // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32> return %0 : tensor<5x4xf32> @@ -78,14 +78,14 @@ func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { // CHECK-LABEL: func @batchmatmulv2_adj_complex( // CHECK-SAME: [[LHS:%.*]]: tensor<5x2xcomplex>, [[RHS:%.*]]: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { -// CHECK: [[LHSRE:%.*]] = "xla_hlo.real"([[LHS]]) -// CHECK: [[LHSIM:%.*]] = "xla_hlo.imag"([[LHS]]) -// CHECK: [[LHSIMNEG:%.*]] = "xla_hlo.negate"([[LHSIM]]) -// CHECK: [[LHSCONJ:%.*]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]]) -// CHECK: [[RHSRE:%.*]] = "xla_hlo.real"([[RHS]]) -// CHECK: [[RHSIM:%.*]] = "xla_hlo.imag"([[RHS]]) -// CHECK: [[RHSIMNEG:%.*]] = "xla_hlo.negate"([[RHSIM]]) -// CHECK: [[RHSCONJ:%.*]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) +// CHECK: [[LHSRE:%.*]] = "mhlo.real"([[LHS]]) +// CHECK: [[LHSIM:%.*]] = "mhlo.imag"([[LHS]]) +// CHECK: [[LHSIMNEG:%.*]] = "mhlo.negate"([[LHSIM]]) +// CHECK: [[LHSCONJ:%.*]] = "mhlo.complex"([[LHSRE]], [[LHSIMNEG]]) +// CHECK: [[RHSRE:%.*]] = "mhlo.real"([[RHS]]) +// CHECK: [[RHSIM:%.*]] = "mhlo.imag"([[RHS]]) +// CHECK: [[RHSIMNEG:%.*]] = "mhlo.negate"([[RHSIM]]) +// CHECK: [[RHSCONJ:%.*]] = "mhlo.complex"([[RHSRE]], [[RHSIMNEG]]) // CHECK: shape.shape_of [[LHSCONJ]] // CHECK: shape.shape_of [[RHSCONJ]] %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir index 3d270a52f48244..45c90d26ab4c76 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -11,8 +11,8 @@ // CHECK-LABEL: func @add func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32> + // CHECK-NEXT: %[[SUM0:.*]] = mhlo.add %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %[[SUM1:.*]] = mhlo.add %[[SUM0]], %arg0 : tensor<2xi32> // CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> @@ -24,8 +24,8 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { // patterns unambiguous and more interesting (once broadcastable trait is // fixed upstream). func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-NEXT: xla_hlo.add %[[LHS_BCAST]], %arg1 + // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1 %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } @@ -34,8 +34,8 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x // TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream // broadcastable bug is fixed (helps make the CHECK matching unambiguous) func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - // CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} - // CHECK-NEXT: xla_hlo.add %[[LHS_BCAST]], %arg1 + // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} + // CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1 %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0: tensor<4x4x4x4xi32> } @@ -50,9 +50,9 @@ func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor : tensor<1xi64>} - // CHECK-NEXT: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor + // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor // CHECK-NEXT: shape.assuming_yield %[[RESULT]] %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor @@ -60,7 +60,7 @@ func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> @@ -68,7 +68,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-LABEL: func @shift_left func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> + // CHECK: mhlo.shift_left %arg0, %arg1 : tensor<4xi32> %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -82,21 +82,21 @@ func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> + // CHECK-NEXT: mhlo.maximum %arg0, %arg1 : tensor<4xf32> %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // CHECK-LABEL: func @minimum func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> + // CHECK-NEXT: mhlo.minimum %arg0, %arg1 : tensor<4xf32> %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // CHECK-LABEL: func @mul func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> @@ -104,14 +104,14 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-LABEL: func @real_div func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } // CHECK-LABEL: func @sub func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> @@ -119,7 +119,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-LABEL: func @shift_right func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + // CHECK: mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -140,7 +140,7 @@ func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8 // CHECK-LABEL: func @and func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { - // CHECK-NEXT: xla_hlo.and + // CHECK-NEXT: mhlo.and %0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -154,28 +154,28 @@ func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { // CHECK-LABEL: func @or func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { - // CHECK-NEXT: xla_hlo.or + // CHECK-NEXT: mhlo.or %0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> return %0: tensor<2xi1> } // CHECK-LABEL: func @bitwise_or func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: xla_hlo.or + // CHECK-NEXT: mhlo.or %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %0: tensor<4xi32> } // CHECK-LABEL: func @bitwise_and func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: xla_hlo.and + // CHECK-NEXT: mhlo.and %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %0: tensor<4xi32> } // CHECK-LABEL: func @pow func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: xla_hlo.power + // CHECK-NEXT: mhlo.power %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> return %0: tensor<2xf32> } @@ -188,32 +188,33 @@ func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @equal func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} %0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } // CHECK-LABEL: func @equal_dynamic func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] - // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[LHS_SHAPE]], %[[RHS_SHAPE]] - // CHECK-NEXT: shape.assuming %[[WITNESS]] -> (tensor) { - // CHECK-DAG: %[[LHS_SHAPE1:.+]] = shape.shape_of %arg0 - // CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE1]], %[[RHS_SHAPE]]) - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} - // CHECK-NEXT: shape.assuming_yield %[[RESULT]] + // TODO(jpienaar): Uncomment below when fallout from https://reviews.llvm.org/D83194 fixed. + // NOT-CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // NOT-CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] + // NOT-CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[LHS_SHAPE]], %[[RHS_SHAPE]] + // NOT-CHECK-NEXT: shape.assuming %[[WITNESS]] -> (tensor) { + // NOT-CHECK-DAG: %[[LHS_SHAPE1:.+]] = shape.shape_of %arg0 + // NOT-CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE1]], %[[RHS_SHAPE]]) + // NOT-CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] + // NOT-CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // NOT-CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // NOT-CHECK-NEXT: %[[RESULT:.+]] = "mhlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} + // NOT-CHECK-NEXT: shape.assuming_yield %[[RESULT]] %0 = "tf.Equal"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0: tensor } // CHECK-LABEL: func @equal_broadcast func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "EQ"} + // CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: "mhlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "EQ"} %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0: tensor<1x2xi1> } @@ -255,7 +256,7 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> // CHECK-LABEL: func @notequal func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} %0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -268,15 +269,15 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { // CHECK-LABEL: func @greater func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} + // CHECK: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } // CHECK-LABEL: func @broadcast_greater func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "GT"} + // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: "mhlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "GT"} %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0: tensor<1x2xi1> } @@ -291,9 +292,9 @@ func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} + // CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-NEXT: "mhlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} %0 = "tf.Greater"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor } @@ -307,21 +308,21 @@ func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> { // CHECK-LABEL: func @greater_equal func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } // CHECK-LABEL: func @less func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } // CHECK-LABEL: func @less_equal func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir index a4ceb8655af7c9..5a9089756a90a5 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir @@ -3,40 +3,40 @@ // CHECK-LABEL: @if func @if(%arg0: tensor, %arg1: tensor) -> (tensor) attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { - // CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - // CHECK: [[VAL1:%.+]] = "xla_hlo.tuple"(%arg0, %arg1) - // CHECK: [[VAL2:%.+]] = "xla_hlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( { + // CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + // CHECK: [[VAL1:%.+]] = "mhlo.tuple"(%arg0, %arg1) + // CHECK: [[VAL2:%.+]] = "mhlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( { // CHECK: ^bb0(%arg2: tuple, tensor>): - // CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 0 : i32} - // CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 1 : i32} + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32} // CHECK: [[VAL6:%.+]] = call @cond_true([[VAL4]], [[VAL5]]) - // CHECK: [[VAL7:%.+]] = "xla_hlo.tuple"([[VAL6]]) - // CHECK: "xla_hlo.return"([[VAL7]]) : (tuple>) -> () + // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]]) + // CHECK: "mhlo.return"([[VAL7]]) : (tuple>) -> () // CHECK: }, { // CHECK: ^bb0(%arg2: tuple, tensor>) - // CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 0 : i32} - // CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 1 : i32} + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32} // CHECK: [[VAL6:%.+]] = call @cond_false([[VAL4]], [[VAL5]]) - // CHECK: [[VAL7:%.+]] = "xla_hlo.tuple"([[VAL6]]) - // CHECK: "xla_hlo.return"([[VAL7]]) : (tuple>) -> () + // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]]) + // CHECK: "mhlo.return"([[VAL7]]) : (tuple>) -> () // CHECK: }) %1 = "tf.If"(%0, %arg0, %arg1) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _lower_using_switch_merge = true, _output_shapes = ["tfshape$"], device = "", else_branch = @cond_false, is_stateless = true, name = "cond", output_shapes = [#tf.shape<>], then_branch = @cond_true} : (tensor, tensor, tensor) -> tensor - // CHECK: [[VAL3:%.+]] = "xla_hlo.get_tuple_element"([[VAL2]]) {index = 0 : i32} + // CHECK: [[VAL3:%.+]] = "mhlo.get_tuple_element"([[VAL2]]) {index = 0 : i32} // CHECK: return [[VAL3]] return %1 : tensor } func @cond_false(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { - %0 = "xla_hlo.exponential"(%arg1) : (tensor) -> tensor + %0 = "mhlo.exponential"(%arg1) : (tensor) -> tensor return %0 : tensor } func @cond_true(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { - %0 = "xla_hlo.log"(%arg0) : (tensor) -> tensor + %0 = "mhlo.log"(%arg0) : (tensor) -> tensor return %0 : tensor } @@ -45,42 +45,42 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { // CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor, %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> (tensor, tensor) func @case(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { %0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor]} : (tensor, tensor, tensor) -> (tensor, tensor) - // CHECK: %[[TUPLE_INPUT:.*]] = "xla_hlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor, tensor) -> tuple, tensor> - // CHECK: %[[CASE:.*]]:2 = "xla_hlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( { + // CHECK: %[[TUPLE_INPUT:.*]] = "mhlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor, tensor) -> tuple, tensor> + // CHECK: %[[CASE:.*]]:2 = "mhlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( { // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple, tensor>): - // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple, tensor>) -> tensor - // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple, tensor>) -> tensor // CHECK: %[[CALL_EXP:.*]]:2 = call @exponential(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor, tensor) -> (tensor, tensor) - // CHECK: "xla_hlo.return"(%[[CALL_EXP]]#0, %[[CALL_EXP]]#1) : (tensor, tensor) -> () + // CHECK: "mhlo.return"(%[[CALL_EXP]]#0, %[[CALL_EXP]]#1) : (tensor, tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple, tensor>): - // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple, tensor>) -> tensor - // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple, tensor>) -> tensor // CHECK: %[[CALL_LOG:.*]]:2 = call @log(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor, tensor) -> (tensor, tensor) - // CHECK: "xla_hlo.return"(%[[CALL_LOG]]#0, %[[CALL_LOG]]#1) : (tensor, tensor) -> () + // CHECK: "mhlo.return"(%[[CALL_LOG]]#0, %[[CALL_LOG]]#1) : (tensor, tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple, tensor>): - // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple, tensor>) -> tensor - // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple, tensor>) -> tensor // CHECK: %[[CALL_FLOOR:.*]]:2 = call @floor(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor, tensor) -> (tensor, tensor) - // CHECK: "xla_hlo.return"(%[[CALL_FLOOR]]#0, %[[CALL_FLOOR]]#1) : (tensor, tensor) -> () + // CHECK: "mhlo.return"(%[[CALL_FLOOR]]#0, %[[CALL_FLOOR]]#1) : (tensor, tensor) -> () // CHECK: }) : (tensor, tuple, tensor>, tuple, tensor>, tuple, tensor>) -> (tensor, tensor) return %0#0, %0#1 : tensor, tensor // CHECK: return %[[CASE]]#0, %[[CASE]]#1 : tensor, tensor } func @exponential(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "xla_hlo.exponential"(%arg1) : (tensor) -> tensor + %0 = "mhlo.exponential"(%arg1) : (tensor) -> tensor return %0, %arg1 : tensor, tensor } func @log(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "xla_hlo.log"(%arg0) : (tensor) -> tensor + %0 = "mhlo.log"(%arg0) : (tensor) -> tensor return %0, %arg1 : tensor, tensor } func @floor(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + %0 = "mhlo.floor"(%arg0) : (tensor) -> tensor return %0, %arg1 : tensor, tensor } @@ -88,44 +88,44 @@ func @floor(%arg0: tensor, %arg1: tensor) -> (tensor, tensor // CHECK-LABEL: func @while func @while(%arg0: tensor {tf_saved_model.index_path = [0]}) -> (tensor {tf_saved_model.index_path = []}) attributes {tf._input_shapes = ["tfshape$"]} { - // CHECK: [[VAL0:%.+]] = xla_hlo.constant dense<0> - // CHECK: [[VAL1:%.+]] = xla_hlo.constant dense<-1> - %0 = xla_hlo.constant dense<0> : tensor - %1 = xla_hlo.constant dense<-1> : tensor - // CHECK: [[VAL2:%.+]] = "xla_hlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]]) - // CHECK: [[VAL3:%.+]] = "xla_hlo.while"([[VAL2]]) ( { + // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> + // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<-1> : tensor + // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]]) + // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( { // CHECK: ^bb0(%arg1: tuple, tensor, tensor>): - // CHECK: [[VAL7:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} - // CHECK: [[VAL8:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 1 : i32} - // CHECK: [[VAL9:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 2 : i32} + // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} + // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 1 : i32} + // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 2 : i32} // CHECK: [[VAL10:%.+]] = call @while_cond([[VAL7]], [[VAL8]], [[VAL9]]) - // CHECK: "xla_hlo.return"([[VAL10]]) + // CHECK: "mhlo.return"([[VAL10]]) // CHECK: }, { // CHECK: ^bb0(%arg1: tuple, tensor, tensor>): - // CHECK: [[VAL7:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} - // CHECK: [[VAL8:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 1 : i32} - // CHECK: [[VAL9:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 2 : i32} + // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} + // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 1 : i32} + // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 2 : i32} // CHECK: [[VAL10:%.+]]:3 = call @while_body([[VAL7]], [[VAL8]], [[VAL9]]) - // CHECK: [[VAL11:%.+]] = "xla_hlo.tuple"([[VAL10]]#0, [[VAL10]]#1, [[VAL10]]#2) - // CHECK: "xla_hlo.return"([[VAL11]]) + // CHECK: [[VAL11:%.+]] = "mhlo.tuple"([[VAL10]]#0, [[VAL10]]#1, [[VAL10]]#2) + // CHECK: "mhlo.return"([[VAL11]]) // CHECK: }) : (tuple, tensor, tensor>) -> tuple, tensor, tensor> - // CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"([[VAL3]]) {index = 0 : i32} - // CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"([[VAL3]]) {index = 1 : i32} - // CHECK: [[VAL6:%.+]] = "xla_hlo.get_tuple_element"([[VAL3]]) {index = 2 : i32} + // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32} + // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 1 : i32} + // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 2 : i32} // CHECK: return [[VAL6]] %2:3 = "tf.While"(%0, %1, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_INT32", "tfdtype$DT_INT32"], _lower_using_switch_merge = true, _num_original_outputs = 3 : i64, _output_shapes = ["tfshape$", "tfshape$", "tfshape$"], body = @while_body, cond = @while_cond, device = "", is_stateless = true, name = "while", output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) return %2#2 : tensor } func @while_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = ["tfshape$", "tfshape$", "tfshape$"]} { - %0 = xla_hlo.constant dense<10> : tensor - %1 = "xla_hlo.compare"(%arg2, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + %0 = mhlo.constant dense<10> : tensor + %1 = "mhlo.compare"(%arg2, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor return %1 : tensor } func @while_body(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$", "tfshape$", "tfshape$"]} { - %0 = xla_hlo.constant dense<1> : tensor - %1 = xla_hlo.add %arg2, %0 : tensor - %2 = xla_hlo.add %arg0, %0 : tensor + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.add %arg2, %0 : tensor + %2 = mhlo.add %arg0, %0 : tensor return %2, %arg1, %1 : tensor, tensor, tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 86a7f2b9e09cce..ad4ef4b8f770bf 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -7,7 +7,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-LABEL: abs // expected-error@+1 {{unsupported device}} func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> // return %[[RESULT]] @@ -23,8 +23,8 @@ func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } -// CHECK-LABEL: not_whitelisted_op -func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor) -> tensor { +// CHECK-LABEL: not_allowlisted_op +func @not_allowlisted_op(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor) -> tensor { // CHECK: tf.TensorListReserve %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor) -> tensor>> // CHECK: tf.TensorListGetItem @@ -54,7 +54,7 @@ func @dynamic_operand(%arg0: tensor) -> tensor { func @tuple_type(%arg0: tuple, tensor>) -> tensor { // Verifies that the pass can handle operands of non-tensor type like tuple // from non TensorFlow ops. - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor return %0 : tensor } @@ -69,9 +69,9 @@ func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> { // CHECK-LABEL: multiple_dialect_ops func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: xla_hlo.negate - %0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: xla_hlo.abs + // CHECK: mhlo.negate + %0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: mhlo.abs %1 = "tf.Abs"(%0) : (tensor<2xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> @@ -79,21 +79,21 @@ func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: binary_op func @binary_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: xla_hlo.atan2 %arg0, %arg1 : tensor<2xf32> + // CHECK: mhlo.atan2 %arg0, %arg1 : tensor<2xf32> %0 = "tf.Atan2"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: binary_op_broadcast func @binary_op_broadcast(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> { - // CHECK: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4x1xf32> - // CHECK: %[[RESHAPE0:.*]] = "xla_hlo.reshape"(%[[BROADCAST0]]) : (tensor<4x4x1xf32>) -> tensor<4x4xf32> - // CHECK: %[[UPDATED_ARG0:.*]] = "xla_hlo.broadcast_in_dim"(%[[RESHAPE0]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32> + // CHECK: %[[BROADCAST0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4x1xf32> + // CHECK: %[[RESHAPE0:.*]] = "mhlo.reshape"(%[[BROADCAST0]]) : (tensor<4x4x1xf32>) -> tensor<4x4xf32> + // CHECK: %[[UPDATED_ARG0:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE0]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32> - // CHECK: %[[RESHAPE1:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<4x1x4xf32>) -> tensor<4x4xf32> - // CHECK: %[[UPDATED_ARG1:.*]] = "xla_hlo.broadcast_in_dim"(%[[RESHAPE1]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32> + // CHECK: %[[RESHAPE1:.*]] = "mhlo.reshape"(%arg1) : (tensor<4x1x4xf32>) -> tensor<4x4xf32> + // CHECK: %[[UPDATED_ARG1:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE1]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32> - // CHECK: %[[RESULT:.*]] = xla_hlo.atan2 %[[UPDATED_ARG0]], %[[UPDATED_ARG1]] : tensor<4x4x4xf32> + // CHECK: %[[RESULT:.*]] = mhlo.atan2 %[[UPDATED_ARG0]], %[[UPDATED_ARG1]] : tensor<4x4x4xf32> // CHECK: return %[[RESULT]] : tensor<4x4x4xf32> %0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32> @@ -102,23 +102,23 @@ func @binary_op_broadcast(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> t // CHECK-LABEL: func @ternary_op func @ternary_op(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK: "xla_hlo.select"(%arg0, %arg1, %arg2) + // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } // CHECK-LABEL: func @convert func @convert(%arg0: tensor<2xi32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + // CHECK: "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @constant func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: %[[SCALAR_ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[ONE:.*]] = "xla_hlo.broadcast_in_dim"(%[[SCALAR_ONE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<2xf32> - // CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[ONE]], %arg0 : tensor<2xf32> + // CHECK: %[[SCALAR_ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[ONE:.*]] = "mhlo.broadcast_in_dim"(%[[SCALAR_ONE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + // CHECK: %[[RESULT:.*]] = mhlo.divide %[[ONE]], %arg0 : tensor<2xf32> // CHECK: return %[[RESULT]] %0 = "tf.Inv"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> @@ -127,7 +127,7 @@ func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @greater func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -136,14 +136,14 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x2xf64>, %[[ARG1:.*]]: tensor, func @const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> { - // CHECK: "xla_hlo.pad"(%[[ARG0]], %[[ARG1]]) + // CHECK: "mhlo.pad"(%[[ARG0]], %[[ARG1]]) // CHECK-SAME-DAG: edge_padding_high = dense<[1, 2]> : tensor<2xi64> // CHECK-SAME-DAG: edge_padding_low = dense<[2, 1]> : tensor<2xi64> // CHECK-SAME-DAG: interior_padding = dense<[1, 0]> : tensor<2xi64> - %0 = xla_hlo.constant dense<[2, 1]> : tensor<2xi32> - %1 = xla_hlo.constant dense<[1, 2]> : tensor<2xi32> - %2 = xla_hlo.constant dense<[1, 0]> : tensor<2xi32> + %0 = mhlo.constant dense<[2, 1]> : tensor<2xi32> + %1 = mhlo.constant dense<[1, 2]> : tensor<2xi32> + %2 = mhlo.constant dense<[1, 0]> : tensor<2xi32> %3 = "tf.XlaPad"(%arg0, %arg1, %0, %1, %2) : (tensor<2x2xf64>, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64> return %3 : tensor<6x5xf64> } @@ -156,7 +156,7 @@ func @non_const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor, %arg2: tensor // CHECK-LABEL: dynamic_result_type func @dynamic_result_type(%arg0: tensor<2xf32>) -> tensor<*xf32> { - // CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: tensor_cast %0 : tensor<2xf32> to tensor<*xf32> %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<*xf32> @@ -166,7 +166,7 @@ func @dynamic_result_type(%arg0: tensor<2xf32>) -> tensor<*xf32> { func @truncated_normal() -> tensor<2x2xf32> { // CHECK-NOT: tf.TruncatedNormal - %0 = xla_hlo.constant dense<[2, 2]> : tensor<2xi32> + %0 = mhlo.constant dense<[2, 2]> : tensor<2xi32> %1 = "tf.TruncatedNormal"(%0) {T = i32, device = "", dtype = f32, seed = 0 : i64, seed2 = 1950157571 : i64} : (tensor<2xi32>) -> tensor<2x2xf32> return %1 : tensor<2x2xf32> } @@ -175,21 +175,21 @@ func @truncated_normal() -> tensor<2x2xf32> { // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32> func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> { - // CHECK: %[[SLICE0:.*]] = "xla_hlo.slice"(%[[ARG2]]) + // CHECK: %[[SLICE0:.*]] = "mhlo.slice"(%[[ARG2]]) // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64> // CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64> // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> - // CHECK: %[[DIM0:.*]] = "xla_hlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor + // CHECK: %[[DIM0:.*]] = "mhlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor - // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%[[ARG2]]) + // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%[[ARG2]]) // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> // CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64> // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> - // CHECK: %[[DIM1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor + // CHECK: %[[DIM1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor - // CHECK: "xla_hlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]]) + // CHECK: "mhlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]]) %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32> return %0: tensor<3x4xi32> @@ -199,32 +199,32 @@ func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2 // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xi32>, %[[ARG1:.*]]: tensor<3xf32>, %[[ARG2:.*]]: tensor) func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tensor) -> tensor<3x3xf32> { -// CHECK: %[[CST:.*]] = xla_hlo.constant dense<3> : tensor<2xi32> -// CHECK: %[[DEFAULT:.*]] = "xla_hlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<3x3xf32> +// CHECK: %[[CST:.*]] = mhlo.constant dense<3> : tensor<2xi32> +// CHECK: %[[DEFAULT:.*]] = "mhlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<3x3xf32> -// CHECK: %[[RESULT:.*]] = "xla_hlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( { +// CHECK: %[[RESULT:.*]] = "mhlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( { // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): // no predecessors -// CHECK: "xla_hlo.return"(%[[ARG4]]) : (tensor) -> () +// CHECK: "mhlo.return"(%[[ARG4]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers // CHECK-SAME: index_vector_dim = 1 : i64 // CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64> // CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64> -// CHECK-SAME: update_window_dims = dense<[]> : tensor<0xi64> +// CHECK-SAME: update_window_dims = dense<> : tensor<0xi64> // CHECK-SAME: unique_indices = false // CHECK-SAME: (tensor<3x3xf32>, tensor<3x2xi32>, tensor<3xf32>) -> tensor<3x3xf32> // return %[[RESULT]] : tensor<3x3xf32> - %cst = xla_hlo.constant dense<3> : tensor<2xi32> + %cst = mhlo.constant dense<3> : tensor<2xi32> %0 = "tf.SparseToDense"(%arg0, %cst, %arg1, %arg2) {validate_indices = true}: (tensor<3x2xi32>, tensor<2xi32>, tensor<3xf32>, tensor) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> } // CHECK-LABEL: fft func @fft(%arg0: tensor<3x5x8xcomplex>) -> tensor<3x5x8xcomplex> { - // CHECK: "xla_hlo.fft"(%arg0) + // CHECK: "mhlo.fft"(%arg0) %0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex>) -> tensor<3x5x8xcomplex> return %0 : tensor<3x5x8xcomplex> } @@ -238,7 +238,7 @@ func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> te // CHECK-LABEL: mirror_pad func @mirror_pad(%arg0: tensor<2x3xcomplex>) -> tensor<4x7xcomplex> { - %0 = xla_hlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32> + %0 = mhlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32> // CHECK-NOT: tf.MirrorPad %1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex>, tensor<2x2xi32>) -> tensor<4x7xcomplex> return %1 : tensor<4x7xcomplex> @@ -254,7 +254,7 @@ func @bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> { // CHECK-LABEL: arg_min func @arg_min(%arg0: tensor<6xf64>) -> tensor { // CHECK-NOT: ArgMin - %0 = xla_hlo.constant dense<0> : tensor + %0 = mhlo.constant dense<0> : tensor %1 = "tf.ArgMin"(%arg0, %0) : (tensor<6xf64>, tensor) -> tensor return %1 : tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 10d69221979640..b4eef909750660 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -13,7 +13,7 @@ // CHECK-LABEL: fusedBatchNorm_notraining func @fusedBatchNorm_notraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: "xla_hlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK: "mhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) return %0#0 : tensor<8x8x8x8xf32> } @@ -28,7 +28,7 @@ func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, // CHECK-LABEL: fusedBatchNormV3_noTraining func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: "xla_hlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) return %0#0 : tensor<8x8x8x8xf32> } @@ -36,11 +36,11 @@ func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf3 // CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision // CHECK-SAME: ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>) func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) { - // CHECK: [[CONVERT_X:%.*]] = "xla_hlo.convert"([[X]]) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK: [[Y:%.*]] = "xla_hlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK: [[Y:%.*]] = "mhlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) - // CHECK: [[Y_CONVERT:%.*]] = "xla_hlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> - // CHECK: [[DUMMY:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<0xf32> + // CHECK: [[Y_CONVERT:%.*]] = "mhlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK: [[DUMMY:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<0xf32> // CHECK: [[DUMMY_CAST:%.*]] = tensor_cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32> // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]] return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32> @@ -48,45 +48,45 @@ func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %a // CHECK-LABEL: fusedBatchNormV3_training func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: %[[RESULT0:.*]] = "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> - // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK: %[[VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK: xla_hlo.constant - // CHECK: xla_chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK: mhlo.constant + // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> return %0#0 : tensor<8x8x8x8xf32> } // CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> { - // CHECK: %[[RESULT0:.*]] = "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK: %[[VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: return %[[VAR]] return %0#4 : tensor<8xf32> } // CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { - // CHECK: %[[RESULT0:.*]] = "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK-DAG: %[[BATCH_MEAN:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} - // CHECK-DAG: %[[BATCH_VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} + // CHECK-DAG: %[[BATCH_MEAN:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} + // CHECK-DAG: %[[BATCH_VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} - // CHECK: %[[FACTOR:.*]] = xla_hlo.constant dense<1.00195694> - // CHECK: %[[CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]] + // CHECK: %[[FACTOR:.*]] = mhlo.constant dense<1.00195694> + // CHECK: %[[CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]] - // CHECK-DAG: %[[ALPHA:.*]] = xla_hlo.constant dense<0.199999988> - // CHECK-DAG: %[[BETA:.*]] = xla_hlo.constant dense<8.000000e-01> + // CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988> + // CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01> - // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg3 - // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]] - // CHECK: %[[NEW_BATCH_MEAN:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] + // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3 + // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]] + // CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] - // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg4 - // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] - // CHECK: %[[NEW_BATCH_VAR:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] + // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4 + // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] + // CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]] return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> @@ -94,22 +94,22 @@ func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, // CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK: "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK: "mhlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK: "xla_hlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK: "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> return %0#0 : tensor<8x8x8x8xbf16> } // CHECK-LABEL: fusedBatchNormV3_NCHW func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK: "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) return %0#0 : tensor<8x8x8x8xf32> } // CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor) { - // CHECK: "xla_hlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor, tensor, tensor, tensor, tensor) -> tensor %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) return %0#0 : tensor } @@ -130,42 +130,42 @@ func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor, // CHECK-LABEL: fusedBatchNormGrad_noTraining func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - - // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - - // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( { + // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor + + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { // CHECK-NEXT: ^bb0(%arg5: tensor, %arg6: tensor): // no predecessors - // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg5, %arg6 : tensor - // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor) -> () + // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg5, %arg6 : tensor + // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor) -> () // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> + // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> - // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> + // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> - // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( { + // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { // CHECK-NEXT: ^bb0(%arg5: tensor, %arg6: tensor): // no predecessors - // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg5, %arg6 : tensor - // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor) -> () + // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg5, %arg6 : tensor + // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor) -> () // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) @@ -174,13 +174,13 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-LABEL: fusedBatchNormGrad_Training func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> - // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) @@ -189,43 +189,43 @@ func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x // CHECK-LABEL: fusedBatchNormGradV2_noTraining func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - - // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - - // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( { + // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor + + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { // CHECK-NEXT: ^bb0(%arg5: tensor, %arg6: tensor): // no predecessors - // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg5, %arg6 : tensor - // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor) -> () + // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg5, %arg6 : tensor + // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor) -> () // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> + // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> - // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> - // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( { + // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { // CHECK-NEXT: ^bb0(%arg5: tensor, %arg6: tensor): // no predecessors - // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg5, %arg6 : tensor - // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor) -> () + // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg5, %arg6 : tensor + // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor) -> () // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) @@ -234,13 +234,13 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-LABEL: fusedBatchNormGradV2_Training func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> - // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) @@ -249,10 +249,10 @@ func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK: %[[x_backprop:.*]] = "xla_hlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK: %[[x_backprop:.*]] = "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) @@ -261,13 +261,13 @@ func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32> // CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> - // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) @@ -276,43 +276,43 @@ func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, // CHECK-LABEL: fusedBatchNormGradV3_noTraining func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - - // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - - // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( { + // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor + + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { // CHECK-NEXT: ^bb0(%arg6: tensor, %arg7: tensor): // no predecessors - // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg6, %arg7 : tensor - // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor) -> () + // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg6, %arg7 : tensor + // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor) -> () // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> + // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> - // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> - // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( { + // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { // CHECK-NEXT: ^bb0(%arg6: tensor, %arg7: tensor): // no predecessors - // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg6, %arg7 : tensor - // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor) -> () + // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg6, %arg7 : tensor + // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor) -> () // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) @@ -321,13 +321,13 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-LABEL: fusedBatchNormGradV3_Training func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> - // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) @@ -336,10 +336,10 @@ func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK: %[[x_backprop:.*]] = "xla_hlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK: %[[x_backprop:.*]] = "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) @@ -348,13 +348,13 @@ func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32> // CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> - // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> - // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) @@ -363,43 +363,43 @@ func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, // CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - - // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - - // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> - // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( { + // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor + + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> + + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64> + // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { // CHECK-NEXT: ^bb0(%arg6: tensor, %arg7: tensor): // no predecessors - // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg6, %arg7 : tensor - // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor) -> () + // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg6, %arg7 : tensor + // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor) -> () // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> + // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> - // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> - // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> - // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( { + // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64> + // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { // CHECK-NEXT: ^bb0(%arg6: tensor, %arg7: tensor): // no predecessors - // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg6, %arg7 : tensor - // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor) -> () + // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg6, %arg7 : tensor + // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor) -> () // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> - // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> + // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) @@ -408,7 +408,7 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: %{{.*}} = "xla_hlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + // CHECK: %{{.*}} = "mhlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) return %0#0 : tensor<8x8x8x8xf32> } @@ -421,9 +421,9 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] - // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} - // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] + // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } @@ -432,9 +432,9 @@ func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tens func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] - // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] + // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } @@ -443,9 +443,9 @@ func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tens func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] - // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] + // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor, tensor) -> tensor return %0 : tensor } @@ -457,17 +457,17 @@ func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tenso // CHECK-LABEL: func @diag_part // CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { - // CHECK: %[[RS:.*]] = "xla_hlo.reshape"(%[[ARG]]) : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32> - // CHECK-DAG: %[[IOTA0:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<12x12xi32> - // CHECK-DAG: %[[IOTA1:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<12x12xi32> - // CHECK-DAG: %[[COMP:.*]] = "xla_hlo.compare"(%[[IOTA0]], %[[IOTA1]]) {comparison_direction = "EQ"} : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> - // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[ZERO_MAT:.*]] = "xla_hlo.broadcast"(%[[ZERO]]) {broadcast_sizes = dense<12> : tensor<2xi64>} : (tensor) -> tensor<12x12xf32> - // CHECK-DAG: %[[SEL:.*]] = "xla_hlo.select"(%[[COMP]], %[[RS]], %[[ZERO_MAT]]) : (tensor<12x12xi1>, tensor<12x12xf32>, tensor<12x12xf32>) -> tensor<12x12xf32> - // CHECK-DAG: %[[RED:.*]] = "xla_hlo.reduce"(%[[SEL]], %[[ZERO]]) - // CHECK-DAG: xla_hlo.add + // CHECK: %[[RS:.*]] = "mhlo.reshape"(%[[ARG]]) : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32> + // CHECK-DAG: %[[IOTA0:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<12x12xi32> + // CHECK-DAG: %[[IOTA1:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<12x12xi32> + // CHECK-DAG: %[[COMP:.*]] = "mhlo.compare"(%[[IOTA0]], %[[IOTA1]]) {comparison_direction = "EQ"} : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ZERO_MAT:.*]] = "mhlo.broadcast"(%[[ZERO]]) {broadcast_sizes = dense<12> : tensor<2xi64>} : (tensor) -> tensor<12x12xf32> + // CHECK-DAG: %[[SEL:.*]] = "mhlo.select"(%[[COMP]], %[[RS]], %[[ZERO_MAT]]) : (tensor<12x12xi1>, tensor<12x12xf32>, tensor<12x12xf32>) -> tensor<12x12xf32> + // CHECK-DAG: %[[RED:.*]] = "mhlo.reduce"(%[[SEL]], %[[ZERO]]) + // CHECK-DAG: mhlo.add // CHECK-DAG: {dimensions = dense<0> : tensor<1xi64>} : (tensor<12x12xf32>, tensor) -> tensor<12xf32> - // CHECK-DAG: %[[RES:.*]] = "xla_hlo.reshape"(%[[RED]]) : (tensor<12xf32>) -> tensor<4x3xf32> + // CHECK-DAG: %[[RES:.*]] = "mhlo.reshape"(%[[RED]]) : (tensor<12xf32>) -> tensor<4x3xf32> // CHECK-DAG: return %[[RES]] : tensor<4x3xf32> %0 = "tf.DiagPart"(%arg0) : (tensor<4x3x4x3xf32>) -> tensor<4x3xf32> return %0: tensor<4x3xf32> @@ -479,14 +479,14 @@ func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { // CHECK-LABEL: func @einsum func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { - // CHECK: xla_hlo.einsum + // CHECK: mhlo.einsum %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> return %0: tensor<2x4xf32> } // CHECK-LABEL: func @unary_einsum func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { - // CHECK: xla_hlo.unary_einsum + // CHECK: mhlo.unary_einsum %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> return %0: tensor<2x2xf32> } @@ -497,21 +497,21 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @floordiv_broadcast_i32 func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} - // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) - // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[ONES:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) - // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) + // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} + // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) + // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) + // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) + // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) + // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> return %0: tensor<2x3xi32> @@ -519,21 +519,21 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te // CHECK-LABEL: func @floordiv_reverse_broadcast_i32 func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} - // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) - // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[ONES:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) - // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] - // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) + // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} + // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) + // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) + // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) + // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) + // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] + // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0: tensor<2x3xi32> @@ -541,8 +541,8 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 // CHECK-LABEL: func @floordiv_f32 func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: %[[DIV:.*]] = xla_chlo.broadcast_divide %arg0, %arg0 - // CHECK-NEXT: %[[FLOOR:.*]] = "xla_hlo.floor"(%[[DIV]]) + // CHECK-NEXT: %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0 + // CHECK-NEXT: %[[FLOOR:.*]] = "mhlo.floor"(%[[DIV]]) // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> return %0: tensor<2xf32> @@ -550,11 +550,11 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @floordiv_bf16 func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { - // CHECK-NEXT: xla_hlo.convert - // CHECK-NEXT: xla_hlo.convert - // CHECK-NEXT: xla_chlo.broadcast_divide - // CHECK-NEXT: xla_hlo.floor - // CHECK-NEXT: xla_hlo.convert + // CHECK-NEXT: mhlo.convert + // CHECK-NEXT: mhlo.convert + // CHECK-NEXT: chlo.broadcast_divide + // CHECK-NEXT: mhlo.floor + // CHECK-NEXT: mhlo.convert // CHECK-NEXT: return %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16> return %0: tensor<2xbf16> @@ -562,8 +562,8 @@ func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-LABEL: func @floordiv_f16_broadcast func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - // CHECK-NEXT: xla_chlo.broadcast_divide - // CHECK-NEXT: xla_hlo.floor + // CHECK-NEXT: chlo.broadcast_divide + // CHECK-NEXT: mhlo.floor // CHECK-NEXT: return %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> return %0: tensor<2x3xf16> @@ -571,21 +571,21 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te // CHECK-LABEL: func @floordiv_dynamic func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} - // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) - // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[ONES:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) - // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) + // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} + // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) + // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) + // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) + // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) + // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor @@ -600,16 +600,16 @@ func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x // CHECK-LABEL: func @floormod_broadcast_numerator func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] - // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] + // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0: tensor<2x3xi32> @@ -617,16 +617,16 @@ func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) // CHECK-LABEL: func @floormod_broadcast_denominator func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} - // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} + // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> return %0: tensor<2x3xi32> @@ -634,16 +634,16 @@ func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32 // CHECK-LABEL: func @floormod_dynamic func @floormod_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} - // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} + // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor @@ -664,9 +664,9 @@ func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> - // CHECK: [[CST:%.+]] = xla_hlo.constant + // CHECK: [[CST:%.+]] = mhlo.constant // CHECK: [[CAST:%.+]] = tensor_cast [[CST]] : tensor<4xi32> to tensor<4xi32> - // CHECK: "xla_hlo.dynamic_broadcast_in_dim"(%arg0, [[CAST]]) + // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg0, [[CAST]]) // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32> return %0 : tensor<16x16x16x16xf32> @@ -678,31 +678,31 @@ func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { // CHECK-LABEL: func @complex func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { - // CHECK: "xla_hlo.complex" + // CHECK: "mhlo.complex" %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> return %1 : tensor<3xcomplex> } // CHECK-LABEL: func @imag func @imag(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { - // CHECK: "xla_hlo.imag" + // CHECK: "mhlo.imag" %1 = "tf.Imag"(%arg0) : (tensor<3xcomplex>) -> tensor<3xf32> return %1 : tensor<3xf32> } // CHECK-LABEL: func @real func @real(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { - // CHECK: "xla_hlo.real" + // CHECK: "mhlo.real" %1 = "tf.Real"(%arg0) : (tensor<3xcomplex>) -> tensor<3xf32> return %1 : tensor<3xf32> } // CHECK-LABEL: func @conj func @conj(%arg0: tensor<3xcomplex>) -> tensor<3xcomplex> { - // CHECK-DAG: [[R1:%.*]] = "xla_hlo.real"(%arg0) - // CHECK-DAG: [[R2:%.*]] = "xla_hlo.imag"(%arg0) - // CHECK-DAG: [[R3:%.*]] = "xla_hlo.negate"([[R2]]) - // CHECK: [[R4:%.*]] = "xla_hlo.complex"([[R1]], [[R3]]) + // CHECK-DAG: [[R1:%.*]] = "mhlo.real"(%arg0) + // CHECK-DAG: [[R2:%.*]] = "mhlo.imag"(%arg0) + // CHECK-DAG: [[R3:%.*]] = "mhlo.negate"([[R2]]) + // CHECK: [[R4:%.*]] = "mhlo.complex"([[R1]], [[R3]]) %1 = "tf.Conj"(%arg0) : (tensor<3xcomplex>) -> tensor<3xcomplex> return %1 : tensor<3xcomplex> } @@ -713,7 +713,7 @@ func @conj(%arg0: tensor<3xcomplex>) -> tensor<3xcomplex> { // CHECK-LABEL: func @concat_v2 func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - // CHECK: "xla_hlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> return %1 : tensor<6x3xf32> @@ -721,7 +721,7 @@ func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf3 // CHECK-LABEL: func @concat_v2_neg_axis func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - // CHECK: "xla_hlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> %axis = "tf.Const"() { value = dense<-2> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> @@ -730,7 +730,7 @@ func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tens // CHECK-LABEL: func @concat_v2_1d_axis func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { - // CHECK: "xla_hlo.concatenate"({{.*}}) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> + // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64> %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32> @@ -759,7 +759,7 @@ func @concat_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<* // CHECK-LABEL: func @padv2_1D func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor<6xf32> { %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64> - // CHECK: "xla_hlo.pad"(%arg0, %arg1) { + // CHECK: "mhlo.pad"(%arg0, %arg1) { // CHECK-SAME: edge_padding_high = dense<2> : tensor<1xi64>, // CHECK-SAME: edge_padding_low = dense<1> : tensor<1xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64> @@ -770,7 +770,7 @@ func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor<6xf32> { // CHECK-LABEL: func @padv2_2D func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64> - // CHECK: "xla_hlo.pad"(%arg0, %arg1) { + // CHECK: "mhlo.pad"(%arg0, %arg1) { // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> @@ -781,7 +781,7 @@ func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { // CHECK-LABEL: func @padv2_i32_paddings func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32> - // CHECK: "xla_hlo.pad"(%arg0, %arg1) { + // CHECK: "mhlo.pad"(%arg0, %arg1) { // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> @@ -827,11 +827,11 @@ func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> { // CHECK-LABEL: func @infeed_dequeue_tuple func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) { -// CHECK: [[TOKEN:%.*]] = "xla_hlo.create_token"() : () -> !xla_hlo.token -// CHECK: [[INFEED:%.*]] = "xla_hlo.infeed"([[TOKEN]]) {infeed_config = ""} : (!xla_hlo.token) -> tuple, tensor<4xf32>>, !xla_hlo.token> -// CHECK: [[INFEED_VAL:%.*]] = "xla_hlo.get_tuple_element"([[INFEED]]) {index = 0 : i32} : (tuple, tensor<4xf32>>, !xla_hlo.token>) -> tuple, tensor<4xf32>> -// CHECK: [[RES_1:%.*]] = "xla_hlo.get_tuple_element"([[INFEED_VAL]]) {index = 0 : i32} : (tuple, tensor<4xf32>>) -> tensor<3xi32> -// CHECK: [[RES_2:%.*]] = "xla_hlo.get_tuple_element"([[INFEED_VAL]]) {index = 1 : i32} : (tuple, tensor<4xf32>>) -> tensor<4xf32> +// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token +// CHECK: [[INFEED:%.*]] = "mhlo.infeed"([[TOKEN]]) {infeed_config = ""} : (!mhlo.token) -> tuple, tensor<4xf32>>, !mhlo.token> +// CHECK: [[INFEED_VAL:%.*]] = "mhlo.get_tuple_element"([[INFEED]]) {index = 0 : i32} : (tuple, tensor<4xf32>>, !mhlo.token>) -> tuple, tensor<4xf32>> +// CHECK: [[RES_1:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 0 : i32} : (tuple, tensor<4xf32>>) -> tensor<3xi32> +// CHECK: [[RES_2:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 1 : i32} : (tuple, tensor<4xf32>>) -> tensor<4xf32> // CHECK: return [[RES_1]], [[RES_2]] %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>) return %0#0, %0#1 : tensor<3xi32>, tensor<4xf32> @@ -850,7 +850,7 @@ func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) { // CHECK-LABEL: infeed_dequeue_tuple_sharding func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { - // CHECK: "xla_hlo.infeed" + // CHECK: "mhlo.infeed" // An additional sharding is added at the end to account for token result. // Proto debug string: // type: TUPLE @@ -864,7 +864,7 @@ func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { // tile_assignment_dimensions: 1 // tile_assignment_devices: 0 // } - // CHECK-SAME: xla_hlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00" + // CHECK-SAME: mhlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00" %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32> return %0 : tensor<8xi32> } @@ -875,14 +875,14 @@ func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { // CHECK-LABEL: @const func @const() -> tensor<2xi32> { - // CHECK: xla_hlo.constant dense<0> : tensor<2xi32> + // CHECK: mhlo.constant dense<0> : tensor<2xi32> %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>) return %0: tensor<2xi32> } // CHECK-LABEL: @const_dynamic_output func @const_dynamic_output() -> tensor<*xi32> { - // CHECK: [[CONST:%.*]] = xla_hlo.constant dense<0> : tensor<2xi32> + // CHECK: [[CONST:%.*]] = mhlo.constant dense<0> : tensor<2xi32> // CHECK: [[CAST:%.*]] = tensor_cast [[CONST]] : tensor<2xi32> to tensor<*xi32> %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>) // CHECK: return [[CAST]] @@ -891,7 +891,7 @@ func @const_dynamic_output() -> tensor<*xi32> { // CHECK-LABEL: @opaque_const func @opaque_const() -> tensor>> { - // CHECK-NOT: xla_hlo.constant + // CHECK-NOT: mhlo.constant %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor} : () -> tensor>> return %0 : tensor>> } @@ -903,7 +903,7 @@ func @opaque_const() -> tensor>> { // CHECK-LABEL: matmul_notranspose // CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<7x11xf32>) func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tensor<5x11xf32> { - // CHECK: "xla_hlo.dot"(%[[A]], %[[B]]) + // CHECK: "mhlo.dot"(%[[A]], %[[B]]) %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32> return %0 : tensor<5x11xf32> @@ -912,8 +912,8 @@ func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tensor<5x // CHECK-LABEL: matmul_transpose_b // CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<11x7xf32>) func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { - // CHECK: %[[UPDATED_B:.*]] = "xla_hlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} - // CHECK: "xla_hlo.dot"(%[[A]], %[[UPDATED_B]]) + // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} + // CHECK: "mhlo.dot"(%[[A]], %[[UPDATED_B]]) %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> return %0 : tensor<5x11xf32> @@ -922,9 +922,9 @@ func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x // CHECK-LABEL: matmul_transpose_both // CHECK-SAME: (%[[A:.*]]: tensor<7x5xf32>, %[[B:.*]]: tensor<11x7xf32>) func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { - // CHECK: %[[UPDATED_A:.*]] = "xla_hlo.transpose"(%[[A]]) {permutation = dense<[1, 0]> : tensor<2xi64>} - // CHECK: %[[UPDATED_B:.*]] = "xla_hlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} - // CHECK: "xla_hlo.dot"(%[[UPDATED_A]], %[[UPDATED_B]]) + // CHECK: %[[UPDATED_A:.*]] = "mhlo.transpose"(%[[A]]) {permutation = dense<[1, 0]> : tensor<2xi64>} + // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} + // CHECK: "mhlo.dot"(%[[UPDATED_A]], %[[UPDATED_B]]) %0 = "tf.MatMul"(%a, %b) {transpose_a = true, transpose_b = true} : (tensor<7x5xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> return %0 : tensor<5x11xf32> @@ -933,7 +933,7 @@ func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor // Verify that MatMul with ranked inputs are lowered to HLO. // CHECK-LABEL: matmul_ranked func @matmul_ranked(%a: tensor, %b: tensor<7x?xf32>) -> tensor { - // CHECK: "xla_hlo.dot" + // CHECK: "mhlo.dot" %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor, tensor<7x?xf32>) -> tensor return %0 : tensor @@ -942,7 +942,7 @@ func @matmul_ranked(%a: tensor, %b: tensor<7x?xf32>) -> tensor // Verify that MatMul with unranked inputs are lowered to HLO. // CHECK-LABEL: matmul_unranked func @matmul_unranked(%a: tensor<*xf32>, %b: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.dot" + // CHECK: "mhlo.dot" %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> @@ -951,7 +951,7 @@ func @matmul_unranked(%a: tensor<*xf32>, %b: tensor<*xf32>) -> tensor<*xf32> { // Verify SparseMatMul is legalized to dot. // CHECK-LABEL: test_sparse_mat_mul func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> { - // CHECK: "xla_hlo.dot" + // CHECK: "mhlo.dot" %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32> return %0: tensor<3x5xf32> } @@ -963,31 +963,31 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten // CHECK-LABEL: matrix_band_part // CHECK-SAME: (%[[INPUT:.*]]: tensor<64x64xbf16>, %[[LOWER:.*]]: tensor, %[[UPPER:.*]]: tensor) func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<64x64xbf16> { - // CHECK: %[[M:.*]] = xla_hlo.constant dense<64> : tensor - // CHECK: %[[N:.*]] = xla_hlo.constant dense<64> : tensor + // CHECK: %[[M:.*]] = mhlo.constant dense<64> : tensor + // CHECK: %[[N:.*]] = mhlo.constant dense<64> : tensor - // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: %[[A:.*]] = "xla_hlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - // CHECK: %[[B:.*]] = "xla_hlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor, tensor, tensor) -> tensor + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[A:.*]] = "mhlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor, tensor, tensor) -> tensor - // CHECK: %[[C:.*]] = "xla_hlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - // CHECK: %[[D:.*]] = "xla_hlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor, tensor, tensor) -> tensor + // CHECK: %[[C:.*]] = "mhlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor, tensor, tensor) -> tensor - // CHECK: %[[E:.*]] = "xla_hlo.convert"(%[[B]]) : (tensor) -> tensor - // CHECK: %[[F:.*]] = "xla_hlo.negate"(%[[E]]) : (tensor) -> tensor + // CHECK: %[[E:.*]] = "mhlo.convert"(%[[B]]) : (tensor) -> tensor + // CHECK: %[[F:.*]] = "mhlo.negate"(%[[E]]) : (tensor) -> tensor - // CHECK: %[[X:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16> - // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16> - // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16> - // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<64x64xi1> + // CHECK: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16> + // CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16> + // CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16> + // CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<64x64xi1> - // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<64x64xi1> + // CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor) -> tensor + // CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<64x64xi1> - // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<64x64xi1> + // CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1> - // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<64x64xbf16> - // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) + // CHECK: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16> + // CHECK: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) // CHECK: return %[[R]] %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor, tensor) -> tensor<64x64xbf16> return %0 : tensor<64x64xbf16> @@ -996,18 +996,18 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: ten // CHECK-LABEL: matrix_band_part_2 // CHECK-SAME: (%[[INPUT:.*]]: tensor<12x24x48xbf16>, %[[LOWER:.*]]: tensor, %[[UPPER:.*]]: tensor) func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<12x24x48xbf16> { - // CHECK: %[[X:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xbf16> - // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16> - // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16> + // CHECK: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xbf16> + // CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16> + // CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16> - // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<24x48xi1> + // CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<24x48xi1> - // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<24x48xi1> - // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<24x48xi1> + // CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor) -> tensor + // CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<24x48xi1> + // CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1> - // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> - // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) + // CHECK: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> + // CHECK: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) // CHECK: return %[[R]] %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<12x24x48xbf16>, tensor, tensor) -> tensor<12x24x48xbf16> return %0 : tensor<12x24x48xbf16> @@ -1037,10 +1037,10 @@ func @matrix_band_part_4(%arg0: tensor<24x48xbf16>, %arg1: tensor, %arg2: t // CHECK-LABEL: maxpool_valid_padding // CHECK-SAME: %[[ARG:.*]]: tensor func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { - // CHECK: %[[INIT:.*]] = xla_hlo.constant dense<-2147483648> : tensor - // CHECK: "xla_hlo.reduce_window"(%[[ARG]], %[[INIT]]) - // CHECK: xla_hlo.maximum - // CHECK: xla_hlo.return + // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor + // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK: mhlo.maximum + // CHECK: mhlo.return // CHECK: {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> @@ -1059,10 +1059,10 @@ func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> // CHECK-LABEL: maxpool_3d_valid_padding // CHECK-SAME: %[[ARG:.*]]: tensor func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> { - // CHECK: %[[INIT:.*]] = xla_hlo.constant dense<0xFF800000> : tensor - // CHECK: "xla_hlo.reduce_window"(%[[ARG]], %[[INIT]]) - // CHECK: xla_hlo.maximum - // CHECK: xla_hlo.return + // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor + // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK: mhlo.maximum + // CHECK: mhlo.return // CHECK: {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>} %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> @@ -1085,15 +1085,15 @@ func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x // CHECK-LABEL: @max_pool_grad_valid // CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { - // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[RESULT:.*]] = "xla_hlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( { + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = "xla_hlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor, tensor) -> tensor - // CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () + // CHECK: %[[SELECT_RESULT:.*]] = "mhlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = xla_hlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor - // CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () + // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor + // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { @@ -1108,15 +1108,15 @@ func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: te // CHECK-LABEL: @max_pool_3d_grad_valid // CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { - // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[RESULT:.*]] = "xla_hlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( { + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = "xla_hlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor, tensor) -> tensor - // CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () + // CHECK: %[[SELECT_RESULT:.*]] = "mhlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = xla_hlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor - // CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () + // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor + // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () // CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor) -> tensor<10x8x24x24x64xf32> // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32> %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> @@ -1148,12 +1148,12 @@ func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: // CHECK-LABEL:one_hot func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tensor) -> tensor<3x5xf32> { - // CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32> - // CHECK: %[[BCAST_ARG0:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32> - // CHECK: %[[COMPARE:.*]] = "xla_hlo.compare"(%[[BCAST_ARG0]], %[[IOTA]]) {comparison_direction = "EQ"} : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> - // CHECK: %[[ON_VALUE:.*]] = "xla_hlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> - // CHECK: %[[OFF_VALUE:.*]] = "xla_hlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> - // CHECK: %[[RESULT:.*]] = "xla_hlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> + // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32> + // CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32> + // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[BCAST_ARG0]], %[[IOTA]]) {comparison_direction = "EQ"} : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> + // CHECK: %[[ON_VALUE:.*]] = "mhlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> + // CHECK: %[[OFF_VALUE:.*]] = "mhlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> // CHECK: return %[[RESULT]] : tensor<3x5xf32> %depth = "tf.Const"() { value = dense<5> : tensor } : () -> tensor %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<3x5xf32> @@ -1167,9 +1167,9 @@ func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tenso // CHECK-LABEL: func @outfeed_enqueue_tuple // CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>) func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () { -// CHECK: [[TUPLE:%.*]] = "xla_hlo.tuple"([[VAL_0]], [[VAL_1]]) : (tensor<3xi32>, tensor<4xf32>) -> tuple, tensor<4xf32>> -// CHECK: [[TOKEN:%.*]] = "xla_hlo.create_token"() : () -> !xla_hlo.token -// CHECK: "xla_hlo.outfeed"([[TUPLE]], [[TOKEN]]) {outfeed_config = ""} : (tuple, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token +// CHECK: [[TUPLE:%.*]] = "mhlo.tuple"([[VAL_0]], [[VAL_1]]) : (tensor<3xi32>, tensor<4xf32>) -> tuple, tensor<4xf32>> +// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token +// CHECK: "mhlo.outfeed"([[TUPLE]], [[TOKEN]]) {outfeed_config = ""} : (tuple, tensor<4xf32>>, !mhlo.token) -> !mhlo.token "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> () return } @@ -1180,9 +1180,9 @@ func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> ( // CHECK-LABEL: func @pack func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { - // CHECK: "xla_hlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32> - // CHECK: "xla_hlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32> - // CHECK: "xla_hlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> + // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32> + // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32> + // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> @@ -1251,7 +1251,7 @@ func @unhandled_partitioned_call_2(%arg0: tensor, %arg1: tensor<*xi32>) -> func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> // CHECK: return [[VAL]] : tensor<5xi32> @@ -1262,7 +1262,7 @@ func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) - // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> // CHECK: return [[VAL]] : tensor<5xi32> @@ -1273,7 +1273,7 @@ func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>} + // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>} %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> // CHECK: return [[VAL]] : tensor<5x5xi32> @@ -1314,34 +1314,34 @@ func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (te // CHECK-LABEL: func @relu func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { - // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } // CHECK-LABEL: func @relu_unranked func @relu_unranked(%arg0: tensor) -> tensor { - // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor return %0: tensor } // CHECK-LABEL: func @relu6 func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { - // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: %[[SIX:.*]] = xla_hlo.constant dense<6> : tensor - // CHECK: "xla_hlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor, tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[SIX:.*]] = mhlo.constant dense<6> : tensor + // CHECK: "mhlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor, tensor<1xi32>, tensor) -> tensor<1xi32> %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } // CHECK-LABEL: func @relu6_unranked func @relu6_unranked(%arg0: tensor) -> tensor { - // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: %[[SIX:.*]] = xla_hlo.constant dense<6> : tensor - // CHECK: "xla_hlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor, tensor, tensor) -> tensor + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[SIX:.*]] = mhlo.constant dense<6> : tensor + // CHECK: "mhlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor, tensor, tensor) -> tensor %0 = "tf.Relu6"(%arg0) : (tensor) -> tensor return %0: tensor } @@ -1349,10 +1349,10 @@ func @relu6_unranked(%arg0: tensor) -> tensor { // CHECK-LABEL: func @relu_grad // CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor) func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { - // CHECK-DAG: %[[ZERO_SCALAR:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> - // CHECK-DAG: %[[PRED:.*]] = xla_chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor, tensor) -> tensor - // CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-DAG: %[[ZERO_SCALAR:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32> + // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor, tensor) -> tensor + // CHECK-DAG: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32> %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> @@ -1364,56 +1364,56 @@ func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tens // CHECK-LABEL: func @selectv2 func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) + // CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2) %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } // CHECK-LABEL: func @selectv2_pred_scalar func @selectv2_pred_scalar(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) + // CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2) %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } // CHECK-LABEL: func @selectv2_broadcast_then func @selectv2_broadcast_then(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { - // CHECK: %[[BROADCAST:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> - // CHECK: "xla_hlo.select"(%arg0, %[[BROADCAST]], %arg2) + // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> + // CHECK: "mhlo.select"(%arg0, %[[BROADCAST]], %arg2) %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> return %0: tensor<2x8x8xi32> } // CHECK-LABEL: func @selectv2_broadcast_else func @selectv2_broadcast_else(%arg0: tensor, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> { - // CHECK: %[[BROADCAST:.*]] = "xla_hlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> - // CHECK: "xla_hlo.select"(%arg0, %arg1, %[[BROADCAST]]) + // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> + // CHECK: "mhlo.select"(%arg0, %arg1, %[[BROADCAST]]) %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32> return %0: tensor<2x8x8xi32> } // CHECK-LABEL: func @selectv2_broadcast_pred func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { - // CHECK: %[[BROADCAST:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<2x8x8xi1> - // CHECK: "xla_hlo.select"(%[[BROADCAST]], %arg1, %arg2) + // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<2x8x8xi1> + // CHECK: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2) %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> return %0: tensor<2x8x8xi32> } // CHECK-LABEL: func @selectv2_broadcast_tensor_pred func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { - // CHECK: %[[BROADCAST:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1> - // CHECK: "xla_hlo.select"(%[[BROADCAST]], %arg1, %arg2) + // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1> + // CHECK: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2) %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> return %0: tensor<2x3xf16> } // CHECK-LABEL: func @selectv2_broadcast_all func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { - // CHECK-DAG: %[[BROADCAST_0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1> - // CHECK-DAG: %[[BROADCAST_1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32> - // CHECK-DAG: %[[BROADCAST_2:.*]] = "xla_hlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32> - // CHECK: "xla_hlo.select"(%[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]]) + // CHECK-DAG: %[[BROADCAST_0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1> + // CHECK-DAG: %[[BROADCAST_1:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32> + // CHECK-DAG: %[[BROADCAST_2:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32> + // CHECK: "mhlo.select"(%[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]]) %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32> return %0: tensor<8x8x8xi32> } @@ -1441,33 +1441,33 @@ func @selectv2_unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: te func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // Verify reduce op for max computation and its body. - // CHECK-DAG: %[[NEG_INF:.*]] = xla_hlo.constant dense<0xFF800000> : tensor - // CHECK-DAG: %[[CASTED_INP:.*]] = "xla_hlo.convert"(%[[ARG0]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[CASTED_INP]], %[[NEG_INF]]) - // CHECK: xla_hlo.maximum - // CHECK: "xla_hlo.return" + // CHECK-DAG: %[[NEG_INF:.*]] = mhlo.constant dense<0xFF800000> : tensor + // CHECK-DAG: %[[CASTED_INP:.*]] = "mhlo.convert"(%[[ARG0]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK: %[[MAX:.*]] = "mhlo.reduce"(%[[CASTED_INP]], %[[NEG_INF]]) + // CHECK: mhlo.maximum + // CHECK: "mhlo.return" // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> - // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[CASTED_MAX:.*]] = "mhlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex> - // CHECK: %[[BCAST_MAX:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: %[[SHIFTED_INP:.*]] = xla_hlo.subtract %[[ARG0]], %[[BCAST_MAX]] - // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) + // CHECK: %[[BCAST_MAX:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[SHIFTED_INP:.*]] = mhlo.subtract %[[ARG0]], %[[BCAST_MAX]] + // CHECK: %[[EXP:.*]] = "mhlo.exponential"(%[[SHIFTED_INP]]) // Verify reduce op for summation and its body. - // CHECK-DAG: %[[CASTED_EXP:.*]] = "xla_hlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[CASTED_EXP]], %[[ZERO]]) - // CHECK: xla_hlo.add - // CHECK: "xla_hlo.return" + // CHECK-DAG: %[[CASTED_EXP:.*]] = "mhlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[SUM:.*]] = "mhlo.reduce"(%[[CASTED_EXP]], %[[ZERO]]) + // CHECK: mhlo.add + // CHECK: "mhlo.return" // CHECK: {dimensions = dense<1> : tensor<1xi64>} - // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[CASTED_SUM:.*]] = "mhlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex> - // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[EXP]], %[[BCAST_SUM]] + // CHECK: %[[BCAST_SUM:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = mhlo.divide %[[EXP]], %[[BCAST_SUM]] // CHECK: return %[[RESULT]] %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> @@ -1477,7 +1477,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // Verify intermediate and final shape are correct with dynamic shapes. // CHECK-LABEL: func @dynamic_softmax func @dynamic_softmax(%arg0: tensor) -> tensor { - // CHECK: xla_hlo.divide {{.*}} : tensor + // CHECK: mhlo.divide {{.*}} : tensor %0 = "tf.Softmax"(%arg0) : (tensor) -> tensor return %0: tensor } @@ -1486,8 +1486,8 @@ func @dynamic_softmax(%arg0: tensor) -> tensor { func @bf16_softmax(%arg0: tensor<2x3xbf16>) -> tensor<2x3xbf16> { // Verify that conversion to f32 and then back to bf16 are introduced. - // CHECK: "xla_hlo.convert"({{.*}}) : (tensor<2x3xbf16>) -> tensor<2x3xf32> - // CHECK: "xla_hlo.convert"({{.*}}) : (tensor<2xf32>) -> tensor<2xbf16> + // CHECK: "mhlo.convert"({{.*}}) : (tensor<2x3xbf16>) -> tensor<2x3xf32> + // CHECK: "mhlo.convert"({{.*}}) : (tensor<2xf32>) -> tensor<2xbf16> %0 = "tf.Softmax"(%arg0) : (tensor<2x3xbf16>) -> tensor<2x3xbf16> return %0: tensor<2x3xbf16> @@ -1497,14 +1497,14 @@ func @bf16_softmax(%arg0: tensor<2x3xbf16>) -> tensor<2x3xbf16> { func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> { // Verify that reduce op dimensions and broadcast dimensions are correct. - // CHECK: "xla_hlo.reduce" + // CHECK: "mhlo.reduce" // CHECK: dimensions = dense<3> - // CHECK: "xla_hlo.reduce" + // CHECK: "mhlo.reduce" // CHECK: dimensions = dense<3> // CHECK: {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} - // CHECK: xla_hlo.divide {{.*}} + // CHECK: mhlo.divide {{.*}} %0 = "tf.Softmax"(%arg0) : (tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> return %0: tensor<2x3x4x5xf16> } @@ -1517,14 +1517,14 @@ func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> { // CHECK-LABEL: func @simple_logsoftmax // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: %{{.*}} = "xla_hlo.reduce"({{.*}}) - // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"({{.*}}) - // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[LOG:.*]] = "xla_hlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %{{.*}} = "mhlo.reduce"({{.*}}) + // CHECK: %[[SUM:.*]] = "mhlo.reduce"({{.*}}) + // CHECK: %[[CASTED_SUM:.*]] = "mhlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[LOG:.*]] = "mhlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex> - // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: %[[RESULT:.*]] = xla_hlo.subtract {{.*}}, %[[BCAST_SUM]] + // CHECK: %[[BCAST_SUM:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = mhlo.subtract {{.*}}, %[[BCAST_SUM]] // CHECK: return %[[RESULT]] %0 = "tf.LogSoftmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> @@ -1538,7 +1538,7 @@ func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK-LABEL: func @rfft_1D func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex> { %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: "xla_hlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32> + // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32> %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<8xcomplex> return %0 : tensor<8xcomplex> } @@ -1595,7 +1595,7 @@ func @transpose_noop(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK-LABEL: @transpose_2d func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) - // CHECK: "xla_hlo.transpose" + // CHECK: "mhlo.transpose" %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> return %0 : tensor<3x2xf32> } @@ -1603,7 +1603,7 @@ func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { // CHECK-LABEL: @transpose_3d_int32 func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>) - // CHECK: "xla_hlo.transpose" + // CHECK: "mhlo.transpose" %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32> return %0 : tensor<3x2x1xf32> } @@ -1611,7 +1611,7 @@ func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { // CHECK-LABEL: @transpose_3d func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>) - // CHECK: "xla_hlo.transpose" + // CHECK: "mhlo.transpose" %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> return %0 : tensor<3x2x1xf32> } @@ -1619,7 +1619,7 @@ func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { // CHECK-LABEL: @transpose_dynamic_2d func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) - // CHECK: "xla_hlo.transpose" + // CHECK: "mhlo.transpose" %0 = "tf.Transpose"(%arg0, %permutation) : (tensor, tensor<2xi64>) -> tensor<4x?xf32> return %0 : tensor<4x?xf32> } @@ -1627,7 +1627,7 @@ func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { // CHECK-LABEL: @transpose_unranked_2d func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) - // CHECK: "xla_hlo.transpose" + // CHECK: "mhlo.transpose" %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -1639,245 +1639,245 @@ func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @abs func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @abs_dynamic func @abs_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.abs"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.abs"(%arg0) : (tensor) -> tensor %0 = "tf.Abs"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @abs_unranked func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: func @cast_dynamic_i2f func @cast_dynamic_i2f(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.convert"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.convert"(%arg0) : (tensor) -> tensor %0 = "tf.Cast"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @cast_i2f func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + // CHECK: "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> %0 = "tf.Cast"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @cast_c2f func @cast_c2f(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { - //CHECK: "xla_hlo.convert"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + //CHECK: "mhlo.convert"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: @ceil func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @ceil_dynamic func @ceil_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.ceil"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.ceil"(%arg0) : (tensor) -> tensor %0 = "tf.Ceil"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @ceil_unranked func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: @complex_abs func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { - // CHECK: "xla_hlo.abs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + // CHECK: "mhlo.abs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: @cos func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @cos_dynamic func @cos_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.cosine"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.cosine"(%arg0) : (tensor) -> tensor %0 = "tf.Cos"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @cos_unranked func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: @exp func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @exp_dynamic func @exp_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.exponential"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.exponential"(%arg0) : (tensor) -> tensor %0 = "tf.Exp"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @exp_unranked func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: @floor func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @floor_dynamic func @floor_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.floor"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.floor"(%arg0) : (tensor) -> tensor %0 = "tf.Floor"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @floor_unranked func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: @is_finite func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { - // CHECK: "xla_hlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> + // CHECK: "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> %0 = "tf.IsFinite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> return %0 : tensor<2xi1> } // CHECK-LABEL: func @is_finite_dynamic func @is_finite_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.is_finite"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.is_finite"(%arg0) : (tensor) -> tensor %0 = "tf.IsFinite"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @is_finite_unranked func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> { - // CHECK: "xla_hlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> + // CHECK: "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> %0 = "tf.IsFinite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> return %0 : tensor<*xi1> } // CHECK-LABEL: @log func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @log_dynamic func @log_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.log"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.log"(%arg0) : (tensor) -> tensor %0 = "tf.Log"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @log_unranked func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: @log1p func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @log1p_dynamic func @log1p_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.log_plus_one"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor) -> tensor %0 = "tf.Log1p"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @log1p_unranked func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Log1p"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: func @not_op_unranked func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> { - // CHECK: "xla_hlo.not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + // CHECK: "mhlo.not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> %0 = "tf.LogicalNot"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> return %0 : tensor<*xi1> } // CHECK-LABEL: @neg func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @neg_dynamic func @neg_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.negate"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.negate"(%arg0) : (tensor) -> tensor %0 = "tf.Neg"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @neg_unranked func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: @sigmoid func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor + // CHECK-DAG: [[SCALAR:%.+]] = mhlo.constant dense<5.000000e-01> : tensor // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<2xf32> // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] : tensor<1xindex> - // CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<2xf32> - // CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<2xf32> - // CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK-DAG: [[R3:%.+]] = xla_hlo.multiply [[R2]], [[HALF]] : tensor<2xf32> - // CHECK-DAG: [[R4:%.+]] = xla_hlo.add [[R3]], [[HALF]] : tensor<2xf32> + // CHECK-DAG: [[HALF:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<2xf32> + // CHECK-DAG: [[R1:%.+]] = mhlo.multiply %arg0, [[HALF]] : tensor<2xf32> + // CHECK-DAG: [[R2:%.+]] = "mhlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK-DAG: [[R3:%.+]] = mhlo.multiply [[R2]], [[HALF]] : tensor<2xf32> + // CHECK-DAG: [[R4:%.+]] = mhlo.add [[R3]], [[HALF]] : tensor<2xf32> %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: @sigmoid_complex func @sigmoid_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { - // CHECK: [[R0:%.+]] = xla_hlo.constant dense<(5.000000e-01,0.000000e+00)> : tensor> + // CHECK: [[R0:%.+]] = mhlo.constant dense<(5.000000e-01,0.000000e+00)> : tensor> // CHECK-NOT: tf.Sigmoid %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> return %0 : tensor<2xcomplex> @@ -1885,14 +1885,14 @@ func @sigmoid_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { // CHECK-LABEL: @sigmoid_unranked func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor + // CHECK-DAG: [[SCALAR:%.+]] = mhlo.constant dense<5.000000e-01> : tensor // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<*xf32> // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] : tensor - // CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor<*xf32> - // CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<*xf32> - // CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32> - // CHECK-DAG: [[R3:%.+]] = xla_hlo.multiply [[R2]], [[HALF]] : tensor<*xf32> - // CHECK-DAG: [[R4:%.+]] = xla_hlo.add [[R3]], [[HALF]] : tensor<*xf32> + // CHECK-DAG: [[HALF:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor<*xf32> + // CHECK-DAG: [[R1:%.+]] = mhlo.multiply %arg0, [[HALF]] : tensor<*xf32> + // CHECK-DAG: [[R2:%.+]] = "mhlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK-DAG: [[R3:%.+]] = mhlo.multiply [[R2]], [[HALF]] : tensor<*xf32> + // CHECK-DAG: [[R4:%.+]] = mhlo.add [[R3]], [[HALF]] : tensor<*xf32> %0 = "tf.Sigmoid"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -1900,10 +1900,10 @@ func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @sigmoid_grad func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-DAG: [[MUL0:%.+]] = xla_hlo.multiply %arg1, %arg0 : tensor<2xf32> - // CHECK-DAG: [[ONE:%.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<2xf32> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ONE]], %arg0 : tensor<2xf32> - // CHECK-DAG: [[MUL1:%.+]] = xla_hlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32> + // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xf32> + // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> + // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xf32> + // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32> // CHECK: return [[MUL1]] %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> @@ -1911,10 +1911,10 @@ func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> // CHECK-LABEL: @sigmoid_grad_complex func @sigmoid_grad_complex(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomplex>) -> tensor<2xcomplex> { - // CHECK-DAG: [[MUL0:%.+]] = xla_hlo.multiply %arg1, %arg0 : tensor<2xcomplex> - // CHECK-DAG: [[ONE:%.+]] = xla_hlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ONE]], %arg0 : tensor<2xcomplex> - // CHECK-DAG: [[MUL1:%.+]] = xla_hlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex> + // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xcomplex> + // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex> + // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xcomplex> + // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex> // CHECK: return [[MUL1]] %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xcomplex>, tensor<2xcomplex>) -> tensor<2xcomplex> return %0 : tensor<2xcomplex> @@ -1922,112 +1922,112 @@ func @sigmoid_grad_complex(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomple // CHECK-LABEL: @sin func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @sin_dynamic func @sin_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.sine"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.sine"(%arg0) : (tensor) -> tensor %0 = "tf.Sin"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @sin_unranked func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: func @rsqrt func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @rsqrt_dynamic func @rsqrt_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.rsqrt"(%arg0) : (tensor) -> tensor %0 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @rsqrt_unranked func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: func @sqrt func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @sqrt_dynamic func @sqrt_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.sqrt"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.sqrt"(%arg0) : (tensor) -> tensor %0 = "tf.Sqrt"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @sqrt_unranked func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: func @tanh func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @tanh_dynamic func @tanh_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.tanh"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.tanh"(%arg0) : (tensor) -> tensor %0 = "tf.Tanh"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @tanh_unranked func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: func @bitcast func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // CHECK-LABEL: func @bitcast_dynamic func @bitcast_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.bitcast_convert"(%arg0) : (tensor) -> tensor + // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor) -> tensor %0 = "tf.Bitcast"(%arg0) : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @bitcast_unranked func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "xla_hlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Bitcast"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: func @bitcast_same_widths func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { - // CHECK: "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> + // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> return %0 : tensor<2xi32> } @@ -2048,14 +2048,14 @@ func @bitcast_smaller_output_width(%arg0: tensor<2xf32>) -> tensor<2xf16> { // CHECK-LABEL: reshape func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> { - // CHECK: "xla_hlo.reshape" + // CHECK: "mhlo.reshape" %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<2x1xf32> return %0 : tensor<2x1xf32> } // CHECK-LABEL: reshape_dynamic func @reshape_dynamic(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<1x1xf32> { - // CHECK: "xla_hlo.reshape" + // CHECK: "mhlo.reshape" %0 = "tf.Reshape"(%arg0, %arg1) : (tensor, tensor<2xi32>) -> tensor<1x1xf32> return %0 : tensor<1x1xf32> } @@ -2069,7 +2069,7 @@ func @reshape_unranked(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor) -> tensor<1x10xf32> { - // CHECK: "xla_hlo.reshape" + // CHECK: "mhlo.reshape" %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> return %0 : tensor<1x10xf32> } @@ -2083,7 +2083,7 @@ func @squeeze_dynamic(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: expand_dims func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { - // CHECK: "xla_hlo.reshape" + // CHECK: "mhlo.reshape" %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor) -> tensor<1x2xf32> return %0 : tensor<1x2xf32> } @@ -2091,10 +2091,10 @@ func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { // CHECK-LABEL: func @sign // CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32> func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { - // CHECK: [[PRED:%.*]] = "xla_hlo.compare"([[ARG]], [[ARG]]) - // CHECK: [[ZEROS:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> - // CHECK: [[SIGN:%.*]] = "xla_hlo.sign"([[ARG]]) - // CHECK: [[SELECT:%.*]] = "xla_hlo.select"([[PRED]], [[ZEROS]], [[SIGN]]) + // CHECK: [[PRED:%.*]] = "mhlo.compare"([[ARG]], [[ARG]]) + // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> + // CHECK: [[SIGN:%.*]] = "mhlo.sign"([[ARG]]) + // CHECK: [[SELECT:%.*]] = "mhlo.select"([[PRED]], [[ZEROS]], [[SIGN]]) // CHECK: return [[SELECT]] : tensor<1x2x3x4xf32> %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>) return %0 : tensor<1x2x3x4xf32> @@ -2102,17 +2102,17 @@ func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { // CHECK-LABEL: slice_constant_start func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> + // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi64> // CHECK: %[[CAST:.*]] = tensor_cast %[[START]] : tensor<1xi64> to tensor<1xi64> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[CAST]]) : (tensor<1xi64>) -> tensor<1xi64> - // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[CAST]]) : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]]) // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START:.*]]) : + // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START:.*]]) : // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : // CHECK-DAG-SAME: (tensor<4xi32>, tensor) -> tensor<2xi32> // CHECK: return %[[RESULT]] : tensor<2xi32> @@ -2124,15 +2124,15 @@ func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK-LABEL: slice_i32_consts func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32> + // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi32> // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<1xi32> to tensor<1xi32> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<1xi32>) -> tensor<1xi64> - // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START_CAST]]) : (tensor<1xi32>) -> tensor<1xi64> + // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]]) // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor - // CHECK: "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> + // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor + // CHECK: "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> @@ -2141,15 +2141,15 @@ func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK-LABEL: slice_constant_start_negative_one_size func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { - // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> + // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi64> // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<1xi64> to tensor<1xi64> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<1xi64>) -> tensor<1xi64> - // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START_CAST]]) : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]]) // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<3xi32> + // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<3xi32> // CHECK: return %[[RESULT]] : tensor<3xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>) @@ -2159,24 +2159,24 @@ func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi // CHECK-LABEL: slice_constant_start_dynamic_shape func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> + // CHECK: %[[START:.*]] = mhlo.constant dense<[1, 0]> : tensor<2xi64> // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<2xi64> to tensor<2xi64> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<2xi64>) -> tensor<2xi64> - // CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START_CAST]]) : (tensor<2xi64>) -> tensor<2xi64> + // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%[[START_I64]]) // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START1:.*]] = "xla_hlo.reshape"(%[[SLICED_START1]]) : + // CHECK: %[[RESHAPED_START1:.*]] = "mhlo.reshape"(%[[SLICED_START1]]) : // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor - // CHECK: %[[SLICED_START2:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%[[START_I64]]) // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START2:.*]] = "xla_hlo.reshape"(%[[SLICED_START2]]) : + // CHECK: %[[RESHAPED_START2:.*]] = "mhlo.reshape"(%[[SLICED_START2]]) : // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice" + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice" // CHECK-DAG-SAME: (%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor<1x4xi32> @@ -2189,15 +2189,15 @@ func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2 // CHECK-LABEL: slice_variable_start func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64> - // CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64> + // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%[[START_I64]]) // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START1:.*]] = "xla_hlo.reshape"(%[[SLICED_START1]]) : (tensor<1xi64>) -> tensor - // CHECK: %[[SLICED_START2:.*]] = "xla_hlo.slice"(%[[START_I64]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START2:.*]] = "xla_hlo.reshape"(%[[SLICED_START2]]) : (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + // CHECK: %[[RESHAPED_START1:.*]] = "mhlo.reshape"(%[[SLICED_START1]]) : (tensor<1xi64>) -> tensor + // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%[[START_I64]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START2:.*]] = "mhlo.reshape"(%[[SLICED_START2]]) : (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> // CHECK: return %[[RESULT]] : tensor<1x4xi32> %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> @@ -2223,7 +2223,7 @@ func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - // CHECK: xla_hlo.slice + // CHECK: mhlo.slice // CHECK-DAG-SAME: start_indices = dense<[0, 1]> // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> // CHECK-DAG-SAME: strides = dense<[1, 3]> @@ -2252,9 +2252,9 @@ func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - // CHECK: "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK: xla_hlo.slice + // CHECK: mhlo.slice // CHECK-DAG-SAME: start_indices = dense<[0, 1]> // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> // CHECK-DAG-SAME: strides = dense<[1, 3]> @@ -2283,9 +2283,9 @@ func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<0x3xf32> { %end = "tf.Const"() {value = dense<[-1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>) %strides = "tf.Const"() {value = dense<[-1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - // CHECK: "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + // CHECK: "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} - // CHECK: xla_hlo.slice + // CHECK: mhlo.slice // CHECK-DAG-SAME: start_indices = dense<[3, 0]> // CHECK-DAG-SAME: limit_indices = dense<[3, 8]> // CHECK-DAG-SAME: strides = dense<[1, 3]> @@ -2319,9 +2319,9 @@ func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) { %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) - // CHECK: %[[REVERSE:.*]] = "xla_hlo.reverse"(%[[INPUT]]) + // CHECK: %[[REVERSE:.*]] = "mhlo.reverse"(%[[INPUT]]) - // CHECK: %[[SLICE:.*]] = "xla_hlo.slice"(%[[REVERSE]]) + // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[REVERSE]]) // CHECK-DAG-SAME: limit_indices = dense<[4, 65, 1024]> // CHECK-DAG-SAME: start_indices = dense<[0, 4, 2]> // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> @@ -2329,7 +2329,7 @@ func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) { %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x16x1022xf32> - // CHECK: "xla_hlo.reshape"(%[[SLICE]]) + // CHECK: "mhlo.reshape"(%[[SLICE]]) // CHECK-SAME: -> tensor<4x16x1022xf32> return @@ -2361,7 +2361,7 @@ func @strided_slice_shrink_axis_mask(%input: tensor<4x128x1024xf32>) { %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) - // CHECK: %[[SLICE:.*]] = "xla_hlo.slice"(%[[INPUT]]) + // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) // CHECK-DAG-SAME: limit_indices = dense<[1, 65, 1022]> // CHECK-DAG-SAME: start_indices = dense<[0, 4, 1021]> // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> @@ -2369,7 +2369,7 @@ func @strided_slice_shrink_axis_mask(%input: tensor<4x128x1024xf32>) { %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4, shrink_axis_mask = 5} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<16xf32> - // CHECK: "xla_hlo.reshape"(%[[SLICE]]) + // CHECK: "mhlo.reshape"(%[[SLICE]]) // CHECK-SAME: -> tensor<16xf32> return @@ -2382,7 +2382,7 @@ func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) { // The ellipsis mask is applied to dim #1, #2, i.e, we get canonicalized // slice input[1, :, :, 8:, :10, 2:6:2] - // The start, limit indices and strides attributes of xla_hlo.slice would + // The start, limit indices and strides attributes of mhlo.slice would // reflect the canonicalized slice. // As output shape of StridedSlice differs, a reshape will follow. @@ -2390,14 +2390,14 @@ func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) { %end = "tf.Const"() {value = dense<[2, 0, 10, 10, 6]> : tensor<5xi32>} : () -> (tensor<5xi32>) %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) - // CHECK: %[[SLICE:.*]] = "xla_hlo.slice"(%[[INPUT]]) + // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 8, end_mask = 4, shrink_axis_mask = 1, ellipsis_mask = 2} : (tensor<2x4x8x16x32x64xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<4x8x8x10x2xf32> - // CHECK: "xla_hlo.reshape"(%[[SLICE]]) + // CHECK: "mhlo.reshape"(%[[SLICE]]) // CHECK-SAME: -> tensor<4x8x8x10x2xf32> return @@ -2413,7 +2413,7 @@ func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) { // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] // This is then reshaped to add the new axes. - // The start, limit indices and strides attributes of xla_hlo.slice would + // The start, limit indices and strides attributes of mhlo.slice would // reflect the canonicalized slice. // As output shape of StridedSlice differs, a reshape will follow to reflect // new axes added. @@ -2422,14 +2422,14 @@ func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) { %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) - // CHECK: %[[SLICE:.*]] = "xla_hlo.slice"(%[[INPUT]]) + // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<2x4x8x16x32x64xf32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>) -> tensor<1x4x8x8x10x2x1xf32> - // CHECK: "xla_hlo.reshape"(%[[SLICE]]) + // CHECK: "mhlo.reshape"(%[[SLICE]]) // CHECK-SAME: -> tensor<1x4x8x8x10x2x1xf32> return @@ -2439,16 +2439,16 @@ func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) { // CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32> func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> { // StridedSlice gets input[8:10], which is same as input[8:10, ...] - // The start_indices, limit_indices, and strides attribute of xla_hlo.slice + // The start_indices, limit_indices, and strides attribute of mhlo.slice // reflect the canonicalized slice. %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32> %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: [[SLICE:%.*]] = "xla_hlo.slice"([[INPUT]]) + // CHECK: [[SLICE:%.*]] = "mhlo.slice"([[INPUT]]) // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64> // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64> // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64> - // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[SLICE]]) : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32> + // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[SLICE]]) : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32> %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32> // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32> return %0 : tensor<2x16x2xf32> @@ -2464,26 +2464,26 @@ func @strided_slice_nonconstant_begin_end(%arg0: tensor, %arg1: tensor<32x1 %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - // CHECK: %[[A:.*]] = "xla_hlo.reshape"(%arg0) : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[BEGIN:.*]] = "xla_hlo.concatenate"(%[[A]]) + // CHECK: %[[A:.*]] = "mhlo.reshape"(%arg0) : (tensor) -> tensor<1xi32> + // CHECK-NEXT: %[[BEGIN:.*]] = "mhlo.concatenate"(%[[A]]) // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32> - // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK-NEXT: %[[INDEX:.*]] = "xla_hlo.slice"(%[[BEGIN]]) + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK-NEXT: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]]) // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> - // CHECK-NEXT: %[[INDEX2:.*]] = "xla_hlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[CMP:.*]] = xla_chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] + // CHECK-NEXT: %[[INDEX2:.*]] = "mhlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor + // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] // CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[DIM:.*]] = xla_hlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = xla_chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[INDEX3:.*]] = "xla_hlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : + // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor + // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor - // CHECK-NEXT: %[[SLICED:.*]] = "xla_hlo.dynamic-slice" + // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic-slice" // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor, tensor, tensor) -> tensor<1x97xi32> - // CHECK-NEXT: %[[FINAL:.*]] = "xla_hlo.reshape"(%[[SLICED]]) : (tensor<1x97xi32>) -> tensor<1x97xi32> + // CHECK-NEXT: %[[FINAL:.*]] = "mhlo.reshape"(%[[SLICED]]) : (tensor<1x97xi32>) -> tensor<1x97xi32> %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32> return %result : tensor<1x97xi32> @@ -2564,7 +2564,7 @@ func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor // This ellipsis mask is supported because it refers to the last dimension. // [1, 0, 0] = 4 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: xla_hlo.dynamic-slice + // CHECK: mhlo.dynamic-slice %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> return %result : tensor<1x97xi32> } @@ -2574,7 +2574,7 @@ func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: ten // This shrink_axis mask is supported because it refers to a major dimension. // [1, 1, 1] = 7 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: xla_hlo.dynamic-slice + // CHECK: mhlo.dynamic-slice %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> return %result : tensor<1x97xi32> } @@ -2597,17 +2597,17 @@ func @strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask(%input: t // CHECK-LABEL: func @mean func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = "xla_hlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> - // CHECK: %[[INITIAL:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = "xla_hlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { + // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { // CHECK: ^bb0(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.add %[[ARGA]], %[[ARGB]] : tensor - // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () + // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor + // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<8.000000e+00> : tensor - // CHECK: %[[MEAN:.*]] = xla_chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<8.000000e+00> : tensor + // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16> + // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> @@ -2634,15 +2634,15 @@ func @mean_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { // CHECK-LABEL: func @sum func @sum(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = "xla_hlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> - // CHECK: %[[INITIAL:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = "xla_hlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { + // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { // CHECK: ^bb0(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.add %[[ARGA]], %[[ARGB]] : tensor - // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () + // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor + // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> + // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> @@ -2651,15 +2651,15 @@ func @sum(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK-LABEL: func @sum_dynamic func @sum_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = "xla_hlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf32> - // CHECK: %[[INITIAL:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = "xla_hlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { + // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf32> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { // CHECK: ^bb0(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.add %[[ARGA]], %[[ARGB]] : tensor - // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () + // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor + // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x?xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> + // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> @@ -2668,15 +2668,15 @@ func @sum_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { // CHECK-LABEL: func @max func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = "xla_hlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16> - // CHECK: %[[INITIAL:.*]] = xla_hlo.constant dense<0xFC00> : tensor - // CHECK: %[[REDUCED:.*]] = "xla_hlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { + // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor + // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { // CHECK: ^bb0(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.maximum %[[ARGA]], %[[ARGB]] : tensor - // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () + // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.maximum %[[ARGA]], %[[ARGB]] : tensor + // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor) -> tensor<4xf16> - // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> + // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> @@ -2685,15 +2685,15 @@ func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK-LABEL: func @max_dynamic func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = "xla_hlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf16> - // CHECK: %[[INITIAL:.*]] = xla_hlo.constant dense<0xFC00> : tensor - // CHECK: %[[REDUCED:.*]] = "xla_hlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { + // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf16> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor + // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { // CHECK: ^bb0(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.maximum %[[ARGA]], %[[ARGB]] : tensor - // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () + // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.maximum %[[ARGA]], %[[ARGB]] : tensor + // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x?xf16>, tensor) -> tensor<4xf16> - // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> + // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> @@ -2702,15 +2702,15 @@ func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { // CHECK-LABEL: func @min func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = "xla_hlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16> - // CHECK: %[[INITIAL:.*]] = xla_hlo.constant dense<0x7C00> : tensor - // CHECK: %[[REDUCED:.*]] = "xla_hlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { + // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0x7C00> : tensor + // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { // CHECK: ^bb0(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.minimum %[[ARGA]], %[[ARGB]] : tensor - // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () + // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.minimum %[[ARGA]], %[[ARGB]] : tensor + // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor) -> tensor<4xf16> - // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> + // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> @@ -2719,15 +2719,15 @@ func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK-LABEL: func @prod func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = "xla_hlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> - // CHECK: %[[INITIAL:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = "xla_hlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { + // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> + // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { // CHECK: ^bb0(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[REDUCE_BODY_RESULT:.*]] = xla_hlo.multiply %[[ARGA]], %[[ARGB]] : tensor - // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () + // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.multiply %[[ARGA]], %[[ARGB]] : tensor + // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> + // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> + // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> @@ -2737,11 +2737,11 @@ func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK-LABEL: @all func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> { %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[INIT:.*]] = xla_hlo.constant dense : tensor - // CHECK: "xla_hlo.reduce"(%{{.*}}, %[[INIT]]) ( { + // CHECK: %[[INIT:.*]] = mhlo.constant dense : tensor + // CHECK: "mhlo.reduce"(%{{.*}}, %[[INIT]]) ( { // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[AND:.*]] = xla_hlo.and %[[ARGA]], %[[ARGB]] : tensor - // CHECK: "xla_hlo.return"(%[[AND]]) : (tensor) -> () + // CHECK: %[[AND:.*]] = mhlo.and %[[ARGA]], %[[ARGB]] : tensor + // CHECK: "mhlo.return"(%[[AND]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor) -> tensor<4xi1> %0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> return %0 : tensor<4xi1> @@ -2749,7 +2749,7 @@ func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> { // CHECK-LABEL: @all_keep_dim func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { - // CHECK: "xla_hlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1> + // CHECK: "mhlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1> %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> return %0 : tensor<4x1xi1> @@ -2758,8 +2758,8 @@ func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { // CHECk-LABEL: @all_dynamic func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[ARG:.*]] = "xla_hlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1> - // CHECK: "xla_hlo.reduce"(%[[ARG]] + // CHECK: %[[ARG:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1> + // CHECK: "mhlo.reduce"(%[[ARG]] %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> return %0 : tensor<4x1xi1> } @@ -2767,11 +2767,11 @@ func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { // CHECK-LABEL: @any func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> { %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[INIT:.*]] = xla_hlo.constant dense : tensor - // CHECK: "xla_hlo.reduce"(%{{.*}}, %[[INIT]]) ( { + // CHECK: %[[INIT:.*]] = mhlo.constant dense : tensor + // CHECK: "mhlo.reduce"(%{{.*}}, %[[INIT]]) ( { // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): - // CHECK: %[[AND:.*]] = xla_hlo.or %[[ARGA]], %[[ARGB]] : tensor - // CHECK: "xla_hlo.return"(%[[AND]]) : (tensor) -> () + // CHECK: %[[AND:.*]] = mhlo.or %[[ARGA]], %[[ARGB]] : tensor + // CHECK: "mhlo.return"(%[[AND]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor) -> tensor<4xi1> %0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> return %0 : tensor<4xi1> @@ -2779,7 +2779,7 @@ func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> { // CHECK-LABEL: @any_keep_dim func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { - // CHECK: "xla_hlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1> + // CHECK: "mhlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1> %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> return %0 : tensor<4x1xi1> @@ -2788,8 +2788,8 @@ func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { // CHECk-LABEL: @any_dynamic func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[ARG:.*]] = "xla_hlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1> - // CHECK: "xla_hlo.reduce"(%[[ARG]] + // CHECK: %[[ARG:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1> + // CHECK: "mhlo.reduce"(%[[ARG]] %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> return %0 : tensor<4x1xi1> } @@ -2800,8 +2800,8 @@ func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { // CHECK-LABEL: func @tile_by_reshape func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> { - // CHECK: %[[BROADCASTED:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32> - // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[BROADCASTED]]) : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32> + // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[BROADCASTED]]) : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32> // CHECK: return %[[RESULT]] : tensor<28x24xf32> %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> %0 = "tf.Tile"(%arg0, %multiples) : (tensor<4x8xf32>, tensor<2xi64>) -> tensor<28x24xf32> @@ -2810,7 +2810,7 @@ func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> { // CHECK-LABEL: func @tile_just_broadcast func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> { - // CHECK: %[[RESULT:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<7x3xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<7x3xf32> // CHECK: return %[[RESULT]] : tensor<7x3xf32> %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> %0 = "tf.Tile"(%arg0, %multiples) : (tensor<1x1xf32>, tensor<2xi64>) -> tensor<7x3xf32> @@ -2823,15 +2823,15 @@ func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> { // CHECK-LABEL: func @argmax_i64_input_i32_output_axis_0 func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { - // CHECK: %[[INIT:.*]] = xla_hlo.constant dense<-9223372036854775808> : tensor - // CHECK: %[[INDEX_INIT:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: %[[INDEX:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x7xi32> - // CHECK: %[[REDUCE:.*]]:2 = "xla_hlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) + // CHECK: %[[INIT:.*]] = mhlo.constant dense<-9223372036854775808> : tensor + // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x7xi32> + // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): - // CHECK: %[[COMPARE:.*]] = "xla_hlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - // CHECK: %[[RESULT1:.*]] = "xla_hlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor, tensor, tensor) -> tensor - // CHECK: %[[RESULT2:.*]] = "xla_hlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor, tensor, tensor) -> tensor - // CHECK: "xla_hlo.return"(%[[RESULT1]], %[[RESULT2]]) : (tensor, tensor) -> () + // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + // CHECK: %[[RESULT1:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor, tensor, tensor) -> tensor + // CHECK: %[[RESULT2:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor, tensor, tensor) -> tensor + // CHECK: "mhlo.return"(%[[RESULT1]], %[[RESULT2]]) : (tensor, tensor) -> () // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi64>, tensor) -> tensor<7xi32> @@ -2840,10 +2840,10 @@ func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32 // CHECK-LABEL: func @argmax_f32_input_i64_output_axis_1 func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor<3xi64> { - // CHECK: %[[INIT:.*]] = xla_hlo.constant dense<0xFF800000> : tensor - // CHECK: %[[INDEX_INIT:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: %[[INDEX:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x7xi64> - // CHECK: %[[REDUCE:.*]]:2 = "xla_hlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) + // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor + // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x7xi64> + // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xf32>, tensor) -> tensor<3xi64> @@ -2852,10 +2852,10 @@ func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor<3xi64 // CHECK-LABEL: func @argmax_dynamic_shape_input_output func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor { - // CHECK: %[[INIT:.*]] = xla_hlo.constant dense<-2147483648> : tensor - // CHECK: %[[INDEX_INIT:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: %[[INDEX:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x?xi32> - // CHECK: %[[REDUCE:.*]]:2 = "xla_hlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) + // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor + // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x?xi32> + // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) // CHECK: return %[[REDUCE]]#1 : tensor %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor) -> tensor @@ -2864,10 +2864,10 @@ func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor // CHECK-LABEL: func @argmax_dynamic_shape_input func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> { - // CHECK: %[[INIT:.*]] = xla_hlo.constant dense<-2147483648> : tensor - // CHECK: %[[INDEX_INIT:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: %[[INDEX:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x?xi32> - // CHECK: %[[REDUCE:.*]]:2 = "xla_hlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) + // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor + // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x?xi32> + // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) // CHECK: return %[[REDUCE]]#1 : tensor<3xi32> %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor) -> tensor<3xi32> @@ -2880,10 +2880,10 @@ func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> { // CHECK-LABEL: func @rng_uniform func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { - // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[CONV:.*]] = "xla_hlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "xla_hlo.rng_uniform"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[CONV:.*]] = "mhlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64> + // CHECK: %[[F32:.*]] = "mhlo.rng_uniform"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> // CHECK: return %[[F32]] return %0 : tensor<12x?x64xf32> @@ -2891,10 +2891,10 @@ func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { // CHECK-LABEL: func @rng_std_normal func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { - // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[CONV:.*]] = "xla_hlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "xla_hlo.rng_normal"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[CONV:.*]] = "mhlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64> + // CHECK: %[[F32:.*]] = "mhlo.rng_normal"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> // CHECK: return %[[F32]] return %0 : tensor<12x?x64xf32> @@ -2908,9 +2908,9 @@ func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { // CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor} : () -> tensor - // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota" - // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK: xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota" + // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor<5xf32> return %3 : tensor<5xf32> } @@ -2918,19 +2918,19 @@ func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { // CHECK-LABEL: func @range_dynamic // CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor func @range_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract %arg1, %arg0 - // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"([[SUB]]) - // CHECK-DAG: [[CONVERT1:%.+]] = "xla_hlo.convert"([[ABS1]]) - // CHECK-DAG: [[CONVERT2:%.+]] = "xla_hlo.convert"(%arg2) - // CHECK-DAG: [[DIV:%.+]] = xla_hlo.divide [[CONVERT1]], [[CONVERT2]] - // CHECK-DAG: [[CEIL:%.+]] = "xla_hlo.ceil"([[DIV]]) - // CHECK-DAG: [[CONVERT3:%.+]] = "xla_hlo.convert"([[CEIL]]) - // CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[CONVERT3]]) - // CHECK-DAG: [[IOTA:%.+]] = "xla_hlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} - // CHECK-DAG: [[CONVERT3:%.+]] = "xla_hlo.convert"(%arg0) - // CHECK-DAG: [[CONVERT4:%.+]] = "xla_hlo.convert"(%arg2) - // CHECK-DAG: [[MUL:%.+]] = xla_chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 + // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"([[SUB]]) + // CHECK-DAG: [[CONVERT1:%.+]] = "mhlo.convert"([[ABS1]]) + // CHECK-DAG: [[CONVERT2:%.+]] = "mhlo.convert"(%arg2) + // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]] + // CHECK-DAG: [[CEIL:%.+]] = "mhlo.ceil"([[DIV]]) + // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"([[CEIL]]) + // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.reshape"([[CONVERT3]]) + // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} + // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0) + // CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2) + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>} %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor // CHECK: return [[ADD]] @@ -2940,19 +2940,19 @@ func @range_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) // CHECK-LABEL: func @range_int_dynamic // CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor func @range_int_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract %arg1, %arg0 - // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"([[SUB]]) - // CHECK-DAG: [[CONVERT1:%.+]] = "xla_hlo.convert"([[ABS1]]) - // CHECK-DAG: [[CONVERT2:%.+]] = "xla_hlo.convert"(%arg2) - // CHECK-DAG: [[DIV:%.+]] = xla_hlo.divide [[CONVERT1]], [[CONVERT2]] - // CHECK-DAG: [[CEIL:%.+]] = "xla_hlo.ceil"([[DIV]]) - // CHECK-DAG: [[CONVERT3:%.+]] = "xla_hlo.convert"([[CEIL]]) - // CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[CONVERT3]]) - // CHECK-DAG: [[IOTA:%.+]] = "xla_hlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} - // CHECK-DAG: [[CONVERT3:%.+]] = "xla_hlo.convert"(%arg0) - // CHECK-DAG: [[CONVERT4:%.+]] = "xla_hlo.convert"(%arg2) - // CHECK-DAG: [[MUL:%.+]] = xla_chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 + // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"([[SUB]]) + // CHECK-DAG: [[CONVERT1:%.+]] = "mhlo.convert"([[ABS1]]) + // CHECK-DAG: [[CONVERT2:%.+]] = "mhlo.convert"(%arg2) + // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]] + // CHECK-DAG: [[CEIL:%.+]] = "mhlo.ceil"([[DIV]]) + // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"([[CEIL]]) + // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.reshape"([[CONVERT3]]) + // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} + // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0) + // CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2) + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>} %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor // CHECK: return [[ADD]] @@ -2962,16 +2962,16 @@ func @range_int_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor, [[STOP:%.*]]: tensor func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { - // CHECK-DAG: [[NUM:%.*]] = xla_hlo.constant dense<4> + // CHECK-DAG: [[NUM:%.*]] = mhlo.constant dense<4> // CHECK-DAG: [[NUM_CAST:%.*]] = tensor_cast [[NUM]] - // CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM_CAST]]) - // CHECK-DAG: [[ONE:%.*]] = xla_hlo.constant dense<1.000000e+00> - // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_chlo.broadcast_subtract [[NUM_F32]], [[ONE]] - // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_chlo.broadcast_subtract [[STOP]], [[START]] - // CHECK-DAG: [[STEP:%.*]] = xla_chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] - // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} - // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK-DAG: [[LINSPACE:%.*]] = xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[NUM_F32:%.*]] = "mhlo.convert"([[NUM_CAST]]) + // CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> + // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]] + // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]] + // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] + // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} + // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} // CHECK: return [[LINSPACE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor} : () -> tensor %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor<4xf32> @@ -2987,9 +2987,9 @@ func @linspace_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg1: tensor) -> tensor { - // CHECK: xla_hlo.constant dense<[]> : tensor<0xi32> + // CHECK: mhlo.constant dense<> : tensor<0xi32> // CHECK: "tf.LinSpace" - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor<0xi32>) -> tensor return %1 : tensor } @@ -3001,7 +3001,7 @@ func @linspace_invalid_num(%arg0: tensor, %arg1: tensor) -> tensor, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> { - // CHECK: "xla_hlo.convolution"(%arg0, %arg1) + // CHECK: "mhlo.convolution"(%arg0, %arg1) // Default attributes // CHECK-NOT: lhs_dilation @@ -3032,7 +3032,7 @@ func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) - // CHECK-LABEL: conv3d_simple func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x30x30x30x16xf32> { - // CHECK: "xla_hlo.convolution"(%arg0, %arg1) + // CHECK: "mhlo.convolution"(%arg0, %arg1) // Default attributes // CHECK-NOT: lhs_dilation @@ -3062,8 +3062,8 @@ func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16x // CHECK-LABEL: depthwiseconv_simple func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> { - // CHECK: %[[RESHAPED_FILTER:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32> - // CHECK: "xla_hlo.convolution"(%arg0, %[[RESHAPED_FILTER]]) + // CHECK: %[[RESHAPED_FILTER:.*]] = "mhlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32> + // CHECK: "mhlo.convolution"(%arg0, %[[RESHAPED_FILTER]]) // CHECK: feature_group_count = 3 %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { data_format = "NHWC", @@ -3078,7 +3078,7 @@ func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32 // CHECK-LABEL: conv_valid_padding func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> { - // CHECK: "xla_hlo.convolution"(%arg0, %arg1) + // CHECK: "mhlo.convolution"(%arg0, %arg1) %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x4x5x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> return %0 : tensor<1x2x3x1xf32> @@ -3087,7 +3087,7 @@ func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) // CHECK-LABEL: conv_explicit_paddings func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> { - // CHECK: "xla_hlo.convolution"(%arg0, %arg1) + // CHECK: "mhlo.convolution"(%arg0, %arg1) // CHECK-SAME: padding = dense<{{\[\[}}6, 0], [3, 3]]> %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> @@ -3099,8 +3099,8 @@ func @conv2d_backprop_input( %filter: tensor<3x3x1x32xf32>, %out_backprop: tensor<100x26x26x32xf32> ) -> tensor<100x28x28x1xf32> { - // CHECK: %[[REV_FILTER:.*]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK: %[[RESULT:.*]] = "xla_hlo.convolution"(%arg1, %[[REV_FILTER]]) { + // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg1, %[[REV_FILTER]]) { // CHECK-SAME: batch_group_count = 1 : i64, // CHECK-SAME: dimension_numbers = { // CHECK-SAME: input_batch_dimension = 0 : i64, @@ -3133,8 +3133,8 @@ func @conv2d_backprop_input( // CHECK-LABEL: @conv3d_backprop_input func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { - // CHECK: %[[REV_FILTER:.*]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} - // CHECK: %[[RESULT:.*]] = "xla_hlo.convolution"(%arg1, %[[REV_FILTER]]) + // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg1, %[[REV_FILTER]]) // CHECK-DAG-SAME: batch_group_count = 1 : i64, @@ -3166,7 +3166,7 @@ func @conv2d_backprop_filter( %input: tensor<100x28x28x1xf32>, %out_backprop: tensor<100x26x26x32xf32> ) -> tensor<100x28x28x1xf32> { - // CHECK: %[[RESULT:.*]] = "xla_hlo.convolution"(%arg0, %arg1) { + // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg0, %arg1) { // CHECK-SAME: batch_group_count = 1 : i64, // CHECK-SAME: dimension_numbers = { // CHECK-SAME: input_batch_dimension = 3 : i64, @@ -3199,7 +3199,7 @@ func @conv2d_backprop_filter( // CHECK-LABEL: @conv3d_backprop_filter func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { - // CHECK: %[[RESULT:.*]] = "xla_hlo.convolution"(%arg0, %arg1) + // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg0, %arg1) // CHECK-DAG-SAME: batch_group_count = 1 : i64 @@ -3232,7 +3232,7 @@ func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { value = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32> } : () -> tensor<2x4xi32> - // CHECK: xla_hlo.cross-replica-sum + // CHECK: mhlo.cross-replica-sum // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> %result = "tf.CrossReplicaSum" (%input, %replica_groups) : (tensor<10xf32>, tensor<2x4xi32>) -> tensor<10xf32> return %result : tensor<10xf32> @@ -3244,7 +3244,7 @@ func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { // CHECK-LABEL: @size_rank_one_i32 func @size_rank_one_i32(%input: tensor) -> (tensor) { - // CHECK: %[[CONST:.*]] = xla_hlo.constant dense<1> + // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> // CHECK-SAME: tensor %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor) -> tensor // CHECK: return %[[CONST]] @@ -3253,7 +3253,7 @@ func @size_rank_one_i32(%input: tensor) -> (tensor) { // CHECK-LABEL: @size_rank_one_i64 func @size_rank_one_i64(%input: tensor) -> (tensor) { - // CHECK: %[[CONST:.*]] = xla_hlo.constant dense<1> + // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> // CHECK-SAME: tensor %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT64"} : (tensor) -> tensor // CHECK: return %[[CONST]] @@ -3263,16 +3263,16 @@ func @size_rank_one_i64(%input: tensor) -> (tensor) { // CHECK-LABEL: @size_ranked // CHECK-SAME: (%[[INPUT:.*]]: tensor<2x?x8xf32>) func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor) { - // CHECK: %[[CONST:.*]] = xla_hlo.constant dense<1> - // CHECK: %[[DIM_0:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) + // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> + // CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 0 - // CHECK: %[[MUL_0:.*]] = xla_chlo.broadcast_multiply %[[CONST]], %[[DIM_0]] - // CHECK: %[[DIM_1:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) + // CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[DIM_0]] + // CHECK: %[[DIM_1:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 1 - // CHECK: %[[MUL_1:.*]] = xla_chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]] - // CHECK: %[[DIM_2:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) + // CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]] + // CHECK: %[[DIM_2:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 2 - // CHECK: %[[MUL_2:.*]] = xla_chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]] + // CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]] %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor // CHECK: return %[[MUL_2]] return %size : tensor @@ -3307,8 +3307,8 @@ func @split_not_match_unknown_input_dim(%input: tensor<4x?x4xf32>) -> (tensor<4x // CHECK-LABEL: @split_match_and_split_into_two func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) { %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> - // CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> + // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> + // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) // CHECK: return %[[ONE]], %[[TWO]] return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32> @@ -3317,8 +3317,8 @@ func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32 // CHECK-LABEL: @split_match_and_split_into_two_dynamic func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) { %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> - // CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> + // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> + // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) // CHECK: return %[[ONE]], %[[TWO]] return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32> @@ -3328,9 +3328,9 @@ func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor // CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) { %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[THREE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> %0:3 = "tf.Split"(%cst, %input) : (tensor, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> @@ -3360,16 +3360,16 @@ func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor - // CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} - // CHECK-NEXT: %[[SORT:.*]] = "xla_hlo.sort"(%[[INPUT]], %[[IOTA]]) ( { + // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} + // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]], %[[IOTA]]) ( { // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor): - // CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"} - // CHECK-NEXT: "xla_hlo.return"(%[[CMP]]) + // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"} + // CHECK-NEXT: "mhlo.return"(%[[CMP]]) // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> - // CHECK-NEXT: %[[TUPL0:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32} - // CHECK-NEXT: %[[TUPL1:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32} - // CHECK-NEXT: %[[VAL:.*]] = "xla_hlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - // CHECK-NEXT: %[[IDX:.*]] = "xla_hlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: %[[TUPL0:.*]] = "mhlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32} + // CHECK-NEXT: %[[TUPL1:.*]] = "mhlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32} + // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} // CHECK-NEXT: return %[[VAL]], %[[IDX]] %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> (tensor<16x8xf32>, tensor<16x8xi32>) return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> @@ -3384,9 +3384,9 @@ func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32> - // CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[THREE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32> + // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32> + // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32> %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> @@ -3396,9 +3396,9 @@ func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1x func @splitv_match_and_split_into_three_dynamic(%input: tensor) -> (tensor, tensor, tensor) { %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor - // CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor - // CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor + // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor + // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor + // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor, tensor<3xi32>, tensor) -> (tensor, tensor, tensor) return %0#0, %0#1, %0#2 : tensor, tensor, tensor } @@ -3432,12 +3432,12 @@ func @assert(%arg0: tensor, %arg1: tensor<*xf32>) { // TODO(b/156340000): Re-enable when fixed. // // C-HECK-LABEL: @unpack // func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { -// // C-HECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> -// // C-HECK: %[[RES1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32> -// // C-HECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> -// // C-HECK: %[[RES2:.*]] = "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> -// // C-HECK: %[[SLICE3:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> -// // C-HECK: %[[RES3:.*]] = "xla_hlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> +// // C-HECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32> +// // C-HECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> +// // C-HECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> // %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) // // return %[[RES1]], %[[RES2]], %[[RES3]] @@ -3446,10 +3446,10 @@ func @assert(%arg0: tensor, %arg1: tensor<*xf32>) { // // C-HECK-LABEL: @unpack_dynamic // func @unpack_dynamic(%input: tensor) -> (tensor, tensor) { -// // C-HECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor -// // C-HECK: "xla_hlo.reshape"(%[[SLICE1]]) : (tensor) -> tensor -// // C-HECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor -// // C-HECK: "xla_hlo.reshape"(%[[SLICE2]]) : (tensor) -> tensor +// // C-HECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor +// // C-HECK: "mhlo.reshape"(%[[SLICE1]]) : (tensor) -> tensor +// // C-HECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor +// // C-HECK: "mhlo.reshape"(%[[SLICE2]]) : (tensor) -> tensor // %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor) -> (tensor, tensor) // return %0#0, %0#1 : tensor, tensor @@ -3464,12 +3464,12 @@ func @assert(%arg0: tensor, %arg1: tensor<*xf32>) { // CHECK-SAME: [[SI:%.*]]: tensor<8x16xi32> func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x16xi32>) -> (tensor<4x64xf32>) { %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: [[ZERO:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: [[INIT:%.*]] = "xla_hlo.broadcast"([[ZERO]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor) -> tensor<4x64xf32> - // CHECK: [[SCATTER:%.*]] = "xla_hlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( { + // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor) -> tensor<4x64xf32> + // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( { // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): - // CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]] : tensor - // CHECK: "xla_hlo.return"([[ADD]]) + // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor + // CHECK: "mhlo.return"([[ADD]]) // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = {index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<2> : tensor<1xi64>}, unique_indices = false} : (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32> // CHECK: return [[SCATTER]] %0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<8x16x64xf32>, tensor<8x16xi32>, tensor) -> (tensor<4x64xf32>) @@ -3481,12 +3481,12 @@ func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x1 // CHECK-SAME: [[SI:%.*]]: tensor func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor) -> (tensor<4x?xf32>) { %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: [[ONE:%.*]] = xla_hlo.constant dense<1.000000e+00> : tensor - // CHECK: [[INIT:%.*]] = "xla_hlo.broadcast"([[ONE]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor) -> tensor<4x64xf32> - // CHECK: [[SCATTER:%.*]] = "xla_hlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( { + // CHECK: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor) -> tensor<4x64xf32> + // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( { // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): - // CHECK: [[MUL:%.*]] = xla_hlo.multiply [[LHS]], [[RHS]] : tensor - // CHECK: "xla_hlo.return"([[MUL]]) + // CHECK: [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor + // CHECK: "mhlo.return"([[MUL]]) // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = {index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<2> : tensor<1xi64>}, unique_indices = false} : (tensor<4x64xf32>, tensor, tensor<8x?x64xf32>) -> tensor<4x?xf32> // CHECK: return [[SCATTER]] %0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) @@ -3496,9 +3496,9 @@ func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor, %segment_ids : tensor) -> (tensor<4x?xf32>) { %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: xla_hlo.constant dense<0x7F800000> : tensor - // CHECK: xla_hlo.scatter - // CHECK: xla_hlo.minimum + // CHECK: mhlo.constant dense<0x7F800000> : tensor + // CHECK: mhlo.scatter + // CHECK: mhlo.minimum %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) return %0: tensor<4x?xf32> } @@ -3506,9 +3506,9 @@ func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor, %segment_ids : tensor) -> (tensor<4x?xf32>) { %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: xla_hlo.constant dense<0xFF800000> : tensor - // CHECK: xla_hlo.scatter - // CHECK: xla_hlo.maximum + // CHECK: mhlo.constant dense<0xFF800000> : tensor + // CHECK: mhlo.scatter + // CHECK: mhlo.maximum %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) return %0: tensor<4x?xf32> } @@ -3519,7 +3519,7 @@ func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor, %arg1: tensor<16x5xi32>) -> tensor<16x2x5xf32> { - // CHECK: "xla_hlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5xf32> + // CHECK: "mhlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5xf32> %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32> return %1 : tensor<16x2x5xf32> @@ -3527,7 +3527,7 @@ func @gather_v2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16 // CHECK-LABEL: @gather_v2_dynamic func @gather_v2_dynamic(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { - // CHECK: "xla_hlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor, tensor) -> tensor<*xf32> + // CHECK: "mhlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor, tensor) -> tensor<*xf32> %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor, tensor, tensor<1xi32>) -> tensor<*xf32> return %1 : tensor<*xf32> @@ -3575,10 +3575,10 @@ func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) - // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"(%arg0) : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> - // CHECK: [[REVERSE:%.*]] = "xla_hlo.reverse"([[RESHAPE]]) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> - // CHECK: [[ZERO:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "xla_hlo.pad"([[REVERSE]], [[ZERO]]) {edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>} : (tensor<4x16x1022xf32>, tensor) -> tensor<4x128x1024xf32> + // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"(%arg0) : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> + // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> + // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) {edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>} : (tensor<4x16x1022xf32>, tensor) -> tensor<4x128x1024xf32> %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 1, end_mask = 4} : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> // CHECK: return [[PAD]] @@ -3600,9 +3600,9 @@ func @strided_slice_grad_shrink_axis_mask(%grad: tensor<8xf32>) -> tensor<4x8xf3 %end = "tf.Const"() {value = dense<[3, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) - // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[GRAD]]) : (tensor<8xf32>) -> tensor<1x8xf32> - // CHECK: [[ZEROS:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "xla_hlo.pad"([[RESHAPE]], [[ZEROS]]) + // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<8xf32>) -> tensor<1x8xf32> + // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0]> : tensor<2xi64> // CHECK-DAG-SAME: edge_padding_high = dense<[1, 0]> : tensor<2xi64> // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<2xi64> @@ -3627,9 +3627,9 @@ func @strided_slice_grad_new_axis_mask(%grad: tensor<1x2xf32>) -> tensor<8xf32> %end = "tf.Const"() {value = dense<[0, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>) %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) - // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[GRAD]]) : (tensor<1x2xf32>) -> tensor<2xf32> - // CHECK: [[ZEROS:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "xla_hlo.pad"([[RESHAPE]], [[ZEROS]]) + // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<1x2xf32>) -> tensor<2xf32> + // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) // CHECK-DAG-SAME: edge_padding_low = dense<2> : tensor<1xi64> // CHECK-DAG-SAME: edge_padding_high = dense<4> : tensor<1xi64> // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<1xi64> @@ -3655,9 +3655,9 @@ func @strided_slice_grad_ellipsis_mask(%grad: tensor<2x4x8xf32>) -> tensor<4x4x8 %end = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>) %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) - // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[GRAD]]) : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32> - // CHECK: [[ZEROS:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "xla_hlo.pad"([[RESHAPE]], [[ZEROS]]) + // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32> + // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0, 0]> : tensor<3xi64> // CHECK-DAG-SAME: edge_padding_high = dense<0> : tensor<3xi64> // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<3xi64> @@ -3685,12 +3685,12 @@ func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor< %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) // Remove 2 new axes (at index 1 and 6) and 1 shrink axis (at index 0) - // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[GRAD]]) : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32> - // CHECK: [[ZERO:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32> + // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor // The edge_padding_low, edge_padding_high and interior_padding attributes of - // xla_hlo.pad would reflect the padding required to get the shape of the + // mhlo.pad would reflect the padding required to get the shape of the // input of StridedSlice op. - // CHECK: [[PAD:%.*]] = "xla_hlo.pad"([[RESHAPE]], [[ZERO]]) + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZERO]]) // CHECK-DAG-SAME: edge_padding_low = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> // CHECK-DAG-SAME: edge_padding_high = dense<[0, 0, 0, 0, 22, 59]> : tensor<6xi64> // CHECK-DAG-SAME: interior_padding = dense<[0, 0, 0, 0, 0, 1]> : tensor<6xi64> @@ -3702,9 +3702,9 @@ func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor< // CHECK-LABEL: @tensor_scatter_update func @tensor_scatter_update(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "xla_hlo.scatter"(%arg0, %arg1, %arg2) ( { + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( { // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: "xla_hlo.return"(%arg4) : (tensor) -> () + // CHECK: "mhlo.return"(%arg4) : (tensor) -> () // CHECK: }) // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers @@ -3732,15 +3732,15 @@ func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> { // CHECK-LABEL: @random_shuffle_1D_16 // CHECK-SAME: [[INPUT:%.*]]: tensor<16xf32> func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { - // CHECK: [[SHAPE:%.*]] = xla_hlo.constant dense<16> : tensor<1xi64> - // CHECK: [[LOWER:%.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: [[UPPER:%.*]] = xla_hlo.constant dense<-1> : tensor - // CHECK: [[RNG:%.*]] = "xla_hlo.rng_uniform"([[LOWER]], [[UPPER]], [[SHAPE]]) - // CHECK: [[SORT:%.*]] = "xla_hlo.sort"([[RNG]], [[INPUT]]) ( { + // CHECK: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64> + // CHECK: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor + // CHECK: [[RNG:%.*]] = "mhlo.rng_uniform"([[LOWER]], [[UPPER]], [[SHAPE]]) + // CHECK: [[SORT:%.*]] = "mhlo.sort"([[RNG]], [[INPUT]]) ( { // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor, {{.*}}: tensor, {{.*}}: tensor): - // CHECK: "xla_hlo.compare"([[ARG1]], [[ARG2]]) {comparison_direction = "LT"} + // CHECK: "mhlo.compare"([[ARG1]], [[ARG2]]) {comparison_direction = "LT"} // CHECK: }) {dimension = -1 : i64, is_stable = true} : (tensor<16xi32>, tensor<16xf32>) -> tuple, tensor<16xf32>> - // CHECK: [[RES:%.*]] = "xla_hlo.get_tuple_element"([[SORT]]) {index = 1 : i32} + // CHECK: [[RES:%.*]] = "mhlo.get_tuple_element"([[SORT]]) {index = 1 : i32} // CHECK: return [[RES]] %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>) return %0: tensor<16xf32> @@ -3748,12 +3748,12 @@ func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { // CHECK-LABEL: @random_shuffle_1D_10240 func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { - // CHECK: xla_hlo.rng_uniform - // CHECK: xla_hlo.sort - // CHECK: xla_hlo.get_tuple_element - // CHECK: xla_hlo.rng_uniform - // CHECK: xla_hlo.sort - // CHECK: xla_hlo.get_tuple_element + // CHECK: mhlo.rng_uniform + // CHECK: mhlo.sort + // CHECK: mhlo.get_tuple_element + // CHECK: mhlo.rng_uniform + // CHECK: mhlo.sort + // CHECK: mhlo.get_tuple_element %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>) return %0: tensor<10240xf32> } @@ -3761,41 +3761,41 @@ func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { // CHECK-LABEL: @random_shuffle_3D // CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32> func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { - // CHECK: [[INDICES:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + // CHECK: [[INDICES:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> - // CHECK: [[RNG_SHAPE:%.*]] = xla_hlo.constant dense<4> : tensor<1xi64> - // CHECK: [[RNG_LOWER:%.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: [[RNG_UPPER:%.*]] = xla_hlo.constant dense<4> : tensor - // CHECK: [[SWAPS:%.*]] = "xla_hlo.rng_uniform"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) + // CHECK: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64> + // CHECK: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor + // CHECK: [[SWAPS:%.*]] = "mhlo.rng_uniform"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) - // CHECK: [[IV_INIT:%.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: [[WHILE_INIT:%.*]] = "xla_hlo.tuple"([[IV_INIT]], [[SWAPS]], [[INDICES]]) + // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[WHILE_INIT:%.*]] = "mhlo.tuple"([[IV_INIT]], [[SWAPS]], [[INDICES]]) - // CHECK: [[WHILE_OUT:%.*]] = "xla_hlo.while"([[WHILE_INIT]]) ( { + // CHECK: [[WHILE_OUT:%.*]] = "mhlo.while"([[WHILE_INIT]]) ( { // CHECK: ^{{.*}}([[COND_ARG:%.*]]: tuple, tensor<4xi32>, tensor<4xi32>>): - // CHECK: [[IV:%.*]] = "xla_hlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} - // CHECK: [[LIMIT:%.*]] = xla_hlo.constant dense<4> : tensor - // CHECK: [[CMP:%.*]] = "xla_hlo.compare"([[IV]], [[LIMIT]]) {comparison_direction = "LT"} - // CHECK: "xla_hlo.return"([[CMP]]) + // CHECK: [[IV:%.*]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} + // CHECK: [[LIMIT:%.*]] = mhlo.constant dense<4> : tensor + // CHECK: [[CMP:%.*]] = "mhlo.compare"([[IV]], [[LIMIT]]) {comparison_direction = "LT"} + // CHECK: "mhlo.return"([[CMP]]) // CHECK: }, { // CHECK: ^{{.*}}([[BODY_ARG:%.*]]: tuple, tensor<4xi32>, tensor<4xi32>>): - // CHECK: [[IV:%.*]] = "xla_hlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32} - // CHECK: [[SWAPS:%.*]] = "xla_hlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32} - // CHECK: [[INDICES:%.*]] = "xla_hlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32} - // CHECK: [[SRC_IDX:%.*]] = "xla_hlo.dynamic-slice"([[INDICES]], [[IV]]) {slice_sizes = dense<1> : tensor} : (tensor<4xi32>, tensor) -> tensor<1xi32> - // CHECK: [[SWP_IDX:%.*]] = "xla_hlo.dynamic-slice"([[SWAPS]], [[IV]]) {slice_sizes = dense<1> : tensor} : (tensor<4xi32>, tensor) -> tensor<1xi32> - // CHECK: [[SWP:%.*]] = "xla_hlo.reshape"([[SWP_IDX]]) : (tensor<1xi32>) -> tensor - // CHECK: [[TGT_IDX:%.*]] = "xla_hlo.dynamic-slice"([[INDICES]], [[SWP]]) {slice_sizes = dense<1> : tensor} - // CHECK: [[INDICES1:%.*]] = "xla_hlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> - // CHECK: [[INDICES2:%.*]] = "xla_hlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> - // CHECK: [[ONE:%.*]] = xla_hlo.constant dense<1> : tensor - // CHECK: [[NEW_IV:%.*]] = xla_chlo.broadcast_add [[IV]], [[ONE]] - // CHECK: [[NEW_TUPLE:%.*]] = "xla_hlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]]) - // CHECK: "xla_hlo.return"([[NEW_TUPLE]]) + // CHECK: [[IV:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32} + // CHECK: [[SWAPS:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32} + // CHECK: [[INDICES:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32} + // CHECK: [[SRC_IDX:%.*]] = "mhlo.dynamic-slice"([[INDICES]], [[IV]]) {slice_sizes = dense<1> : tensor} : (tensor<4xi32>, tensor) -> tensor<1xi32> + // CHECK: [[SWP_IDX:%.*]] = "mhlo.dynamic-slice"([[SWAPS]], [[IV]]) {slice_sizes = dense<1> : tensor} : (tensor<4xi32>, tensor) -> tensor<1xi32> + // CHECK: [[SWP:%.*]] = "mhlo.reshape"([[SWP_IDX]]) : (tensor<1xi32>) -> tensor + // CHECK: [[TGT_IDX:%.*]] = "mhlo.dynamic-slice"([[INDICES]], [[SWP]]) {slice_sizes = dense<1> : tensor} + // CHECK: [[INDICES1:%.*]] = "mhlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> + // CHECK: [[INDICES2:%.*]] = "mhlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> + // CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor + // CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[IV]], [[ONE]] + // CHECK: [[NEW_TUPLE:%.*]] = "mhlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]]) + // CHECK: "mhlo.return"([[NEW_TUPLE]]) // CHECK: }) : (tuple, tensor<4xi32>, tensor<4xi32>>) -> tuple, tensor<4xi32>, tensor<4xi32>> - // CHECK: [[SWAPED_INDICES:%.*]] = "xla_hlo.get_tuple_element"([[WHILE_OUT]]) {index = 2 : i32} : (tuple, tensor<4xi32>, tensor<4xi32>>) -> tensor<4xi32> - // CHECK: [[GATHER:%.*]] = "xla_hlo.gather"([[INPUT]], [[SWAPED_INDICES]]) + // CHECK: [[SWAPED_INDICES:%.*]] = "mhlo.get_tuple_element"([[WHILE_OUT]]) {index = 2 : i32} : (tuple, tensor<4xi32>, tensor<4xi32>>) -> tensor<4xi32> + // CHECK: [[GATHER:%.*]] = "mhlo.gather"([[INPUT]], [[SWAPED_INDICES]]) // CHECK-SAME: dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 1 : i64, offset_dims = dense<[1, 2, 3]> : tensor<3xi64>, start_index_map = dense<0> : tensor<1xi64>} // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: slice_sizes = dense<[1, -1, 16]> : tensor<3xi64> @@ -3814,16 +3814,16 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK-LABEL: avgpool_valid_padding // CHECK-SAME: [[ARG:%.+]]: tensor<2x12x20x7xf16> func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> { - // CHECK: [[CONV32:%.+]] = "xla_hlo.convert"(%arg0) : (tensor<2x12x20x7xf16>) -> tensor<2x12x20x7xf32> - // CHECK: [[INIT:%.+]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.+]] = "xla_hlo.reduce_window"([[CONV32]], [[INIT]]) ( { + // CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x20x7xf16>) -> tensor<2x12x20x7xf32> + // CHECK: [[INIT:%.+]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.+]] = "mhlo.reduce_window"([[CONV32]], [[INIT]]) ( { // CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): - // CHECK: [[ADD:%.+]] = xla_hlo.add [[ARG1]], [[ARG2]] - // CHECK: "xla_hlo.return"([[ADD]]) + // CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] + // CHECK: "mhlo.return"([[ADD]]) // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor) -> tensor<2x3x5x7xf32> - // CHECK: [[COUNT:%.+]] = xla_hlo.constant dense<4.000000e+00> : tensor - // CHECK: [[DIV:%.+]] = xla_chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> - // CHECK: [[CONV16:%.+]] = "xla_hlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16> + // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor + // CHECK: [[DIV:%.+]] = chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> + // CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16> // CHECK: return [[CONV16]] %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> return %0 : tensor<2x3x5x7xf16> @@ -3842,20 +3842,20 @@ func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> // CHECK-LABEL: func @avgpool_grad_valid_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { -// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] -// CHECK_SAME: broadcast_dimensions = dense<[]> +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] +// CHECK_SAME: broadcast_dimensions = dense<> // CHECK_SAME: -> tensor<10x12x16x64xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> // CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> // CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> // CHECK-SAME: -> tensor<10x25x33x64xf32> -// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: "xla_hlo.return"(%[[SUM]]) : (tensor) -> () +// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "mhlo.return"(%[[SUM]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> // CHECK-SAME: window_strides = dense<1> @@ -3874,18 +3874,18 @@ func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24 // CHECK-LABEL: func @avgpool_3d_grad_valid_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { -// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]> // CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> // CHECK-SAME: interior_padding = dense<[0, 0, 1, 1, 0]> // CHECK-SAME: -> tensor<10x8x25x33x64xf32> -// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: "xla_hlo.return"(%[[SUM]]) : (tensor) -> () +// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "mhlo.return"(%[[SUM]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> // CHECK-SAME: window_strides = dense<1> @@ -3903,27 +3903,27 @@ func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor< // CHECK-LABEL: func @avgpool_grad_same_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { -// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> -// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> // CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> // CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> // CHECK-SAME: -> tensor<2x4x7x9xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 0]> // CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> // CHECK-SAME: interior_padding = dense<[0, 3, 3, 0]> // CHECK-SAME: -> tensor<2x14x27x9xf32> -// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> // CHECK-SAME: window_strides = dense<1> @@ -3942,27 +3942,27 @@ func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9x // CHECK-LABEL: func @avgpool_3d_grad_same_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { -// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> -// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> // CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> // CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> // CHECK-SAME: -> tensor<2x8x4x7x9xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1, 0]> // CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> // CHECK-SAME: interior_padding = dense<[0, 0, 3, 3, 0]> // CHECK-SAME: -> tensor<2x8x14x27x9xf32> -// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> // CHECK-SAME: window_strides = dense<1> @@ -3980,27 +3980,27 @@ func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x1 // CHECK-LABEL: func @avgpool_grad_nchw_format( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { -// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> -// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]> // CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> // CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> // CHECK-SAME: -> tensor<2x9x4x7xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1]> // CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1]> // CHECK-SAME: interior_padding = dense<[0, 0, 3, 3]> // CHECK-SAME: -> tensor<2x9x14x27xf32> -// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> // CHECK-SAME: window_strides = dense<1> @@ -4019,27 +4019,27 @@ func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf // CHECK-LABEL: func @avgpool_3d_grad_ncdwh_format( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { -// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> -// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]> // CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> // CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> // CHECK-SAME: -> tensor<2x9x8x4x7xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 0, 1]> // CHECK-SAME: edge_padding_low = dense<[0, 0, 0, 1, 1]> // CHECK-SAME: interior_padding = dense<[0, 0, 0, 3, 3]> // CHECK-SAME: -> tensor<2x9x8x14x27xf32> -// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> // CHECK-SAME: window_strides = dense<1> : tensor<5xi64> @@ -4057,27 +4057,27 @@ func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8 // CHECK-LABEL: func @avgpool_grad_bf16( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { -// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] -// CHECK-SAME: broadcast_dimensions = dense<[]> +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] +// CHECK-SAME: broadcast_dimensions = dense<> // CHECK-SAME: -> tensor<10x12x16x64xbf16> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> // CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> // CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> // CHECK-SAME: -> tensor<10x25x33x64xbf16> -// CHECK: %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = "xla_hlo.convert"(%[[REDUCE_WINDOW_INPUT]]) : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32> -// CHECK: %[[ZERO_F32:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) ( { +// CHECK: %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = "mhlo.convert"(%[[REDUCE_WINDOW_INPUT]]) : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32> +// CHECK: %[[ZERO_F32:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) ( { // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: "xla_hlo.return"(%[[SUM]]) : (tensor) -> () +// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "mhlo.return"(%[[SUM]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> // CHECK-SAME: window_strides = dense<1> // CHECK-SAME: -> tensor<10x24x32x64xf32> -// CHECK: %[[RESULT_CONVERTED:.*]] = "xla_hlo.convert"(%[[RESULT]]) : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16> +// CHECK: %[[RESULT_CONVERTED:.*]] = "mhlo.convert"(%[[RESULT]]) : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16> // CHECK: return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16> func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) @@ -4092,18 +4092,18 @@ func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xb // CHECK-LABEL: xla_sharding func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { - // CHECK-NEXT: "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", has_side_effect = false, xla_hlo.sharding = ""} + // CHECK-NEXT: "mhlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", has_side_effect = false, mhlo.sharding = ""} %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32> return %0 : tensor<4x16xf32> } // CHECK-LABEL: inplace_update_one func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> { - // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]]) - // CHECK-DAG: [[UPDATE:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]) + // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) + // CHECK-DAG: [[UPDATE:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]) %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> // CHECK: return [[UPDATE]] @@ -4112,19 +4112,19 @@ func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: // CHECK-LABEL: inplace_update_three func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> { - // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SLICE3:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SLICE4:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} - // CHECK-DAG: [[SLICE5:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} - // CHECK-DAG: [[SLICE6:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} - // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]]) - // CHECK-DAG: [[RESHAPE2:%.+]] = "xla_hlo.reshape"([[SLICE2]]) - // CHECK-DAG: [[RESHAPE3:%.+]] = "xla_hlo.reshape"([[SLICE3]]) - // CHECK-DAG: [[UPDATE1:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]) - // CHECK-DAG: [[UPDATE2:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]) - // CHECK-DAG: [[UPDATE3:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]) + // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) + // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]]) + // CHECK-DAG: [[RESHAPE3:%.+]] = "mhlo.reshape"([[SLICE3]]) + // CHECK-DAG: [[UPDATE1:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE2:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE3:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]) %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> @@ -4134,11 +4134,11 @@ func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, % // CHECK-LABEL: xla_dynamic_update_slice func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> { - // CHECK: [[SLICE0:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> - // CHECK: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor - // CHECK: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> - // CHECK: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]]) : (tensor<1xi32>) -> tensor - // CHECK: [[DUS:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]]) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor, tensor) -> tensor<4x16xf32> + // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor + // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) : (tensor<1xi32>) -> tensor + // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]]) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor, tensor) -> tensor<4x16xf32> // CHECK: return [[DUS]] %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32> return %0 : tensor<4x16xf32> @@ -4146,9 +4146,9 @@ func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, // CHECK-LABEL: xla_dynamic_update_slice2 func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> { - // CHECK: [[SLICE0:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> - // CHECK: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor - // CHECK: [[DUS:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]]) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> + // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor + // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]]) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> // CHECK: return [[DUS]] %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32> return %0 : tensor<4xf32> @@ -4164,7 +4164,7 @@ func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> { value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32> } : () -> tensor<3x4xi32> %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} : (tensor<10xf32>, tensor<3x4xi32>) -> tensor<10xf32> - // CHECK: xla_hlo.all_to_all + // CHECK: mhlo.all_to_all // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64> return %result : tensor<10xf32> } @@ -4176,15 +4176,15 @@ func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> { // CHECK-LABEL: func @cumsum_static // CHECK-SAME: [[X:%.*]]: tensor<4xf32> func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: [[AXIS:%.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: [[CONVERT_X:%.*]] = "xla_hlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: [[INIT:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "xla_hlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { + // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): - // CHECK: [[SUM:%.*]] = xla_hlo.add [[A]], [[B]] : tensor - // CHECK: "xla_hlo.return"([[SUM]]) : (tensor) -> () + // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor + // CHECK: "mhlo.return"([[SUM]]) : (tensor) -> () // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[CONVERT_REDUCE:%.*]] = "xla_hlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32> // CHECK: return [[CONVERT_REDUCE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> @@ -4232,17 +4232,17 @@ func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x7 // CHECK-LABEL: func @softplus_f16 // CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>) func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]]) - // CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<1.220700e-04> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]]) - // CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]]) - // CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} - // CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]]) - // CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) - // CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) + // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) + // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor + // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) + // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} + // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) + // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) + // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16> // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf16> @@ -4252,17 +4252,17 @@ func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> { // CHECK-LABEL: func @softplus_bf16 // CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>) func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]]) - // CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<7.812500e-03> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]]) - // CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]]) - // CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} - // CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]]) - // CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) - // CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) + // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) + // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor + // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) + // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} + // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) + // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) + // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) %0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16> // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xbf16> @@ -4272,17 +4272,17 @@ func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> { // CHECK-LABEL: func @softplus_f32 // CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>) func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]]) - // CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<1.1920929E-7> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]]) - // CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]]) - // CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} - // CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]]) - // CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) - // CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) + // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) + // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor + // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) + // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} + // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) + // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) + // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf32> @@ -4292,17 +4292,17 @@ func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { // CHECK-LABEL: func @softplus_f64 // CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>) func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]]) - // CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<2.2204460492503131E-16> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]]) - // CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]]) - // CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} - // CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]]) - // CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) - // CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) + // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) + // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor + // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) + // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} + // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) + // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) + // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64> // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64> diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir deleted file mode 100644 index d20a8f4ac1d1e8..00000000000000 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir +++ /dev/null @@ -1,93 +0,0 @@ -// RUN: xla-opt -lhlo-copy-removal %s -o - | FileCheck %s - -// CHECK-LABEL: func @remove_simple -func @remove_simple(%arg0: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @remove_without_dealloc -func @remove_without_dealloc(%arg0: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @replace_dependency -func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @keep_copies -func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { - // CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @must_not_be_removed -func @must_not_be_removed(%arg0: memref<2x2xf32>, - %arg1: memref<2x2xf32>, - %arg2: memref<2x2xf32>) { - // CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32> - %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - "xla_lhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @must_be_removed_first -func @must_be_removed_first(%arg0: memref<2x2xf32>, - %arg1: memref<2x2xf32>, - %arg2: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - "xla_lhlo.terminator"() : () -> () -} - -// ----- - -// CHECK-LABEL: func @must_be_removed_second -func @must_be_removed_second(%arg0: memref<2x2xf32>, - %arg1: memref<2x2xf32>, - %arg2: memref<2x2xf32>) { - %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - dealloc %0 : memref<2x2xf32> - "xla_lhlo.terminator"() : () -> () -} diff --git a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir deleted file mode 100644 index dfc615bbef4ec8..00000000000000 --- a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir +++ /dev/null @@ -1,224 +0,0 @@ -// RUN: xla-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s - -// CHECK-LABEL: @add -func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3 - %4 = "xla_hlo.add"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<2xf32>, tensor<2xf32> -} - -// CHECK-LABEL: @add_unranked -func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3 - %4 = "xla_hlo.add"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<*xf32>, tensor<*xf32> -} - -// CHECK-LABEL: @sub -func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3 - %4 = "xla_hlo.subtract"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<2xf32>, tensor<2xf32> -} - -// CHECK-LABEL: @sub_unranked -func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3 - %4 = "xla_hlo.subtract"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<*xf32>, tensor<*xf32> -} - -// CHECK-LABEL: @mul -func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3 - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] - %4 = "xla_hlo.multiply"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32> - return %5, %6 : tensor<2xf32>, tensor<2xf32> -} - -// CHECK-LABEL: @mul_unranked -func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3 - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] - %4 = "xla_hlo.multiply"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - - // CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32> - return %5, %6 : tensor<*xf32>, tensor<*xf32> -} - -// CHECK-LABEL: @div -func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) - - // Compute the numerator's real component: - // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] - - // Compute the real valued denominator as rhs * con(rhs): - // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] - - // Compute the numerator's imaginary component: - // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2 - // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]] - // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] - - // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]] - // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]] - %4 = "xla_hlo.divide"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - - %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return [[VAL10]], [[VAL11]] - return %5, %6 : tensor<2xf32>, tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: @div_unranked -func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) - - // Compute the numerator's real component: - // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] - - // Compute the real valued denominator as rhs * con(rhs): - // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] - - // Compute the numerator's imaginary component: - // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2 - // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]] - // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] - - // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]] - // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]] - %4 = "xla_hlo.divide"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - - %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - - // CHECK: return [[VAL10]], [[VAL11]] - return %5, %6 : tensor<*xf32>, tensor<*xf32> -} - -// CHECK-LABEL: @abs -func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { - %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg0 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg1 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]]) - %1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) - %2 = "xla_hlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return [[VAL3]] - return %2 : tensor<2xf32> -} - -// CHECK-LABEL: @exp -func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0) - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1) - // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1) - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]] - %1 = "xla_hlo.exponential"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) - %2 = "xla_hlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) - %3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return [[VAL3]], [[VAL4]] - return %2, %3 : tensor<2xf32>, tensor<2xf32> -} - -// CHECK-LABEL: @exp_unranked -func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0) - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1) - // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1) - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]] - %1 = "xla_hlo.exponential"(%0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) - %2 = "xla_hlo.real"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) - - // CHECK: return [[VAL3]], [[VAL4]] - return %2, %3 : tensor<*xf32>, tensor<*xf32> -} diff --git a/tensorflow/compiler/mlir/xla/tests/lower-general-dot.mlir b/tensorflow/compiler/mlir/xla/tests/lower-general-dot.mlir deleted file mode 100644 index 7250fd4cc9415d..00000000000000 --- a/tensorflow/compiler/mlir/xla/tests/lower-general-dot.mlir +++ /dev/null @@ -1,35 +0,0 @@ -// RUN: xla-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s - -// CHECK-LABEL: @testDebatch1 -func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> { - // CHECK-DAG: [[R0:%.+]] = "xla_hlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32> - // CHECK-DAG: [[R1:%.+]] = "xla_hlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - // CHECK: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32> - %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32> - - return %0 : tensor<1x1x3xf32> -} - -// ----- - -// CHECK-LABEL: @testDebatch2 -func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> { - // CHECK-DAG: [[R0:%.+]] = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> - // CHECK-DAG: [[R1:%.+]] = "xla_hlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32> - // CHECK-DAG: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32> - // CHECK-DAG: [[R3:%.+]] = "xla_hlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32> - // CHECK: [[R4:%.+]] = "xla_hlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32> - - %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32> - return %0 : tensor<3x1x1xf32> -} - -// ----- - -// CHECK-LABEL: @testBatchPassthrough -func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> { - // CHECK-NEXT: "xla_hlo.dot_general"(%arg0, %arg1) - %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32> - return %0 : tensor<3x2x1xf32> -} - diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir deleted file mode 100644 index 746c3150b755f9..00000000000000 --- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: xla-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck %s - -// CHECK-LABEL: @clampBroadcast -// CHECK-SAME: (%[[MIN:.+]]: tensor, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor) -func @clampBroadcast(%min: tensor, %value: tensor<4xf32>, %max: tensor) -> tensor<4xf32> { - // CHECK-DAG: %[[MIN_BC:.+]] = "xla_hlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> - // CHECK-DAG: %[[MAX_BC:.+]] = "xla_hlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> - // CHECK: "xla_hlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> - return %0 : tensor<4xf32> -} diff --git a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc index 70154f5190bc32..1a3f0c16247b84 100644 --- a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc +++ b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc @@ -48,7 +48,7 @@ class XlaBuilderTest : public ::testing::Test { xla_builder_(name_, builder_, module_->getLoc()) {} string SetupTest() { - mlir::registerDialect(); + mlir::registerDialect(); return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } @@ -75,7 +75,7 @@ TEST_F(XlaBuilderTest, CreateToken) { TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); ExpectHasSubstr(GetMlirOpString(token), - R"("xla_hlo.create_token"() : () -> !xla_hlo.token)"); + R"("mhlo.create_token"() : () -> !mhlo.token)"); } TEST_F(XlaBuilderTest, Infeed) { @@ -85,7 +85,7 @@ TEST_F(XlaBuilderTest, Infeed) { TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); ExpectHasSubstr( GetMlirOpString(infeed), - R"("xla_hlo.infeed"(%0) {infeed_config = ""} : (!xla_hlo.token) -> tuple, !xla_hlo.token>)"); + R"("mhlo.infeed"(%0) {infeed_config = ""} : (!mhlo.token) -> tuple, !mhlo.token>)"); } TEST_F(XlaBuilderTest, Outfeed) { @@ -99,7 +99,7 @@ TEST_F(XlaBuilderTest, Outfeed) { TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); ExpectHasSubstr( GetMlirOpString(outfeed), - R"("xla_hlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !xla_hlo.token) -> !xla_hlo.token)"); + R"("mhlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !mhlo.token) -> !mhlo.token)"); } TEST_F(XlaBuilderTest, ConcatInDim) { @@ -112,7 +112,7 @@ TEST_F(XlaBuilderTest, ConcatInDim) { TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); ExpectHasSubstr( GetMlirOpString(concat), - R"("xla_hlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)"); + R"("mhlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)"); } TEST_F(XlaBuilderTest, Tuple) { @@ -125,7 +125,7 @@ TEST_F(XlaBuilderTest, Tuple) { TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); ExpectHasSubstr( GetMlirOpString(tuple), - R"("xla_hlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor) -> tuple, tensor>)"); + R"("mhlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor) -> tuple, tensor>)"); } TEST_F(XlaBuilderTest, GetTupleElement) { @@ -139,7 +139,7 @@ TEST_F(XlaBuilderTest, GetTupleElement) { TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); ExpectHasSubstr( GetMlirOpString(gte), - R"("xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple, tensor>) -> tensor)"); + R"("mhlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple, tensor>) -> tensor)"); } TEST_F(XlaBuilderTest, Slice) { @@ -150,7 +150,7 @@ TEST_F(XlaBuilderTest, Slice) { TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); ExpectHasSubstr( GetMlirOpString(slice), - R"("xla_hlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)"); + R"("mhlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)"); } TEST_F(XlaBuilderTest, Pad) { @@ -172,7 +172,7 @@ TEST_F(XlaBuilderTest, Pad) { TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); ExpectHasSubstr( GetMlirOpString(pad), - R"("xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor) -> tensor<6x16xf32>)"); + R"("mhlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor) -> tensor<6x16xf32>)"); } } // namespace diff --git a/tensorflow/compiler/mlir/xla/tests/reshape.mlir b/tensorflow/compiler/mlir/xla/tests/reshape.mlir deleted file mode 100644 index fe16e8c1c99e74..00000000000000 --- a/tensorflow/compiler/mlir/xla/tests/reshape.mlir +++ /dev/null @@ -1,149 +0,0 @@ -// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s - -// CHECK-LABEL: func @const_fold_collapse_to_scalar -func @const_fold_collapse_to_scalar() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42> : tensor<1x1xi32> - %0 = "xla_hlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor - // CHECK-NEXT: return [[CST]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @const_fold_collapse_to_tensor -func @const_fold_collapse_to_tensor() -> tensor<2xi32> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<2xi32> - %cst = xla_hlo.constant dense<42> : tensor<1x2xi32> - %0 = "xla_hlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32> - // CHECK-NEXT: return [[CST]] - return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: func @const_fold_expand -func @const_fold_expand() -> tensor<1xi32> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<1xi32> - %cst = xla_hlo.constant dense<42> : tensor - %0 = "xla_hlo.reshape"(%cst) : (tensor) -> tensor<1xi32> - // CHECK-NEXT: return [[CST]] - return %0 : tensor<1xi32> -} - -// ----- - -// CHECK-LABEL: func @const_fold_nontrivial -func @const_fold_nontrivial() -> tensor<16xi64> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64> - %cst = xla_hlo.constant dense<42> : tensor<4x4xi64> - %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> - // CHECK-NEXT: return [[CST]] - return %0 : tensor<16xi64> -} - -// ----- - -// CHECK-LABEL: func @const_fold_flatten -func @const_fold_flatten() -> tensor<16xi64> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64> - %cst = xla_hlo.constant dense<42> : tensor<4x4xi64> - %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> - // CHECK-NEXT: return [[CST]] - return %0 : tensor<16xi64> -} - -// ----- - -// CHECK-LABEL: func @const_fold_6 -func @const_fold_6() -> tensor<6xi32> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> - %cst = xla_hlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> - %0 = "xla_hlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32> - // CHECK-NEXT: return [[CST]] - return %0 : tensor<6xi32> -} - -// ----- - -// CHECK-LABEL: func @const_fold_same_shape -func @const_fold_same_shape() -> tensor<2x3xi32> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[ - // CHECK-SAME: [1, 2, 3], [4, 5, 6] - // CHECK-SAME: ]> : tensor<2x3xi32> - %cst = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> - %0 = "xla_hlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32> - // CHECK-NEXT: return [[CST]] - return %0 : tensor<2x3xi32> -} - -// ----- - -// CHECK-LABEL: func @const_fold_float -func @const_fold_float() -> tensor<16xf64> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64> - %cst = xla_hlo.constant dense<4.2> : tensor<4x4xf64> - %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64> - // CHECK-NEXT: return [[CST]] - return %0 : tensor<16xf64> -} - -// ----- - -// CHECK-LABEL: func @non_const_same_shape -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-NEXT: return [[ARG]] - %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32> - return %0 : tensor<2x3xi32> -} - -// ----- - -// CHECK-LABEL: func @non_const_chained_reshape -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) { - // CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32> - // CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> - %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> - %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> - return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed -} - -// ----- - -// CHECK-LABEL: func @non_const_chained_reshape_unused_parent -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> - %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> - %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> - // CHECK-NEXT: return [[RES]] - return %1 : tensor<6xi32> -} - -// ----- - -// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { - %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> - %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32> - // CHECK-NEXT: return [[ARG]] - return %1 : tensor<2x3xi32> -} - -// ----- - -// CHECK-LABEL: func @non_const_many_chained_reshapes -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> - %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32> - %1 = "xla_hlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32> - %2 = "xla_hlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32> - %3 = "xla_hlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32> - %4 = "xla_hlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32> - // CHECK-NEXT: return [[RES]] - return %4 : tensor<1x2x4x3xi32> -} diff --git a/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir deleted file mode 100644 index cecd95f0ffe720..00000000000000 --- a/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir +++ /dev/null @@ -1,60 +0,0 @@ -// RUN: xla-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s - -// Tests sinking constants to a while loop. - -// CHECK-LABEL: func @sink_const_to_while -func @sink_const_to_while(%arg0: tensor) -> tensor { - // CHECK-NEXT: xla_hlo.while - %c0 = xla_hlo.constant dense<1> : tensor - %c1 = xla_hlo.constant dense<2> : tensor - %0 = "xla_hlo.while"(%arg0) ( { - ^bb0(%arg1: tensor): - // CHECK: %[[ARG1A:.+]]: tensor - // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor - // CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]]) - %1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () - }, { - ^bb0(%arg1: tensor): - // CHECK: %[[ARG1B:.+]]: tensor - // CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor - // CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]] - %2 = xla_hlo.add %arg1, %arg1 : tensor - // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]] - %3 = xla_hlo.add %c1, %2 : tensor - // CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]] - %4 = xla_hlo.add %c1, %3 : tensor - "xla_hlo.return"(%4) : (tensor) -> () - }) : (tensor) -> tensor - return %0 : tensor -} - -// Tests sinking constants to a conditional op. - -// CHECK-LABEL: func @sink_const_to_conditional -func @sink_const_to_conditional(%arg0: tensor) -> tensor { - %c0 = xla_hlo.constant dense<1> : tensor - %c1 = xla_hlo.constant dense<2> : tensor - %0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - %1 = "xla_hlo.tuple"(%arg0) : (tensor) -> tuple> - // CHECK: xla_hlo.if - %2 = "xla_hlo.if"(%0, %1, %1) ( { - ^bb0(%arg1: tuple>): - // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor - %3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor - // CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]], - %4 = xla_hlo.add %c0, %3 : tensor - %5 = "xla_hlo.tuple"(%4) : (tensor) -> tuple> - "xla_hlo.return"(%5) : (tuple>) -> () - }, { - ^bb0(%arg1: tuple>): - // CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor - %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor - // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], - %7 = xla_hlo.add %c1, %6 : tensor - %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple> - "xla_hlo.return"(%8) : (tuple>) -> () - }) : (tensor, tuple>, tuple>) -> tuple> - %9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor - return %9 : tensor -} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/add.mlir b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir index 4d846a0603cd7b..9bcc0b92aa4cfc 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/add.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir @@ -12,9 +12,9 @@ func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) // CHECK-NEXT: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: ROOT %add.4 = f32[4] add(f32[4] %add.3, f32[4] %Arg_1.2) - %1 = "xla_hlo.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %1 = "mhlo.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir index dba9e8b61ca8a9..579595682878a0 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir @@ -5,18 +5,18 @@ func @main() -> tensor { %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor - %0 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { + %0 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { ^bb0(%arg0: tensor): - %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.copy"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -52,18 +52,18 @@ func @main() -> (tensor, tensor) { %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor - %0:2 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { + %0:2 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { ^bb0(%arg0: tensor): - %1 = "xla_hlo.negate"(%arg0) {name = "negate"} : (tensor) -> tensor - "xla_hlo.return"(%1, %1) : (tensor, tensor) -> () + %1 = "mhlo.negate"(%arg0) {name = "negate"} : (tensor) -> tensor + "mhlo.return"(%1, %1) : (tensor, tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.copy"(%arg0) {name = "copy"} : (tensor) -> tensor - "xla_hlo.return"(%1, %1) : (tensor, tensor) -> () + %1 = "mhlo.copy"(%arg0) {name = "copy"} : (tensor) -> tensor + "mhlo.return"(%1, %1) : (tensor, tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.floor"(%arg0) {name = "floor"} : (tensor) -> tensor - "xla_hlo.return"(%1, %1) : (tensor, tensor) -> () + %1 = "mhlo.floor"(%arg0) {name = "floor"} : (tensor) -> tensor + "mhlo.return"(%1, %1) : (tensor, tensor) -> () }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) return %0#0, %0#1 : tensor, tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt index 2ff223cd480ef3..62f0d7a59e4f8e 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt @@ -30,17 +30,17 @@ ENTRY %indexed_conditional () -> f32[] { // CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor // CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor // CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor -// CHECK: %[[RESULT:.*]] = "xla_hlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( { +// CHECK: %[[RESULT:.*]] = "mhlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( { // CHECK: ^bb0(%[[ARG_1:.*]]: tensor): -// CHECK: %[[RES_1:.*]] = "xla_hlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor) -> tensor -// CHECK: "xla_hlo.return"(%[[RES_1]]) : (tensor) -> () +// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: "mhlo.return"(%[[RES_1]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[ARG_2:.*]]: tensor): -// CHECK: %[[RES_2:.*]] = "xla_hlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor) -> tensor -// CHECK: "xla_hlo.return"(%[[RES_2]]) : (tensor) -> () +// CHECK: %[[RES_2:.*]] = "mhlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: "mhlo.return"(%[[RES_2]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[ARG_3:.*]]: tensor): -// CHECK: %[[RES_3:.*]] = "xla_hlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor) -> tensor -// CHECK: "xla_hlo.return"(%[[RES_3]]) : (tensor) -> () +// CHECK: %[[RES_3:.*]] = "mhlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: "mhlo.return"(%[[RES_3]]) : (tensor) -> () // CHECK: }) {name = "{{.*}}"} : (tensor, tensor, tensor, tensor) -> tensor // CHECK: return %[[RESULT]] : tensor diff --git a/tensorflow/compiler/mlir/xla/tests/translate/dynamic_parameter_binding.mlir b/tensorflow/compiler/mlir/xla/tests/translate/dynamic_parameter_binding.mlir index af11ccfdad6279..f76d51e0b4035b 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/dynamic_parameter_binding.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/dynamic_parameter_binding.mlir @@ -19,7 +19,7 @@ func @main(%arg0: tensor<10xf32>, %arg1: tensor) { // Test entry function with single dynamic parameter binding on an argument. -func @main(%arg0: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i32], padding_arg_indices = [1 : i32]}}, %arg1: tensor) { +func @main(%arg0: tensor<10xf32> {mhlo.padding_map = {shape_indices = [0 : i32], padding_arg_indices = [1 : i32]}}, %arg1: tensor) { return } @@ -42,7 +42,7 @@ func @main(%arg0: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i3 // Test entry function with multiple dynamic parameter bindings on an argument. -func @main(%arg0: tensor<8x10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor, %arg2: tensor) { +func @main(%arg0: tensor<8x10xf32> {mhlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor, %arg2: tensor) { return } @@ -75,7 +75,7 @@ func @main(%arg0: tensor<8x10xf32> {xla_hlo.padding_map = {shape_indices = [0 : // Test entry function with multiple dynamic parameter bindings on multiple // arguments. -func @main(%arg0: tensor<8x10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor, %arg2: tensor, %arg3: tensor<10x8x6xi32> {xla_hlo.padding_map = {shape_indices = [2 : i32], padding_arg_indices = [4 : i32]}}, %arg4: tensor) { +func @main(%arg0: tensor<8x10xf32> {mhlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor, %arg2: tensor, %arg3: tensor<10x8x6xi32> {mhlo.padding_map = {shape_indices = [2 : i32], padding_arg_indices = [4 : i32]}}, %arg4: tensor) { return } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/dynamic_parameter_binding_invalid.mlir b/tensorflow/compiler/mlir/xla/tests/translate/dynamic_parameter_binding_invalid.mlir index 6a7c64733c0e95..244c25a40eb35e 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/dynamic_parameter_binding_invalid.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/dynamic_parameter_binding_invalid.mlir @@ -1,148 +1,148 @@ // RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo %s -o - 2>&1 | FileCheck %s -// Test bad `xla_hlo.padding_map` attribute type. +// Test bad `mhlo.padding_map` attribute type. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = ""}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = ""}) { return } -// CHECK: requires 'xla_hlo.padding_map' dict attribute at arg 1 +// CHECK: requires 'mhlo.padding_map' dict attribute at arg 1 // ----- -// Test missing `shape_indices` attribute in `xla_hlo.padding_map`. +// Test missing `shape_indices` attribute in `mhlo.padding_map`. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {}}) { return } -// CHECK: requires 'shape_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1 +// CHECK: requires 'shape_indices' array attribute in 'mhlo.padding_map' dict at arg 1 // ----- -// Test bad `shape_indices` attribute type in `xla_hlo.padding_map`. +// Test bad `shape_indices` attribute type in `mhlo.padding_map`. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = ""}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = ""}}) { return } -// CHECK: requires 'shape_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1 +// CHECK: requires 'shape_indices' array attribute in 'mhlo.padding_map' dict at arg 1 // ----- -// Test missing `padding_arg_indices` attribute in `xla_hlo.padding_map`. +// Test missing `padding_arg_indices` attribute in `mhlo.padding_map`. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = []}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = []}}) { return } -// CHECK: requires 'padding_arg_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1 +// CHECK: requires 'padding_arg_indices' array attribute in 'mhlo.padding_map' dict at arg 1 // ----- -// Test bad `padding_arg_indices` attribute type in `xla_hlo.padding_map`. +// Test bad `padding_arg_indices` attribute type in `mhlo.padding_map`. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [], padding_arg_indices = ""}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [], padding_arg_indices = ""}}) { return } -// CHECK: requires 'padding_arg_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1 +// CHECK: requires 'padding_arg_indices' array attribute in 'mhlo.padding_map' dict at arg 1 // ----- // Test mismatched `shape_indices` and `padding_arg_indices` lengths. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32, 0 : i32 ]}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32, 0 : i32 ]}}) { return } -// CHECK: requires 'shape_indices' and 'padding_arg_indices' array attributes in 'xla_hlo.padding_map' dic at arg 1 to be of the same size, got sizes 1 and 2 +// CHECK: requires 'shape_indices' and 'padding_arg_indices' array attributes in 'mhlo.padding_map' dic at arg 1 to be of the same size, got sizes 1 and 2 // ----- // Test non integer attribute in `shape_indices`. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32, 0.0: f32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32, 0.0: f32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) { return } -// CHECK: requires element 1 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be an int attribute +// CHECK: requires element 1 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be an int attribute // ----- // Test non integer attribute in `padding_arg_indices`. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0.0: f32 ]}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0.0: f32 ]}}) { return } -// CHECK: requires element 1 in 'padding_arg_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be an int attribute +// CHECK: requires element 1 in 'padding_arg_indices' array of 'mhlo.padding_map' dict at arg 1 to be an int attribute // ----- // Test negative out of range shape index in `shape_indices`. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) { return } -// CHECK: requires element 0 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 1), got -1 +// CHECK: requires element 0 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 1), got -1 // ----- // Test positive out of range shape index in `shape_indices`. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 1: i32 ], padding_arg_indices = [ 0: i32 ]}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 1: i32 ], padding_arg_indices = [ 0: i32 ]}}) { return } -// CHECK: requires element 0 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 1), got 1 +// CHECK: requires element 0 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 1), got 1 // ----- // Test negative shape index in `shape_indices` for unranked argument. -func @main(%arg0: tensor, %arg1: tensor<*xf32> {xla_hlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) { +func @main(%arg0: tensor, %arg1: tensor<*xf32> {mhlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) { return } -// CHECK: requires element 0 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be non-negative, got -1 +// CHECK: requires element 0 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be non-negative, got -1 // ----- // Test duplicate shape indices in `shape_indices`. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) { return } -// CHECK: requires elements in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be unique, got duplicate element 0 at index 1 +// CHECK: requires elements in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be unique, got duplicate element 0 at index 1 // ----- // Test negative out of range shape index in `padding_arg_indices`. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ -1: i32 ]}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ -1: i32 ]}}) { return } -// CHECK: requires element 0 in 'padding_arg_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 2), got -1 +// CHECK: requires element 0 in 'padding_arg_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 2), got -1 // ----- // Test positive out of range shape index in `padding_arg_indices`. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 2: i32 ]}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 2: i32 ]}}) { return } -// CHECK: requires element 0 in 'padding_arg_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 2), got 2 +// CHECK: requires element 0 in 'padding_arg_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 2), got 2 // ----- // Test non scalar padding argument. -func @main(%arg0: tensor<8xi32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) { +func @main(%arg0: tensor<8xi32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) { return } @@ -152,7 +152,7 @@ func @main(%arg0: tensor<8xi32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {s // Test non integer type padding argument. -func @main(%arg0: tensor, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) { +func @main(%arg0: tensor, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) { return } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 79fcfa3614ab9a..9929bd85b43c93 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -1,9 +1,9 @@ // RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s // CHECK: HloModule -func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token { - %0 = "xla_hlo.after_all"(%arg0, %arg1) : (!xla_hlo.token, !xla_hlo.token) -> !xla_hlo.token - return %0 : !xla_hlo.token +func @main(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0, %arg1) : (!mhlo.token, !mhlo.token) -> !mhlo.token + return %0 : !mhlo.token } // CHECK: ENTRY @@ -15,11 +15,11 @@ func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token { // CHECK: HloModule func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { - %0 = "xla_hlo.all_reduce"(%arg0) ({ + %0 = "mhlo.all_reduce"(%arg0) ({ // Perform max reduction inside the region ^bb0(%lhs: tensor, %rhs: tensor): - %max = xla_hlo.maximum %lhs, %rhs : tensor - "xla_hlo.return"(%max) : (tensor) -> () + %max = mhlo.maximum %lhs, %rhs : tensor + "mhlo.return"(%max) : (tensor) -> () }) { replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, @@ -43,7 +43,7 @@ func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { // CHECK: HloModule func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> { - %0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> + %0 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> return %0 : tuple, tensor<2xf32>, tensor<2xf32>> } @@ -60,7 +60,7 @@ func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf // CHECK: HloModule func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> { - %0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> + %0 = "mhlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> return %0 : tuple, tensor<2xf32>, tensor<2xf32>> } @@ -78,16 +78,16 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor // CHECK: [[VAL_1:%.*]] = s32[4] parameter(0) // CHECK: [[VAL_2:%.*]] = s32[4] parameter(1) // CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) - %0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32> + %0 = mhlo.atan2 %arg0, %arg1 : tensor<4xi32> // CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) - %1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> + %1 = mhlo.shift_left %arg0, %arg1 : tensor<4xi32> // CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) - %2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + %2 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> // CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) - %3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32> + %3 = mhlo.shift_right_logical %arg0, %arg1 : tensor<4xi32> // CHECK: ROOT // CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]]) @@ -98,7 +98,7 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor // CHECK: HloModule func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { - %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -112,7 +112,7 @@ func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> { // CHECK: [[ARG:%.*]] = s32[4] parameter(0) // CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] [[ARG]]), dimensions={3} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> return %0 : tensor<1x2x3x4xi32> } @@ -120,7 +120,7 @@ func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> { // CHECK: HloModule func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> { - %result = "xla_hlo.broadcast_in_dim"(%arg0) { + %result = "mhlo.broadcast_in_dim"(%arg0) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<1xf32>) -> tensor<1x10xf32> return %result : tensor<1x10xf32> @@ -133,9 +133,9 @@ func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> { // ----- // CHECK: HloModule -func @main() -> !xla_hlo.token { - %0 = "xla_hlo.create_token"() : () -> !xla_hlo.token - return %0 : !xla_hlo.token +func @main() -> !mhlo.token { + %0 = "mhlo.create_token"() : () -> !mhlo.token + return %0 : !mhlo.token } // CHECK: ROOT [[TOKEN:%.*]] = token[] after-all() @@ -150,7 +150,7 @@ func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> { } func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -181,8 +181,8 @@ func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) { } func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) { - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - %1 = "xla_hlo.multiply"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "mhlo.multiply"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %0, %1 : tensor<4xi32>, tensor<4xi32> } @@ -202,7 +202,7 @@ func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tens // CHECK: HloModule func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { - %0 = "xla_hlo.collective_permute"(%arg0) { + %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> } : (tensor<128x32xf32>) -> tensor<128x32xf32> return %0 : tensor<128x32xf32> @@ -217,7 +217,7 @@ func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { func @main(%arg0 : tensor<5x2xf32>, %arg1 : tensor<5x5xf32>, %arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> { - %result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { + %result = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 1 : i64 } : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32> return %result : tensor<5x14xf32> @@ -279,7 +279,7 @@ func @main() { // CHECK: HloModule func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { - %result = "xla_hlo.convolution"(%arg0, %arg1) { + %result = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = { input_batch_dimension = 0 : i64, @@ -312,7 +312,7 @@ func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> te // CHECK: HloModule func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { - %0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + %0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -324,7 +324,7 @@ func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { // CHECK: HloModule func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + %0 = "mhlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } @@ -336,8 +336,8 @@ func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK: HloModule func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { - %0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32> - %1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32> + %0 = mhlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32> + %1 = "mhlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32> return %1 : tensor<10xf32> } @@ -354,7 +354,7 @@ func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { // CHECK: HloModule func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32> { - %0 = "xla_hlo.custom_call"(%arg0, %arg1) {backend_config = "bar", call_target_name = "foo"} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> + %0 = "mhlo.custom_call"(%arg0, %arg1) {backend_config = "bar", call_target_name = "foo"} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> return %0 : tensor<1x2x3xf32> } @@ -369,7 +369,7 @@ func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32> // CHECK: HloModule func @main(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { - %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> + %0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> return %0 : tensor<16x64xbf16> } @@ -388,7 +388,7 @@ func @main(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { // CHECK: HloModule func @main(%arg: tensor<16x16xi32>) -> tensor<16x32xbf16> { - %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false, is_16bits = true} : (tensor<16x16xi32>) -> tensor<16x32xbf16> + %0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false, is_16bits = true} : (tensor<16x16xi32>) -> tensor<16x32xbf16> return %0 : tensor<16x32xbf16> } @@ -408,7 +408,7 @@ func @main(%arg: tensor<16x16xi32>) -> tensor<16x32xbf16> { func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> { // Simple einsum is lowered to HLO dot op. // CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0} - %0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32> + %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32> return %0 : tensor<3x5xi32> } @@ -416,7 +416,7 @@ func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> { // CHECK: HloModule func @main(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex> { - %0 = "xla_hlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<3x9xf32>) -> tensor<3x5xcomplex> + %0 = "mhlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<3x9xf32>) -> tensor<3x5xcomplex> return %0 : tensor<3x5xcomplex> } @@ -437,7 +437,7 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10 // CHECK-SAME: index_vector_dim=1 // CHECK-SAME: slice_sizes={1,1,300} // CHECK-SAME: indices_are_sorted=true - %0 = "xla_hlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> return %0 : tensor<10x300xf32> } @@ -445,8 +445,8 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10 // CHECK: HloModule func @main(%arg: tensor<4x2xf32>, %size: tensor) -> tensor { - %0 = "xla_hlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x2xf32>, tensor) -> tensor<4x2xf32> - %1 = "xla_hlo.get_dimension_size"(%0) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor + %0 = "mhlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x2xf32>, tensor) -> tensor<4x2xf32> + %1 = "mhlo.get_dimension_size"(%0) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor return %1 : tensor } @@ -461,7 +461,7 @@ func @main(%arg: tensor<4x2xf32>, %size: tensor) -> tensor { // CHECK: HloModule func @main(%arg0: tuple, tensor>) -> tensor { - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor return %0 : tensor } @@ -472,9 +472,9 @@ func @main(%arg0: tuple, tensor>) -> tensor { // ----- // CHECK: HloModule -func @main(%arg0: !xla_hlo.token) -> tuple, tensor>, !xla_hlo.token> { - %0 = "xla_hlo.infeed"(%arg0) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple, tensor>, !xla_hlo.token> - return %0 : tuple, tensor>, !xla_hlo.token> +func @main(%arg0: !mhlo.token) -> tuple, tensor>, !mhlo.token> { + %0 = "mhlo.infeed"(%arg0) {infeed_config = "foobar"} : (!mhlo.token) -> tuple, tensor>, !mhlo.token> + return %0 : tuple, tensor>, !mhlo.token> } // CHECK: ENTRY @@ -485,7 +485,7 @@ func @main(%arg0: !xla_hlo.token) -> tuple, tensor>, !xl // CHECK: HloModule func @main() -> tensor<1x10xf32> { - %result = "xla_hlo.iota"() { + %result = "mhlo.iota"() { iota_dimension = 1 : i64 } : () -> tensor<1x10xf32> return %result : tensor<1x10xf32> @@ -498,10 +498,10 @@ func @main() -> tensor<1x10xf32> { // CHECK: HloModule func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors - %1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -522,9 +522,9 @@ func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // ----- // CHECK: HloModule -func @main(%data: tensor<3xi32>, %token: !xla_hlo.token) -> !xla_hlo.token { - %0 = "xla_hlo.outfeed"(%data, %token) {outfeed_config = "foobar"} : (tensor<3xi32>, !xla_hlo.token) -> !xla_hlo.token - return %0 : !xla_hlo.token +func @main(%data: tensor<3xi32>, %token: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.outfeed"(%data, %token) {outfeed_config = "foobar"} : (tensor<3xi32>, !mhlo.token) -> !mhlo.token + return %0 : !mhlo.token } // CHECK: ENTRY @@ -536,7 +536,7 @@ func @main(%data: tensor<3xi32>, %token: !xla_hlo.token) -> !xla_hlo.token { // CHECK: HloModule func @main(%arg: tensor<4x6xf32>, %pad: tensor) -> tensor<13x19xf32> { - %0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor) -> tensor<13x19xf32> + %0 = "mhlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor) -> tensor<13x19xf32> return %0 : tensor<13x19xf32> } @@ -549,15 +549,15 @@ func @main(%arg: tensor<4x6xf32>, %pad: tensor) -> tensor<13x19xf32> { // ----- // CHECK: HloModule -func @main(%token: !xla_hlo.token) -> tuple, !xla_hlo.token> { - %0 = "xla_hlo.recv"(%token) { +func @main(%token: !mhlo.token) -> tuple, !mhlo.token> { + %0 = "mhlo.recv"(%token) { channel_id = { handle = 5 : i64, type = 3 : i64 // Host to device channel }, is_host_transfer = true - } : (!xla_hlo.token) -> tuple, !xla_hlo.token> - return %0 : tuple, !xla_hlo.token> + } : (!mhlo.token) -> tuple, !mhlo.token> + return %0 : tuple, !mhlo.token> } // CHECK: ENTRY @@ -569,15 +569,15 @@ func @main(%token: !xla_hlo.token) -> tuple, !xla_hlo.token> { // ----- // CHECK: HloModule -func @main(%token: !xla_hlo.token) -> tuple, !xla_hlo.token> { - %0 = "xla_hlo.recv"(%token) { +func @main(%token: !mhlo.token) -> tuple, !mhlo.token> { + %0 = "mhlo.recv"(%token) { channel_id = { handle = 5 : i64, type = 1 : i64 // Device to device channel }, is_host_transfer = false - } : (!xla_hlo.token) -> tuple, !xla_hlo.token> - return %0 : tuple, !xla_hlo.token> + } : (!mhlo.token) -> tuple, !mhlo.token> + return %0 : tuple, !mhlo.token> } // CHECK: ENTRY @@ -591,11 +591,11 @@ func @main(%token: !xla_hlo.token) -> tuple, !xla_hlo.token> { // CHECK: HloModule func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor, %arg3 : tensor) -> (tensor<1xf32>, tensor<1xi32>) { - %result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( { + %result0, %result1 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( { ^bb0(%fa: tensor, %ia : tensor, %fb: tensor, %ib: tensor): // no predecessors - %fmax = "xla_hlo.maximum"(%fa, %fb) {} : (tensor, tensor) -> tensor - %imax = "xla_hlo.maximum"(%ia, %ib) {} : (tensor, tensor) -> tensor - "xla_hlo.return"(%fmax, %imax) : (tensor, tensor) -> () + %fmax = "mhlo.maximum"(%fa, %fb) {} : (tensor, tensor) -> tensor + %imax = "mhlo.maximum"(%ia, %ib) {} : (tensor, tensor) -> tensor + "mhlo.return"(%fmax, %imax) : (tensor, tensor) -> () }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor, tensor) -> (tensor<1xf32>, tensor<1xi32>) return %result0, %result1 : tensor<1xf32>, tensor<1xi32> } @@ -617,11 +617,11 @@ func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor) -> tensor<2x3x5x7xi32> { - %0 = xla_hlo.constant dense<-2147483648> : tensor - %1 = "xla_hlo.reduce_window"(%arg0, %0) ( { + %0 = mhlo.constant dense<-2147483648> : tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %2 = xla_hlo.maximum %arg1, %arg2 : tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = mhlo.maximum %arg1, %arg2 : tensor + "mhlo.return"(%2) : (tensor) -> () }) { window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>, @@ -646,7 +646,7 @@ func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> { // CHECK: HloModule func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32> return %0 : tensor<1x2xf32> } @@ -658,7 +658,7 @@ func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> { // CHECK: HloModule func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { - %result = "xla_hlo.reverse"(%arg0) { + %result = "mhlo.reverse"(%arg0) { dimensions = dense<[1,2]> : tensor<2xi64> } : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> return %result : tensor<10x11x12x13xf32> @@ -672,8 +672,8 @@ func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { // CHECK: HloModule func @main(%mu: tensor, %sigma: tensor) -> tensor<2x3x5xf32> { - %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> - %0 = "xla_hlo.rng_normal"(%mu, %sigma, %shape) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> + %0 = "mhlo.rng_normal"(%mu, %sigma, %shape) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> return %0 : tensor<2x3x5xf32> } @@ -686,10 +686,10 @@ func @main(%mu: tensor, %sigma: tensor) -> tensor<2x3x5xf32> { // CHECK: HloModule func @main() -> tensor<2x3x5xf32> { - %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %1 = xla_hlo.constant dense<1.000000e+00> : tensor - %2 = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> - %3 = "xla_hlo.rng_uniform"(%0, %1, %2) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = mhlo.constant dense<1.000000e+00> : tensor + %2 = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> + %3 = "mhlo.rng_uniform"(%0, %1, %2) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> return %3 : tensor<2x3x5xf32> } @@ -702,10 +702,10 @@ func @main() -> tensor<2x3x5xf32> { // CHECK: HloModule func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> { - %0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): // no predecessors - %add = xla_hlo.add %lhs, %rhs : tensor - "xla_hlo.return"(%add) : (tensor) -> () + %add = mhlo.add %lhs, %rhs : tensor + "mhlo.return"(%add) : (tensor) -> () }) { scatter_dimension_numbers = { update_window_dims = dense<[1]> : tensor<1xi64>, @@ -737,7 +737,7 @@ func @main(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> // CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2) // CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]]) - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -745,15 +745,15 @@ func @main(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> // CHECK: HloModule func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { - %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %1 = "xla_hlo.select_and_scatter"(%arg0, %arg1, %0) ( { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ( { ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - %2 = "xla_hlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = "mhlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () }, { ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - %2 = xla_hlo.add %arg3, %arg4 : tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = mhlo.add %arg3, %arg4 : tensor + "mhlo.return"(%2) : (tensor) -> () }) { window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> @@ -780,15 +780,15 @@ func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> te // ----- // CHECK: HloModule -func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token { - %0 = "xla_hlo.send"(%arg, %token) { +func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.send"(%arg, %token) { channel_id = { handle = 5 : i64, type = 2 : i64 // Device to host channel }, is_host_transfer = true - } : (tensor<3x4xi32>, !xla_hlo.token) -> !xla_hlo.token - return %0 : !xla_hlo.token + } : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token + return %0 : !mhlo.token } // CHECK: ENTRY @@ -801,15 +801,15 @@ func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token { // ----- // CHECK: HloModule -func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token { - %0 = "xla_hlo.send"(%arg, %token) { +func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.send"(%arg, %token) { channel_id = { handle = 5 : i64, type = 1 : i64 // Device to device channel }, is_host_transfer = false - } : (tensor<3x4xi32>, !xla_hlo.token) -> !xla_hlo.token - return %0 : !xla_hlo.token + } : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token + return %0 : !mhlo.token } // CHECK: ENTRY @@ -823,7 +823,7 @@ func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token { // CHECK: HloModule func @main(%arg: tensor<4x4xf32>, %size: tensor) -> tensor<4x4xf32> { - %0 = "xla_hlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> + %0 = "mhlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> return %0 : tensor<4x4xf32> } @@ -837,7 +837,7 @@ func @main(%arg: tensor<4x4xf32>, %size: tensor) -> tensor<4x4xf32> { // CHECK: HloModule func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32> + %0 = "mhlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -850,7 +850,7 @@ func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> { // CHECK: HloModule func @main(%arg: tensor<3x4xi32>, %start1: tensor, %start2: tensor) -> tensor<1x4xi32> { - %0 = "xla_hlo.dynamic-slice"(%arg, %start1, %start2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic-slice"(%arg, %start1, %start2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -865,7 +865,7 @@ func @main(%arg: tensor<3x4xi32>, %start1: tensor, %start2: tensor) -> // CHECK: HloModule func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - "xla_hlo.trace"(%arg0) {tag = "This is a random test"} : (tensor<2xi32>) -> () + "mhlo.trace"(%arg0) {tag = "This is a random test"} : (tensor<2xi32>) -> () return %arg0: tensor<2xi32> } @@ -880,7 +880,7 @@ func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // CHECK: [[ARG:%.*]] = s32[1,2,3,4] parameter(0) // CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] [[ARG]]), dimensions={1,0,3,2} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0 : tensor<2x1x4x3xi32> } @@ -888,7 +888,7 @@ func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // CHECK: HloModule func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> return %0 : tensor<4x3xf32> } @@ -901,7 +901,7 @@ func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { // CHECK: HloModule func @main(%arg0: tensor, %arg1 : tensor) -> tuple, tensor> { - %result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor, tensor) -> tuple, tensor> + %result = "mhlo.tuple"(%arg0, %arg1) {} : (tensor, tensor) -> tuple, tensor> return %result : tuple, tensor> } @@ -916,17 +916,17 @@ func @main(%arg0: tensor, %arg1 : tensor) -> tuple, tensor func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>) { // CHECK: [[ARG_F32:%.*]] = f32[4] parameter(0) // CHECK: [[EXPM1:%.*]] = f32[4] exponential-minus-one(f32[4] [[ARG_F32]]) - %expm1 = "xla_hlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32> + %expm1 = "mhlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32> // CHECK: [[LOG1P:%.*]] = f32[4] log-plus-one(f32[4] [[ARG_F32]]) - %log1p = "xla_hlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32> + %log1p = "mhlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32> // CHECK: [[ARG_I32:%.*]] = s32[4] parameter(1) // CHECK: [[NOT:%.*]] = s32[4] not(s32[4] [[ARG_I32]]) - %not = "xla_hlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32> + %not = "mhlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32> // CHECK: [[POPCNT:%.*]] = s32[4] popcnt(s32[4] [[ARG_I32]]) - %popcnt = "xla_hlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32> + %popcnt = "mhlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32> return %expm1, %log1p, %not, %popcnt : tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32> } @@ -937,7 +937,7 @@ func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: [[VAL_1:%.*]] = pred[4] parameter(0) // CHECK: [[VAL_2:%.*]] = pred[4] parameter(1) - %0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1> + %0 = mhlo.xor %arg0, %arg1 : tensor<4xi1> // CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]]) return %0 : tensor<4xi1> } @@ -946,10 +946,10 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: HloModule func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -975,7 +975,7 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // CHECK: HloModule func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { - %0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"} : (tensor<16x16xf32>) -> tensor<16x16xf32> + %0 = "mhlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"} : (tensor<16x16xf32>) -> tensor<16x16xf32> return %0 : tensor<16x16xf32> } @@ -988,8 +988,8 @@ func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { // Tests that the exported HLO module keeps parameter replication annotation. // CHECK: HloModule -func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {xla_hlo.is_same_data_across_replicas}) -> tensor<16x16xf32> { - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> +func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> { + %0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> return %0 : tensor<16x16xf32> } @@ -1003,8 +1003,8 @@ func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {xla_hlo.is_same_d // CHECK: HloModule func @main(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomplex>) -> (tensor<2xf32>, tensor<2xf64>) { - %0 = "xla_hlo.abs"(%arg0) : (tensor<2xcomplex>) -> (tensor<2xf32>) - %1 = "xla_hlo.abs"(%arg1) : (tensor<2xcomplex>) -> (tensor<2xf64>) + %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %1 = "mhlo.abs"(%arg1) : (tensor<2xcomplex>) -> (tensor<2xf64>) return %0, %1 : tensor<2xf32>, tensor<2xf64> } @@ -1019,7 +1019,7 @@ func @main(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomplex>) -> (ten // CHECK: HloModule func @main(%arg0: tensor<4xui8>) -> tensor<4xui8> { - %0 = "xla_hlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8> + %0 = "mhlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8> return %0 : tensor<4xui8> } @@ -1031,7 +1031,7 @@ func @main(%arg0: tensor<4xui8>) -> tensor<4xui8> { // CHECK: HloModule func @main(%arg0: tensor<4xi32>) -> tensor<*xi32> { - %0 = "xla_hlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %0 = "mhlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> %1 = tensor_cast %0 : tensor<4xi32> to tensor<*xi32> return %1 : tensor<*xi32> } @@ -1046,10 +1046,10 @@ func @main(%arg0: tensor<4xi32>) -> tensor<*xi32> { // correctly in HloModule as frontend_attributes. // CHECK: HloModule -func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> tuple, !xla_hlo.token> { - %0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_dtoh_0"}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token - %1 = "xla_hlo.recv"(%0) {channel_id = {handle = 2 : i64, type = 3 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_htod_0"}} : (!xla_hlo.token) -> tuple, !xla_hlo.token> - return %1 : tuple, !xla_hlo.token> +func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> tuple, !mhlo.token> { + %0 = "mhlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_dtoh_0"}} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token + %1 = "mhlo.recv"(%0) {channel_id = {handle = 2 : i64, type = 3 : i64}, is_host_transfer = true, mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_htod_0"}} : (!mhlo.token) -> tuple, !mhlo.token> + return %1 : tuple, !mhlo.token> } // CHECK: ENTRY @@ -1068,9 +1068,9 @@ func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> tuple, %token: !xla_hlo.token) -> !xla_hlo.token { - %0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token - return %0 : !xla_hlo.token +func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, mhlo.frontend_attributes = {}} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token + return %0 : !mhlo.token } // CHECK-NOT: frontend_attributes @@ -1081,9 +1081,9 @@ func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token { // populated in HloModule. // CHECK: HloModule -func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token { - %0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token - return %0 : !xla_hlo.token +func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token + return %0 : !mhlo.token } // CHECK-NOT: frontend_attributes diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt index 05d6a2a9af2aaf..86adcf0710f902 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt @@ -9,95 +9,95 @@ ENTRY %tfcompile.48 { %arg0.1 = f32[1,300] parameter(0) %arg1.2 = f32[1,300,3,1] parameter(1) - // CHECK-NEXT: %0 = "xla_hlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32> + // CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32> %reshape.3 = f32[1,300] reshape(%arg0.1) - // CHECK-NEXT: %1 = "xla_hlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %1 = "mhlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> %transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0} - // CHECK-NEXT: %2 = "xla_hlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32> + // CHECK-NEXT: %2 = "mhlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32> %reshape.28 = f32[300,1,1] reshape(%transpose.27) - // CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32> %reshape.29 = f32[300,1] reshape(%reshape.28) - // CHECK-NEXT: %4 = "xla_hlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> %broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1} // CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor %constant.8 = f32[] constant(1) - // CHECK-NEXT: %5 = "xla_hlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.9"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.9"} : (tensor) -> tensor<300x1x5xf32> %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={} - // CHECK-NEXT: %6 = xla_hlo.multiply %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32> + // CHECK-NEXT: %6 = mhlo.multiply %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32> %multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9) // CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor %constant.32 = f32[] constant(0) - // CHECK-NEXT: %7 = "xla_hlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.33"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.33"} : (tensor) -> tensor<300x1x5xf32> %broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={} - // CHECK-NEXT: %8 = "xla_hlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> + // CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> %compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT // CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor %constant.10 = f32[] constant(0) - // CHECK-NEXT: %9 = "xla_hlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.11"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.11"} : (tensor) -> tensor<300x1x5xf32> %broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={} // CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor %constant.40 = f32[] constant(0) - // CHECK-NEXT: %10 = "xla_hlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.41"} : (tensor) -> tensor<300x5xf32> + // CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.41"} : (tensor) -> tensor<300x5xf32> %broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={} - // CHECK-NEXT: %11 = "xla_hlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %11 = "mhlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %copy.1 = f32[1,300,3,1] copy(%arg1.2) - // CHECK-NEXT: %12 = "xla_hlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %12 = "mhlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %reshape.4 = f32[1,300,3,1] reshape(%copy.1) - // CHECK-NEXT: %13 = "xla_hlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> + // CHECK-NEXT: %13 = "mhlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> %reshape.24 = f32[1,300,3] reshape(%reshape.4) - // CHECK-NEXT: %14 = "xla_hlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> + // CHECK-NEXT: %14 = "mhlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> %transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2} - // CHECK-NEXT: %15 = "xla_hlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32> + // CHECK-NEXT: %15 = "mhlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32> %reshape.26 = f32[300,3] reshape(%transpose.25) // CHECK-NEXT: %cst_3 = constant {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32> %constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } }) // TODO(b/129709049) consider making this default precision config implied. - // CHECK-NEXT: %16 = "xla_hlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> %dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0} // CHECK-NEXT: %cst_4 = constant {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32> %constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0}) - // CHECK-NEXT: %17 = "xla_hlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32> %broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1} - // CHECK-NEXT: %18 = xla_hlo.add %16, %17 {name = "add.39"} : tensor<300x5xf32> + // CHECK-NEXT: %18 = mhlo.add %16, %17 {name = "add.39"} : tensor<300x5xf32> %add.39 = f32[300,5] add(%dot.36, %broadcast.38) - // CHECK-NEXT: %19 = xla_hlo.maximum %10, %18 {name = "maximum.42"} : tensor<300x5xf32> + // CHECK-NEXT: %19 = mhlo.maximum %10, %18 {name = "maximum.42"} : tensor<300x5xf32> %maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39) - // CHECK-NEXT: %20 = "xla_hlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %20 = "mhlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32> %reshape.44 = f32[300,1,5] reshape(%maximum.42) - // CHECK-NEXT: %21 = "xla_hlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44) - // CHECK-NEXT: %22 = "xla_hlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %22 = "mhlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %reshape.46 = f32[300,1,5] reshape(%select.45) - // CHECK-NEXT: %23 = "xla_hlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple> + // CHECK-NEXT: %23 = "mhlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple> // CHECK-NEXT: return %23 : tuple> ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/if.mlir b/tensorflow/compiler/mlir/xla/tests/translate/if.mlir index 6542966fc7c79c..f145f4fda0df76 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/if.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/if.mlir @@ -4,13 +4,13 @@ // CHECK: %[[A0]] = (f32[]) parameter(0) func @then_branch(%arg0: tuple>) -> tuple> { // CHECK: %[[VAL0:.+]] = f32[] get-tuple-element((f32[]) %[[A0]]), index=0 - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple>) -> tensor + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple>) -> tensor // CHECK: %[[VAL1:.+]] = f32[] log(f32[] %[[VAL0]]) - %1 = "xla_hlo.log"(%0) : (tensor) -> tensor + %1 = "mhlo.log"(%0) : (tensor) -> tensor // CHECK: ROOT %[[VAl2:.+]] = (f32[]) tuple(f32[] %[[VAL1]]) - %2 = "xla_hlo.tuple"(%1) : (tensor) -> tuple> + %2 = "mhlo.tuple"(%1) : (tensor) -> tuple> return %2 : tuple> } @@ -18,13 +18,13 @@ func @then_branch(%arg0: tuple>) -> tuple> { // CHECK: %[[A0]] = (f32[]) parameter(0) func @else_branch(%arg0: tuple>) -> tuple> { // CHECK: %[[VAL0:.+]] = f32[] get-tuple-element((f32[]) %[[A0]]), index=0 - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple>) -> tensor + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple>) -> tensor // CHECK: %[[VAL1:.+]] = f32[] exponential(f32[] %[[VAL0]]) - %1 = "xla_hlo.exponential"(%0) : (tensor) -> tensor + %1 = "mhlo.exponential"(%0) : (tensor) -> tensor // CHECK: ROOT %[[VAL2:.+]] = (f32[]) tuple(f32[] %[[VAL1]]) - %2 = "xla_hlo.tuple"(%1) : (tensor) -> tuple> + %2 = "mhlo.tuple"(%1) : (tensor) -> tuple> return %2 : tuple> } @@ -35,30 +35,30 @@ func @main(%arg0: tensor) -> tuple> { %cst = constant dense<1.000000e+01> : tensor // CHECK: %[[VAL1:.+]] = pred[] compare(f32[] %[[A0]], f32[] %[[VAL0]]), direction=LT - %0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + %0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK: %[[VAL2:.+]] = (f32[]) tuple(f32[] %[[A0]]) - %1 = "xla_hlo.tuple"(%arg0) : (tensor) -> tuple> + %1 = "mhlo.tuple"(%arg0) : (tensor) -> tuple> // CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]] - %2 = "xla_hlo.if"(%0, %1, %1) ( { + %2 = "mhlo.if"(%0, %1, %1) ( { ^bb0(%arg1: tuple>): - %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor - %7 = "xla_hlo.log"(%6) : (tensor) -> tensor - %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple> - "xla_hlo.return"(%8) : (tuple>) -> () + %6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + %7 = "mhlo.log"(%6) : (tensor) -> tensor + %8 = "mhlo.tuple"(%7) : (tensor) -> tuple> + "mhlo.return"(%8) : (tuple>) -> () }, { ^bb0(%arg1: tuple>): - %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor - %7 = "xla_hlo.exponential"(%6) : (tensor) -> tensor - %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple> - "xla_hlo.return"(%8) : (tuple>) -> () + %6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + %7 = "mhlo.exponential"(%6) : (tensor) -> tensor + %8 = "mhlo.tuple"(%7) : (tensor) -> tuple> + "mhlo.return"(%8) : (tuple>) -> () }) : (tensor, tuple>, tuple>) -> tuple> // CHECK: %[[VAL4:.+]] = f32[] get-tuple-element((f32[]) %[[VAL3]]), index=0 - %3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor + %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor // CHECK: ROOT %[[VAL5:.+]] = (f32[]) tuple(f32[] %[[VAL4]]) - %4 = "xla_hlo.tuple"(%3) : (tensor) -> tuple> + %4 = "mhlo.tuple"(%3) : (tensor) -> tuple> return %4 : tuple> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt index d2c6e669e9b285..28e98c1376ac1a 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt @@ -23,31 +23,31 @@ ENTRY %tfcompile.20 { // CHECK: [[C0:%.+]] = constant %constant.3 = f32[] constant(10), metadata={op_type="Less" op_name="Less"} - // CHECK: [[R1:%.+]] = "xla_hlo.compare"([[A0]], [[C0]]) + // CHECK: [[R1:%.+]] = "mhlo.compare"([[A0]], [[C0]]) %compare.4 = pred[] compare(%arg0.1, %constant.3), direction=LT, metadata={op_type="Less" op_name="Less"} - // CHECK: [[R2:%.+]] = "xla_hlo.tuple"([[A0]]) + // CHECK: [[R2:%.+]] = "mhlo.tuple"([[A0]]) %tuple.5 = (f32[]) tuple(%arg0.1), metadata={op_type="If" op_name="cond/Merge_if"} - // CHECK: [[R3:%.+]] = "xla_hlo.if"([[R1]], [[R2]], [[R2]]) ( { + // CHECK: [[R3:%.+]] = "mhlo.if"([[R1]], [[R2]], [[R2]]) ( { // CHECK: ^bb0([[A1:%.+]]: tuple>): - // CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]]) - // CHECK: [[R8:%.+]] = "xla_hlo.log"([[R7]]) - // CHECK: [[R9:%.+]] = "xla_hlo.tuple"([[R8]]) - // CHECK: "xla_hlo.return"([[R9]]) + // CHECK: [[R7:%.+]] = "mhlo.get_tuple_element"([[A1]]) + // CHECK: [[R8:%.+]] = "mhlo.log"([[R7]]) + // CHECK: [[R9:%.+]] = "mhlo.tuple"([[R8]]) + // CHECK: "mhlo.return"([[R9]]) // CHECK: }, { // CHECK: ^bb0([[A1:%.+]]: tuple>): - // CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]]) - // CHECK: [[R8:%.+]] = "xla_hlo.exponential"([[R7]]) - // CHECK: [[R9:%.+]] = "xla_hlo.tuple"([[R8]]) - // CHECK: "xla_hlo.return"([[R9]]) + // CHECK: [[R7:%.+]] = "mhlo.get_tuple_element"([[A1]]) + // CHECK: [[R8:%.+]] = "mhlo.exponential"([[R7]]) + // CHECK: [[R9:%.+]] = "mhlo.tuple"([[R8]]) + // CHECK: "mhlo.return"([[R9]]) // CHECK: }) %conditional.16 = (f32[]) conditional(%compare.4, %tuple.5, %tuple.5), true_computation=%then_branch, false_computation=%else_branch, metadata={op_type="If" op_name="cond/Merge_if"} - // CHECK: [[R4:%.+]] = "xla_hlo.get_tuple_element"([[R3]]) + // CHECK: [[R4:%.+]] = "mhlo.get_tuple_element"([[R3]]) %get-tuple-element.17 = f32[] get-tuple-element(%conditional.16), index=0, metadata={op_type="If" op_name="cond/Merge_if"} - // CHECK: [[R5:%.+]] = "xla_hlo.tuple"([[R4]]) + // CHECK: [[R5:%.+]] = "mhlo.tuple"([[R4]]) // CHECK: return [[R5]] ROOT %tuple.19 = (f32[]) tuple(%get-tuple-element.17), metadata={op_name="XLA_Retvals"} } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 6336d6ed6882fc..2b7d44f4522941 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -13,20 +13,20 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { %Arg_0.1 = f32[4]{0} parameter(0) %Arg_1.2 = f32[4]{0} parameter(1) - // CHECK-NEXT: xla_hlo.add %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.add %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> %add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2) // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "xla_hlo.dot"(%0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor + // CHECK-NEXT: "mhlo.dot"(%0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} } // CHECK-LABEL: func @test_after_all -// CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token [[PRIVATE]] +// CHECK-SAME: ([[VAL_0:%.*]]: !mhlo.token, [[VAL_1:%.*]]: !mhlo.token) -> !mhlo.token [[PRIVATE]] %test_after_all (token0: token[], token1: token[] ) -> token[] { token0 = token[] parameter(0) token1 = token[] parameter(1) - // CHECK-NEXT: "xla_hlo.after_all"([[VAL_0]], [[VAL_1]]) {name = "{{.*}}"} : (!xla_hlo.token, !xla_hlo.token) -> !xla_hlo.token + // CHECK-NEXT: "mhlo.after_all"([[VAL_0]], [[VAL_1]]) {name = "{{.*}}"} : (!mhlo.token, !mhlo.token) -> !mhlo.token ROOT after-all = token[] after-all(token0, token1) } @@ -41,10 +41,10 @@ add { // CHECK-SAME: ([[INPUT:%.*]]: tensor<8xf32>) %test_all_reduce { input = f32[8] parameter(0) - // CHECK-NEXT: "xla_hlo.all_reduce"([[INPUT]]) + // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]]) // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): - // CHECK: [[ADD:%.*]] = xla_hlo.add [[ARG0]], [[ARG1]] - // CHECK: "xla_hlo.return"([[ADD]]) : (tensor) -> () + // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]] + // CHECK: "mhlo.return"([[ADD]]) : (tensor) -> () // CHECK: }) { // CHECK-SAME: channel_handle = {handle = 1 : i64, type = 0 : i64} // CHECK-SAME: replica_groups = dense<{{\[\[}}0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64> @@ -57,7 +57,7 @@ add { %Arg_0.1 = pred[4] parameter(0) %Arg_1.2 = pred[4] parameter(1) - // CHECK-NEXT: xla_hlo.and %arg0, %arg1 + // CHECK-NEXT: mhlo.and %arg0, %arg1 ROOT %and.3 = pred[4] and(pred[4] %Arg_0.1, pred[4] %Arg_1.2) } @@ -67,7 +67,7 @@ add { %Arg_0.1 = s32[4] parameter(0) %Arg_1.2 = s32[4] parameter(1) - // CHECK: xla_hlo.atan2 [[VAL_0]], [[VAL_1]] + // CHECK: mhlo.atan2 [[VAL_0]], [[VAL_1]] ROOT %atan2 = s32[4] atan2(s32[4] %Arg_0.1, s32[4] %Arg_1.2) } @@ -75,10 +75,10 @@ add { %test_broadcast_in_dim { %Arg_0.1 = f32[1, 2] parameter(0) - // CHECK-NEXT: "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> + // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> %broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1} - // CHECK-NEXT: "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> + // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2} } @@ -90,7 +90,7 @@ add { %variance = f32[2] parameter(3) %grad_output = f32[2,2,2,2] parameter(4) - // CHECK: "xla_hlo.batch_norm_grad" + // CHECK: "mhlo.batch_norm_grad" // CHECK-SAME: epsilon = 1.000000e-03 : f32 // CHECK-SAME: feature_index = 1 : i64 ROOT %batch-norm-grad = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] %input, f32[2] %scale, f32[2] %mean, f32[2] %variance, f32[2,2,2,2] %grad_output), epsilon=0.001, feature_index=1 @@ -113,7 +113,7 @@ add { // CHECK-SAME: ([[ARG:%.*]]: tensor<1x291x291xf32>) -> tensor<1x291x291xf32> %test_cholesky (a: f32[1,291,291]) -> f32[1,291,291] { %a = f32[1,291,291] parameter(0) - // CHECK-NEXT: "xla_hlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32> + // CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32> ROOT %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true } @@ -124,7 +124,7 @@ add { %Arg_1.2 = f32[4] parameter(1) %Arg_2.3 = f32[] parameter(2) - // CHECK-NEXT: "xla_hlo.clamp"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.clamp"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> ROOT %clamp.3 = f32[4] clamp(f32[] %Arg_0.1, f32[4] %Arg_1.2, f32[] %Arg_2.3) } @@ -132,7 +132,7 @@ add { // CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> %test_collective_permute (input: f32[128,32]) -> f32[128,32] { %input = f32[128,32]{0,1} parameter(0) - // CHECK-NEXT: "xla_hlo.collective_permute"([[ARG]]) {name = {{.*}}, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32> + // CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {name = {{.*}}, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32> ROOT root = f32[128,32]{0,1} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}} } @@ -143,14 +143,14 @@ add { %Arg_1.2 = f32[3] parameter(1) %Arg_2.3 = f32[3] parameter(2) - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE // Requires broadcast of compatible tensors. - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT } @@ -159,7 +159,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: "xla_hlo.complex"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + // CHECK-NEXT: "mhlo.complex"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> ROOT %complex.3 = c64[4] complex(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -168,7 +168,7 @@ add { %Arg_0.1 = f32[4, 1] parameter(0) %Arg_1.2 = f32[4, 2] parameter(1) - // CHECK-NEXT: "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32> + // CHECK-NEXT: "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32> ROOT %concatenate.3 = f32[4, 3] concatenate(f32[4, 1] %Arg_0.1, f32[4, 2] %Arg_1.2), dimensions={1} } @@ -206,10 +206,10 @@ add { %test_conv { %arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %0 = "xla_hlo.copy"(%arg0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> + // CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> %copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %1 = "xla_hlo.reshape"(%0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> + // CHECK-NEXT: %1 = "mhlo.reshape"(%0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> %reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1) // Note that double brackets "[[" have to be escaped as they denote variables @@ -217,7 +217,7 @@ add { // CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> %constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %2 = "xla_hlo.convolution"(%1, %cst) { + // CHECK-NEXT: %2 = "mhlo.convolution"(%1, %cst) { // CHECK-SAME: batch_group_count = 1 : i64 // CHECK-SAME: dimension_numbers = { // CHECK-SAME: input_batch_dimension = 0 : i64 @@ -241,10 +241,10 @@ add { %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> + // CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"} - // CHECK-NEXT: "xla_hlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple> + // CHECK-NEXT: "mhlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple> ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"} } @@ -253,7 +253,7 @@ add { %test_convolve1D_padding (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,5,1] { %input = f32[1,2,1] parameter(0) %filter = f32[1,1,1] parameter(1) - // CHECK: "xla_hlo.convolution" + // CHECK: "mhlo.convolution" // CHECK-SAME: padding = dense<{{\[\[}}1, 2]]> : tensor<1x2xi64> ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1} } @@ -263,13 +263,13 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + // CHECK-NEXT: %0 = "mhlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> %convert.3 = f64[4] convert(f32[4] %Arg_0.1) - // CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + // CHECK-NEXT: %1 = "mhlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> %convert.4 = f64[4] convert(f32[4] %Arg_1.2) - // CHECK-NEXT: xla_hlo.add %0, %1 + // CHECK-NEXT: mhlo.add %0, %1 ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4) } @@ -277,7 +277,7 @@ add { %test_cosine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "xla_hlo.cosine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.cosine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %cosine.3 = f32[1,16,16,3]{3,2,1,0} cosine(f32[1,16,16,3]{3,2,1,0} %arg0.1) } @@ -286,7 +286,7 @@ add { %test_custom_call (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] { %arg1 = f32[2,3] parameter(0) %arg2 = f32[5,5] parameter(1) -// CHECK: "xla_hlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, name = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> +// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, name = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true } @@ -295,7 +295,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: xla_hlo.divide %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.divide %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -304,17 +304,17 @@ add { %Arg_0.1 = f32[1, 4] parameter(0) %Arg_1.2 = f32[4, 1] parameter(1) - // CHECK-NEXT: %0 = "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest} - // CHECK-NEXT: %1 = "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default} - // CHECK-NEXT: %2 = "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default} // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } @@ -325,17 +325,17 @@ add { %Arg_0.1 = f32[4, 1] parameter(0) %Arg_1.2 = f32[1, 4] parameter(1) - // CHECK-NEXT: [[R0:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} + // CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={high,highest} - // CHECK-NEXT: [[R1:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} + // CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,default} - // CHECK-NEXT: [[R2:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} + // CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={default,default} // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} + // CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1} } @@ -346,7 +346,7 @@ add { %start_idx_1 = s32[] parameter(1) %start_idx_2 = s32[] parameter(2) %start_idx_3 = s32[] parameter(3) - // CHECK: "xla_hlo.dynamic-slice"([[OPERAND]], [[START_IDX_1]], [[START_IDX_2]], [[START_IDX_3]]) + // CHECK: "mhlo.dynamic-slice"([[OPERAND]], [[START_IDX_1]], [[START_IDX_2]], [[START_IDX_3]]) // CHECK-SAME: slice_sizes = dense<[1, 1, 32]> : tensor<3xi64> ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32} } @@ -358,7 +358,7 @@ add { %Arg_2.3 = s32[] parameter(2) %Arg_3.4 = s32[] parameter(3) - // CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> + // CHECK-NEXT: "mhlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4) } @@ -368,7 +368,7 @@ add { %Arg_1.2 = f32[2] parameter(1) %Arg_2.3 = s32[] parameter(2) - // CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3) } @@ -376,7 +376,7 @@ add { %test_exponential (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "xla_hlo.exponential"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.exponential"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1) } @@ -384,14 +384,14 @@ add { %test_expm1 (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "xla_hlo.exponential_minus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.exponential_minus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> ROOT %expm1.2 = f32[16] exponential-minus-one(f32[16] %arg0.1) } // CHECK-LABEL: func @test_fft(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex> %test_fft { %arg0.1 = f32[3,9]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"} - // CHECK: "xla_hlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT" + // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT" ROOT %fft.2 = c64[3,5]{1,0} fft(%arg0.1), fft_type=RFFT, fft_length={9}, metadata={op_type="RFFT" op_name="rfft"} } @@ -400,7 +400,7 @@ add { %test_floor (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "xla_hlo.floor"([[A0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.floor"([[A0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1) } @@ -409,7 +409,7 @@ add { %test_gather (arg.0: f32[200,100,300], arg.1: s32[10,2]) -> f32[10,300] { %arg.0 = f32[200,100,300] parameter(0) %arg.1 = s32[10,2] parameter(1) - // CHECK: "xla_hlo.gather"([[ARG0]], [[ARG1]]) + // CHECK: "mhlo.gather"([[ARG0]], [[ARG1]]) // CHECK-SAME: dimension_numbers // CHECK-SAME: collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64> // CHECK-SAME: index_vector_dim = 1 : i64 @@ -430,7 +430,7 @@ add { // CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>) %test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] { %Arg_0.1 = f32[4,2] parameter(0) - // CHECK-NEXT: "xla_hlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x2xf32>) -> tensor + // CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x2xf32>) -> tensor ROOT %get-dimension-size.2 = s32[] get-dimension-size(f32[4,2] %Arg_0.1), dimensions={1} } @@ -438,15 +438,15 @@ add { %test_imag (Arg_0.1: c64[4]) -> f32[4] { %Arg_0.1 = c64[4] parameter(0) - // CHECK-NEXT: "xla_hlo.imag"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex>) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.imag"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex>) -> tensor<4xf32> ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1) } // CHECK-LABEL: func @test_infeed -// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple, !xla_hlo.token> +// CHECK-SAME: ([[TOKEN:%.*]]: !mhlo.token) -> tuple, !mhlo.token> %test_infeed (token0: token[]) -> (s32[3], token[]) { %token0 = token[] parameter(0) - // CHECK-NEXT: "xla_hlo.infeed"([[TOKEN]]) + // CHECK-NEXT: "mhlo.infeed"([[TOKEN]]) // CHECK-SAME: infeed_config = "foobar" ROOT %infeed = (s32[3], token[]) infeed(token[] %token0), infeed_config="foobar" } @@ -454,13 +454,13 @@ add { // CHECK-LABEL: func @test_iota_1() -> tensor<4xf32> %test_iota_1 () -> f32[4] { - // CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> + // CHECK-NEXT: "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> ROOT %iota.0 = f32[4] iota(), iota_dimension=0 } // CHECK-LABEL: func @test_iota_2() -> tensor<4x5xf32> %test_iota_2 () -> f32[4, 5] { - // CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32> + // CHECK-NEXT: "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32> ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1 } @@ -468,7 +468,7 @@ add { %test_log (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "xla_hlo.log"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.log"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> ROOT %log.2 = f32[16] log(f32[16] %arg0.1) } @@ -476,11 +476,11 @@ add { %test_log1p (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "xla_hlo.log_plus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.log_plus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> ROOT %log1p.2 = f32[16] log-plus-one(f32[16] %arg0.1) } -// Test xla_hlo.map +// Test mhlo.map %map_computation { lhs = f32[] parameter(0) rhs = f32[] parameter(1) @@ -492,10 +492,10 @@ add { %test_map { param0 = f32[4]{0} parameter(0) param1 = f32[4]{0} parameter(1) -// CHECK: "xla_hlo.map"([[ARG_0]], [[ARG_1]]) ( { +// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) ( { // CHECK: ^bb0([[ARG_2:%.*]]: tensor, [[ARG_3:%.*]]: tensor): -// CHECK: [[ADD:%.*]] = xla_hlo.add [[ARG_2]], [[ARG_3]] -// CHECK: "xla_hlo.return"([[ADD]]) : (tensor) -> () +// CHECK: [[ADD:%.*]] = mhlo.add [[ARG_2]], [[ARG_3]] +// CHECK: "mhlo.return"([[ADD]]) : (tensor) -> () // CHECK: }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=%map_computation } @@ -507,7 +507,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: xla_hlo.maximum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.maximum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -516,7 +516,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: xla_hlo.minimum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.minimum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -525,7 +525,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -533,7 +533,7 @@ add { %test_negate (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "xla_hlo.negate"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.negate"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> ROOT %negate.2 = f32[16] negate(f32[16] %arg0.1) } @@ -541,7 +541,7 @@ add { %test_not (arg0.1: pred[16]) -> pred[16] { %arg0.1 = pred[16] parameter(0) - // CHECK: "xla_hlo.not"(%arg0) {name = "{{.*}}"} : (tensor<16xi1>) -> tensor<16xi1> + // CHECK: "mhlo.not"(%arg0) {name = "{{.*}}"} : (tensor<16xi1>) -> tensor<16xi1> ROOT %not.2 = pred[16] not(pred[16] %arg0.1) } @@ -550,16 +550,16 @@ add { %Arg_0.1 = pred[4] parameter(0) %Arg_1.2 = pred[4] parameter(1) - // CHECK-NEXT: xla_hlo.or %arg0, %arg1 + // CHECK-NEXT: mhlo.or %arg0, %arg1 ROOT %or.3 = pred[4] or(pred[4] %Arg_0.1, pred[4] %Arg_1.2) } // CHECK-LABEL: func @test_outfeed -// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !xla_hlo.token) -> !xla_hlo.token +// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !mhlo.token) -> !mhlo.token %test_outfeed (Arg_0.1: s32[3], Arg_1.2: token[]) -> token[] { %Arg_0.1 = s32[3] parameter(0) %Arg_1.2 = token[] parameter(1) - // CHECK-NEXT: "xla_hlo.outfeed"([[DATA]], [[TOKEN]]) + // CHECK-NEXT: "mhlo.outfeed"([[DATA]], [[TOKEN]]) // CHECK-SAME: outfeed_config = "foobar" ROOT %outfeed.3 = token[] outfeed(s32[3] %Arg_0.1, token[] %Arg_1.2), outfeed_config="foobar" } @@ -569,7 +569,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[] parameter(1) - // CHECK-NEXT: "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0 } @@ -578,7 +578,7 @@ add { %Arg_0.1 = f32[4, 4, 4] parameter(0) %Arg_1.2 = f32[] parameter(1) - // CHECK-NEXT: "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor) -> tensor<7x11x15xf32> + // CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor) -> tensor<7x11x15xf32> ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6 } @@ -587,7 +587,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[] parameter(1) - // CHECK-NEXT: "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<2> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<10xf32> + // CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<2> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<10xf32> ROOT %pad.3 = f32[10] pad(%Arg_0.1, %Arg_1.2), padding=0_0_2 } @@ -595,7 +595,7 @@ add { %test_popcnt (arg0.1: s32[16]) -> s32[16] { %arg0.1 = s32[16] parameter(0) - // CHECK: "xla_hlo.popcnt"(%arg0) {name = "{{.*}}"} : (tensor<16xi32>) -> tensor<16xi32> + // CHECK: "mhlo.popcnt"(%arg0) {name = "{{.*}}"} : (tensor<16xi32>) -> tensor<16xi32> ROOT %popcnt.2 = s32[16] popcnt(s32[16] %arg0.1) } @@ -604,7 +604,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: xla_hlo.power %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.power %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %power.3 = f32[4] power(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -614,7 +614,7 @@ add { %Arg_0.1 = f32[] parameter(0) %Arg_1.2 = f32[] parameter(1) // CHECK: [[CST:%.*]] = constant dense<[2, 3, 5]> : tensor<3xi64> - // CHECK: "xla_hlo.rng_normal"([[ARG0]], [[ARG1]], [[CST]]) + // CHECK: "mhlo.rng_normal"([[ARG0]], [[ARG1]], [[CST]]) ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_normal } @@ -624,7 +624,7 @@ add { %Arg_0.1 = f32[] parameter(0) %Arg_1.2 = f32[] parameter(1) // CHECK: [[CST:%.*]] = constant dense<[2, 3, 5]> : tensor<3xi64> - // CHECK: "xla_hlo.rng_uniform"([[ARG0]], [[ARG1]], [[CST]]) + // CHECK: "mhlo.rng_uniform"([[ARG0]], [[ARG1]], [[CST]]) ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_uniform } @@ -632,7 +632,7 @@ add { %test_real (Arg_0.1: c64[4]) -> f32[4] { %Arg_0.1 = c64[4] parameter(0) - // CHECK-NEXT: "xla_hlo.real"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex>) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.real"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex>) -> tensor<4xf32> ROOT %real.3 = f32[4] real(c64[4] %Arg_0.1) } @@ -666,28 +666,28 @@ add { %Arg_1.2 = f32[4] parameter(1) %Arg_2.3 = f32[] parameter(2) - // CHECK: "xla_hlo.reduce"([[ARG0]], [[ARG0]], [[ARG2]], [[ARG2]]) - // CHECK: xla_hlo.add{{.*}} : tensor - // CHECK: xla_hlo.add{{.*}} : tensor + // CHECK: "mhlo.reduce"([[ARG0]], [[ARG0]], [[ARG2]], [[ARG2]]) + // CHECK: mhlo.add{{.*}} : tensor + // CHECK: mhlo.add{{.*}} : tensor // CHECK: {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor, tensor) -> tuple, tensor> %reduce.1 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.1 - // CHECK: [[VAL2:%.*]] = "xla_hlo.reduce"([[ARG0]], [[ARG2]]) - // CHECK: xla_hlo.add{{.*}} : tensor + // CHECK: [[VAL2:%.*]] = "mhlo.reduce"([[ARG0]], [[ARG2]]) + // CHECK: mhlo.add{{.*}} : tensor // CHECK: {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor) -> tensor %reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3 - // CHECK: [[VAL3:%.*]] = "xla_hlo.reduce"([[ARG0]], [[ARG1]]) - // CHECK: xla_hlo.add{{.*}} : tensor<4xf32> + // CHECK: [[VAL3:%.*]] = "mhlo.reduce"([[ARG0]], [[ARG1]]) + // CHECK: mhlo.add{{.*}} : tensor<4xf32> // CHECK: {dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32> %reduce.2 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.2 - // CHECK: [[VAL4:%.*]] = "xla_hlo.reduce"([[VAL3]], [[ARG2]]) - // CHECK: xla_hlo.add{{.*}} : tensor + // CHECK: [[VAL4:%.*]] = "mhlo.reduce"([[VAL3]], [[ARG2]]) + // CHECK: mhlo.add{{.*}} : tensor // CHECK: {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor %reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3 - // CHECK: %4 = xla_hlo.subtract [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor + // CHECK: %4 = mhlo.subtract [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor %sub.5 = f32[] subtract(%reduce.3, %reduce.4) ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5) @@ -699,8 +699,8 @@ add { %Arg_0.1 = f32[2,17,31,7] parameter(0) %Arg_1.2 = f32[] parameter(1) - // CHECK: "xla_hlo.reduce_window"([[ARG0]], [[ARG1]]) ( { - // CHECK: xla_hlo.add {{.*}} : tensor + // CHECK: "mhlo.reduce_window"([[ARG0]], [[ARG1]]) ( { + // CHECK: mhlo.add {{.*}} : tensor // CHECK: }) { // CHECK-SAME: base_dilations = dense<1> : tensor<4xi64> // CHECK-SAME: padding = dense<{{\[\[}}0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> @@ -716,7 +716,7 @@ add { %test_remainder (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) -// CHECK: xla_hlo.remainder [[VAL_0]], [[VAL_1]] +// CHECK: mhlo.remainder [[VAL_0]], [[VAL_1]] ROOT %remainder.3 = f32[4] remainder(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -724,7 +724,7 @@ add { %test_reverse_1d (Arg_0.1: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) - // CHECK-NEXT: "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0} } @@ -732,7 +732,7 @@ add { %test_reverse_2d (Arg_0.1: f32[4, 4]) -> f32[4, 4] { %Arg_0.1 = f32[4, 4] parameter(0) - // CHECK-NEXT: "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32> ROOT reverse.2 = f32[4, 4] reverse(%Arg_0.1), dimensions={0, 1} } @@ -741,7 +741,7 @@ add { %test_rsqrt (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "xla_hlo.rsqrt"([[ARG0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.rsqrt"([[ARG0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1) } @@ -767,10 +767,10 @@ add { // CHECK-LABEL: func @test_scatter // CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32> -// CHECK: "xla_hlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( { +// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( { // CHECK: ^bb0([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): -// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]] -// CHECK: "xla_hlo.return"([[ADD]]) : (tensor) -> () +// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] +// CHECK: "mhlo.return"([[ADD]]) : (tensor) -> () // CHECK: }) // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers = { @@ -788,7 +788,7 @@ add { %Arg_1.2 = s32[2,3] parameter(1) %Arg_2.3 = s32[2,3] parameter(2) - // CHECK: "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3) } @@ -814,14 +814,14 @@ add { ROOT %select-and-scatter = f32[4,5] select-and-scatter(f32[4,5] %input, f32[2,2] %source, f32[] %init_value), window={size=2x3 stride=2x3 pad=0_0x0_1}, select=%ge_select, scatter=%add_gather } -// CHECK: [[RESULT:%.*]] = "xla_hlo.select_and_scatter"([[INPUT]], [[SOURCE]], [[INIT_VAL]]) ( { +// CHECK: [[RESULT:%.*]] = "mhlo.select_and_scatter"([[INPUT]], [[SOURCE]], [[INIT_VAL]]) ( { // CHECK: ^bb0([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): -// CHECK: [[CMP:%.*]] = "xla_hlo.compare"([[LHS]], [[RHS]]) -// CHECK: "xla_hlo.return"([[CMP]]) : (tensor) -> () +// CHECK: [[CMP:%.*]] = "mhlo.compare"([[LHS]], [[RHS]]) +// CHECK: "mhlo.return"([[CMP]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): -// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]] -// CHECK: "xla_hlo.return"([[ADD]]) : (tensor) -> () +// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] +// CHECK: "mhlo.return"([[ADD]]) : (tensor) -> () // CHECK: }) { // CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1]]> : tensor<2x2xi64> // CHECK-SAME: window_dimensions = dense<[2, 3]> : tensor<2xi64> @@ -835,7 +835,7 @@ add { %test_set_dimension_size (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] { %Arg_0.1 = f32[4,4] parameter(0) %Arg_1.2 = s32[] parameter(1) - // CHECK-NEXT: "xla_hlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> + // CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1} } @@ -843,7 +843,7 @@ add { %test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "xla_hlo.sine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.sine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1) } @@ -860,10 +860,10 @@ add { } // CHECK-LABEL: func @test_sort // CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32> -// CHECK: "xla_hlo.sort"([[ARG]]) ( { +// CHECK: "mhlo.sort"([[ARG]]) ( { // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): -// CHECK: [[CMP:%.*]] = "xla_hlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor, tensor) -> tensor -// CHECK: "xla_hlo.return"([[CMP]]) : (tensor) -> () +// CHECK: [[CMP:%.*]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor, tensor) -> tensor +// CHECK: "mhlo.return"([[CMP]]) : (tensor) -> () // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32> // CHECK-LABEL: func @test_subtract @@ -871,7 +871,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: xla_hlo.subtract %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.subtract %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -879,7 +879,7 @@ add { %test_tanh (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "xla_hlo.tanh"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.tanh"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"} } @@ -887,7 +887,7 @@ add { %test_transpose { %Arg_0.1 = s32[1,2,3,4] parameter(0) - // CHECK: "xla_hlo.transpose"(%arg0) {name = "{{.*}}", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + // CHECK: "mhlo.transpose"(%arg0) {name = "{{.*}}", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2} } @@ -896,7 +896,7 @@ add { %test_triangular_solve (Arg_0.1: f32[4,4], Arg_1.2: f32[4,3]) -> f32[4,3] { %Arg_0.1 = f32[4,4] parameter(0) %Arg_1.2 = f32[4,3] parameter(1) - // CHECK-NEXT: "xla_hlo.triangular_solve"([[ARG_A]], [[ARG_B]]) + // CHECK-NEXT: "mhlo.triangular_solve"([[ARG_A]], [[ARG_B]]) // CHECK-SAME: left_side = true // CHECK-SAME: lower = true // CHECK-SAME: transpose_a = "NO_TRANSPOSE" @@ -909,10 +909,10 @@ add { %Arg_0.1 = s32[1] parameter(0) %Arg_1.2 = f32[1, 2] parameter(1) - // CHECK-NEXT: %0 = "xla_hlo.tuple"(%arg0) {name = "{{.*}}"} : (tensor<1xi32>) -> tuple> + // CHECK-NEXT: %0 = "mhlo.tuple"(%arg0) {name = "{{.*}}"} : (tensor<1xi32>) -> tuple> %tuple.3 = (s32[1]) tuple(%Arg_0.1) - // CHECK: "xla_hlo.tuple"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> + // CHECK: "mhlo.tuple"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2) } @@ -932,14 +932,14 @@ add { // CHECK-LABEL: func @test_while(%arg0: tensor) -> tensor %test_while (arg0.1: s64[]) -> s64[] { %arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "xla_hlo.while"(%arg0) ( { + // CHECK-NEXT: "mhlo.while"(%arg0) ( { // CHECK-NEXT: ^bb0(%arg1: tensor): // no predecessors - // CHECK-NEXT: [[CMP:%.*]] = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "{{.*}}"} : (tensor, tensor) -> tensor - // CHECK-NEXT: "xla_hlo.return"([[CMP]]) : (tensor) -> () + // CHECK-NEXT: [[CMP:%.*]] = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "{{.*}}"} : (tensor, tensor) -> tensor + // CHECK-NEXT: "mhlo.return"([[CMP]]) : (tensor) -> () // CHECK-NEXT: }, { // CHECK-NEXT: ^bb0(%arg1: tensor): // no predecessors - // CHECK-NEXT: [[ADD:%.*]] = xla_hlo.add %arg1, %arg1 {name = "{{.*}}"} : tensor - // CHECK-NEXT: "xla_hlo.return"([[ADD]]) : (tensor) -> () + // CHECK-NEXT: [[ADD:%.*]] = mhlo.add %arg1, %arg1 {name = "{{.*}}"} : tensor + // CHECK-NEXT: "mhlo.return"([[ADD]]) : (tensor) -> () // CHECK-NEXT: }) : (tensor) -> tensor ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond } @@ -950,7 +950,7 @@ add { %Arg_0.1 = pred[4] parameter(0) %Arg_1.2 = pred[4] parameter(1) - // CHECK: xla_hlo.xor [[VAL_0]], [[VAL_1]] + // CHECK: mhlo.xor [[VAL_0]], [[VAL_1]] ROOT %xor.3 = pred[4] xor(pred[4] %Arg_0.1, pred[4] %Arg_1.2) } @@ -960,7 +960,7 @@ add { %Arg_0.1 = s32[4] parameter(0) %Arg_1.2 = s32[4] parameter(1) - // CHECK: xla_hlo.shift_left [[VAL_0]], [[VAL_1]] + // CHECK: mhlo.shift_left [[VAL_0]], [[VAL_1]] ROOT %shiftleft = s32[4] shift-left(s32[4] %Arg_0.1, s32[4] %Arg_1.2) } @@ -970,7 +970,7 @@ add { %Arg_0.1 = s32[4] parameter(0) %Arg_1.2 = s32[4] parameter(1) - // CHECK: xla_hlo.shift_right_arithmetic [[VAL_0]], [[VAL_1]] + // CHECK: mhlo.shift_right_arithmetic [[VAL_0]], [[VAL_1]] ROOT %shiftright.arithmetic = s32[4] shift-right-arithmetic(s32[4] %Arg_0.1, s32[4] %Arg_1.2) } @@ -980,7 +980,7 @@ add { %Arg_0.1 = s32[4] parameter(0) %Arg_1.2 = s32[4] parameter(1) - // CHECK: xla_hlo.shift_right_logical [[VAL_0]], [[VAL_1]] + // CHECK: mhlo.shift_right_logical [[VAL_0]], [[VAL_1]] ROOT %shiftright.logical = s32[4] shift-right-logical(s32[4] %Arg_0.1, s32[4] %Arg_1.2) } @@ -992,8 +992,8 @@ add { %Arg_1.2 = c128[2] parameter(1) %abs.4 = f64[2] abs(c128[2] %Arg_1.2) - // CHECK: "xla_hlo.abs"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf32> - // CHECK: "xla_hlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf64> + // CHECK: "mhlo.abs"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf32> + // CHECK: "mhlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf64> ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4) } @@ -1002,6 +1002,6 @@ add { %unsigned_int(Arg_0.1: u16[4]) -> u16[4] { %Arg_0.1 = u16[4] parameter(0) - // CHECK: "xla_hlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16> + // CHECK: "mhlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16> ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/input_output_aliasing.mlir b/tensorflow/compiler/mlir/xla/tests/translate/input_output_aliasing.mlir index a0dc1798dc62ae..5e4b0c93a7ee88 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/input_output_aliasing.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/input_output_aliasing.mlir @@ -6,7 +6,7 @@ // TUPLE-ARG-LABEL: ENTRY %main // TUPLE-ARG: // OutputIndex {0} aliases with input 0 at {0} func @main(%arg0: tensor<1xf32> {tf.aliasing_output = 0 : i64}) -> (tensor<1xf32>) { - %0 = xla_hlo.constant dense<4.200000e+01> : tensor<1xf32> - %1 = xla_hlo.add %arg0, %0 : tensor<1xf32> + %0 = mhlo.constant dense<4.200000e+01> : tensor<1xf32> + %1 = mhlo.add %arg0, %0 : tensor<1xf32> return %1 : tensor<1xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/multiple_return_tuple.mlir b/tensorflow/compiler/mlir/xla/tests/translate/multiple_return_tuple.mlir index 713a256d3ce783..1d236fc0d183be 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/multiple_return_tuple.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/multiple_return_tuple.mlir @@ -9,6 +9,6 @@ func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<1x2x3x4xi32>) { // CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0) // CHECK-NEXT: %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> return %arg0, %0 : tensor<4xi32>, tensor<1x2x3x4xi32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo index 033621b7af91c1..d97c5150335f9b 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo +++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo @@ -139,8 +139,8 @@ dynamic_parameter_binding { } # CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { -# CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg1 {name = "add.3"} : tensor<4xf32> +# CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 {name = "add.3"} : tensor<4xf32> # TODO(b/129709049) consider making this default precision config inferred. -# CHECK-NEXT: %1 = "xla_hlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor +# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor # CHECK-NEXT: return %1 : tensor # CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir b/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir index e68262ba9ff4c3..5e6ef9698729d0 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir @@ -2,8 +2,8 @@ func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %1 = "xla_hlo.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %1 = "mhlo.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt index c43365e4f5bf49..126bc88ec7aa4e 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt @@ -16,14 +16,14 @@ HloModule foo ENTRY %foo (arg0.1: s64[]) -> s64[] { %arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"} - // CHECK: "xla_hlo.while"(%arg0) ( { + // CHECK: "mhlo.while"(%arg0) ( { // CHECK: ^bb0 - // CHECK: "xla_hlo.compare" - // CHECK: "xla_hlo.return" + // CHECK: "mhlo.compare" + // CHECK: "mhlo.return" // CHECK: }, { // CHECK: ^bb0 - // CHECK: xla_hlo.add - // CHECK: "xla_hlo.return" + // CHECK: mhlo.add + // CHECK: "mhlo.return" // CHECK: }) : (tensor) -> tensor ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/translate/while.mlir b/tensorflow/compiler/mlir/xla/tests/translate/while.mlir index c9c35f9fcca51a..61d7aadb23f4d4 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/while.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/while.mlir @@ -2,7 +2,7 @@ module { func @main(%arg0: tensor) -> tensor { - %0 = "xla_hlo.while"(%arg0) ( { + %0 = "mhlo.while"(%arg0) ( { // CHECK: [[R0:%.+]] ([[A0:.+]]: s64[]) -> s64[] { // CHECK: %[[A0]] = s64[] parameter(0) // CHECK: ROOT %add.4 = s64[] add(s64[] %[[A0]], s64[] %[[A0]]) @@ -10,12 +10,12 @@ module { // CHECK: %[[A0]] = s64[] parameter(0) // CHECK: ROOT %compare.7 = pred[] compare(s64[] %[[A0]], s64[] %[[A0]]), direction=LT ^bb0(%arg1: tensor): - %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): - %1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) : (tensor) -> tensor // CHECK: ENTRY %main.9 ([[A0:.+]]: s64[]) -> s64[] { diff --git a/tensorflow/compiler/mlir/xla/tests/tuple.mlir b/tensorflow/compiler/mlir/xla/tests/tuple.mlir deleted file mode 100644 index f22bc210c5700f..00000000000000 --- a/tensorflow/compiler/mlir/xla/tests/tuple.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s - -// CHECK-LABEL: func @fold_access -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func @fold_access(%arg : tensor) -> tensor { - // CHECK-NEXT: return [[ARG]] - %tuple = "xla_hlo.tuple"(%arg) : (tensor) -> tuple> - %element = "xla_hlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple>) -> tensor - return %element : tensor -} diff --git a/tensorflow/compiler/mlir/xla/tests/xla-hlo-fusion.mlir b/tensorflow/compiler/mlir/xla/tests/xla-hlo-fusion.mlir deleted file mode 100644 index 7cf06de5e018bd..00000000000000 --- a/tensorflow/compiler/mlir/xla/tests/xla-hlo-fusion.mlir +++ /dev/null @@ -1,97 +0,0 @@ -// RUN: tf-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s - -// CHECK-LABEL: func @multi_outputs_same -func @multi_outputs_same(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "xla_hlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor - %2 = "xla_hlo.add"(%1, %1) : (tensor, tensor) -> tensor - // CHECK: %[[RET:.*]]:2 = "xla_hlo.fusion" - // CHECK-NEXT: xla_hlo.add - // CHECK-NEXT: xla_hlo.subtract - // CHECK-NEXT: xla_hlo.add - // CHECK-NEXT: xla_hlo.return - return %1, %2 : tensor, tensor -} - -// ----- - -// CHECK-LABEL: func @multi_outputs_same_2 -func @multi_outputs_same_2(%arg0: tensor, %arg1: tensor) -> (tensor, tensor, tensor) { - %0 = "xla_hlo.abs"(%arg0) : (tensor) -> tensor - %1 = "xla_hlo.abs"(%arg1) : (tensor) -> tensor - %2 = "xla_hlo.add"(%0, %1) : (tensor, tensor) -> tensor - %3 = "xla_hlo.abs"(%0) : (tensor) -> tensor - %4 = "xla_hlo.abs"(%1) : (tensor) -> tensor - // CHECK: %[[RET:.*]]:3 = "xla_hlo.fusion" - // CHECK-NEXT: xla_hlo.abs - // CHECK-NEXT: xla_hlo.abs - // CHECK-NEXT: xla_hlo.add - // CHECK-NEXT: xla_hlo.abs - // CHECK-NEXT: xla_hlo.abs - // CHECK-NEXT: xla_hlo.return - return %2, %3, %4 : tensor, tensor, tensor -} - -// ----- - -// CHECK-LABEL: func @multi_outputs_not_sure_same -func @multi_outputs_not_sure_same(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "xla_hlo.add"(%arg0, %arg0) : (tensor, tensor) -> tensor - // CHECK-NOT: xla_hlo.fusion - %1 = "xla_hlo.subtract"(%arg1, %arg1) : (tensor, tensor) -> tensor - return %0, %1 : tensor, tensor -} - -// ----- - -// CHECK-LABEL: func @reduce -func @reduce(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "xla_hlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor - // CHECK: %[[RET0:.*]] = "xla_hlo.fusion" - // CHECK-NEXT: xla_hlo.add - // CHECK-NEXT: xla_hlo.subtract - // CHECK-NEXT: xla_hlo.return - // Currently we do not support fuse arguments and ops without direct producer-consumer - // relationship. Thus Reduce Op should not be fused with above two ops. - - %2 = xla_hlo.constant dense<0.000000e+00> : tensor - %3 = "xla_hlo.reduce"(%arg0, %2) ( { - ^bb0(%arg2: tensor, %arg3: tensor): - %4 = "xla_hlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor - "xla_hlo.return"(%4) : (tensor) -> () - }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor - %4 = "xla_hlo.add"(%3, %3) : (tensor, tensor) -> tensor - // Above two ops should not be fused since reduce op can not be - // fused with its consumer. - // CHECK-NOT: xla_hlo.fusion - - return %1, %4 : tensor, tensor -} - -// ----- - -// CHECK-LABEL: func @reduce_2 -func @reduce_2(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "xla_hlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor - - %2 = xla_hlo.constant dense<0.000000e+00> : tensor - %3 = "xla_hlo.reduce"(%1, %2) ( { - ^bb0(%arg2: tensor, %arg3: tensor): - %4 = "xla_hlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor - "xla_hlo.return"(%4) : (tensor) -> () - }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor - // CHECK: %[[RET0:.*]]:2 = "xla_hlo.fusion" - // CHECK-NEXT: xla_hlo.add - // CHECK-NEXT: xla_hlo.subtract - // CHECK-NEXT: xla_hlo.constant - // CHECK-NEXT: xla_hlo.reduce - // CHECK: xla_hlo.return - - // Following op should not be fused with the above ops since reduce op can not be - // fused with its consumer. - // CHECK-NOT: xla_hlo.fusion - %4 = "xla_hlo.add"(%3, %3) : (tensor, tensor) -> tensor - return %1, %4 : tensor, tensor -} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 51e211ad402907..616b214a40cbdf 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -62,10 +62,10 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { -constexpr char kShardingAttr[] = "xla_hlo.sharding"; +constexpr char kShardingAttr[] = "mhlo.sharding"; class LegalizeTF : public PassWrapper { public: @@ -286,10 +286,10 @@ static ConstOp GetMaxValueForType(Type ty, Location loc, // element type and the value. static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, OpBuilder *builder) { - return builder->create(loc, xla::GetScalarOfType(ty, raw_value)); + return builder->create(loc, hlo::GetScalarOfType(ty, raw_value)); } -// Creates an xla_hlo::SliceOp where the major dimensions have full size, and +// Creates an mhlo::SliceOp where the major dimensions have full size, and // the minor dimensions have the provided offsets and sizes. static Value SliceInMinorDims(Location loc, Value v, ArrayRef minor_starts, @@ -326,7 +326,7 @@ static llvm::SmallVector CreateFullIndexVectorFromMinorIndices( return indices; } -// Creates an xla_hlo::DynamicSliceOp where the major dimensions have full size, +// Creates an mhlo::DynamicSliceOp where the major dimensions have full size, // and the minor dimensions have the provided offsets and sizes. static Value DynamicSliceInMinorDims(Location loc, Value v, ArrayRef minor_starts, @@ -341,12 +341,12 @@ static Value DynamicSliceInMinorDims(Location loc, Value v, std::copy(minor_sizes.begin(), minor_sizes.end(), slice_sizes.begin() + major_dims); auto slice_type = RankedTensorType::get(slice_sizes, type.getElementType()); - return builder->create( + return builder->create( loc, slice_type, v, slice_starts, GetI64ElementsAttr(slice_sizes, builder)); } -// Creates an xla_hlo::DynamicUpdateSliceOp where the major dimensions have zero +// Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero // offsets, and the minor dimensions have the provided offsets. static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update, ArrayRef minor_starts, @@ -359,7 +359,7 @@ static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update, llvm::makeArrayRef(dus_starts)); } -// Creates an xla_hlo::DynamicUpdateSliceOp where the major dimensions have zero +// Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero // offsets, and the minor dimensions have the provided static offsets. static Value UpdateSliceInMinorDims(Location loc, Value v, Value update, ArrayRef minor_starts, @@ -540,7 +540,7 @@ static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, loc, to_type, input, result_extents, broadcast_dims); } -// Creates a batch dot using xla_hlo::DotGeneralOp. +// Creates a batch dot using mhlo::DotGeneralOp. Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs, bool transpose_rhs, int64_t num_batch_dims, ArrayAttr precision_config, OpBuilder *builder) { @@ -605,31 +605,30 @@ static Value ApplyReduction(Location loc, Value input, builder->getBoolAttr(false)); } -// Creates a xla_hlo.rng_uniform op with `builder` to generate `num_elements` +// Creates a mhlo.rng_uniform op with `builder` to generate `num_elements` // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`). -static xla_hlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements, - int lower_limit, - int upper_limit, - OpBuilder *builder) { +static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements, + int lower_limit, int upper_limit, + OpBuilder *builder) { auto i32_type = builder->getIntegerType(32); auto key_type = RankedTensorType::get({num_elements}, i32_type); - auto shape_tensor = builder->create( + auto shape_tensor = builder->create( loc, GetI64ElementsAttr({num_elements}, builder)); - auto lower = builder->create( + auto lower = builder->create( loc, builder->getI32IntegerAttr(lower_limit)); - auto upper = builder->create( + auto upper = builder->create( loc, builder->getI32IntegerAttr(upper_limit)); - return builder->create(loc, key_type, lower, upper, - shape_tensor); + return builder->create(loc, key_type, lower, upper, + shape_tensor); } using WhileBodyFnType = llvm::function_ref old_values, SmallVectorImpl *new_values, OpBuilder *builder)>; -// Creates a xla_hlo.while op with `builder` to loop `num_interations` times, +// Creates a mhlo.while op with `builder` to loop `num_interations` times, // each time calling the given `body_fn` on a set of values to generate a new // set of values. Returns the final set of values via `final_values`. The // initial set of values is passed in via `init_values`. @@ -659,16 +658,16 @@ static void CreateWhile32(Location loc, int num_iterations, init_values_with_loop_iv.reserve(value_count); // The initial value for the loop induction variable is 0. init_values_with_loop_iv.push_back( - builder->create(loc, builder->getI32IntegerAttr(0))); + builder->create(loc, builder->getI32IntegerAttr(0))); init_values_with_loop_iv.append(init_values.begin(), init_values.end()); // Prepare the initial tuple for the while op. auto init_tuple = - builder->create(loc, init_values_with_loop_iv); + builder->create(loc, init_values_with_loop_iv); auto tuple_type = init_tuple.getType(); // Create the while op. - auto while_op = builder->create(loc, init_tuple); + auto while_op = builder->create(loc, init_tuple); { OpBuilder::InsertionGuard guard(*builder); @@ -681,13 +680,13 @@ static void CreateWhile32(Location loc, int num_iterations, // Get the loop induction variable and compare it against the upper limit. auto loop_iv = builder->create(loc, arg, 0); - auto upper_limit = builder->create( + auto upper_limit = builder->create( loc, builder->getI32IntegerAttr(num_iterations)); StringAttr compare_direction = StringAttr::get("LT", builder->getContext()); - Value compare = builder->create( - loc, loop_iv, upper_limit, compare_direction); + Value compare = builder->create(loc, loop_iv, upper_limit, + compare_direction); - builder->create(loc, compare); + builder->create(loc, compare); } { @@ -714,16 +713,16 @@ static void CreateWhile32(Location loc, int num_iterations, // Increment the loop induction variable by one. auto one = - builder->create(loc, builder->getI32IntegerAttr(1)); + builder->create(loc, builder->getI32IntegerAttr(1)); auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder); - auto plus_one = builder->create( + auto plus_one = builder->create( loc, old_values[0], one, scalar_broadcast_dims); // Prepend with the updated loop induction variable. new_values.insert(new_values.begin(), plus_one); - Value updated_tuple = builder->create(loc, new_values); + Value updated_tuple = builder->create(loc, new_values); - builder->create(loc, updated_tuple); + builder->create(loc, updated_tuple); } final_values->reserve(init_values.size()); @@ -786,7 +785,7 @@ static Value CreateConvertOp(OpBuilder *builder, Location loc, Value input, Value elem_type_tensor) { auto element_type = elem_type_tensor.getType().cast().getElementType(); - return builder->create(loc, input, element_type); + return builder->create(loc, input, element_type); } //===----------------------------------------------------------------------===// @@ -897,7 +896,7 @@ static DenseElementsAttr GetEpsilonValue(Type ty) { auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon)); return DenseElementsAttr::get(scalar_ty, value); } else if (element_ty.isBF16()) { - uint16_t raw_epsilon = tensorflow::bfloat16::epsilon().value; + uint16_t raw_epsilon = Eigen::NumTraits::epsilon().value; auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon)); return DenseElementsAttr::get(scalar_ty, value); } else if (element_ty.isF32()) { @@ -998,7 +997,7 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( Builder *builder) { DenseIntElementsAttr constant_start_indices; if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) { - return xla::ConvertElementsAttr(slice_sizes, builder->getIntegerType(64)) + return hlo::ConvertElementsAttr(slice_sizes, builder->getIntegerType(64)) .cast(); } @@ -1023,9 +1022,9 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( // Sort op utilities. //===----------------------------------------------------------------------===// -// Builds the region `body` for xla_hlo.sort's comparator: for each type in +// Builds the region `body` for mhlo.sort's comparator: for each type in // `element_types`, create two block arguments, one for lhs and one for rhs, and -// generates xla_hlo.compare op to compare them with the given `direction`. +// generates mhlo.compare op to compare them with the given `direction`. // // Note that this right now only does comparision on the first pair of block // arguments. @@ -1044,10 +1043,10 @@ static void BuildSortComparisonBody(llvm::ArrayRef element_types, Location loc = body->getLoc(); StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); - Value compare = builder->create( + Value compare = builder->create( loc, block->getArgument(0), block->getArgument(1), compare_direction); - builder->create(loc, compare); + builder->create(loc, compare); } //===----------------------------------------------------------------------===// @@ -1110,7 +1109,7 @@ class ConvertBiasAddOp : public OpRewritePattern { // // Sample result for Conv2D: // -// %conv = "xla_hlo.convolution"(%input, %filter) { +// %conv = "mhlo.convolution"(%input, %filter) { // strides = [1, 2], // paddings = [[1, 0], [1, 1]], // ... @@ -1235,7 +1234,7 @@ class ConvertConvOp : public OpRewritePattern { new_shape.push_back(1); new_shape.push_back(filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1]); - operands[1] = rewriter.create( + operands[1] = rewriter.create( op.getLoc(), RankedTensorType::get(new_shape, filter_ty.getElementType()), operands[1]); @@ -1319,16 +1318,16 @@ class ConvertBroadcastToOp : public OpRewritePattern { // Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix. // For a Rank-2 input, it creates the following ops: -// %1 = "xla_hlo.iota"() {iota_dimension = 0 : i64} -// %2 = "xla_hlo.iota"() {iota_dimension = 1 : i64} -// %3 = "xla_hlo.compare"(%1, %2) {comparison_direction = "EQ"} -// %4 = xla_hlo.constant dense<0.000000e+00> : tensor -// %5 = "xla_hlo.broadcast"(%4) -// %6 = "xla_hlo.select"(%3, %input, %5) -// %7 = "xla_hlo.reduce"(%6, %4) ( { +// %1 = "mhlo.iota"() {iota_dimension = 0 : i64} +// %2 = "mhlo.iota"() {iota_dimension = 1 : i64} +// %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"} +// %4 = mhlo.constant dense<0.000000e+00> : tensor +// %5 = "mhlo.broadcast"(%4) +// %6 = "mhlo.select"(%3, %input, %5) +// %7 = "mhlo.reduce"(%6, %4) ( { // ^bb0(%arg1: tensor, %arg2: tensor): -// %9 = xla_hlo.add %arg1, %arg2 : tensor -// "xla_hlo.return"(%9) : (tensor) -> () +// %9 = mhlo.add %arg1, %arg2 : tensor +// "mhlo.return"(%9) : (tensor) -> () // }) {dimensions = dense<0> : tensor<1xi64>} // // If the input's rank N is greater than 2, we will reshape it to R2 first and @@ -1353,7 +1352,7 @@ class ConvertDiagPartOp : public OpRewritePattern { new_size *= input_type.getDimSize(i); new_dims.push_back(input_type.getDimSize(i)); } - Value reshaped_input = rewriter.create( + Value reshaped_input = rewriter.create( op.getLoc(), RankedTensorType::get({new_size, new_size}, input_type.getElementType()), @@ -1484,29 +1483,29 @@ class ConvertFusedBatchNormGradBase RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type); auto epsilon = rewriter.create( loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()})); - auto add_op = rewriter.create( + auto add_op = rewriter.create( loc, var, epsilon.getResult(), scalar_broadcast_dims); Value scratch1 = rewriter.create(loc, add_op); // scratch2 = sum(y_backprop * (x - mean)) - auto sub_op = rewriter.create( + auto sub_op = rewriter.create( loc, act, Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter)); - auto weighted_grad = rewriter.create(loc, grad, sub_op); + auto weighted_grad = rewriter.create(loc, grad, sub_op); Value scratch2 = ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter); // x_backprop = y_backprop * (scale * scratch1) auto scaled_grad = - rewriter.create(loc, op.scale(), scratch1); - x_backprop = rewriter.create( + rewriter.create(loc, op.scale(), scratch1); + x_backprop = rewriter.create( loc, grad, Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim, rewriter)); // scale_backprop = scratch2 * scratch1 - scale_backprop = rewriter.create(loc, scratch1, scratch2); + scale_backprop = rewriter.create(loc, scratch1, scratch2); // offset_backprop = sum(y_backprop) offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter); @@ -1559,8 +1558,8 @@ class ConvertFusedBatchNormV3Op // TODO(b/69928690): Support mixed precision in the XLA batch // normalization operators. As a workaround, create a new x with the same // element type as scale (which may be more precise than the input type). - Value bn_train_input = rewriter.create( - op.getLoc(), op.x(), scale_element_type); + Value bn_train_input = rewriter.create(op.getLoc(), op.x(), + scale_element_type); TensorType bn_train_input_type_tensor = bn_train_input.getType().cast(); @@ -1579,17 +1578,17 @@ class ConvertFusedBatchNormV3Op mean_var_type, mean_var_type}; Type result_type = TupleType::get(operand_types, rewriter.getContext()); - auto bn_train_op = rewriter.create( + auto bn_train_op = rewriter.create( op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(), op.epsilon(), feature_dim.getValue()); // HLO op outputs a tuple of tensors. Extract those results. auto bn_train_op_result = bn_train_op.getResult(); - Value y_out = rewriter.create( + Value y_out = rewriter.create( op.getLoc(), bn_train_op_result, 0); - Value batch_mean = rewriter.create( + Value batch_mean = rewriter.create( op.getLoc(), bn_train_op_result, 1); Value reserve_space_1 = batch_mean; - Value batch_variance = rewriter.create( + Value batch_variance = rewriter.create( op.getLoc(), bn_train_op_result, 2); // Apply Bessel's correction on the variance. @@ -1599,49 +1598,47 @@ class ConvertFusedBatchNormV3Op int sample_size_minus_one = std::max(1, sample_size - 1); double factor = static_cast(sample_size) / static_cast(sample_size_minus_one); - auto factor_const_op = rewriter.create( + auto factor_const_op = rewriter.create( op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); - Value corrected_variance = rewriter.create( + Value corrected_variance = rewriter.create( op.getLoc(), batch_variance.getType(), batch_variance, factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr()); // Convert back to input type to stay aligned with expected output type // for TF op. - y_out = rewriter.create(op.getLoc(), y_out, - input_element_type); + y_out = rewriter.create(op.getLoc(), y_out, + input_element_type); float exponential_avg_factor = op.exponential_avg_factor().convertToFloat(); if (exponential_avg_factor != 1.0f) { - auto alpha = rewriter.create( + auto alpha = rewriter.create( op.getLoc(), rewriter.getFloatAttr(mean_element_type, 1.0f - exponential_avg_factor)); - auto beta = rewriter.create( + auto beta = rewriter.create( op.getLoc(), rewriter.getFloatAttr(mean_element_type, exponential_avg_factor)); // new_running_mean = alpha * old_mean + beta * batch_mean. - auto alpha_mul_old_mean = rewriter.create( + auto alpha_mul_old_mean = rewriter.create( op.getLoc(), op.mean().getType(), alpha, op.mean(), /*broadcast_dimensions=*/DenseIntElementsAttr()); - auto beta_mul_batch_mean = rewriter.create( + auto beta_mul_batch_mean = rewriter.create( op.getLoc(), batch_mean.getType(), beta, batch_mean, /*broadcast_dimensions=*/DenseIntElementsAttr()); - batch_mean = rewriter.create( + batch_mean = rewriter.create( op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean, /*broadcast_dimensions=*/DenseIntElementsAttr()); // new_running_variance = alpha * old_variance + beta * batch_variance. - auto alpha_mul_old_variance = rewriter.create( + auto alpha_mul_old_variance = rewriter.create( op.getLoc(), op.variance().getType(), alpha, op.variance(), /*broadcast_dimensions=*/DenseIntElementsAttr()); - auto beta_mul_batch_variance = - rewriter.create( - op.getLoc(), corrected_variance.getType(), beta, - corrected_variance, - /*broadcast_dimensions=*/DenseIntElementsAttr()); - corrected_variance = rewriter.create( + auto beta_mul_batch_variance = rewriter.create( + op.getLoc(), corrected_variance.getType(), beta, corrected_variance, + /*broadcast_dimensions=*/DenseIntElementsAttr()); + corrected_variance = rewriter.create( op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance, /*broadcast_dimensions=*/DenseIntElementsAttr()); } @@ -1666,8 +1663,8 @@ class ConvertFusedBatchNormV3Op // Convert back to input type to stay aligned with expected output type // for TF op. - auto y_out = rewriter.create(op.getLoc(), bn_train_op, - input_element_type); + auto y_out = rewriter.create(op.getLoc(), bn_train_op, + input_element_type); // The mean, variance, and reserved space outputs of the batch norm op are // not used for inference. It doesn't matter what values we provide for @@ -1811,7 +1808,7 @@ class ConvertAvgPoolOp : public OpRewritePattern { Value divisor = GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter); auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); - Value result = rewriter.create( + Value result = rewriter.create( op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims); // Convert back if we enlarged the element type's bitwidth. @@ -1915,7 +1912,7 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { Value divisor = GetScalarConstOfType(element_type, loc, window_count, &rewriter); auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); - out_grad_divided = rewriter.create( + out_grad_divided = rewriter.create( loc, out_grad_type, out_grad, divisor, scalar_broadcast_dims); } else { assert(op.padding() == "SAME"); @@ -1923,7 +1920,7 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { // are counted for the average of this window, not padded entries. // Build all-ones tensor of same shape as the original input. - ElementsAttr splat = xla::getSplat(&rewriter, orig_input_type, 1); + ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1); auto all_ones_tensor = rewriter.create(loc, splat); // Get the same padding as for the original input. @@ -1947,8 +1944,8 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { BuildReduceBody(element_type, &window_counts.body(), &rewriter); // Divide `out_grad` by window counts. - out_grad_divided = rewriter.create( - loc, out_grad_type, out_grad, window_counts); + out_grad_divided = rewriter.create(loc, out_grad_type, + out_grad, window_counts); } // Get same padding as for original input. @@ -2053,7 +2050,7 @@ using ConvertAvgPool3DGradOp = // Sample result for VALID padding mode: // // %init = constant dense<...> : tensor -// %max_pool = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.maximum"] +// %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] // {window_dimensions = ..., window_strides = ... } // template @@ -2098,13 +2095,13 @@ using ConvertMaxPool3DOp = ConvertMaxPoolOp; // // will be converted into: // -// %pred = "xla_hlo.broadcast_in_dim"(%cond) +// %pred = "mhlo.broadcast_in_dim"(%cond) // {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : // (tensor<1xi1>) -> tensor<2xi1> -// %on_false = "xla_hlo.broadcast_in_dim"(%e) +// %on_false = "mhlo.broadcast_in_dim"(%e) // {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : // (tensor<1xi32>) -> tensor<2xi32> -// %select = "xla_hlo.select"(%pred, %t, %on_false) : +// %select = "mhlo.select"(%pred, %t, %on_false) : // (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> class ConvertSelectV2Op : public OpRewritePattern { public: @@ -2173,18 +2170,18 @@ class ConvertSelectV2Op : public OpRewritePattern { // Sample result with 2-d f16 inputs with B batches of with N elements each. // // // Create an array of 0.5 the shape of the input array. -// %half = xla_hlo.constant dense<5.000000e-01> : tensor -// %half_array = "xla_hlo.broadcast"(half) +// %half = mhlo.constant dense<5.000000e-01> : tensor +// %half_array = "mhlo.broadcast"(half) // {broadcast_sizes = dense<2> : tensor<1xi64>} // : (tensor) -> tensor<2xf32> // // // Compute Tanh of half the logits of the values. -// %halved_logits = xla_hlo.multiply %logits, %half_array : tensor<2xf32> -// %tanh = "xla_hlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32> +// %halved_logits = mhlo.multiply %logits, %half_array : tensor<2xf32> +// %tanh = "mhlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32> // // // Have the result of Tanh and add 0.5. -// %halved_tanh = xla_hlo.multiply %tanh, %half : tensor<2xf32> -// %sigmoid = xla_hlo.add %halved_tanh, %half : tensor<2xf32> +// %halved_tanh = mhlo.multiply %tanh, %half : tensor<2xf32> +// %sigmoid = mhlo.add %halved_tanh, %half : tensor<2xf32> // class ConvertSigmoidOp : public OpRewritePattern { public: @@ -2198,7 +2195,7 @@ class ConvertSigmoidOp : public OpRewritePattern { Value operand = op.getOperand(); auto operand_ty = operand.getType().cast(); auto scalar_ty = RankedTensorType::get({}, operand_ty.getElementType()); - ElementsAttr attr = mlir::xla::getSplat(&rewriter, scalar_ty, 0.5); + ElementsAttr attr = mlir::hlo::getSplat(&rewriter, scalar_ty, 0.5); auto scalar_half = rewriter.create(loc, attr); auto half = BroadcastToShapeOf(loc, scalar_half, operand, rewriter); @@ -2227,15 +2224,15 @@ class ConvertSigmoidOp : public OpRewritePattern { // // stability. // %max = "tf.Max"(%input, %reduce_dim) // : (tensor, tensor<1xi64>) -> tensor -// %sub = "xla_hlo.subtract"(%inp, %max) {broadcast_dimensions = 0} +// %sub = "mhlo.subtract"(%inp, %max) {broadcast_dimensions = 0} // : (tensor, tensor) -> tensor // -// %exp = "xla_hlo.exponential"(%sub) : (tensor) -> tensor +// %exp = "mhlo.exponential"(%sub) : (tensor) -> tensor // %sum = "tf.Sum"(%exp, %reduce_dim) // : (tensor, tensor<1xi64>) -> tensor // // // Softmax computation: -// %softmax = "xla_hlo.divide"(%exp, %sum_f16) {broadcast_dimensions = 0} +// %softmax = "mhlo.divide"(%exp, %sum_f16) {broadcast_dimensions = 0} // : (tensor, tensor) -> tensor template class ConvertSoftmaxOp : public OpRewritePattern { @@ -2270,8 +2267,8 @@ class ConvertSoftmaxOp : public OpRewritePattern { /*keep_dims=*/rewriter.getBoolAttr(false)); auto max_logits_broadcast = CommonPrefixBroadcast(loc, logits, max_logits, rewriter); - auto shifted_logits = rewriter.create(loc, type, logits, - max_logits_broadcast); + auto shifted_logits = + rewriter.create(loc, type, logits, max_logits_broadcast); // Exponentiate the inputs. Value exp = rewriter.create(loc, type, shifted_logits); @@ -2285,11 +2282,11 @@ class ConvertSoftmaxOp : public OpRewritePattern { if (use_log) { Value log = rewriter.create(loc, sum); auto log_broadcast = CommonPrefixBroadcast(loc, logits, log, rewriter); - rewriter.replaceOpWithNewOp(op, shifted_logits, - log_broadcast); + rewriter.replaceOpWithNewOp(op, shifted_logits, + log_broadcast); } else { auto sum_broadcast = CommonPrefixBroadcast(loc, logits, sum, rewriter); - rewriter.replaceOpWithNewOp(op, exp, sum_broadcast); + rewriter.replaceOpWithNewOp(op, exp, sum_broadcast); } return success(); } @@ -2307,16 +2304,16 @@ class ConvertSoftmaxOp : public OpRewritePattern { // // will be converted into: // -// %const = xla_hlo.constant dense<1> : tensor -// %dim_0 = "xla_hlo.get_dimension_size"(%input) {dimension = 0 : i32} : +// %const = mhlo.constant dense<1> : tensor +// %dim_0 = "mhlo.get_dimension_size"(%input) {dimension = 0 : i32} : // (tensor<2x?x8xf32>) -> tensor -// %prod_0 = xla_hlo.multiply %const, %dim_0 : tensor -// %dim_1 = "xla_hlo.get_dimension_size"(%input) {dimension = 1 : i32} : +// %prod_0 = mhlo.multiply %const, %dim_0 : tensor +// %dim_1 = "mhlo.get_dimension_size"(%input) {dimension = 1 : i32} : // (tensor<2x?x8xf32>) -> tensor -// %prod_1 = xla_hlo.multiply %prod_0, %dim_1 : tensor -// %dim_2 = "xla_hlo.get_dimension_size"(%input) {dimension = 2 : i32} : +// %prod_1 = mhlo.multiply %prod_0, %dim_1 : tensor +// %dim_2 = "mhlo.get_dimension_size"(%input) {dimension = 2 : i32} : // (tensor<2x?x8xf32>) -> tensor -// %size = xla_hlo.multiply %prod_1, %dim_2 : tensor +// %size = mhlo.multiply %prod_1, %dim_2 : tensor class ConvertSizeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -2336,7 +2333,7 @@ class ConvertSizeOp : public OpRewritePattern { auto dim = rewriter.create( op.getLoc(), result_type, input, rewriter.getIntegerAttr(rewriter.getIntegerType(32), i)); - size = rewriter.create( + size = rewriter.create( op.getLoc(), size->getResult(0), dim.getResult(), /*DenseIntElementsAttr=*/DenseIntElementsAttr()); } @@ -2470,17 +2467,17 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { // // will be converted into: // -// %0 = "xla_hlo.slice"(%input) { +// %0 = "mhlo.slice"(%input) { // limit_indices = dense<[4, 2]> : tensor<2xi64>, // start_indices = dense<0> : tensor<2xi64>, // strides = dense<1> : tensor<2xi64>} : // (tensor<4x6xf32>) -> tensor<4x2xf32> -// %1 = "xla_hlo.slice"(%input) { +// %1 = "mhlo.slice"(%input) { // limit_indices = dense<4> : tensor<2xi64>, // start_indices = dense<[0, 2]> : tensor<2xi64>, // strides = dense<1> : tensor<2xi64>} : // (tensor<4x6xf32>) -> tensor<4x2xf32> -// %2 = "xla_hlo.slice"(%input) { +// %2 = "mhlo.slice"(%input) { // limit_indices = dense<[4, 6]> : tensor<2xi64>, // start_indices = dense<[0, 4]> : tensor<2xi64>, // strides = dense<1> : tensor<2xi64>} : @@ -2563,17 +2560,17 @@ class ConvertSplitOp : public OpRewritePattern { // (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) // // We will generate slices following slices: -// %0 = "xla_hlo.slice"(%input) { +// %0 = "mhlo.slice"(%input) { // limit_indices = dense<[4, 1]> : tensor<2xi64>, // start_indices = dense<0> : tensor<2xi64>, // strides = dense<1> : tensor<2xi64>} : // (tensor<4x6xf32>) -> tensor<4x1xf32> -// %1 = "xla_hlo.slice"(%input) { +// %1 = "mhlo.slice"(%input) { // limit_indices = dense<[4, 3]> : tensor<2xi64>, // start_indices = dense<[0, 1]> : tensor<2xi64>, // strides = dense<1> : tensor<2xi64>} : // (tensor<4x6xf32>) -> tensor<4x2xf32> -// %2 = "xla_hlo.slice"(%input) { +// %2 = "mhlo.slice"(%input) { // limit_indices = dense<[4, 6]> : tensor<2xi64>, // start_indices = dense<[0, 3]> : tensor<2xi64>, // strides = dense<1> : tensor<2xi64>} : @@ -2645,7 +2642,7 @@ class ConvertSplitVOp : public OpRewritePattern { for (int i = 0; i < op.getNumResults(); ++i) { end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i]; - slices.push_back(rewriter.create( + slices.push_back(rewriter.create( op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter), GetI64ElementsAttr(end_indices, &rewriter), GetI64ElementsAttr(strides, &rewriter))); @@ -2663,7 +2660,7 @@ class ConvertSplitVOp : public OpRewritePattern { // strides operands are converted to attributes with non-negative indexing. // // If the begin input is not a compile time constant, the begin input needs to -// be sliced and the slice needs to be lowered to xla_hlo.DynamicSlice. In this +// be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this // case, strides must have a known value of 1 (otherwise we have insufficient // information to conform to XLA's op semantics). // @@ -2672,10 +2669,10 @@ class ConvertSplitVOp : public OpRewritePattern { // : tensor -> tensor // // If the %begin input is constant, output would be: -// %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...} -// %sliced = "xla_hlo.Slice" (%input) +// %reversed = "mhlo.Reverse" (%input) {dimensions = ...} +// %sliced = "mhlo.Slice" (%input) // {start_indices = ..., limit_indices = ..., strides = ...} -// %output = "xla_hlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor +// %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor // class ConvertStridedSliceOp : public OpRewritePattern { public: @@ -2940,7 +2937,7 @@ class ConvertStridedSliceGradOp Type element_type = grad.getType().cast().getElementType(); // Perform reshape to undo any new/shrink axes done by strided slice. - grad = rewriter.create( + grad = rewriter.create( op.getLoc(), RankedTensorType::get(shape, element_type), grad); SmallVector padding_low, padding_high, padding_interm; @@ -2976,13 +2973,13 @@ class ConvertStridedSliceGradOp } if (!dims_to_reverse.empty()) { - grad = rewriter.create( + grad = rewriter.create( op.getLoc(), grad.getType(), grad, GetI64ElementsAttr(dims_to_reverse, &rewriter)); } auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), grad, zero, GetI64ElementsAttr(padding_low, &rewriter), GetI64ElementsAttr(padding_high, &rewriter), @@ -2991,7 +2988,7 @@ class ConvertStridedSliceGradOp } }; -/// Converts the RangeOp tensorflow op to a xla_hlo.iota op with a scaling and +/// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and /// offset applied to generate the range values. The output tensor needs to /// have a static shape. /// @@ -3000,11 +2997,11 @@ class ConvertStridedSliceGradOp /// : (tensor, tensor, tensor) -> tensor<5xf32> /// /// Output would be: -/// %iota = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32> -/// %scaled = "xla_hlo.multiply"(%iota, %delta) +/// %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32> +/// %scaled = "mhlo.multiply"(%iota, %delta) /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : /// (tensor<5xf32>, tensor) -> tensor<5xf32> -/// %result = "xla_hlo.add"(%scaled, %offset) +/// %result = "mhlo.add"(%scaled, %offset) /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : /// (tensor<5xf32>, tensor) -> tensor<5xf32> /// @@ -3022,12 +3019,12 @@ class ConvertRangeOp : public OpRewritePattern { auto iota = rewriter.create(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, op.delta(), - xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta())); - rewriter.replaceOpWithNewOp( + hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.delta())); + rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), - xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); + hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); } }; @@ -3071,23 +3068,23 @@ class ConvertDynamicRangeOp : public OpRewritePattern { // some conversion to float for the operations. // // %size = ceil(abs((%limit - %start) / %delta)) - auto range = rewriter.create(op.getLoc(), limit, start); - auto abs = rewriter.create(op.getLoc(), range); + auto range = rewriter.create(op.getLoc(), limit, start); + auto abs = rewriter.create(op.getLoc(), range); // Delta is not necessarily the same type as start and limit. auto abs_cast = - rewriter.create(op.getLoc(), compute_type, abs); + rewriter.create(op.getLoc(), compute_type, abs); auto delta_cast = - rewriter.create(op.getLoc(), compute_type, delta); + rewriter.create(op.getLoc(), compute_type, delta); // Compute the total number of integer steps and convert to the HLO // dimension tensor. auto normalized = - rewriter.create(op.getLoc(), abs_cast, delta_cast); - auto ceil = rewriter.create(op.getLoc(), normalized); - auto steps = rewriter.create( + rewriter.create(op.getLoc(), abs_cast, delta_cast); + auto ceil = rewriter.create(op.getLoc(), normalized); + auto steps = rewriter.create( op.getLoc(), RankedTensorType::get({}, rewriter.getI64Type()), ceil); - auto reshape = rewriter.create( + auto reshape = rewriter.create( op.getLoc(), RankedTensorType::get({1}, rewriter.getI64Type()), steps); // Using the resulting length compute the correct range value: @@ -3095,19 +3092,19 @@ class ConvertDynamicRangeOp : public OpRewritePattern { // %range = %start + %delta * iota(%size) auto out_scalar_type = RankedTensorType::get({}, getElementTypeOrSelf(result_type)); - auto start_out_cast = rewriter.create( - op.getLoc(), out_scalar_type, start); - auto delta_out_cast = rewriter.create( - op.getLoc(), out_scalar_type, delta); + auto start_out_cast = + rewriter.create(op.getLoc(), out_scalar_type, start); + auto delta_out_cast = + rewriter.create(op.getLoc(), out_scalar_type, delta); auto iota = rewriter.create( op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, delta_out_cast, - xla::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast)); - rewriter.replaceOpWithNewOp( + hlo::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast)); + rewriter.replaceOpWithNewOp( op, result_type, scaled, start_out_cast, - xla::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast)); + hlo::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast)); return success(); } }; @@ -3127,7 +3124,7 @@ ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { return builder->getI64TensorAttr(axis); } -/// Converts the LinSpace tensorflow op to a xla_hlo.iota op with a scaling +/// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling /// and offset applied to generate the linspace values. The output tensor needs /// to have a static shape. The implementation is defined in C++ because there /// is no type inference for the iota op. @@ -3153,37 +3150,37 @@ class ConvertLinSpaceOp : public OpRewritePattern { int64_t num = (*num_attr.begin()).getSExtValue(); // Calculate the scaling that needs to be applied to the iota. - auto step_numerator = rewriter.create( + auto step_numerator = rewriter.create( op.getLoc(), op.start().getType(), op.stop(), op.start(), - xla::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start())); + hlo::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start())); Value step_denominator = rewriter.create( op.getLoc(), op.num(), result_type.getElementType()); if (num > 1) { Value one = GetScalarConstOfType(result_type.getElementType(), op.getLoc(), 1, &rewriter); - step_denominator = rewriter.create( + step_denominator = rewriter.create( op.getLoc(), step_denominator.getType(), step_denominator, one, - xla::getBroadcastDimensionsAttr(&rewriter, step_denominator, one)); + hlo::getBroadcastDimensionsAttr(&rewriter, step_denominator, one)); } - auto step = rewriter.create( + auto step = rewriter.create( op.getLoc(), step_numerator.getType(), step_numerator, step_denominator, - xla::getBroadcastDimensionsAttr(&rewriter, step_numerator, + hlo::getBroadcastDimensionsAttr(&rewriter, step_numerator, step_denominator)); // Scale the iota and add the offset. auto iota = rewriter.create(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, step, - xla::getBroadcastDimensionsAttr(&rewriter, iota, step)); - rewriter.replaceOpWithNewOp( + hlo::getBroadcastDimensionsAttr(&rewriter, iota, step)); + rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), - xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); + hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); } }; -/// Converts a generic OpTy tensorflow op to a xla_hlo.reduce op over +/// Converts a generic OpTy tensorflow op to a mhlo.reduce op over /// ReductionOp. /// `is_accumulation` controls whether it uses higher precision for the actual /// reduction. This is set to false for ops like max where there is no precision @@ -3252,7 +3249,7 @@ class GenericConvertReductionOp : public OpRewritePattern { auto divisor = GetScalarConstOfType(reduce_element_type, loc, divisor_count, &rewriter); auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); - result = rewriter.create( + result = rewriter.create( loc, result, divisor.getResult(), broadcast_dims); } @@ -3272,10 +3269,10 @@ class GenericConvertReductionOp : public OpRewritePattern { // Converts Mean op to HLO Reduce op. // // %init = constant dense<...> : tensor -// %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"] +// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] // {dimensions = ...} // %divisor = constant dense<...> : tensor -// %mean = "xla_hlo.divide"(%sum, %divisor) +// %mean = "mhlo.divide"(%sum, %divisor) class ConvertMeanOp : public GenericConvertReductionOp { public: @@ -3289,7 +3286,7 @@ class ConvertMeanOp // Converts Sum op to HLO Reduce op. // // %init = constant dense<...> : tensor -// %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"] +// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] // {dimensions = ...} class ConvertSumOp : public GenericConvertReductionOp { @@ -3305,7 +3302,7 @@ class ConvertSumOp // Converts Max op to HLO Reduce op. // // %init = constant dense<...> : tensor -// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.maximum"] +// %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] // {dimensions = ...} class ConvertMaxOp : public GenericConvertReductionOp : tensor -// %min = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.minimum"] +// %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"] // {dimensions = ...} class ConvertMinOp : public GenericConvertReductionOp : tensor -// %prod = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.multiply"] +// %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"] // {dimensions = ...} class ConvertProdOp : public GenericConvertReductionOp { @@ -3355,7 +3352,7 @@ class ConvertProdOp // Converts All op to HLO Reduce op. // // %init = constant dense<...> : tensor -// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.and"] +// %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"] // {dimensions = ...} class ConvertAllOp : public GenericConvertReductionOp { @@ -3370,7 +3367,7 @@ class ConvertAllOp // Converts Any op to HLO Reduce op. // // %init = constant dense<...> : tensor -// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.or"] +// %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"] // {dimensions = ...} class ConvertAnyOp : public GenericConvertReductionOp { @@ -3382,7 +3379,7 @@ class ConvertAnyOp } }; -// Converts tensorflow ArgMin or ArgMax op to xla_hlo operations that perform +// Converts tensorflow ArgMin or ArgMax op to mhlo operations that perform // a reduction on the original input and the corresponding index. The reduction // sub-computation selects the max (or min) value and the index for the value. // Derived: is the resulting derived class of this class. @@ -3454,13 +3451,13 @@ class ConvertArgMinMaxOp : public OpRewritePattern { } }; -// Converts tensorflow ArgMax op to xla_hlo operations. The actual +// Converts tensorflow ArgMax op to mhlo operations. The actual // implementation is in class ConvertArgMinMaxOp: // // %init_index = constant dense<...> : tensor // %init = constant dense<...> : tensor -// %reduce = "xla_hlo.reduce"(%selected_input, %select_index, %init, -// %init_index) ["xla_hlo.arg_max"] +// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init, +// %init_index) ["mhlo.arg_max"] class ConvertArgMaxOp : public ConvertArgMinMaxOp { public: @@ -3476,7 +3473,7 @@ class ConvertArgMaxOp // Converts TF TensorScatterUpdate op into Scatter Op with assignment: // -// %result = "xla_hlo.scatter"(%tensor, %indices, %updates) +// %result = "mhlo.scatter"(%tensor, %indices, %updates) // { dimensions = ... } // class ConvertTensorScatterUpdateOp @@ -3534,10 +3531,10 @@ class ConvertTensorScatterUpdateOp // For shape [S1, S2] and multiples [M1, M2], // MS1 = M1 * S1; MS2 = M2 * S2 // -// %broadcast = xla_hlo.broadcast_in_dim(%input) { +// %broadcast = mhlo.broadcast_in_dim(%input) { // broadcast_dimensions = [0, 2] // } -// %result = "xla_hlo.reshape"(%broadcast) : (tensor) +// %result = "mhlo.reshape"(%broadcast) : (tensor) // -> tensor class ConvertTileOp : public OpRewritePattern { public: @@ -3657,8 +3654,8 @@ using ConvertMaxPool3DGradOp = ConvertMaxPoolGradOp; // Converts tf.Conv?DBackpropInputOp into: -// %rev_filter = "xla_hlo.reverse"(%filter) -// %result = "xla_hlo.convolution"(%out_backprop, %rev_filter) +// %rev_filter = "mhlo.reverse"(%filter) +// %result = "mhlo.convolution"(%out_backprop, %rev_filter) template class ConvertConvBackpropInputOp : public OpRewritePattern { public: @@ -3821,7 +3818,7 @@ using ConvertConv3DBackpropInputOp = /*num_spatial_dims=*/3>; // Converts tf.Conv?DBackpropFilterOp into: -// %result = "xla_hlo.convolution"(%input, %out_backprop) +// %result = "mhlo.convolution"(%input, %out_backprop) template class ConvertConvBackpropFilterOp : public OpRewritePattern { public: @@ -4078,7 +4075,7 @@ class ConvertOneHotOp : public OpRewritePattern { loc, index_type, op.indices(), GetI64ElementsAttr(broadcast_dims, &rewriter)); - Value compare = rewriter.create( + Value compare = rewriter.create( loc, broadcast_indices, iota, StringAttr::get("EQ", rewriter.getContext())); Value on_value = rewriter.create( @@ -4111,13 +4108,13 @@ class ConvertOneHotOp : public OpRewritePattern { // // would be lowered to // -// %token = "xla_hlo.create_token"() : () -> !xla_hlo.token -// %data_and_token = "xla_hlo.infeed"(%token) {infeed_config = ""} : -// (!xla_hlo.token) -> tuple, tensor<4xf32>>, -// !xla_hlo.token> -// %data = "xla_hlo.get_tuple_element"(%data_and_token) {index = 0} -// %0#0 = "xla_hlo.get_tuple_element"(%data) {index = 0} -// %0#1 = "xla_hlo.get_tuple_element"(%data) {index = 1} +// %token = "mhlo.create_token"() : () -> !mhlo.token +// %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} : +// (!mhlo.token) -> tuple, tensor<4xf32>>, +// !mhlo.token> +// %data = "mhlo.get_tuple_element"(%data_and_token) {index = 0} +// %0#0 = "mhlo.get_tuple_element"(%data) {index = 0} +// %0#1 = "mhlo.get_tuple_element"(%data) {index = 1} // class ConvertInfeedDequeueTupleOp : public OpRewritePattern { @@ -4133,7 +4130,7 @@ class ConvertInfeedDequeueTupleOp // Infeed takes a single token operand. Generate the token using // create_token op to pass to the infeed op. auto token = rewriter.create( - op.getLoc(), xla_hlo::TokenType::get(rewriter.getContext())); + op.getLoc(), mhlo::TokenType::get(rewriter.getContext())); // Emit infeed op. // The result type of infeed is a tuple(tuple(result types), token type). @@ -4196,11 +4193,11 @@ class ConvertInfeedDequeueTupleOp // // would be lowered to // -// %tuple = "xla_hlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) -> +// %tuple = "mhlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) -> // tuple, tensor<4xf32>> -// %token = "xla_hlo.create_token"() : () -> !xla_hlo.token -// %outfeed_token = "xla_hlo.outfeed"(%tuple, %token) {outfeed_config = ""} : -// (tuple, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token +// %token = "mhlo.create_token"() : () -> !mhlo.token +// %outfeed_token = "mhlo.outfeed"(%tuple, %token) {outfeed_config = ""} : +// (tuple, tensor<4xf32>>, !mhlo.token) -> !mhlo.token // class ConvertOutfeedEnqueueTupleOp : public OpRewritePattern { @@ -4209,7 +4206,7 @@ class ConvertOutfeedEnqueueTupleOp LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op, PatternRewriter &rewriter) const override { - auto token_type = xla_hlo::TokenType::get(rewriter.getContext()); + auto token_type = mhlo::TokenType::get(rewriter.getContext()); auto tuple = rewriter.create(op.getLoc(), op.inputs()); auto token = rewriter.create(op.getLoc(), token_type); rewriter.create(op.getLoc(), token_type, tuple, token, @@ -4235,20 +4232,20 @@ class ConvertOutfeedEnqueueTupleOp // // We will get: // -// %1 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32> -// %2 = "xla_hlo.sort"(%input, %1) ( { +// %1 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32> +// %2 = "mhlo.sort"(%input, %1) ( { // ^bb0(%arg1: tensor, %arg2: tensor, // %arg3: tensor, %arg4: tensor): -// %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ... -// "xla_hlo.return"(%7) : (tensor) -> () +// %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ... +// "mhlo.return"(%7) : (tensor) -> () // }) {dimension = 1 : i64, is_stable = true} : ... -// %3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : ... -// %4 = "xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : ... -// %5 = "xla_hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>, +// %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : ... +// %4 = "mhlo.get_tuple_element"(%2) {index = 1 : i32} : ... +// %5 = "mhlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>, // start_indices dense<0> : tensor<2xi64>, // strides = dense<1> : tensor<2xi64>} : // (tensor<16x16xf32>) -> tensor<16x8xf32> -// %6 = "xla_hlo.slice"(%4) ... +// %6 = "mhlo.slice"(%4) ... class ConvertTopKV2Op : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -4271,12 +4268,12 @@ class ConvertTopKV2Op : public OpRewritePattern { // Create an Itoa op for indices. auto i32_type = rewriter.getIntegerType(32); Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type); - Value iota_op = rewriter.create( + Value iota_op = rewriter.create( op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index)); // Create the sort op. It takes two inputs, one for the original input, the // other for the indices. - auto sort_op = rewriter.create( + auto sort_op = rewriter.create( op.getLoc(), llvm::ArrayRef{op.input(), iota_op}, last_dim_index, /*is_stable=*/true); BuildSortComparisonBody({input_type.getElementType(), i32_type}, @@ -4285,9 +4282,9 @@ class ConvertTopKV2Op : public OpRewritePattern { // Get the sorted input and index tuple element. auto tuple_first_element = - rewriter.create(op.getLoc(), sort_op, 0); + rewriter.create(op.getLoc(), sort_op, 0); auto tuple_second_element = - rewriter.create(op.getLoc(), sort_op, 1); + rewriter.create(op.getLoc(), sort_op, 1); SmallVector begin_indices(input_rank, 0); auto end_indices = llvm::to_vector<4>(input_type.getShape()); @@ -4297,13 +4294,13 @@ class ConvertTopKV2Op : public OpRewritePattern { // Get the slice for the top K elements. - Value values = rewriter.create( + Value values = rewriter.create( op.getLoc(), tuple_first_element, GetI64ElementsAttr(begin_indices, &rewriter), GetI64ElementsAttr(end_indices, &rewriter), GetI64ElementsAttr(strides, &rewriter)); - Value indices = rewriter.create( + Value indices = rewriter.create( op.getLoc(), tuple_second_element, GetI64ElementsAttr(begin_indices, &rewriter), GetI64ElementsAttr(end_indices, &rewriter), @@ -4346,12 +4343,12 @@ class ConvertUnpackOp : public OpRewritePattern { begin_indices[axis] = i; end_indices[axis] = i + 1; - auto slice_op = rewriter.create( + auto slice_op = rewriter.create( op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter), GetI64ElementsAttr(end_indices, &rewriter), GetI64ElementsAttr(strides, &rewriter)); // Reshape to drop the axis dimension. - auto reshape_op = rewriter.create( + auto reshape_op = rewriter.create( op.getLoc(), op.getType(i), slice_op); results.push_back(reshape_op); } @@ -4410,7 +4407,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { // 'operand' parameter to scatter to for the final scatter op. Value init = ConcreteClass::GetInitialValue(data_type.getElementType(), op.getLoc(), &rewriter); - auto broadcasted_init = rewriter.create( + auto broadcasted_init = rewriter.create( op.getLoc(), output_type, init, GetI64ElementsAttr(output_shape, &rewriter)); @@ -4565,7 +4562,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { auto keys = CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0, /*upper_limit=*/u32_max, &rewriter); - auto sorted = rewriter.create( + auto sorted = rewriter.create( op.getLoc(), llvm::ArrayRef{keys, current}); auto i32_type = rewriter.getIntegerType(32); BuildSortComparisonBody({i32_type, input_type.getElementType()}, @@ -4583,7 +4580,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { // Generate range(n) as the initial value for the indices to be swapped. auto indices_type = RankedTensorType::get({first_dim_size}, rewriter.getIntegerType(32)); - Value indices = rewriter.create( + Value indices = rewriter.create( op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0)); // Generate random numbers to be used as swaps for the indices. @@ -4609,21 +4606,21 @@ class ConvertRandomShuffleOp : public OpRewritePattern { // We need to swap the indices[i] with indices[swaps[i]]. First get // these index values. - Value source_index = builder->create( + Value source_index = builder->create( loc, vec1_i32_type, indices, i, scalar_one); - Value swap_index = builder->create( + Value swap_index = builder->create( loc, scalar_i32_type, - builder->create(loc, vec1_i32_type, swaps, i, - scalar_one)); - Value target_index = builder->create( + builder->create(loc, vec1_i32_type, swaps, i, + scalar_one)); + Value target_index = builder->create( loc, vec1_i32_type, indices, swap_index, scalar_one); // Then perform the swap. // indices[i] <- indices[swaps[i]] - indices = builder->create( + indices = builder->create( loc, indices.getType(), indices, target_index, llvm::makeArrayRef(i)); // indices[swaps[i]] <- indices[i] - indices = builder->create( + indices = builder->create( loc, indices.getType(), indices, source_index, llvm::makeArrayRef(swap_index)); @@ -4647,7 +4644,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { /*start_index_map=*/GetI64ElementsAttr({0}, &rewriter), /*index_vector_dim=*/rewriter.getI64IntegerAttr(1), rewriter.getContext()); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.value(), swaped_indices, dims_attr, GetI64ElementsAttr(slice_sizes, &rewriter)); @@ -4666,7 +4663,7 @@ class ConvertXlaShardingOp : public OpRewritePattern { // using a string. if (!op._XlaSharding().hasValue()) return failure(); - auto custom_call = rewriter.create( + auto custom_call = rewriter.create( op.getLoc(), op.getType(), op.input(), /*call_target_name=*/rewriter.getStringAttr("Sharding"), /*has_side_effect=*/rewriter.getBoolAttr(false), @@ -4716,7 +4713,7 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { updates_type.getElementType())); auto cst = - rewriter.create(op.getLoc(), zero_attr).getResult(); + rewriter.create(op.getLoc(), zero_attr).getResult(); auto split_updates = rewriter.create( op.getLoc(), split_updates_type, cst, updates); @@ -4731,7 +4728,7 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { for (auto pair : llvm::zip(unpacked_indices.output(), split_updates.output())) { input_indices.front() = std::get<0>(pair); - input = rewriter.create( + input = rewriter.create( op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices); } @@ -4759,7 +4756,7 @@ class ConvertXlaDynamicUpdateSliceOp auto unpacked_indices = rewriter.create( op.getLoc(), unpacked_indices_type, op.indices(), IntegerAttr::get(rewriter.getIntegerType(64), 0)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.input(), op.update(), unpacked_indices.output()); return success(); } @@ -5009,11 +5006,11 @@ class ConvertQrOp : public OpRewritePattern { Value iota = builder->create( loc, RankedTensorType::get({m}, builder->getIntegerType(32)), builder->getI64IntegerAttr(0)); - Value gtk = builder->create( + Value gtk = builder->create( loc, iota, k, GetI64ElementsAttr({}, builder), StringAttr::get("GT", builder->getContext())); gtk = builder->create(loc, gtk, x_type.getElementType()); - Value x_after_k = builder->create( + Value x_after_k = builder->create( loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder)); Value x_after_k_sq = builder->create(loc, x_after_k, x_after_k); // sigma = np.dot(x[k+1:], x[k+1:]) @@ -5025,15 +5022,15 @@ class ConvertQrOp : public OpRewritePattern { Value mu = builder->create( loc, builder->create(loc, alpha_sq, sigma.getResult(0))); - Value sigma_is_zero = builder->create( + Value sigma_is_zero = builder->create( loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); - Value alpha_is_negative = builder->create( + Value alpha_is_negative = builder->create( loc, alpha, zero, GetI64ElementsAttr({}, builder), StringAttr::get("LT", builder->getContext())); auto batch_size_one = builder->create( loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder)); - Value signed_mu = builder->create( + Value signed_mu = builder->create( loc, builder->create(loc, mu.getType(), alpha_is_negative, batch_size_one, @@ -5051,7 +5048,7 @@ class ConvertQrOp : public OpRewritePattern { divisor = builder->create(loc, divisor.getType(), sigma_is_zero, batch_size_one, divisor); - Value eqk = builder->create( + Value eqk = builder->create( loc, iota, k, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); eqk = builder->create(loc, eqk, x_type.getElementType()); @@ -5065,7 +5062,7 @@ class ConvertQrOp : public OpRewritePattern { // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. // Note that the add performs a degenerate broadcast. - *v = builder->create( + *v = builder->create( loc, e_k, StaticBinaryBroadcast(loc, x_after_k, divisor, GetI64ElementsAttr(batch_dim_ids, builder), @@ -5143,7 +5140,7 @@ class ConvertQrOp : public OpRewritePattern { precision, builder); vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims, precision, builder); - auto tau_x_vva = StaticBinaryBroadcast( + auto tau_x_vva = StaticBinaryBroadcast( loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder), *builder); a = builder->create(loc, a, tau_x_vva); @@ -5155,12 +5152,12 @@ class ConvertQrOp : public OpRewritePattern { auto iota = builder->create( loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)), builder->getI64IntegerAttr(0)); - Value predecessor_mask = builder->create( + Value predecessor_mask = builder->create( loc, iota, j, GetI64ElementsAttr({}, builder), StringAttr::get("LT", builder->getContext())); predecessor_mask = builder->create(loc, predecessor_mask, a_type.getElementType()); - Value mask = builder->create( + Value mask = builder->create( loc, iota, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); mask = builder->create(loc, mask, a_type.getElementType()); @@ -5190,7 +5187,7 @@ class ConvertQrOp : public OpRewritePattern { loc, RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)), builder->getI64IntegerAttr(minor_dim + 1)); - Value xa_mask = builder->create( + Value xa_mask = builder->create( loc, iota_mn, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); a = builder->create(loc, a_type, xa_mask, new_x, a); @@ -5227,7 +5224,7 @@ class ConvertQrOp : public OpRewritePattern { loc, taus.getType(), taus_zeros, GetI64ElementsAttr(taus.getType().cast().getShape(), builder)); - Value taus_mask = builder->create( + Value taus_mask = builder->create( loc, iota_n, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); auto taus_update = builder->create( @@ -5312,7 +5309,7 @@ class ConvertQrOp : public OpRewritePattern { loc, vs.getType(), zero, GetI64ElementsAttr(vs.getType().cast().getShape(), builder)); - auto compare = builder->create( + auto compare = builder->create( loc, iota_mn, j, GetI64ElementsAttr({}, builder), StringAttr::get("GE", builder->getContext())); auto y = builder->create(loc, vs.getType(), compare, zero, vs); @@ -5460,23 +5457,23 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, // Populate with CHLO->HLO lowerings to account for TF ops legalized to // CHLO first. if (legalize_chlo) { - xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); + chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); } ConversionTarget target(*context); if (legalize_chlo) { - target.addIllegalDialect(); + target.addIllegalDialect(); } else { - target.addLegalDialect(); + target.addLegalDialect(); } - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); target.addLegalOp(); if (!allow_partial_conversion) { - // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp. + // Fully qualify ReturnOp here as mhlo dialect also defines a ReturnOp. target.addLegalOp(); DenseSet nonlegalized_ops; LogicalResult result = @@ -5498,5 +5495,5 @@ std::unique_ptr> createLegalizeTFPass( return std::make_unique(allow_partial_conversion, legalize_chlo); } -} // end namespace xla_hlo +} // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index 934cdea7337f4e..09e94d9a84f7b1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -48,7 +48,7 @@ limitations under the License. using mlir::PassRegistration; namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { class LegalizeTFControlFlow : public PassWrapper> { @@ -67,7 +67,7 @@ namespace { void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) { // De-tuple the results of the xla hlo if result. for (auto result_it : llvm::enumerate(replace)) { - auto get_tuple_value = builder->create( + auto get_tuple_value = builder->create( result_it.value().getLoc(), tuple, result_it.index()); result_it.value().replaceAllUsesWith(get_tuple_value); } @@ -95,10 +95,10 @@ void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc, auto result = builder.create(loc, func, detupled_args).getResults(); if (!tuple_return) { - builder.create(loc, result); + builder.create(loc, result); } else { auto tuple_op = builder.create(loc, result); - builder.create(loc, tuple_op.getResult()); + builder.create(loc, tuple_op.getResult()); } } @@ -109,12 +109,12 @@ void LowerIf(TF::IfOp op, ModuleOp module) { // XLA prefers tuple arguments for control flow due to XLA not supporting // multiple return values. SmallVector inputs(op.input()); - auto tuple_input = builder.create(loc, inputs); + auto tuple_input = builder.create(loc, inputs); // Create the new if op with tuple inputs. auto result_type = builder.getTupleType(op.getResultTypes()); - auto if_op = builder.create(loc, result_type, op.cond(), - tuple_input, tuple_input); + auto if_op = builder.create(loc, result_type, op.cond(), + tuple_input, tuple_input); // Import the regions for both the true and false cases. These regions // must be updated to tuple the return results together and use the xla hlo @@ -136,15 +136,15 @@ void LowerCase(TF::CaseOp op, ModuleOp module) { // XLA requires one argument per branch so we create a tuple of inputs to pass // to each branch. SmallVector inputs(op.input()); - auto tuple_input = builder.create(loc, inputs); + auto tuple_input = builder.create(loc, inputs); // Create replica of input tuple for each branch SmallVector n_tuple_inputs(op.branches().size(), tuple_input); // Create the new case op with tuple inputs. - auto case_op = builder.create( - loc, op.getResultTypes(), op.branch_index(), n_tuple_inputs, - op.branches().size()); + auto case_op = + builder.create(loc, op.getResultTypes(), op.branch_index(), + n_tuple_inputs, op.branches().size()); // Import the regions for all branches. for (unsigned i = 0; i < op.branches().size(); ++i) { @@ -166,10 +166,10 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { // multiple return values. SmallVector inputs(op.input()); builder.setInsertionPoint(op); - Value tuple_input = builder.create(loc, inputs); + Value tuple_input = builder.create(loc, inputs); // Create the new while op with tuple inputs. - auto while_op = builder.create( + auto while_op = builder.create( loc, builder.getTupleType(op.getResultTypes()), tuple_input); // Import the regions for both the cond and body. These regions must be @@ -204,9 +204,9 @@ void LegalizeTFControlFlow::runOnOperation() { } }); } -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir -static PassRegistration cfpass( +static PassRegistration cfpass( "xla-legalize-tf-control-flow", "Legalize TensorFlow control flow to the XLA dialect"); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 643f1cba04a925..b1667f526d28e9 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -259,7 +259,7 @@ def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis)), //===----------------------------------------------------------------------===// def CastElementsToI64Elements : NativeCodeCall< - "xla::ConvertElementsAttr(" + "hlo::ConvertElementsAttr(" "$0, $_builder.getIntegerType(64)).cast()">; def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)), @@ -366,7 +366,7 @@ class GetDimensionSizeFromEnd: NativeCodeCall< // For now, this op needs to be created in C++ because the expected output type // cannot be inferred. class createIotaOp: NativeCodeCall< - "$_builder.create($0.getOwner()->getLoc(), " + "$_builder.create($0.getOwner()->getLoc(), " "Get2DTensorType($1), $_builder.getI64IntegerAttr(" # dim # "))">; // This op needs to be created in C++ because the generated Convert Op has no diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 54453406ef791e..d25b38d9ece917 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -37,7 +37,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -70,16 +69,16 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { template using InlinedVector = tensorflow::gtl::InlinedVector; // non-absl ok -static bool IsOpWhitelisted(Operation* op) { - // White-listed TensorFlow ops are known to have well behaved tf2xla kernels +static bool IsOpAllowlisted(Operation* op) { + // Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels // building valid MLIR using MlirHloBuilder. - // TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for + // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for // all tf2xla kernels. // clang-format off static llvm::SmallDenseSet ops = { @@ -343,7 +342,7 @@ LogicalResult FuncLegalizer::Legalize() { } LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { - if (!IsOpWhitelisted(op)) return success(); + if (!IsOpAllowlisted(op)) return success(); // Only static shaped operands are supported in XLA builders for now. for (Type ty : op->getOperandTypes()) { @@ -545,5 +544,5 @@ std::unique_ptr> createLegalizeTfWithTf2XlaPass( return std::make_unique(device_type); } -} // end namespace xla_hlo +} // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc similarity index 89% rename from tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc rename to tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index a9a12f0dee31c3..519068893e7a4b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h" +#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" #include #include @@ -190,51 +190,51 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { using ::xla::HloOpcode; switch (instr->opcode()) { case HloOpcode::kAbs: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kAdd: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kAnd: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kCeil: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kComplex: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kCopy: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kCos: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kDivide: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kExp: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kImag: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kLog: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kMaximum: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kMinimum: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kMultiply: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kNegate: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kReal: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kRemainder: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kRsqrt: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kSelect: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kSign: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kSqrt: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kSubtract: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kTanh: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); default: llvm::errs() << instr->ToString(); return tensorflow::errors::Internal( @@ -246,7 +246,7 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { StatusOr LhloDialectEmitter::EmitSortOp( HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs(instr)); + TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs(instr)); auto* sort_instr = ::xla::Cast<::xla::HloSortInstruction>(instr); sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension())); sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable())); @@ -379,16 +379,16 @@ Status LhloDialectEmitter::Initialize() { block->addArgument(arg_type); allocations_[alloc] = block->getArguments().back(); args_attrs.emplace_back(); - args_attrs.back().set(builder_.getIdentifier("xla_lhlo.params"), + args_attrs.back().set(builder_.getIdentifier("lmhlo.params"), builder_.getIndexAttr(alloc->parameter_number())); } else { block->addArgument(MemRefType::get({alloc->size()}, i8_type_)); allocations_[alloc] = block->getArguments().back(); args_attrs.emplace_back(); - args_attrs.back().set(builder_.getIdentifier("xla_lhlo.alloc"), + args_attrs.back().set(builder_.getIdentifier("lmhlo.alloc"), builder_.getIndexAttr(alloc->index())); if (alloc->maybe_live_out()) - args_attrs.back().set(builder_.getIdentifier("xla_lhlo.liveout"), + args_attrs.back().set(builder_.getIdentifier("lmhlo.liveout"), builder_.getBoolAttr(true)); } } diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h similarity index 75% rename from tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h rename to tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 9db490cb797452..ca40eb5804c01e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_HLO_TO_LHLO_WITH_XLA_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_HLO_TO_LHLO_WITH_XLA_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project @@ -33,46 +33,46 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { // of the visitors. tensorflow::Status Initialize(); - LhloDialectEmitter(const xla::BufferAssignment& assignment, - const xla::HloComputation& computation, ModuleOp module) + LhloDialectEmitter(const ::xla::BufferAssignment& assignment, + const ::xla::HloComputation& computation, ModuleOp module) : assignment_(std::move(assignment)), computation_(computation), module_(module), builder_(module.getContext()), i8_type_(builder_.getIntegerType(8)) {} - xla::StatusOr EmitSortOp(xla::HloInstruction* instr); + ::xla::StatusOr EmitSortOp(::xla::HloInstruction* instr); private: template - xla::StatusOr CreateOpWithoutAttrs(xla::HloInstruction* instr); + ::xla::StatusOr CreateOpWithoutAttrs(::xla::HloInstruction* instr); - tensorflow::Status DefaultAction(xla::HloInstruction* instr) final; + tensorflow::Status DefaultAction(::xla::HloInstruction* instr) final; // Computation parameters don't need any specific handling when they are // visited, they are already processed when we enter a new computation. - tensorflow::Status HandleParameter(xla::HloInstruction* instr) final { + tensorflow::Status HandleParameter(::xla::HloInstruction* instr) final { return tensorflow::Status::OK(); } - tensorflow::Status HandleSort(xla::HloInstruction* instr) final; + tensorflow::Status HandleSort(::xla::HloInstruction* instr) final; // Helper function that recursively visits the tuple structure in - // `current_shape`, and reconstruct a matching xla_lhlo::TupleOp. + // `current_shape`, and reconstruct a matching lmhlo::TupleOp. // Each leaf node is converted to an std.view op with corresponding offsets. // If no tuple presents, it simply returns a view of the buffer. - tensorflow::Status CreateView(const xla::HloInstruction* instr, - const xla::Shape& current_shape, + tensorflow::Status CreateView(const ::xla::HloInstruction* instr, + const ::xla::Shape& current_shape, ::xla::ShapeIndex* current_shape_index, SmallVectorImpl* values); // Helper function to create view/tuple of views to a buffer for a given // instruction result. - tensorflow::Status GetOrCreateView(const xla::HloInstruction* instr, + tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr, SmallVectorImpl* values); // Return an MLIR location for an HLO instruction. - Location getLocation(xla::HloInstruction* inst) { + Location getLocation(::xla::HloInstruction* inst) { return NameLoc::get(builder_.getIdentifier(inst->name()), builder_.getContext()); } @@ -84,7 +84,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { // using a "slice" of the buffer allocation and providing shape, layout, and // Dtype. An MLIR view is used separately to model slices into the allocations // (see below). - llvm::DenseMap allocations_; + llvm::DenseMap allocations_; // This map provides access to MLIR buffers for each HLO instruction, keyed by // its buffer slice. A slice is contained in a BufferAllocation, and has an @@ -101,14 +101,14 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { // // `slices_` is populated lazily in the `GetOrCreateView()` helper as we // process every instruction. - using SliceKey = std::tuple; + using SliceKey = std::tuple; llvm::DenseMap> slices_; // The BufferAssignment computed by XLA ahead of time. - const xla::BufferAssignment& assignment_; + const ::xla::BufferAssignment& assignment_; // The HLO module that will be converted. - const xla::HloComputation& computation_; + const ::xla::HloComputation& computation_; // This is the MLIR module in which a function will be created for every HLO // computation. @@ -123,10 +123,10 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { // Populate the MLIR `module` with the computation from the `hlo_module` using // the provided buffer `assignment`. The returned `Status` indicates success // or failure in the conversion. -tensorflow::Status HloToLhloModule(const xla::BufferAssignment& assignment, - const xla::HloModule& hlo_module, +tensorflow::Status HloToLhloModule(const ::xla::BufferAssignment& assignment, + const ::xla::HloModule& hlo_module, ModuleOp module); } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_HLO_TO_LHLO_WITH_XLA_H_ +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 67951e6832eef4..bc261324055dd9 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -29,7 +29,7 @@ template class OperationPass; class Pass; -namespace xla_hlo { +namespace mhlo { /// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is /// false, emits an error if there is any operation that can't be legalized. @@ -51,7 +51,7 @@ std::unique_ptr> createLegalizeTFControlFlowPass(); LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false, bool legalize_chlo = true); -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index 9aae596a05e4ed..b684abde7a5b39 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -179,7 +179,7 @@ Shape TypeToShape(mlir::Type type) { } return ShapeUtil::MakeTupleShape(shapes); } - case mlir::xla_hlo::HLOTypes::Token: + case mlir::mhlo::HLOTypes::Token: return ShapeUtil::MakeTokenShape(); default: break; diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 9590688fda78a1..326c3ec4929345 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -30,6 +30,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import image_ops @@ -774,6 +775,7 @@ def testNonAlignCorners3x2To6x4Batch2(self): class NonMaxSuppressionTest(xla_test.XLATestCase): + @test_util.disable_mlir_bridge("%1") def testNMS128From1024(self): num_boxes = 1024 boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") @@ -808,6 +810,7 @@ def testNMS128From1024(self): self.assertEqual(indices_tf.size, max_output_size) + @test_util.disable_mlir_bridge("%1") def testNMS3From6Boxes(self): # Three boxes are selected based on IOU. boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], @@ -849,6 +852,7 @@ def testNMS3From6Boxes(self): self.assertEqual(num_valid, 3) self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) + @test_util.disable_mlir_bridge("%1") def testNMS3Then2WithScoreThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. @@ -891,6 +895,7 @@ def testNMS3Then2WithScoreThresh(self): self.assertEqual(num_valid, 2) self.assertAllClose(indices_tf[:num_valid], [3, 0]) + @test_util.disable_mlir_bridge("%1") def testNMS3Then1WithScoreMaxThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. @@ -934,6 +939,7 @@ def testNMS3Then1WithScoreMaxThresh(self): self.assertEqual(num_valid, 1) self.assertAllClose(indices_tf[:num_valid], [3]) + @test_util.disable_mlir_bridge("%1") def testSelectFromContinuousOverLap(self): # Tests that a suppressed box does not itself suppress other boxes. @@ -978,6 +984,7 @@ def testSelectFromContinuousOverLap(self): class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): + @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1015,6 +1022,7 @@ def testBatchedNMSFrom6(self): indices_output) self.assertAllEqual([5, 4], num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1048,6 +1056,7 @@ def testBatchedNMSFrom6Max3(self): self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output) self.assertAllEqual([3, 3], num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSSingleFrom6Max3(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1078,6 +1087,7 @@ def testBatchedNMSSingleFrom6Max3(self): self.assertAllEqual([0, 1, 2], indices_output) self.assertAllEqual(3, num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSSingleFrom6NoPad(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1107,6 +1117,7 @@ def testBatchedNMSSingleFrom6NoPad(self): self.assertAllEqual([0, 1, 2, 4, 5], indices_output) self.assertAllEqual(5, num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSBatchDimsFrom6Max3(self): boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1140,6 +1151,7 @@ def testBatchedNMSBatchDimsFrom6Max3(self): self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output) self.assertAllEqual([[3, 3]], num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSScoreThresholdFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1175,6 +1187,7 @@ def testBatchedNMSScoreThresholdFrom6Max3(self): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSUnsortedInputFrom6(self): boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]], @@ -1211,6 +1224,7 @@ def testBatchedNMSUnsortedInputFrom6(self): indices_output) self.assertAllEqual([5, 4], num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSNoncanonicalizedInputFrom6(self): boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4], [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]], @@ -1248,6 +1262,7 @@ def testBatchedNMSNoncanonicalizedInputFrom6(self): indices_output) self.assertAllEqual([5, 4], num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1283,6 +1298,7 @@ def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6DynamicInput(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 368cb5af2ed2ea..0718bd8cd65bcf 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -501,6 +501,7 @@ cc_library( copts = tf_copts(), deps = [ ":common_utils", + ":utils", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -570,6 +571,7 @@ cc_library( deps = [ "@com_google_absl//absl/algorithm:container", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:lib", ] + if_tensorrt([":tensorrt_lib"]), diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 5429aaf3362e24..c9210a1a1e7959 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -130,7 +130,9 @@ Status GetEngineInfo(const Graph* g, EngineInfo* info) { std::vector subgraph_nodes; // Topologically sorted nodes. std::set added_const_nodes; // Used to prevent double insertion. - std::set segment_devices; + // The device assignment accumulated from the compatible device assignments + // for the nodes in the segment. + DeviceNameUtils::ParsedName segment_device; // Map from src_node_name+port to the unique port numbers of the TRT op, where // the src_node_name is the name of the source node of the input/output @@ -144,36 +146,17 @@ Status GetEngineInfo(const Graph* g, const Node* node = *it; if (segment_nodes.count(node) == 0) continue; - std::string device_name; - if (!node->requested_device().empty()) { - device_name = node->requested_device(); - } else if (node->has_assigned_device_name()) { - // It appears that nodes will not have assigned devices at this point in - // execution. - device_name = node->assigned_device_name(); - } else { - VLOG(2) << "Node " << node->name() - << " neither have requested device nor assigned device"; - } - - if (!device_name.empty()) { - // If device is set, it means device placement may have been done before, - // so we need to assign a device for the TRTEngineOp if the assigned - // device is a GPU device. - DeviceNameUtils::ParsedName parsed_name; - const bool parse_succeeded = - DeviceNameUtils::ParseFullName(device_name, &parsed_name); - if (!parse_succeeded) { - VLOG(1) << "Failed to parse " - << (node->requested_device().empty() ? "assigned" : "requested") - << " device " << device_name << " of node " << node->name(); - } else if (parsed_name.type != "GPU") { - VLOG(1) << "Node " << node->name() - << " was assigned to a non-GPU device " << device_name; - } else { - segment_devices.insert(device_name); - } + absl::optional new_segment_device = + MergeIfCompatible(segment_device, GetDeviceName(node)); + if (!new_segment_device.has_value()) { + // The segmenter should guarantee that nodes in the same segment have + // compatible device assignments. + return errors::Internal( + "segment nodes have incompatible device assignments: ", + DeviceNameUtils::ParsedNameToString(segment_device), " vs ", + GetDeviceName(node), " to node ", node->name()); } + segment_device = *new_segment_device; subgraph_nodes.push_back(node); const int node_id = node->id(); @@ -273,13 +256,16 @@ Status GetEngineInfo(const Graph* g, info->engine_name = StrCat(scope_name, info->engine_name); VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name << "' to a GraphDef"; - if (segment_devices.size() == 1) { - info->device = *segment_devices.begin(); - } else if (segment_devices.size() > 1) { - LOG_WARNING_WITH_PREFIX - << "Detected multiple (" << segment_devices.size() - << ") devices for the segment. Picking first one to continue."; - info->device = *segment_devices.begin(); + if (segment_device.has_type) { + // If the accumulated device assignment for the segment has a device type, + // the segmenter guarantees the device type is GPU. Use the device + // assignment in this case. + if (segment_device.type != "GPU") { + return errors::Internal( + "segment device is not GPU: ", + DeviceNameUtils::ParsedNameToString(segment_device)); + } + info->device = DeviceNameUtils::ParsedNameToString(segment_device); } else { TfGpuId tf_gpu_id; PlatformGpuId platform_gpu_id; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index f2407fccfad6a1..369b339d01a12f 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -1429,6 +1429,18 @@ Status Converter::BuildCudaEngine( TF_RETURN_IF_ERROR(profiles->ConfigureBuilder( trt_builder_.get(), builder_config.get(), network())); } + + string precision_mode_str; + TF_RETURN_IF_ERROR( + TrtPrecisionModeToName(precision_mode_, &precision_mode_str)); + string trt_network_name = StrCat( + "TF:", TF_VERSION_STRING, ", ", "TRT:", GetLoadedTensorRTVersion(), "-", + "Precision:", precision_mode_str, ", ", "Calibration:", use_calibration_, + ", ", "Max-Batch-Size:", max_batch_size, ", ", + "Max-Workspace-Size:", max_workspace_size_bytes); + VLOG(1) << "Setting TensorRT network name to " << trt_network_name; + network()->setName(trt_network_name.c_str()); + VLOG(1) << "Building TensorRT engine"; engine->reset( trt_builder_->buildEngineWithConfig(*network(), *builder_config)); @@ -1449,19 +1461,6 @@ Status Converter::BuildCudaEngine( } } -#if IS_TRT_VERSION_GE(6, 0, 0, 0) - string precision_mode_str; - TF_RETURN_IF_ERROR( - TrtPrecisionModeToName(precision_mode_, &precision_mode_str)); - string trt_network_name = StrCat( - "TF:", TF_VERSION_STRING, ", ", "TRT:", GetLoadedTensorRTVersion(), "-", - "Precision:", precision_mode_str, ", ", "Calibration:", use_calibration_, - ", ", "Max-Batch-Size:", max_batch_size, ", ", - "Max-Workspace-Size:", max_workspace_size_bytes); - VLOG(1) << "Setting TensorRT network name to " << trt_network_name; - network()->setName(trt_network_name.c_str()); -#endif // #if IS_TRT_VERSION_GE(6, 0, 0, 0) - VLOG(1) << "Building TensorRT engine"; engine->reset(trt_builder_->buildCudaEngine(*network())); #endif diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index a4b64ec0dc58d2..a69960005fca23 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -271,5 +271,44 @@ string GetLoadedTensorRTVersion() { return absl::StrCat(major, ".", minor, ".", patch); } +absl::string_view GetDeviceName(const Node* node) { + if (node->has_assigned_device_name()) { + return node->assigned_device_name(); + } + return node->requested_device(); +} + +absl::optional GetDeviceParsedName( + const Node* node) { + absl::string_view device_name = GetDeviceName(node); + DeviceNameUtils::ParsedName parsed_name; + if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) { + return absl::nullopt; + } + return parsed_name; +} + +absl::optional MergeIfCompatible( + const DeviceNameUtils::ParsedName& a, + const DeviceNameUtils::ParsedName& b) { + DeviceNameUtils::ParsedName merged_name = a; + if (!DeviceNameUtils::MergeDevNames(&merged_name, b, + /*allow_soft_placement=*/false) + .ok()) { + return absl::nullopt; + } + return merged_name; +} + +absl::optional MergeIfCompatible( + const DeviceNameUtils::ParsedName& a, absl::string_view b) { + DeviceNameUtils::ParsedName b_parsed_name; + if (!DeviceNameUtils::ParseFullName(b, &b_parsed_name)) { + return absl::nullopt; + } + + return MergeIfCompatible(a, b_parsed_name); +} + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 775616ff7aa55d..a0505c3f92298e 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -133,6 +134,26 @@ bool AreShapesCompatible(const std::vector& actual_shapes, // input bindings, because the number of total input bindings equals the number // of profiles times the number of engine inputs. int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine); + +// Returns the string representation for the assigned device or the requested +// device of the given node. +absl::string_view GetDeviceName(const Node* node); + +// Returns the ParsedName representation for the assigned device or the +// requested device string of the given node. If the device string is invalid, +// returns absl::nullopt. +absl::optional GetDeviceParsedName( + const Node* node); + +// If the given two device assignments as compatible, returns the merge of the +// two assignments. Otherwise, returns absl::nullopt. +absl::optional MergeIfCompatible( + const DeviceNameUtils::ParsedName& a, const DeviceNameUtils::ParsedName& b); +// Similar to the above, except that the second device assignment is represented +// by a string_view. +absl::optional MergeIfCompatible( + const DeviceNameUtils::ParsedName& a, absl::string_view b); + #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index d9080b6f69a6a5..e7820ca41fe0d5 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "tensorflow/compiler/tf2tensorrt/common/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/graph/algorithm.h" @@ -664,11 +665,15 @@ ClusterBatchSize GetClusterBatchSizeForNode( void AddSegmentForNode(const grappler::GraphProperties* graph_properties, std::vector>* segments, - SimpleNode* node, bool use_implicit_batch) { + SimpleNode* node, + const DeviceNameUtils::ParsedName& device_name, + bool use_implicit_batch) { segments->emplace_back( - node, GetClusterBatchSizeForNode( - graph_properties, node == nullptr ? nullptr : node->tf_node(), - use_implicit_batch)); + node, + GetClusterBatchSizeForNode(graph_properties, + node == nullptr ? nullptr : node->tf_node(), + use_implicit_batch), + device_name); } } // namespace @@ -721,6 +726,10 @@ Status SegmentGraph(const Graph* tf_graph, std::vector> node_segments; for (int i = 0; i < graph->num_node_ids(); ++i) { SimpleNode* node = graph->FindNodeId(i); + if (!node) { + VLOG(3) << "Node " << i << " doesn't exist in the graph"; + continue; + } auto exclude_node = [&](absl::string_view reason) { VLOG(1) << "Not a TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " @@ -730,7 +739,13 @@ Status SegmentGraph(const Graph* tf_graph, num_unsupported_ops++; node = nullptr; }; - if (options.exclude_node_list.count(node->name()) != 0) { + absl::optional device_name = + GetDeviceParsedName(node->tf_node()); + // GetDeviceParseName capitalizes the device type. + if (!device_name.has_value() || + (device_name->has_type && device_name->type != "GPU")) { + exclude_node("node can't be placed on GPU"); + } else if (options.exclude_node_list.count(node->name()) != 0) { exclude_node("excluded by segmenter option"); } else if (options.use_implicit_batch && !OperationCanBeTranslatedToImplicitBatch(graph_properties, @@ -759,7 +774,7 @@ Status SegmentGraph(const Graph* tf_graph, << "(Op name: " << node->name(); } } - AddSegmentForNode(graph_properties, &node_segments, node, + AddSegmentForNode(graph_properties, &node_segments, node, *device_name, options.use_implicit_batch); } string msg = StrCat( @@ -805,6 +820,8 @@ Status SegmentGraph(const Graph* tf_graph, // contracting an output edge may unblock new edges for contracting. ClusterBatchSize expected_batch_size = node_segments[node->id()].BatchSize(); + DeviceNameUtils::ParsedName expected_device_name = + node_segments[node->id()].DeviceName(); VLOG(3) << "batch size " << expected_batch_size; while (true) { std::set contract_edges; @@ -817,26 +834,39 @@ Status SegmentGraph(const Graph* tf_graph, VLOG(3) << "... ... Control Edge, Skipping"; continue; } + UnionFind* out_cluster = + &node_segments[out_edge->dst()->id()]; // Out node must be a TRT candidate. - if (node_segments[out_edge->dst()->id()].Value() == nullptr) { + if (out_cluster->Value() == nullptr) { VLOG(3) << "... ... not a TRT candidate"; continue; } // Out node must have compatible batch size. - ClusterBatchSize out_batch_size = - node_segments[out_edge->dst()->id()].BatchSize(); + ClusterBatchSize out_batch_size = out_cluster->BatchSize(); ClusterBatchSize merged_batch_size = expected_batch_size; if (!merged_batch_size.MergeIfCompatible(out_batch_size)) { - VLOG(3) << "... ... incompatible batch size " + VLOG(3) << "... ... incompatible batch sizes " << expected_batch_size.ToString() << " " << out_batch_size.ToString(); continue; } + + const DeviceNameUtils::ParsedName& out_device_name = + out_cluster->DeviceName(); + absl::optional merged_device_name = + MergeIfCompatible(expected_device_name, out_device_name); + if (!merged_device_name.has_value()) { + VLOG(3) << "... ... incompatible device names " + << expected_device_name << " " << out_device_name; + continue; + } + if (CanContractEdge(out_edge, graph)) { VLOG(3) << "... ... can contract. new batch size " << merged_batch_size.ToString(); contract_edges.insert(out_edge); expected_batch_size = merged_batch_size; + expected_device_name = *merged_device_name; } else { VLOG(3) << "... ... cannot contract, would form cycle"; } @@ -868,12 +898,14 @@ Status SegmentGraph(const Graph* tf_graph, graph->RemoveEdge(r); } } - ClusterBatchSize actual_batch_size = - node_segments[node->id()].BatchSize(); - if (expected_batch_size != actual_batch_size) { + if (expected_batch_size != node_segments[node->id()].BatchSize()) { return errors::Internal( "expected batch size is not the same as the actual batch size"); } + if (expected_device_name != node_segments[node->id()].DeviceName()) { + return errors::Internal( + "expected device name is not the same as the actual device name"); + } } } @@ -884,34 +916,9 @@ Status SegmentGraph(const Graph* tf_graph, // the segment tree) to the segment nodes set. std::map> sg_map; - // A map from the segment identifier (currently the name of the root node of - // the segment tree) to the device names that the nodes in the segment are - // assigned to. - // - // TODO(aaroey): nodes assigned to different devices should not be merged, - // fix this. - std::unordered_map> device_maps; - for (auto& u : node_segments) { if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) { sg_map[u.ParentValue()->name()].insert(u.Value()->tf_node()); - auto tf_node = u.Value()->tf_node(); - // has_assigned_device_name() is expected to return true - // when called from optimization pass. However, since graph - // is converted back and forth between graph and graphdef, - // assigned devices demoted to requested devices. If the graph - // is passed directly to this module, assigned devices will be set. - if (tf_node->has_assigned_device_name()) { - device_maps[u.ParentValue()->name()].insert( - tf_node->assigned_device_name()); - } else if (!tf_node->requested_device().empty()) { - device_maps[u.ParentValue()->name()].insert( - tf_node->requested_device()); - } else { - VLOG(2) << "Node " << tf_node->name() - << " has no device assigned requested device is: " - << tf_node->requested_device(); - } } } @@ -1030,30 +1037,9 @@ Status SegmentGraph(const Graph* tf_graph, continue; } - const auto& dev_itr = device_maps.find(segment_root); - if (dev_itr == device_maps.end() || dev_itr->second.empty()) { - VLOG(1) << "No device assigned to segment " << segments->size(); - } else if (dev_itr->second.size() > 1) { - string s = StrCat("Segment ", segments->size(), - " has multiple devices attached: "); - for (const auto& dev : dev_itr->second) { - StrAppend(&s, dev, ", "); - } - LOG_WARNING_WITH_PREFIX << s; - } - segments->emplace_back(segment_nodes); } - if (VLOG_IS_ON(1)) { - for (const auto& d : device_maps) { - string s("Segment "); - StrAppend(&s, ": '", d.first, "' "); - for (const auto& dd : d.second) { - StrAppend(&s, dd, ", "); - } - VLOG(1) << "Devices " << s; - } - } + return Status::OK(); } diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index f3bc5bfbee61ed..bf277328fe7c94 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -179,6 +179,69 @@ TEST_F(SegmentTest, Simple) { RunTest(&g, all_adds, all_adds, without_add3, {all_adds}); } +TEST_F(SegmentTest, WithDeviceAssignments) { + // feed + // // \\ + // add0 add1 + // | \ / + // | add2 + // | / \\ + // add3 add4 + // \ / + // + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT); + auto add0 = ops::Add(s.WithOpName("add0"), feed, feed); + auto add1 = ops::Add(s.WithOpName("add1"), feed, feed); + auto add2 = ops::Add(s.WithOpName("add2"), add0, add1); + auto add3 = ops::Add(s.WithOpName("add3"), add0, add2); + auto add4 = ops::Add(s.WithOpName("add4"), add2, add2); + + const std::set all_adds = {"add0", "add1", "add2", "add3", "add4"}; + DisableImplicitBatchMode(); + + { + Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + RunTest(&g, all_adds, all_adds, all_adds, {all_adds}); + } + + { + // Assigning add1 to CPU to exclude it from the cluster. + add1.node()->set_assigned_device_name("/device:CPU:0"); + Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + RunTest(&g, all_adds, all_adds, all_adds, {all_adds - "add1"}); + add1.node()->set_assigned_device_name(""); + } + + { + // Assigning operations add3 and add4 to another GPU to exclude the + // operation from the cluster. + constexpr char kGpu0[] = "/device:GPU:0"; + add0.node()->set_assigned_device_name(kGpu0); + add1.node()->set_assigned_device_name(kGpu0); + add2.node()->set_assigned_device_name(kGpu0); + constexpr char kGpu1[] = "/device:GPU:1"; + add3.node()->set_assigned_device_name(kGpu1); + add4.node()->set_assigned_device_name(kGpu1); + Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + RunTest(&g, all_adds, all_adds, all_adds, {{"add0", "add1", "add2"}}); + } + + { + // Assigning the operations to two compatibile GPU devices resulting in + // one cluster with all operations. + constexpr char kGpuAny[] = "/device:GPU:*"; + add3.node()->set_assigned_device_name(kGpuAny); + add4.node()->set_assigned_device_name(kGpuAny); + Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + RunTest(&g, all_adds, all_adds, all_adds, {all_adds}); + } +} + TEST_F(SegmentTest, AvoidCycle) { // feed // // \\ diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h index b53615ec019215..b91f5771ce5486 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -149,9 +149,11 @@ template class UnionFind { public: UnionFind() : size_(1), parent_(nullptr) {} - explicit UnionFind(const T& v, ClusterBatchSize batch_size) + UnionFind(const T& v, ClusterBatchSize batch_size, + const DeviceNameUtils::ParsedName& device_name) : size_(1), cluster_batch_size_(batch_size), + cluster_device_name_(device_name), parent_(nullptr), value_(v) {} @@ -159,10 +161,16 @@ class UnionFind { // this object to the root of the cluster. int Size() { return FindRoot()->size_; } - // Returns the batch size of the cluster and compress the path from this + // Returns the batch size of the cluster and compresses the path from this // object to the root object. ClusterBatchSize BatchSize() { return FindRoot()->cluster_batch_size_; } + // Returns the device name of the cluster and compresses the path from this + // object to the root object. + const DeviceNameUtils::ParsedName& DeviceName() { + return FindRoot()->cluster_device_name_; + } + // Merges this cluster with 'other'. This cluster's size_ is updated to // the size of the merged cluster; the size_ of 'other' becomes inaccessible // as only the size_ of the root object is accessible. @@ -181,6 +189,7 @@ class UnionFind { int size_; ClusterBatchSize cluster_batch_size_; + DeviceNameUtils::ParsedName cluster_device_name_; UnionFind* parent_; T value_; }; @@ -192,12 +201,20 @@ Status UnionFind::Merge(UnionFind* other) { if (a == b) return Status::OK(); ClusterBatchSize batch_size = a->cluster_batch_size_; - bool merged = batch_size.MergeIfCompatible(other->cluster_batch_size_); - if (!merged) { - return errors::Internal("trying to merge incompatible cluster."); + if (!batch_size.MergeIfCompatible(other->cluster_batch_size_)) { + return errors::Internal( + "trying to merge clusters with incompatible batch sizes."); + } + + absl::optional device_name = + MergeIfCompatible(a->cluster_device_name_, other->cluster_device_name_); + if (!device_name.has_value()) { + return errors::Internal( + "trying to merge clusters with incompatible device assignment."); } a->cluster_batch_size_ = batch_size; + a->cluster_device_name_ = *device_name; b->parent_ = a; a->size_ += b->size_; return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index bf6406a796b1a5..cac72925dfdc0d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -704,6 +704,7 @@ cc_library( srcs = ["mlir_bridge_pass.cc"], hdrs = ["mlir_bridge_pass.h"], deps = [ + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:core_cpu", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index c90261303f5266..1da34266460cfe 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -74,7 +74,8 @@ Status CondConstInputIndices( *(fbody->graph), &compile_time_const_arg_indices, /*compile_time_const_nodes=*/nullptr, flib_runtime)); } - for (int i = 0; i < compile_time_const_arg_indices.size(); i++) { + for (int i = 0, iter_limit = compile_time_const_arg_indices.size(); + i < iter_limit; i++) { if (compile_time_const_arg_indices[i]) { // The 0th input is the pred or branch index, which is not passed to the // branches. So the i'th input of a branch function corresponds to the diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index b6e84eabe8d5d2..5f6dcad55389a9 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -65,7 +65,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, /*compile_time_const_nodes=*/nullptr, ctx->function_library())); args->resize(expressions.size()); - for (int i = 0; i < args->size(); ++i) { + for (int i = 0, iter_limit = args->size(); i < iter_limit; ++i) { XlaCompiler::Argument& arg = (*args)[i]; arg.type = ctx->input_type(i); arg.shape = ctx->InputShape(i); @@ -161,7 +161,8 @@ Status GraphCompiler::Compile() { for (auto* e : n->in_edges()) { if (e->IsControlEdge()) continue; const Node* src = e->src(); - TF_RET_CHECK(src->id() < output_registry.size()); + const int output_registry_size = output_registry.size(); + TF_RET_CHECK(src->id() < output_registry_size); const NodeOutputs& src_outputs = output_registry[src->id()]; tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output()); @@ -268,7 +269,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RET_CHECK(arguments.size() == expressions.size()); std::vector handles; - for (int64 i = 0; i < expressions.size(); ++i) { + for (int64 i = 0, iter_limit = expressions.size(); i < iter_limit; ++i) { if (arguments[i].kind == XlaCompiler::Argument::kConstant) { continue; } @@ -312,7 +313,8 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, } } - for (int64 i = 0; i < result.resource_updates.size(); i++) { + for (int64 i = 0, iter_limit = result.resource_updates.size(); i < iter_limit; + i++) { if (result.resource_updates[i].modified) { XlaResource* resource = expressions[result.resource_updates[i].input_index]->resource(); diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.cc b/tensorflow/compiler/tf2xla/graph_compiler_util.cc index a9385e0556438f..f7adae077f7f95 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.cc @@ -216,7 +216,8 @@ Status CollectArgNodes(const Graph& graph, std::vector* arg_nodes) { } arg_nodes->clear(); for (const auto& index_node : indexed_arg_nodes) { - if (index_node.first != arg_nodes->size()) { + const int arg_nodes_size = arg_nodes->size(); + if (index_node.first != arg_nodes_size) { return errors::InvalidArgument( "Expected ", FunctionLibraryDefinition::kArgOp, " node with index ", arg_nodes->size(), ", but got index ", index_node.first); diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index b60a13972a7c6b..e0bc2ba5052f7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -124,7 +124,8 @@ xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape, // convolutions (as currently implemented). Status CheckConvAttrs(const ConvOpAttrs& attrs) { const int num_dims = attrs.num_spatial_dims + 2; - if (attrs.strides.size() != num_dims) { + const int attrs_strides_size = attrs.strides.size(); + if (attrs_strides_size != num_dims) { return errors::InvalidArgument("Sliding window strides field must specify ", num_dims, " dimensions"); } @@ -135,7 +136,8 @@ Status CheckConvAttrs(const ConvOpAttrs& attrs) { "Current implementation does not yet support strides in the batch and " "depth dimensions."); } - if (attrs.dilations.size() != num_dims) { + const int attrs_dilations_size = attrs.dilations.size(); + if (attrs_dilations_size != num_dims) { return errors::InvalidArgument("Dilations field must specify ", num_dims, " dimensions"); } diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 66545fc72cf11d..6d4393ee00684f 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -85,21 +85,8 @@ XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x)); XLAJIT_MAKE_UNARY(Sigmoid, xla::Logistic(x)); -// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. -static xla::XlaOp Sign(xla::XlaBuilder* b, xla::XlaOp x) { - return b->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); - if (xla::primitive_util::IsComplexType(shape.element_type())) { - return xla::Sign(x); - } - auto gt = xla::Gt(x, xla::ZerosLike(x)); - auto lt = xla::Lt(x, xla::ZerosLike(x)); - return xla::ConvertElementType(gt, shape.element_type()) - - xla::ConvertElementType(lt, shape.element_type()); - }); -} - -XLAJIT_MAKE_UNARY(Sign, Sign(b, x)); +// Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. +XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x)); static xla::XlaOp Softplus(xla::XlaBuilder* b, xla::XlaOp features) { diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc index 7daff47e966b73..2ab86c78e44cdf 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.cc +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -66,7 +66,7 @@ xla::StatusOr Expand(xla::XlaOp input, int64 dim) { // Move the newly created dimension to the end with a transpose. std::vector permutation; - for (int64 i = 0; i != expanded_shape.size(); ++i) { + for (int64 i = 0, iter_limit = expanded_shape.size(); i != iter_limit; ++i) { permutation.push_back(i); if (i == dim) { ++i; diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 720b81a50973c6..42a95bbb9f8836 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -72,7 +72,7 @@ Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, buf_ptrs.reserve(host_tensors.size()); std::vector tensor_shapes(host_tensors.size()); - for (int i = 0; i < host_tensors.size(); i++) { + for (int i = 0, iter_limit = host_tensors.size(); i < iter_limit; i++) { // Validate runtime shapes and fail if it doesn't match the contract. const Tensor* tensor = &host_tensors[i]; buf_ptrs.emplace_back(static_cast(DMAHelper::base(tensor))); diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index c398e5f129e540..eefef26dc24356 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -56,7 +56,7 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, // Skip function graphs as MlirBridgePass will be used instead. if (options.is_function_graph) return Status::OK(); - if (!options.session_options->config.experimental().enable_mlir_bridge()) { + if (!IsEnabled(options.session_options->config)) { VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled"; mlir_bridge_gauge_v1->GetCell()->Set(false); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index b7f8ef203f7318..f7541e634d4ca8 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ #include "llvm/ADT/StringRef.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" namespace tensorflow { @@ -45,7 +46,8 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { llvm::StringRef name() const override { return "bridge"; } bool IsEnabled(const ConfigProto& config_proto) const override { - return config_proto.experimental().enable_mlir_bridge(); + return config_proto.experimental().enable_mlir_bridge() || + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; } // This should be used as a thin mapper around mlir::ModulePass::runOnModule diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index af5ee1f2371b7f..abaeb305104d51 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -95,7 +95,7 @@ static void RegisterDialects() { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); - mlir::registerDialect(); + mlir::registerDialect(); mlir::registerDialect(); return true; }(); @@ -142,7 +142,7 @@ Status ConvertGraphDefToXlaViaMlir( std::string* file_name = debug_info.mutable_files(i); size_t location = file_name->rfind(std::string(debug_info_path_begin_marker)); - if (location != -1) { + if (location != std::string::npos) { *file_name = file_name->substr(location + debug_info_path_begin_marker.length()); } diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 2fce6e7f0c7bd4..146694b775450c 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -55,7 +55,8 @@ xla::StatusOr MakeLayout(absl::Span minor_to_major, } std::vector dim_present(minor_to_major.size(), false); for (auto dim : minor_to_major) { - if (dim < 0 || dim >= minor_to_major.size()) { + const int minor_to_major_size = minor_to_major.size(); + if (dim < 0 || dim >= minor_to_major_size) { return errors::InvalidArgument("Layout dimension out of range: dim=", dim, " rank=", minor_to_major.size()); } @@ -204,7 +205,8 @@ Status GetShapeWithLayout( *output_shape = xla::ShapeUtil::MakeTupleShape(shapes); } else { int64 rank = input_shape.rank(); - if (rank != minor_to_major.size()) { + const int64 minor_to_major_size = minor_to_major.size(); + if (rank != minor_to_major_size) { return errors::InvalidArgument( "Wrong number of layout attribute elements: rank=", rank, " elements=", minor_to_major.size()); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index bcdfd1c6a8ec5d..0454bbb771ab69 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -87,7 +87,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, *computation = std::move(*result.computation); int num_const_results = 0; - for (int i = 0; i < result.outputs.size(); ++i) { + for (int i = 0, iter_limit = result.outputs.size(); i < iter_limit; ++i) { // Ending up with const results (i.e. output args) is an error, since it // means that one or more fetches that the user specified will be dropped // from the generated function. It's most likely a configuration error, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index c1aef3ff690c12..6d92fd97793736 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -64,7 +64,7 @@ Status CheckSignature(const DataTypeVector& types, return errors::Internal("Compilation arguments have ", args.size(), " elements while function has ", types.size()); } - for (int i = 0; i < types.size(); ++i) { + for (int i = 0, iter_limit = types.size(); i < iter_limit; ++i) { // Don't perform type checks on resource variables and tensor // lists (DT_VARIANT) as we have to trick the type system in order to // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor. @@ -192,7 +192,7 @@ Status BuildComputation( // replicate sharding is used. The first element is the output index, second // element is the sharding. std::unordered_map retval_index_and_sharding; - for (int i = 0; i < retvals.size(); ++i) { + for (int i = 0, iter_limit = retvals.size(); i < iter_limit; ++i) { XlaCompiler::OutputDescription& output = (*outputs)[i]; const XlaExpression& retval = retvals[i]; output.type = retval.dtype(); @@ -356,7 +356,7 @@ Status BuildComputation( xla::Shape shape = xla::ShapeUtil::MakeTupleShape(elem_shapes); // Copy specified sharding from retval_index_and_sharding. std::vector sharding_elems; - for (int i = 0; i < elems.size(); i++) { + for (int i = 0, iter_limit = elems.size(); i < iter_limit; i++) { const auto& iter = retval_index_and_sharding.find(i); TF_RET_CHECK(iter != retval_index_and_sharding.end()); const xla::OpSharding& sub_op_sharding = iter->second; @@ -365,7 +365,8 @@ Status BuildComputation( if (elem_shapes[i].IsTuple()) { const std::vector sub_sharding_elems = sub_sharding.tuple_elements(); - TF_RET_CHECK(sub_sharding_elems.size() == + const int64 sub_sharding_elems_size = sub_sharding_elems.size(); + TF_RET_CHECK(sub_sharding_elems_size == xla::ShapeUtil::GetLeafCount(elem_shapes[i])); for (const auto& sub_sharding_elem : sub_sharding_elems) { sharding_elems.push_back(sub_sharding_elem); @@ -700,7 +701,7 @@ Status XlaCompiler::CompileFunction( // Set shapes for _Arg nodes. They are useful for constant folding (e.g. an // Xla op requires a compile-time constant input, and that input is shape of // an _Arg node. - for (int i = 0; i < args.size(); i++) { + for (int i = 0, iter_limit = args.size(); i < iter_limit; i++) { // Skip resource variables and tensor lists. DataType dtype; TF_RETURN_IF_ERROR(GetNodeAttr(fbody->arg_nodes[i]->def(), "T", &dtype)); @@ -942,7 +943,7 @@ Status XlaCompiler::BuildArguments( // to the d'th XLA input. Note that the value -1 corresponds to constants, or // other args that don't correspond to an input. std::vector arg_to_inputs(args.size(), -1); - for (int i = 0; i < input_to_args->size(); i++) { + for (int i = 0, iter_limit = input_to_args->size(); i < iter_limit; i++) { arg_to_inputs[input_to_args->at(i)] = i; } @@ -988,7 +989,7 @@ Status XlaCompiler::BuildArguments( : it->second; } std::vector is_same_across_replicas; - for (int i = 0; i < input_to_args->size(); ++i) { + for (int i = 0, iter_limit = input_to_args->size(); i < iter_limit; ++i) { // Add an entry to is_same_across_replicas for every leaf buffer. is_same_across_replicas.insert( is_same_across_replicas.end(), @@ -1004,7 +1005,7 @@ Status XlaCompiler::BuildArguments( tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } - for (int i = 0; i < input_to_args->size(); ++i) { + for (int i = 0, iter_limit = input_to_args->size(); i < iter_limit; ++i) { const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); @@ -1045,7 +1046,7 @@ Status XlaCompiler::BuildArguments( } } - for (int i = 0; i < input_to_args->size(); ++i) { + for (int i = 0, iter_limit = input_to_args->size(); i < iter_limit; ++i) { const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); @@ -1365,7 +1366,7 @@ void SetTransfer(const string& key, absl::Span types, tf2xla::HostTransferMetadata* transfer) { transfer->set_key(key); CHECK(types.size() == shapes.size()); - for (int i = 0; i < types.size(); ++i) { + for (int i = 0, iter_limit = types.size(); i < iter_limit; ++i) { tf2xla::TensorMetadata* metadata = transfer->add_metadata(); metadata->set_type(types[i]); shapes[i].AsProto(metadata->mutable_shape()); diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index e49c944eeb3af1..c94c4805d53421 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -64,7 +64,8 @@ XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder) string XlaContext::DebugString() const { return "XLA JIT context"; } void XlaContext::SetRetval(int index, const XlaExpression& expression) { - if (retvals_.size() <= index) { + const int64 retvals_size = retvals_.size(); + if (retvals_size <= index) { retvals_.resize(index + 1); } retvals_[index] = expression; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index a43608bd4348e1..e37f465918580b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -63,7 +63,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; if (x.name != y.name) return true; if (x.label != y.label) return true; // The registrations refer to the same Op: ensures they are compatible and - // are restricted to different device whitelists. + // are restricted to different device allowlists. if (x.compilation_only != y.compilation_only) { LOG(WARNING) << "Registrations of " << x.name << " have incompatible compilation_only settings."; @@ -84,14 +84,14 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible allow_string_type settings."; return false; } - if (!x.has_device_whitelist && !y.has_device_whitelist) { + if (!x.has_device_allowlist && !y.has_device_allowlist) { LOG(WARNING) << "Duplicate registrations of " << x.name - << "with no device whitelists."; + << "with no device allowlists."; return false; } - if (x.has_device_whitelist && y.has_device_whitelist) { - for (const auto& device : x.device_whitelist) { - if (y.device_whitelist.count(device) != 0) { + if (x.has_device_allowlist && y.has_device_allowlist) { + for (const auto& device : x.device_allowlist) { + if (y.device_allowlist.count(device) != 0) { LOG(WARNING) << "Multiple registrations of " << x.name << " on device " << device; return false; @@ -185,28 +185,28 @@ void XlaOpRegistry::RegisterCompilationKernels() { // The goal is to allow the co-existence of backend-specific kernels and // generic kernels. To achieve this, we enforce the following order of // registrations for one op: - // 1. Process op registration with device whitelists: + // 1. Process op registration with device allowlists: // this pass registers backend-specific kernels for this op. - // 2. Process op registration without device whitelists: + // 2. Process op registration without device allowlists: // this pass registers the kernels for all the other supported backends. for (auto& ops : registry.ops_) { const string& op_name = ops.first; std::vector>& op_registrations = ops.second; - // Partition the op registration so that the ones with device whitelists - // precede the one without device whitelist. + // Partition the op registration so that the ones with device allowlists + // precede the one without device allowlist. std::partition(op_registrations.begin(), op_registrations.end(), [](const std::unique_ptr& op_reg) { - return op_reg->has_device_whitelist; + return op_reg->has_device_allowlist; }); - // Collect a set of backend registered by ops with device whitelists. - // The op registration without whitelists will register a generic kernel + // Collect a set of backend registered by ops with device allowlists. + // The op registration without allowlists will register a generic kernel // for all other backends not in this set. - std::unordered_set whitelisted_backend; + std::unordered_set allowlisted_backend; for (auto& op_registration : op_registrations) { - if (op_registration->has_device_whitelist) { - whitelisted_backend.insert(op_registration->device_whitelist.begin(), - op_registration->device_whitelist.end()); + if (op_registration->has_device_allowlist) { + allowlisted_backend.insert(op_registration->device_allowlist.begin(), + op_registration->device_allowlist.end()); } } @@ -238,19 +238,19 @@ void XlaOpRegistry::RegisterCompilationKernels() { } for (auto& backend : registry.backends_) { - // If the operator has a device whitelist, only register on whitelisted + // If the operator has a device allowlist, only register on allowlisted // devices. - if (op_registration->has_device_whitelist && - op_registration->device_whitelist.find(backend.first) == - op_registration->device_whitelist.end()) { + if (op_registration->has_device_allowlist && + op_registration->device_allowlist.find(backend.first) == + op_registration->device_allowlist.end()) { continue; } - // If the operator does NOT has a device whitelist, skip all devices + // If the operator does NOT has a device allowlist, skip all devices // that has already been registered. - if (!op_registration->has_device_whitelist && - whitelisted_backend.find(backend.first) != - whitelisted_backend.end()) { + if (!op_registration->has_device_allowlist && + allowlisted_backend.find(backend.first) != + allowlisted_backend.end()) { continue; } @@ -478,17 +478,17 @@ XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( absl::Span devices) { - registration_->has_device_whitelist = true; + registration_->has_device_allowlist = true; for (absl::string_view device : devices) { - registration_->device_whitelist.emplace(device); + registration_->device_allowlist.emplace(device); } return *this; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( absl::string_view device) { - registration_->has_device_whitelist = true; - registration_->device_whitelist.emplace(device); + registration_->has_device_allowlist = true; + registration_->device_allowlist.emplace(device); return *this; } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 7839ae95dc0569..af720fb4bb932c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -258,10 +258,10 @@ class XlaOpRegistry { // Mapping from attribute name to a list of supported types. std::unordered_map> type_constraints; - // An optional whitelist of devices. If there is no whitelist, all devices + // An optional allowlist of devices. If there is no allowlist, all devices // are permitted. - bool has_device_whitelist = false; - std::unordered_set device_whitelist; + bool has_device_allowlist = false; + std::unordered_set device_allowlist; // Names of arguments that must be compile-time constants. std::unordered_set compile_time_constant_inputs; @@ -279,8 +279,8 @@ class XlaOpRegistry { // Returns true if registrations x and y can both be added to the registry. // This is always the case if they refer to different ops. If they refer to // the same op name, they must: have the same values for compilation_only, - // allow_resource_types and allow_variant_types; use a device_whitelist; and - // their whitelists must not intersect. + // allow_resource_types and allow_variant_types; use a device_allowlist; and + // their allowlists must not intersect. static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); static Status CompileTimeConstantInputs(const NodeDef& node_def, @@ -319,7 +319,7 @@ class XlaOpRegistrationBuilder { // Starts an operator registration chain. static XlaOpRegistrationBuilder Name(absl::string_view name); - // Specifies a whitelist of devices on which the operator may run. + // Specifies a allowlist of devices on which the operator may run. XlaOpRegistrationBuilder& Device(absl::string_view devices); XlaOpRegistrationBuilder& Device(absl::Span devices); diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc index 6bd56a8df0a5d0..4836dff7fa0a20 100644 --- a/tensorflow/compiler/xla/client/lib/constants.cc +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -48,7 +48,9 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { builder, static_cast(Eigen::NumTraits::epsilon())); case BF16: - return ConstantR0(builder, bfloat16::epsilon()); + return ConstantR0( + builder, static_cast( + Eigen::NumTraits::epsilon())); case F32: return ConstantR0(builder, std::numeric_limits::epsilon()); case F64: @@ -70,7 +72,8 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) { return ConstantR0(builder, Eigen::NumTraits::lowest()); case BF16: - return ConstantR0(builder, bfloat16::lowest()); + return ConstantR0( + builder, Eigen::NumTraits::lowest()); case F32: return ConstantR0(builder, -std::numeric_limits::max()); case F64: @@ -86,7 +89,8 @@ XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) { return ConstantR0(builder, std::numeric_limits::min()); case BF16: - return ConstantR0(builder, bfloat16::min_positive_normal()); + return ConstantR0( + builder, std::numeric_limits::min()); case F32: return ConstantR0(builder, std::numeric_limits::min()); case F64: @@ -108,7 +112,8 @@ XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) { return ConstantR0(builder, Eigen::NumTraits::highest()); case BF16: - return ConstantR0(builder, bfloat16::highest()); + return ConstantR0( + builder, Eigen::NumTraits::highest()); case F32: return ConstantR0(builder, std::numeric_limits::max()); case F64: @@ -125,8 +130,8 @@ XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) { return ConstantR0( builder, Eigen::NumTraits::quiet_NaN()); case BF16: - return ConstantR0( - builder, bfloat16(std::numeric_limits::quiet_NaN())); + return ConstantR0( + builder, Eigen::NumTraits::quiet_NaN()); case F32: return ConstantR0(builder, std::numeric_limits::quiet_NaN()); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 89bc30e1a0e951..cc6a680c4e97eb 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -424,9 +424,8 @@ StatusOr XlaBuilder::Build(int64 root_id, alias.param_number, alias.param_index.ToString().c_str()); } - TF_RETURN_IF_ERROR(config.SetUpAlias( - alias.output_index, alias.param_number, alias.param_index, - HloInputOutputAliasConfig::AliasKind::kUserAlias)); + TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number, + alias.param_index)); } *module->mutable_input_output_alias() = config.ToProto(); return Status::OK(); @@ -1562,7 +1561,8 @@ XlaOp XlaBuilder::AfterAll(absl::Span tokens) { XlaOp XlaBuilder::CustomCall( const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, - absl::optional> operand_shapes_with_layout) { + absl::optional> operand_shapes_with_layout, + bool has_side_effect) { return ReportErrorOrReturn([&]() -> StatusOr { if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( @@ -1594,14 +1594,15 @@ XlaOp XlaBuilder::CustomCall( } } return CustomCallInternal(call_target_name, operands, shape, opaque, - operand_shapes_with_layout); + operand_shapes_with_layout, has_side_effect); }); } StatusOr XlaBuilder::CustomCallInternal( const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, - absl::optional> operand_shapes_with_layout) { + absl::optional> operand_shapes_with_layout, + bool has_side_effect) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_custom_call_target(call_target_name); @@ -1612,13 +1613,15 @@ StatusOr XlaBuilder::CustomCallInternal( *instr.add_operand_shapes_with_layout() = operand_shape.ToProto(); } } + instr.set_custom_call_has_side_effect(has_side_effect); return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); } XlaOp XlaBuilder::CustomCall( const string& call_target_name, absl::Span operands, const XlaComputation& computation, const Shape& shape, const string& opaque, - absl::optional> operand_shapes_with_layout) { + absl::optional> operand_shapes_with_layout, + bool has_side_effect) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { @@ -3385,27 +3388,29 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, - const string& opaque) { + const string& opaque, bool has_side_effect) { return builder->CustomCall(call_target_name, operands, shape, opaque, - /*operand_shapes_with_layout=*/absl::nullopt); + /*operand_shapes_with_layout=*/absl::nullopt, + has_side_effect); } XlaOp CustomCallWithComputation(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const XlaComputation& computation, - const Shape& shape, const string& opaque) { - return builder->CustomCall(call_target_name, operands, computation, shape, - opaque, - /*operand_shapes_with_layout=*/absl::nullopt); + const Shape& shape, const string& opaque, + bool has_side_effect) { + return builder->CustomCall( + call_target_name, operands, computation, shape, opaque, + /*operand_shapes_with_layout=*/absl::nullopt, has_side_effect); } XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, absl::Span operand_shapes_with_layout, - const string& opaque) { + const string& opaque, bool has_side_effect) { return builder->CustomCall(call_target_name, operands, shape, opaque, - operand_shapes_with_layout); + operand_shapes_with_layout, has_side_effect); } XlaOp Complex(const XlaOp lhs, const XlaOp rhs, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 12623d7912fe43..60bdc32e68d6ea 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -525,7 +525,8 @@ class XlaBuilder { XlaOp CustomCall( const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const string& opaque, - absl::optional> operand_shapes_with_layout); + absl::optional> operand_shapes_with_layout, + bool has_side_effect); // Internal version of CustomCall without computation that doesn't do op // specific error handling and expects arguments to be legal. CustomCall @@ -533,13 +534,15 @@ class XlaBuilder { virtual StatusOr CustomCallInternal( const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const string& opaque, - absl::optional> operand_shapes_with_layout); + absl::optional> operand_shapes_with_layout, + bool has_side_effect); XlaOp CustomCall( const string& call_target_name, absl::Span operands, const XlaComputation& computation, const Shape& shape_with_layout, const string& opaque, - absl::optional> operand_shapes_with_layout); + absl::optional> operand_shapes_with_layout, + bool has_side_effect); XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation, @@ -970,17 +973,16 @@ class XlaBuilder { absl::Span operands); friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, - const string& opaque); - friend XlaOp CustomCallWithComputation(XlaBuilder* builder, - const string& call_target_name, - absl::Span operands, - const XlaComputation& computation, - const Shape& shape, - const string& opaque); + const string& opaque, bool has_side_effect); + friend XlaOp CustomCallWithComputation( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const XlaComputation& computation, + const Shape& shape, const string& opaque, bool has_side_effect); friend XlaOp CustomCallWithLayout( XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, - absl::Span operand_shapes_with_layout, const string& opaque); + absl::Span operand_shapes_with_layout, const string& opaque, + bool has_side_effect); friend XlaOp Complex(XlaOp real, XlaOp imag, absl::Span broadcast_dimensions); friend XlaOp Conj(XlaOp operand); @@ -1674,14 +1676,15 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, // can encode arbitrarily large amounts of information. XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, - const string& opaque = ""); + const string& opaque = "", bool has_side_effect = false); // Overload which constructs a custom call that applies an Xla computation. XlaOp CustomCallWithComputation(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const XlaComputation& computation, - const Shape& shape, const string& opaque = ""); + const Shape& shape, const string& opaque = "", + bool has_side_effect = false); // Overload which constructs a custom call with fixed layouts. The operands will // have the layouts specified by |operand_shapes_with_layout| when provided to @@ -1692,7 +1695,8 @@ XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, absl::Span operand_shapes_with_layout, - const string& opaque = ""); + const string& opaque = "", + bool has_side_effect = false); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/BUILD b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD index 57eeb25bb4920a..8e14bc0f67ce5f 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/BUILD +++ b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD @@ -11,6 +11,7 @@ py_library( srcs = ["xla_sharding.py"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/compiler/xla:xla_data_proto_py", "//tensorflow/compiler/xla/python_api:types", "//tensorflow/compiler/xla/python_api:xla_shape", diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index 40bf8f0c42b665..e05f69b1e8b00d 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -25,6 +25,8 @@ upper_tabs: path: /xla/operation_semantics - title: Shapes and layout path: /xla/shapes + - title: Aliasing + path: /xla/aliasing - title: Tiled layout path: /xla/tiled_layout - title: Use AOT compilation diff --git a/tensorflow/compiler/xla/g3doc/aliasing.md b/tensorflow/compiler/xla/g3doc/aliasing.md new file mode 100644 index 00000000000000..90cd24b3e8f928 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/aliasing.md @@ -0,0 +1,73 @@ +# Aliasing in XLA + +This document describes the aliasing API for XLA: when building an XLA program, +you can specify the desired aliasing between the input and output buffers. + +## Defining aliasing at compile-time + +For example, consider a trivial HLO module which simply adds `1` to its input: + +``` +HloModule increment + +ENTRY entry { + %p = f32[] parameter(0) + %c = f32[] constant(1) + ROOT %out = f32[] add(%p, %c) +} +``` + +This module will allocate two 4-byte buffers: one for the input `%p`, and one +for the output `%out`. + +However, it is often desirable to perform the update in-place (for example, if +in the frontend generating the expression the input variable is no longer alive +after the computation, as in the increment `p++`). + +To perform such an update efficiently, you can specify the input aliasing: + +``` +HloModule increment, input_output_alias={ {}: 0 } + +ENTRY entry { + %p = f32[] parameter(0) + %c = f32[] constant(1) + ROOT %out = f32[] add(%p, %c) +} +``` + +The format specifies that the entire output (marked by `{}`) is aliased to the +input parameter `0`. + +See the +[`XlaBuilder::SetUpAlias`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) +API to specify the aliasing programmatically. + +## Defining aliasing at run-time + +The aliasing defined in the previous step is specified during the _compilation_. +During the execution, you can choose whether actually to donate the buffer using +the +[`LocalClient::RunAsync`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/local_client.h) +API. + +Input buffers to the program are wrapped in +[`ExecutionInput`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/executable.h), +which in turn contain a tree of `MaybeOwningDeviceMemory`. If memory is +specified as _owning_ (ownership of the buffer is passed to the XLA runtime), +the buffer is actually donated, and the update is executed in-place, as +requested by the compile-time aliasing API. + +If, however, the buffer which is aliased at compile time is _not_ donated at +runtime, _copy-protection_ kicks in: an extra output buffer `O` is allocated, +and the contents of the input buffer `P` which was meant to be aliased are +copied into `O` (so effectively the program can execute as if the buffer `O` was +donated at runtime). + +## Frontend interop + +### TF/XLA + +In clusters of TensorFlow program compiled with XLA, all resource variable +updates are aliased at compile time (the aliasing at runtime depends on whether +anything else holds a reference to the resource variable tensor). diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 695ba9dee9385d..6e61e0600a018c 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -215,6 +215,7 @@ tf_cc_test( # TODO(phawkins): figure out TF test infra such that this only runs under GPU. "no_oss", "requires-gpu-nvidia", + "notap", ], deps = [ ":nvidia_gpu_device", diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 39fa6a1c2673d9..6e345b06e43cad 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -24,6 +24,7 @@ tf_proto_library_cc( has_services = 1, cc_api_version = 2, cc_grpc_version = 1, + create_service = True, protodeps = [ "//tensorflow/compiler/xla:xla_proto", ], diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8a75a7b01a9ace..bc024f7144bde2 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -492,8 +492,7 @@ TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) { auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias( - {}, 0, {}, HloInputOutputAliasConfig::kUserAlias)); + TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({}, 0, {})); auto buffers = RunBufferAssignment(module.get()); diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 9c0ed45ad06886..3ee6b200da568a 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1586,11 +1586,9 @@ TEST_F(CopyInsertionTest, CrossingParameters) { builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 4); @@ -1621,11 +1619,9 @@ TEST_F(CopyInsertionTest, ParametersAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1689,8 +1685,7 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); InsertCopies(module.get()); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -1731,8 +1726,7 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1773,8 +1767,7 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -2505,11 +2498,11 @@ ENTRY entry_computation { ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( /*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{})); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( /*output_index=*/{3}, /*param_number=*/1, - /*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{})); InsertCopies(module.get()); @@ -2532,7 +2525,7 @@ ENTRY Entry { ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( /*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{})); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 1); } diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index b9e10bfb083650..7f051d4d1b237a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -8,6 +8,7 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_openmp_copts") load(":build_defs.bzl", "runtime_copts") +load("//tensorflow/core/platform:build_config.bzl", "if_llvm_system_z_available") package( default_visibility = [":friends"], @@ -203,7 +204,9 @@ cc_library( ], "//conditions:default": [ ], - }), + }) + if_llvm_system_z_available([ + "@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep + ]), alwayslink = True, # Contains compiler registration ) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 8ba6f5a7159442..0abcc91a1d7838 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -256,17 +256,15 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( se::DeviceMemoryBase argument_buffer = owning->Release(); *maybe_owning_memory = argument_buffer; result_buffer = argument_buffer; - if (alias->kind == HloInputOutputAliasConfig::kUserAlias) { - // This is a user alias, so a must alias. The caller is giving us the - // input buffer, but in case of error of the execute call, we should - // not be releasing it as it contains valid data (for example, it is a - // parameter which the user wants us to alias, in a gradient update - // computation). So we store the index into the result in the aliased - // vactor, which will be fed to the ExecutionOutput, which will be - // using the indices to drop the addresses from its own - // ScopedShapedBuffer result, if the ExecutionOutput is not committed. - result.AddAliasedIndex(index); - } + // The caller is giving us the + // input buffer, but in case of error of the execute call, we should + // not be releasing it as it contains valid data (for example, it is a + // parameter which the user wants us to alias, in a gradient update + // computation). So we store the index into the result in the aliased + // vactor, which will be fed to the ExecutionOutput, which will be + // using the indices to drop the addresses from its own + // ScopedShapedBuffer result, if the ExecutionOutput is not committed. + result.AddAliasedIndex(index); } else { VLOG(3) << "Using copy-protection: aliasing is specified, but the " "buffer is not donated; allocating a fresh buffer"; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index ff654c83d61e7c..c0222010fd93d5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -25,6 +25,7 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaForceEnableExperimentalLlvmIrGemm = "xla_force_enable_experimental_llvm_ir_gemm"; +const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -63,6 +64,12 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0; } +bool UseLinalgForDot(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaUseLinalgForDot) > 0; +} + static absl::string_view RemoveSuffix(absl::string_view str, absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 3df9ef35bab7af..2231ecfa1e87a3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -122,6 +122,7 @@ extern const char* const kTracingStartSymbolName = extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce"; +extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll"; extern const char* const kCollectivePermuteSymbolName = "__xla_cpu_runtime_CollectivePermute"; extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId"; @@ -154,6 +155,34 @@ struct CollectivePermuteParticipantData : xla::ParticipantData { } }; +struct AllToAllParticipantData : xla::ParticipantData { + AllToAllParticipantData(const xla::RendezvousKey& rendezvous_key_p, + xla::int64 device_ordinal_p, se::Stream* stream_p) + : ParticipantData(rendezvous_key_p, device_ordinal_p, stream_p) {} + + std::vector source_buffers; + std::vector destination_buffers; + int replica_id; + + // Replica ids participating in AllToAll, concatenation happens in the order + // of appearence. + std::vector replica_ids_to_copy_to; + + std::string ToString() const override { + auto addr_formatter = [](std::string* out, + const se::DeviceMemoryBase& mem) { + absl::StrAppend(out, absl::StrFormat("%p", mem.opaque())); + }; + return absl::StrFormat( + "AllToAllParticipantData{replica_id=%d, " + "replica_ids_to_copy_to=[%s], source_buffers=[%s], " + "destination_buffers=[%s]}", + replica_id, absl::StrJoin(replica_ids_to_copy_to, ", "), + absl::StrJoin(source_buffers, ", ", addr_formatter), + absl::StrJoin(destination_buffers, ", ", addr_formatter)); + } +}; + // Inverses the encoding of a Shape protobuf into an LLVM global variable. xla::StatusOr DecodeSelfDescribingShapeConstant( const void* shape_ptr, xla::int32 size_bytes) { @@ -286,6 +315,70 @@ __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( namespace { +class CpuAllToAllRendezvous + : public xla::Rendezvous { + public: + explicit CpuAllToAllRendezvous(const xla::RendezvousKey& k) + : xla::Rendezvous(k) {} + + protected: + xla::StatusOr RunCollectiveOp( + const AllToAllParticipantData& /*participant*/) override { + bool is_primary = InitializationBarrier(); + + if (is_primary) { + tensorflow::mutex_lock lock(mu_); + + CHECK(!participants_.empty()); + CHECK(!participants_[0].source_buffers.empty()); + int expected_buffer_size = participants_[0].source_buffers[0].size(); + + // Replica id -> position in participants_. + absl::flat_hash_map replica_id_map; + + for (int pos = 0; pos < participants_.size(); pos++) { + const AllToAllParticipantData& p = participants_[pos]; + CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size()); + CHECK_EQ(p.source_buffers.size(), participants_.size()); + for (int i = 0; i < p.source_buffers.size(); i++) { + CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size); + CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size); + } + replica_id_map[p.replica_id] = pos; + } + + for (AllToAllParticipantData& p : participants_) { + VLOG(3) << "Processing AllToAll participant data: " << p.ToString(); + for (int j = 0; j < p.source_buffers.size(); j++) { + for (int i = 0; i < p.replica_ids_to_copy_to.size(); i++) { + int replica_id = p.replica_ids_to_copy_to[i]; + int participant_num = xla::FindOrDie(replica_id_map, replica_id); + AllToAllParticipantData& other = participants_[participant_num]; + + // Sort by replica ordering. + std::vector destination_buffers = + other.destination_buffers; + absl::flat_hash_map buffers_index; + for (int idx = 0; idx < destination_buffers.size(); idx++) { + buffers_index[destination_buffers[idx].opaque()] = idx; + } + absl::c_sort( + destination_buffers, [&](const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b) { + return p.replica_ids_to_copy_to[buffers_index[a.opaque()]] < + p.replica_ids_to_copy_to[buffers_index[b.opaque()]]; + }); + + std::memcpy(destination_buffers[j].opaque(), + p.source_buffers[j].opaque(), expected_buffer_size); + } + } + } + } + return ParticipantImplOutput{is_primary, nullptr}; + } +}; + class CpuCollectivePermuteRendezvous : public xla::Rendezvous { public: @@ -486,6 +579,13 @@ GlobalCollectivePermuteRendezvousMap() { return m; } +xla::RefcountingHashMap& +GlobalAllToAllRendezvousMap() { + static auto& m = + *new xla::RefcountingHashMap; + return m; +} + int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) { if (run_options->stream()) { return run_options->stream()->parent()->device_ordinal(); @@ -524,6 +624,48 @@ xla::RendezvousKey GetRendezvousKey( } // namespace +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll( + const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present, + xla::int64 op_id, const void* replica_groups_str, + xla::int32 replica_groups_str_size, xla::int32 num_buffers, + xla::int64 buffer_size, void** source_buffers, void** destination_buffers) { + int device_ordinal = GetDeviceOrdinal(run_options); + xla::int32 replica_id = run_options->device_assignment() + ->ReplicaIdForDeviceOrdinal(device_ordinal) + .ValueOrDie(); + absl::string_view replica_groups_serialized( + static_cast(replica_groups_str), replica_groups_str_size); + std::vector group = + xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie(); + xla::RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, group, channel_id_present, op_id); + + AllToAllParticipantData participant(rendezvous_key, device_ordinal, + run_options->stream()); + participant.replica_id = replica_id; + participant.replica_ids_to_copy_to = + xla::GetParticipatingReplicas( + xla::GlobalDeviceId(device_ordinal), group, + run_options->device_assignment()->replica_count(), + *run_options->device_assignment()) + .ValueOrDie(); + for (int i = 0; i < num_buffers; i++) { + participant.source_buffers.emplace_back(source_buffers[i], buffer_size); + participant.destination_buffers.emplace_back(destination_buffers[i], + buffer_size); + } + auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) { + return absl::make_unique(k); + }; + TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant( + [&] { + return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent( + rendezvous_key, make_cpu_rendezvous); + }, + participant) + .status()); +} + TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, xla::int32 replica_groups_str_size, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 492ce3f68b2fcf..ee75b97e4dc95f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -77,6 +77,7 @@ extern const char* const kCollectivePermuteSymbolName; extern const char* const kReplicaIdSymbolName; extern const char* const kTracingStartSymbolName; extern const char* const kTracingEndSymbolName; +extern const char* const kAllToAllSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. @@ -181,6 +182,12 @@ extern void __xla_cpu_runtime_CollectivePermute( void* output_buffer, const void* source_target_pairs, xla::int32 source_target_pairs_size); +extern void __xla_cpu_runtime_AllToAll( + const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present, + xla::int64 op_id, const void* replica_groups_str, + xla::int32 replica_groups_str_size, xla::int32 num_buffers, + xla::int64 buffer_size, void** source_buffers, void** destination_buffers); + // Write the replica ID into the output buffer. extern void __xla_cpu_runtime_ReplicaId( const xla::ExecutableRunOptions* run_options, void* output_buffer); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 574d83c68c80f9..ee4bcf4cd359b0 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -986,8 +986,7 @@ DotImplementationStrategy GetDotImplementationStrategy( if (IsAlignedGemm(dot_info, target_machine_features)) { if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) { - return primitive_util::IsFloatingPointType( - dot_info.result_shape.element_type()) + return options::UseLinalgForDot(config) ? DotImplementationStrategy::kLinalgMatmul : DotImplementationStrategy::kTiledLlvmIrGemm; } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 0cfab50d0a3067..ebb2df2380567d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -359,7 +359,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { // to the output buffer of its corresponding operand. A GetTupleElement // instruction forwards a pointer to the tuple element buffer at the given // index. - auto operand = get_tuple_element->operand(0); + const HloInstruction* operand = get_tuple_element->operand(0); const Shape& shape = get_tuple_element->shape(); emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement( shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape), @@ -462,64 +462,22 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value * shape_ptr, llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_)); - llvm::Type* int32_type = b_.getInt32Ty(); llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::FunctionType* acquire_type = llvm::FunctionType::get( - i8_ptr_type, - {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, - /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type}, - /*isVarArg=*/false); - - llvm::Function* acquire_func; - if (kind == XfeedKind::kInfeed) { - acquire_func = llvm::dyn_cast( - module_ - ->getOrInsertFunction( - runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type) - .getCallee()); - } else { - acquire_func = llvm::dyn_cast( - module_ - ->getOrInsertFunction( - runtime::kAcquireOutfeedBufferForPopulationSymbolName, - acquire_type) - .getCallee()); - } - acquire_func->setCallingConv(llvm::CallingConv::C); - - llvm::FunctionType* release_type = llvm::FunctionType::get( - b_.getVoidTy(), - {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, - /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type, - /*shape_length*/ int32_type}, - /*isVarArg=*/false); - - llvm::Function* release_func; - if (kind == XfeedKind::kInfeed) { - release_func = llvm::dyn_cast( - module_ - ->getOrInsertFunction( - runtime::kReleaseInfeedBufferAfterDequeueSymbolName, - release_type) - .getCallee()); - } else { - release_func = llvm::dyn_cast( - module_ - ->getOrInsertFunction( - runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, - release_type) - .getCallee()); - } - release_func->setCallingConv(llvm::CallingConv::C); - - // Implementation note: this call informs the runtime that it wants a buffer - // of size exactly 'length_32', and the runtime is responsible for - // check-failing the process if there is a mismatch, versus passing us back a - // buffer that we might overrun. - llvm::Value* acquired_pointer = Call( - acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), - shape_ptr, b_.getInt32(shape_length)}); + const char* acquire_func_name = + kind == XfeedKind::kInfeed + ? runtime::kAcquireInfeedBufferForDequeueSymbolName + : runtime::kAcquireOutfeedBufferForPopulationSymbolName; + + // Implementation note: this call informs the runtime that it wants a + // buffer of size exactly 'length_32', and the runtime is responsible for + // check-failing the process if there is a mismatch, versus passing us + // back a buffer that we might overrun. + llvm::Value* acquired_pointer = + EmitCallToFunc(acquire_func_name, + {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), + shape_ptr, b_.getInt32(shape_length)}, + i8_ptr_type); if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. MemCpy(program_buffer_address, /*DstAlign=*/llvm::Align(1), @@ -532,8 +490,14 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, /*SrcAlign=*/llvm::Align(1), length_32); } - Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), - acquired_pointer, shape_ptr, b_.getInt32(shape_length)}); + const char* release_func_name = + kind == XfeedKind::kInfeed + ? runtime::kReleaseInfeedBufferAfterDequeueSymbolName + : runtime::kReleaseOutfeedBufferAfterPopulationSymbolName; + EmitCallToFunc(release_func_name, + {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), + acquired_pointer, shape_ptr, b_.getInt32(shape_length)}, + b_.getVoidTy()); return Status::OK(); } @@ -624,22 +588,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { lower_dimensions *= normalized_keys_shape.dimensions(i); } - auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply()); CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply())); - llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( - b_.getVoidTy(), - {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), - b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), - b_.getInt32Ty()->getPointerTo(), b_.getInt1Ty(), b_.getInt8PtrTy(), - b_.getInt64Ty()->getPointerTo(), less_than_function->getType()}, - /*isVarArg=*/false); - auto* key_value_sort_func = llvm::dyn_cast( - module_ - ->getOrInsertFunction(runtime::kKeyValueSortSymbolName, - key_value_sort_type) - .getCallee()); - key_value_sort_func->setCallingConv(llvm::CallingConv::C); - key_value_sort_func->setDoesNotThrow(); llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca", &b_); @@ -659,12 +608,15 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { Store(size, slot_in_sizes_alloca); } - Call(key_value_sort_func, - {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), - b_.getInt64(lower_dimensions), values, - b_.getInt32(sort->operand_count()), sizes, - b_.getInt1(sort->is_stable()), GetExecutableRunOptionsArgument(), - GetProfileCountersArgument(), less_than_function}); + auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply()); + EmitCallToFunc( + runtime::kKeyValueSortSymbolName, + {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), + b_.getInt64(lower_dimensions), values, + b_.getInt32(sort->operand_count()), sizes, b_.getInt1(sort->is_stable()), + GetExecutableRunOptionsArgument(), GetProfileCountersArgument(), + less_than_function}, + b_.getVoidTy()); if (sort->values_count() > 0) { llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_); @@ -1138,16 +1090,6 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { llvm::Type* ir_ptr_type = primitive_type == F16 ? b_.getHalfTy()->getPointerTo() : b_.getFloatTy()->getPointerTo(); - llvm::Type* int64_type = b_.getInt64Ty(); - llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo(); - llvm::FunctionType* conv_type = llvm::FunctionType::get( - b_.getVoidTy(), - {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type, - int64_type, int64_type, int64_type, int64_type, int64_type, - int64_type, int64_type, int64_type, int64_type, int64_type, - int64_type, int64_type, int64_type, int64_type, int64_type, - int64_type, int64_type, int64_type, int64_type}, - /*isVarArg=*/false); bool multi_threaded = hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); bool use_mkl_dnn = @@ -1168,37 +1110,35 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded " "conv2d function."; } - llvm::Function* conv_func = llvm::dyn_cast( - module_->getOrInsertFunction(fn_name, conv_type).getCallee()); - conv_func->setCallingConv(llvm::CallingConv::C); - conv_func->setDoesNotThrow(); - conv_func->setOnlyAccessesArgMemory(); - Call(conv_func, { - GetExecutableRunOptionsArgument(), - BitCast(GetEmittedValueFor(convolution), ir_ptr_type), - BitCast(lhs_address, ir_ptr_type), - BitCast(rhs_address, ir_ptr_type), - b_.getInt64(input_batch), - b_.getInt64(input_rows), - b_.getInt64(input_cols), - b_.getInt64(input_channels), - b_.getInt64(kernel_rows), - b_.getInt64(kernel_cols), - b_.getInt64(kernel_channels), - b_.getInt64(kernel_filters), - b_.getInt64(output_rows), - b_.getInt64(output_cols), - b_.getInt64(row_stride), - b_.getInt64(col_stride), - b_.getInt64(padding_top), - b_.getInt64(padding_bottom), - b_.getInt64(padding_left), - b_.getInt64(padding_right), - b_.getInt64(lhs_row_dilation), - b_.getInt64(lhs_col_dilation), - b_.getInt64(rhs_row_dilation), - b_.getInt64(rhs_col_dilation), - }); + EmitCallToFunc(fn_name, + { + GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(convolution), ir_ptr_type), + BitCast(lhs_address, ir_ptr_type), + BitCast(rhs_address, ir_ptr_type), + b_.getInt64(input_batch), + b_.getInt64(input_rows), + b_.getInt64(input_cols), + b_.getInt64(input_channels), + b_.getInt64(kernel_rows), + b_.getInt64(kernel_cols), + b_.getInt64(kernel_channels), + b_.getInt64(kernel_filters), + b_.getInt64(output_rows), + b_.getInt64(output_cols), + b_.getInt64(row_stride), + b_.getInt64(col_stride), + b_.getInt64(padding_top), + b_.getInt64(padding_bottom), + b_.getInt64(padding_left), + b_.getInt64(padding_right), + b_.getInt64(lhs_row_dilation), + b_.getInt64(lhs_col_dilation), + b_.getInt64(rhs_row_dilation), + b_.getInt64(rhs_col_dilation), + }, + b_.getVoidTy(), /*does_not_throw=*/true, + /*only_accesses_arg_memory=*/true); return Status::OK(); } @@ -1234,36 +1174,26 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { // Args have been computed, make the call. llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo(); - llvm::Type* int32_type = b_.getInt32Ty(); - llvm::Type* int64_type = b_.getInt64Ty(); - llvm::FunctionType* fft_type = llvm::FunctionType::get( - b_.getVoidTy(), - {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, - int32_type, int64_type, int64_type, int64_type, int64_type}, - /*isVarArg=*/false); - bool multi_threaded_eigen = hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); const char* fn_name = multi_threaded_eigen ? runtime::kEigenFftSymbolName : runtime::kEigenSingleThreadedFftSymbolName; - - llvm::Function* fft_func = llvm::dyn_cast( - module_->getOrInsertFunction(fn_name, fft_type).getCallee()); - fft_func->setCallingConv(llvm::CallingConv::C); - fft_func->setDoesNotThrow(); - fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); const int fft_rank = fft_length.size(); - Call(fft_func, - {GetExecutableRunOptionsArgument(), - BitCast(GetEmittedValueFor(fft), int8_ptr_type), - BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()), - b_.getInt32(operand->shape().element_type() == F64 || - operand->shape().element_type() == C128), - b_.getInt32(fft_rank), b_.getInt64(input_batch), - b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), - b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), - b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); + EmitCallToFunc( + fn_name, + {GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(fft), int8_ptr_type), + BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()), + b_.getInt32(operand->shape().element_type() == F64 || + operand->shape().element_type() == C128), + b_.getInt32(fft_rank), b_.getInt64(input_batch), + b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), + b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), + b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}, + b_.getVoidTy(), /*does_not_throw=*/true, + /*only_accesses_arg_memory=*/false, + /*only_accesses_inaccessible_mem_or_arg_mem=*/true); return Status::OK(); } @@ -1337,31 +1267,6 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { crs->to_apply()->ToString()); } - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* int32_type = b_.getInt32Ty(); - llvm::Type* int64_type = b_.getInt64Ty(); - llvm::FunctionType* all_reduce_func_ty = - llvm::FunctionType::get(b_.getVoidTy(), - {/*run_options=*/i8_ptr_type, - /*replica_groups=*/i8_ptr_type, - /*replica_groups_size=*/int32_type, - /*channel_id_present=*/int32_type, - /*op_id=*/int64_type, - /*reduction_kind=*/int32_type, - /*shape_ptr=*/i8_ptr_type, - /*shape_length=*/int32_type, - /*num_buffers=*/int32_type, - /*input_buffer=*/i8_ptr_type, - /*output_buffer=*/i8_ptr_type}, - /*isVarArg=*/false); - - auto all_reduce_func = llvm::dyn_cast( - module_ - ->getOrInsertFunction(runtime::kAllReduceSymbolName, - all_reduce_func_ty) - .getCallee()); - all_reduce_func->setCallingConv(llvm::CallingConv::C); - std::string replica_groups = ReplicaGroupsToString(crs->replica_groups()); int32 replica_groups_size = replica_groups.size(); llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); @@ -1402,25 +1307,28 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { llvm_ir::EncodeSelfDescribingShapeConstant( crs->shape(), &shape_length, &b_)); - Call(all_reduce_func, - {/*run_options=*/GetExecutableRunOptionsArgument(), - /*replica_groups=*/replica_groups_v, - /*replica_groups_size=*/b_.getInt32(replica_groups_size), - - /*channel_id_present=*/ - b_.getInt32(static_cast(crs->channel_id().has_value())), - /*op_id=*/ - b_.getInt64(crs->channel_id().has_value() - ? *crs->channel_id() - : crs->GetModule()->unique_id()), - /*reduction_kind=*/ - b_.getInt32( - static_cast(*MatchReductionComputation(crs->to_apply()))), - /*shape_ptr=*/shape_ptr, - /*shape_length=*/b_.getInt32(shape_length), - /*num_buffers=*/b_.getInt32(crs->operand_count()), - /*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type), - /*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)}); + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + EmitCallToFunc( + runtime::kAllReduceSymbolName, + {/*run_options=*/GetExecutableRunOptionsArgument(), + /*replica_groups=*/replica_groups_v, + /*replica_groups_size=*/b_.getInt32(replica_groups_size), + + /*channel_id_present=*/ + b_.getInt32(static_cast(crs->channel_id().has_value())), + /*op_id=*/ + b_.getInt64(crs->channel_id().has_value() + ? *crs->channel_id() + : crs->GetModule()->unique_id()), + /*reduction_kind=*/ + b_.getInt32( + static_cast(*MatchReductionComputation(crs->to_apply()))), + /*shape_ptr=*/shape_ptr, + /*shape_length=*/b_.getInt32(shape_length), + /*num_buffers=*/b_.getInt32(crs->operand_count()), + /*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type), + /*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)}, + b_.getVoidTy()); return Status::OK(); } @@ -1432,6 +1340,61 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { return HandleAllReduceMultipleReplica(crs); } +Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { + auto* instr = Cast(instruction); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction)); + CHECK(!instr->split_dimension() && instr->shape().IsTuple()) + << "Only tuple AllToAll is supported"; + + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + std::string replica_groups = + ReplicaGroupsToString(instruction->replica_groups()); + int32 replica_groups_size = replica_groups.size(); + llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); + + int64 buffer_size = -1; + std::vector input_buffer_ptrs; + std::vector output_buffer_ptrs; + + for (int64 i = 0; i < instruction->operand_count(); i++) { + const HloInstruction* op = instruction->operand(i); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueSlice(instruction, {i})); + const Shape& operand_shape = instruction->operand(i)->shape(); + CHECK(operand_shape.IsArray()) + << "Operands to all-to-all must be arrays: " << instruction->ToString(); + output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); + input_buffer_ptrs.push_back(GetEmittedValueFor(op)); + CHECK(buffer_size == -1 || buffer_size == out_slice.size()); + buffer_size = out_slice.size(); + } + + llvm::Value* input_buffers = + EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_); + llvm::Value* output_buffers = + EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_); + + EmitCallToFunc( + runtime::kAllToAllSymbolName, + {/*run_options=*/GetExecutableRunOptionsArgument(), + /*channel_id_present=*/ + b_.getInt32(static_cast(instruction->channel_id().has_value())), + /*op_id=*/ + b_.getInt64(instruction->channel_id().has_value() + ? *instruction->channel_id() + : instruction->GetModule()->unique_id()), + /*replica_groups=*/replica_groups_v, + /*replica_groups_size=*/b_.getInt32(replica_groups_size), + /*num_buffers=*/b_.getInt32(instruction->operand_count()), + /*buffer_size=*/b_.getInt64(buffer_size), + /*source_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type), + /*destination_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)}, + b_.getVoidTy()); + + llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_); + return Status::OK(); +} + Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { auto* instr = Cast(crs); std::string source_target_pairs = absl::StrJoin( @@ -1439,30 +1402,6 @@ Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { llvm::Value* source_target_pairs_v = b_.CreateGlobalStringPtr(source_target_pairs); - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* int32_type = b_.getInt32Ty(); - llvm::Type* int64_type = b_.getInt64Ty(); - llvm::FunctionType* collective_permute_func_ty = - llvm::FunctionType::get(b_.getVoidTy(), - { - /*run_options=*/i8_ptr_type, - /*channel_id_present=*/int32_type, - /*op_id=*/int64_type, - /*byte_size=*/int32_type, - /*input_buffer=*/i8_ptr_type, - /*output_buffer=*/i8_ptr_type, - /*source_target_pairs=*/i8_ptr_type, - /*source_target_pairs_size=*/int32_type, - }, - /*isVarArg=*/false); - - auto collective_permute_func = llvm::dyn_cast( - module_ - ->getOrInsertFunction(runtime::kCollectivePermuteSymbolName, - collective_permute_func_ty) - .getCallee()); - collective_permute_func->setCallingConv(llvm::CallingConv::C); - Shape shape = crs->operand(0)->shape(); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, @@ -1473,44 +1412,37 @@ Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { assignment_.GetUniqueSlice(crs, {})); llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape); - Call(collective_permute_func, - {/*run_options=*/GetExecutableRunOptionsArgument(), - /*channel_id_present=*/ - b_.getInt32(static_cast(crs->channel_id().has_value())), - /*op_id=*/ - b_.getInt64(crs->channel_id().has_value() - ? *crs->channel_id() - : crs->GetModule()->unique_id()), - /*byte_size=*/b_.getInt32(ShapeUtil::ByteSizeOf(shape)), - /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type), - /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type), - /*source_target_pairs=*/source_target_pairs_v, - /*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())}); + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + EmitCallToFunc( + runtime::kCollectivePermuteSymbolName, + {/*run_options=*/GetExecutableRunOptionsArgument(), + /*channel_id_present=*/ + b_.getInt32(static_cast(crs->channel_id().has_value())), + /*op_id=*/ + b_.getInt64(crs->channel_id().has_value() + ? *crs->channel_id() + : crs->GetModule()->unique_id()), + /*byte_size=*/b_.getInt32(ShapeUtil::ByteSizeOf(shape)), + /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type), + /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type), + /*source_target_pairs=*/source_target_pairs_v, + /*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())}, + b_.getVoidTy()); return Status::OK(); } Status IrEmitter::HandleReplicaId(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::FunctionType* replica_id_function_ty = - llvm::FunctionType::get(b_.getVoidTy(), - {/*run_options=*/i8_ptr_type, - /*output_buffer=*/i8_ptr_type}, - /*isVarArg=*/false); - auto* replica_id_func = llvm::dyn_cast( - module_ - ->getOrInsertFunction(runtime::kReplicaIdSymbolName, - replica_id_function_ty) - .getCallee()); - replica_id_func->setCallingConv(llvm::CallingConv::C); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, assignment_.GetUniqueSlice(hlo, {})); llvm::Value* output_buffer = EmitBufferPointer(output_slice, hlo->shape()); - Call(replica_id_func, - {/*run_options=*/GetExecutableRunOptionsArgument(), - /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)}); - + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + EmitCallToFunc( + runtime::kReplicaIdSymbolName, + {/*run_options=*/GetExecutableRunOptionsArgument(), + /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)}, + b_.getVoidTy()); return Status::OK(); } @@ -2017,10 +1949,6 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { return DefaultAction(reduce); } -Status IrEmitter::HandleAllToAll(HloInstruction*) { - return Unimplemented("AllToAll is not implemented on CPU."); -} - Status IrEmitter::HandleSend(HloInstruction* send) { // TODO(b/33942983): Support Send/Recv on CPU. return Unimplemented("Send is not implemented on CPU."); @@ -2484,28 +2412,13 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { // TODO(b/66051036): Run the msan instrumentation pass instead. const llvm::DataLayout& dl = module_->getDataLayout(); llvm::Type* intptr_type = b_.getIntPtrTy(dl); - auto* msan_unpoison_ir_function = llvm::cast( - module_ - ->getOrInsertFunction( - "__msan_unpoison", - llvm::FunctionType::get( - /*Result=*/b_.getVoidTy(), - /*Params=*/{i8_ptr_type, intptr_type}, /*isVarArg=*/false)) - .getCallee()); - Call(msan_unpoison_ir_function, - {PointerCast(operands_alloca, i8_ptr_type), - llvm::ConstantInt::get( - intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)}); - } - auto* custom_call_ir_function = llvm::dyn_cast( - module_ - ->getOrInsertFunction( - custom_call->custom_call_target(), - llvm::FunctionType::get( - /*Result=*/b_.getVoidTy(), - /*Params=*/{i8_ptr_type, operands_alloca->getType()}, - /*isVarArg=*/false)) - .getCallee()); + EmitCallToFunc( + "__msan_unpoison", + {PointerCast(operands_alloca, i8_ptr_type), + llvm::ConstantInt::get( + intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)}, + b_.getVoidTy()); + } TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); // Write the tuple table if the output is a tuple. @@ -2526,7 +2439,8 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { auto* output_address_arg = PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); - Call(custom_call_ir_function, {output_address_arg, operands_alloca}); + EmitCallToFunc(custom_call->custom_call_target(), + {output_address_arg, operands_alloca}, b_.getVoidTy()); return Status::OK(); } @@ -2728,6 +2642,31 @@ llvm::Value* IrEmitter::EmitPrintf(absl::string_view fmt, call_args); } +llvm::Value* IrEmitter::EmitCallToFunc( + std::string func_name, const std::vector& arguments, + llvm::Type* return_type, bool does_not_throw, bool only_accesses_arg_memory, + bool only_accesses_inaccessible_mem_or_arg_mem) { + std::vector types; + types.reserve(arguments.size()); + absl::c_transform(arguments, std::back_inserter(types), + [&](llvm::Value* val) { return val->getType(); }); + llvm::FunctionType* func_type = + llvm::FunctionType::get(return_type, types, /*isVarArg=*/false); + auto func = llvm::dyn_cast( + module_->getOrInsertFunction(func_name, func_type).getCallee()); + func->setCallingConv(llvm::CallingConv::C); + if (does_not_throw) { + func->setDoesNotThrow(); + } + if (only_accesses_arg_memory) { + func->setOnlyAccessesArgMemory(); + } + if (only_accesses_inaccessible_mem_or_arg_mem) { + func->setOnlyAccessesInaccessibleMemOrArgMem(); + } + return b_.CreateCall(func, arguments); +} + void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, int64 element_count, PrimitiveType primitive_type, @@ -2749,10 +2688,10 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, element_alignment); target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { - auto* memcpy_instruction = - MemCpy(target, /*DstAlign=*/llvm::Align(element_alignment), source, - /*SrcAlign=*/llvm::Align(element_alignment), - element_count * primitive_type_size); + auto* memcpy_instruction = b_.CreateMemCpy( + target, /*DstAlign=*/llvm::Align(element_alignment), source, + /*SrcAlign=*/llvm::Align(element_alignment), + element_count * primitive_type_size); // The memcpy does the load and the store internally. The aliasing related // metadata has to reflect that. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index cef9b817503279..3955deefbeadba 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -424,6 +425,14 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::Value* EmitPrintf(absl::string_view fmt, absl::Span arguments); + // Emits a call to a non-variadic function `func_name` with arguments + // `arguments` assuming C calling convention. + llvm::Value* EmitCallToFunc( + std::string func_name, const std::vector& arguments, + llvm::Type* return_type, bool does_not_throw = true, + bool only_accesses_arg_memory = false, + bool only_accesses_inaccessible_mem_or_arg_mem = false); + // Assignment of the buffers needed by the computation and their shape // information. const BufferAssignment& assignment_; diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc index d17f4671327525..ff48f554ce69ff 100644 --- a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc @@ -33,11 +33,19 @@ namespace { // Lower an MLIR module to an LLVM module. std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module) { + // When set, the LLVM backend will be allowed to reassociate floating-point + // reductions, which enables much more efficient "horizontal" SIMD + // implementations. + // TODO(kramerb): link this to the right option, command line flag, etc. + constexpr bool kReassociateFPReductions = true; + mlir::PassManager manager(module->getContext()); manager.addPass(mlir::createConvertLinalgToLoopsPass()); manager.addPass(mlir::createLowerAffinePass()); manager.addPass(mlir::createLowerToCFGPass()); - manager.addPass(mlir::createConvertVectorToLLVMPass()); + manager.addPass(mlir::createConvertVectorToLLVMPass( + mlir::LowerVectorToLLVMOptions().setReassociateFPReductions( + kReassociateFPReductions))); CHECK(succeeded(manager.run(*module))); return mlir::translateModuleToLLVMIR(*module); } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 4bdb601a3d19bf..631c6985b03fb4 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -241,6 +241,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); REGISTER_CPU_RUNTIME_SYMBOL(AllReduce); REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute); + REGISTER_CPU_RUNTIME_SYMBOL(AllToAll); REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index b7186c186f43a7..6ebbf62261457d 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -1602,6 +1602,17 @@ Status DynamicDimensionInference::AnalyzeDynamicDimensions() { custom_call_handler_); } +void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith( + HloInstruction* replace, HloInstruction* with) { + CHECK(Shape::Equal()(replace->shape(), ShapeUtil::MakeScalarShape(S32))); + CHECK(Shape::Equal()(with->shape(), ShapeUtil::MakeScalarShape(S32))); + for (auto& kv : dynamic_mapping_) { + if (kv.second == replace) { + kv.second = with; + } + } +} + Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst, const ShapeIndex& index) { diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h index 417f0289143429..607d68bd9c3448 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -68,6 +68,11 @@ class DynamicDimensionInference { SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1)); } + // For all tensors whose dynamic dimension is `replace`, replace them with + // `with`. + void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace, + HloInstruction* with); + friend class DynamicDimensionInferenceVisitor; private: diff --git a/tensorflow/compiler/xla/service/gpu/alias_passthrough_params.cc b/tensorflow/compiler/xla/service/gpu/alias_passthrough_params.cc index 5266379b573de6..d86ddd7cdf153e 100644 --- a/tensorflow/compiler/xla/service/gpu/alias_passthrough_params.cc +++ b/tensorflow/compiler/xla/service/gpu/alias_passthrough_params.cc @@ -37,12 +37,10 @@ StatusOr AliasPassthroughParams::Run(HloModule* module) { << " in module " << module->name() << " is passed-through to root tuple element " << i << ": " << root->shape().ToString(); - // Use must-alias semantics (kUserAlias) for pass-through params. TF_RETURN_IF_ERROR(module->input_output_alias_config().SetUpAlias( /*output_index=*/{i}, /*param_number=*/root->operand(i)->parameter_number(), - /*param_index=*/{}, - HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{})); used_params.insert(root->operand(i)->parameter_number()); changed = true; } diff --git a/tensorflow/compiler/xla/service/gpu/alias_passthrough_params_test.cc b/tensorflow/compiler/xla/service/gpu/alias_passthrough_params_test.cc index d349b871bffe88..a3c88e7478412f 100644 --- a/tensorflow/compiler/xla/service/gpu/alias_passthrough_params_test.cc +++ b/tensorflow/compiler/xla/service/gpu/alias_passthrough_params_test.cc @@ -38,12 +38,8 @@ TEST_F(AliasPassthroughParamsTest, AliasPassThroughParams) { EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).ValueOrDie()); const auto& alias_config = module->input_output_alias_config(); EXPECT_EQ(0, alias_config.GetAliasedParameter({0}).value().parameter_number); - EXPECT_EQ(xla::HloInputOutputAliasConfig::kUserAlias, - alias_config.GetAliasedParameter({0}).value().kind); EXPECT_FALSE(alias_config.OutputHasAlias({1})); EXPECT_EQ(1, alias_config.GetAliasedParameter({2}).value().parameter_number); - EXPECT_EQ(xla::HloInputOutputAliasConfig::kUserAlias, - alias_config.GetAliasedParameter({2}).value().kind); } TEST_F(AliasPassthroughParamsTest, DoNotAliasPassThroughParamsMoreThanOnce) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 2b31099d26f9c9..758bba90bd2aef 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -464,6 +464,8 @@ StatusOr> GpuCompiler::RunHloPasses( return std::move(module); } +// The order of `thunk_sequence` corresponds to +// `hlo_schedule->ThunkLaunchOrder()`. static Status CompileModuleToLlvmIrImpl( HloModule* hlo_module, llvm::LLVMContext* llvm_context, const std::string& target_triple, const std::string& data_layout, @@ -471,19 +473,18 @@ static Status CompileModuleToLlvmIrImpl( absl::optional cuda_compute_capability, const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function, int pointer_size, std::unique_ptr* llvm_module, - std::unique_ptr* stream_assignment, - std::unique_ptr* hlo_schedule, std::unique_ptr* buffer_assignment, - std::unique_ptr* thunk_sequence) { + std::unique_ptr* thunk_schedule) { *llvm_module = absl::make_unique("", *llvm_context); (*llvm_module)->setTargetTriple(target_triple); (*llvm_module)->setDataLayout(data_layout); - *stream_assignment = AssignStreams(*hlo_module); + std::unique_ptr stream_assignment = + AssignStreams(*hlo_module); TF_ASSIGN_OR_RETURN( - *hlo_schedule, - GpuHloSchedule::Build(*hlo_module, **stream_assignment, pointer_size)); + std::unique_ptr hlo_schedule, + GpuHloSchedule::Build(*hlo_module, *stream_assignment, pointer_size)); auto buffer_size_bytes_function = [pointer_size](const BufferValue& buffer_value) -> int64 { @@ -493,7 +494,7 @@ static Status CompileModuleToLlvmIrImpl( TF_ASSIGN_OR_RETURN( *buffer_assignment, BufferAssigner::Run( - hlo_module, (*hlo_schedule)->ConsumeHloOrdering(), + hlo_module, hlo_schedule->ConsumeHloOrdering(), buffer_size_bytes_function, /*color_alignment=*/ [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, @@ -518,9 +519,49 @@ static Status CompileModuleToLlvmIrImpl( { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); - TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); + + absl::flat_hash_map thunk_to_hlo; + ThunkSequence thunk_sequence; + absl::Span order = hlo_schedule->ThunkLaunchOrder(); + for (HloInstruction* instruction : order) { + TF_RETURN_IF_ERROR(instruction->Visit(&ir_emitter)); + TF_RETURN_IF_ERROR(ir_emitter.Postprocess(instruction)); + std::unique_ptr thunks = ir_emitter.ConsumeThunkSequence(); + + // The invariants between each input HloInstruction* and output Thunk* are + // not all explicitly checked, but at least we can document them here: + // * The entry HloComputation shall not have dead code (all reachable from + // ROOT). + // * For each visit of HloInstruction, either none or one Thunk will be + // returned. + // * If there is a thunk returned, thunk->hlo_instruction() equals the + // input HloInstruction*. + TF_RET_CHECK(thunks->size() <= 1) << instruction->ToString(); + if (!thunks->empty()) { + auto thunk = std::move(thunks->front()); + InsertOrDie(&thunk_to_hlo, thunk.get(), instruction); + thunk_sequence.push_back(std::move(thunk)); + } + } + // TODO(timshen): ThunkSchedule taking thunk_to_hlo is a bit awkward. To fix + // that, we can turn it into a proper pass, from: + // map -> (ThunkSchedule, [Thunk...]) + // to: + // map -> GenerateMultiStreamDepInfo() -> [(Thunk, + // DepInfo)...] + // + // where "DepInfo" is + // struct { + // int stream_number; + // std::vector dependencies; + // std::vector users; + // }; + // We might want to do this after MLIR migration. + *thunk_schedule = absl::make_unique( + std::make_unique(std::move(thunk_sequence)), + std::move(stream_assignment), std::move(thunk_to_hlo)); } - *thunk_sequence = ir_emitter.ConsumeThunkSequence(); + return Status::OK(); } @@ -563,16 +604,14 @@ StatusOr> GpuCompiler::RunBackend( }(); std::unique_ptr llvm_module; - std::unique_ptr stream_assignment; - std::unique_ptr hlo_schedule; std::unique_ptr buffer_assignment; - std::unique_ptr thunk_sequence; + std::unique_ptr thunk_schedule; TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( module.get(), &llvm_context, target_triple_, data_layout_, stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability, - GetCanShareBuffer(), pointer_size_, &llvm_module, &stream_assignment, - &hlo_schedule, &buffer_assignment, &thunk_sequence)); + GetCanShareBuffer(), pointer_size_, &llvm_module, &buffer_assignment, + &thunk_schedule)); if (user_pre_optimization_hook_) { user_pre_optimization_hook_(*llvm_module); @@ -609,9 +648,6 @@ StatusOr> GpuCompiler::RunBackend( CompileTargetBinary(module.get(), llvm_module.get(), gpu_version, stream_exec)); - auto thunk_schedule = absl::make_unique( - std::move(thunk_sequence), std::move(stream_assignment), - hlo_schedule->ThunkLaunchOrder()); if (DumpingEnabledForHloModule(*module)) { DumpToFileInDirOrStdout(*module, "", "thunk_schedule", thunk_schedule->ToString()); @@ -667,16 +703,13 @@ StatusOr> CompileModuleToLlvmIr( absl::optional cuda_compute_capability, int pointer_size) { std::unique_ptr llvm_module; - std::unique_ptr stream_assignment; - std::unique_ptr hlo_schedule; std::unique_ptr buffer_assignment; - std::unique_ptr thunk_sequence; + std::unique_ptr thunk_schedule; TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( hlo_module, llvm_context, target_triple, data_layout, platform_name, gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction, - pointer_size, &llvm_module, &stream_assignment, &hlo_schedule, - &buffer_assignment, &thunk_sequence)); + pointer_size, &llvm_module, &buffer_assignment, &thunk_schedule)); return llvm_module; } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index c3bc6489d737d0..89c5e123a48334 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -173,14 +173,12 @@ Status GpuExecutable::ExecuteThunks( std::map> thunk_to_finish_event; std::vector> deferred_host_callbacks; for (Thunk* thunk : thunk_schedule_->TotalOrder()) { - CHECK(thunk->hlo_instruction()); // Annotate execution of this op if tracing was enabled when we started // running this module. If tracing is enabled *while* we're running the // module, we won't get any data, but that's probably an OK trade-off. ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); - int32 stream_no = - thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction()); + int32 stream_no = thunk_schedule_->StreamNumberForThunk(thunk); se::Stream* stream = (stream_no == 0 ? main_stream : sub_streams[stream_no - 1].get()); @@ -188,8 +186,7 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } - VLOG(2) << "Executing the thunk for " - << thunk->hlo_instruction()->ToString() << " on stream " + VLOG(2) << "Executing the thunk for " << thunk->name() << " on stream " << stream_no; const GpuExecutableRunOptions* gpu_options = run_options->run_options().gpu_executable_run_options(); @@ -499,17 +496,15 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( se::DeviceMemoryBase argument_buffer = owning->Release(); *maybe_owning_memory = argument_buffer; result_buffer = argument_buffer; - if (alias->kind == HloInputOutputAliasConfig::kUserAlias) { - // This is a user alias, so a must alias. The caller is giving us the - // input buffer, but in case of error from the execute call, we should - // not be releasing it as it contains valid data (for example, it is a - // parameter which the user wants us to alias, in a gradient update - // computation). So we store the index into the result in the aliased - // vector, which will be fed to the ExecutionOutput, which will use - // the indices to drop the addresses from its own ScopedShapedBuffer - // result, if the ExecutionOutput is not committed. - result.AddAliasedIndex(index); - } + // The caller is giving us the + // input buffer, but in case of error from the execute call, we should + // not be releasing it as it contains valid data (for example, it is a + // parameter which the user wants us to alias, in a gradient update + // computation). So we store the index into the result in the aliased + // vector, which will be fed to the ExecutionOutput, which will use + // the indices to drop the addresses from its own ScopedShapedBuffer + // result, if the ExecutionOutput is not committed. + result.AddAliasedIndex(index); } else if (src_hlo->opcode() != HloOpcode::kParameter) { // The guard is above is not to insert copy-protection when aliasing // pass-through params, as we do not need to write into the output diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 17f372679ee2f8..23b29df6ec8db5 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -119,7 +119,9 @@ void HloToIrBindings::EmitBasePointersForHlos( if (slice.allocation()->is_thread_local()) { llvm::Type* pointee_type = llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); - BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type), + BindHloToIrValue(*non_io_hlo, + llvm_ir::EmitAllocaAtFunctionEntry( + pointee_type, /*name=*/"", b_), index); } else if (slice.allocation()->is_constant()) { llvm::Value* global_for_constant = module_->getGlobalVariable( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index a789e450b922d3..4c4ae47cd695b0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -149,10 +149,7 @@ IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, IrEmitterContext* ir_emitter_context) : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false), - hlo_computation_(hlo_computation) { - // Initialize thunk_sequence_ to an empty list of thunks. - thunk_sequence_.reset(new ThunkSequence()); -} + hlo_computation_(hlo_computation) {} Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { bindings_.UnbindAllLocalIrValues(); @@ -475,7 +472,7 @@ Status IrEmitterUnnested::HandlePadToStatic(HloInstruction* pad_to_static) { .EmitLoop(ir_name, GetIndexTypeForKernel( pad_to_static, launch_dimensions.launch_bound(), &b_))); - thunk_sequence_->emplace_back(std::move(kernel_thunk)); + thunk_sequence_.emplace_back(std::move(kernel_thunk)); return Status::OK(); } @@ -584,7 +581,7 @@ Status IrEmitterUnnested::HandleSliceToDynamic( .EmitLoop(ir_name, GetIndexTypeForKernel( slice_to_dynamic, launch_dimensions.launch_bound(), &b_))); - thunk_sequence_->emplace_back(std::move(kernel_thunk)); + thunk_sequence_.emplace_back(std::move(kernel_thunk)); return Status::OK(); } @@ -1227,8 +1224,10 @@ Status IrEmitterUnnested::EmitScatter( llvm::Value* output_address = GetIrArray(*output_hlo, *output_hlo) .EmitArrayElementAddress(input_window_index, &b_); - llvm::Value* input_address = Alloca(llvm_ir::PrimitiveTypeToIrType( - updates->shape().element_type(), module_)); + llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(updates->shape().element_type(), + module_), + "input_address", &b_); TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index)); Store(input_ir_value, input_address); @@ -2161,7 +2160,7 @@ Status IrEmitterUnnested::EmitTargetElementLoop( BuildKernelThunk(&hlo, /*implements_whole_instruction=*/true); Status emit_status = EmitTargetElementLoopInThunk( hlo, body_emitter, kernel_thunk.get(), unroll_factor); - thunk_sequence_->emplace_back(std::move(kernel_thunk)); + thunk_sequence_.emplace_back(std::move(kernel_thunk)); return emit_status; } @@ -2423,15 +2422,18 @@ void IrEmitterUnnested::EmitPrologueForReduction( llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(reduce_inst->shape().element_type(), ir_emitter_context_->llvm_module()); - llvm::AllocaInst* reduction_input_address = Alloca(element_type); + llvm::AllocaInst* reduction_input_address = + llvm_ir::EmitAllocaAtFunctionEntry(element_type, + "reduction_input_address", &b_); reduction_input_addresses->push_back(reduction_input_address); int num_partial_results = reduction_info->GetNumPartialResults(); AddressVector* partial_result_addresses = reduction_info->GetMutablePartialResultAddresses(); llvm::AllocaInst* partial_result_address = - Alloca(element_type, /*ArraySize=*/b_.getInt32(num_partial_results), - "partial_reduction_result." + llvm::Twine(i)); + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + element_type, /*ArraySize=*/b_.getInt32(num_partial_results), + ("partial_reduction_result." + llvm::Twine(i)).str(), &b_); partial_result_addresses->push_back(partial_result_address); // Initialize the partial result with the initial value of the reduction. @@ -2505,8 +2507,8 @@ void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce( llvm::Value* partial_result_address) { for (int distance = 16; distance >= 1; distance /= 2) { int bit_width = llvm_ir::GetSizeInBits(element_type); - llvm::Value* result_from_other_lane = - Alloca(element_type, nullptr, "result_from_other_lane"); + llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry( + element_type, "result_from_other_lane", &b_); // Bitcast cannot be applied to aggregate types (even packed ones), so // we bitcast addresses of load/store to intN* of the same bit-width. llvm::Type* shuffled_value_type = @@ -3750,7 +3752,7 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( GetIndexTypeForKernel( unnested_hlo, launch_dimensions.launch_bound(), &b_)); - thunk_sequence_->emplace_back(std::move(kernel_thunk)); + thunk_sequence_.emplace_back(std::move(kernel_thunk)); return emit_status; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index d8314f2895fd56..1be3b8dbd26ff3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -97,7 +97,7 @@ class IrEmitterUnnested : public IrEmitter, // Transfers the ownship of thunk_sequence_ out. std::unique_ptr ConsumeThunkSequence() { - return std::move(thunk_sequence_); + return std::make_unique(std::move(thunk_sequence_)); } Status DefaultAction(HloInstruction* hlo) override; @@ -145,10 +145,12 @@ class IrEmitterUnnested : public IrEmitter, // Emits LLVM global variables corresponding to constant instructions. Status EmitConstantGlobals(); + Status Postprocess(HloInstruction* hlo) override; + private: // Add a owning Thunk object to the thunk sequence. void AddThunkToThunkSequence(std::unique_ptr thunk) override { - thunk_sequence_->emplace_back(std::move(thunk)); + thunk_sequence_.emplace_back(std::move(thunk)); } // Input = {static array, dynamic_dim0, dynamic_dim1} @@ -543,13 +545,11 @@ class IrEmitterUnnested : public IrEmitter, absl::optional thread_id_filter = absl::nullopt, absl::optional block_id_filter = absl::nullopt); - Status Postprocess(HloInstruction* hlo) override; - // Returns the last generated thunk. - Thunk* LastThunk() const { return thunk_sequence_->back().get(); } + Thunk* LastThunk() const { return thunk_sequence_.back().get(); } // The thunk sequence this IrEmitter generates for the input computation. - std::unique_ptr thunk_sequence_; + ThunkSequence thunk_sequence_; // The HloComputation that this IrEmitter emits code for. const HloComputation* hlo_computation_; diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 20a8d3b13a93ff..25ab1b54f075ee 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -45,7 +45,11 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { return Status::OK(); } CHECK(ShapeUtil::Compatible(hlo_instruction()->operand(0)->shape(), - outfeed_buffers->shape())); + outfeed_buffers->shape())) + << "XLA program outfeed request of shape " + << hlo_instruction()->operand(0)->shape().ToString() + << " did not match the runtime's outfeed buffer of shape " + << outfeed_buffers->shape().ToString(); TF_RETURN_IF_ERROR(outfeed_buffers->ForEachMutableElementWithStatus( [&](const ShapeIndex& index, std::unique_ptr* buffer) { diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc index abca1f0cf18dd6..215c2e627ae4e0 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc @@ -34,7 +34,7 @@ namespace { class ReductionVectorizationTest : public GpuCodegenTest {}; -TEST_F(ReductionVectorizationTest, Power2) { +TEST_F(ReductionVectorizationTest, DISABLED_Power2) { const char* hlo_text = R"( HloModule ReducePower2 @@ -82,7 +82,7 @@ CHECK: ld.global.nc.f32 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } -TEST_F(ReductionVectorizationTest, TileFit) { +TEST_F(ReductionVectorizationTest, DISABLED_TileFit) { const char* hlo_text = R"( HloModule ReduceTileFit @@ -130,7 +130,7 @@ CHECK: ld.global.nc.f32 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } -TEST_F(ReductionVectorizationTest, EvenColumns) { +TEST_F(ReductionVectorizationTest, DISABLED_EvenColumns) { const char* hlo_text = R"( HloModule ReducePower2 @@ -183,7 +183,7 @@ CHECK: ld.global.nc.f32 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } -TEST_F(ReductionVectorizationTest, DisabledOddColumns) { +TEST_F(ReductionVectorizationTest, DISABLED_DisabledOddColumns) { const char* hlo_text = R"( HloModule ReduceTileFit @@ -212,7 +212,7 @@ CHECK-NOT: ld.global.u64 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } -TEST_F(ReductionVectorizationTest, Exp) { +TEST_F(ReductionVectorizationTest, DISABLED_Exp) { const char* hlo_text = R"( HloModule DisableSin @@ -262,7 +262,7 @@ CHECK: ld.global.nc.f32 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } -TEST_F(ReductionVectorizationTest, DisableSin) { +TEST_F(ReductionVectorizationTest, DISABLED_DisableSin) { const char* hlo_text = R"( HloModule DisableSin diff --git a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo index b1cfb826e5f88c..796c0adadd220b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo @@ -2,6 +2,7 @@ // CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) { // CHECK: entry: +// CHECK: %[[VAL_32:.*]] = alloca i32, align 4 // CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 // CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x [3 x i32]]* // CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0 @@ -34,7 +35,6 @@ // CHECK: br label %[[VAL_22]] // CHECK: scatter.in_bounds-true: ; preds = %[[VAL_21]] // CHECK: %[[VAL_31:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_2]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_18]] -// CHECK: %[[VAL_32:.*]] = alloca i32, align 4 // CHECK: %[[VAL_33:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_11]] to i32* // CHECK: %[[VAL_34:.*]] = getelementptr inbounds i32, i32* %[[VAL_33]], i32 %[[VAL_15]] // CHECK: %[[VAL_35:.*]] = load i32, i32* %[[VAL_34]], align 4, !invariant.load !4, !noalias !5 @@ -77,6 +77,7 @@ ENTRY main { // CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* align 64 dereferenceable(4) %alloc0, i8* align 16 dereferenceable(4) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 %alloc3) { // CHECK: entry: +// CHECK: %[[VAL_60:.*]] = alloca i32, align 4 // CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0 // CHECK: %[[VAL_39:.*]] = bitcast i8* %[[VAL_37]] to i32* // CHECK: %[[VAL_40:.*]] = getelementptr inbounds i8, i8* %[[VAL_41:.*]], i64 0 @@ -100,7 +101,6 @@ ENTRY main { // CHECK: scatter.in_bounds-after: ; preds = %[[VAL_59]], %[[VAL_55]] // CHECK: br label %[[VAL_56]] // CHECK: scatter.in_bounds-true: ; preds = %[[VAL_55]] -// CHECK: %[[VAL_60:.*]] = alloca i32, align 4 // CHECK: %[[VAL_61:.*]] = load i32, i32* %[[VAL_48]], align 4, !invariant.load !3, !noalias !4 // CHECK: store i32 %[[VAL_61]], i32* %[[VAL_60]], align 4 // CHECK: %[[VAL_62:.*]] = load i32, i32* %[[VAL_60]], align 4 @@ -140,6 +140,7 @@ ENTRY main { // CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) { // CHECK: %[[VAL_63:.*]] = alloca i32, align 4 // CHECK: %[[VAL_64:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_98:.*]] = alloca i32, align 4 // CHECK: %[[VAL_65:.*]] = getelementptr inbounds i8, i8* %[[VAL_66:.*]], i64 0 // CHECK: %[[VAL_67:.*]] = bitcast i8* %[[VAL_65]] to [3 x [3 x i32]]* // CHECK: %[[VAL_68:.*]] = getelementptr inbounds i8, i8* %[[VAL_69:.*]], i64 0 @@ -172,7 +173,6 @@ ENTRY main { // CHECK: br label %[[VAL_87]] // CHECK: scatter.in_bounds-true: ; preds = %[[VAL_86]] // CHECK: %[[VAL_97:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_67]], i32 0, i32 %[[VAL_92]], i32 %[[VAL_83]] -// CHECK: %[[VAL_98:.*]] = alloca i32, align 4 // CHECK: %[[VAL_99:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_76]] to i32* // CHECK: %[[VAL_100:.*]] = getelementptr inbounds i32, i32* %[[VAL_99]], i32 %[[VAL_80]] // CHECK: %[[VAL_101:.*]] = load i32, i32* %[[VAL_100]], align 4, !invariant.load !4, !noalias !5 @@ -233,6 +233,7 @@ ENTRY main { // CHECK-LABEL: define void @scatter_ScalarUpdate(i8* align 64 dereferenceable(16) %alloc0, i8* align 16 dereferenceable(16) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 dereferenceable(4) %alloc3) { // CHECK: entry: +// CHECK: %[[VAL_146:.*]] = alloca i32, align 4 // CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0 // CHECK: %[[VAL_120:.*]] = bitcast i8* %[[VAL_118]] to [4 x i32]* // CHECK: %[[VAL_121:.*]] = getelementptr inbounds i8, i8* %[[VAL_122:.*]], i64 0 @@ -261,7 +262,6 @@ ENTRY main { // CHECK: br label %[[VAL_137]] // CHECK: scatter.in_bounds-true: ; preds = %[[VAL_136]] // CHECK: %[[VAL_145:.*]] = getelementptr inbounds [4 x i32], [4 x i32]* %[[VAL_120]], i32 0, i32 %[[VAL_141]] -// CHECK: %[[VAL_146:.*]] = alloca i32, align 4 // CHECK: %[[VAL_147:.*]] = load i32, i32* %[[VAL_129]], align 4, !invariant.load !3, !noalias !4 // CHECK: store i32 %[[VAL_147]], i32* %[[VAL_146]], align 4 // CHECK: %[[VAL_148:.*]] = load i32, i32* %[[VAL_146]], align 4 diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index e9be41b74dee6c..d0477d374af49c 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -72,15 +72,18 @@ class Thunk { // generated from, but Thunk never uses this argument other than to save it // to Thunk::hlo_instruction, so it can be null. explicit Thunk(Kind kind, const HloInstruction* hlo_instruction) - : kind_(kind), hlo_instruction_(hlo_instruction) {} + : kind_(kind), + hlo_instruction_(hlo_instruction), + name_(hlo_instruction_ ? hlo_instruction_->name() : "") {} virtual ~Thunk() {} Thunk(const Thunk&) = delete; Thunk& operator=(const Thunk&) = delete; Kind kind() const { return kind_; } - const HloInstruction* hlo_instruction() const { return hlo_instruction_; } string profile_annotation() const { return profile_annotation_; } + absl::string_view name() const { return name_; } + // Constructs and caches the profile annotation string for this thunk and // any child thunks. virtual void ComputeAnnotations() { @@ -123,6 +126,8 @@ class Thunk { virtual Status ExecuteOnStream(const ExecuteParams& params) = 0; protected: + const HloInstruction* hlo_instruction() const { return hlo_instruction_; } + const HloModuleConfig& GetModuleConfig() const { return hlo_instruction()->GetModule()->config(); } @@ -142,6 +147,7 @@ class Thunk { private: Kind kind_; const HloInstruction* hlo_instruction_; + std::string name_; string profile_annotation_; }; diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index daa5f33e5604c4..af9543b57d8e93 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -34,7 +34,7 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands( // list if `operand` is assigned to a different stream. As an optimization, // we skip `operand`'s operands because `operand` depends on them already. if (stream_assignment_->StreamNumberForHlo(operand) != - stream_assignment_->StreamNumberForHlo(*thunk.hlo_instruction())) { + stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(&thunk))) { depends_on_[&thunk].push_back(FindOrDie(hlo_to_thunk, &operand)); } } else { @@ -50,22 +50,21 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands( ThunkSchedule::ThunkSchedule( std::unique_ptr thunks, std::unique_ptr stream_assignment, - const std::vector& hlo_total_order) + absl::flat_hash_map thunk_to_hlo) : thunks_(std::move(thunks)), - stream_assignment_(std::move(stream_assignment)) { - absl::flat_hash_map hlo_to_thunk; - for (const auto& thunk : *thunks_) { - InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); + stream_assignment_(std::move(stream_assignment)), + thunk_to_hlo_(std::move(thunk_to_hlo)) { + for (auto& thunk : *thunks_) { + thunk_total_order_.push_back(thunk.get()); } - for (HloInstruction* hlo : hlo_total_order) { - if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) { - thunk_total_order_.push_back(*thunk); - } + absl::flat_hash_map hlo_to_thunk; + for (const auto& thunk : *thunks_) { + InsertOrDie(&hlo_to_thunk, thunk_to_hlo_.at(thunk.get()), thunk.get()); } for (const Thunk* thunk : thunk_total_order_) { - const auto* dst = thunk->hlo_instruction(); + const auto* dst = thunk_to_hlo_.at(thunk); CHECK(stream_assignment_->HasStreamAssigned(*dst)); for (const auto* src : dst->operands()) { AddDependenciesOnTransitiveOperands(*thunk, *src, hlo_to_thunk); @@ -116,13 +115,13 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() { } int dst_stream = - stream_assignment_->StreamNumberForHlo(*dst->hlo_instruction()); + stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(dst)); std::list& sources = FindOrDie(depends_on_, dst); for (auto iter = sources.begin(); iter != sources.end();) { const Thunk* src = *iter; // `dst` depends on `src`. int src_stream = - stream_assignment_->StreamNumberForHlo(*src->hlo_instruction()); + stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(src)); int src_order = FindOrDie(thunk_to_total_order, src); if (src_order <= last_dependency(dst_stream, src_stream)) { iter = sources.erase(iter); @@ -165,8 +164,8 @@ string ThunkSchedule::ToString() const { absl::string_view kind_str = ThunkKindToString(thunk->kind()); absl::StrAppend(&result, kind_str, string(max_thunk_kind_len - kind_str.length(), ' '), "\t"); - if (thunk->hlo_instruction() != nullptr) { - absl::StrAppend(&result, thunk->hlo_instruction()->ToString()); + if (thunk_to_hlo_.at(thunk) != nullptr) { + absl::StrAppend(&result, thunk_to_hlo_.at(thunk)->ToString()); } else { absl::StrAppend(&result, "(no HloInstruction)"); } @@ -176,8 +175,8 @@ string ThunkSchedule::ToString() const { for (const auto& entry : depends_on_) { const Thunk* dependent = entry.first; for (const Thunk* dependency : entry.second) { - absl::StrAppend(&result, "\t", dependent->hlo_instruction()->name(), - " depends on ", dependency->hlo_instruction()->name(), + absl::StrAppend(&result, "\t", thunk_to_hlo_.at(dependent)->name(), + " depends on ", thunk_to_hlo_.at(dependency)->name(), "\n"); } } diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h index 549378debd5241..3801dc8aee8f30 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -46,9 +46,14 @@ namespace gpu { // "depends" on A. class ThunkSchedule { public: - ThunkSchedule(std::unique_ptr thunks, - std::unique_ptr stream_assignment, - const std::vector& hlo_total_order); + // `thunk_to_hlo` is an one-to-one map. Every thunk in this container maps to + // an HLO, but not every HLO ever exists produces a Thunk. + // + // thunk_to_hlo.keys() == set(thunks). + ThunkSchedule( + std::unique_ptr thunks, + std::unique_ptr stream_assignment, + absl::flat_hash_map thunk_to_hlo); // Returns the total order of executing all the thunks. const std::vector& TotalOrder() const { return thunk_total_order_; } @@ -62,8 +67,8 @@ class ThunkSchedule { // Delegates to StreamAssignment. int StreamCount() const { return stream_assignment_->StreamCount(); } - int StreamNumberForHlo(const HloInstruction& hlo) const { - return stream_assignment_->StreamNumberForHlo(hlo); + int StreamNumberForThunk(const Thunk* thunk) const { + return stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(thunk)); } string ToString() const; @@ -89,6 +94,8 @@ class ThunkSchedule { std::list empty_thunk_list_; std::unique_ptr stream_assignment_; + + absl::flat_hash_map thunk_to_hlo_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 134c8953b15972..960f60fe88216c 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -284,18 +284,6 @@ message HloScheduleProto { } message HloInputOutputAliasProto { - enum Kind { - // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 - // behavior and missing has_*() APIs. - UNDEFINED_ALIAS = 0; - // An alias setup by the user as must alias. A use setting USER_ALIAS is - // expecting the designed output to be dropped over the given input - // parameter number+index. - USER_ALIAS = 1; - // An alias setup by the compiler as part of its optimizations. - SYSTEM_ALIAS = 2; - } - // The following proto describes a pair of aliased an input // (described by parameter number and a ShapeIndex of the parameter) // and an output (described by a ShapeIndex of the root @@ -316,8 +304,8 @@ message HloInputOutputAliasProto { int64 parameter_number = 2; // ShapeIndex of the parameter instruction. repeated int64 parameter_shape_index = 3; - // The kind of alias to be setup. - Kind kind = 4; + reserved 4; + reserved "kind"; } repeated AliasEntryProto entries = 1; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 1ef007cc817751..2666cb0872d3c9 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -241,16 +241,13 @@ TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); // Cannot alias an output twice. ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -287,16 +284,13 @@ TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1})); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); // Cannot alias an output twice. ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -378,11 +372,9 @@ TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); const HloAliasAnalysis& analysis = RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 4b86a58c35c079..438aa6ff05f549 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -249,12 +249,7 @@ bool HloComputation::HasSideEffect() const { } bool HloComputation::IsMarkedAsDead(const HloInstruction* inst) { - for (auto& to_be_delete : to_be_deleted_) { - if (to_be_delete.get() == inst) { - return true; - } - } - return false; + return inst->IsMarkedAsDead(); } Status HloComputation::RemoveInstructionAndUnusedOperands( @@ -320,6 +315,7 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, (*inst_it->second)->set_parent(nullptr); to_be_deleted_.emplace_back(inst_it->second->release()); to_be_deleted_.back()->DetachFromOperandsAndUsers(); + to_be_deleted_.back()->MarkAsDead(); instructions_.erase(inst_it->second); instruction_iterators_.erase(inst_it); return Status::OK(); @@ -549,7 +545,11 @@ string HloComputation::ToString( if (options.print_percent()) { s << "%"; } - s << PrintName(name(), options.print_ids()) << " "; + if (options.print_ids() || !IsEntryComputation()) { + // Exclude entry computation's name because it includes and leads to + // non-deterministic fingerprint. + s << PrintName(name(), options.print_ids()) << " "; + } } if (options.print_program_shape()) { diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc index fdb311adb5d102..9415e20af7bd14 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -28,7 +28,7 @@ namespace { StatusOr ReplaceGetSize( HloInstruction* instr, - const DynamicDimensionInference* dynamic_dimension_inference) { + DynamicDimensionInference* dynamic_dimension_inference) { if (instr->opcode() != HloOpcode::kGetDimensionSize) { return false; } @@ -47,11 +47,18 @@ StatusOr ReplaceGetSize( dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); if (dynamic_size != nullptr) { TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); + // The dependency between a instruction and its dynamic dimensions is not + // modeled in the IR. As instr is being replaced by dynamic_size, also tell + // dynamic dimension inference that the instruction is being replaced. + dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith( + instr, dynamic_size); } else { int32 size = instr->operand(0)->shape().dimensions(dim); HloInstruction* new_instr = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr, + new_instr); } return true; } @@ -95,14 +102,14 @@ StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { // // This will get static size of the op, which is incorrect. for (auto* computation : module->computations()) { - for (auto instruction : computation->instructions()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool replaced_get_size, ReplaceGetSize(instruction, &inference)); changed = changed || replaced_get_size; } } for (auto* computation : module->computations()) { - for (auto instruction : computation->instructions()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction)); changed = changed || replaced_set_size; } diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc index d96f2db3c2654c..b1491e96095167 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/types.h" @@ -55,6 +56,24 @@ ENTRY gds { op::Multiply(op::Constant(), op::Constant())); } +TEST_F(HloGetDimensionSizeRewriterTest, GetSetSetDimensionSizeRewriter) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule _ +ENTRY gds { + p = s32[3,4] parameter(0) + size0 = s32[] get-dimension-size(p), dimensions={0} + p_copy = s32[3,4] copy(p) + p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0} + size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0} + ROOT mul = s32[] multiply(size0, size1) +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Multiply(op::Constant(), op::Constant())); +} + TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) { auto module = ParseAndReturnUnverifiedModule(R"( HloModule _ diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc index c4825b06bc85c2..e123161720b3e1 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -26,10 +26,7 @@ bool HloInputOutputAliasConfig::OutputHasAlias( Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index, - AliasKind kind) { - TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias) - << kind; + const ShapeIndex& param_index) { TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) << "Trying to set up alias at " << output_index.ToString() << " which is an invalid index for shape " @@ -44,8 +41,7 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, param_number, param_index.ToString(), output_index.ToString(), alias_.element(output_index)->parameter_number, alias_.element(output_index)->parameter_index.ToString()); - (*alias_.mutable_element(output_index)) = - Alias(kind, param_number, param_index); + (*alias_.mutable_element(output_index)) = Alias(param_number, param_index); VLOG(4) << "Set up alias between output index " << output_index.ToString() << " and parameter " << param_index << " at index " << param_index.ToString(); @@ -58,16 +54,6 @@ HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { [&](const ShapeIndex& index, const absl::optional& data) { if (data) { HloInputOutputAliasProto::AliasEntryProto entry; - switch (data->kind) { - case AliasKind::kUserAlias: - entry.set_kind(HloInputOutputAliasProto::USER_ALIAS); - break; - case AliasKind::kSystemAlias: - entry.set_kind(HloInputOutputAliasProto::SYSTEM_ALIAS); - break; - default: - LOG(FATAL) << "Unknown alias kind " << data->kind; - } for (int64 i : index) { entry.add_output_shape_index(i); } @@ -91,14 +77,8 @@ StatusOr HloInputOutputAliasConfig::CreateFromProto( int64 param_number = entry.parameter_number(); ShapeIndex param_index(entry.parameter_shape_index().begin(), entry.parameter_shape_index().end()); - // Handle backward compatibility with existing protos, which only knew of - // system aliases. - AliasKind kind = AliasKind::kSystemAlias; - if (entry.kind() == HloInputOutputAliasProto::USER_ALIAS) { - kind = AliasKind::kUserAlias; - } TF_RETURN_IF_ERROR( - result.SetUpAlias(output_index, param_number, param_index, kind)); + result.SetUpAlias(output_index, param_number, param_index)); } return result; } @@ -113,9 +93,9 @@ string HloInputOutputAliasConfig::ToString() const { ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) { pieces.push_back(absl::StrFormat( - " OutputIndex %s is aliased (kind=%s) with parameter %lld at %s:", - output_index.ToString(), AliasKindToString(alias.kind), - alias.parameter_number, alias.parameter_index.ToString())); + " OutputIndex %s is aliased with parameter %lld at %s:", + output_index.ToString(), alias.parameter_number, + alias.parameter_index.ToString())); }); return absl::StrJoin(pieces, "\n"); } @@ -124,30 +104,14 @@ string HloInputOutputAliasConfig::ToShortString() const { std::vector pieces; for (const auto& p : alias_) { const ShapeIndex& index = p.first; - absl::optional alias = p.second; - if (!alias) { - continue; + if (absl::optional alias = p.second) { + pieces.push_back( + absl::StrFormat("%s: %s", index.ToString(), alias->ToString())); } - pieces.push_back( - absl::StrFormat("%s: %s", index.ToString(), alias->ToString())); } return absl::StrJoin(pieces, ", "); } -absl::optional -HloInputOutputAliasConfig::ParameterAliasKind( - int64 param_number, const ShapeIndex& param_index) const { - absl::optional kind; - alias_.ForEachElement( - [&](const xla::ShapeIndex&, absl::optional alias) { - if (alias && alias->parameter_number == param_number && - alias->parameter_index == param_index) { - kind = alias->kind; - } - }); - return kind; -} - absl::optional HloInputOutputAliasConfig::GetAliasedOutput( int64 param_number, const ShapeIndex& param_index) const { absl::optional output; diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h index 1a5b5f475bbb04..d5ca28e9387bbb 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -32,43 +32,22 @@ class HloModule; // parameter index in the entry computation. class HloInputOutputAliasConfig { public: - // The kind of aliases which can be set. A kUserAlias is one setup at - // compilation time by the user, and has to be respected. A kSystemAlias one - // might be setup by the compiler, if it decides it is convenient to do so. - enum AliasKind { - kUserAlias, - kSystemAlias, - }; - - static std::string AliasKindToString(AliasKind kind) { - switch (kind) { - case kUserAlias: - return "USER"; - case kSystemAlias: - return "SYSTEM"; - } - } - // Defines the alias information for a given output buffer. A given output // buffer shape index can refer only to one parameter+index. struct Alias { - Alias(AliasKind kind, int64 parameter_number, ShapeIndex parameter_index) - : kind(kind), - parameter_number(parameter_number), + Alias(int64 parameter_number, ShapeIndex parameter_index) + : parameter_number(parameter_number), parameter_index(std::move(parameter_index)) {} - AliasKind kind; int64 parameter_number; ShapeIndex parameter_index; std::string ToString() { - if (kind == kUserAlias) { - return absl::StrFormat("(%lld, %s)", parameter_number, - parameter_index.ToString()); + if (parameter_index.empty()) { + return absl::StrCat(parameter_number); } - return absl::StrFormat("(%lld, %s, %s)", parameter_number, - parameter_index.ToString(), - AliasKindToString(kind)); + return absl::StrFormat("(%lld, %s)", parameter_number, + parameter_index.ToString()); } }; @@ -82,19 +61,13 @@ class HloInputOutputAliasConfig { // Sets up alias config from `output_index` to `param_index` at // `param_number`. Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index, - AliasKind kind = AliasKind::kUserAlias); - - // Returns the kind of alias for the given parameter number and parameter - // index. - absl::optional ParameterAliasKind( - int64 param_number, const ShapeIndex& param_index) const; + const ShapeIndex& param_index); // Returns true if the given parameter is aliased with one of the output // buffers. bool ParameterHasAlias(int64 param_number, const ShapeIndex& param_index) const { - return ParameterAliasKind(param_number, param_index).has_value(); + return GetAliasedOutput(param_number, param_index).has_value(); } // Checks whether the provided output index has already been aliased. diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc index 8293d495878b2c..d873b3f3b7e0ad 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -86,8 +86,7 @@ ENTRY main { TF_ASSERT_OK(config.SetUpAlias( /*output_index=*/{0}, /*param_number=*/1, - /*param_index=*/{}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{})); expect_aliased(/*output_index=*/{0}, /*param_number=*/1, /*param_index=*/{}, config); @@ -118,13 +117,11 @@ ENTRY main { TF_ASSERT_OK(config.SetUpAlias( /*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{0}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{0})); TF_ASSERT_OK(config.SetUpAlias( /*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{1}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{1})); expect_aliased(/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, config); @@ -157,13 +154,11 @@ ENTRY main { TF_ASSERT_OK(config.SetUpAlias( /*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{})); TF_ASSERT_OK(config.SetUpAlias( /*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{})); ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); @@ -188,8 +183,7 @@ ENTRY main { TF_ASSERT_OK(config.SetUpAlias( /*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{})); ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); @@ -214,13 +208,11 @@ ENTRY main { TF_ASSERT_OK(config.SetUpAlias( /*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{})); ASSERT_IS_NOT_OK(config.SetUpAlias( /*output_index=*/{0}, /*param_number=*/1, - /*param_index=*/{}, - /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); + /*param_index=*/{})); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index e9a04583bdfa45..9957df41f1aefe 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1889,7 +1889,7 @@ Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) { void HloInstruction::AppendOperand(HloInstruction* operand) { if (operand->parent() != nullptr) { - CHECK(!operand->parent()->IsMarkedAsDead(operand)) + DCHECK(!operand->parent()->IsMarkedAsDead(operand)) << "Operand " << operand->name() << " is already marked dead"; } operands_.push_back(operand); @@ -2839,7 +2839,8 @@ HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) : unique_id_(-1), opcode_(opcode), shape_(shape), - name_(HloOpcodeString(opcode)) { + name_(HloOpcodeString(opcode)), + marked_as_dead_(false) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 95850a8d9da78c..8c50a9bb8fca01 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1931,6 +1931,8 @@ class HloInstruction { }; private: + friend class HloComputation; + // Implementation for non-common logic of CloneWithNewOperands. virtual std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, @@ -1953,15 +1955,11 @@ class HloInstruction { virtual bool IsElementwiseImpl( const absl::optional& operand_idx) const; - // Prints an operand to a string. + // Prints an operand to a string. Accessed by friend class HloInstruction. virtual string OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const; - // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and - // OperandsToStringWithCanonicalNameMap() functions. - friend class HloComputation; - // See comments on Identical(). virtual bool IdenticalSlowPath( const HloInstruction& other, @@ -1990,6 +1988,13 @@ class HloInstruction { // given proto. Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; + // Mark this instruction as dead. Accessed by friend class HloInstruction. + void MarkAsDead() { marked_as_dead_ = true; } + + // Has this instruction been marked as dead? Accessed by friend class + // HloInstruction. + bool IsMarkedAsDead() const { return marked_as_dead_; } + int unique_id_; // Unique to this HloInstruction within a HloModule // Opcode for this instruction. @@ -2071,6 +2076,10 @@ class HloInstruction { // outer-most dimension first). std::vector outer_dimension_partitions_; + // Intrusive flag used by HloComputation, whether this instruction has + // been marked as dead. + bool marked_as_dead_; + TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); }; diff --git a/tensorflow/compiler/xla/service/hlo_live_range_test.cc b/tensorflow/compiler/xla/service/hlo_live_range_test.cc index e2d320beffd6bf..7a45733d79bba1 100644 --- a/tensorflow/compiler/xla/service/hlo_live_range_test.cc +++ b/tensorflow/compiler/xla/service/hlo_live_range_test.cc @@ -237,8 +237,7 @@ TEST_F(HloLiveRangeTest, AliasedParameter) { HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY)); module_->AddEntryComputation(builder.Build()); // Set up alias of the first parameter. - TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - {}, 0, {}, HloInputOutputAliasConfig::kUserAlias)); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias({}, 0, {})); HloSchedule schedule(module_.get()); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 220443d527640c..c715d016c4f0d3 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -230,7 +230,10 @@ void HloModule::ReplaceComputations( string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << PrintName(name(), options.print_ids()); + // When print_ids() is false, exclude module's name because it includes and + // leads to non-deterministic fingerprint. + s << "HloModule " + << (options.print_ids() ? PrintName(name(), options.print_ids()) : ""); if (has_schedule()) { TF_CHECK_OK(schedule().Verify()); s << ", is_scheduled=true"; @@ -687,6 +690,7 @@ std::unique_ptr HloModule::Clone(const HloModuleConfig& config, HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); module->AddEntryComputation(std::move(cloned_computation)); + module->input_output_alias_config() = input_output_alias_config(); if (has_schedule() && schedule().Verify().ok()) { HloSchedule clone_schedule(module.get()); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 22cd34f337872b..d47be84e7fc046 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -547,39 +547,36 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) { } std::string errmsg = "Expected format: : (, " - ")"; + ") OR : "; if (!ParseToken(TokKind::kColon, errmsg)) { return false; } - if (!ParseToken(TokKind::kLparen, errmsg)) { - return false; - } - int64 param_num; - ParseInt64(¶m_num); - if (!ParseToken(TokKind::kComma, errmsg)) { - return false; - } - ShapeIndex param_idx; - if (!ParseShapeIndex(¶m_idx)) { - return false; - } - HloInputOutputAliasConfig::AliasKind alias_kind = - HloInputOutputAliasConfig::kUserAlias; - if (EatIfPresent(TokKind::kComma)) { - std::string type; - ParseName(&type); - if (type == "SYSTEM") { - alias_kind = HloInputOutputAliasConfig::kSystemAlias; - } else if (type == "USER") { - alias_kind = HloInputOutputAliasConfig::kUserAlias; - } else { - return TokenError("Unexpected aliasing kind; expected SYSTEM or USER"); + + if (lexer_.GetKind() != TokKind::kLparen) { + // Short form: "{0}: 0", output index "{}" is assumed. + int64 param_num; + ParseInt64(¶m_num); + data->emplace(std::piecewise_construct, std::forward_as_tuple(out), + std::forward_as_tuple(param_num, ShapeIndex{})); + } else { + // Long form: "{0}: (0, {0})", output index is explicitly specified. + if (!ParseToken(TokKind::kLparen, errmsg)) { + return false; + } + int64 param_num; + ParseInt64(¶m_num); + if (!ParseToken(TokKind::kComma, errmsg)) { + return false; + } + ShapeIndex param_idx; + if (!ParseShapeIndex(¶m_idx)) { + return false; + } + data->emplace(std::piecewise_construct, std::forward_as_tuple(out), + std::forward_as_tuple(param_num, param_idx)); + if (!ParseToken(TokKind::kRparen, errmsg)) { + return false; } - } - data->emplace(std::piecewise_construct, std::forward_as_tuple(out), - std::forward_as_tuple(alias_kind, param_num, param_idx)); - if (!ParseToken(TokKind::kRparen, errmsg)) { - return false; } if (!EatIfPresent(TokKind::kComma)) { @@ -627,9 +624,8 @@ bool HloParserImpl::ParseHloModule(HloModule* module) { if (aliasing_data) { HloInputOutputAliasConfig alias_config(module->result_shape()); for (auto& p : *aliasing_data) { - Status st = - alias_config.SetUpAlias(p.first, p.second.parameter_number, - p.second.parameter_index, p.second.kind); + Status st = alias_config.SetUpAlias(p.first, p.second.parameter_number, + p.second.parameter_index); if (!st.ok()) { return TokenError(st.error_message()); } @@ -2678,7 +2674,9 @@ struct MinMaxFiniteValue { template <> struct MinMaxFiniteValue { - static double max() { return static_cast(bfloat16::highest()); } + static double max() { + return static_cast(Eigen::NumTraits::highest()); + } static double min() { return -max(); } }; diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 800bf94b2583a9..484578e5e0e1da 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -2399,7 +2399,7 @@ ENTRY c2 { TEST_F(HloParserTest, SimpleAliasing) { const string original = R"( -HloModule Module, input_output_alias={ {0}: (0, {0}, USER), {1}: (0, {1}, USER) } +HloModule Module, input_output_alias={ {0}: (0, {0}), {1}: (0, {1}) } ENTRY entry { %p = (f32[], f32[]) parameter(0) @@ -2419,24 +2419,38 @@ ENTRY entry { TEST_F(HloParserTest, SimpleAliasingShortForm) { const string original = R"( -HloModule Module, input_output_alias={ {0}: (0, {0}), {1}: (0, {1}) } +HloModule Module, input_output_alias={ {0}: 0, {1}: 1 } ENTRY entry { - %p = (f32[], f32[]) parameter(0) - %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0 - %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1 + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) ROOT %out = (f32[], f32[]) tuple(%p0, %p1) } )"; auto module = ParseAndReturnVerifiedModule(original); TF_ASSERT_OK(module.status()); std::unique_ptr parsed_module = module.ConsumeValueOrDie(); - EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}), + EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {}), ShapeIndex{0}); - EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}), + EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(1, {}), ShapeIndex{1}); } +TEST_F(HloParserTest, SimpleAliasingShortFormError) { + const string original = R"( +HloModule Module, input_output_alias={ {0}: A, {1}: 1 } + +ENTRY entry { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %out = (f32[], f32[]) tuple(%p0, %p1) +} + )"; + ExpectHasSubstr( + ParseAndReturnUnverifiedModule(original).status().error_message(), + "expects integer"); +} + TEST_F(HloParserTest, NestedAliasing) { const string original = R"( HloModule Module, input_output_alias={ {0, 0}: (0, {0}), {1, 1}: (0, {1}) } @@ -2539,22 +2553,6 @@ ENTRY entry { "expects integer"); } -TEST_F(HloParserTest, AliasingUnexpectedKind) { - const string original = R"( -HloModule Module, input_output_alias={ {0}: (0, {0}, UNKNOWN), {1}: (0, {1}, UNKNOWN) } - -ENTRY entry { - %p = (f32[], f32[]) parameter(0) - %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0 - %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1 - ROOT %out = (f32[], f32[]) tuple(%p0, %p1) -} - )"; - ExpectHasSubstr( - ParseAndReturnUnverifiedModule(original).status().error_message(), - "Unexpected aliasing kind"); -} - TEST_F(HloParserTest, MultipleRoots) { const string original = R"(HloModule multiple_roots: ENTRY consts { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index bfc6769660a61d..2166ecdd890037 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -746,7 +746,7 @@ Status MemoryUsageTracker::EndInstruction() { Buffer& buffer = buffers_.at(buffer_id); buffer.unfinished_user_count--; CHECK_GE(buffer.unfinished_user_count, 0) - << buffer.ToString() << " has negative unfinished use count."; + << buffer.ToString() << " has negative unfinished user count."; if (buffer.unfinished_user_count == 0) { // Buffer is now dead. VLOG(3) << " " << buffer.ToString() << " is now dead."; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 398f07d4a405b0..10e11e55291c74 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -3247,10 +3247,8 @@ TEST_P(MemorySpaceAssignmentTest, InputOutputAlias) { TF_CHECK_OK(module->set_schedule(schedule)); // Make input {0} alias with output {0} and input {1} alias with output {1}. - TF_CHECK_OK(module->input_output_alias_config().SetUpAlias( - {0}, 0, {0}, HloInputOutputAliasConfig::AliasKind::kSystemAlias)); - TF_CHECK_OK(module->input_output_alias_config().SetUpAlias( - {1}, 0, {1}, HloInputOutputAliasConfig::AliasKind::kSystemAlias)); + TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0})); + TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1})); AssignMemorySpace(module.get()); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 04d9b09a9ae710..113c9764b40bfb 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -162,13 +162,13 @@ cc_library( "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", + "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation", + "//tensorflow/compiler/mlir/hlo:legalize_to_linalg", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/mlir/hlo:lhlo_copy_removal", "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", - "//tensorflow/compiler/mlir/hlo:xla_legalize_tanh_to_approximation", - "//tensorflow/compiler/mlir/hlo:xla_legalize_to_linalg", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc index 8ff45a67269caf..4b06cda8744a27 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc @@ -42,7 +42,7 @@ using ::mlir::RankedTensorType; using ::mlir::Type; using ::mlir::Value; -namespace hlo = ::mlir::xla_hlo; +namespace hlo = ::mlir::mhlo; // TODO(b/137624192) Use tablegen for this. StatusOr InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, @@ -185,7 +185,7 @@ Status HloDialectEmitter::HandleConstant(HloInstruction* instr) { Status HloDialectEmitter::HandleGather(HloInstruction* instr) { HloGatherInstruction* gather = static_cast(instr); - mlir::xla_hlo::GatherDimensionNumbers dimension_numbers = + mlir::mhlo::GatherDimensionNumbers dimension_numbers = xla::CreateGatherDimensionNumbers(gather->gather_dimension_numbers(), builder_); mlir::DenseIntElementsAttr slice_sizes = CreateDenseIntElementsAttrFromVector( diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index c89c43fa37d66a..648c44d9ac1e20 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -60,7 +60,7 @@ namespace xla { namespace mlir_gpu { namespace { -using ::mlir::xla_lhlo::FusionOp; +using ::mlir::lmhlo::FusionOp; // Replaces a FusionOp by the operations contained in its region. struct FusionOpRemover @@ -392,7 +392,7 @@ struct RewriteKernelSignature } }; -// Extract_element(xla_hlo_scalars_to_dimension_tensor(v_i), i) -> v_i +// Extract_element(mhlo_scalars_to_dimension_tensor(v_i), i) -> v_i // // We need to direct fusion to the inner loops. This cannot be done with // a passmanager alone ATM, as nested pass managers require operations to @@ -457,20 +457,20 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { } // Legalize from HLO to LHLO. - pm.addPass(::mlir::xla_hlo::createLegalizeToLhloPass()); + pm.addPass(::mlir::mhlo::createLegalizeToLhloPass()); // Moving `AllocOp`s and inserting missing `DeallocOp`s pm.addPass(::mlir::createBufferPlacementPass()); // Next, we can strip the outer fusion operation. pm.addPass(absl::make_unique()); // Remove unnecessary LHLO copies. - pm.addPass(::mlir::xla_lhlo::createLhloCopyRemovalPass()); + pm.addPass(::mlir::lmhlo::createLhloCopyRemovalPass()); // Transform LHLO operations to LinAlg. - pm.addPass(::mlir::xla_lhlo::createLegalizeLhloToLinalgPass()); + pm.addPass(::mlir::lmhlo::createLegalizeLhloToLinalgPass()); // Fuse linalg operations. - pm.addPass(::mlir::xla_lhlo::createLhloFuseLinalg(/*use_parallel_loops=*/true, - tiling_for_unrolling)); + pm.addPass(::mlir::lmhlo::createLhloFuseLinalg(/*use_parallel_loops=*/true, + tiling_for_unrolling)); // Legalize reduce operations directly to GPU dialect. - pm.addPass(::mlir::xla_lhlo::createLegalizeToGpuPass()); + pm.addPass(::mlir::lmhlo::createLegalizeToGpuPass()); // Transform the Linalg operations inside of the loop nest into parallel // loops. pm.addPass(::mlir::createConvertLinalgToParallelLoopsPass()); @@ -512,7 +512,7 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { // Approximate of requested. if (options.use_approximations) { pm.addNestedPass<::mlir::FuncOp>( - ::mlir::xla::createLegalizeTanhToApproximationPass()); + ::mlir::hlo::createLegalizeTanhToApproximationPass()); } // Move scalar operations into the launch to ensure smaller signatures. pm.addPass(absl::make_unique()); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 62ca93907520e4..194eb4618d34f0 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -58,7 +58,7 @@ using ::xla::gpu::Thunk; using ::xla::gpu::ThunkEmitter; using ::xla::gpu::ThunkSequence; -namespace lhlo = ::mlir::xla_lhlo; +namespace lhlo = ::mlir::lmhlo; // TODO(b/137624192) Use tablegen for this. Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc, @@ -202,15 +202,14 @@ LhloDialectEmitter::LhloDialectEmitter( mlir_module_(mlir_module), builder_(mlir_module_.getContext()), buffer_assignment_(assignment), - platform_(platform), - thunk_sequence_(new ThunkSequence()) { + platform_(platform) { LLVMDialect* llvmDialect = mlir_module.getContext()->getRegisteredDialect(); pointer_size_ = llvmDialect->getLLVMModule().getDataLayout().getPointerSize(); } void LhloDialectEmitter::AddThunkToThunkSequence(std::unique_ptr thunk) { - thunk_sequence_->push_back(std::move(thunk)); + thunk_sequence_.push_back(std::move(thunk)); } StatusOr LhloDialectEmitter::MaybeGetAllocationSlice( @@ -226,10 +225,6 @@ absl::string_view LhloDialectEmitter::platform_name() const { return platform_->Name(); } -Status LhloDialectEmitter::EmitComputation(const HloComputation& computation) { - return computation.root_instruction()->Accept(this); -} - StatusOr LhloDialectEmitter::CreateFunction( const HloInstruction& instr) { TF_ASSIGN_OR_RETURN(auto args, GetInstructionArgTypes(instr, builder_)); @@ -311,7 +306,7 @@ Status LhloDialectEmitter::HandleFusion(HloInstruction* instr) { Status LhloDialectEmitter::HandleGather(HloInstruction* instr) { HloGatherInstruction* gather = static_cast(instr); - mlir::xla_hlo::GatherDimensionNumbers dim_numbers = + mlir::mhlo::GatherDimensionNumbers dim_numbers = xla::CreateGatherDimensionNumbers(gather->gather_dimension_numbers(), builder_); mlir::DenseIntElementsAttr slice_sizes = CreateDenseIntElementsAttrFromVector( diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h index 185c1e13bb7992..145d3681b16051 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h @@ -47,8 +47,6 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault, ::mlir::ModuleOp mlir_module); ~LhloDialectEmitter() override = default; - Status EmitComputation(const HloComputation& computation); - // The following methods implement the DfsHloVisitor interface. // // Default action which emits code for most operations. Operations which are @@ -71,8 +69,10 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault, Status FinishVisit(HloInstruction* root) override; // Transfers the ownship of thunk_sequence_ out. - std::unique_ptr ConsumeThunkSequence() { - return std::move(thunk_sequence_); + gpu::ThunkSequence ConsumeThunkSequence() { + gpu::ThunkSequence result; + std::swap(result, thunk_sequence_); + return result; } const absl::flat_hash_map& @@ -100,7 +100,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault, // Cached pointer size extracted from the mlir module. unsigned pointer_size_; // The thunk sequence this IrEmitter generates for the input computation. - std::unique_ptr thunk_sequence_; + gpu::ThunkSequence thunk_sequence_; TF_DISALLOW_COPY_AND_ASSIGN(LhloDialectEmitter); }; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc index beabc99a1738f5..eb901e59fd85a7 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -488,8 +488,17 @@ StatusOr> MlirCompilerImpl::RunBackend( LhloDialectEmitter lhlo_emitter(&emission_context, *buffer_assignment, stream_exec->platform(), *mlir_module); - TF_RETURN_IF_ERROR(lhlo_emitter.EmitComputation( - *emission_context.getHloModule()->entry_computation())); + absl::flat_hash_map> + hlo_to_thunk; + for (HloInstruction* instruction : hlo_schedule->ThunkLaunchOrder()) { + TF_RETURN_IF_ERROR(instruction->Visit(&lhlo_emitter)); + gpu::ThunkSequence thunks = lhlo_emitter.ConsumeThunkSequence(); + TF_RET_CHECK(thunks.size() <= 1) << instruction->ToString(); + if (!thunks.empty()) { + auto thunk = std::move(thunks.front()); + hlo_to_thunk[instruction] = std::move(thunk); + } + } TF_RETURN_IF_ERROR( module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module)); @@ -507,13 +516,26 @@ StatusOr> MlirCompilerImpl::RunBackend( TF_ASSIGN_OR_RETURN(OwningModuleRef kernel_module, ExtractKernelModule(*mlir_module)); - auto thunk_sequence = lhlo_emitter.ConsumeThunkSequence(); for (auto entry : lhlo_emitter.InstructionToFunctionMap()) { TF_ASSIGN_OR_RETURN( auto thunk, TransformKernelToXlaThunk(entry.second, entry.first, *kernel_module, buffer_assignment.get())); - thunk_sequence->push_back(std::move(thunk)); + hlo_to_thunk[entry.first] = std::move(thunk); + } + + absl::flat_hash_map thunk_to_hlo; + gpu::ThunkSequence thunk_sequence; + { + for (HloInstruction* hlo : hlo_schedule->ThunkLaunchOrder()) { + auto it = hlo_to_thunk.find(hlo); + if (it != hlo_to_thunk.end()) { + const HloInstruction* hlo = it->first; + auto& thunk = it->second; + thunk_to_hlo[thunk.get()] = hlo; + thunk_sequence.push_back(std::move(thunk)); + } + } } TF_RETURN_IF_ERROR( @@ -539,8 +561,8 @@ StatusOr> MlirCompilerImpl::RunBackend( gpu::PtxOptsFromConfig(config))); auto thunk_schedule = absl::make_unique( - std::move(thunk_sequence), std::move(stream_assignment), - hlo_schedule->ThunkLaunchOrder()); + std::make_unique(std::move(thunk_sequence)), + std::move(stream_assignment), std::move(thunk_to_hlo)); if (DumpingEnabledForHloModule(*emission_context.getHloModule())) { DumpToFileInDirOrStdout(*emission_context.getHloModule(), "", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo index 0927a6dc15d22b..ba29b0a17fd4d9 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo @@ -6,5 +6,5 @@ ENTRY %Abs (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @abs(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.abs"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo index d8c20cfdab0311..37c163eb83e647 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo @@ -8,5 +8,5 @@ ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { } // CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo index 05dbbebf197329..2603b925c76e75 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo @@ -10,13 +10,13 @@ ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { } // CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]]) -// CHECK: "xla_lhlo.fusion"() ( { +// CHECK: "lmhlo.fusion"() ( { // CHECK: %[[REF0:.*]] = tensor_load %[[ARG0]] : [[TYPE]] // CHECK: %[[REF1:.*]] = tensor_load %[[ARG1]] : [[TYPE]] // CHECK: %[[REF2:.*]] = tensor_load %[[ARG2]] : [[TYPE]] -// CHECK: %[[ADD:.*]] = xla_hlo.add %[[REF1]], %[[REF2]] -// CHECK: %[[MUL:.*]] = xla_hlo.multiply %[[ADD]], %[[REF0]] +// CHECK: %[[ADD:.*]] = mhlo.add %[[REF1]], %[[REF2]] +// CHECK: %[[MUL:.*]] = mhlo.multiply %[[ADD]], %[[REF0]] // CHECK: tensor_store %[[MUL]], %[[RESULT]] -// CHECK: "xla_lhlo.terminator"() +// CHECK: "lmhlo.terminator"() // CHECK-NEXT: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo index fd594b7eca513c..a57f427cedc9e8 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo @@ -14,11 +14,11 @@ ENTRY %AddReduce (x: f32[100,10], c: f32[]) -> f32[100] { } // CHECK: func @reduce(%[[ARG:.*]]: [[ARGT:.*]], %[[CST:.*]]: memref, %[[RES:.*]]: [[REST:.*]]) { -// CHECK: "xla_lhlo.reduce"(%[[ARG]], %[[CST]], %[[RES]]) ( { +// CHECK: "lmhlo.reduce"(%[[ARG]], %[[CST]], %[[RES]]) ( { // CHECK: ^bb0(%[[FARG0:.*]]: memref, %[[FARG1:.*]]: memref, %[[FRES:.*]]: memref): // CHECK: %[[LHS:.*]] = tensor_load %[[FARG0]] : memref // CHECK: %[[RHS:.*]] = tensor_load %[[FARG1]] : memref -// CHECK: %[[RES:.*]] = xla_hlo.add %[[LHS]], %[[RHS]] : tensor +// CHECK: %[[RES:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor // CHECK: tensor_store %[[RES]], %[[FRES]] : memref -// CHECK: "xla_lhlo.terminator"() : () -> () +// CHECK: "lmhlo.terminator"() : () -> () // CHECK-NEXT: }) {dimensions = dense<1> : tensor<1xi64>} : ([[ARGT]], memref, [[REST]]) -> () diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo index 9a2736c019a812..366545c431fa9d 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo @@ -7,7 +7,7 @@ ENTRY %Broadcast (x: f32[10]) -> f32[10, 5] { } // CHECK: func @broadcast(%[[IN:.*]]: [[IN_T:.*]], %[[OUT:.*]]: [[OUT_T:.*]]) { -// CHECK: "xla_lhlo.broadcast_in_dim"(%[[IN]], %[[OUT]]) +// CHECK: "lmhlo.broadcast_in_dim"(%[[IN]], %[[OUT]]) // CHECK: {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: : ([[IN_T]], [[OUT_T]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo index 71014e17db8211..6bbddb61a746d7 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo @@ -7,4 +7,4 @@ ENTRY %Add (x: f32[2,2,2], y: f32[2,2,2]) -> f32[2,2,2] { ROOT %add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y) } -// CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y): failed for testing: xla_lhlo.add; failed for testing: std.return] +// CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y): failed for testing: lmhlo.add; failed for testing: std.return] diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo index 26a4131617ee13..f45fa1a55e2594 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo @@ -6,5 +6,5 @@ ENTRY %Ceil (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @ceil(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.ceil"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.ceil"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo index 99662951456cb8..2a34f494083d94 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo @@ -8,6 +8,6 @@ ENTRY %Compare (x: f32[2,2], y: f32[2,2]) -> pred[2,2] { } // CHECK: func @compare(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[PRED:.*]]: [[PRED_TYPE:.*]]) { -// CHECK: "xla_lhlo.compare"(%[[ARG0]], %[[ARG1]], %[[PRED]]) +// CHECK: "lmhlo.compare"(%[[ARG0]], %[[ARG1]], %[[PRED]]) // CHECK: {comparison_direction = "EQ"} : ([[TYPE]], [[TYPE]], [[PRED_TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo index 996ca0b2786858..99a4872b2282b5 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo @@ -8,5 +8,5 @@ ENTRY %Complex (real: f32[2,2]{0,1}, imag: f32[2,2]{0,1}) -> c64[2,2] { } // CHECK: func @complex(%[[REAL:.*]]: [[BUF_F32:.*]], %[[IMAG:.*]]: [[BUF_F32]], %[[OUT:.*]]: [[BUF_C64:.*]]) { -// CHECK: "xla_lhlo.complex"(%[[REAL]], %[[IMAG]], %[[OUT]]) : ([[BUF_F32]], [[BUF_F32]], [[BUF_C64]]) -> () +// CHECK: "lmhlo.complex"(%[[REAL]], %[[IMAG]], %[[OUT]]) : ([[BUF_F32]], [[BUF_F32]], [[BUF_C64]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo index 0b858842db7864..06f29185aa1172 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo @@ -8,6 +8,6 @@ ENTRY %Concatenate (x: f32[2,3], y: f32[2,2]) -> f32[2,5] { } // CHECK: func @concatenate(%[[ARG0:.*]]: [[TYPE0:.*]], %[[ARG1:.*]]: [[TYPE1:.*]], %[[RESULT:.*]]: [[RTYPE:.*]]) { -// CHECK: "xla_lhlo.concatenate"(%[[ARG0]], %[[ARG1]], %[[RESULT]]) +// CHECK: "lmhlo.concatenate"(%[[ARG0]], %[[ARG1]], %[[RESULT]]) // CHECK: {dimension = 1 : i64} : ([[TYPE0]], [[TYPE1]], [[RTYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo index 632a44a79e7ff6..e0745c4763eac9 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo @@ -7,6 +7,6 @@ ENTRY %Const () -> s32[100] { } // CHECK: func @constant(%[[ARG0:.*]]: memref) -// CHECK: "xla_lhlo.constant"(%[[ARG0]]) {value = dense<10> : tensor} +// CHECK: "lmhlo.constant"(%[[ARG0]]) {value = dense<10> : tensor} // CHECK: func @broadcast(%[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref<100xi32>) -// CHECK: "xla_lhlo.broadcast_in_dim"(%[[ARG1]], %[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} +// CHECK: "lmhlo.broadcast_in_dim"(%[[ARG1]], %[[ARG2]]) {broadcast_dimensions = dense<> : tensor<0xi64>} diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo index cc1acd03ad55d2..b4058da80192b5 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo @@ -7,4 +7,4 @@ ENTRY %Copy (x: f32[2,4]) -> f32[2,4] { } // CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, %[[RESULT:.*]]: memref<2x4xf32>) { -// CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4xf32>, memref<2x4xf32>) -> () +// CHECK: "lmhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4xf32>, memref<2x4xf32>) -> () diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo index 7a9b994eae6679..3a3dd22b338d4b 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo @@ -9,5 +9,5 @@ ENTRY %CopyTranspose (x: f32[2,4]) -> f32[2,4]{0,1} { // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> // CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, // CHECK-SAME: %[[RESULT:.*]]: memref<2x4xf32, #[[MAP0]]>) -// CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) +// CHECK: "lmhlo.copy"(%[[OPERAND]], %[[RESULT]]) // CHECK-SAME: : (memref<2x4xf32>, memref<2x4xf32, #[[MAP0]]>) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo index 12c9c16d6895ee..8a00a56206c412 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo @@ -6,5 +6,5 @@ ENTRY %Cos (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @cosine(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.cosine"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.cosine"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo index 741ebe1118e173..42cc605b2b63a8 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo @@ -7,6 +7,6 @@ ENTRY %Exp (x: f32[2,2]) -> f32[2,2] { } // CHECK: func @exponential(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.exponential"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.exponential"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo index 66437757140243..f74cdef1473883 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo @@ -21,15 +21,15 @@ ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] { } // CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]]) -// CHECK: "xla_lhlo.fusion"() ( { +// CHECK: "lmhlo.fusion"() ( { // CHECK: %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]] -// CHECK: %[[CT0:.*]] = xla_hlo.constant dense<0.000000e+00> -// CHECK: %[[RED:.*]] = "xla_hlo.reduce"(%0, %1) ( { +// CHECK: %[[CT0:.*]] = mhlo.constant dense<0.000000e+00> +// CHECK: %[[RED:.*]] = "mhlo.reduce"(%0, %1) ( { // CHECK: ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]]) -// CHECK: %[[ADD:.*]] = xla_hlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]] -// CHECK: "xla_hlo.return"(%[[ADD]]) +// CHECK: %[[ADD:.*]] = mhlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]] +// CHECK: "mhlo.return"(%[[ADD]]) // CHECK: }) // CHECK: tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]] -// CHECK: "xla_lhlo.terminator"() +// CHECK: "lmhlo.terminator"() // CHECK-NEXT: }) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo index 8dbd5dab178dfd..470ae348740a84 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo @@ -11,7 +11,7 @@ ENTRY %Gather (x: f32[100,10], y: s64[4,6]) -> f32[4,6,10] { // CHECK: func @gather(%[[ARG0:.*]]: [[TYPE0:.*]], %[[ARG1:.*]]: [[TYPE1:.*]], // CHECK-SAME: %[[RESULT:.*]]: [[RTYPE:.*]]) { -// CHECK-NEXT: "xla_lhlo.gather"(%[[ARG0]], %[[ARG1]], %[[RESULT]]) { +// CHECK-NEXT: "lmhlo.gather"(%[[ARG0]], %[[ARG1]], %[[RESULT]]) { // CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64>, // CHECK-SAME: index_vector_dim = 2 : i64, // CHECK-SAME: offset_dims = dense<2> : tensor<1xi64>, diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo index 01d125fd866003..50ff5571dbee03 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo @@ -7,5 +7,5 @@ ENTRY %Imag (x: c64[2,2]{0,1}) -> f32[2,2] { } // CHECK: func @imag(%[[IN:.*]]: [[BUF_C64:.*]], %[[OUT:.*]]: [[BUF_F32:.*]]) { -// CHECK: "xla_lhlo.imag"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> () +// CHECK: "lmhlo.imag"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo index eb97667886f30f..1755e4b0157ac7 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo @@ -6,6 +6,6 @@ HloModule Iota } // CHECK: func @iota(%[[OUT:.*]]: [[OUT_T:.*]]) { -// CHECK: "xla_lhlo.iota"(%[[OUT]]) +// CHECK: "lmhlo.iota"(%[[OUT]]) // CHECK: {iota_dimension = 0 : i64} : ([[OUT_T]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo index 3a19bc2f7036ff..5f1156497b9d06 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo @@ -7,5 +7,5 @@ ENTRY %Log (x: f32[2,2]) -> f32[2,2] { } // CHECK: func @log(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo index 45804cf8edd1ce..30557f134496eb 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo @@ -6,5 +6,5 @@ ENTRY %Neg (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @negate(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.negate"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.negate"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo index b1b02976a7d7b1..559a4db4914a5f 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo @@ -7,5 +7,5 @@ ENTRY %Real (x: c64[2,2]{0,1}) -> f32[2,2] { } // CHECK: func @real(%[[IN:.*]]: [[BUF_C64:.*]], %[[OUT:.*]]: [[BUF_F32:.*]]) { -// CHECK: "xla_lhlo.real"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> () +// CHECK: "lmhlo.real"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo index 97977e93d44bde..4c23a9854b1e17 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo @@ -19,13 +19,13 @@ ENTRY %ReduceWindow (x: f32[128,64,112,112], y: f32[]) -> f32[128,64,56,56] { // CHECK: func @"reduce-window"( // CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[CST:%.*]]: memref, [[RES:%.*]]: [[REST:.*]]) { -// CHECK: "xla_lhlo.reduce_window"([[LHS:%.*]], [[RHS:%.*]], [[OUT:%.*]]) ( { +// CHECK: "lmhlo.reduce_window"([[LHS:%.*]], [[RHS:%.*]], [[OUT:%.*]]) ( { // CHECK: ^bb0([[LHS:%.*]]: memref, [[RHS:%.*]]: memref, [[OUT:%.*]]: memref): // CHECK: [[LHS_TENSOR:%.*]] = tensor_load [[LHS]] // CHECK: [[RHS_TENSOR:%.*]] = tensor_load [[RHS]] -// CHECK: [[OUT_TENSOR:%.*]] = xla_hlo.maximum [[LHS_TENSOR]], [[RHS_TENSOR]] +// CHECK: [[OUT_TENSOR:%.*]] = mhlo.maximum [[LHS_TENSOR]], [[RHS_TENSOR]] // CHECK: tensor_store [[OUT_TENSOR]], [[OUT]] -// CHECK: "xla_lhlo.terminator"() : () -> () +// CHECK: "lmhlo.terminator"() : () -> () // CHECK: }) { // CHECK-SAME: base_dilations = dense<1> : tensor<4xi64> // CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]> diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo index 172e3224b779c9..6d3afb07f56a3b 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo @@ -7,5 +7,5 @@ ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] { } // CHECK: func @remainder(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.remainder"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.remainder"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo index 44167bba987e22..11d18e88061cec 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo @@ -7,5 +7,5 @@ ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] { } // CHECK: func @rsqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo index d900f56dcb2a24..bf25c69c524222 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo @@ -9,6 +9,6 @@ ENTRY %Select (p: pred[2,2], x: f32[2,2], y: f32[2,2]) -> f32[2,2] { } // CHECK: func @select(%[[PRED:.*]]: [[PRED_TYPE:.*]], %[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.select"(%[[PRED]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[PRED_TYPE]], [[TYPE]], [[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.select"(%[[PRED]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[PRED_TYPE]], [[TYPE]], [[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo index da138103bfbf68..46d29856828c59 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo @@ -30,23 +30,23 @@ ENTRY %SelectAndScatter (x: f32[128,64,112,112], // CHECK: func @"select-and-scatter"( // CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[SRC:%.*]]: [[SRCT:.*]], [[CST:%.*]]: memref, [[RES:%.*]]: [[REST:.*]]) { -// CHECK: "xla_lhlo.select_and_scatter"([[ARG]], [[SRC]], [[CST]], [[RES]]) ( { +// CHECK: "lmhlo.select_and_scatter"([[ARG]], [[SRC]], [[CST]], [[RES]]) ( { // CHECK: ^bb0([[LHS:%.*]]: memref, [[RHS:%.*]]: memref, // CHECK-SAME: [[OUT:%.*]]: memref): // CHECK: [[LHS_TENSOR:%.*]] = tensor_load [[LHS]] // CHECK: [[RHS_TENSOR:%.*]] = tensor_load [[RHS]] -// CHECK: [[OUT_TENSOR:%.*]] = "xla_hlo.compare" +// CHECK: [[OUT_TENSOR:%.*]] = "mhlo.compare" // CHECK-SAME: ([[LHS_TENSOR]], [[RHS_TENSOR]]) {comparison_direction = "GE"} // CHECK: tensor_store [[OUT_TENSOR]], [[OUT]] -// CHECK: xla_lhlo.terminator +// CHECK: lmhlo.terminator // CHECK: }, { // CHECK: ^bb0([[LHS_:%.*]]: memref, [[RHS_:%.*]]: memref, // CHECK-SAME: [[OUT_:%.*]]: memref): // CHECK: [[LHS_TENSOR_:%.*]] = tensor_load [[LHS_]] // CHECK: [[RHS_TENSOR_:%.*]] = tensor_load [[RHS_]] -// CHECK: [[OUT_TENSOR_:%.*]] = xla_hlo.add [[LHS_TENSOR_]], [[RHS_TENSOR_]] +// CHECK: [[OUT_TENSOR_:%.*]] = mhlo.add [[LHS_TENSOR_]], [[RHS_TENSOR_]] // CHECK: tensor_store [[OUT_TENSOR_]], [[OUT_]] -// CHECK: xla_lhlo.terminator +// CHECK: lmhlo.terminator // CHECK: }) { // CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]> // CHECK-SAME: window_dimensions = dense<[1, 1, 3, 3]> diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo index 0a7afa69babf27..6acadb84e17f68 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo @@ -6,5 +6,5 @@ ENTRY %Sign (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @sign(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.sign"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.sign"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo index 54bf947350b24b..4e47229397d293 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo @@ -7,6 +7,6 @@ ENTRY %Sqrt (x: f32[2,2]) -> f32[2,2] { } // CHECK: func @sqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.sqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.sqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo index ff147c9041c1df..681c18aed29811 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo @@ -6,5 +6,5 @@ ENTRY %Tanh (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @tanh(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.tanh"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.tanh"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc index c1d401613d70cc..f160d7dafb286a 100644 --- a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc @@ -75,9 +75,8 @@ StatusOr OptimizeInputOutputBufferAlias::Build( const ShapeIndex& output_index = index; if (!alias_config->ParameterHasAlias(0, input_index) && !alias_config->OutputHasAlias(output_index)) { - TF_RETURN_IF_ERROR(alias_config->SetUpAlias( - output_index, 0, input_index, - HloInputOutputAliasConfig::AliasKind::kSystemAlias)); + TF_RETURN_IF_ERROR( + alias_config->SetUpAlias(output_index, 0, input_index)); } entry.used = true; break; diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index a1903cd2746154..6c4cf2d786624f 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -1384,8 +1384,8 @@ Status CheckAndUpdateDeviceAssignmentsInWhileBody( [](const HloSharding& s) { return !s.HasUniqueDevice(); }); } if (is_spatially_partitioned) { - for (HloInstruction* domain : domain.exit_domains) { - domain->mutable_operand(0)->set_sharding(*sharding); + for (HloInstruction* d : domain.exit_domains) { + d->mutable_operand(0)->set_sharding(*sharding); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index f5466c632ac985..7459b3d3f1f926 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc @@ -108,7 +108,7 @@ class CollectiveOpsTest : public HloTestBase { } template - void TestAllOps() { + void TestAllOpsForReduce() { auto cast = [&](int value) { return static_cast(value); }; auto to_literal = [&](absl::Span values) { return LiteralUtil::CreateR1(values); @@ -183,39 +183,39 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) { } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) { @@ -593,6 +593,98 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { results[3])); } +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_EmptyReplicaGroups)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a = f32[2] constant({10, 10}) + b = f32[2] constant({20, 20}) + c = f32[2] constant({30, 30}) + d = f32[2] constant({40, 40}) + all2all = (f32[2], f32[2], f32[2], f32[2]) all-to-all(a, b, c, d), replica_groups={} + a_prime = f32[2] get-tuple-element(all2all), index=0 + b_prime = f32[2] get-tuple-element(all2all), index=1 + c_prime = f32[2] get-tuple-element(all2all), index=2 + d_prime = f32[2] get-tuple-element(all2all), index=3 + ROOT out = f32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0} + } + )"; + const int64 kNumReplicas = 4; + auto config = GetModuleConfigForTest(kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (int i = 0; i < kNumReplicas; i++) { + EXPECT_TRUE(LiteralTestUtil::NearOrEqual( + LiteralUtil::CreateR1({10, 10, 20, 20, 30, 30, 40, 40}), + results[i], ErrorSpec{1e-5, 1e-5})); + } +} + +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_OrderedReplicaGroups)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a = f32[2] constant({10, 10}) + b = f32[2] constant({20, 20}) + c = f32[2] constant({30, 30}) + d = f32[2] constant({40, 40}) + all2all = (f32[2], f32[2], f32[2], f32[2]) all-to-all(a, b, c, d), replica_groups={{3,2,1,0}} + a_prime = f32[2] get-tuple-element(all2all), index=0 + b_prime = f32[2] get-tuple-element(all2all), index=1 + c_prime = f32[2] get-tuple-element(all2all), index=2 + d_prime = f32[2] get-tuple-element(all2all), index=3 + ROOT out = f32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0} + } + )"; + const int64 kNumReplicas = 4; + auto config = GetModuleConfigForTest(kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (int i = 0; i < kNumReplicas; i++) { + EXPECT_TRUE(LiteralTestUtil::NearOrEqual( + LiteralUtil::CreateR1({40, 40, 30, 30, 20, 20, 10, 10}), + results[i], ErrorSpec{1e-5, 1e-5})); + } +} + +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_TwoReplicaGroups)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a = f32[2] constant({10, 10}) + b = f32[2] constant({20, 20}) + all2all = (f32[2], f32[2]) all-to-all(a, b), replica_groups={{2,1},{3,0}} + a_prime = f32[2] get-tuple-element(all2all), index=0 + b_prime = f32[2] get-tuple-element(all2all), index=1 + ROOT out = f32[4] concatenate(a_prime, b_prime), dimensions={0} + } + )"; + const int64 kNumReplicas = 4; + auto config = GetModuleConfigForTest(kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (int i = 0; i < kNumReplicas; i++) { + EXPECT_TRUE(LiteralTestUtil::NearOrEqual( + LiteralUtil::CreateR1({20, 20, 10, 10}), results[i], + ErrorSpec{1e-5, 1e-5})); + } +} + XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) { std::string hlo_string = R"( HloModule test diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index b83fed07e34e32..201c0da87f1409 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -103,7 +103,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values( // The largest negative number smaller than zero in bf16 that's not // denormalized. - std::make_pair(static_cast(-bfloat16::min_positive_normal()), + std::make_pair(static_cast( + -std::numeric_limits::min()), 0.0f), // Test odd and even values. std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f), diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 3b5023457b2e09..311a4a38e8b573 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -233,8 +233,15 @@ StatusOr ReplayComputation(const HloSnapshot& module, if (opts.use_fake_data) { // Run fake computations with debug options ignoring XLA_FLAGS. Users very // likely want XLA_FLAGS only to apply to the "real" computation being run, - // not to the fake computations we use for generating arguments. + // not to the fake computations we use for generating arguments. There is + // an exception. ptxas can be called during the generation of fake + // data. As it is cached in the process memory, the flag affecting this call + // should not be ignored. + auto debug_opts_flags = GetDebugOptionsFromFlags(); auto debug_opts = DefaultDebugOptionsIgnoringFlags(); + debug_opts.set_xla_gpu_asm_extra_flags( + debug_opts_flags.xla_gpu_asm_extra_flags()); + global_data_arguments = MakeFakeArgumentsOrDie(computation, client, &debug_opts); for (const auto& data : global_data_arguments) { diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc index 926ba23c7af629..5381e3265b9fe4 100644 --- a/tensorflow/compiler/xrt/xrt_util.cc +++ b/tensorflow/compiler/xrt/xrt_util.cc @@ -303,11 +303,8 @@ Status RebuildOutputAliases( [&](const xla::ShapeIndex& output_index, const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { TF_RET_CHECK(alias.parameter_number < input_tuples.size()); - return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias - ? output_tuple->AliasBufferFrom( - *input_tuples[alias.parameter_number], - alias.parameter_index, output_index) - : Status::OK(); + return output_tuple->AliasBufferFrom(*input_tuples[alias.parameter_number], + alias.parameter_index, output_index); }; return input_output_alias.ForEachAliasWithStatus(alias_function); } @@ -332,17 +329,7 @@ xla::StatusOr> GetArgumentsBuffers( for (int64 i = 0; i < input_tuples.size(); ++i) { auto alias_checker = [&](const xla::ShapeIndex& index) -> xla::StatusOr { - // Only the buffers which the caller explicitly marked as aliased - // (kUserAlias), should create aliases. - // The XLA compiler might create opportunistic aliases (kSystemAlias) - // which need a different handling. With a system alias we know that XLA - // is going to reuse a given input parameter buffer for a given output, so - // unless it is known at call site that the input buffer has no more uses, - // a copy needs to be made at call site. With user specified alias the - // caller tells us that he expects a given output to land over the buffers - // of a given parametter. - if (input_output_alias.ParameterAliasKind(i, index) == - xla::HloInputOutputAliasConfig::AliasKind::kUserAlias) { + if (input_output_alias.ParameterHasAlias(i, index)) { TF_RET_CHECK(!is_dynamic(i)); return true; } diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 37bb14b4e070ae..18341a81df4465 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1297,7 +1297,6 @@ filegroup( srcs = [ # Sources for which we do not yet have granular targets. "//tensorflow/c/eager:srcs", - "//tensorflow/c/experimental/saved_model/core:mobile_srcs_only_runtime", "//tensorflow/c:srcs", "//tensorflow/core/common_runtime:mobile_srcs_only_runtime", "//tensorflow/core/common_runtime/eager:srcs", @@ -2260,7 +2259,6 @@ tf_cuda_library( "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/core/tpu:tpu_api_dlsym_initializer", "//tensorflow/core/util:einsum_op_util", "//tensorflow/core/util:padding", "//tensorflow/core/util:port", diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt index 9adb1a4056c0aa..3030c60fc2d01a 100644 --- a/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt @@ -1,6 +1,5 @@ op { graph_op_name: "DecodeProtoV2" - visibility: HIDDEN in_arg { name: "bytes" description: <