xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <memory>
21 #include <numeric>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/tf2xla/kernels/callback.pb.h"
27 #include "tensorflow/compiler/tf2xla/lib/util.h"
28 #include "tensorflow/compiler/tf2xla/shape_util.h"
29 #include "tensorflow/compiler/tf2xla/type_util.h"
30 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/compiler/xla/service/custom_call_status.h"
34 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h"
38 
39 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
40 #include "tensorflow/core/common_runtime/gpu/gpu_device.h"
41 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
42 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
43 #include "tensorflow/stream_executor/gpu/gpu_types.h"
44 #endif
45 
46 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
47 #include "tensorflow/core/framework/shape_inference.h"
48 #include "tensorflow/core/framework/tensor_shape.h"
49 #include "tensorflow/core/framework/types.pb.h"
50 
51 namespace tensorflow {
52 
53 namespace {
54 
55 const char* const kTfCallbackCustomCall = "GenericTfCallbackGPU";
56 
57 }
58 
TensorFromProto(const TensorProto & proto)59 static StatusOr<Tensor> TensorFromProto(const TensorProto& proto) {
60   Tensor out;
61   if (!out.FromProto(proto)) {
62     return se::port::InternalError("Failed deserializing a TensorProto");
63   }
64   return out;
65 }
66 
CallbackDataFromProto(const char * opaque,size_t opaque_len)67 static StatusOr<TfCallbackData> CallbackDataFromProto(const char* opaque,
68                                                       size_t opaque_len) {
69   TfCallbackData callback_data;
70   absl::string_view data{opaque, opaque_len};
71   callback_data.ParseFromString(std::string{data});  // NOLINT: OSS req
72   return callback_data;
73 }
74 
SerializeCallbackData(const TfCallbackData & data)75 static StatusOr<std::string> SerializeCallbackData(const TfCallbackData& data) {
76   return data.SerializeAsString();
77 }
78 
CompileToCustomCallCallingTfKernel(int graph_def_version,const NodeDef & node_def,XlaOpKernelContext * ctx) const79 Status LightOutsideCompilationOp::CompileToCustomCallCallingTfKernel(
80     int graph_def_version, const NodeDef& node_def,
81     XlaOpKernelContext* ctx) const {
82   const OpRegistrationData* data = OpRegistry::Global()->LookUp(node_def.op());
83   int num_inputs = ctx->num_inputs();
84   int num_outputs = ctx->num_outputs();
85 
86   std::vector<Tensor> tensor_storage(num_inputs);
87   std::vector<const Tensor*> input_tensors(num_inputs);
88   std::vector<shape_inference::ShapeHandle> input_shapes;
89 
90   shape_inference::InferenceContext ic(
91       graph_def_version, node_def, data->op_def,
92       std::vector<shape_inference::ShapeHandle>(num_inputs), {}, {}, {});
93   TF_RETURN_IF_ERROR(ic.construction_status());
94 
95   TfCallbackData callback_data;
96   *callback_data.mutable_op() = node_def;
97 
98   TF_ASSIGN_OR_RETURN(
99       std::vector<int> constant_inputs,
100       XlaOpRegistry::CompileTimeConstantInputs(node_def, data->op_def));
101   VLOG(1) << "Constant inputs we got: " << absl::StrJoin(constant_inputs, ", ");
102 
103   std::vector<xla::Shape> operand_shapes_with_layout;
104   std::vector<xla::XlaOp> operands;
105   for (int i = 0; i < num_inputs; ++i) {
106     TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, ctx->InputXlaShape(i));
107     if (absl::c_any_of(xla_shape.dynamic_dimensions(),
108                        [](const bool is_dynamic) { return is_dynamic; })) {
109       // TODO(cheshire): Support input dynamic dimensions.
110       return se::port::InternalError(
111           "Input dynamic dimensions are not supported for light outside "
112           "compilation");
113     }
114     // TODO(cheshire): Use InputXlaShape.
115     TensorShape shape = ctx->InputShape(i);
116     TfCallbackData::InputBufferDescription input_description;
117 
118     *input_description.mutable_buffer_description()->mutable_shape() =
119         shape.AsProto();
120     input_description.mutable_buffer_description()->set_type(
121         ctx->input_type(i));
122 
123     if (absl::c_linear_search(constant_inputs, i)) {
124       // Assuming kernels want to read INT32 datatypes.
125       TF_ASSIGN_OR_RETURN(Tensor input_tensor, ctx->ConstantInputTensor(i));
126       tensor_storage[i] = input_tensor;
127       input_tensors[i] = &tensor_storage.at(i);
128       input_tensor.AsProtoTensorContent(input_description.mutable_value());
129     } else {
130       input_tensors[i] = nullptr;
131       operands.push_back(ctx->Input(i));
132       operand_shapes_with_layout.push_back(xla_shape);
133       xla::LayoutUtil::SetToDefaultLayout(&operand_shapes_with_layout.back());
134     }
135 
136     *callback_data.add_inputs() = input_description;
137 
138     TF_ASSIGN_OR_RETURN(shape_inference::ShapeHandle handle,
139                         ic.MakeShapeFromShapeTensor(shape));
140     ic.SetInput(i, handle);
141   }
142 
143   ic.set_input_tensors(input_tensors);
144 
145   TF_RETURN_IF_ERROR(data->shape_inference_fn(&ic));
146   TF_ASSIGN_OR_RETURN(OutputDimensionBoundsMap output_dimension_bounds,
147                       DynamicOutputDimensions(node_def, ctx));
148 
149   std::vector<xla::Shape> output_xla_shapes;
150   for (int i = 0; i < num_outputs; ++i) {
151     const DimensionBoundsMap& dimension_bounds = output_dimension_bounds[i];
152     TfCallbackData::OutputBufferDescription output_description;
153     output_description.mutable_buffer_description()->set_type(
154         ctx->expected_output_dtype(i));
155 
156     TensorShapeProto output_tensor_shape_proto =
157         ic.ShapeHandleToProto(ic.output(i));
158     if (output_tensor_shape_proto.unknown_rank()) {
159       return se::port::InternalError(
160           absl::StrCat("Output ", i, " has unknown rank"));
161     }
162 
163     int rank = output_tensor_shape_proto.dim_size();
164     std::vector<bool> dynamic_dimensions(rank, false);
165 
166     // Modify output tensor shape proto to replace dynamic dimensions with upper
167     // bounds: that is the information we will be storing in the callback.
168     for (int d = 0; d < output_tensor_shape_proto.dim_size(); ++d) {
169       auto* dim = output_tensor_shape_proto.mutable_dim(d);
170       auto it = dimension_bounds.find(d);
171 
172       if (dim->size() < 0) {
173         if (it == dimension_bounds.end()) {
174           return se::port::InternalError(absl::StrCat(
175               "Bound for unknown dimension not found for dimension ", d));
176         }
177         dim->set_size(it->second);
178         dynamic_dimensions[d] = true;
179         output_description.set_is_dynamically_padded(true);
180       } else {
181         if (it == dimension_bounds.end()) {
182           continue;
183         }
184         if (it->second < dim->size()) {
185           dim->set_size(it->second);
186         }
187       }
188     }
189 
190     *output_description.mutable_buffer_description()->mutable_shape() =
191         output_tensor_shape_proto;
192     *callback_data.add_outputs() = output_description;
193 
194     TF_ASSIGN_OR_RETURN(
195         TensorShape output_tensor_shape,
196         TensorShape::BuildTensorShape(output_tensor_shape_proto));
197 
198     TF_ASSIGN_OR_RETURN(xla::Shape output_shape,
199                         TensorShapeToXLAShape(ctx->expected_output_dtype(i),
200                                               output_tensor_shape));
201 
202     // Set corresponding dynamic bounds on the output xla::Shape.
203     for (int64_t d = 0; d < dynamic_dimensions.size(); ++d) {
204       output_shape.set_dynamic_dimension(d, dynamic_dimensions[d]);
205     }
206     output_xla_shapes.push_back(output_shape);
207   }
208 
209   xla::Shape output_shape =
210       xla::ShapeUtil::MakeMaybeTupleShape(output_xla_shapes);
211 
212   VLOG(1) << "Created output shape: " << output_shape.ToString();
213 
214   TF_ASSIGN_OR_RETURN(std::string callback_data_serialized,
215                       SerializeCallbackData(callback_data));
216 
217   xla::XlaOp out = xla::CustomCallWithLayout(
218       ctx->builder(), kTfCallbackCustomCall, operands, output_shape,
219       /*operand_shapes_with_layout=*/operand_shapes_with_layout,
220       /*opaque=*/callback_data_serialized,
221       /*has_side_effect=*/false,
222       /*output_operand_aliasing=*/{},
223       /*literal=*/nullptr, xla::CustomCallSchedule::SCHEDULE_NONE,
224       xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING);
225 
226   for (int i = 0; i < num_outputs; ++i) {
227     ctx->SetOutput(i,
228                    output_shape.IsTuple() ? xla::GetTupleElement(out, i) : out);
229   }
230 
231   return OkStatus();
232 }
233 
234 namespace {
235 
236 class WriteIntoXlaBufferAllocator : public Allocator {
237  public:
WriteIntoXlaBufferAllocator(void * xla_buffer,size_t buffer_size,absl::string_view description)238   WriteIntoXlaBufferAllocator(void* xla_buffer, size_t buffer_size,
239                               absl::string_view description)
240       : xla_buffer_(xla_buffer),
241         buffer_size_(buffer_size),
242         description_(description) {}
243 
Name()244   std::string Name() override {
245     return absl::StrCat("allocator-xla-", description_);
246   }
247 
AllocateRaw(size_t alignment,size_t num_bytes)248   void* AllocateRaw(size_t alignment, size_t num_bytes) override {
249     VLOG(1) << "Faking allocation of " << num_bytes << " bytes into xla buffer "
250             << description_;
251 
252     if (num_bytes > buffer_size_) {
253       LOG(ERROR) << "Failed allocation: requested larger size than the "
254                     "underlying buffer";
255       return nullptr;
256     }
257     return xla_buffer_;
258   }
259 
260   // Do not perform our own memory management.
DeallocateRaw(void * ptr)261   void DeallocateRaw(void* ptr) override {
262     VLOG(1) << "Not deallocating pointer " << ptr << " for " << description_;
263   }
264 
265  private:
266   void* xla_buffer_;
267   size_t buffer_size_;
268   std::string description_;
269 };
270 
GetNumConstants(const TfCallbackData & callback_data)271 int GetNumConstants(const TfCallbackData& callback_data) {
272   return absl::c_count_if(callback_data.inputs(),
273                           [&](const auto& input) { return input.has_value(); });
274 }
275 
GetOutputBufferId(int output_num,const TfCallbackData & callback_data)276 int GetOutputBufferId(int output_num, const TfCallbackData& callback_data) {
277   return (callback_data.inputs_size() - GetNumConstants(callback_data)) +
278          output_num;
279 }
280 
BufferSize(const TfCallbackData::BufferDescription & descr)281 int64_t BufferSize(const TfCallbackData::BufferDescription& descr) {
282   TensorShape shape;
283   CHECK(TensorShape::BuildTensorShape(descr.shape(), &shape).ok());  // Crash OK
284   return shape.num_elements() * DataTypeSize(descr.type());
285 }
286 
287 class TfCallbackDevice : public DeviceBase {
288  public:
TfCallbackDevice(se::Stream * stream,void ** buffers,const TfCallbackData & callback_data)289   explicit TfCallbackDevice(se::Stream* stream, void** buffers,
290                             const TfCallbackData& callback_data)
291       : DeviceBase(Env::Default()),
292         stream_(stream),
293         gpu_allocator_(GPUProcessState::singleton()->GetGPUAllocator(
294             TfDeviceId{stream_->parent()->device_ordinal()})),
295         cpu_allocator_(
296             ProcessState::singleton()->GetCPUAllocator(/*numa_node=*/0)) {
297     for (int i = 0; i < callback_data.outputs_size(); ++i) {
298       int buffer_num = GetOutputBufferId(i, callback_data);
299       VLOG(1) << "Binding output " << i << " to buffer " << buffers[buffer_num];
300       int64_t buffer_size =
301           BufferSize(callback_data.outputs(i).buffer_description());
302       allocators_.emplace_back(buffers[buffer_num], buffer_size,
303                                absl::StrCat("xla-output-", i));
304     }
305 
306     accelerator_device_info_.stream = stream;
307     set_tensorflow_accelerator_device_info(&accelerator_device_info_);
308   }
309 
name() const310   const string& name() const override { return name_; }
311 
MakeGpuDevice()312   PerOpGpuDevice* MakeGpuDevice() override {
313 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
314     return new ConcretePerOpGpuDevice();
315 #else
316     LOG(FATAL) << "CUDA-enabled build is required";  // Crash OK
317 #endif
318   }
319 
ReinitializeGpuDevice(OpKernelContext * context,PerOpGpuDevice * device,DeviceContext * dc,Allocator * allocator)320   Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
321                                DeviceContext* dc,
322                                Allocator* allocator) override {
323 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
324     auto concrete_device = static_cast<ConcretePerOpGpuDevice*>(device);
325     const void* gpu_stream = reinterpret_cast<const void*>(
326         stream_->implementation()->GpuStreamMemberHack());
327     concrete_device->Reinitialize(
328         context, gpu_stream,
329         /*platform_device_id=*/
330         PlatformDeviceId(stream_->parent()->device_ordinal()), allocator,
331         // TODO(cheshire): Pass meaningful scratch
332         // buffer.
333         /*scratch=*/nullptr);
334     return OkStatus();
335 #else
336     LOG(FATAL) << "CUDA-enabled build is required";  // Crash OK
337 #endif
338   }
339 
GetScopedAllocator(AllocatorAttributes attrs,int64_t step_id)340   Allocator* GetScopedAllocator(AllocatorAttributes attrs,
341                                 int64_t step_id) override {
342     return &allocators_[attrs.scope_id - 1];
343   }
344 
GetAllocator(AllocatorAttributes attr)345   Allocator* GetAllocator(AllocatorAttributes attr) override {
346     if (attr.on_host()) {
347       if (attr.gpu_compatible()) {
348         GPUProcessState* ps = GPUProcessState::singleton();
349         return ps->GetGpuHostAllocator(0);
350       } else {
351         return cpu_allocator_;
352       }
353     } else {
354       return gpu_allocator_;
355     }
356   }
357 
358  private:
359   std::vector<WriteIntoXlaBufferAllocator> allocators_;
360   se::Stream* stream_;  // NOLINT (used under GOOGLE_CUDA)
361   Allocator* gpu_allocator_;
362   Allocator* cpu_allocator_;
363   AcceleratorDeviceInfo accelerator_device_info_;
364   std::string name_ = "tf_callback_device";
365 };
366 
367 // Populate the output with actual dimensions of the allocated shapes.
368 //
369 // Populates the vector on the host and then copies it over to the GPU.
PopulateMetadataBufferIfNeeded(OpKernelContext & ctx,const TfCallbackData & callback_data,void ** buffers,se::Stream * stream)370 Status PopulateMetadataBufferIfNeeded(OpKernelContext& ctx,
371                                       const TfCallbackData& callback_data,
372                                       void** buffers, se::Stream* stream) {
373   for (int i = 0; i < ctx.num_outputs(); i++) {
374     if (callback_data.outputs(i).is_dynamically_padded()) {
375       Tensor* allocated = ctx.mutable_output(i);
376       TensorShape allocated_shape = allocated->shape();
377       int num_dimensions = allocated_shape.dims();
378       std::vector<int32_t> shape_info(num_dimensions);
379       for (int d = 0; d < allocated_shape.dims(); d++) {
380         int dim_size = allocated_shape.dim_size(d);
381         shape_info[d] = dim_size;
382       }
383 
384       TF_ASSIGN_OR_RETURN(
385           xla::Shape xla_shape,
386           tensorflow::TensorShapeToXLAShape(
387               callback_data.outputs(i).buffer_description().type(),
388               callback_data.outputs(i).buffer_description().shape()));
389       void* location = static_cast<char*>(allocated->data()) +
390                        xla::ShapeUtil::ByteSizeOf(xla_shape);
391       se::DeviceMemoryBase m{location, num_dimensions * sizeof(int32_t)};
392       stream->ThenMemcpy(&m, shape_info.data(),
393                          num_dimensions * sizeof(int32_t));
394     }
395   }
396   return OkStatus();
397 }
398 
399 class FakeDeviceContext : public DeviceContext {
400  public:
FakeDeviceContext(se::Stream * stream)401   explicit FakeDeviceContext(se::Stream* stream) { stream_ = stream; }
stream() const402   se::Stream* stream() const override { return stream_; }
403 
404  private:
405   se::Stream* stream_;
406 };
407 
CallTfKernel(void * stream_handle,void ** buffers,const char * opaque,int opaque_len)408 Status CallTfKernel(void* stream_handle, void** buffers, const char* opaque,
409                     int opaque_len) {
410   TF_ASSIGN_OR_RETURN(se::Platform * platform,
411                       se::MultiPlatformManager::PlatformWithName("CUDA"));
412   se::StreamExecutorConfig config;
413   config.gpu_stream = stream_handle;
414   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
415                       platform->GetExecutor(config));
416   se::Stream* stream = executor->FindAllocatedStream(stream_handle);
417   if (!stream) {
418     return xla::InternalError("Stream not found for %p", stream_handle);
419   }
420   TF_ASSIGN_OR_RETURN(TfCallbackData callback_data,
421                       CallbackDataFromProto(opaque, opaque_len));
422   TfCallbackDevice device(stream, buffers, callback_data);
423 
424   std::vector<AllocatorAttributes> allocator_attributes;
425   for (int output_idx = 0; output_idx < callback_data.outputs_size();
426        ++output_idx) {
427     AllocatorAttributes attr;
428     // Repurpose `scope_id` to communicate which output is it.
429     // Shift by one to make it greater than zero.
430     attr.scope_id = output_idx + 1;
431     allocator_attributes.push_back(attr);
432   }
433 
434   Status nested_status;
435   std::unique_ptr<OpKernel> kernel =
436       CreateOpKernel(DeviceType(DEVICE_GPU),
437                      /*device=*/&device,
438 
439                      // NB: real allocator is passed with device, the one here
440                      // is only called during the kernel construction.
441                      // TODO(cheshire): Pass scratch allocator.
442                      /*allocator=*/nullptr, callback_data.op(),
443                      /*graph_def_version=*/1, &nested_status);
444   TF_RETURN_IF_ERROR(nested_status);
445 
446   auto device_context =
447       core::RefCountPtr<FakeDeviceContext>(new FakeDeviceContext(stream));
448 
449   OpKernelContext::Params params;
450   params.output_attr_array = allocator_attributes.data();
451   params.op_kernel = kernel.get();
452   params.device = &device;
453   params.ensure_eigen_gpu_device();
454   params.op_device_context = device_context.get();
455 
456   absl::InlinedVector<TensorValue, 4> inputs;
457 
458   // Deque usage is important to avoid moving objects.
459   std::deque<WriteIntoXlaBufferAllocator> input_allocators;
460   std::deque<Tensor> input_tensors;
461 
462   int constant_offset = 0;
463 
464   for (int i = 0; i < callback_data.inputs_size(); ++i) {
465     DataType dt = callback_data.inputs(i).buffer_description().type();
466 
467     TensorShape shape;
468     TF_RETURN_IF_ERROR(TensorShape::BuildTensorShape(
469         callback_data.inputs(i).buffer_description().shape(), &shape));
470 
471     VLOG(2) << "Input shape: " << shape.DebugString();
472     int64_t input_size = shape.num_elements() * DataTypeSize(dt);
473 
474     if (callback_data.inputs(i).has_value()) {
475       // Value provided at compile time: reconstruct the tensor.
476       TF_ASSIGN_OR_RETURN(Tensor input,
477                           TensorFromProto(callback_data.inputs(i).value()));
478       input_tensors.push_back(input);
479 
480       constant_offset++;
481       VLOG(1) << "Input " << i << " is a tensor: " << input.DebugString();
482     } else {
483       VLOG(1) << "Reading into the buffer for the input " << i;
484 
485       // We only get backing input buffer for those inputs which are *not*
486       // forced to be constant at compile time.
487       input_allocators.emplace_back(buffers[i - constant_offset], input_size,
488                                     absl::StrCat("input-", i));
489       input_tensors.emplace_back(&input_allocators[i], dt, shape);
490     }
491     inputs.emplace_back(&input_tensors.back());
492   }
493 
494   params.inputs = inputs;
495   OpKernelContext ctx(&params, callback_data.outputs_size());
496   kernel->Compute(&ctx);
497 
498   bool has_dynamic_outputs = absl::c_any_of(
499       callback_data.outputs(),
500       [](const auto& out) { return out.is_dynamically_padded(); });
501 
502   if (has_dynamic_outputs) {
503     TF_RETURN_IF_ERROR(
504         PopulateMetadataBufferIfNeeded(ctx, callback_data, buffers, stream));
505   }
506 
507   TF_RETURN_IF_ERROR(ctx.status());
508   return OkStatus();
509 }
510 
GenericTfCallback(void * stream_handle,void ** buffers,const char * opaque,int opaque_len,XlaCustomCallStatus * status)511 void GenericTfCallback(void* stream_handle, void** buffers, const char* opaque,
512                        int opaque_len, XlaCustomCallStatus* status) {
513   Status s = CallTfKernel(stream_handle, buffers, opaque, opaque_len);
514   if (!s.ok()) {
515     XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
516                                   s.error_message().size());
517   }
518 }
519 
520 XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(kTfCallbackCustomCall,
521                                          GenericTfCallback, "CUDA");
522 
523 }  // namespace
524 
LightOutsideCompilationOp(OpKernelConstruction * context)525 LightOutsideCompilationOp::LightOutsideCompilationOp(
526     OpKernelConstruction* context)
527     : XlaOpKernel(context),
528       def_(context->def()),
529       graph_def_version_(context->graph_def_version()) {}
530 
Compile(XlaOpKernelContext * ctx)531 void LightOutsideCompilationOp::Compile(XlaOpKernelContext* ctx) {
532   OP_REQUIRES_OK(
533       ctx, CompileToCustomCallCallingTfKernel(graph_def_version_, def_, ctx));
534 }
535 
536 }  // namespace tensorflow
537