1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <memory>
21 #include <numeric>
22 #include <string>
23 #include <utility>
24
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/tf2xla/kernels/callback.pb.h"
27 #include "tensorflow/compiler/tf2xla/lib/util.h"
28 #include "tensorflow/compiler/tf2xla/shape_util.h"
29 #include "tensorflow/compiler/tf2xla/type_util.h"
30 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/compiler/xla/service/custom_call_status.h"
34 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h"
38
39 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
40 #include "tensorflow/core/common_runtime/gpu/gpu_device.h"
41 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
42 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
43 #include "tensorflow/stream_executor/gpu/gpu_types.h"
44 #endif
45
46 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
47 #include "tensorflow/core/framework/shape_inference.h"
48 #include "tensorflow/core/framework/tensor_shape.h"
49 #include "tensorflow/core/framework/types.pb.h"
50
51 namespace tensorflow {
52
53 namespace {
54
55 const char* const kTfCallbackCustomCall = "GenericTfCallbackGPU";
56
57 }
58
TensorFromProto(const TensorProto & proto)59 static StatusOr<Tensor> TensorFromProto(const TensorProto& proto) {
60 Tensor out;
61 if (!out.FromProto(proto)) {
62 return se::port::InternalError("Failed deserializing a TensorProto");
63 }
64 return out;
65 }
66
CallbackDataFromProto(const char * opaque,size_t opaque_len)67 static StatusOr<TfCallbackData> CallbackDataFromProto(const char* opaque,
68 size_t opaque_len) {
69 TfCallbackData callback_data;
70 absl::string_view data{opaque, opaque_len};
71 callback_data.ParseFromString(std::string{data}); // NOLINT: OSS req
72 return callback_data;
73 }
74
SerializeCallbackData(const TfCallbackData & data)75 static StatusOr<std::string> SerializeCallbackData(const TfCallbackData& data) {
76 return data.SerializeAsString();
77 }
78
CompileToCustomCallCallingTfKernel(int graph_def_version,const NodeDef & node_def,XlaOpKernelContext * ctx) const79 Status LightOutsideCompilationOp::CompileToCustomCallCallingTfKernel(
80 int graph_def_version, const NodeDef& node_def,
81 XlaOpKernelContext* ctx) const {
82 const OpRegistrationData* data = OpRegistry::Global()->LookUp(node_def.op());
83 int num_inputs = ctx->num_inputs();
84 int num_outputs = ctx->num_outputs();
85
86 std::vector<Tensor> tensor_storage(num_inputs);
87 std::vector<const Tensor*> input_tensors(num_inputs);
88 std::vector<shape_inference::ShapeHandle> input_shapes;
89
90 shape_inference::InferenceContext ic(
91 graph_def_version, node_def, data->op_def,
92 std::vector<shape_inference::ShapeHandle>(num_inputs), {}, {}, {});
93 TF_RETURN_IF_ERROR(ic.construction_status());
94
95 TfCallbackData callback_data;
96 *callback_data.mutable_op() = node_def;
97
98 TF_ASSIGN_OR_RETURN(
99 std::vector<int> constant_inputs,
100 XlaOpRegistry::CompileTimeConstantInputs(node_def, data->op_def));
101 VLOG(1) << "Constant inputs we got: " << absl::StrJoin(constant_inputs, ", ");
102
103 std::vector<xla::Shape> operand_shapes_with_layout;
104 std::vector<xla::XlaOp> operands;
105 for (int i = 0; i < num_inputs; ++i) {
106 TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, ctx->InputXlaShape(i));
107 if (absl::c_any_of(xla_shape.dynamic_dimensions(),
108 [](const bool is_dynamic) { return is_dynamic; })) {
109 // TODO(cheshire): Support input dynamic dimensions.
110 return se::port::InternalError(
111 "Input dynamic dimensions are not supported for light outside "
112 "compilation");
113 }
114 // TODO(cheshire): Use InputXlaShape.
115 TensorShape shape = ctx->InputShape(i);
116 TfCallbackData::InputBufferDescription input_description;
117
118 *input_description.mutable_buffer_description()->mutable_shape() =
119 shape.AsProto();
120 input_description.mutable_buffer_description()->set_type(
121 ctx->input_type(i));
122
123 if (absl::c_linear_search(constant_inputs, i)) {
124 // Assuming kernels want to read INT32 datatypes.
125 TF_ASSIGN_OR_RETURN(Tensor input_tensor, ctx->ConstantInputTensor(i));
126 tensor_storage[i] = input_tensor;
127 input_tensors[i] = &tensor_storage.at(i);
128 input_tensor.AsProtoTensorContent(input_description.mutable_value());
129 } else {
130 input_tensors[i] = nullptr;
131 operands.push_back(ctx->Input(i));
132 operand_shapes_with_layout.push_back(xla_shape);
133 xla::LayoutUtil::SetToDefaultLayout(&operand_shapes_with_layout.back());
134 }
135
136 *callback_data.add_inputs() = input_description;
137
138 TF_ASSIGN_OR_RETURN(shape_inference::ShapeHandle handle,
139 ic.MakeShapeFromShapeTensor(shape));
140 ic.SetInput(i, handle);
141 }
142
143 ic.set_input_tensors(input_tensors);
144
145 TF_RETURN_IF_ERROR(data->shape_inference_fn(&ic));
146 TF_ASSIGN_OR_RETURN(OutputDimensionBoundsMap output_dimension_bounds,
147 DynamicOutputDimensions(node_def, ctx));
148
149 std::vector<xla::Shape> output_xla_shapes;
150 for (int i = 0; i < num_outputs; ++i) {
151 const DimensionBoundsMap& dimension_bounds = output_dimension_bounds[i];
152 TfCallbackData::OutputBufferDescription output_description;
153 output_description.mutable_buffer_description()->set_type(
154 ctx->expected_output_dtype(i));
155
156 TensorShapeProto output_tensor_shape_proto =
157 ic.ShapeHandleToProto(ic.output(i));
158 if (output_tensor_shape_proto.unknown_rank()) {
159 return se::port::InternalError(
160 absl::StrCat("Output ", i, " has unknown rank"));
161 }
162
163 int rank = output_tensor_shape_proto.dim_size();
164 std::vector<bool> dynamic_dimensions(rank, false);
165
166 // Modify output tensor shape proto to replace dynamic dimensions with upper
167 // bounds: that is the information we will be storing in the callback.
168 for (int d = 0; d < output_tensor_shape_proto.dim_size(); ++d) {
169 auto* dim = output_tensor_shape_proto.mutable_dim(d);
170 auto it = dimension_bounds.find(d);
171
172 if (dim->size() < 0) {
173 if (it == dimension_bounds.end()) {
174 return se::port::InternalError(absl::StrCat(
175 "Bound for unknown dimension not found for dimension ", d));
176 }
177 dim->set_size(it->second);
178 dynamic_dimensions[d] = true;
179 output_description.set_is_dynamically_padded(true);
180 } else {
181 if (it == dimension_bounds.end()) {
182 continue;
183 }
184 if (it->second < dim->size()) {
185 dim->set_size(it->second);
186 }
187 }
188 }
189
190 *output_description.mutable_buffer_description()->mutable_shape() =
191 output_tensor_shape_proto;
192 *callback_data.add_outputs() = output_description;
193
194 TF_ASSIGN_OR_RETURN(
195 TensorShape output_tensor_shape,
196 TensorShape::BuildTensorShape(output_tensor_shape_proto));
197
198 TF_ASSIGN_OR_RETURN(xla::Shape output_shape,
199 TensorShapeToXLAShape(ctx->expected_output_dtype(i),
200 output_tensor_shape));
201
202 // Set corresponding dynamic bounds on the output xla::Shape.
203 for (int64_t d = 0; d < dynamic_dimensions.size(); ++d) {
204 output_shape.set_dynamic_dimension(d, dynamic_dimensions[d]);
205 }
206 output_xla_shapes.push_back(output_shape);
207 }
208
209 xla::Shape output_shape =
210 xla::ShapeUtil::MakeMaybeTupleShape(output_xla_shapes);
211
212 VLOG(1) << "Created output shape: " << output_shape.ToString();
213
214 TF_ASSIGN_OR_RETURN(std::string callback_data_serialized,
215 SerializeCallbackData(callback_data));
216
217 xla::XlaOp out = xla::CustomCallWithLayout(
218 ctx->builder(), kTfCallbackCustomCall, operands, output_shape,
219 /*operand_shapes_with_layout=*/operand_shapes_with_layout,
220 /*opaque=*/callback_data_serialized,
221 /*has_side_effect=*/false,
222 /*output_operand_aliasing=*/{},
223 /*literal=*/nullptr, xla::CustomCallSchedule::SCHEDULE_NONE,
224 xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING);
225
226 for (int i = 0; i < num_outputs; ++i) {
227 ctx->SetOutput(i,
228 output_shape.IsTuple() ? xla::GetTupleElement(out, i) : out);
229 }
230
231 return OkStatus();
232 }
233
234 namespace {
235
236 class WriteIntoXlaBufferAllocator : public Allocator {
237 public:
WriteIntoXlaBufferAllocator(void * xla_buffer,size_t buffer_size,absl::string_view description)238 WriteIntoXlaBufferAllocator(void* xla_buffer, size_t buffer_size,
239 absl::string_view description)
240 : xla_buffer_(xla_buffer),
241 buffer_size_(buffer_size),
242 description_(description) {}
243
Name()244 std::string Name() override {
245 return absl::StrCat("allocator-xla-", description_);
246 }
247
AllocateRaw(size_t alignment,size_t num_bytes)248 void* AllocateRaw(size_t alignment, size_t num_bytes) override {
249 VLOG(1) << "Faking allocation of " << num_bytes << " bytes into xla buffer "
250 << description_;
251
252 if (num_bytes > buffer_size_) {
253 LOG(ERROR) << "Failed allocation: requested larger size than the "
254 "underlying buffer";
255 return nullptr;
256 }
257 return xla_buffer_;
258 }
259
260 // Do not perform our own memory management.
DeallocateRaw(void * ptr)261 void DeallocateRaw(void* ptr) override {
262 VLOG(1) << "Not deallocating pointer " << ptr << " for " << description_;
263 }
264
265 private:
266 void* xla_buffer_;
267 size_t buffer_size_;
268 std::string description_;
269 };
270
GetNumConstants(const TfCallbackData & callback_data)271 int GetNumConstants(const TfCallbackData& callback_data) {
272 return absl::c_count_if(callback_data.inputs(),
273 [&](const auto& input) { return input.has_value(); });
274 }
275
GetOutputBufferId(int output_num,const TfCallbackData & callback_data)276 int GetOutputBufferId(int output_num, const TfCallbackData& callback_data) {
277 return (callback_data.inputs_size() - GetNumConstants(callback_data)) +
278 output_num;
279 }
280
BufferSize(const TfCallbackData::BufferDescription & descr)281 int64_t BufferSize(const TfCallbackData::BufferDescription& descr) {
282 TensorShape shape;
283 CHECK(TensorShape::BuildTensorShape(descr.shape(), &shape).ok()); // Crash OK
284 return shape.num_elements() * DataTypeSize(descr.type());
285 }
286
287 class TfCallbackDevice : public DeviceBase {
288 public:
TfCallbackDevice(se::Stream * stream,void ** buffers,const TfCallbackData & callback_data)289 explicit TfCallbackDevice(se::Stream* stream, void** buffers,
290 const TfCallbackData& callback_data)
291 : DeviceBase(Env::Default()),
292 stream_(stream),
293 gpu_allocator_(GPUProcessState::singleton()->GetGPUAllocator(
294 TfDeviceId{stream_->parent()->device_ordinal()})),
295 cpu_allocator_(
296 ProcessState::singleton()->GetCPUAllocator(/*numa_node=*/0)) {
297 for (int i = 0; i < callback_data.outputs_size(); ++i) {
298 int buffer_num = GetOutputBufferId(i, callback_data);
299 VLOG(1) << "Binding output " << i << " to buffer " << buffers[buffer_num];
300 int64_t buffer_size =
301 BufferSize(callback_data.outputs(i).buffer_description());
302 allocators_.emplace_back(buffers[buffer_num], buffer_size,
303 absl::StrCat("xla-output-", i));
304 }
305
306 accelerator_device_info_.stream = stream;
307 set_tensorflow_accelerator_device_info(&accelerator_device_info_);
308 }
309
name() const310 const string& name() const override { return name_; }
311
MakeGpuDevice()312 PerOpGpuDevice* MakeGpuDevice() override {
313 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
314 return new ConcretePerOpGpuDevice();
315 #else
316 LOG(FATAL) << "CUDA-enabled build is required"; // Crash OK
317 #endif
318 }
319
ReinitializeGpuDevice(OpKernelContext * context,PerOpGpuDevice * device,DeviceContext * dc,Allocator * allocator)320 Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
321 DeviceContext* dc,
322 Allocator* allocator) override {
323 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
324 auto concrete_device = static_cast<ConcretePerOpGpuDevice*>(device);
325 const void* gpu_stream = reinterpret_cast<const void*>(
326 stream_->implementation()->GpuStreamMemberHack());
327 concrete_device->Reinitialize(
328 context, gpu_stream,
329 /*platform_device_id=*/
330 PlatformDeviceId(stream_->parent()->device_ordinal()), allocator,
331 // TODO(cheshire): Pass meaningful scratch
332 // buffer.
333 /*scratch=*/nullptr);
334 return OkStatus();
335 #else
336 LOG(FATAL) << "CUDA-enabled build is required"; // Crash OK
337 #endif
338 }
339
GetScopedAllocator(AllocatorAttributes attrs,int64_t step_id)340 Allocator* GetScopedAllocator(AllocatorAttributes attrs,
341 int64_t step_id) override {
342 return &allocators_[attrs.scope_id - 1];
343 }
344
GetAllocator(AllocatorAttributes attr)345 Allocator* GetAllocator(AllocatorAttributes attr) override {
346 if (attr.on_host()) {
347 if (attr.gpu_compatible()) {
348 GPUProcessState* ps = GPUProcessState::singleton();
349 return ps->GetGpuHostAllocator(0);
350 } else {
351 return cpu_allocator_;
352 }
353 } else {
354 return gpu_allocator_;
355 }
356 }
357
358 private:
359 std::vector<WriteIntoXlaBufferAllocator> allocators_;
360 se::Stream* stream_; // NOLINT (used under GOOGLE_CUDA)
361 Allocator* gpu_allocator_;
362 Allocator* cpu_allocator_;
363 AcceleratorDeviceInfo accelerator_device_info_;
364 std::string name_ = "tf_callback_device";
365 };
366
367 // Populate the output with actual dimensions of the allocated shapes.
368 //
369 // Populates the vector on the host and then copies it over to the GPU.
PopulateMetadataBufferIfNeeded(OpKernelContext & ctx,const TfCallbackData & callback_data,void ** buffers,se::Stream * stream)370 Status PopulateMetadataBufferIfNeeded(OpKernelContext& ctx,
371 const TfCallbackData& callback_data,
372 void** buffers, se::Stream* stream) {
373 for (int i = 0; i < ctx.num_outputs(); i++) {
374 if (callback_data.outputs(i).is_dynamically_padded()) {
375 Tensor* allocated = ctx.mutable_output(i);
376 TensorShape allocated_shape = allocated->shape();
377 int num_dimensions = allocated_shape.dims();
378 std::vector<int32_t> shape_info(num_dimensions);
379 for (int d = 0; d < allocated_shape.dims(); d++) {
380 int dim_size = allocated_shape.dim_size(d);
381 shape_info[d] = dim_size;
382 }
383
384 TF_ASSIGN_OR_RETURN(
385 xla::Shape xla_shape,
386 tensorflow::TensorShapeToXLAShape(
387 callback_data.outputs(i).buffer_description().type(),
388 callback_data.outputs(i).buffer_description().shape()));
389 void* location = static_cast<char*>(allocated->data()) +
390 xla::ShapeUtil::ByteSizeOf(xla_shape);
391 se::DeviceMemoryBase m{location, num_dimensions * sizeof(int32_t)};
392 stream->ThenMemcpy(&m, shape_info.data(),
393 num_dimensions * sizeof(int32_t));
394 }
395 }
396 return OkStatus();
397 }
398
399 class FakeDeviceContext : public DeviceContext {
400 public:
FakeDeviceContext(se::Stream * stream)401 explicit FakeDeviceContext(se::Stream* stream) { stream_ = stream; }
stream() const402 se::Stream* stream() const override { return stream_; }
403
404 private:
405 se::Stream* stream_;
406 };
407
CallTfKernel(void * stream_handle,void ** buffers,const char * opaque,int opaque_len)408 Status CallTfKernel(void* stream_handle, void** buffers, const char* opaque,
409 int opaque_len) {
410 TF_ASSIGN_OR_RETURN(se::Platform * platform,
411 se::MultiPlatformManager::PlatformWithName("CUDA"));
412 se::StreamExecutorConfig config;
413 config.gpu_stream = stream_handle;
414 TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
415 platform->GetExecutor(config));
416 se::Stream* stream = executor->FindAllocatedStream(stream_handle);
417 if (!stream) {
418 return xla::InternalError("Stream not found for %p", stream_handle);
419 }
420 TF_ASSIGN_OR_RETURN(TfCallbackData callback_data,
421 CallbackDataFromProto(opaque, opaque_len));
422 TfCallbackDevice device(stream, buffers, callback_data);
423
424 std::vector<AllocatorAttributes> allocator_attributes;
425 for (int output_idx = 0; output_idx < callback_data.outputs_size();
426 ++output_idx) {
427 AllocatorAttributes attr;
428 // Repurpose `scope_id` to communicate which output is it.
429 // Shift by one to make it greater than zero.
430 attr.scope_id = output_idx + 1;
431 allocator_attributes.push_back(attr);
432 }
433
434 Status nested_status;
435 std::unique_ptr<OpKernel> kernel =
436 CreateOpKernel(DeviceType(DEVICE_GPU),
437 /*device=*/&device,
438
439 // NB: real allocator is passed with device, the one here
440 // is only called during the kernel construction.
441 // TODO(cheshire): Pass scratch allocator.
442 /*allocator=*/nullptr, callback_data.op(),
443 /*graph_def_version=*/1, &nested_status);
444 TF_RETURN_IF_ERROR(nested_status);
445
446 auto device_context =
447 core::RefCountPtr<FakeDeviceContext>(new FakeDeviceContext(stream));
448
449 OpKernelContext::Params params;
450 params.output_attr_array = allocator_attributes.data();
451 params.op_kernel = kernel.get();
452 params.device = &device;
453 params.ensure_eigen_gpu_device();
454 params.op_device_context = device_context.get();
455
456 absl::InlinedVector<TensorValue, 4> inputs;
457
458 // Deque usage is important to avoid moving objects.
459 std::deque<WriteIntoXlaBufferAllocator> input_allocators;
460 std::deque<Tensor> input_tensors;
461
462 int constant_offset = 0;
463
464 for (int i = 0; i < callback_data.inputs_size(); ++i) {
465 DataType dt = callback_data.inputs(i).buffer_description().type();
466
467 TensorShape shape;
468 TF_RETURN_IF_ERROR(TensorShape::BuildTensorShape(
469 callback_data.inputs(i).buffer_description().shape(), &shape));
470
471 VLOG(2) << "Input shape: " << shape.DebugString();
472 int64_t input_size = shape.num_elements() * DataTypeSize(dt);
473
474 if (callback_data.inputs(i).has_value()) {
475 // Value provided at compile time: reconstruct the tensor.
476 TF_ASSIGN_OR_RETURN(Tensor input,
477 TensorFromProto(callback_data.inputs(i).value()));
478 input_tensors.push_back(input);
479
480 constant_offset++;
481 VLOG(1) << "Input " << i << " is a tensor: " << input.DebugString();
482 } else {
483 VLOG(1) << "Reading into the buffer for the input " << i;
484
485 // We only get backing input buffer for those inputs which are *not*
486 // forced to be constant at compile time.
487 input_allocators.emplace_back(buffers[i - constant_offset], input_size,
488 absl::StrCat("input-", i));
489 input_tensors.emplace_back(&input_allocators[i], dt, shape);
490 }
491 inputs.emplace_back(&input_tensors.back());
492 }
493
494 params.inputs = inputs;
495 OpKernelContext ctx(¶ms, callback_data.outputs_size());
496 kernel->Compute(&ctx);
497
498 bool has_dynamic_outputs = absl::c_any_of(
499 callback_data.outputs(),
500 [](const auto& out) { return out.is_dynamically_padded(); });
501
502 if (has_dynamic_outputs) {
503 TF_RETURN_IF_ERROR(
504 PopulateMetadataBufferIfNeeded(ctx, callback_data, buffers, stream));
505 }
506
507 TF_RETURN_IF_ERROR(ctx.status());
508 return OkStatus();
509 }
510
GenericTfCallback(void * stream_handle,void ** buffers,const char * opaque,int opaque_len,XlaCustomCallStatus * status)511 void GenericTfCallback(void* stream_handle, void** buffers, const char* opaque,
512 int opaque_len, XlaCustomCallStatus* status) {
513 Status s = CallTfKernel(stream_handle, buffers, opaque, opaque_len);
514 if (!s.ok()) {
515 XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
516 s.error_message().size());
517 }
518 }
519
520 XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(kTfCallbackCustomCall,
521 GenericTfCallback, "CUDA");
522
523 } // namespace
524
LightOutsideCompilationOp(OpKernelConstruction * context)525 LightOutsideCompilationOp::LightOutsideCompilationOp(
526 OpKernelConstruction* context)
527 : XlaOpKernel(context),
528 def_(context->def()),
529 graph_def_version_(context->graph_def_version()) {}
530
Compile(XlaOpKernelContext * ctx)531 void LightOutsideCompilationOp::Compile(XlaOpKernelContext* ctx) {
532 OP_REQUIRES_OK(
533 ctx, CompileToCustomCallCallingTfKernel(graph_def_version_, def_, ctx));
534 }
535
536 } // namespace tensorflow
537