1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines helper routines for XLA compilation.
17
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19
20 #include <string>
21
22 #include "absl/synchronization/notification.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/tf2xla/lib/util.h"
25 #include "tensorflow/compiler/tf2xla/literal_util.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/tf2xla/type_util.h"
28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/core/common_runtime/device_mgr.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/device.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/stream_executor/stream.h"
40
41 namespace tensorflow {
42
Zero(xla::XlaBuilder * b,DataType data_type)43 xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
44 xla::PrimitiveType type;
45 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
46 return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type));
47 }
48
One(xla::XlaBuilder * b,DataType data_type)49 xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
50 xla::PrimitiveType type;
51 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
52 return xla::ConstantLiteral(b, xla::LiteralUtil::One(type));
53 }
54
IntegerLiteral(xla::XlaBuilder * b,DataType data_type,int64_t value)55 xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
56 int64_t value) {
57 xla::PrimitiveType type;
58 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
59 return ::tensorflow::IntegerLiteral(b, type, value);
60 }
61
FloatLiteral(xla::XlaBuilder * b,DataType data_type,double value)62 xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
63 double value) {
64 xla::PrimitiveType type;
65 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
66 return ::tensorflow::FloatLiteral(b, type, value);
67 }
68
ReshapeLiteral(const xla::Literal & input,absl::Span<const int64_t> dimensions,xla::Literal * output)69 /* static */ Status XlaHelpers::ReshapeLiteral(
70 const xla::Literal& input, absl::Span<const int64_t> dimensions,
71 xla::Literal* output) {
72 if (input.shape().IsTuple()) {
73 return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
74 }
75 xla::Shape shape =
76 xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
77 int64_t elements_before = xla::ShapeUtil::ElementsIn(input.shape());
78 int64_t elements_after = xla::ShapeUtil::ElementsIn(shape);
79 if (elements_before != elements_after) {
80 return errors::InvalidArgument(
81 "Shapes before and after ReshapeLiteral have different numbers of "
82 "elements.");
83 }
84
85 *output = input.Clone();
86 output->mutable_shape_do_not_use()->Swap(&shape);
87 return OkStatus();
88 }
89
OneHot(xla::XlaBuilder * builder,int64_t depth,int axis,DataType index_type,const TensorShape & indices_shape,const xla::XlaOp & indices,const xla::XlaOp & on_value,const xla::XlaOp & off_value,xla::XlaOp * one_hot)90 Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64_t depth, int axis,
91 DataType index_type, const TensorShape& indices_shape,
92 const xla::XlaOp& indices, const xla::XlaOp& on_value,
93 const xla::XlaOp& off_value, xla::XlaOp* one_hot) {
94 // Broadcast the linspace constant across the indices along the new axis,
95 // and test equality at each position.
96 std::vector<int64_t> broadcast_dims(indices_shape.dims());
97 std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
98 std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
99
100 TensorShape output_shape = indices_shape;
101 output_shape.InsertDim(axis, depth);
102 xla::Shape iota_shape;
103 TF_RETURN_IF_ERROR(
104 TensorShapeToXLAShape(index_type, output_shape, &iota_shape));
105
106 // Selects the user-provided off_value and on_value values.
107 *one_hot = xla::Select(
108 xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims),
109 xla::Broadcast(on_value, output_shape.dim_sizes()),
110 xla::Broadcast(off_value, output_shape.dim_sizes()));
111 return OkStatus();
112 }
113
SumAccumulationType(const DataType & dtype)114 DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
115 // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
116 // repeated floating point additions.
117 if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
118 return DT_FLOAT;
119 }
120 // Upcast small integer types to 32 bit to avoid overflow.
121 if (dtype == DT_INT8 || dtype == DT_INT16) {
122 return DT_INT32;
123 }
124 if (dtype == DT_UINT8 || dtype == DT_UINT16) {
125 return DT_UINT32;
126 }
127 return dtype;
128 }
129
ConvertElementType(const xla::XlaOp & operand,const DataType new_element_type)130 xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
131 const DataType new_element_type) {
132 xla::PrimitiveType convert_to;
133 TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
134 return xla::ConvertElementType(operand, convert_to);
135 }
136
IdentityShapeRepresentationFn()137 XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() {
138 return [](const TensorShape& shape, DataType dtype, bool use_fast_memory,
139 XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
140 xla::Shape xla_shape;
141 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
142 return xla_shape;
143 };
144 }
145
ResolveDeviceAssignment(OpKernelContext * ctx,const XlaCompilationResult::CollectiveInfo & collective_info,xla::ExecutableRunOptions & run_options,xla::DeviceAssignment & device_assignment,xla::gpu::GpuExecutableRunOptions & gpu_options)146 Status ResolveDeviceAssignment(
147 OpKernelContext* ctx,
148 const XlaCompilationResult::CollectiveInfo& collective_info,
149 xla::ExecutableRunOptions& run_options,
150 xla::DeviceAssignment& device_assignment,
151 xla::gpu::GpuExecutableRunOptions& gpu_options) {
152 // TODO(nnigania): workaround for b/199436990
153 static const int kTimeoutSeconds = 1000;
154 if (ctx->collective_executor() == nullptr) {
155 return errors::InvalidArgument(
156 "CollectiveExecutor is required but not available");
157 }
158
159 auto params = core::RefCountPtr<CollectiveParams>(new CollectiveParams());
160 params->name = "xla-reduction-compilation";
161 params->group.device_type =
162 DeviceType{static_cast<Device*>(ctx->device())->device_type()};
163 params->group.group_size = collective_info.group_size;
164 params->group.group_key = collective_info.group_key;
165 params->instance.type = REDUCTION_COLLECTIVE;
166 params->instance.impl_details.communication_hint = "nccl";
167 params->instance.impl_details.timeout_seconds = kTimeoutSeconds;
168 params->instance.impl_details.collective_name = "NcclReduce";
169 // TODO(cheshire): Avoid passing a dummy shape, TF runtime does not resolve
170 // devices otherwise.
171 params->instance.shape = TensorShape({1});
172
173 VLOG(5) << "Using collective params to resolve device assignment: "
174 << params->ToString();
175
176 Status st;
177 absl::Notification n;
178 ctx->collective_executor()->CompleteParamsAsync(
179 ctx->device()->attributes(), params.get(), ctx->cancellation_manager(),
180 [&](const Status& s) {
181 st = s;
182 n.Notify();
183 });
184 if (!n.WaitForNotificationWithTimeout(absl::Seconds(kTimeoutSeconds))) {
185 return errors::InvalidArgument("Timeout reached");
186 }
187 TF_RETURN_IF_ERROR(st);
188 VLOG(5) << "Collective params completed: " << params->ToString();
189
190 // Identify the physical device associated with each replica.
191 device_assignment = xla::DeviceAssignment(params->group.group_size, 1);
192 for (int device_idx = 0; device_idx < params->group.group_size;
193 device_idx++) {
194 const DeviceAttributes& device = params->group.members[device_idx].device;
195 if (device.xla_global_id() == -1) {
196 if (params->group.device_type == DEVICE_TPU) {
197 return errors::InvalidArgument(
198 absl::StrCat("No global ID was set for TPU device ", device.name(),
199 ". Try initializing the TPU system, e.g. "
200 "`tf.tpu.experimental.initialize_tpu_system()`."));
201 } else if (params->group.device_type == DEVICE_GPU) {
202 return errors::Internal(
203 absl::StrCat("No global ID was set for ", device.name(),
204 ". This is unexpected, please file a bug."));
205 } else {
206 // TODO(b/194942685): Implement CPU collectives.
207 return errors::Unimplemented(
208 absl::StrCat("Collectives are not yet implemented for ",
209 params->group.device_type.type_string(),
210 " devices when compiling with XLA. Attempted to "
211 "compile a collective running on",
212 device.name(),
213 ". Please comment on b/194942685 or "
214 "file a new bug if you don't have access."));
215 }
216 }
217 VLOG(2) << "Assigning physical id " << device.xla_global_id()
218 << " for replica " << device_idx << " (" << device.name() << ")";
219 device_assignment(device_idx, 0) = device.xla_global_id();
220 }
221 VLOG(5) << "Generated device assignment: " << device_assignment.ToString();
222 if (params->group.device_type == DEVICE_GPU) {
223 // For GPU collectives, `xla_global_id`s are arbitrary integers, and XLA
224 // requires a mapping from local device IDs to global device IDs.
225 const DeviceMgr* device_mgr = ctx->function_library()->device_mgr();
226 std::vector<xla::GlobalDeviceId> global_device_ids(
227 device_mgr->NumDeviceType(params->group.device_type.type_string()));
228
229 for (int device_idx = 0; device_idx < params->group.group_size;
230 device_idx++) {
231 const DeviceAttributes& device_attributes =
232 params->group.members[device_idx].device;
233 Device* resolved_device = nullptr;
234 Status lookup_status =
235 device_mgr->LookupDevice(device_attributes.name(), &resolved_device);
236 if (lookup_status.ok()) {
237 // This is a local device, so include it in the mapping.
238 const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
239 resolved_device->tensorflow_accelerator_device_info();
240 global_device_ids[accelerator_device_info->stream->parent()
241 ->device_ordinal()] =
242 device_attributes.xla_global_id();
243 }
244 }
245 gpu_options.set_gpu_global_device_ids(global_device_ids);
246 }
247 const std::string& communicator_key =
248 params->group.runtime_details.communicator_key;
249 gpu_options.set_nccl_unique_id_callback(
250 [=](const xla::gpu::NcclCliqueKey& key) { return communicator_key; });
251 run_options.set_device_assignment(&device_assignment);
252 run_options.set_gpu_executable_run_options(&gpu_options);
253 return OkStatus();
254 }
255
256 } // end namespace tensorflow
257