xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_H_
18 
19 #include <string>
20 #include <utility>
21 
22 #include "mlir/ExecutionEngine/CRunnerUtils.h"
23 #include "absl/time/time.h"
24 #include "absl/types/optional.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/PointerUnion.h"
28 #include "tensorflow/core/framework/allocation_description.pb.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/runtime_fallback/util/type_util.h"
31 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
32 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
33 
34 namespace tensorflow {
35 
36 // Record JitRt kernel compilation time for a given session name.
37 void RecordCompileTime(const std::string& model_name, const std::string& kernel,
38                        std::optional<size_t> specialization,
39                        absl::Duration compile_time);
40 
41 // A set of helper classes to convert results returned from the compiled
42 // functions (memrefs or async memrefs) to the Tensorflow Tensors that can be
43 // seamlessly passed to the Tensorflow fallback kernels.
44 
45 // MemrefTensorBuffer wraps a memref returned from the compiled kernel into
46 // the Tensorflow `TensorBuffer` that can be used to construct a `Tensor`.
47 class MemrefTensorBuffer : public TensorBuffer {
48  public:
MemrefTensorBuffer(void * base_ptr,void * data,size_t size,bool owner)49   MemrefTensorBuffer(void* base_ptr, void* data, size_t size, bool owner)
50       : TensorBuffer(data), base_ptr_(base_ptr), size_(size), owner_(owner) {}
51 
~MemrefTensorBuffer()52   ~MemrefTensorBuffer() override {
53     if (owner_) free(base_ptr_);
54   }
55 
FillAllocationDescription(AllocationDescription * proto)56   void FillAllocationDescription(AllocationDescription* proto) const override {
57     proto->set_requested_bytes(size());
58     proto->set_allocator_name("tf_jitrt");
59   }
60 
size()61   size_t size() const override { return size_; }
OwnsMemory()62   bool OwnsMemory() const override { return owner_; }
root_buffer()63   TensorBuffer* root_buffer() override { return this; }
64 
65  private:
66   void* base_ptr_;
67   size_t size_;
68   bool owner_;
69 };
70 
71 struct TensorflowConversionContext {
72   // Keep track of compiled kernel operands to detect input to output
73   // forwarding, and tensors returned multiple times.
74   using TensorOrBuffer = llvm::PointerUnion<const Tensor*, TensorBuffer*>;
75 
TensorflowConversionContextTensorflowConversionContext76   TensorflowConversionContext(size_t num_operands, size_t num_results)
77       : num_pending_results(num_results) {
78     runtime_tensors.reserve(num_operands + num_results - 1);
79   }
80 
81   // Ensure that the context is always moved around instead of copying.
82   TensorflowConversionContext(const TensorflowConversionContext&) = delete;
83   TensorflowConversionContext(TensorflowConversionContext&&) = default;
84 
85   // Memrefs that are already materialized as runtime tensors:
86   //   1. Tensor operands that we got from the caller.
87   //   2. Tensor buffers that we constructed for newly allocated memrefs.
88   llvm::SmallDenseMap<const void*, TensorOrBuffer> runtime_tensors;
89 
90   // The number of results that are waiting for the conversion.
91   size_t num_pending_results;
92 };
93 
94 namespace internal {
95 // The returned memref can point into statically allocated memory that we can't
96 // pass to `free` (memref.global). The LLVM lowering of `memref.global` sets the
97 // allocated pointer to the magic value 0xDEADBEEF.
98 template <typename T, int rank>
IsStaticStorageDuration(StridedMemRefType<T,rank> * memref)99 inline bool IsStaticStorageDuration(StridedMemRefType<T, rank>* memref) {
100   return reinterpret_cast<std::intptr_t>(memref->basePtr) == 0xDEADBEEF;
101 }
102 }  // namespace internal
103 
104 // Converts StridedMemrefType to the Tensor. This struct satisfies
105 // ReturnStridedMemref's concept (see jitrt.h).
106 struct ConvertTensor {
107   using ResultType = tfrt_stub::FallbackTensor;
108   using ConversionContext = TensorflowConversionContext;
109 
110   template <typename T, int rank>
SizesConvertTensor111   static llvm::ArrayRef<int64_t> Sizes(StridedMemRefType<T, rank>* memref) {
112     return memref->sizes;
113   }
114 
115   template <typename T>
SizesConvertTensor116   static llvm::ArrayRef<int64_t> Sizes(StridedMemRefType<T, 0>* memref) {
117     return {};
118   }
119 
120   template <typename T, int rank>
ConvertConvertTensor121   static Tensor Convert(ConversionContext& ctx, void* memref_ptr) {
122     auto* memref = static_cast<StridedMemRefType<T, rank>*>(memref_ptr);
123     auto memref_sizes = Sizes(memref);
124 
125     // Convert TFRT data type into Tensorflow data type.
126     auto dtype = tfd::GetTfDataType(tfrt::GetDType<T>());
127 
128     // Build a Tensorflow TensorShape from memref sizes.
129     TensorShape shape(memref_sizes);
130 
131     // Check if returned memref already has corresponding runtime tensor.
132     auto it = ctx.runtime_tensors.find(memref->data);
133     ConversionContext::TensorOrBuffer runtime_tensor =
134         it != ctx.runtime_tensors.end() ? it->second : nullptr;
135 
136     // Forward operand tensor to the result.
137     if (auto* operand = runtime_tensor.dyn_cast<const Tensor*>()) {
138       Tensor result;
139       auto st = result.BitcastFrom(*operand, dtype, shape);
140       assert(st.ok() && "failed to bitcast from forwarded tensor");
141       (void)st;
142       return result;
143     }
144 
145     // The same memref returned multiple times.
146     if (auto* buffer = runtime_tensor.dyn_cast<TensorBuffer*>()) {
147       buffer->Ref();
148       auto ptr = core::RefCountPtr<TensorBuffer>(buffer);
149       return Tensor(dtype, std::move(shape), std::move(ptr));
150     }
151 
152     // This is a newly allocated memref, and we need to wrap it into the runtime
153     // tensor buffer to pass it back to the caller as a Tensor.
154     size_t size = sizeof(T);
155     for (int64_t dim : memref_sizes) size *= dim;
156 
157     // Create a TensorBuffer from the returned memref.
158     TF_ANNOTATE_MEMORY_IS_INITIALIZED(memref->data, size);
159     auto* buffer = new MemrefTensorBuffer(
160         memref->basePtr, memref->data, size,
161         /*owner=*/!internal::IsStaticStorageDuration(memref));
162 
163     // Construct a tensor from the memory buffer.
164     auto ptr = core::RefCountPtr<MemrefTensorBuffer>(buffer);
165     Tensor tensor(dtype, std::move(shape), std::move(ptr));
166 
167     // Keep track of memrefs already used to construct runtime tensors.
168     if (--ctx.num_pending_results > 0)
169       ctx.runtime_tensors.try_emplace(memref->data, buffer);
170 
171     // Incorrect alignment will lead to a segfault in the downstream Tensorflow
172     // kernels, check it before returning to the runtime.
173     if (internal::IsStaticStorageDuration(memref)) {
174       DCHECK(tensor.IsAligned()) << "global memref is not aligned";
175     } else {
176       DCHECK(tensor.IsAligned()) << "allocated memref is not aligned";
177     }
178 
179     return tensor;
180   }
181 };
182 
183 }  // namespace tensorflow
184 
185 #endif  // TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_H_
186