xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/aot/compile.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/aot/compile.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/base/call_once.h"
24 #include "llvm-c/Target.h"
25 #include "llvm/Support/ManagedStatic.h"
26 #include "tensorflow/compiler/aot/codegen.h"
27 #include "tensorflow/compiler/aot/flags.h"
28 #include "tensorflow/compiler/aot/quantize.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/client/client_library.h"
32 #include "tensorflow/compiler/xla/client/compile_only_client.h"
33 #include "tensorflow/compiler/xla/client/xla_computation.h"
34 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/framework/graph.pb.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/io/path.h"
41 #include "tensorflow/core/lib/strings/proto_serialization.h"
42 #include "tensorflow/core/platform/env.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/regexp.h"
45 #include "tensorflow/core/platform/types.h"
46 
47 namespace tensorflow {
48 namespace tfcompile {
49 
50 static llvm::ManagedStatic<QuantizeXlaFn> quantize_xla;
51 
RegisterQuantizeFn(const QuantizeXlaFn & fn)52 bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
53   if (*quantize_xla) return false;
54   *quantize_xla = fn;
55   return true;
56 }
57 
58 namespace {
59 
60 // Compiles the XLA computation into executable code.
CompileXla(xla::CompileOnlyClient * client,const xla::XlaComputation & computation,const xla::cpu::CpuAotCompilationOptions & aot_opts,CompileResult * compile_result)61 Status CompileXla(xla::CompileOnlyClient* client,
62                   const xla::XlaComputation& computation,
63                   const xla::cpu::CpuAotCompilationOptions& aot_opts,
64                   CompileResult* compile_result) {
65   // Retrieves arg and result layouts from the computation.
66   // TODO(toddw): Should we let the user choose the major/minor ordering?
67   xla::StatusOr<std::unique_ptr<xla::ProgramShape>> pshape_or =
68       client->GetComputationShape(computation);
69   if (!pshape_or.ok()) {
70     return errors::Unknown("Couldn't get XLA program shape: ",
71                            pshape_or.status().error_message());
72   }
73   compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
74   xla::ProgramShapeProto* pshape = &compile_result->program_shape;
75 
76   // AotXlaComputationInstance::argument_layouts is a vector of Shape
77   // pointers. Accumulate the Shape objects themselves in a separate vector
78   // while building the vector of pointers.
79   std::vector<const xla::Shape*> arg_layout_ptrs(pshape->parameters_size());
80   std::vector<xla::Shape> arg_layouts(pshape->parameters_size());
81   for (int i = 0; i < pshape->parameters_size(); ++i) {
82     arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
83     arg_layout_ptrs[i] = &arg_layouts[i];
84   }
85   xla::CompileOnlyClient::AotXlaComputationInstance instance;
86   instance.computation = &computation;
87   instance.argument_layouts = std::move(arg_layout_ptrs);
88   xla::Shape result_shape(pshape->result());
89   instance.result_layout = &result_shape;
90   xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
91       aot_or = client->CompileAheadOfTime({instance}, aot_opts);
92   if (!aot_or.ok()) {
93     return errors::Unknown("XLA compilation failed: ",
94                            aot_or.status().error_message());
95   }
96   compile_result->aot =
97       xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
98           std::move(aot_or.ValueOrDie().back()));
99   compile_result->entry_point = aot_opts.entry_point_name();
100   compile_result->pointer_size =
101       xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
102   return OkStatus();
103 }
104 
105 }  // namespace
106 
CompileGraph(GraphDef graph_def,const tf2xla::Config & config,const MainFlags & flags,CompileResult * compile_result)107 Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
108                     const MainFlags& flags, CompileResult* compile_result) {
109   // Converts the graph into an XLA computation, and compiles the
110   // computation.
111   // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
112   se::Platform* cpu_platform =
113       se::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
114   xla::CompileOnlyClient* client =
115       xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
116           .ValueOrDie();
117   xla::XlaComputation computation;
118 
119   bool use_mlir_hlo_lowering = false;
120   bool use_mlir_bridge = false;
121   if (!flags.mlir_components.empty() && flags.mlir_components != "None") {
122     for (auto component : absl::StrSplit(flags.mlir_components, ',')) {
123       if (component == "Bridge") {
124         use_mlir_bridge = true;
125       } else if (component == "HloLowering") {
126         use_mlir_hlo_lowering = true;
127       } else {
128         return errors::Unknown("Unknown mlir_component ", component);
129       }
130     }
131   }
132   if (use_mlir_bridge) {
133     TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
134         graph_def, config, &computation, flags.debug_info,
135         flags.debug_info_path_begin_marker));
136   } else {
137     TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
138                                             client, &computation));
139   }
140 
141   if (flags.experimental_quantize && *quantize_xla) {
142     TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
143   }
144 
145   if (!flags.out_session_module.empty()) {
146     TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
147                         computation.Snapshot());
148     // Serialize the HloSnapshot deterministically so that all the outputs of a
149     // tf_library genrule are deterministic.
150     const size_t size = module->ByteSizeLong();
151     auto serialized = absl::make_unique<char[]>(size);
152     TF_RET_CHECK(
153         SerializeToBufferDeterministic(*module, serialized.get(), size));
154     TF_RETURN_IF_ERROR(
155         WriteStringToFile(Env::Default(), flags.out_session_module,
156                           absl::string_view(serialized.get(), size)));
157   }
158   xla::cpu::CpuAotCompilationOptions aot_opts(
159       flags.target_triple, flags.target_cpu, flags.target_features,
160       flags.entry_point,
161       xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
162   aot_opts.set_use_mlir_hlo_lowering(use_mlir_hlo_lowering);
163 
164   if (flags.sanitize_dataflow) {
165     aot_opts.set_sanitize_dataflow(flags.sanitize_dataflow);
166     aot_opts.set_sanitize_abilists_dataflow(absl::StrSplit(
167         flags.sanitize_abilists_dataflow, ',', absl::SkipEmpty()));
168   }
169 
170   return CompileXla(client, computation, aot_opts, compile_result);
171 }
172 
ReadProtoFile(const string & fname,protobuf::Message * proto)173 static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
174   if (absl::EndsWith(fname, ".pbtxt")) {
175     return ReadTextProto(Env::Default(), fname, proto);
176   } else {
177     return ReadBinaryProto(Env::Default(), fname, proto);
178   }
179 }
180 
181 static absl::once_flag targets_init;
182 
InitializeTargets()183 static void InitializeTargets() {
184   // Initialize all LLVM targets so we can cross compile.
185 #if TF_LLVM_AARCH64_AVAILABLE
186   LLVMInitializeAArch64Target();
187   LLVMInitializeAArch64TargetInfo();
188   LLVMInitializeAArch64TargetMC();
189   LLVMInitializeAArch64AsmPrinter();
190 #endif
191 #if TF_LLVM_S390X_AVAILABLE
192   LLVMInitializeSystemZTarget();
193   LLVMInitializeSystemZTargetInfo();
194   LLVMInitializeSystemZTargetMC();
195   LLVMInitializeSystemZAsmPrinter();
196 #endif
197   LLVMInitializeARMTarget();
198   LLVMInitializeARMTargetInfo();
199   LLVMInitializeARMTargetMC();
200   LLVMInitializeARMAsmPrinter();
201   LLVMInitializePowerPCTarget();
202   LLVMInitializePowerPCTargetInfo();
203   LLVMInitializePowerPCTargetMC();
204   LLVMInitializePowerPCAsmPrinter();
205   LLVMInitializeX86Target();
206   LLVMInitializeX86TargetInfo();
207   LLVMInitializeX86TargetMC();
208   LLVMInitializeX86AsmPrinter();
209 }
210 
211 // Replaces {{tag.type tag.name}} in the error message with tag_name.
212 // TODO(bixia): We currently only handlge tag.type == "node".
213 //
214 // In the error message, a graph node is represented as {{tag.type, tag.name}},
215 // to allow a Python debugger to insert source information about the graph node.
216 // For example, a Python add expression may be represented as
217 // {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
218 // in tensorflow/python/framework/error_interpolation.py for more detail.
InterpolateErrorMessage(std::string message)219 static std::string InterpolateErrorMessage(std::string message) {
220   // See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
221   // Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
222   static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
223   RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
224 
225   return message;
226 }
227 
Main(const MainFlags & flags)228 Status Main(const MainFlags& flags) {
229   absl::call_once(targets_init, &InitializeTargets);
230 
231   // Process config.
232   tf2xla::Config config;
233   if (flags.config.empty()) {
234     return errors::InvalidArgument("Must specify --config");
235   }
236   TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
237   TF_RETURN_IF_ERROR(ValidateConfig(config));
238   if (flags.dump_fetch_nodes) {
239     std::set<string> nodes;
240     for (const tf2xla::Fetch& fetch : config.fetch()) {
241       nodes.insert(fetch.id().node_name());
242     }
243     std::cout << absl::StrJoin(nodes, ",");
244     return OkStatus();
245   }
246 
247   // Read and initialize the graph.
248   if (flags.graph.empty()) {
249     return errors::InvalidArgument("Must specify --graph");
250   }
251   GraphDef graph_def;
252   TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
253   CompileResult compile_result;
254 
255   Status status =
256       CompileGraph(std::move(graph_def), config, flags, &compile_result);
257   if (!status.ok()) {
258     return errors::CreateWithUpdatedMessage(
259         status, InterpolateErrorMessage(status.error_message()));
260   }
261 
262   // Write output files.
263   Env* env = Env::Default();
264   const std::vector<char>& obj = compile_result.aot->object_file_data();
265   TF_RETURN_IF_ERROR(
266       WriteStringToFile(env, flags.out_function_object,
267                         absl::string_view(obj.data(), obj.size())));
268   CodegenOpts codegen_opts;
269   codegen_opts.gen_name_to_index = flags.gen_name_to_index;
270   codegen_opts.gen_program_shape = flags.gen_program_shape;
271   codegen_opts.target_triple = flags.target_triple;
272   if (flags.cpp_class.empty()) {
273     return errors::InvalidArgument("Must specify --cpp_class");
274   }
275   codegen_opts.gen_hlo_profile_printer_data =
276       xla::GetDebugOptionsFromFlags().xla_hlo_profile();
277   TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
278                                    &codegen_opts.namespaces));
279 
280   MetadataResult metadata_result;
281   TF_RETURN_IF_ERROR(
282       GenerateMetadata(codegen_opts, compile_result, &metadata_result));
283   TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
284                                        metadata_result.object_file_data));
285   string header;
286   TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
287                                     metadata_result, &header));
288   TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
289   return OkStatus();
290 }
291 
292 }  // namespace tfcompile
293 }  // namespace tensorflow
294