1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_UTILS_FALLBACK_TENSOR_H_ 16 #define TENSORFLOW_CORE_TFRT_UTILS_FALLBACK_TENSOR_H_ 17 18 #include "absl/types/variant.h" 19 #include "tensorflow/core/framework/tensor.h" 20 21 namespace tensorflow { 22 namespace tfrt_stub { 23 24 // A special tensor wrapper for immutable tensors that live a long time and are 25 // reused across steps in a program, eg. weights. 26 class ImmutableTensor { 27 public: 28 ImmutableTensor() = default; 29 // Create an ImmutableTensor by copying the content in `tensor`. 30 static ImmutableTensor Create(tensorflow::Tensor tensor); 31 32 // Accessors for this underlying tensor. Users must not modify its content. It 33 // is guaranteed that RefCountIsOne() always return false for the tensor. tensor()34 tensorflow::Tensor& tensor() { return tensor_; } tensor()35 const tensorflow::Tensor& tensor() const { return tensor_; } 36 37 private: ImmutableTensor(tensorflow::Tensor tensor)38 explicit ImmutableTensor(tensorflow::Tensor tensor) 39 : tensor_(std::move(tensor)) { 40 DCHECK(!tensor_.RefCountIsOne()) 41 << "Immutable tensors' buffers cannot be forwarded."; 42 } 43 44 tensorflow::Tensor tensor_; 45 }; 46 47 // A wrapper class over normal tensors and immutable tensors. This class is used 48 // as the currency type in TFRT fallback execution. Note that this class does 49 // not own the underlying tensor if it is an immutable tensor. 50 class FallbackTensor { 51 public: 52 FallbackTensor() = default; 53 FallbackTensor(const tensorflow::Tensor & tensor)54 explicit FallbackTensor(const tensorflow::Tensor& tensor) : tensor_(tensor) {} FallbackTensor(tensorflow::Tensor && tensor)55 explicit FallbackTensor(tensorflow::Tensor&& tensor) 56 : tensor_(std::move(tensor)) {} 57 FallbackTensor(ImmutableTensor * immutable_tensor)58 explicit FallbackTensor(ImmutableTensor* immutable_tensor) 59 : tensor_(immutable_tensor) {} 60 is_immutable()61 bool is_immutable() const { 62 return absl::holds_alternative<ImmutableTensor*>(tensor_); 63 } 64 tensor()65 tensorflow::Tensor& tensor() { 66 if (is_immutable()) return absl::get<ImmutableTensor*>(tensor_)->tensor(); 67 return absl::get<tensorflow::Tensor>(tensor_); 68 } tensor()69 const tensorflow::Tensor& tensor() const { 70 return const_cast<FallbackTensor*>(this)->tensor(); 71 } 72 73 private: 74 absl::variant<absl::monostate, tensorflow::Tensor, ImmutableTensor*> tensor_; 75 }; 76 77 } // namespace tfrt_stub 78 } // namespace tensorflow 79 80 #endif // TENSORFLOW_CORE_TFRT_UTILS_FALLBACK_TENSOR_H_ 81