1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_H_
18
19 #include <string>
20 #include <utility>
21
22 #include "mlir/ExecutionEngine/CRunnerUtils.h"
23 #include "absl/time/time.h"
24 #include "absl/types/optional.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/PointerUnion.h"
28 #include "tensorflow/core/framework/allocation_description.pb.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/runtime_fallback/util/type_util.h"
31 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
32 #include "tfrt/dtype/dtype.h" // from @tf_runtime
33
34 namespace tensorflow {
35
36 // Record JitRt kernel compilation time for a given session name.
37 void RecordCompileTime(const std::string& model_name, const std::string& kernel,
38 std::optional<size_t> specialization,
39 absl::Duration compile_time);
40
41 // A set of helper classes to convert results returned from the compiled
42 // functions (memrefs or async memrefs) to the Tensorflow Tensors that can be
43 // seamlessly passed to the Tensorflow fallback kernels.
44
45 // MemrefTensorBuffer wraps a memref returned from the compiled kernel into
46 // the Tensorflow `TensorBuffer` that can be used to construct a `Tensor`.
47 class MemrefTensorBuffer : public TensorBuffer {
48 public:
MemrefTensorBuffer(void * base_ptr,void * data,size_t size,bool owner)49 MemrefTensorBuffer(void* base_ptr, void* data, size_t size, bool owner)
50 : TensorBuffer(data), base_ptr_(base_ptr), size_(size), owner_(owner) {}
51
~MemrefTensorBuffer()52 ~MemrefTensorBuffer() override {
53 if (owner_) free(base_ptr_);
54 }
55
FillAllocationDescription(AllocationDescription * proto)56 void FillAllocationDescription(AllocationDescription* proto) const override {
57 proto->set_requested_bytes(size());
58 proto->set_allocator_name("tf_jitrt");
59 }
60
size()61 size_t size() const override { return size_; }
OwnsMemory()62 bool OwnsMemory() const override { return owner_; }
root_buffer()63 TensorBuffer* root_buffer() override { return this; }
64
65 private:
66 void* base_ptr_;
67 size_t size_;
68 bool owner_;
69 };
70
71 struct TensorflowConversionContext {
72 // Keep track of compiled kernel operands to detect input to output
73 // forwarding, and tensors returned multiple times.
74 using TensorOrBuffer = llvm::PointerUnion<const Tensor*, TensorBuffer*>;
75
TensorflowConversionContextTensorflowConversionContext76 TensorflowConversionContext(size_t num_operands, size_t num_results)
77 : num_pending_results(num_results) {
78 runtime_tensors.reserve(num_operands + num_results - 1);
79 }
80
81 // Ensure that the context is always moved around instead of copying.
82 TensorflowConversionContext(const TensorflowConversionContext&) = delete;
83 TensorflowConversionContext(TensorflowConversionContext&&) = default;
84
85 // Memrefs that are already materialized as runtime tensors:
86 // 1. Tensor operands that we got from the caller.
87 // 2. Tensor buffers that we constructed for newly allocated memrefs.
88 llvm::SmallDenseMap<const void*, TensorOrBuffer> runtime_tensors;
89
90 // The number of results that are waiting for the conversion.
91 size_t num_pending_results;
92 };
93
94 namespace internal {
95 // The returned memref can point into statically allocated memory that we can't
96 // pass to `free` (memref.global). The LLVM lowering of `memref.global` sets the
97 // allocated pointer to the magic value 0xDEADBEEF.
98 template <typename T, int rank>
IsStaticStorageDuration(StridedMemRefType<T,rank> * memref)99 inline bool IsStaticStorageDuration(StridedMemRefType<T, rank>* memref) {
100 return reinterpret_cast<std::intptr_t>(memref->basePtr) == 0xDEADBEEF;
101 }
102 } // namespace internal
103
104 // Converts StridedMemrefType to the Tensor. This struct satisfies
105 // ReturnStridedMemref's concept (see jitrt.h).
106 struct ConvertTensor {
107 using ResultType = tfrt_stub::FallbackTensor;
108 using ConversionContext = TensorflowConversionContext;
109
110 template <typename T, int rank>
SizesConvertTensor111 static llvm::ArrayRef<int64_t> Sizes(StridedMemRefType<T, rank>* memref) {
112 return memref->sizes;
113 }
114
115 template <typename T>
SizesConvertTensor116 static llvm::ArrayRef<int64_t> Sizes(StridedMemRefType<T, 0>* memref) {
117 return {};
118 }
119
120 template <typename T, int rank>
ConvertConvertTensor121 static Tensor Convert(ConversionContext& ctx, void* memref_ptr) {
122 auto* memref = static_cast<StridedMemRefType<T, rank>*>(memref_ptr);
123 auto memref_sizes = Sizes(memref);
124
125 // Convert TFRT data type into Tensorflow data type.
126 auto dtype = tfd::GetTfDataType(tfrt::GetDType<T>());
127
128 // Build a Tensorflow TensorShape from memref sizes.
129 TensorShape shape(memref_sizes);
130
131 // Check if returned memref already has corresponding runtime tensor.
132 auto it = ctx.runtime_tensors.find(memref->data);
133 ConversionContext::TensorOrBuffer runtime_tensor =
134 it != ctx.runtime_tensors.end() ? it->second : nullptr;
135
136 // Forward operand tensor to the result.
137 if (auto* operand = runtime_tensor.dyn_cast<const Tensor*>()) {
138 Tensor result;
139 auto st = result.BitcastFrom(*operand, dtype, shape);
140 assert(st.ok() && "failed to bitcast from forwarded tensor");
141 (void)st;
142 return result;
143 }
144
145 // The same memref returned multiple times.
146 if (auto* buffer = runtime_tensor.dyn_cast<TensorBuffer*>()) {
147 buffer->Ref();
148 auto ptr = core::RefCountPtr<TensorBuffer>(buffer);
149 return Tensor(dtype, std::move(shape), std::move(ptr));
150 }
151
152 // This is a newly allocated memref, and we need to wrap it into the runtime
153 // tensor buffer to pass it back to the caller as a Tensor.
154 size_t size = sizeof(T);
155 for (int64_t dim : memref_sizes) size *= dim;
156
157 // Create a TensorBuffer from the returned memref.
158 TF_ANNOTATE_MEMORY_IS_INITIALIZED(memref->data, size);
159 auto* buffer = new MemrefTensorBuffer(
160 memref->basePtr, memref->data, size,
161 /*owner=*/!internal::IsStaticStorageDuration(memref));
162
163 // Construct a tensor from the memory buffer.
164 auto ptr = core::RefCountPtr<MemrefTensorBuffer>(buffer);
165 Tensor tensor(dtype, std::move(shape), std::move(ptr));
166
167 // Keep track of memrefs already used to construct runtime tensors.
168 if (--ctx.num_pending_results > 0)
169 ctx.runtime_tensors.try_emplace(memref->data, buffer);
170
171 // Incorrect alignment will lead to a segfault in the downstream Tensorflow
172 // kernels, check it before returning to the runtime.
173 if (internal::IsStaticStorageDuration(memref)) {
174 DCHECK(tensor.IsAligned()) << "global memref is not aligned";
175 } else {
176 DCHECK(tensor.IsAligned()) << "allocated memref is not aligned";
177 }
178
179 return tensor;
180 }
181 };
182
183 } // namespace tensorflow
184
185 #endif // TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_H_
186