1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/local_service.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/client/executable_build_options.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/service/backend.h"
29 #include "tensorflow/compiler/xla/service/computation_layout.h"
30 #include "tensorflow/compiler/xla/service/executable.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
35 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
36 #include "tensorflow/compiler/xla/service/platform_util.h"
37 #include "tensorflow/compiler/xla/shape_layout.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
44
45 namespace xla {
46
NewService(const ServiceOptions & options)47 /* static */ StatusOr<std::unique_ptr<LocalService>> LocalService::NewService(
48 const ServiceOptions& options) {
49 se::Platform* platform = options.platform();
50 if (platform == nullptr) {
51 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
52 }
53
54 BackendOptions backend_options;
55 backend_options.set_platform(platform)
56 .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads())
57 .set_allowed_devices(options.allowed_devices());
58
59 TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
60 Backend::CreateBackend(backend_options));
61
62 std::unique_ptr<LocalService> service(
63 new LocalService(options, std::move(backend)));
64 return std::move(service);
65 }
66
LocalService(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)67 LocalService::LocalService(const ServiceOptions& options,
68 std::unique_ptr<Backend> execute_backend)
69 : Service(options, std::move(execute_backend)) {}
70
71 namespace {
72
73 // Retrieves the parameter metadata for the given computation and parameter
74 // number.
75 //
76 // If the parameter number is invalid for this computation, nullopt is
77 // returned. When the return value has_value(), nullptr will never be
78 // the held value.
ParameterMetadata(const XlaComputation & computation,int parameter_number)79 std::optional<const OpMetadata*> ParameterMetadata(
80 const XlaComputation& computation, int parameter_number) {
81 for (const HloComputationProto& comp : computation.proto().computations()) {
82 if (comp.id() == computation.proto().entry_computation_id()) {
83 for (const HloInstructionProto& instr : comp.instructions()) {
84 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
85 instr.parameter_number() == parameter_number) {
86 if (!instr.has_metadata()) {
87 return std::nullopt;
88 }
89 return &instr.metadata();
90 }
91 }
92 }
93 }
94 return std::nullopt;
95 }
96
97 } // namespace
98
GetHloModuleConfig(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options)99 StatusOr<std::unique_ptr<HloModuleConfig>> LocalService::GetHloModuleConfig(
100 const XlaComputation& computation,
101 const absl::Span<const Shape* const> argument_layouts,
102 const ExecutableBuildOptions& build_options) {
103 const HloModuleProto& proto = computation.proto();
104 TF_RET_CHECK(proto.has_host_program_shape());
105 ProgramShape program_shape(proto.host_program_shape());
106
107 // Validate incoming layouts.
108 if (argument_layouts.size() != program_shape.parameters_size()) {
109 return InvalidArgument(
110 "Invalid number of arguments for computation: expected %d, got %u.",
111 program_shape.parameters_size(), argument_layouts.size());
112 }
113
114 for (int i = 0; i < argument_layouts.size(); ++i) {
115 const Shape& argument_shape = *argument_layouts[i];
116 TF_RETURN_IF_ERROR(
117 ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape));
118 if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
119 std::optional<const OpMetadata*> metadata =
120 ParameterMetadata(computation, /*parameter_number=*/i);
121 auto metadata_string = [&metadata]() -> std::string {
122 if (!metadata.has_value()) {
123 return "";
124 }
125 CHECK(metadata.value() != nullptr);
126 const OpMetadata& m = *metadata.value();
127 if (!m.source_file().empty()) {
128 return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line());
129 }
130 return "";
131 };
132 return InvalidArgument(
133 "Invalid argument shape for argument %d%s, expected %s, got %s.", i,
134 metadata_string(),
135 ShapeUtil::HumanString(program_shape.parameters(i)),
136 ShapeUtil::HumanString(argument_shape));
137 }
138 }
139 if (build_options.result_layout() != nullptr) {
140 TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(),
141 program_shape.result()));
142 }
143
144 ExecutionOptions execution_options =
145 CreateExecutionOptions(build_options, &program_shape);
146
147 return CreateModuleConfig(program_shape, argument_layouts,
148 &execution_options);
149 }
150
151 StatusOr<std::vector<std::unique_ptr<Executable>>>
CompileExecutables(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options)152 LocalService::CompileExecutables(
153 const XlaComputation& computation,
154 const absl::Span<const Shape* const> argument_layouts,
155 const ExecutableBuildOptions& build_options) {
156 TF_ASSIGN_OR_RETURN(
157 std::unique_ptr<HloModuleConfig> module_config,
158 GetHloModuleConfig(computation, argument_layouts, build_options));
159
160 VLOG(3) << "Computation Layout: "
161 << module_config->entry_computation_layout().ToString();
162
163 TF_ASSIGN_OR_RETURN(
164 se::StreamExecutor * executor,
165 execute_backend_->stream_executor(build_options.device_ordinal()));
166
167 // TODO(cjfj): Investigate why there are a couple of test failures when the
168 // single partition computations are built using `BuildExecutables`, fix it,
169 // and remove this special case (provided the performance if similar).
170 if (build_options.num_partitions() == 1) {
171 TF_ASSIGN_OR_RETURN(
172 std::unique_ptr<Executable> executable,
173 BuildExecutable(computation.proto(), std::move(module_config),
174 execute_backend_.get(), executor,
175 {build_options.device_allocator(),
176 build_options.compile_thread_pool()},
177 build_options.run_backend_only()));
178 std::vector<std::unique_ptr<Executable>> executables;
179 executables.push_back(std::move(executable));
180 return executables;
181 } else {
182 std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
183 module_configs.push_back(std::move(module_config));
184 // BuildExecutables uses the executors length to determine the number of
185 // cores per module, but otherwise only uses the first executor.
186 std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
187 executor);
188
189 return BuildExecutables(
190 /*module_protos=*/{&computation.proto()}, std::move(module_configs),
191 execute_backend_.get(), {executors},
192 Compiler::CompileOptions{build_options.device_allocator(),
193 build_options.compile_thread_pool()},
194 build_options.run_backend_only());
195 }
196 }
197
198 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAotResults(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options)199 LocalService::CompileAotResults(
200 const XlaComputation& computation,
201 const absl::Span<const Shape* const> argument_layouts,
202 const ExecutableBuildOptions& build_options) {
203 TF_ASSIGN_OR_RETURN(
204 std::unique_ptr<HloModuleConfig> module_config,
205 GetHloModuleConfig(computation, argument_layouts, build_options));
206
207 TF_ASSIGN_OR_RETURN(
208 se::StreamExecutor * executor,
209 execute_backend_->stream_executor(build_options.device_ordinal()));
210
211 std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
212 module_configs.push_back(std::move(module_config));
213 // BuildAotResults uses the executors length to determine the number of
214 // cores per module, but otherwise only uses the first executor.
215 std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
216 executor);
217
218 return BuildAotResults(
219 /*module_protos=*/{&computation.proto()}, std::move(module_configs),
220 execute_backend_.get(), {executors},
221 Compiler::CompileOptions{build_options.device_allocator(),
222 build_options.compile_thread_pool()},
223 build_options.run_backend_only());
224 }
225
ReplicaNumberToDeviceOrdinal(int replica_number)226 StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
227 return backend().computation_placer()->DeviceId(
228 replica_number, /*computation=*/0, options_.number_of_replicas(),
229 /*computation_count=*/1);
230 }
231
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)232 StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
233 const GlobalDataHandle& data, int replica_number) {
234 TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
235 if (replica_number >= buffers.size()) {
236 return InvalidArgument(
237 "replica_number %d out of range; must be less than num_replicas = %u.",
238 replica_number, buffers.size());
239 }
240 return buffers[replica_number];
241 }
242
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const std::string & tag)243 StatusOr<GlobalDataHandle> LocalService::RegisterReplicatedBuffers(
244 std::vector<ScopedShapedBuffer> replicated_buffers,
245 const std::string& tag) {
246 return allocation_tracker_.RegisterReplicatedBuffers(
247 std::move(replicated_buffers), tag);
248 }
249
250 } // namespace xla
251