xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/utils/fallback_tensor.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_UTILS_FALLBACK_TENSOR_H_
16 #define TENSORFLOW_CORE_TFRT_UTILS_FALLBACK_TENSOR_H_
17 
18 #include "absl/types/variant.h"
19 #include "tensorflow/core/framework/tensor.h"
20 
21 namespace tensorflow {
22 namespace tfrt_stub {
23 
24 // A special tensor wrapper for immutable tensors that live a long time and are
25 // reused across steps in a program, eg. weights.
26 class ImmutableTensor {
27  public:
28   ImmutableTensor() = default;
29   // Create an ImmutableTensor by copying the content in `tensor`.
30   static ImmutableTensor Create(tensorflow::Tensor tensor);
31 
32   // Accessors for this underlying tensor. Users must not modify its content. It
33   // is guaranteed that RefCountIsOne() always return false for the tensor.
tensor()34   tensorflow::Tensor& tensor() { return tensor_; }
tensor()35   const tensorflow::Tensor& tensor() const { return tensor_; }
36 
37  private:
ImmutableTensor(tensorflow::Tensor tensor)38   explicit ImmutableTensor(tensorflow::Tensor tensor)
39       : tensor_(std::move(tensor)) {
40     DCHECK(!tensor_.RefCountIsOne())
41         << "Immutable tensors' buffers cannot be forwarded.";
42   }
43 
44   tensorflow::Tensor tensor_;
45 };
46 
47 // A wrapper class over normal tensors and immutable tensors. This class is used
48 // as the currency type in TFRT fallback execution. Note that this class does
49 // not own the underlying tensor if it is an immutable tensor.
50 class FallbackTensor {
51  public:
52   FallbackTensor() = default;
53 
FallbackTensor(const tensorflow::Tensor & tensor)54   explicit FallbackTensor(const tensorflow::Tensor& tensor) : tensor_(tensor) {}
FallbackTensor(tensorflow::Tensor && tensor)55   explicit FallbackTensor(tensorflow::Tensor&& tensor)
56       : tensor_(std::move(tensor)) {}
57 
FallbackTensor(ImmutableTensor * immutable_tensor)58   explicit FallbackTensor(ImmutableTensor* immutable_tensor)
59       : tensor_(immutable_tensor) {}
60 
is_immutable()61   bool is_immutable() const {
62     return absl::holds_alternative<ImmutableTensor*>(tensor_);
63   }
64 
tensor()65   tensorflow::Tensor& tensor() {
66     if (is_immutable()) return absl::get<ImmutableTensor*>(tensor_)->tensor();
67     return absl::get<tensorflow::Tensor>(tensor_);
68   }
tensor()69   const tensorflow::Tensor& tensor() const {
70     return const_cast<FallbackTensor*>(this)->tensor();
71   }
72 
73  private:
74   absl::variant<absl::monostate, tensorflow::Tensor, ImmutableTensor*> tensor_;
75 };
76 
77 }  // namespace tfrt_stub
78 }  // namespace tensorflow
79 
80 #endif  // TENSORFLOW_CORE_TFRT_UTILS_FALLBACK_TENSOR_H_
81