1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/aot/compile.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/base/call_once.h"
24 #include "llvm-c/Target.h"
25 #include "llvm/Support/ManagedStatic.h"
26 #include "tensorflow/compiler/aot/codegen.h"
27 #include "tensorflow/compiler/aot/flags.h"
28 #include "tensorflow/compiler/aot/quantize.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/client/client_library.h"
32 #include "tensorflow/compiler/xla/client/compile_only_client.h"
33 #include "tensorflow/compiler/xla/client/xla_computation.h"
34 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/framework/graph.pb.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/io/path.h"
41 #include "tensorflow/core/lib/strings/proto_serialization.h"
42 #include "tensorflow/core/platform/env.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/regexp.h"
45 #include "tensorflow/core/platform/types.h"
46
47 namespace tensorflow {
48 namespace tfcompile {
49
50 static llvm::ManagedStatic<QuantizeXlaFn> quantize_xla;
51
RegisterQuantizeFn(const QuantizeXlaFn & fn)52 bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
53 if (*quantize_xla) return false;
54 *quantize_xla = fn;
55 return true;
56 }
57
58 namespace {
59
60 // Compiles the XLA computation into executable code.
CompileXla(xla::CompileOnlyClient * client,const xla::XlaComputation & computation,const xla::cpu::CpuAotCompilationOptions & aot_opts,CompileResult * compile_result)61 Status CompileXla(xla::CompileOnlyClient* client,
62 const xla::XlaComputation& computation,
63 const xla::cpu::CpuAotCompilationOptions& aot_opts,
64 CompileResult* compile_result) {
65 // Retrieves arg and result layouts from the computation.
66 // TODO(toddw): Should we let the user choose the major/minor ordering?
67 xla::StatusOr<std::unique_ptr<xla::ProgramShape>> pshape_or =
68 client->GetComputationShape(computation);
69 if (!pshape_or.ok()) {
70 return errors::Unknown("Couldn't get XLA program shape: ",
71 pshape_or.status().error_message());
72 }
73 compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
74 xla::ProgramShapeProto* pshape = &compile_result->program_shape;
75
76 // AotXlaComputationInstance::argument_layouts is a vector of Shape
77 // pointers. Accumulate the Shape objects themselves in a separate vector
78 // while building the vector of pointers.
79 std::vector<const xla::Shape*> arg_layout_ptrs(pshape->parameters_size());
80 std::vector<xla::Shape> arg_layouts(pshape->parameters_size());
81 for (int i = 0; i < pshape->parameters_size(); ++i) {
82 arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
83 arg_layout_ptrs[i] = &arg_layouts[i];
84 }
85 xla::CompileOnlyClient::AotXlaComputationInstance instance;
86 instance.computation = &computation;
87 instance.argument_layouts = std::move(arg_layout_ptrs);
88 xla::Shape result_shape(pshape->result());
89 instance.result_layout = &result_shape;
90 xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
91 aot_or = client->CompileAheadOfTime({instance}, aot_opts);
92 if (!aot_or.ok()) {
93 return errors::Unknown("XLA compilation failed: ",
94 aot_or.status().error_message());
95 }
96 compile_result->aot =
97 xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
98 std::move(aot_or.ValueOrDie().back()));
99 compile_result->entry_point = aot_opts.entry_point_name();
100 compile_result->pointer_size =
101 xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
102 return OkStatus();
103 }
104
105 } // namespace
106
CompileGraph(GraphDef graph_def,const tf2xla::Config & config,const MainFlags & flags,CompileResult * compile_result)107 Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
108 const MainFlags& flags, CompileResult* compile_result) {
109 // Converts the graph into an XLA computation, and compiles the
110 // computation.
111 // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
112 se::Platform* cpu_platform =
113 se::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
114 xla::CompileOnlyClient* client =
115 xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
116 .ValueOrDie();
117 xla::XlaComputation computation;
118
119 bool use_mlir_hlo_lowering = false;
120 bool use_mlir_bridge = false;
121 if (!flags.mlir_components.empty() && flags.mlir_components != "None") {
122 for (auto component : absl::StrSplit(flags.mlir_components, ',')) {
123 if (component == "Bridge") {
124 use_mlir_bridge = true;
125 } else if (component == "HloLowering") {
126 use_mlir_hlo_lowering = true;
127 } else {
128 return errors::Unknown("Unknown mlir_component ", component);
129 }
130 }
131 }
132 if (use_mlir_bridge) {
133 TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
134 graph_def, config, &computation, flags.debug_info,
135 flags.debug_info_path_begin_marker));
136 } else {
137 TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
138 client, &computation));
139 }
140
141 if (flags.experimental_quantize && *quantize_xla) {
142 TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
143 }
144
145 if (!flags.out_session_module.empty()) {
146 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
147 computation.Snapshot());
148 // Serialize the HloSnapshot deterministically so that all the outputs of a
149 // tf_library genrule are deterministic.
150 const size_t size = module->ByteSizeLong();
151 auto serialized = absl::make_unique<char[]>(size);
152 TF_RET_CHECK(
153 SerializeToBufferDeterministic(*module, serialized.get(), size));
154 TF_RETURN_IF_ERROR(
155 WriteStringToFile(Env::Default(), flags.out_session_module,
156 absl::string_view(serialized.get(), size)));
157 }
158 xla::cpu::CpuAotCompilationOptions aot_opts(
159 flags.target_triple, flags.target_cpu, flags.target_features,
160 flags.entry_point,
161 xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
162 aot_opts.set_use_mlir_hlo_lowering(use_mlir_hlo_lowering);
163
164 if (flags.sanitize_dataflow) {
165 aot_opts.set_sanitize_dataflow(flags.sanitize_dataflow);
166 aot_opts.set_sanitize_abilists_dataflow(absl::StrSplit(
167 flags.sanitize_abilists_dataflow, ',', absl::SkipEmpty()));
168 }
169
170 return CompileXla(client, computation, aot_opts, compile_result);
171 }
172
ReadProtoFile(const string & fname,protobuf::Message * proto)173 static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
174 if (absl::EndsWith(fname, ".pbtxt")) {
175 return ReadTextProto(Env::Default(), fname, proto);
176 } else {
177 return ReadBinaryProto(Env::Default(), fname, proto);
178 }
179 }
180
181 static absl::once_flag targets_init;
182
InitializeTargets()183 static void InitializeTargets() {
184 // Initialize all LLVM targets so we can cross compile.
185 #if TF_LLVM_AARCH64_AVAILABLE
186 LLVMInitializeAArch64Target();
187 LLVMInitializeAArch64TargetInfo();
188 LLVMInitializeAArch64TargetMC();
189 LLVMInitializeAArch64AsmPrinter();
190 #endif
191 #if TF_LLVM_S390X_AVAILABLE
192 LLVMInitializeSystemZTarget();
193 LLVMInitializeSystemZTargetInfo();
194 LLVMInitializeSystemZTargetMC();
195 LLVMInitializeSystemZAsmPrinter();
196 #endif
197 LLVMInitializeARMTarget();
198 LLVMInitializeARMTargetInfo();
199 LLVMInitializeARMTargetMC();
200 LLVMInitializeARMAsmPrinter();
201 LLVMInitializePowerPCTarget();
202 LLVMInitializePowerPCTargetInfo();
203 LLVMInitializePowerPCTargetMC();
204 LLVMInitializePowerPCAsmPrinter();
205 LLVMInitializeX86Target();
206 LLVMInitializeX86TargetInfo();
207 LLVMInitializeX86TargetMC();
208 LLVMInitializeX86AsmPrinter();
209 }
210
211 // Replaces {{tag.type tag.name}} in the error message with tag_name.
212 // TODO(bixia): We currently only handlge tag.type == "node".
213 //
214 // In the error message, a graph node is represented as {{tag.type, tag.name}},
215 // to allow a Python debugger to insert source information about the graph node.
216 // For example, a Python add expression may be represented as
217 // {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
218 // in tensorflow/python/framework/error_interpolation.py for more detail.
InterpolateErrorMessage(std::string message)219 static std::string InterpolateErrorMessage(std::string message) {
220 // See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
221 // Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
222 static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
223 RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
224
225 return message;
226 }
227
Main(const MainFlags & flags)228 Status Main(const MainFlags& flags) {
229 absl::call_once(targets_init, &InitializeTargets);
230
231 // Process config.
232 tf2xla::Config config;
233 if (flags.config.empty()) {
234 return errors::InvalidArgument("Must specify --config");
235 }
236 TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
237 TF_RETURN_IF_ERROR(ValidateConfig(config));
238 if (flags.dump_fetch_nodes) {
239 std::set<string> nodes;
240 for (const tf2xla::Fetch& fetch : config.fetch()) {
241 nodes.insert(fetch.id().node_name());
242 }
243 std::cout << absl::StrJoin(nodes, ",");
244 return OkStatus();
245 }
246
247 // Read and initialize the graph.
248 if (flags.graph.empty()) {
249 return errors::InvalidArgument("Must specify --graph");
250 }
251 GraphDef graph_def;
252 TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
253 CompileResult compile_result;
254
255 Status status =
256 CompileGraph(std::move(graph_def), config, flags, &compile_result);
257 if (!status.ok()) {
258 return errors::CreateWithUpdatedMessage(
259 status, InterpolateErrorMessage(status.error_message()));
260 }
261
262 // Write output files.
263 Env* env = Env::Default();
264 const std::vector<char>& obj = compile_result.aot->object_file_data();
265 TF_RETURN_IF_ERROR(
266 WriteStringToFile(env, flags.out_function_object,
267 absl::string_view(obj.data(), obj.size())));
268 CodegenOpts codegen_opts;
269 codegen_opts.gen_name_to_index = flags.gen_name_to_index;
270 codegen_opts.gen_program_shape = flags.gen_program_shape;
271 codegen_opts.target_triple = flags.target_triple;
272 if (flags.cpp_class.empty()) {
273 return errors::InvalidArgument("Must specify --cpp_class");
274 }
275 codegen_opts.gen_hlo_profile_printer_data =
276 xla::GetDebugOptionsFromFlags().xla_hlo_profile();
277 TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
278 &codegen_opts.namespaces));
279
280 MetadataResult metadata_result;
281 TF_RETURN_IF_ERROR(
282 GenerateMetadata(codegen_opts, compile_result, &metadata_result));
283 TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
284 metadata_result.object_file_data));
285 string header;
286 TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
287 metadata_result, &header));
288 TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
289 return OkStatus();
290 }
291
292 } // namespace tfcompile
293 } // namespace tensorflow
294