1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements conversion function between TFRuntimeFallback and Gpu
17 // tensors.
18
19 #include "tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.h"
20
21 #include "absl/strings/match.h"
22 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
23 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
24 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
25 #include "tensorflow/core/runtime_fallback/util/gpu/gpu_utils.h"
26 #include "tensorflow/core/runtime_fallback/util/type_util.h"
27 #include "tfrt/gpu/device/conversion_function.h" // from @tf_runtime
28 #include "tfrt/gpu/device/device.h" // from @tf_runtime
29 #include "tfrt/gpu/device/device_util.h" // from @tf_runtime
30 #include "tfrt/gpu/gpu_types.h" // from @tf_runtime
31 #include "tfrt/gpu/tensor/dense_gpu_tensor.h" // from @tf_runtime
32 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
33 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
34 #include "tfrt/host_context/host_buffer.h" // from @tf_runtime
35 #include "tfrt/host_context/host_context.h" // from @tf_runtime
36 #include "tfrt/support/error_util.h" // from @tf_runtime
37 #include "tfrt/tensor/conversion_registry.h" // from @tf_runtime
38 #include "tfrt/tensor/conversion_utils.h" // from @tf_runtime
39 #include "tfrt/tensor/tensor.h" // from @tf_runtime
40
41 namespace tensorflow {
42 namespace tfd {
43
44 static tfrt::Expected<RuntimeFallbackTensor>
CopyRefGpuTensorToRuntimeFallbackTensor(const tfrt::gpu::DenseGpuTensor & gpu_tensor,Device * device,Device * op_device,EagerContext * eager_ctx)45 CopyRefGpuTensorToRuntimeFallbackTensor(
46 const tfrt::gpu::DenseGpuTensor& gpu_tensor, Device* device,
47 Device* op_device, EagerContext* eager_ctx) {
48 // Do not copy the gpu buffer content, CopyRef on the buffer instead.
49 tfrt::AsyncValueRef<tfrt::gpu::GpuBuffer> gpu_buffer =
50 gpu_tensor.CopyBufferRef();
51 tfrt::Expected<tensorflow::Tensor> tensor = MoveGpuBufferToTFTensor(
52 std::move(gpu_buffer), gpu_tensor.dtype(), gpu_tensor.shape());
53 if (!tensor) return tensor.takeError();
54
55 OwnedTensorHandle tensor_handle{tensorflow::TensorHandle::CreateLocalHandle(
56 std::move(tensor.get()), device, op_device, eager_ctx)};
57 return RuntimeFallbackTensor(gpu_tensor.shape(), gpu_tensor.dtype(),
58 std::move(tensor_handle));
59 }
60
61 // Convert the RuntimeFallbackTensor to a GpuTensor (currently DenseGpuTensor
62 // only). If the source tensor is on CPU, copy the data to GPU. If the source
63 // tensor is already on GPU, just do type conversion.
64 // TODO(b/167254525): For TFRuntimeFallback tensor, create separate tensor
65 // types for different devices.
66 static tfrt::AsyncValueRef<tfrt::gpu::DenseGpuTensor>
ConvertRuntimeFallbackTensorToDenseGpuTensor(const RuntimeFallbackTensor & tensor,const tfrt::Device & src,const tfrt::gpu::GpuDevice & dst,const tfrt::ExecutionContext & exec_ctx)67 ConvertRuntimeFallbackTensorToDenseGpuTensor(
68 const RuntimeFallbackTensor& tensor, const tfrt::Device& src,
69 const tfrt::gpu::GpuDevice& dst, const tfrt::ExecutionContext& exec_ctx) {
70 auto* host_ctx = exec_ctx.host();
71
72 auto tf_tensor_handle = tensor.GetTensorHandle();
73
74 tensorflow::Status status;
75 const char* device_name = tf_tensor_handle->DeviceName(&status);
76
77 auto tensor_device_ref =
78 host_ctx->GetDeviceManager()->GetDeviceRef<tfrt::Device>(device_name);
79
80 if (!tensor_device_ref) {
81 tensor_device_ref =
82 host_ctx->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
83 ConvertTfDeviceNameToTfrtDefault(device_name));
84 }
85
86 if (!tensor_device_ref)
87 return tfrt::EmitErrorAsync(
88 exec_ctx,
89 tfrt::StrCat("Failed to find a device with name: ", device_name));
90
91 if (!status.ok()) {
92 return EmitErrorAsync(
93 exec_ctx, tfrt::StrCat("error getting device name from TensorHandle: ",
94 status.error_message()));
95 }
96
97 // Check if the underlying tensorflow::TensorHandle is already on GPU.
98 // If so, just convert the RuntimeFallbackTensor to GpuTensor.
99 if (tensor_device_ref.get() == &dst) {
100 tensorflow::TensorShape shape;
101 tensorflow::Status status = tf_tensor_handle->Shape(&shape);
102 if (!status.ok()) {
103 return EmitErrorAsync(
104 exec_ctx, tfrt::StrCat("error getting shape from TF tensor handle: ",
105 status.error_message()));
106 }
107
108 auto tf_shape = shape.dim_sizes();
109 DataType dtype = tf_tensor_handle->DataType();
110 // Note that GPU tensor might not be available yet. But since TF
111 // and TFRT share the same stream, this is ok.
112 const tensorflow::Tensor* tf_tensor = nullptr;
113 status = tf_tensor_handle->Tensor(&tf_tensor);
114 if (!status.ok()) {
115 return EmitErrorAsync(exec_ctx,
116 tfrt::StrCat("error calling TensorHandle::Tensor: ",
117 status.error_message()));
118 }
119
120 auto platform = tensorflow::tfd::GetTfrtGpuPlatform(tf_tensor_handle);
121
122 void* data = tf_tensor->data();
123 size_t size = tf_tensor->TotalBytes();
124
125 // Need to add a reference here since we are transferring the ownership
126 // of the Tensorflow::TensorHandle and the underlying GPU buffer to
127 // tfrt::DenseGpuTensor. Otherwise, the TensorHandle will be released
128 // when he RuntimeFallbackTensor goes out of scope after the tensor
129 // conversion. The GPU buffer will be deleted as well.
130 OwnedTensorHandle owned_tf_tensor_handle =
131 OwnedTensorHandle{TensorHandleFromInterface(tf_tensor_handle->Copy())};
132
133 // The OwnedTensorHandle holds a reference on underlying Tensorflow buffer
134 // and is held alive by GpuOneShotAllocator.
135 auto allocator = tfrt::MakeAvailableAsyncValueRef<
136 tfrt::gpu::GpuOneShotAllocator<OwnedTensorHandle>>(
137 tfrt::gpu::wrapper::Pointer<void>(data, platform),
138 std::move(owned_tf_tensor_handle));
139 llvm::Expected<tfrt::gpu::GpuBuffer> gpu_buffer =
140 tfrt::gpu::GpuBuffer::Allocate(std::move(allocator), size);
141 if (!gpu_buffer) {
142 return tfrt::MakeErrorAsyncValueRef(tfrt::StrCat(gpu_buffer.takeError()));
143 }
144
145 // create DenseGpuTensor.
146 tfrt::gpu::DenseGpuTensor gpu_tensor{
147 tfrt::TensorShape(
148 std::vector<tfrt::Index>(tf_shape.begin(), tf_shape.end())),
149 GetTfrtDtype(dtype),
150 tfrt::MakeAvailableAsyncValueRef<tfrt::gpu::GpuBuffer>(
151 std::move(*gpu_buffer))};
152
153 return tfrt::MakeAvailableAsyncValueRef<tfrt::gpu::DenseGpuTensor>(
154 exec_ctx.host(), std::move(gpu_tensor));
155 } else {
156 // TODO(chuanhao): clean up the branch after cl/325503773. Currently this
157 // branch is needed since we don't know what type of tensor that
158 // RuntimeFallbackTensor holds.
159 // tensorflow::TensorHandle is on host CPU.
160 assert(tensor_device_ref.get() == &host_ctx->GetHostDevice());
161
162 // Convert the TFRuntimeFallbackTensor to DenseHostTensor.
163 auto host_tensor_ref = tfrt::ConvertTensor(
164 exec_ctx, tensor, src, src, tfrt::DenseHostTensor::kTensorType);
165
166 if (!host_tensor_ref.get().IsTensorType(tfrt::DenseHostTensor::kTensorType))
167 return EmitErrorAsync(exec_ctx,
168 "TFRuntimeFallbackTensor not converted to "
169 "DenseHostTensor.");
170 llvm::Expected<tfrt::gpu::wrapper::CurrentContext> current_context =
171 dst.SetCurrentContext();
172 if (!current_context) {
173 return tfrt::MakeErrorAsyncValueRef(
174 tfrt::StrCat(current_context.takeError()));
175 }
176
177 auto expected_gpu_tensor =
178 tfrt::gpu::ConvertDenseHostTensorToDenseGpuTensor(
179 std::move(current_context.get()), dst.stream(), dst.allocator(),
180 llvm::cast<tfrt::DenseHostTensor>(host_tensor_ref.get()), host_ctx);
181 if (!expected_gpu_tensor) {
182 return EmitErrorAsync(exec_ctx, expected_gpu_tensor.takeError());
183 }
184 return tfrt::MakeAvailableAsyncValueRef<tfrt::gpu::DenseGpuTensor>(
185 exec_ctx.host(), std::move(expected_gpu_tensor.get()));
186 }
187 }
188
189 static tfrt::AsyncValueRef<RuntimeFallbackTensor>
ConvertDenseGpuTensorToRuntimeFallbackTensor(const tfrt::gpu::DenseGpuTensor & tensor,const tfrt::gpu::GpuDevice & src,const tfrt::gpu::GpuDevice & dst,const tfrt::ExecutionContext & exec_ctx)190 ConvertDenseGpuTensorToRuntimeFallbackTensor(
191 const tfrt::gpu::DenseGpuTensor& tensor, const tfrt::gpu::GpuDevice& src,
192 const tfrt::gpu::GpuDevice& dst, const tfrt::ExecutionContext& exec_ctx) {
193 auto* host = exec_ctx.host();
194
195 tfrt::ResourceContext* resource_context = exec_ctx.resource_context();
196 tensorflow::tfd::EagerContextResource* eager_context_resource =
197 resource_context
198 ->GetOrCreateResource<tensorflow::tfd::EagerContextResource>(
199 tensorflow::tfd::kEagerContextResourceName);
200
201 tfrt::Expected<EagerContext*> eager_ctx_expected =
202 eager_context_resource->GetTFEagerContext();
203 if (!eager_ctx_expected)
204 return EmitErrorAsync(exec_ctx, eager_ctx_expected.takeError());
205
206 EagerContext* eager_ctx = eager_ctx_expected.get();
207
208 assert(&src == &dst);
209 Device* device;
210 Status status = eager_ctx->local_device_mgr()->LookupDevice(
211 ToAbslStringView(dst.name()), &device);
212 if (!status.ok())
213 return EmitErrorAsync(exec_ctx,
214 tfrt::MakeStringError(tfrt::StrCat(
215 "error looking up gpu device from EagerContext: ",
216 status.error_message())));
217
218 auto fallback_tensor = CopyRefGpuTensorToRuntimeFallbackTensor(
219 tensor, device, device, eager_ctx);
220 if (fallback_tensor) {
221 return tfrt::MakeAvailableAsyncValueRef<RuntimeFallbackTensor>(
222 host, std::move(*fallback_tensor));
223 } else {
224 return EmitErrorAsync(exec_ctx, fallback_tensor.takeError());
225 }
226 }
227
RegisterTFRuntimeFallbackTensorToGpuConversionFn(tfrt::TensorConversionFnRegistry * registry)228 void RegisterTFRuntimeFallbackTensorToGpuConversionFn(
229 tfrt::TensorConversionFnRegistry* registry) {
230 registry->AddTensorConversionFn(
231 TFRT_CONVERSION(ConvertRuntimeFallbackTensorToDenseGpuTensor));
232
233 registry->AddTensorConversionFn(
234 TFRT_CONVERSION(ConvertDenseGpuTensorToRuntimeFallbackTensor));
235 }
236
237 } // namespace tfd
238 } // namespace tensorflow
239