1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/ADT/iterator_range.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
41 #include "mlir/Dialect/Traits.h" // from @llvm-project
42 #include "mlir/IR/Attributes.h" // from @llvm-project
43 #include "mlir/IR/Builders.h" // from @llvm-project
44 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
45 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
46 #include "mlir/IR/Diagnostics.h" // from @llvm-project
47 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
48 #include "mlir/IR/Location.h" // from @llvm-project
49 #include "mlir/IR/MLIRContext.h" // from @llvm-project
50 #include "mlir/IR/Matchers.h" // from @llvm-project
51 #include "mlir/IR/OpDefinition.h" // from @llvm-project
52 #include "mlir/IR/OpImplementation.h" // from @llvm-project
53 #include "mlir/IR/PatternMatch.h" // from @llvm-project
54 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
55 #include "mlir/IR/Types.h" // from @llvm-project
56 #include "mlir/IR/Value.h" // from @llvm-project
57 #include "mlir/Parser/Parser.h" // from @llvm-project
58 #include "mlir/Support/LLVM.h" // from @llvm-project
59 #include "mlir/Support/LogicalResult.h" // from @llvm-project
60 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
61 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
62 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
63 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
67 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
68 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
69 #include "tensorflow/core/platform/logging.h"
70 #include "tensorflow/core/util/tensor_format.h"
71
72 namespace mlir {
73 namespace TF {
74 namespace {
75 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
76 } // namespace
77
78 //===----------------------------------------------------------------------===//
79 // _XlaHostComputeOp
80 //===----------------------------------------------------------------------===//
81
82 // This verifies that `_XlaHostComputeMlirOp` has a well-formed
83 // `host_mlir_module` attribute.
84 // For other attributes, there is no additional verification beyond the default.
verify()85 LogicalResult _XlaHostComputeMlirOp::verify() {
86 _XlaHostComputeMlirOp op = *this;
87 // Extract the module and function.
88 StringRef host_module = op.host_mlir_module();
89
90 if (host_module.empty()) return success();
91
92 mlir::OwningOpRef<mlir::ModuleOp> module_for_func;
93 tensorflow::Status status = tensorflow::DeserializeMlirModule(
94 host_module.str(), op->getContext(), &module_for_func);
95 if (!status.ok()) {
96 return op.emitError()
97 << "attribute 'host_mlir_module' can not be deserialized. "
98 << status.error_message();
99 }
100
101 func::FuncOp func = module_for_func->lookupSymbol<func::FuncOp>("host_func");
102 if (!func)
103 return op.emitError()
104 << "serialized module in attribute 'host_mlir_module' does not "
105 "contain 'host_func' function.";
106
107 if (op->getNumOperands() != func.getFunctionType().getNumInputs())
108 return op.emitError()
109 << "'host_func' has " << func.getFunctionType().getNumInputs()
110 << " inputs and '_XlaHostComputeMlir' has " << op->getNumOperands()
111 << " operands. Number of operands/inputs should be the same.";
112
113 if (op->getNumResults() != func.getFunctionType().getNumResults())
114 return op.emitError() << "'host_func' has "
115 << func.getFunctionType().getNumResults()
116 << " results and '_XlaHostComputeMlir' has "
117 << op->getNumResults()
118 << " results. Number of results should be the same.";
119
120 return success();
121 }
122
GetHostFunc(mlir::OwningOpRef<mlir::ModuleOp> * mlir_module)123 func::FuncOp _XlaHostComputeMlirOp::GetHostFunc(
124 mlir::OwningOpRef<mlir::ModuleOp>* mlir_module) {
125 if (!tensorflow::DeserializeMlirModule(host_mlir_module().str(),
126 this->getContext(), mlir_module)
127 .ok())
128 return nullptr;
129 return (*mlir_module)->lookupSymbol<func::FuncOp>("host_func");
130 }
131
132 //===----------------------------------------------------------------------===//
133 // XLA Send/Recv ops
134 //===----------------------------------------------------------------------===//
135
136 // For XLA Send/Recv ops the key corresponds to the resource instance.
137
GetResourceInstanceStr()138 std::string _XlaRecvAtHostOp::GetResourceInstanceStr() { return key().str(); }
139
GetResourceInstanceStr()140 std::string _XlaRecvAtHostV2Op::GetResourceInstanceStr() { return key().str(); }
141
GetResourceInstanceStr()142 std::string _XlaSendFromHostOp::GetResourceInstanceStr() { return key().str(); }
143
GetResourceInstanceStr()144 std::string _XlaSendFromHostV2Op::GetResourceInstanceStr() {
145 return key().str();
146 }
147
148 namespace {
GetRendezvousKey(const std::string & send_device,const uint64_t send_device_incarnation,const std::string & recv_device,const std::string & tensor_name)149 std::string GetRendezvousKey(const std::string& send_device,
150 const uint64_t send_device_incarnation,
151 const std::string& recv_device,
152 const std::string& tensor_name) {
153 return absl::StrCat(send_device, ";", send_device_incarnation, ";",
154 recv_device, ";", tensor_name);
155 }
156 } // namespace
157
GetResourceInstanceStr()158 std::string _HostRecvOp::GetResourceInstanceStr() {
159 return GetRendezvousKey(send_device().str(), send_device_incarnation(),
160 recv_device().str(), tensor_name().str());
161 }
162
GetResourceInstanceStr()163 std::string _HostSendOp::GetResourceInstanceStr() {
164 return GetRendezvousKey(send_device().str(), send_device_incarnation(),
165 recv_device().str(), tensor_name().str());
166 }
167
GetResourceInstanceStr()168 std::string _RecvOp::GetResourceInstanceStr() {
169 return GetRendezvousKey(send_device().str(), send_device_incarnation(),
170 recv_device().str(), tensor_name().str());
171 }
172
GetResourceInstanceStr()173 std::string _SendOp::GetResourceInstanceStr() {
174 return GetRendezvousKey(send_device().str(), send_device_incarnation(),
175 recv_device().str(), tensor_name().str());
176 }
177
178 } // namespace TF
179 } // namespace mlir
180
181 //===----------------------------------------------------------------------===//
182 // TableGen'd op method definitions
183 //===----------------------------------------------------------------------===//
184
185 #define GET_OP_CLASSES
186 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc.inc"
187