1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <vector>
18
19 #include "mlir/IR/OwningOpRef.h" // from @llvm-project
20 #include "mlir/Parser/Parser.h" // from @llvm-project
21 #include "mlir/Pass/PassManager.h" // from @llvm-project
22 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
24 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
25 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
31 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
32 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/op_requires.h"
37 #include "tensorflow/core/framework/types.pb.h"
38
39 namespace tensorflow {
40 namespace {
41
42 void RefineDynamicShapes(XlaOpKernelContext *ctx, mlir::MLIRContext *context,
43 mlir::OwningOpRef<mlir::ModuleOp> *module,
44 int nr_dim_args, bool *dim_args_are_i64);
45
46 void PopulateDimArgInputs(XlaOpKernelContext *ctx,
47 std::vector<string> dim_args_spec,
48 bool dim_args_are_i64,
49 std::vector<xla::XlaOp> *inputs);
50
51 class XlaCallModuleOp : public XlaOpKernel {
52 public:
XlaCallModuleOp(OpKernelConstruction * ctx)53 explicit XlaCallModuleOp(OpKernelConstruction *ctx) : XlaOpKernel(ctx) {
54 OP_REQUIRES_OK(ctx, ctx->GetAttr("module", &module_str_));
55 std::vector<PartialTensorShape> expected_output_shapes;
56 OP_REQUIRES_OK(ctx, ctx->GetAttr("Sout", &expected_output_shapes));
57 std::vector<DataType> expected_output_dtypes;
58 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &expected_output_dtypes));
59 OP_REQUIRES_OK(ctx, ctx->GetAttr("dim_args_spec", &dim_args_spec_));
60 OP_REQUIRES(ctx,
61 expected_output_shapes.size() == expected_output_dtypes.size(),
62 errors::InvalidArgument("The size of Sout (",
63 expected_output_shapes.size(),
64 ") must match the size of Tout (",
65 expected_output_dtypes.size(), ")"));
66 expected_nr_outputs_ = expected_output_shapes.size();
67 }
68
Compile(XlaOpKernelContext * ctx)69 void Compile(XlaOpKernelContext *ctx) override {
70 // Code inpired by
71 // tensorflow/compiler/xla/python/mlir.cc::PyMlirModuleToXlaComputation
72 mlir::MLIRContext context;
73 mlir::OwningOpRef<mlir::ModuleOp> module;
74 context.loadDialect<mlir::func::FuncDialect>();
75 context.loadDialect<mlir::mhlo::MhloDialect>();
76 context.loadDialect<mlir::chlo::ChloDialect>();
77 context.loadDialect<mlir::TF::TensorFlowDialect>();
78 module = mlir::parseSourceString<mlir::ModuleOp>(
79 llvm::StringRef(module_str_), &context);
80 OP_REQUIRES(ctx, module,
81 errors::InvalidArgument("Cannot deserialize MHLO computation"));
82 if (failed(module->verifyInvariants())) {
83 VLOG(1) << "MLIR verification failed.";
84 module->dump();
85 OP_REQUIRES(ctx, false,
86 errors::InvalidArgument("Error verifying MHLO module"));
87 }
88
89 int nr_dim_args = dim_args_spec_.size();
90 std::vector<xla::XlaOp> inputs(nr_dim_args + ctx->num_inputs());
91
92 if (nr_dim_args > 0) {
93 bool dim_args_are_i64 = true;
94 RefineDynamicShapes(ctx, &context, &module, nr_dim_args,
95 &dim_args_are_i64);
96 PopulateDimArgInputs(ctx, dim_args_spec_, dim_args_are_i64, &inputs);
97 }
98 for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
99 inputs[nr_dim_args + i] = ctx->Input(i);
100 }
101
102 xla::XlaComputation xla_computation;
103 OP_REQUIRES_OK(
104 ctx, MlirToXlaComputation(*module, xla_computation, false, false));
105 xla::XlaOp output = xla::Call(ctx->builder(), xla_computation, inputs);
106
107 // Check that the resulting computation returns the expected shape
108 OP_REQUIRES_VALUE(xla::Shape found_output_shape, ctx,
109 ctx->builder()->GetShape(output));
110 VLOG(3) << "XlaCallModule compiled output shape : "
111 << xla::ShapeUtil::HumanString(found_output_shape);
112
113 if (expected_nr_outputs_ == 1) {
114 ctx->SetOutput(0, output);
115 } else {
116 for (int i = 0; i < expected_nr_outputs_; ++i) {
117 ctx->SetOutput(i, xla::GetTupleElement(output, i));
118 }
119 }
120 }
121
122 private:
123 string module_str_;
124 int expected_nr_outputs_;
125 std::vector<string> dim_args_spec_;
126 };
127
128 // If there are dynamic shapes then resolve the unknown dimensions based on
129 // the static shapes of the actual arguments and shape inference.
RefineDynamicShapes(XlaOpKernelContext * ctx,mlir::MLIRContext * context,mlir::OwningOpRef<mlir::ModuleOp> * module,int nr_dim_args,bool * dim_args_are_i64)130 void RefineDynamicShapes(XlaOpKernelContext *ctx, mlir::MLIRContext *context,
131 mlir::OwningOpRef<mlir::ModuleOp> *module,
132 int nr_dim_args, bool *dim_args_are_i64) {
133 // Locate the 'main' function.
134 // This is the convention used by MlirToXlaComputation.
135 auto main = (*module)->lookupSymbol<mlir::func::FuncOp>("main");
136 OP_REQUIRES(ctx, main,
137 errors::InvalidArgument("Cannot find 'main' in MHLO module"));
138 VLOG(3) << "XlaCallModule main function: " << debugString(main);
139 mlir::Block &main_body = main.front();
140
141 OP_REQUIRES(ctx,
142 nr_dim_args + ctx->num_inputs() == main_body.getNumArguments(),
143 errors::InvalidArgument(
144 "Incorrect number of arguments for XlaCallModule. ",
145 "The module expects ", main_body.getNumArguments(),
146 " and dim_args_spec specifies ", nr_dim_args,
147 " dimension arguments, but there are ", ctx->num_inputs(),
148 " actual arguments"));
149 // Obtain static input types in MLIR terms.
150 mlir::Builder builder(context);
151
152 std::vector<mlir::Type> static_input_types(main_body.getNumArguments());
153 // The dim_arg parameters already have known types.
154 for (int i = 0; i < nr_dim_args; ++i) {
155 static_input_types[i] = main_body.getArgument(i).getType();
156 *dim_args_are_i64 = (static_input_types[i].getIntOrFloatBitWidth() == 64);
157 }
158
159 // Now the actual arguments
160 for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
161 OP_REQUIRES_VALUE(xla::Shape xla_shape, ctx, ctx->InputXlaShape(i));
162 std::vector<int64_t> xla_dimensions(xla_shape.dimensions().begin(),
163 xla_shape.dimensions().end());
164 OP_REQUIRES_VALUE(
165 mlir::Type element_type, ctx,
166 ConvertPrimitiveTypeToMLIRType(xla_shape.element_type(), builder));
167 mlir::Type type = mlir::RankedTensorType::get(xla_dimensions, element_type);
168 // TODO(burmako): This fails with an obscure compilation error.
169 // OP_REQUIRES_VALUE(
170 // mlir::Type type, ctx,
171 // ConvertShapeToType<mlir::RankedTensorType>(xla_shape, builder));
172 VLOG(3) << "XlaCallModule static input type #" << nr_dim_args + i << ": "
173 << debugString(type);
174 static_input_types[nr_dim_args + i] = type;
175 }
176
177 // Refine 'main' argument types to use static input types instead.
178 // This will only change the argument types and will not propagate the
179 // additional type information further. For that, we'll need to run
180 // shape inference as explained below.
181 main.setType(
182 builder.getFunctionType(static_input_types, main->getResultTypes()));
183 for (auto i = 0; i < main_body.getNumArguments(); ++i) {
184 main_body.getArgument(i).setType(static_input_types[i]);
185 }
186
187 // --tf-shape-inference, despite its TF-specific name, seems to be general
188 // enough to also work on MHLO. (Although it fails if it doesn't see a
189 // tf.versions attribute on the module, which we hackily attach).
190 auto tf_producer =
191 builder.getNamedAttr("producer", builder.getI32IntegerAttr(0));
192 (**module)->setAttr("tf.versions", builder.getDictionaryAttr({tf_producer}));
193
194 // Run --tf-shape-inference.
195 mlir::PassManager pm(context);
196 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
197 OP_REQUIRES(ctx, mlir::succeeded(pm.run(**module)),
198 errors::InvalidArgument("MHLO shape inference failed"));
199 VLOG(3) << "XlaCallModule main function with inferred types: "
200 << debugString(*main);
201 }
202
203 // Compute the dim_arg inputs based on the static shapes of the actual arguments
204 // and put them in the inputs vector.
PopulateDimArgInputs(XlaOpKernelContext * ctx,std::vector<string> dim_args_spec,bool dim_args_are_i64,std::vector<xla::XlaOp> * inputs)205 void PopulateDimArgInputs(XlaOpKernelContext *ctx,
206 std::vector<string> dim_args_spec,
207 bool dim_args_are_i64,
208 std::vector<xla::XlaOp> *inputs) {
209 int nr_dim_args = dim_args_spec.size();
210 for (int i = 0; i < nr_dim_args; ++i) {
211 string dim_arg_spec = dim_args_spec[i];
212 size_t dot_pos = dim_arg_spec.find('.');
213 OP_REQUIRES(
214 ctx, dot_pos != string::npos && dot_pos + 1 < dim_arg_spec.size(),
215 errors::InvalidArgument("Cannot parse dim_args_spec ", dim_arg_spec));
216 int arg_idx = std::stoi(dim_arg_spec.substr(0, dot_pos));
217 int arg_axis_idx = std::stoi(
218 dim_arg_spec.substr(dot_pos + 1, dim_arg_spec.size() - dot_pos));
219 OP_REQUIRES_VALUE(xla::Shape xla_shape, ctx, ctx->InputXlaShape(arg_idx));
220
221 int64_t dim_arg_val = xla_shape.dimensions()[arg_axis_idx];
222 VLOG(3) << "XlaCallModule dim_input[" << i << "] = " << dim_arg_val;
223 if (dim_args_are_i64) {
224 (*inputs)[i] = xla::ConstantR0<int64_t>(ctx->builder(), dim_arg_val);
225 } else {
226 (*inputs)[i] = xla::ConstantR0<int32_t>(ctx->builder(), dim_arg_val);
227 }
228 }
229 }
230
231 REGISTER_XLA_OP(Name("XlaCallModule"), XlaCallModuleOp);
232 } // namespace
233 } // namespace tensorflow
234