1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/get_compiler_ir.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/jit/compilability_check_util.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/compiler/jit/xla_launch_util.h"
29 #include "tensorflow/compiler/jit/xla_platform_info.h"
30 #include "tensorflow/compiler/tf2xla/const_analysis.h"
31 #include "tensorflow/compiler/xla/client/executable_build_options.h"
32 #include "tensorflow/compiler/xla/client/local_client.h"
33 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
34 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/statusor.h"
40 #include "tensorflow/core/util/ptr_util.h"
41
42 namespace tensorflow {
43
BuildExecutable(xla::LocalClient * local_client,const XlaCompiler::CompilationResult & result,const XlaCompiler::Options & options,const bool xla_embed_ir_in_executable=false)44 static StatusOr<std::unique_ptr<xla::LocalExecutable>> BuildExecutable(
45 xla::LocalClient* local_client,
46 const XlaCompiler::CompilationResult& result,
47 const XlaCompiler::Options& options,
48 const bool xla_embed_ir_in_executable = false) {
49 std::vector<const xla::Shape*> argument_layouts(
50 result.xla_input_shapes.size());
51 for (int i = 0, end = result.xla_input_shapes.size(); i < end; ++i) {
52 argument_layouts[i] = &result.xla_input_shapes[i];
53 }
54 xla::ExecutableBuildOptions build_options;
55 if (result.collective_info) {
56 build_options.set_num_replicas(result.collective_info->group_size);
57 }
58 build_options.set_device_ordinal(
59 options.device_ordinal != -1 ? options.device_ordinal
60 : local_client->default_device_ordinal());
61 build_options.set_result_layout(result.xla_output_shape);
62 build_options.set_device_allocator(options.device_allocator.get());
63 build_options.set_alias_passthrough_params(options.alias_passthrough_params);
64 build_options.mutable_debug_options()->set_xla_detailed_logging_and_dumping(
65 options.detailed_logging);
66 // If the embed_ir_in_executable is set, hlo_proto will be dumped in
67 // executable. The hlo_proto contains HLO modules and buffer assignment.
68 build_options.mutable_debug_options()->set_xla_embed_ir_in_executable(
69 xla_embed_ir_in_executable);
70 TF_ASSIGN_OR_RETURN(
71 std::vector<std::unique_ptr<xla::LocalExecutable>> executables,
72 local_client->Compile(*result.computation, argument_layouts,
73 build_options));
74 TF_RET_CHECK(executables.size() == 1);
75 return std::move(executables[0]);
76 }
77
GetCompilerIr(IrExportStage stage,ProcessFunctionLibraryRuntime * pflr,absl::string_view func_name,Device * dev,EagerContext * context,absl::Span<const TensorHandle * const> inputs_handles)78 StatusOr<std::string> GetCompilerIr(
79 IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
80 absl::string_view func_name, Device* dev, EagerContext* context,
81 absl::Span<const TensorHandle* const> inputs_handles) {
82 // TODO(b/238830423): support GetCompilerIr on TFRT TPU device.
83 if (dev->device_type() != DEVICE_CPU &&
84 dev->tensorflow_accelerator_device_info()->stream == nullptr) {
85 return errors::Internal("GetCompilerIr is not supported on this device.");
86 }
87 NameAttrList function;
88 function.set_name(std::string{func_name});
89
90 FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
91 ResourceMgr* rmgr = dev->resource_manager();
92
93 const FunctionBody* fbody = nullptr;
94 std::vector<int> constant_arg_indices;
95 std::vector<int> resource_arg_indices;
96 TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
97 flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
98
99 MemoryTypeVector input_memory_types =
100 GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
101 MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
102
103 std::deque<Tensor> inputs_storage;
104 std::vector<const Tensor*> inputs;
105 inputs.reserve(inputs_handles.size());
106 for (int i = 0; i < inputs_handles.size(); i++) {
107 const TensorHandle* th = inputs_handles[i];
108 const Tensor* t;
109 // Handle owns the tensor.
110 TF_RETURN_IF_ERROR(th->Tensor(&t));
111 if (absl::c_binary_search(constant_arg_indices, i)) {
112 // Need to make sure it's on the host.
113 inputs_storage.emplace_back(t->dtype(), t->shape());
114 TF_RETURN_IF_ERROR(
115 th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
116 inputs.push_back(&inputs_storage.back());
117 } else {
118 inputs.push_back(t);
119 }
120 }
121
122 std::vector<VariableInfo> variable_infos;
123 TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
124 rmgr, dev, inputs, resource_arg_indices, &variable_infos));
125 TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
126
127 XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
128
129 XlaCompilationCache* cache;
130 TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaCompilationCache>(
131 rmgr->default_container(), "xla_cache", &cache,
132 [&](XlaCompilationCache** cache_write_into) {
133 return BuildXlaCompilationCache(dev, flr, platform_info,
134 cache_write_into);
135 }));
136 core::ScopedUnref cache_ref(cache);
137
138 se::Stream* stream = nullptr;
139 if (const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
140 dev->tensorflow_accelerator_device_info()) {
141 stream = accelerator_device_info->stream;
142 }
143
144 XlaCompiler::Options options =
145 GenerateCompilerOptions(*cache, *flr, dev, stream, platform_info,
146 /*has_ref_vars=*/false);
147
148 XlaCompiler::CompileOptions compile_options;
149 compile_options.always_return_tuple = false;
150 compile_options.alias_resource_update = true;
151
152 XlaCompiler compiler(options);
153
154 StatusOr<std::vector<XlaCompiler::Argument>> args =
155 XlaComputationLaunchContext::BuildXlaCompilerArguments(
156 constant_arg_indices, inputs, variable_infos, dev);
157 TF_RETURN_IF_ERROR(args.status());
158
159 xla::LocalClient* local_client = cache->client();
160 XlaCompiler::CompilationResult result;
161 TF_RETURN_IF_ERROR(
162 compiler.CompileFunction(compile_options, function, *args, &result));
163
164 switch (stage) {
165 case IrExportStage::HLO:
166 case IrExportStage::HLO_NO_METADATA:
167 case IrExportStage::HLO_SERIALIZED: {
168 TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape,
169 result.computation->GetProgramShape());
170 xla::HloModuleConfig config(program_shape);
171 TF_ASSIGN_OR_RETURN(
172 std::unique_ptr<xla::HloModule> new_module,
173 xla::HloModule::CreateFromProto(result.computation->proto(), config));
174
175 xla::HloPrintOptions opts;
176 if (stage == IrExportStage::HLO_NO_METADATA) {
177 opts.set_print_metadata(false);
178 }
179
180 if (stage == IrExportStage::HLO_SERIALIZED) {
181 return new_module->ToProto().SerializeAsString();
182 } else {
183 return new_module->ToString(opts);
184 }
185 }
186 case IrExportStage::OPTIMIZED_HLO:
187 case IrExportStage::OPTIMIZED_HLO_SERIALIZED: {
188 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
189 BuildExecutable(local_client, result, options));
190 xla::Executable* new_executable = executable->executable();
191 if (stage == IrExportStage::OPTIMIZED_HLO_SERIALIZED) {
192 return new_executable->module().ToProto().SerializeAsString();
193 } else {
194 return new_executable->module().ToString();
195 }
196 }
197 case IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED: {
198 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
199 BuildExecutable(local_client, result, options,
200 /*xla_embed_ir_in_executable=*/true));
201 return executable->executable()->hlo_proto()->SerializeAsString();
202 }
203 case IrExportStage::OPTIMIZED_HLO_DOT: {
204 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
205 BuildExecutable(local_client, result, options));
206 StatusOr<std::string> graph = xla::RenderGraph(
207 *executable->executable()->module().entry_computation(),
208 "Visualization",
209 /*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
210 /*hlo_execution_profile=*/nullptr,
211 /*hlo_render_options=*/{});
212 TF_RETURN_IF_ERROR(graph.status());
213 return *graph;
214 }
215 }
216 }
217
218 } // namespace tensorflow
219