xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <vector>
18 
19 #include "mlir/IR/OwningOpRef.h"  // from @llvm-project
20 #include "mlir/Parser/Parser.h"  // from @llvm-project
21 #include "mlir/Pass/PassManager.h"  // from @llvm-project
22 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
24 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
25 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
31 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
32 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/op_requires.h"
37 #include "tensorflow/core/framework/types.pb.h"
38 
39 namespace tensorflow {
40 namespace {
41 
42 void RefineDynamicShapes(XlaOpKernelContext *ctx, mlir::MLIRContext *context,
43                          mlir::OwningOpRef<mlir::ModuleOp> *module,
44                          int nr_dim_args, bool *dim_args_are_i64);
45 
46 void PopulateDimArgInputs(XlaOpKernelContext *ctx,
47                           std::vector<string> dim_args_spec,
48                           bool dim_args_are_i64,
49                           std::vector<xla::XlaOp> *inputs);
50 
51 class XlaCallModuleOp : public XlaOpKernel {
52  public:
XlaCallModuleOp(OpKernelConstruction * ctx)53   explicit XlaCallModuleOp(OpKernelConstruction *ctx) : XlaOpKernel(ctx) {
54     OP_REQUIRES_OK(ctx, ctx->GetAttr("module", &module_str_));
55     std::vector<PartialTensorShape> expected_output_shapes;
56     OP_REQUIRES_OK(ctx, ctx->GetAttr("Sout", &expected_output_shapes));
57     std::vector<DataType> expected_output_dtypes;
58     OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &expected_output_dtypes));
59     OP_REQUIRES_OK(ctx, ctx->GetAttr("dim_args_spec", &dim_args_spec_));
60     OP_REQUIRES(ctx,
61                 expected_output_shapes.size() == expected_output_dtypes.size(),
62                 errors::InvalidArgument("The size of Sout (",
63                                         expected_output_shapes.size(),
64                                         ") must match the size of Tout (",
65                                         expected_output_dtypes.size(), ")"));
66     expected_nr_outputs_ = expected_output_shapes.size();
67   }
68 
Compile(XlaOpKernelContext * ctx)69   void Compile(XlaOpKernelContext *ctx) override {
70     // Code inpired by
71     // tensorflow/compiler/xla/python/mlir.cc::PyMlirModuleToXlaComputation
72     mlir::MLIRContext context;
73     mlir::OwningOpRef<mlir::ModuleOp> module;
74     context.loadDialect<mlir::func::FuncDialect>();
75     context.loadDialect<mlir::mhlo::MhloDialect>();
76     context.loadDialect<mlir::chlo::ChloDialect>();
77     context.loadDialect<mlir::TF::TensorFlowDialect>();
78     module = mlir::parseSourceString<mlir::ModuleOp>(
79         llvm::StringRef(module_str_), &context);
80     OP_REQUIRES(ctx, module,
81                 errors::InvalidArgument("Cannot deserialize MHLO computation"));
82     if (failed(module->verifyInvariants())) {
83       VLOG(1) << "MLIR verification failed.";
84       module->dump();
85       OP_REQUIRES(ctx, false,
86                   errors::InvalidArgument("Error verifying MHLO module"));
87     }
88 
89     int nr_dim_args = dim_args_spec_.size();
90     std::vector<xla::XlaOp> inputs(nr_dim_args + ctx->num_inputs());
91 
92     if (nr_dim_args > 0) {
93       bool dim_args_are_i64 = true;
94       RefineDynamicShapes(ctx, &context, &module, nr_dim_args,
95                           &dim_args_are_i64);
96       PopulateDimArgInputs(ctx, dim_args_spec_, dim_args_are_i64, &inputs);
97     }
98     for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
99       inputs[nr_dim_args + i] = ctx->Input(i);
100     }
101 
102     xla::XlaComputation xla_computation;
103     OP_REQUIRES_OK(
104         ctx, MlirToXlaComputation(*module, xla_computation, false, false));
105     xla::XlaOp output = xla::Call(ctx->builder(), xla_computation, inputs);
106 
107     // Check that the resulting computation returns the expected shape
108     OP_REQUIRES_VALUE(xla::Shape found_output_shape, ctx,
109                       ctx->builder()->GetShape(output));
110     VLOG(3) << "XlaCallModule compiled output shape : "
111             << xla::ShapeUtil::HumanString(found_output_shape);
112 
113     if (expected_nr_outputs_ == 1) {
114       ctx->SetOutput(0, output);
115     } else {
116       for (int i = 0; i < expected_nr_outputs_; ++i) {
117         ctx->SetOutput(i, xla::GetTupleElement(output, i));
118       }
119     }
120   }
121 
122  private:
123   string module_str_;
124   int expected_nr_outputs_;
125   std::vector<string> dim_args_spec_;
126 };
127 
128 // If there are dynamic shapes then resolve the unknown dimensions based on
129 // the static shapes of the actual arguments and shape inference.
RefineDynamicShapes(XlaOpKernelContext * ctx,mlir::MLIRContext * context,mlir::OwningOpRef<mlir::ModuleOp> * module,int nr_dim_args,bool * dim_args_are_i64)130 void RefineDynamicShapes(XlaOpKernelContext *ctx, mlir::MLIRContext *context,
131                          mlir::OwningOpRef<mlir::ModuleOp> *module,
132                          int nr_dim_args, bool *dim_args_are_i64) {
133   // Locate the 'main' function.
134   // This is the convention used by MlirToXlaComputation.
135   auto main = (*module)->lookupSymbol<mlir::func::FuncOp>("main");
136   OP_REQUIRES(ctx, main,
137               errors::InvalidArgument("Cannot find 'main' in MHLO module"));
138   VLOG(3) << "XlaCallModule main function: " << debugString(main);
139   mlir::Block &main_body = main.front();
140 
141   OP_REQUIRES(ctx,
142               nr_dim_args + ctx->num_inputs() == main_body.getNumArguments(),
143               errors::InvalidArgument(
144                   "Incorrect number of arguments for XlaCallModule. ",
145                   "The module expects ", main_body.getNumArguments(),
146                   " and dim_args_spec specifies ", nr_dim_args,
147                   " dimension arguments, but there are ", ctx->num_inputs(),
148                   " actual arguments"));
149   // Obtain static input types in MLIR terms.
150   mlir::Builder builder(context);
151 
152   std::vector<mlir::Type> static_input_types(main_body.getNumArguments());
153   // The dim_arg parameters already have known types.
154   for (int i = 0; i < nr_dim_args; ++i) {
155     static_input_types[i] = main_body.getArgument(i).getType();
156     *dim_args_are_i64 = (static_input_types[i].getIntOrFloatBitWidth() == 64);
157   }
158 
159   // Now the actual arguments
160   for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
161     OP_REQUIRES_VALUE(xla::Shape xla_shape, ctx, ctx->InputXlaShape(i));
162     std::vector<int64_t> xla_dimensions(xla_shape.dimensions().begin(),
163                                         xla_shape.dimensions().end());
164     OP_REQUIRES_VALUE(
165         mlir::Type element_type, ctx,
166         ConvertPrimitiveTypeToMLIRType(xla_shape.element_type(), builder));
167     mlir::Type type = mlir::RankedTensorType::get(xla_dimensions, element_type);
168     // TODO(burmako): This fails with an obscure compilation error.
169     // OP_REQUIRES_VALUE(
170     //     mlir::Type type, ctx,
171     //     ConvertShapeToType<mlir::RankedTensorType>(xla_shape, builder));
172     VLOG(3) << "XlaCallModule static input type #" << nr_dim_args + i << ": "
173             << debugString(type);
174     static_input_types[nr_dim_args + i] = type;
175   }
176 
177   // Refine 'main' argument types to use static input types instead.
178   // This will only change the argument types and will not propagate the
179   // additional type information further. For that, we'll need to run
180   // shape inference as explained below.
181   main.setType(
182       builder.getFunctionType(static_input_types, main->getResultTypes()));
183   for (auto i = 0; i < main_body.getNumArguments(); ++i) {
184     main_body.getArgument(i).setType(static_input_types[i]);
185   }
186 
187   // --tf-shape-inference, despite its TF-specific name, seems to be general
188   // enough to also work on MHLO. (Although it fails if it doesn't see a
189   // tf.versions attribute on the module, which we hackily attach).
190   auto tf_producer =
191       builder.getNamedAttr("producer", builder.getI32IntegerAttr(0));
192   (**module)->setAttr("tf.versions", builder.getDictionaryAttr({tf_producer}));
193 
194   // Run --tf-shape-inference.
195   mlir::PassManager pm(context);
196   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
197   OP_REQUIRES(ctx, mlir::succeeded(pm.run(**module)),
198               errors::InvalidArgument("MHLO shape inference failed"));
199   VLOG(3) << "XlaCallModule main function with inferred types: "
200           << debugString(*main);
201 }
202 
203 // Compute the dim_arg inputs based on the static shapes of the actual arguments
204 // and put them in the inputs vector.
PopulateDimArgInputs(XlaOpKernelContext * ctx,std::vector<string> dim_args_spec,bool dim_args_are_i64,std::vector<xla::XlaOp> * inputs)205 void PopulateDimArgInputs(XlaOpKernelContext *ctx,
206                           std::vector<string> dim_args_spec,
207                           bool dim_args_are_i64,
208                           std::vector<xla::XlaOp> *inputs) {
209   int nr_dim_args = dim_args_spec.size();
210   for (int i = 0; i < nr_dim_args; ++i) {
211     string dim_arg_spec = dim_args_spec[i];
212     size_t dot_pos = dim_arg_spec.find('.');
213     OP_REQUIRES(
214         ctx, dot_pos != string::npos && dot_pos + 1 < dim_arg_spec.size(),
215         errors::InvalidArgument("Cannot parse dim_args_spec ", dim_arg_spec));
216     int arg_idx = std::stoi(dim_arg_spec.substr(0, dot_pos));
217     int arg_axis_idx = std::stoi(
218         dim_arg_spec.substr(dot_pos + 1, dim_arg_spec.size() - dot_pos));
219     OP_REQUIRES_VALUE(xla::Shape xla_shape, ctx, ctx->InputXlaShape(arg_idx));
220 
221     int64_t dim_arg_val = xla_shape.dimensions()[arg_axis_idx];
222     VLOG(3) << "XlaCallModule dim_input[" << i << "] = " << dim_arg_val;
223     if (dim_args_are_i64) {
224       (*inputs)[i] = xla::ConstantR0<int64_t>(ctx->builder(), dim_arg_val);
225     } else {
226       (*inputs)[i] = xla::ConstantR0<int32_t>(ctx->builder(), dim_arg_val);
227     }
228   }
229 }
230 
231 REGISTER_XLA_OP(Name("XlaCallModule"), XlaCallModuleOp);
232 }  // namespace
233 }  // namespace tensorflow
234