xref: /aosp_15_r20/external/tensorflow/tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements TF runtime fallback tensor.
17 
18 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
19 
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "tensorflow/c/tensor_interface.h"
23 #include "tensorflow/c/tf_datatype.h"
24 #include "tensorflow/c/tf_tensor.h"
25 #include "tensorflow/c/tf_tensor_internal.h"
26 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/runtime_fallback/util/tensor_util.h"
30 #include "tensorflow/core/runtime_fallback/util/type_util.h"
31 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
32 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
33 #include "tfrt/host_context/host_buffer.h"  // from @tf_runtime
34 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
35 #include "tfrt/support/error_util.h"  // from @tf_runtime
36 #include "tfrt/support/ref_count.h"  // from @tf_runtime
37 #include "tfrt/tensor/conversion_registry.h"  // from @tf_runtime
38 #include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
39 #include "tfrt/tensor/string_host_tensor.h"  // from @tf_runtime
40 #include "tfrt/tensor/tensor_metadata.h"  // from @tf_runtime
41 
42 namespace tensorflow {
43 namespace tfd {
44 
45 using tfrt::DenseHostTensor;
46 using tfrt::DType;
47 using tfrt::Expected;
48 using tfrt::HostBuffer;
49 using tfrt::HostContext;
50 using tfrt::RCReference;
51 using tfrt::StringHostTensor;
52 using tfrt::Tensor;
53 using tfrt::TensorMetadata;
54 using tfrt::TensorShape;
55 
56 using OwnedTFStatus = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
57 
58 // If dtype is unsupported, only crash when converting this object to
59 // HostTensor.
RuntimeFallbackTensor(const TensorShape & shape,DType dtype,OwnedTensorHandle th)60 RuntimeFallbackTensor::RuntimeFallbackTensor(const TensorShape& shape,
61                                              DType dtype, OwnedTensorHandle th)
62     : Tensor(TensorMetadata(dtype, shape)), tensor_handle_{std::move(th)} {
63   assert(IsValid(dtype) && "Invalid dtype");
64 }
65 
GetShape(AbstractTensorInterface * tensor_interface)66 llvm::SmallVector<tfrt::Index, 4> GetShape(
67     AbstractTensorInterface* tensor_interface) {
68   llvm::SmallVector<tfrt::Index, 4> dims;
69   int64_t num_dims = tensor_interface->NumDims();
70   dims.reserve(num_dims);
71   for (int i = 0; i < num_dims; ++i) {
72     dims.push_back(tensor_interface->Dim(i));
73   }
74   return dims;
75 }
76 
CopyTfStringTensorToStringHostTensor(AbstractTensorInterface * tensor_interface,HostContext * host)77 Expected<StringHostTensor> CopyTfStringTensorToStringHostTensor(
78     AbstractTensorInterface* tensor_interface, HostContext* host) {
79   auto sht = StringHostTensor::CreateUninitialized(
80       TensorMetadata(DType(DType::String), GetShape(tensor_interface)), host);
81   if (!sht)
82     return tfrt::MakeStringError(
83         "failed to create uninitialized string tensor");
84 
85   assert(tensor_interface->Type() == DT_STRING);
86   const int64_t num_elems = tensor_interface->NumElements();
87   const tensorflow::tstring* tstrings =
88       reinterpret_cast<const tensorflow::tstring*>(tensor_interface->Data());
89 
90   auto strings = sht->strings();
91   for (int i = 0; i < num_elems; ++i) {
92     strings[i] = tstrings[i];
93   }
94 
95   return std::move(*sht);
96 }
97 
98 // TODO(jingdong): Format the tensor in more user-friendly format, especially
99 // for large tensors. See tensorflow::Tensor::DebugString().
Print(tfrt::raw_ostream & os) const100 void RuntimeFallbackTensor::Print(tfrt::raw_ostream& os) const {
101   tensorflow::Status status;
102   OwnedAbstractTensorInterface tensor_interface{
103       tensor_handle_->Resolve(&status)};
104   assert(status.ok());
105 
106   int rank = tensor_interface->NumDims();
107 
108   llvm::SmallVector<tfrt::Index, 4> dims;
109   for (auto i = 0; i < rank; ++i) {
110     dims.push_back(tensor_interface->Dim(i));
111   }
112 
113   DataType dtype = tensor_interface->Type();
114   os << "RuntimeFallbackTensor dtype = " << DataTypeString(dtype)
115      << ", shape = [";
116   llvm::interleaveComma(dims, os);
117   os << "], values = [";
118 
119   int64_t num_elements = tensor_interface->NumElements();
120   void* tensor_data = tensor_interface->Data();
121 
122   switch (dtype) {
123     case TF_DataType::TF_FLOAT:
124       PrintTensorValues<float>(tensor_data, num_elements, os);
125       break;
126     case TF_DataType::TF_DOUBLE:
127       PrintTensorValues<double>(tensor_data, num_elements, os);
128       break;
129     case TF_DataType::TF_INT32:
130       PrintTensorValues<int32_t>(tensor_data, num_elements, os);
131       break;
132     case TF_DataType::TF_INT64:
133       PrintTensorValues<int64_t>(tensor_data, num_elements, os);
134       break;
135     case TF_DataType::TF_INT8:
136       PrintTensorValues<int8_t>(tensor_data, num_elements, os);
137       break;
138     default:
139       os << "Unsupported tensor dtype " << dtype;
140       break;
141   }
142 
143   os << "]\n";
144 }
145 
146 tfrt::Expected<RuntimeFallbackTensor>
CreateRuntimeFallbackTensorFromTfTensorHandle(OwnedTensorHandle owned_th,HostContext * host)147 CreateRuntimeFallbackTensorFromTfTensorHandle(OwnedTensorHandle owned_th,
148                                               HostContext* host) {
149   int rank;
150   tensorflow::Status status = owned_th->NumDims(&rank);
151   if (!status.ok())
152     return tfrt::MakeStringError(tfrt::StrCat(
153         "error getting rank from TF tensor handle: ", status.error_message()));
154 
155   llvm::SmallVector<tfrt::Index, 4> dims;
156   for (auto i = 0; i < rank; ++i) {
157     int64_t dim;
158     status = owned_th->Dim(i, &dim);
159     if (!status.ok())
160       return tfrt::MakeStringError(
161           tfrt::StrCat("error getting dimension from TFE tensor handle: ",
162                        status.error_message()));
163     dims.push_back(dim);
164   }
165 
166   TensorShape shape{dims};
167   DataType dtype = owned_th->DataType();
168   return RuntimeFallbackTensor(shape, GetTfrtDtype(dtype), std::move(owned_th));
169 }
170 
MoveDHTToRuntimeFallbackTensor(DenseHostTensor && dht,HostContext * host)171 RuntimeFallbackTensor MoveDHTToRuntimeFallbackTensor(DenseHostTensor&& dht,
172                                                      HostContext* host) {
173   // TF_NewTensor takes the ownership of host_buffer.
174   RCReference<HostBuffer> host_buffer = dht.ReleaseBuffer();
175   tensorflow::Tensor tensor = MoveHostBufferToTfTensor(
176       std::move(host_buffer), dht.dtype(), dht.shape());
177 
178   // TODO(zhangqiaorjc): Use CreateLocalHandle with device args.
179   OwnedTensorHandle tensor_handle{
180       tensorflow::TensorHandle::CreateLocalHandle(tensor)};
181 
182   return RuntimeFallbackTensor(dht.shape(), dht.dtype(),
183                                std::move(tensor_handle));
184 }
185 
CopyRefDHTToRuntimeFallbackTensor(const DenseHostTensor & dht,HostContext * host)186 RuntimeFallbackTensor CopyRefDHTToRuntimeFallbackTensor(
187     const DenseHostTensor& dht, HostContext* host) {
188   // Do not copy the host buffer, TF_NewTensor simply CopyRef.
189   RCReference<HostBuffer> host_buffer = dht.buffer();
190   tensorflow::Tensor tensor = MoveHostBufferToTfTensor(
191       std::move(host_buffer), dht.dtype(), dht.shape());
192 
193   OwnedTensorHandle tensor_handle{
194       tensorflow::TensorHandle::CreateLocalHandle(tensor)};
195 
196   return RuntimeFallbackTensor(dht.shape(), dht.dtype(),
197                                std::move(tensor_handle));
198 }
199 
CopySHTToRuntimeFallbackTensor(const StringHostTensor & sht,HostContext * host)200 RuntimeFallbackTensor CopySHTToRuntimeFallbackTensor(
201     const StringHostTensor& sht, HostContext* host) {
202   tensorflow::Tensor tensor = CopyShtToTfTensor(sht);
203   OwnedTensorHandle tensor_handle{
204       tensorflow::TensorHandle::CreateLocalHandle(tensor)};
205 
206   return RuntimeFallbackTensor(sht.shape(), sht.dtype(),
207                                std::move(tensor_handle));
208 }
209 
210 }  // namespace tfd
211 }  // namespace tensorflow
212