1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/tensor_util.h"
16
17 #include <assert.h>
18 #include <sys/types.h>
19
20 #include <cstring>
21 #include <string>
22 #include <utility>
23
24 #include "absl/container/inlined_vector.h"
25 #include "absl/strings/str_cat.h"
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/tstring.h"
33 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h"
34 #include "tensorflow/core/runtime_fallback/util/tensor_util.h"
35 #include "tensorflow/core/runtime_fallback/util/type_util.h"
36 #include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime
37 #include "tfrt/dtype/dtype.h" // from @tf_runtime
38 #include "tfrt/host_context/host_buffer.h" // from @tf_runtime
39 #include "tfrt/host_context/host_context.h" // from @tf_runtime
40 #include "tfrt/support/error_util.h" // from @tf_runtime
41 #include "tfrt/support/forward_decls.h" // from @tf_runtime
42 #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
43 #include "tfrt/tensor/host_tensor.h" // from @tf_runtime
44 #include "tfrt/tensor/scalar_host_tensor.h" // from @tf_runtime
45 #include "tfrt/tensor/string_host_tensor.h" // from @tf_runtime
46 #include "tfrt/tensor/tensor.h" // from @tf_runtime
47 #include "tfrt/tensor/tensor_shape.h" // from @tf_runtime
48
49 namespace tfrt {
50 namespace {
51
52 using ::tensorflow::StatusOr;
53
CopyScalarHostTensorToTFTensor(const AnyScalarHostTensor & tensor)54 llvm::Expected<tensorflow::Tensor> CopyScalarHostTensorToTFTensor(
55 const AnyScalarHostTensor& tensor) {
56 auto element_byte_size = GetHostSize(tensor.dtype());
57 if (element_byte_size == 0) {
58 return MakeStringError(
59 "Failed to convert ScalarHostTensor to tensorflow::Tensor: "
60 "unsupported dtype: ",
61 tensor.dtype());
62 }
63
64 llvm::SmallVector<Index, 4> dims;
65 tensor.shape().GetDimensions(&dims);
66
67 auto tf_dtype = tensorflow::tfd::GetTfDataType(tensor.dtype());
68 tensorflow::Tensor tf_tensor(
69 tf_dtype, tensorflow::TensorShape(
70 llvm::SmallVector<int64_t, 4>(dims.begin(), dims.end())));
71
72 // This can be a DCHECK instead of returninng an error because TFRT's
73 // ScalarHostTensor only supports these types.
74 DCHECK(DataTypeCanUseMemcpy(tf_dtype));
75
76 // TODO(tfrt-devs): Hide the following the logic of generating a full size
77 // buffer for the scalar host tensor under Tensor (and optimize if
78 // necessary), so we don't have to re-implement it every time we need it.
79 char* begin = reinterpret_cast<char*>(tf_tensor.data());
80 for (int i = 0; i < tf_tensor.NumElements(); ++i) {
81 std::memcpy(begin, tensor.data(), element_byte_size);
82 begin += element_byte_size;
83 }
84 return tf_tensor;
85 }
86
ConvertTFDTypeToTFRTDType(tensorflow::DataType dtype)87 StatusOr<DType> ConvertTFDTypeToTFRTDType(tensorflow::DataType dtype) {
88 switch (dtype) {
89 #define DTYPE(TFRT_DTYPE, TF_DTYPE) \
90 case tensorflow::TF_DTYPE: \
91 return DType(DType::TFRT_DTYPE);
92 #include "tensorflow/core/tfrt/utils/dtype.def"
93 default:
94 return tensorflow::errors::Internal(absl::StrCat(
95 "unsupported tensorflow dtype: ", tensorflow::DataType_Name(dtype)));
96 }
97 }
98
ConvertTFRTDTypeToTFDType(DType dtype)99 StatusOr<tensorflow::DataType> ConvertTFRTDTypeToTFDType(DType dtype) {
100 switch (dtype) {
101 #define DTYPE(TFRT_DTYPE, TF_DTYPE) \
102 case DType::TFRT_DTYPE: \
103 return tensorflow::TF_DTYPE;
104 #include "tensorflow/core/tfrt/utils/dtype.def"
105 default:
106 return tensorflow::errors::Internal(
107 StrCat("unsupported tfrt dtype: ", dtype));
108 }
109 }
110
111 } // namespace
112
TFRTTensorToTFTensor(const Tensor & tensor,HostContext * host)113 llvm::Expected<tensorflow::Tensor> TFRTTensorToTFTensor(const Tensor& tensor,
114 HostContext* host) {
115 if (auto* knfbt = llvm::dyn_cast<tensorflow::KernelFallbackTensor>(&tensor)) {
116 return *knfbt->GetTensor();
117 }
118 // TODO(tfrt-devs): The following logic should be better provided by
119 // Tensor so we don't have to re-implement it.
120 if (auto* dht = llvm::dyn_cast<DenseHostTensor>(&tensor)) {
121 return tensorflow::tfd::MoveHostBufferToTfTensor(
122 dht->buffer(), dht->dtype(), dht->shape());
123 }
124 if (auto* sht = llvm::dyn_cast<StringHostTensor>(&tensor)) {
125 return tensorflow::tfd::CopyShtToTfTensor(*sht);
126 }
127 if (auto* scalar = llvm::dyn_cast<AnyScalarHostTensor>(&tensor)) {
128 return CopyScalarHostTensorToTFTensor(*scalar);
129 }
130 return MakeStringError("Unsupported conversion format for ",
131 tensor.tensor_type().name());
132 }
133
TFTensorToTFRTTensorHandle(const tensorflow::Tensor & tf_tensor,HostContext * host_ctx)134 AsyncValueRef<TensorHandle> TFTensorToTFRTTensorHandle(
135 const tensorflow::Tensor& tf_tensor, HostContext* host_ctx) {
136 auto knfbt = MakeAvailableAsyncValueRef<tensorflow::KernelFallbackTensor>(
137 host_ctx, tf_tensor);
138 return MakeAvailableAsyncValueRef<TensorHandle>(
139 host_ctx, host_ctx->GetHostDeviceRef(), knfbt->metadata(),
140 std::move(knfbt));
141 }
142
CreateTensorHandleFromTFTensor(const tensorflow::Tensor & tensor,HostContext * host)143 StatusOr<TensorHandle> CreateTensorHandleFromTFTensor(
144 const tensorflow::Tensor& tensor, HostContext* host) {
145 // TODO(chky): Handle non-trivial types such as strings.
146 TF_ASSIGN_OR_RETURN(auto dtype, ConvertTFDTypeToTFRTDType(tensor.dtype()));
147 auto shape = tensor.shape().dim_sizes();
148 TensorMetadata metadata(dtype, TensorShape(llvm::SmallVector<Index, 4>(
149 shape.begin(), shape.end())));
150
151 if (dtype == DType::String) {
152 auto sht_ref =
153 StringHostTensor::MakeConstructedAsyncValueRef(metadata, host);
154 auto to = sht_ref->strings();
155 auto from = tensor.flat<tensorflow::tstring>();
156 for (int i = 0, e = to.size(); i < e; ++i) {
157 to[i] = from(i);
158 }
159 sht_ref.SetStateConcrete();
160 return TensorHandle(host->GetHostDeviceRef(), metadata, std::move(sht_ref));
161 }
162
163 auto dht_ref = DenseHostTensor::MakeConstructedAsyncValueRef(metadata, host);
164
165 auto& dht = dht_ref.get();
166 assert(dht.DataSizeInBytes() ==
167 tensor.NumElements() * tensorflow::DataTypeSize(tensor.dtype()));
168 std::memcpy(dht_ref.get().data(), tensor.data(), dht.DataSizeInBytes());
169
170 dht_ref.SetStateConcrete();
171 return TensorHandle(host->GetHostDeviceRef(), metadata, std::move(dht_ref));
172 }
173
CreateTFTensorFromTensorHandle(const TensorHandle & tensor_handle)174 StatusOr<tensorflow::Tensor> CreateTFTensorFromTensorHandle(
175 const TensorHandle& tensor_handle) {
176 const auto& metadata = tensor_handle.GetAvailableMetadata();
177 TF_ASSIGN_OR_RETURN(auto dtype, ConvertTFRTDTypeToTFDType(metadata.dtype));
178 llvm::SmallVector<Index, 4> shape;
179 metadata.shape.GetDimensions(&shape);
180 const auto& host_tensor = tensor_handle.GetAsyncTensor()->get<HostTensor>();
181
182 if (auto* kernel_fallback_tensor =
183 llvm::dyn_cast<tensorflow::KernelFallbackTensor>(&host_tensor)) {
184 return *kernel_fallback_tensor->GetTensor();
185 }
186
187 if (llvm::isa<StringHostTensor>(host_tensor)) {
188 assert(dtype == tensorflow::DT_STRING);
189 const auto& sht = llvm::cast<StringHostTensor>(host_tensor);
190 tensorflow::Tensor tensor(
191 tensorflow::DT_STRING,
192 tensorflow::TensorShape(
193 llvm::SmallVector<int64_t, 4>(shape.begin(), shape.end())));
194 auto from = sht.strings();
195 auto to = tensor.flat<tensorflow::tstring>();
196 for (int i = 0, e = from.size(); i < e; ++i) {
197 to(i).assign(from[i].data(), from[i].size());
198 }
199 return tensor;
200 }
201
202 if (llvm::isa<DenseHostTensor>(host_tensor)) {
203 const auto& dht = llvm::cast<DenseHostTensor>(host_tensor);
204 tensorflow::Tensor tensor(
205 dtype, tensorflow::TensorShape(
206 llvm::SmallVector<int64_t, 4>(shape.begin(), shape.end())));
207
208 assert(dht.DataSizeInBytes() ==
209 tensor.NumElements() * tensorflow::DataTypeSize(tensor.dtype()));
210 std::memcpy(tensor.data(), dht.data(), dht.DataSizeInBytes());
211 return tensor;
212 }
213
214 return tensorflow::errors::Internal("unknown host tensor type");
215 }
216
217 } // namespace tfrt
218