xref: /aosp_15_r20/external/tensorflow/tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file declares TF runtime fallback tensor.
17 
18 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_TENSOR_H_
19 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_TENSOR_H_
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
23 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
24 #include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
25 #include "tfrt/tensor/host_tensor.h"  // from @tf_runtime
26 #include "tfrt/tensor/string_host_tensor.h"  // from @tf_runtime
27 #include "tfrt/tensor/tensor.h"  // from @tf_runtime
28 
29 namespace tensorflow {
30 namespace tfd {
31 
32 class RuntimeFallbackTensor final
33     : public tfrt::Tensor,
34       public tfrt::TensorTraits<RuntimeFallbackTensor> {
35  public:
36   explicit RuntimeFallbackTensor(const tfrt::TensorShape& shape,
37                                  tfrt::DType dtype, OwnedTensorHandle th);
38 
39   void Print(tfrt::raw_ostream& os) const override;
40 
41   // Note that this method does not add ref to the return tensor_handle.
GetTensorHandle()42   TensorHandle* GetTensorHandle() const { return tensor_handle_.get(); }
43 
44   // Tensor type name for RuntimeFallbackTensor.
name()45   static const char* name() { return "RuntimeFallback"; }
46 
47  private:
48   template <typename T>
PrintTensorValues(void * data,ssize_t size,llvm::raw_ostream & os)49   static void PrintTensorValues(void* data, ssize_t size,
50                                 llvm::raw_ostream& os) {
51     llvm::ArrayRef<T> elements =
52         llvm::makeArrayRef(static_cast<T*>(data), size);
53     llvm::interleaveComma(elements, os);
54   }
55 
56   OwnedTensorHandle tensor_handle_;
57 };
58 
59 llvm::SmallVector<tfrt::Index, 4> GetShape(
60     AbstractTensorInterface* tensor_interface);
61 
62 tfrt::Expected<tfrt::StringHostTensor> CopyTfStringTensorToStringHostTensor(
63     AbstractTensorInterface* tensor_interface, tfrt::HostContext* host);
64 
65 tfrt::Expected<RuntimeFallbackTensor>
66 CreateRuntimeFallbackTensorFromTfTensorHandle(OwnedTensorHandle owned_th,
67                                               tfrt::HostContext* host);
68 
69 RuntimeFallbackTensor MoveDHTToRuntimeFallbackTensor(
70     tfrt::DenseHostTensor&& dht, tfrt::HostContext* host);
71 
72 RuntimeFallbackTensor CopyRefDHTToRuntimeFallbackTensor(
73     const tfrt::DenseHostTensor& dht, tfrt::HostContext* host);
74 
75 RuntimeFallbackTensor CopySHTToRuntimeFallbackTensor(
76     const tfrt::StringHostTensor& sht, tfrt::HostContext* host);
77 
78 }  // namespace tfd
79 }  // namespace tensorflow
80 
81 #endif  // TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_TENSOR_H_
82