1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements TF runtime fallback tensor.
17
18 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
19
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "tensorflow/c/tensor_interface.h"
23 #include "tensorflow/c/tf_datatype.h"
24 #include "tensorflow/c/tf_tensor.h"
25 #include "tensorflow/c/tf_tensor_internal.h"
26 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/runtime_fallback/util/tensor_util.h"
30 #include "tensorflow/core/runtime_fallback/util/type_util.h"
31 #include "tfrt/dtype/dtype.h" // from @tf_runtime
32 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
33 #include "tfrt/host_context/host_buffer.h" // from @tf_runtime
34 #include "tfrt/host_context/host_context.h" // from @tf_runtime
35 #include "tfrt/support/error_util.h" // from @tf_runtime
36 #include "tfrt/support/ref_count.h" // from @tf_runtime
37 #include "tfrt/tensor/conversion_registry.h" // from @tf_runtime
38 #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
39 #include "tfrt/tensor/string_host_tensor.h" // from @tf_runtime
40 #include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime
41
42 namespace tensorflow {
43 namespace tfd {
44
45 using tfrt::DenseHostTensor;
46 using tfrt::DType;
47 using tfrt::Expected;
48 using tfrt::HostBuffer;
49 using tfrt::HostContext;
50 using tfrt::RCReference;
51 using tfrt::StringHostTensor;
52 using tfrt::Tensor;
53 using tfrt::TensorMetadata;
54 using tfrt::TensorShape;
55
56 using OwnedTFStatus = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
57
58 // If dtype is unsupported, only crash when converting this object to
59 // HostTensor.
RuntimeFallbackTensor(const TensorShape & shape,DType dtype,OwnedTensorHandle th)60 RuntimeFallbackTensor::RuntimeFallbackTensor(const TensorShape& shape,
61 DType dtype, OwnedTensorHandle th)
62 : Tensor(TensorMetadata(dtype, shape)), tensor_handle_{std::move(th)} {
63 assert(IsValid(dtype) && "Invalid dtype");
64 }
65
GetShape(AbstractTensorInterface * tensor_interface)66 llvm::SmallVector<tfrt::Index, 4> GetShape(
67 AbstractTensorInterface* tensor_interface) {
68 llvm::SmallVector<tfrt::Index, 4> dims;
69 int64_t num_dims = tensor_interface->NumDims();
70 dims.reserve(num_dims);
71 for (int i = 0; i < num_dims; ++i) {
72 dims.push_back(tensor_interface->Dim(i));
73 }
74 return dims;
75 }
76
CopyTfStringTensorToStringHostTensor(AbstractTensorInterface * tensor_interface,HostContext * host)77 Expected<StringHostTensor> CopyTfStringTensorToStringHostTensor(
78 AbstractTensorInterface* tensor_interface, HostContext* host) {
79 auto sht = StringHostTensor::CreateUninitialized(
80 TensorMetadata(DType(DType::String), GetShape(tensor_interface)), host);
81 if (!sht)
82 return tfrt::MakeStringError(
83 "failed to create uninitialized string tensor");
84
85 assert(tensor_interface->Type() == DT_STRING);
86 const int64_t num_elems = tensor_interface->NumElements();
87 const tensorflow::tstring* tstrings =
88 reinterpret_cast<const tensorflow::tstring*>(tensor_interface->Data());
89
90 auto strings = sht->strings();
91 for (int i = 0; i < num_elems; ++i) {
92 strings[i] = tstrings[i];
93 }
94
95 return std::move(*sht);
96 }
97
98 // TODO(jingdong): Format the tensor in more user-friendly format, especially
99 // for large tensors. See tensorflow::Tensor::DebugString().
Print(tfrt::raw_ostream & os) const100 void RuntimeFallbackTensor::Print(tfrt::raw_ostream& os) const {
101 tensorflow::Status status;
102 OwnedAbstractTensorInterface tensor_interface{
103 tensor_handle_->Resolve(&status)};
104 assert(status.ok());
105
106 int rank = tensor_interface->NumDims();
107
108 llvm::SmallVector<tfrt::Index, 4> dims;
109 for (auto i = 0; i < rank; ++i) {
110 dims.push_back(tensor_interface->Dim(i));
111 }
112
113 DataType dtype = tensor_interface->Type();
114 os << "RuntimeFallbackTensor dtype = " << DataTypeString(dtype)
115 << ", shape = [";
116 llvm::interleaveComma(dims, os);
117 os << "], values = [";
118
119 int64_t num_elements = tensor_interface->NumElements();
120 void* tensor_data = tensor_interface->Data();
121
122 switch (dtype) {
123 case TF_DataType::TF_FLOAT:
124 PrintTensorValues<float>(tensor_data, num_elements, os);
125 break;
126 case TF_DataType::TF_DOUBLE:
127 PrintTensorValues<double>(tensor_data, num_elements, os);
128 break;
129 case TF_DataType::TF_INT32:
130 PrintTensorValues<int32_t>(tensor_data, num_elements, os);
131 break;
132 case TF_DataType::TF_INT64:
133 PrintTensorValues<int64_t>(tensor_data, num_elements, os);
134 break;
135 case TF_DataType::TF_INT8:
136 PrintTensorValues<int8_t>(tensor_data, num_elements, os);
137 break;
138 default:
139 os << "Unsupported tensor dtype " << dtype;
140 break;
141 }
142
143 os << "]\n";
144 }
145
146 tfrt::Expected<RuntimeFallbackTensor>
CreateRuntimeFallbackTensorFromTfTensorHandle(OwnedTensorHandle owned_th,HostContext * host)147 CreateRuntimeFallbackTensorFromTfTensorHandle(OwnedTensorHandle owned_th,
148 HostContext* host) {
149 int rank;
150 tensorflow::Status status = owned_th->NumDims(&rank);
151 if (!status.ok())
152 return tfrt::MakeStringError(tfrt::StrCat(
153 "error getting rank from TF tensor handle: ", status.error_message()));
154
155 llvm::SmallVector<tfrt::Index, 4> dims;
156 for (auto i = 0; i < rank; ++i) {
157 int64_t dim;
158 status = owned_th->Dim(i, &dim);
159 if (!status.ok())
160 return tfrt::MakeStringError(
161 tfrt::StrCat("error getting dimension from TFE tensor handle: ",
162 status.error_message()));
163 dims.push_back(dim);
164 }
165
166 TensorShape shape{dims};
167 DataType dtype = owned_th->DataType();
168 return RuntimeFallbackTensor(shape, GetTfrtDtype(dtype), std::move(owned_th));
169 }
170
MoveDHTToRuntimeFallbackTensor(DenseHostTensor && dht,HostContext * host)171 RuntimeFallbackTensor MoveDHTToRuntimeFallbackTensor(DenseHostTensor&& dht,
172 HostContext* host) {
173 // TF_NewTensor takes the ownership of host_buffer.
174 RCReference<HostBuffer> host_buffer = dht.ReleaseBuffer();
175 tensorflow::Tensor tensor = MoveHostBufferToTfTensor(
176 std::move(host_buffer), dht.dtype(), dht.shape());
177
178 // TODO(zhangqiaorjc): Use CreateLocalHandle with device args.
179 OwnedTensorHandle tensor_handle{
180 tensorflow::TensorHandle::CreateLocalHandle(tensor)};
181
182 return RuntimeFallbackTensor(dht.shape(), dht.dtype(),
183 std::move(tensor_handle));
184 }
185
CopyRefDHTToRuntimeFallbackTensor(const DenseHostTensor & dht,HostContext * host)186 RuntimeFallbackTensor CopyRefDHTToRuntimeFallbackTensor(
187 const DenseHostTensor& dht, HostContext* host) {
188 // Do not copy the host buffer, TF_NewTensor simply CopyRef.
189 RCReference<HostBuffer> host_buffer = dht.buffer();
190 tensorflow::Tensor tensor = MoveHostBufferToTfTensor(
191 std::move(host_buffer), dht.dtype(), dht.shape());
192
193 OwnedTensorHandle tensor_handle{
194 tensorflow::TensorHandle::CreateLocalHandle(tensor)};
195
196 return RuntimeFallbackTensor(dht.shape(), dht.dtype(),
197 std::move(tensor_handle));
198 }
199
CopySHTToRuntimeFallbackTensor(const StringHostTensor & sht,HostContext * host)200 RuntimeFallbackTensor CopySHTToRuntimeFallbackTensor(
201 const StringHostTensor& sht, HostContext* host) {
202 tensorflow::Tensor tensor = CopyShtToTfTensor(sht);
203 OwnedTensorHandle tensor_handle{
204 tensorflow::TensorHandle::CreateLocalHandle(tensor)};
205
206 return RuntimeFallbackTensor(sht.shape(), sht.dtype(),
207 std::move(tensor_handle));
208 }
209
210 } // namespace tfd
211 } // namespace tensorflow
212