1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file declares TF runtime fallback tensor. 17 18 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_TENSOR_H_ 19 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_TENSOR_H_ 20 21 #include "llvm/ADT/STLExtras.h" 22 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h" 23 #include "tfrt/support/forward_decls.h" // from @tf_runtime 24 #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime 25 #include "tfrt/tensor/host_tensor.h" // from @tf_runtime 26 #include "tfrt/tensor/string_host_tensor.h" // from @tf_runtime 27 #include "tfrt/tensor/tensor.h" // from @tf_runtime 28 29 namespace tensorflow { 30 namespace tfd { 31 32 class RuntimeFallbackTensor final 33 : public tfrt::Tensor, 34 public tfrt::TensorTraits<RuntimeFallbackTensor> { 35 public: 36 explicit RuntimeFallbackTensor(const tfrt::TensorShape& shape, 37 tfrt::DType dtype, OwnedTensorHandle th); 38 39 void Print(tfrt::raw_ostream& os) const override; 40 41 // Note that this method does not add ref to the return tensor_handle. GetTensorHandle()42 TensorHandle* GetTensorHandle() const { return tensor_handle_.get(); } 43 44 // Tensor type name for RuntimeFallbackTensor. name()45 static const char* name() { return "RuntimeFallback"; } 46 47 private: 48 template <typename T> PrintTensorValues(void * data,ssize_t size,llvm::raw_ostream & os)49 static void PrintTensorValues(void* data, ssize_t size, 50 llvm::raw_ostream& os) { 51 llvm::ArrayRef<T> elements = 52 llvm::makeArrayRef(static_cast<T*>(data), size); 53 llvm::interleaveComma(elements, os); 54 } 55 56 OwnedTensorHandle tensor_handle_; 57 }; 58 59 llvm::SmallVector<tfrt::Index, 4> GetShape( 60 AbstractTensorInterface* tensor_interface); 61 62 tfrt::Expected<tfrt::StringHostTensor> CopyTfStringTensorToStringHostTensor( 63 AbstractTensorInterface* tensor_interface, tfrt::HostContext* host); 64 65 tfrt::Expected<RuntimeFallbackTensor> 66 CreateRuntimeFallbackTensorFromTfTensorHandle(OwnedTensorHandle owned_th, 67 tfrt::HostContext* host); 68 69 RuntimeFallbackTensor MoveDHTToRuntimeFallbackTensor( 70 tfrt::DenseHostTensor&& dht, tfrt::HostContext* host); 71 72 RuntimeFallbackTensor CopyRefDHTToRuntimeFallbackTensor( 73 const tfrt::DenseHostTensor& dht, tfrt::HostContext* host); 74 75 RuntimeFallbackTensor CopySHTToRuntimeFallbackTensor( 76 const tfrt::StringHostTensor& sht, tfrt::HostContext* host); 77 78 } // namespace tfd 79 } // namespace tensorflow 80 81 #endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_TENSOR_H_ 82