xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/get_compiler_ir.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/get_compiler_ir.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/jit/compilability_check_util.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/compiler/jit/xla_launch_util.h"
29 #include "tensorflow/compiler/jit/xla_platform_info.h"
30 #include "tensorflow/compiler/tf2xla/const_analysis.h"
31 #include "tensorflow/compiler/xla/client/executable_build_options.h"
32 #include "tensorflow/compiler/xla/client/local_client.h"
33 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
34 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/statusor.h"
40 #include "tensorflow/core/util/ptr_util.h"
41 
42 namespace tensorflow {
43 
BuildExecutable(xla::LocalClient * local_client,const XlaCompiler::CompilationResult & result,const XlaCompiler::Options & options,const bool xla_embed_ir_in_executable=false)44 static StatusOr<std::unique_ptr<xla::LocalExecutable>> BuildExecutable(
45     xla::LocalClient* local_client,
46     const XlaCompiler::CompilationResult& result,
47     const XlaCompiler::Options& options,
48     const bool xla_embed_ir_in_executable = false) {
49   std::vector<const xla::Shape*> argument_layouts(
50       result.xla_input_shapes.size());
51   for (int i = 0, end = result.xla_input_shapes.size(); i < end; ++i) {
52     argument_layouts[i] = &result.xla_input_shapes[i];
53   }
54   xla::ExecutableBuildOptions build_options;
55   if (result.collective_info) {
56     build_options.set_num_replicas(result.collective_info->group_size);
57   }
58   build_options.set_device_ordinal(
59       options.device_ordinal != -1 ? options.device_ordinal
60                                    : local_client->default_device_ordinal());
61   build_options.set_result_layout(result.xla_output_shape);
62   build_options.set_device_allocator(options.device_allocator.get());
63   build_options.set_alias_passthrough_params(options.alias_passthrough_params);
64   build_options.mutable_debug_options()->set_xla_detailed_logging_and_dumping(
65       options.detailed_logging);
66   // If the embed_ir_in_executable is set, hlo_proto will be dumped in
67   // executable. The hlo_proto contains HLO modules and buffer assignment.
68   build_options.mutable_debug_options()->set_xla_embed_ir_in_executable(
69       xla_embed_ir_in_executable);
70   TF_ASSIGN_OR_RETURN(
71       std::vector<std::unique_ptr<xla::LocalExecutable>> executables,
72       local_client->Compile(*result.computation, argument_layouts,
73                             build_options));
74   TF_RET_CHECK(executables.size() == 1);
75   return std::move(executables[0]);
76 }
77 
GetCompilerIr(IrExportStage stage,ProcessFunctionLibraryRuntime * pflr,absl::string_view func_name,Device * dev,EagerContext * context,absl::Span<const TensorHandle * const> inputs_handles)78 StatusOr<std::string> GetCompilerIr(
79     IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
80     absl::string_view func_name, Device* dev, EagerContext* context,
81     absl::Span<const TensorHandle* const> inputs_handles) {
82   // TODO(b/238830423): support GetCompilerIr on TFRT TPU device.
83   if (dev->device_type() != DEVICE_CPU &&
84       dev->tensorflow_accelerator_device_info()->stream == nullptr) {
85     return errors::Internal("GetCompilerIr is not supported on this device.");
86   }
87   NameAttrList function;
88   function.set_name(std::string{func_name});
89 
90   FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
91   ResourceMgr* rmgr = dev->resource_manager();
92 
93   const FunctionBody* fbody = nullptr;
94   std::vector<int> constant_arg_indices;
95   std::vector<int> resource_arg_indices;
96   TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
97       flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
98 
99   MemoryTypeVector input_memory_types =
100       GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
101   MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
102 
103   std::deque<Tensor> inputs_storage;
104   std::vector<const Tensor*> inputs;
105   inputs.reserve(inputs_handles.size());
106   for (int i = 0; i < inputs_handles.size(); i++) {
107     const TensorHandle* th = inputs_handles[i];
108     const Tensor* t;
109     // Handle owns the tensor.
110     TF_RETURN_IF_ERROR(th->Tensor(&t));
111     if (absl::c_binary_search(constant_arg_indices, i)) {
112       // Need to make sure it's on the host.
113       inputs_storage.emplace_back(t->dtype(), t->shape());
114       TF_RETURN_IF_ERROR(
115           th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
116       inputs.push_back(&inputs_storage.back());
117     } else {
118       inputs.push_back(t);
119     }
120   }
121 
122   std::vector<VariableInfo> variable_infos;
123   TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
124       rmgr, dev, inputs, resource_arg_indices, &variable_infos));
125   TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
126 
127   XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
128 
129   XlaCompilationCache* cache;
130   TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaCompilationCache>(
131       rmgr->default_container(), "xla_cache", &cache,
132       [&](XlaCompilationCache** cache_write_into) {
133         return BuildXlaCompilationCache(dev, flr, platform_info,
134                                         cache_write_into);
135       }));
136   core::ScopedUnref cache_ref(cache);
137 
138   se::Stream* stream = nullptr;
139   if (const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
140           dev->tensorflow_accelerator_device_info()) {
141     stream = accelerator_device_info->stream;
142   }
143 
144   XlaCompiler::Options options =
145       GenerateCompilerOptions(*cache, *flr, dev, stream, platform_info,
146                               /*has_ref_vars=*/false);
147 
148   XlaCompiler::CompileOptions compile_options;
149   compile_options.always_return_tuple = false;
150   compile_options.alias_resource_update = true;
151 
152   XlaCompiler compiler(options);
153 
154   StatusOr<std::vector<XlaCompiler::Argument>> args =
155       XlaComputationLaunchContext::BuildXlaCompilerArguments(
156           constant_arg_indices, inputs, variable_infos, dev);
157   TF_RETURN_IF_ERROR(args.status());
158 
159   xla::LocalClient* local_client = cache->client();
160   XlaCompiler::CompilationResult result;
161   TF_RETURN_IF_ERROR(
162       compiler.CompileFunction(compile_options, function, *args, &result));
163 
164   switch (stage) {
165     case IrExportStage::HLO:
166     case IrExportStage::HLO_NO_METADATA:
167     case IrExportStage::HLO_SERIALIZED: {
168       TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape,
169                           result.computation->GetProgramShape());
170       xla::HloModuleConfig config(program_shape);
171       TF_ASSIGN_OR_RETURN(
172           std::unique_ptr<xla::HloModule> new_module,
173           xla::HloModule::CreateFromProto(result.computation->proto(), config));
174 
175       xla::HloPrintOptions opts;
176       if (stage == IrExportStage::HLO_NO_METADATA) {
177         opts.set_print_metadata(false);
178       }
179 
180       if (stage == IrExportStage::HLO_SERIALIZED) {
181         return new_module->ToProto().SerializeAsString();
182       } else {
183         return new_module->ToString(opts);
184       }
185     }
186     case IrExportStage::OPTIMIZED_HLO:
187     case IrExportStage::OPTIMIZED_HLO_SERIALIZED: {
188       TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
189                           BuildExecutable(local_client, result, options));
190       xla::Executable* new_executable = executable->executable();
191       if (stage == IrExportStage::OPTIMIZED_HLO_SERIALIZED) {
192         return new_executable->module().ToProto().SerializeAsString();
193       } else {
194         return new_executable->module().ToString();
195       }
196     }
197     case IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED: {
198       TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
199                           BuildExecutable(local_client, result, options,
200                                           /*xla_embed_ir_in_executable=*/true));
201       return executable->executable()->hlo_proto()->SerializeAsString();
202     }
203     case IrExportStage::OPTIMIZED_HLO_DOT: {
204       TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
205                           BuildExecutable(local_client, result, options));
206       StatusOr<std::string> graph = xla::RenderGraph(
207           *executable->executable()->module().entry_computation(),
208           "Visualization",
209           /*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
210           /*hlo_execution_profile=*/nullptr,
211           /*hlo_render_options=*/{});
212       TF_RETURN_IF_ERROR(graph.status());
213       return *graph;
214     }
215   }
216 }
217 
218 }  // namespace tensorflow
219