xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/utils/tensor_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/tensor_util.h"
16 
17 #include <assert.h>
18 #include <sys/types.h>
19 
20 #include <cstring>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/container/inlined_vector.h"
25 #include "absl/strings/str_cat.h"
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/tstring.h"
33 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h"
34 #include "tensorflow/core/runtime_fallback/util/tensor_util.h"
35 #include "tensorflow/core/runtime_fallback/util/type_util.h"
36 #include "tfrt/core_runtime/tensor_handle.h"  // from @tf_runtime
37 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
38 #include "tfrt/host_context/host_buffer.h"  // from @tf_runtime
39 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
40 #include "tfrt/support/error_util.h"  // from @tf_runtime
41 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
42 #include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
43 #include "tfrt/tensor/host_tensor.h"  // from @tf_runtime
44 #include "tfrt/tensor/scalar_host_tensor.h"  // from @tf_runtime
45 #include "tfrt/tensor/string_host_tensor.h"  // from @tf_runtime
46 #include "tfrt/tensor/tensor.h"  // from @tf_runtime
47 #include "tfrt/tensor/tensor_shape.h"  // from @tf_runtime
48 
49 namespace tfrt {
50 namespace {
51 
52 using ::tensorflow::StatusOr;
53 
CopyScalarHostTensorToTFTensor(const AnyScalarHostTensor & tensor)54 llvm::Expected<tensorflow::Tensor> CopyScalarHostTensorToTFTensor(
55     const AnyScalarHostTensor& tensor) {
56   auto element_byte_size = GetHostSize(tensor.dtype());
57   if (element_byte_size == 0) {
58     return MakeStringError(
59         "Failed to convert ScalarHostTensor to tensorflow::Tensor: "
60         "unsupported dtype: ",
61         tensor.dtype());
62   }
63 
64   llvm::SmallVector<Index, 4> dims;
65   tensor.shape().GetDimensions(&dims);
66 
67   auto tf_dtype = tensorflow::tfd::GetTfDataType(tensor.dtype());
68   tensorflow::Tensor tf_tensor(
69       tf_dtype, tensorflow::TensorShape(
70                     llvm::SmallVector<int64_t, 4>(dims.begin(), dims.end())));
71 
72   // This can be a DCHECK instead of returninng an error because TFRT's
73   // ScalarHostTensor only supports these types.
74   DCHECK(DataTypeCanUseMemcpy(tf_dtype));
75 
76   // TODO(tfrt-devs): Hide the following the logic of generating a full size
77   // buffer for the scalar host tensor under Tensor (and optimize if
78   // necessary), so we don't have to re-implement it every time we need it.
79   char* begin = reinterpret_cast<char*>(tf_tensor.data());
80   for (int i = 0; i < tf_tensor.NumElements(); ++i) {
81     std::memcpy(begin, tensor.data(), element_byte_size);
82     begin += element_byte_size;
83   }
84   return tf_tensor;
85 }
86 
ConvertTFDTypeToTFRTDType(tensorflow::DataType dtype)87 StatusOr<DType> ConvertTFDTypeToTFRTDType(tensorflow::DataType dtype) {
88   switch (dtype) {
89 #define DTYPE(TFRT_DTYPE, TF_DTYPE) \
90   case tensorflow::TF_DTYPE:        \
91     return DType(DType::TFRT_DTYPE);
92 #include "tensorflow/core/tfrt/utils/dtype.def"
93     default:
94       return tensorflow::errors::Internal(absl::StrCat(
95           "unsupported tensorflow dtype: ", tensorflow::DataType_Name(dtype)));
96   }
97 }
98 
ConvertTFRTDTypeToTFDType(DType dtype)99 StatusOr<tensorflow::DataType> ConvertTFRTDTypeToTFDType(DType dtype) {
100   switch (dtype) {
101 #define DTYPE(TFRT_DTYPE, TF_DTYPE) \
102   case DType::TFRT_DTYPE:           \
103     return tensorflow::TF_DTYPE;
104 #include "tensorflow/core/tfrt/utils/dtype.def"
105     default:
106       return tensorflow::errors::Internal(
107           StrCat("unsupported tfrt dtype: ", dtype));
108   }
109 }
110 
111 }  // namespace
112 
TFRTTensorToTFTensor(const Tensor & tensor,HostContext * host)113 llvm::Expected<tensorflow::Tensor> TFRTTensorToTFTensor(const Tensor& tensor,
114                                                         HostContext* host) {
115   if (auto* knfbt = llvm::dyn_cast<tensorflow::KernelFallbackTensor>(&tensor)) {
116     return *knfbt->GetTensor();
117   }
118   // TODO(tfrt-devs): The following logic should be better provided by
119   // Tensor so we don't have to re-implement it.
120   if (auto* dht = llvm::dyn_cast<DenseHostTensor>(&tensor)) {
121     return tensorflow::tfd::MoveHostBufferToTfTensor(
122         dht->buffer(), dht->dtype(), dht->shape());
123   }
124   if (auto* sht = llvm::dyn_cast<StringHostTensor>(&tensor)) {
125     return tensorflow::tfd::CopyShtToTfTensor(*sht);
126   }
127   if (auto* scalar = llvm::dyn_cast<AnyScalarHostTensor>(&tensor)) {
128     return CopyScalarHostTensorToTFTensor(*scalar);
129   }
130   return MakeStringError("Unsupported conversion format for ",
131                          tensor.tensor_type().name());
132 }
133 
TFTensorToTFRTTensorHandle(const tensorflow::Tensor & tf_tensor,HostContext * host_ctx)134 AsyncValueRef<TensorHandle> TFTensorToTFRTTensorHandle(
135     const tensorflow::Tensor& tf_tensor, HostContext* host_ctx) {
136   auto knfbt = MakeAvailableAsyncValueRef<tensorflow::KernelFallbackTensor>(
137       host_ctx, tf_tensor);
138   return MakeAvailableAsyncValueRef<TensorHandle>(
139       host_ctx, host_ctx->GetHostDeviceRef(), knfbt->metadata(),
140       std::move(knfbt));
141 }
142 
CreateTensorHandleFromTFTensor(const tensorflow::Tensor & tensor,HostContext * host)143 StatusOr<TensorHandle> CreateTensorHandleFromTFTensor(
144     const tensorflow::Tensor& tensor, HostContext* host) {
145   // TODO(chky): Handle non-trivial types such as strings.
146   TF_ASSIGN_OR_RETURN(auto dtype, ConvertTFDTypeToTFRTDType(tensor.dtype()));
147   auto shape = tensor.shape().dim_sizes();
148   TensorMetadata metadata(dtype, TensorShape(llvm::SmallVector<Index, 4>(
149                                      shape.begin(), shape.end())));
150 
151   if (dtype == DType::String) {
152     auto sht_ref =
153         StringHostTensor::MakeConstructedAsyncValueRef(metadata, host);
154     auto to = sht_ref->strings();
155     auto from = tensor.flat<tensorflow::tstring>();
156     for (int i = 0, e = to.size(); i < e; ++i) {
157       to[i] = from(i);
158     }
159     sht_ref.SetStateConcrete();
160     return TensorHandle(host->GetHostDeviceRef(), metadata, std::move(sht_ref));
161   }
162 
163   auto dht_ref = DenseHostTensor::MakeConstructedAsyncValueRef(metadata, host);
164 
165   auto& dht = dht_ref.get();
166   assert(dht.DataSizeInBytes() ==
167          tensor.NumElements() * tensorflow::DataTypeSize(tensor.dtype()));
168   std::memcpy(dht_ref.get().data(), tensor.data(), dht.DataSizeInBytes());
169 
170   dht_ref.SetStateConcrete();
171   return TensorHandle(host->GetHostDeviceRef(), metadata, std::move(dht_ref));
172 }
173 
CreateTFTensorFromTensorHandle(const TensorHandle & tensor_handle)174 StatusOr<tensorflow::Tensor> CreateTFTensorFromTensorHandle(
175     const TensorHandle& tensor_handle) {
176   const auto& metadata = tensor_handle.GetAvailableMetadata();
177   TF_ASSIGN_OR_RETURN(auto dtype, ConvertTFRTDTypeToTFDType(metadata.dtype));
178   llvm::SmallVector<Index, 4> shape;
179   metadata.shape.GetDimensions(&shape);
180   const auto& host_tensor = tensor_handle.GetAsyncTensor()->get<HostTensor>();
181 
182   if (auto* kernel_fallback_tensor =
183           llvm::dyn_cast<tensorflow::KernelFallbackTensor>(&host_tensor)) {
184     return *kernel_fallback_tensor->GetTensor();
185   }
186 
187   if (llvm::isa<StringHostTensor>(host_tensor)) {
188     assert(dtype == tensorflow::DT_STRING);
189     const auto& sht = llvm::cast<StringHostTensor>(host_tensor);
190     tensorflow::Tensor tensor(
191         tensorflow::DT_STRING,
192         tensorflow::TensorShape(
193             llvm::SmallVector<int64_t, 4>(shape.begin(), shape.end())));
194     auto from = sht.strings();
195     auto to = tensor.flat<tensorflow::tstring>();
196     for (int i = 0, e = from.size(); i < e; ++i) {
197       to(i).assign(from[i].data(), from[i].size());
198     }
199     return tensor;
200   }
201 
202   if (llvm::isa<DenseHostTensor>(host_tensor)) {
203     const auto& dht = llvm::cast<DenseHostTensor>(host_tensor);
204     tensorflow::Tensor tensor(
205         dtype, tensorflow::TensorShape(
206                    llvm::SmallVector<int64_t, 4>(shape.begin(), shape.end())));
207 
208     assert(dht.DataSizeInBytes() ==
209            tensor.NumElements() * tensorflow::DataTypeSize(tensor.dtype()));
210     std::memcpy(tensor.data(), dht.data(), dht.DataSizeInBytes());
211     return tensor;
212   }
213 
214   return tensorflow::errors::Internal("unknown host tensor type");
215 }
216 
217 }  // namespace tfrt
218