xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/local_service.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/local_service.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/client/executable_build_options.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/service/backend.h"
29 #include "tensorflow/compiler/xla/service/computation_layout.h"
30 #include "tensorflow/compiler/xla/service/executable.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
35 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
36 #include "tensorflow/compiler/xla/service/platform_util.h"
37 #include "tensorflow/compiler/xla/shape_layout.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
44 
45 namespace xla {
46 
NewService(const ServiceOptions & options)47 /* static */ StatusOr<std::unique_ptr<LocalService>> LocalService::NewService(
48     const ServiceOptions& options) {
49   se::Platform* platform = options.platform();
50   if (platform == nullptr) {
51     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
52   }
53 
54   BackendOptions backend_options;
55   backend_options.set_platform(platform)
56       .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads())
57       .set_allowed_devices(options.allowed_devices());
58 
59   TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
60                       Backend::CreateBackend(backend_options));
61 
62   std::unique_ptr<LocalService> service(
63       new LocalService(options, std::move(backend)));
64   return std::move(service);
65 }
66 
LocalService(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)67 LocalService::LocalService(const ServiceOptions& options,
68                            std::unique_ptr<Backend> execute_backend)
69     : Service(options, std::move(execute_backend)) {}
70 
71 namespace {
72 
73 // Retrieves the parameter metadata for the given computation and parameter
74 // number.
75 //
76 // If the parameter number is invalid for this computation, nullopt is
77 // returned. When the return value has_value(), nullptr will never be
78 // the held value.
ParameterMetadata(const XlaComputation & computation,int parameter_number)79 std::optional<const OpMetadata*> ParameterMetadata(
80     const XlaComputation& computation, int parameter_number) {
81   for (const HloComputationProto& comp : computation.proto().computations()) {
82     if (comp.id() == computation.proto().entry_computation_id()) {
83       for (const HloInstructionProto& instr : comp.instructions()) {
84         if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
85             instr.parameter_number() == parameter_number) {
86           if (!instr.has_metadata()) {
87             return std::nullopt;
88           }
89           return &instr.metadata();
90         }
91       }
92     }
93   }
94   return std::nullopt;
95 }
96 
97 }  // namespace
98 
GetHloModuleConfig(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options)99 StatusOr<std::unique_ptr<HloModuleConfig>> LocalService::GetHloModuleConfig(
100     const XlaComputation& computation,
101     const absl::Span<const Shape* const> argument_layouts,
102     const ExecutableBuildOptions& build_options) {
103   const HloModuleProto& proto = computation.proto();
104   TF_RET_CHECK(proto.has_host_program_shape());
105   ProgramShape program_shape(proto.host_program_shape());
106 
107   // Validate incoming layouts.
108   if (argument_layouts.size() != program_shape.parameters_size()) {
109     return InvalidArgument(
110         "Invalid number of arguments for computation: expected %d, got %u.",
111         program_shape.parameters_size(), argument_layouts.size());
112   }
113 
114   for (int i = 0; i < argument_layouts.size(); ++i) {
115     const Shape& argument_shape = *argument_layouts[i];
116     TF_RETURN_IF_ERROR(
117         ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape));
118     if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
119       std::optional<const OpMetadata*> metadata =
120           ParameterMetadata(computation, /*parameter_number=*/i);
121       auto metadata_string = [&metadata]() -> std::string {
122         if (!metadata.has_value()) {
123           return "";
124         }
125         CHECK(metadata.value() != nullptr);
126         const OpMetadata& m = *metadata.value();
127         if (!m.source_file().empty()) {
128           return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line());
129         }
130         return "";
131       };
132       return InvalidArgument(
133           "Invalid argument shape for argument %d%s, expected %s, got %s.", i,
134           metadata_string(),
135           ShapeUtil::HumanString(program_shape.parameters(i)),
136           ShapeUtil::HumanString(argument_shape));
137     }
138   }
139   if (build_options.result_layout() != nullptr) {
140     TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(),
141                                            program_shape.result()));
142   }
143 
144   ExecutionOptions execution_options =
145       CreateExecutionOptions(build_options, &program_shape);
146 
147   return CreateModuleConfig(program_shape, argument_layouts,
148                             &execution_options);
149 }
150 
151 StatusOr<std::vector<std::unique_ptr<Executable>>>
CompileExecutables(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options)152 LocalService::CompileExecutables(
153     const XlaComputation& computation,
154     const absl::Span<const Shape* const> argument_layouts,
155     const ExecutableBuildOptions& build_options) {
156   TF_ASSIGN_OR_RETURN(
157       std::unique_ptr<HloModuleConfig> module_config,
158       GetHloModuleConfig(computation, argument_layouts, build_options));
159 
160   VLOG(3) << "Computation Layout: "
161           << module_config->entry_computation_layout().ToString();
162 
163   TF_ASSIGN_OR_RETURN(
164       se::StreamExecutor * executor,
165       execute_backend_->stream_executor(build_options.device_ordinal()));
166 
167   // TODO(cjfj): Investigate why there are a couple of test failures when the
168   // single partition computations are built using `BuildExecutables`, fix it,
169   // and remove this special case (provided the performance if similar).
170   if (build_options.num_partitions() == 1) {
171     TF_ASSIGN_OR_RETURN(
172         std::unique_ptr<Executable> executable,
173         BuildExecutable(computation.proto(), std::move(module_config),
174                         execute_backend_.get(), executor,
175                         {build_options.device_allocator(),
176                          build_options.compile_thread_pool()},
177                         build_options.run_backend_only()));
178     std::vector<std::unique_ptr<Executable>> executables;
179     executables.push_back(std::move(executable));
180     return executables;
181   } else {
182     std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
183     module_configs.push_back(std::move(module_config));
184     // BuildExecutables uses the executors length to determine the number of
185     // cores per module, but otherwise only uses the first executor.
186     std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
187                                                executor);
188 
189     return BuildExecutables(
190         /*module_protos=*/{&computation.proto()}, std::move(module_configs),
191         execute_backend_.get(), {executors},
192         Compiler::CompileOptions{build_options.device_allocator(),
193                                  build_options.compile_thread_pool()},
194         build_options.run_backend_only());
195   }
196 }
197 
198 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAotResults(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options)199 LocalService::CompileAotResults(
200     const XlaComputation& computation,
201     const absl::Span<const Shape* const> argument_layouts,
202     const ExecutableBuildOptions& build_options) {
203   TF_ASSIGN_OR_RETURN(
204       std::unique_ptr<HloModuleConfig> module_config,
205       GetHloModuleConfig(computation, argument_layouts, build_options));
206 
207   TF_ASSIGN_OR_RETURN(
208       se::StreamExecutor * executor,
209       execute_backend_->stream_executor(build_options.device_ordinal()));
210 
211   std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
212   module_configs.push_back(std::move(module_config));
213   // BuildAotResults uses the executors length to determine the number of
214   // cores per module, but otherwise only uses the first executor.
215   std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
216                                              executor);
217 
218   return BuildAotResults(
219       /*module_protos=*/{&computation.proto()}, std::move(module_configs),
220       execute_backend_.get(), {executors},
221       Compiler::CompileOptions{build_options.device_allocator(),
222                                build_options.compile_thread_pool()},
223       build_options.run_backend_only());
224 }
225 
ReplicaNumberToDeviceOrdinal(int replica_number)226 StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
227   return backend().computation_placer()->DeviceId(
228       replica_number, /*computation=*/0, options_.number_of_replicas(),
229       /*computation_count=*/1);
230 }
231 
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)232 StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
233     const GlobalDataHandle& data, int replica_number) {
234   TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
235   if (replica_number >= buffers.size()) {
236     return InvalidArgument(
237         "replica_number %d out of range; must be less than num_replicas = %u.",
238         replica_number, buffers.size());
239   }
240   return buffers[replica_number];
241 }
242 
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const std::string & tag)243 StatusOr<GlobalDataHandle> LocalService::RegisterReplicatedBuffers(
244     std::vector<ScopedShapedBuffer> replicated_buffers,
245     const std::string& tag) {
246   return allocation_tracker_.RegisterReplicatedBuffers(
247       std::move(replicated_buffers), tag);
248 }
249 
250 }  // namespace xla
251