xref: /aosp_15_r20/external/tensorflow/tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements conversion function between TFRuntimeFallback and Gpu
17 // tensors.
18 
19 #include "tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.h"
20 
21 #include "absl/strings/match.h"
22 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
23 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
24 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
25 #include "tensorflow/core/runtime_fallback/util/gpu/gpu_utils.h"
26 #include "tensorflow/core/runtime_fallback/util/type_util.h"
27 #include "tfrt/gpu/device/conversion_function.h"  // from @tf_runtime
28 #include "tfrt/gpu/device/device.h"  // from @tf_runtime
29 #include "tfrt/gpu/device/device_util.h"  // from @tf_runtime
30 #include "tfrt/gpu/gpu_types.h"  // from @tf_runtime
31 #include "tfrt/gpu/tensor/dense_gpu_tensor.h"  // from @tf_runtime
32 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
33 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
34 #include "tfrt/host_context/host_buffer.h"  // from @tf_runtime
35 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
36 #include "tfrt/support/error_util.h"  // from @tf_runtime
37 #include "tfrt/tensor/conversion_registry.h"  // from @tf_runtime
38 #include "tfrt/tensor/conversion_utils.h"  // from @tf_runtime
39 #include "tfrt/tensor/tensor.h"  // from @tf_runtime
40 
41 namespace tensorflow {
42 namespace tfd {
43 
44 static tfrt::Expected<RuntimeFallbackTensor>
CopyRefGpuTensorToRuntimeFallbackTensor(const tfrt::gpu::DenseGpuTensor & gpu_tensor,Device * device,Device * op_device,EagerContext * eager_ctx)45 CopyRefGpuTensorToRuntimeFallbackTensor(
46     const tfrt::gpu::DenseGpuTensor& gpu_tensor, Device* device,
47     Device* op_device, EagerContext* eager_ctx) {
48   // Do not copy the gpu buffer content, CopyRef on the buffer instead.
49   tfrt::AsyncValueRef<tfrt::gpu::GpuBuffer> gpu_buffer =
50       gpu_tensor.CopyBufferRef();
51   tfrt::Expected<tensorflow::Tensor> tensor = MoveGpuBufferToTFTensor(
52       std::move(gpu_buffer), gpu_tensor.dtype(), gpu_tensor.shape());
53   if (!tensor) return tensor.takeError();
54 
55   OwnedTensorHandle tensor_handle{tensorflow::TensorHandle::CreateLocalHandle(
56       std::move(tensor.get()), device, op_device, eager_ctx)};
57   return RuntimeFallbackTensor(gpu_tensor.shape(), gpu_tensor.dtype(),
58                                std::move(tensor_handle));
59 }
60 
61 // Convert the RuntimeFallbackTensor to a GpuTensor (currently DenseGpuTensor
62 // only). If the source tensor is on CPU, copy the data to GPU. If the source
63 // tensor is already on GPU, just do type conversion.
64 // TODO(b/167254525): For TFRuntimeFallback tensor, create separate tensor
65 // types for different devices.
66 static tfrt::AsyncValueRef<tfrt::gpu::DenseGpuTensor>
ConvertRuntimeFallbackTensorToDenseGpuTensor(const RuntimeFallbackTensor & tensor,const tfrt::Device & src,const tfrt::gpu::GpuDevice & dst,const tfrt::ExecutionContext & exec_ctx)67 ConvertRuntimeFallbackTensorToDenseGpuTensor(
68     const RuntimeFallbackTensor& tensor, const tfrt::Device& src,
69     const tfrt::gpu::GpuDevice& dst, const tfrt::ExecutionContext& exec_ctx) {
70   auto* host_ctx = exec_ctx.host();
71 
72   auto tf_tensor_handle = tensor.GetTensorHandle();
73 
74   tensorflow::Status status;
75   const char* device_name = tf_tensor_handle->DeviceName(&status);
76 
77   auto tensor_device_ref =
78       host_ctx->GetDeviceManager()->GetDeviceRef<tfrt::Device>(device_name);
79 
80   if (!tensor_device_ref) {
81     tensor_device_ref =
82         host_ctx->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
83             ConvertTfDeviceNameToTfrtDefault(device_name));
84   }
85 
86   if (!tensor_device_ref)
87     return tfrt::EmitErrorAsync(
88         exec_ctx,
89         tfrt::StrCat("Failed to find a device with name: ", device_name));
90 
91   if (!status.ok()) {
92     return EmitErrorAsync(
93         exec_ctx, tfrt::StrCat("error getting device name from TensorHandle: ",
94                                status.error_message()));
95   }
96 
97   // Check if the underlying tensorflow::TensorHandle is already on GPU.
98   // If so, just convert the RuntimeFallbackTensor to GpuTensor.
99   if (tensor_device_ref.get() == &dst) {
100     tensorflow::TensorShape shape;
101     tensorflow::Status status = tf_tensor_handle->Shape(&shape);
102     if (!status.ok()) {
103       return EmitErrorAsync(
104           exec_ctx, tfrt::StrCat("error getting shape from TF tensor handle: ",
105                                  status.error_message()));
106     }
107 
108     auto tf_shape = shape.dim_sizes();
109     DataType dtype = tf_tensor_handle->DataType();
110     // Note that GPU tensor might not be available yet. But since TF
111     // and TFRT share the same stream, this is ok.
112     const tensorflow::Tensor* tf_tensor = nullptr;
113     status = tf_tensor_handle->Tensor(&tf_tensor);
114     if (!status.ok()) {
115       return EmitErrorAsync(exec_ctx,
116                             tfrt::StrCat("error calling TensorHandle::Tensor: ",
117                                          status.error_message()));
118     }
119 
120     auto platform = tensorflow::tfd::GetTfrtGpuPlatform(tf_tensor_handle);
121 
122     void* data = tf_tensor->data();
123     size_t size = tf_tensor->TotalBytes();
124 
125     // Need to add a reference here since we are transferring the ownership
126     // of the Tensorflow::TensorHandle and the underlying GPU buffer to
127     // tfrt::DenseGpuTensor. Otherwise, the TensorHandle will be released
128     // when he RuntimeFallbackTensor goes out of scope after the tensor
129     // conversion. The GPU buffer will be deleted as well.
130     OwnedTensorHandle owned_tf_tensor_handle =
131         OwnedTensorHandle{TensorHandleFromInterface(tf_tensor_handle->Copy())};
132 
133     // The OwnedTensorHandle holds a reference on underlying Tensorflow buffer
134     // and is held alive by GpuOneShotAllocator.
135     auto allocator = tfrt::MakeAvailableAsyncValueRef<
136         tfrt::gpu::GpuOneShotAllocator<OwnedTensorHandle>>(
137         tfrt::gpu::wrapper::Pointer<void>(data, platform),
138         std::move(owned_tf_tensor_handle));
139     llvm::Expected<tfrt::gpu::GpuBuffer> gpu_buffer =
140         tfrt::gpu::GpuBuffer::Allocate(std::move(allocator), size);
141     if (!gpu_buffer) {
142       return tfrt::MakeErrorAsyncValueRef(tfrt::StrCat(gpu_buffer.takeError()));
143     }
144 
145     // create DenseGpuTensor.
146     tfrt::gpu::DenseGpuTensor gpu_tensor{
147         tfrt::TensorShape(
148             std::vector<tfrt::Index>(tf_shape.begin(), tf_shape.end())),
149         GetTfrtDtype(dtype),
150         tfrt::MakeAvailableAsyncValueRef<tfrt::gpu::GpuBuffer>(
151             std::move(*gpu_buffer))};
152 
153     return tfrt::MakeAvailableAsyncValueRef<tfrt::gpu::DenseGpuTensor>(
154         exec_ctx.host(), std::move(gpu_tensor));
155   } else {
156     // TODO(chuanhao): clean up the branch after cl/325503773. Currently this
157     // branch is needed since we don't know what type of tensor that
158     // RuntimeFallbackTensor holds.
159     // tensorflow::TensorHandle is on host CPU.
160     assert(tensor_device_ref.get() == &host_ctx->GetHostDevice());
161 
162     // Convert the TFRuntimeFallbackTensor to DenseHostTensor.
163     auto host_tensor_ref = tfrt::ConvertTensor(
164         exec_ctx, tensor, src, src, tfrt::DenseHostTensor::kTensorType);
165 
166     if (!host_tensor_ref.get().IsTensorType(tfrt::DenseHostTensor::kTensorType))
167       return EmitErrorAsync(exec_ctx,
168                             "TFRuntimeFallbackTensor not converted to "
169                             "DenseHostTensor.");
170     llvm::Expected<tfrt::gpu::wrapper::CurrentContext> current_context =
171         dst.SetCurrentContext();
172     if (!current_context) {
173       return tfrt::MakeErrorAsyncValueRef(
174           tfrt::StrCat(current_context.takeError()));
175     }
176 
177     auto expected_gpu_tensor =
178         tfrt::gpu::ConvertDenseHostTensorToDenseGpuTensor(
179             std::move(current_context.get()), dst.stream(), dst.allocator(),
180             llvm::cast<tfrt::DenseHostTensor>(host_tensor_ref.get()), host_ctx);
181     if (!expected_gpu_tensor) {
182       return EmitErrorAsync(exec_ctx, expected_gpu_tensor.takeError());
183     }
184     return tfrt::MakeAvailableAsyncValueRef<tfrt::gpu::DenseGpuTensor>(
185         exec_ctx.host(), std::move(expected_gpu_tensor.get()));
186   }
187 }
188 
189 static tfrt::AsyncValueRef<RuntimeFallbackTensor>
ConvertDenseGpuTensorToRuntimeFallbackTensor(const tfrt::gpu::DenseGpuTensor & tensor,const tfrt::gpu::GpuDevice & src,const tfrt::gpu::GpuDevice & dst,const tfrt::ExecutionContext & exec_ctx)190 ConvertDenseGpuTensorToRuntimeFallbackTensor(
191     const tfrt::gpu::DenseGpuTensor& tensor, const tfrt::gpu::GpuDevice& src,
192     const tfrt::gpu::GpuDevice& dst, const tfrt::ExecutionContext& exec_ctx) {
193   auto* host = exec_ctx.host();
194 
195   tfrt::ResourceContext* resource_context = exec_ctx.resource_context();
196   tensorflow::tfd::EagerContextResource* eager_context_resource =
197       resource_context
198           ->GetOrCreateResource<tensorflow::tfd::EagerContextResource>(
199               tensorflow::tfd::kEagerContextResourceName);
200 
201   tfrt::Expected<EagerContext*> eager_ctx_expected =
202       eager_context_resource->GetTFEagerContext();
203   if (!eager_ctx_expected)
204     return EmitErrorAsync(exec_ctx, eager_ctx_expected.takeError());
205 
206   EagerContext* eager_ctx = eager_ctx_expected.get();
207 
208   assert(&src == &dst);
209   Device* device;
210   Status status = eager_ctx->local_device_mgr()->LookupDevice(
211       ToAbslStringView(dst.name()), &device);
212   if (!status.ok())
213     return EmitErrorAsync(exec_ctx,
214                           tfrt::MakeStringError(tfrt::StrCat(
215                               "error looking up gpu device from EagerContext: ",
216                               status.error_message())));
217 
218   auto fallback_tensor = CopyRefGpuTensorToRuntimeFallbackTensor(
219       tensor, device, device, eager_ctx);
220   if (fallback_tensor) {
221     return tfrt::MakeAvailableAsyncValueRef<RuntimeFallbackTensor>(
222         host, std::move(*fallback_tensor));
223   } else {
224     return EmitErrorAsync(exec_ctx, fallback_tensor.takeError());
225   }
226 }
227 
RegisterTFRuntimeFallbackTensorToGpuConversionFn(tfrt::TensorConversionFnRegistry * registry)228 void RegisterTFRuntimeFallbackTensorToGpuConversionFn(
229     tfrt::TensorConversionFnRegistry* registry) {
230   registry->AddTensorConversionFn(
231       TFRT_CONVERSION(ConvertRuntimeFallbackTensorToDenseGpuTensor));
232 
233   registry->AddTensorConversionFn(
234       TFRT_CONVERSION(ConvertDenseGpuTensorToRuntimeFallbackTensor));
235 }
236 
237 }  // namespace tfd
238 }  // namespace tensorflow
239