xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/xla_helpers.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines helper routines for XLA compilation.
17 
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 
20 #include <string>
21 
22 #include "absl/synchronization/notification.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/tf2xla/lib/util.h"
25 #include "tensorflow/compiler/tf2xla/literal_util.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/tf2xla/type_util.h"
28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/core/common_runtime/device_mgr.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/device.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/stream_executor/stream.h"
40 
41 namespace tensorflow {
42 
Zero(xla::XlaBuilder * b,DataType data_type)43 xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
44   xla::PrimitiveType type;
45   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
46   return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type));
47 }
48 
One(xla::XlaBuilder * b,DataType data_type)49 xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
50   xla::PrimitiveType type;
51   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
52   return xla::ConstantLiteral(b, xla::LiteralUtil::One(type));
53 }
54 
IntegerLiteral(xla::XlaBuilder * b,DataType data_type,int64_t value)55 xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
56                                       int64_t value) {
57   xla::PrimitiveType type;
58   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
59   return ::tensorflow::IntegerLiteral(b, type, value);
60 }
61 
FloatLiteral(xla::XlaBuilder * b,DataType data_type,double value)62 xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
63                                     double value) {
64   xla::PrimitiveType type;
65   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
66   return ::tensorflow::FloatLiteral(b, type, value);
67 }
68 
ReshapeLiteral(const xla::Literal & input,absl::Span<const int64_t> dimensions,xla::Literal * output)69 /* static */ Status XlaHelpers::ReshapeLiteral(
70     const xla::Literal& input, absl::Span<const int64_t> dimensions,
71     xla::Literal* output) {
72   if (input.shape().IsTuple()) {
73     return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
74   }
75   xla::Shape shape =
76       xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
77   int64_t elements_before = xla::ShapeUtil::ElementsIn(input.shape());
78   int64_t elements_after = xla::ShapeUtil::ElementsIn(shape);
79   if (elements_before != elements_after) {
80     return errors::InvalidArgument(
81         "Shapes before and after ReshapeLiteral have different numbers of "
82         "elements.");
83   }
84 
85   *output = input.Clone();
86   output->mutable_shape_do_not_use()->Swap(&shape);
87   return OkStatus();
88 }
89 
OneHot(xla::XlaBuilder * builder,int64_t depth,int axis,DataType index_type,const TensorShape & indices_shape,const xla::XlaOp & indices,const xla::XlaOp & on_value,const xla::XlaOp & off_value,xla::XlaOp * one_hot)90 Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64_t depth, int axis,
91                           DataType index_type, const TensorShape& indices_shape,
92                           const xla::XlaOp& indices, const xla::XlaOp& on_value,
93                           const xla::XlaOp& off_value, xla::XlaOp* one_hot) {
94   // Broadcast the linspace constant across the indices along the new axis,
95   // and test equality at each position.
96   std::vector<int64_t> broadcast_dims(indices_shape.dims());
97   std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
98   std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
99 
100   TensorShape output_shape = indices_shape;
101   output_shape.InsertDim(axis, depth);
102   xla::Shape iota_shape;
103   TF_RETURN_IF_ERROR(
104       TensorShapeToXLAShape(index_type, output_shape, &iota_shape));
105 
106   // Selects the user-provided off_value and on_value values.
107   *one_hot = xla::Select(
108       xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims),
109       xla::Broadcast(on_value, output_shape.dim_sizes()),
110       xla::Broadcast(off_value, output_shape.dim_sizes()));
111   return OkStatus();
112 }
113 
SumAccumulationType(const DataType & dtype)114 DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
115   // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
116   // repeated floating point additions.
117   if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
118     return DT_FLOAT;
119   }
120   // Upcast small integer types to 32 bit to avoid overflow.
121   if (dtype == DT_INT8 || dtype == DT_INT16) {
122     return DT_INT32;
123   }
124   if (dtype == DT_UINT8 || dtype == DT_UINT16) {
125     return DT_UINT32;
126   }
127   return dtype;
128 }
129 
ConvertElementType(const xla::XlaOp & operand,const DataType new_element_type)130 xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
131                                           const DataType new_element_type) {
132   xla::PrimitiveType convert_to;
133   TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
134   return xla::ConvertElementType(operand, convert_to);
135 }
136 
IdentityShapeRepresentationFn()137 XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() {
138   return [](const TensorShape& shape, DataType dtype, bool use_fast_memory,
139             XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
140     xla::Shape xla_shape;
141     TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
142     return xla_shape;
143   };
144 }
145 
ResolveDeviceAssignment(OpKernelContext * ctx,const XlaCompilationResult::CollectiveInfo & collective_info,xla::ExecutableRunOptions & run_options,xla::DeviceAssignment & device_assignment,xla::gpu::GpuExecutableRunOptions & gpu_options)146 Status ResolveDeviceAssignment(
147     OpKernelContext* ctx,
148     const XlaCompilationResult::CollectiveInfo& collective_info,
149     xla::ExecutableRunOptions& run_options,
150     xla::DeviceAssignment& device_assignment,
151     xla::gpu::GpuExecutableRunOptions& gpu_options) {
152   // TODO(nnigania): workaround for b/199436990
153   static const int kTimeoutSeconds = 1000;
154   if (ctx->collective_executor() == nullptr) {
155     return errors::InvalidArgument(
156         "CollectiveExecutor is required but not available");
157   }
158 
159   auto params = core::RefCountPtr<CollectiveParams>(new CollectiveParams());
160   params->name = "xla-reduction-compilation";
161   params->group.device_type =
162       DeviceType{static_cast<Device*>(ctx->device())->device_type()};
163   params->group.group_size = collective_info.group_size;
164   params->group.group_key = collective_info.group_key;
165   params->instance.type = REDUCTION_COLLECTIVE;
166   params->instance.impl_details.communication_hint = "nccl";
167   params->instance.impl_details.timeout_seconds = kTimeoutSeconds;
168   params->instance.impl_details.collective_name = "NcclReduce";
169   // TODO(cheshire): Avoid passing a dummy shape, TF runtime does not resolve
170   // devices otherwise.
171   params->instance.shape = TensorShape({1});
172 
173   VLOG(5) << "Using collective params to resolve device assignment: "
174           << params->ToString();
175 
176   Status st;
177   absl::Notification n;
178   ctx->collective_executor()->CompleteParamsAsync(
179       ctx->device()->attributes(), params.get(), ctx->cancellation_manager(),
180       [&](const Status& s) {
181         st = s;
182         n.Notify();
183       });
184   if (!n.WaitForNotificationWithTimeout(absl::Seconds(kTimeoutSeconds))) {
185     return errors::InvalidArgument("Timeout reached");
186   }
187   TF_RETURN_IF_ERROR(st);
188   VLOG(5) << "Collective params completed: " << params->ToString();
189 
190   // Identify the physical device associated with each replica.
191   device_assignment = xla::DeviceAssignment(params->group.group_size, 1);
192   for (int device_idx = 0; device_idx < params->group.group_size;
193        device_idx++) {
194     const DeviceAttributes& device = params->group.members[device_idx].device;
195     if (device.xla_global_id() == -1) {
196       if (params->group.device_type == DEVICE_TPU) {
197         return errors::InvalidArgument(
198             absl::StrCat("No global ID was set for TPU device ", device.name(),
199                          ". Try initializing the TPU system, e.g. "
200                          "`tf.tpu.experimental.initialize_tpu_system()`."));
201       } else if (params->group.device_type == DEVICE_GPU) {
202         return errors::Internal(
203             absl::StrCat("No global ID was set for ", device.name(),
204                          ". This is unexpected, please file a bug."));
205       } else {
206         // TODO(b/194942685): Implement CPU collectives.
207         return errors::Unimplemented(
208             absl::StrCat("Collectives are not yet implemented for ",
209                          params->group.device_type.type_string(),
210                          " devices when compiling with XLA. Attempted to "
211                          "compile a collective running on",
212                          device.name(),
213                          ". Please comment on b/194942685 or "
214                          "file a new bug if you don't have access."));
215       }
216     }
217     VLOG(2) << "Assigning physical id " << device.xla_global_id()
218             << " for replica " << device_idx << " (" << device.name() << ")";
219     device_assignment(device_idx, 0) = device.xla_global_id();
220   }
221   VLOG(5) << "Generated device assignment: " << device_assignment.ToString();
222   if (params->group.device_type == DEVICE_GPU) {
223     // For GPU collectives, `xla_global_id`s are arbitrary integers, and XLA
224     // requires a mapping from local device IDs to global device IDs.
225     const DeviceMgr* device_mgr = ctx->function_library()->device_mgr();
226     std::vector<xla::GlobalDeviceId> global_device_ids(
227         device_mgr->NumDeviceType(params->group.device_type.type_string()));
228 
229     for (int device_idx = 0; device_idx < params->group.group_size;
230          device_idx++) {
231       const DeviceAttributes& device_attributes =
232           params->group.members[device_idx].device;
233       Device* resolved_device = nullptr;
234       Status lookup_status =
235           device_mgr->LookupDevice(device_attributes.name(), &resolved_device);
236       if (lookup_status.ok()) {
237         // This is a local device, so include it in the mapping.
238         const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
239             resolved_device->tensorflow_accelerator_device_info();
240         global_device_ids[accelerator_device_info->stream->parent()
241                               ->device_ordinal()] =
242             device_attributes.xla_global_id();
243       }
244     }
245     gpu_options.set_gpu_global_device_ids(global_device_ids);
246   }
247   const std::string& communicator_key =
248       params->group.runtime_details.communicator_key;
249   gpu_options.set_nccl_unique_id_callback(
250       [=](const xla::gpu::NcclCliqueKey& key) { return communicator_key; });
251   run_options.set_device_assignment(&device_assignment);
252   run_options.set_gpu_executable_run_options(&gpu_options);
253   return OkStatus();
254 }
255 
256 }  // end namespace tensorflow
257