1 #include <c10/core/alignment.h>
2 #include <torch/csrc/jit/runtime/static/memory_planner.h>
3
4 #include <ATen/Tensor.h>
5 #include <torch/csrc/jit/ir/alias_analysis.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/runtime/static/impl.h>
8 #include <iterator>
9
10 namespace torch::jit {
11
12 namespace {
13
isUnmanagedSpecialCase(const ProcessedNode & pnode,size_t output_idx)14 bool isUnmanagedSpecialCase(const ProcessedNode& pnode, size_t output_idx) {
15 DCHECK(output_idx < pnode.outputs().size());
16 static const auto to_maybe_copy_out_symbol =
17 c10::Symbol::fromQualString("static_runtime::to_maybe_copy_out");
18 // Heuristic and special case:
19 // If to_maybe_copy_out did not actually do anything in the
20 // first iteration, assume it will continue to not do anything
21 // and avoid managing its output.
22 return pnode.node()->kind() == to_maybe_copy_out_symbol &&
23 pnode.Output(output_idx).isNone();
24 }
25
tensorValueToTensor(const std::vector<ProcessedNode> & nodes,const c10::FastSet<const Value * > & managed_tensor_values)26 c10::FastMap<const Value*, at::Tensor*> tensorValueToTensor(
27 const std::vector<ProcessedNode>& nodes,
28 const c10::FastSet<const Value*>& managed_tensor_values) {
29 c10::FastMap<const Value*, at::Tensor*> tensor_value_to_tensor;
30 for (auto& pnode : nodes) {
31 auto* node = pnode.node();
32 for (const auto output_idx : c10::irange(node->outputs().size())) {
33 auto* output = node->output(output_idx);
34
35 if (managed_tensor_values.find(output) == managed_tensor_values.end()) {
36 continue;
37 }
38
39 auto& ival = pnode.Output(output_idx);
40
41 // ival is allowed to be None in special cases, e.g. to_maybe_copy_out
42 DCHECK(
43 ival.isTensor() ||
44 (ival.isNone() && isUnmanagedSpecialCase(pnode, output_idx)));
45
46 if (ival.isTensor()) {
47 tensor_value_to_tensor.emplace(
48 output,
49 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
50 const_cast<at::Tensor*>(&ival.toTensor()));
51 }
52 }
53 }
54 return tensor_value_to_tensor;
55 }
56
57 // Don't change the size if it is already aligned, otherwise increase the size
58 // to make it aligned.
compute_aligned_tensor_size(size_t nbytes)59 size_t compute_aligned_tensor_size(size_t nbytes) {
60 // Note: everything below is size_t
61 return (nbytes + c10::gAlignment - 1) & (~(c10::gAlignment - 1));
62 }
63
allocate_buffer(size_t size)64 at::DataPtr allocate_buffer(size_t size) {
65 at::Allocator* allocator = c10::GetCPUCachingAllocator();
66 return allocator->allocate(size);
67 }
68
69 } // namespace
70
assignStorageToManagedTensors(graph_node_list nodes,const ManagedTensorRanges & ranges,const c10::FastMap<const Value *,at::Tensor * > & tensor_value_to_tensor)71 std::vector<StorageGroup> assignStorageToManagedTensors(
72 graph_node_list nodes,
73 const ManagedTensorRanges& ranges,
74 const c10::FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor) {
75 std::vector<StorageGroup> managed_tensor_groups;
76 // This set maps each Value* to its assigned storage group.
77 c10::FastMap<const Value*, size_t> storage_group_mapping;
78 // On each iteration, this vector stores the set of storage groups that
79 // are available for re-use.
80 std::vector<size_t> free_storage_groups;
81
82 auto makeNewStorageGroup = [&](const Value* value) {
83 const auto storage_group = managed_tensor_groups.size();
84 storage_group_mapping.emplace(value, storage_group);
85 auto* tensor_ptr = tensor_value_to_tensor.at(value);
86 managed_tensor_groups.emplace_back(tensor_ptr);
87 };
88
89 auto assignToAvailableStorageGroup = [&](const Value* value) {
90 DCHECK(!free_storage_groups.empty());
91 const auto storage_group = free_storage_groups.back();
92 TORCH_DCHECK_LT(storage_group, managed_tensor_groups.size());
93 storage_group_mapping.emplace(value, storage_group);
94 auto* tensor_ptr = tensor_value_to_tensor.at(value);
95 managed_tensor_groups[storage_group].addTensor(tensor_ptr);
96 free_storage_groups.pop_back();
97 };
98
99 auto isManagedTensor = [&](const Value* value) {
100 return tensor_value_to_tensor.find(value) != tensor_value_to_tensor.end();
101 };
102
103 for (auto* node : nodes) {
104 // Assign storage groups to outputs
105 for (const auto output_idx : c10::irange(node->outputs().size())) {
106 Value* output = node->output(output_idx);
107 if (!isManagedTensor(output)) {
108 continue;
109 }
110 if (free_storage_groups.empty()) {
111 makeNewStorageGroup(output);
112 continue;
113 }
114 assignToAvailableStorageGroup(output);
115 }
116
117 // This node may be the last use of some managed tensors. If so, we
118 // can mark the corresponding storage groups as free.
119 if (ranges.nodeFreesManagedTensors(node)) {
120 const auto& new_free_tensors =
121 ranges.availableTensorValuesAfterNode(node);
122 for (auto* tensor_value : new_free_tensors) {
123 // We need to check this here to handle special cases like
124 // to_maybe_copy_out. We don't know if the tensor value is managed until
125 // after the first iter, but `ranges` is initialized at load time!
126 if (!isManagedTensor(tensor_value)) {
127 continue;
128 }
129 const auto storage_group = storage_group_mapping.at(tensor_value);
130 free_storage_groups.push_back(storage_group);
131 }
132 }
133 }
134 return managed_tensor_groups;
135 }
136
ManagedStorages()137 ManagedStorages::ManagedStorages()
138 : storages_(nullptr), size_(0), capacity_(0) {}
139
~ManagedStorages()140 ManagedStorages::~ManagedStorages() {
141 deallocate();
142 }
143
allocate(size_t capacity)144 void ManagedStorages::allocate(size_t capacity) {
145 TORCH_CHECK(!is_allocated(), "Must deallocate before allocating again");
146 // `size_` should already be 0 if not allocated, so double check it here
147 TORCH_INTERNAL_ASSERT(size_ == 0);
148 capacity_ = capacity;
149 storages_ = reinterpret_cast<at::StorageImpl*>(
150 new unsigned char[capacity_ * sizeof(at::StorageImpl)]);
151 }
152
deallocate()153 void ManagedStorages::deallocate() {
154 if (is_allocated()) {
155 for (const size_t idx : c10::irange(size_)) {
156 storages_[idx].~StorageImpl();
157 }
158 delete[] reinterpret_cast<unsigned char*>(storages_);
159 capacity_ = 0;
160 size_ = 0;
161 storages_ = nullptr;
162 }
163 }
164
append(at::StorageImpl & storageImpl)165 void ManagedStorages::append(at::StorageImpl& storageImpl) {
166 TORCH_INTERNAL_ASSERT(size_ < capacity_);
167 new (&storages_[size_]) at::StorageImpl(
168 at::StorageImpl::use_byte_size_t(),
169 storageImpl.nbytes(),
170 storageImpl.allocator(),
171 storageImpl.resizable());
172 size_++;
173 }
174
175 namespace {
176
setIncludes(const c10::FastSet<const Value * > & set,const Value * v)177 bool setIncludes(const c10::FastSet<const Value*>& set, const Value* v) {
178 return set.find(v) != set.end();
179 }
180
assignStorageToOutputTensors(BlockRunner * block_runner,const c10::FastSet<const Value * > & managed_output_tensor_values)181 std::vector<std::pair<size_t, at::Tensor*>> assignStorageToOutputTensors(
182 BlockRunner* block_runner,
183 const c10::FastSet<const Value*>& managed_output_tensor_values) {
184 std::vector<std::pair<size_t, at::Tensor*>> managed_output_tensors;
185 for (auto& pnode : block_runner->nodes()) {
186 for (const auto i : c10::irange(pnode.outputs().size())) {
187 auto& ival = pnode.Output(i);
188 const auto* val = pnode.node()->outputs()[i];
189 if (!setIncludes(managed_output_tensor_values, val) ||
190 isUnmanagedSpecialCase(pnode, i)) {
191 continue;
192 }
193 TORCH_CHECK(ival.isTensor());
194 at::Tensor* tensor = &ival.toTensor();
195 managed_output_tensors.emplace_back(0, tensor);
196 }
197 }
198 return managed_output_tensors;
199 }
200
201 } // namespace
202
MemoryPlanner(BlockRunner * block_runner,const BlockInfo & block_info,bool enable_out_variant,bool manage_output_tensors)203 MemoryPlanner::MemoryPlanner(
204 BlockRunner* block_runner,
205 const BlockInfo& block_info,
206 bool enable_out_variant,
207 bool manage_output_tensors) {
208 const auto& managed_tensor_values = block_info.managed_tensor_values();
209 const auto& managed_output_tensor_values =
210 block_info.managed_output_tensor_values();
211 const auto& leaked_values = block_info.leaked_values();
212
213 // collect unmanaged output ivalues
214 c10::FastSet<IValue*> unmanaged_ivalues;
215 c10::FastSet<IValue*> unmanaged_borrowed_ivalues;
216 for (ProcessedNode& pnode : block_runner->nodes()) {
217 const auto borrows_outputs = borrowsOutputs(pnode.node()->kind());
218 for (const auto i : c10::irange(pnode.outputs().size())) {
219 const Value* out_v = pnode.node()->outputs()[i];
220 const bool in_managed_tensors = setIncludes(managed_tensor_values, out_v);
221 const bool is_unmanaged_special_case = isUnmanagedSpecialCase(pnode, i);
222 if (in_managed_tensors && !is_unmanaged_special_case) {
223 ++num_managed_tensors_;
224 }
225 const bool in_managed_sets = in_managed_tensors ||
226 // Manage output tensors might have been turned off, so we have to
227 // check the flag here
228 (manage_output_tensors &&
229 setIncludes(managed_output_tensor_values, out_v)) ||
230 setIncludes(leaked_values, out_v);
231
232 if (in_managed_sets && !is_unmanaged_special_case) {
233 continue;
234 }
235 if (doesNotHeapAllocateWhenStoredInIValue(*out_v->type())) {
236 // Scalars do not need to be freed after each iteration.
237 num_unmanaged_scalar_ivalues_++;
238 } else if (borrows_outputs) {
239 IValue& out = pnode.Output(i);
240 unmanaged_borrowed_ivalues.insert(&out);
241 } else {
242 IValue& out = pnode.Output(i);
243 unmanaged_ivalues.insert(&out);
244 }
245 }
246 }
247 for (IValue* output : block_runner->outputs()) {
248 auto it = unmanaged_borrowed_ivalues.find(output);
249 if (it != unmanaged_borrowed_ivalues.end()) {
250 borrowed_ivalues_needing_incref_.push_back(output);
251 unmanaged_borrowed_ivalues.erase(it);
252 } else {
253 unmanaged_ivalues.erase(output);
254 }
255 }
256
257 // copy to unmanaged_ivalues_
258 unmanaged_ivalues_.reserve(unmanaged_ivalues.size());
259 unmanaged_ivalues_.insert(
260 unmanaged_ivalues_.begin(),
261 unmanaged_ivalues.begin(),
262 unmanaged_ivalues.end());
263 unmanaged_borrowed_ivalues_.reserve(unmanaged_borrowed_ivalues.size());
264 unmanaged_borrowed_ivalues_.insert(
265 unmanaged_borrowed_ivalues_.begin(),
266 unmanaged_borrowed_ivalues.begin(),
267 unmanaged_borrowed_ivalues.end());
268
269 if (enable_out_variant && manage_output_tensors) {
270 managed_output_tensors_ = assignStorageToOutputTensors(
271 block_runner, managed_output_tensor_values);
272 }
273 }
274
allocateBuffer(size_t num_bytes)275 uint8_t* MemoryPlanner::allocateBuffer(size_t num_bytes) {
276 buffer_ = allocate_buffer(num_bytes);
277 uint8_t* start = static_cast<uint8_t*>(buffer_.get());
278 buffer_start_ = start;
279 buffer_end_ = start + num_bytes;
280 return start;
281 }
282
allocateOutputTensors()283 void MemoryPlanner::allocateOutputTensors() {
284 if (output_buffer_bytes_ == 0) {
285 return;
286 }
287 TORCH_CHECK(
288 !output_buffer_,
289 "Previously allocated output_buffer_ was not deallocated properly.");
290 output_buffer_ = allocate_buffer(output_buffer_bytes_);
291
292 size_t offset = 0;
293 uint8_t* start = static_cast<uint8_t*>(output_buffer_.get());
294
295 for (const auto& ms : managed_output_tensors_) {
296 auto tensor_size = ms.first;
297 auto* tensor = ms.second;
298 if (tensor_size == 0) {
299 continue;
300 }
301 TORCH_DCHECK_LE(offset + tensor_size, output_buffer_bytes_);
302 void* src = static_cast<void*>(start + offset);
303 // NOTE: Populating `ctx` enables clients to take the ownership of a
304 // tensor managed by Static Runtime. Some clients use "move" semantics to
305 // pass a Tensor object to another holding object (e.g., a thrift message)
306 // to avoid `memcpy`.
307 // `torch::distributed::detail::WireDumpOp::dumpTensorData is a concrete
308 // example of doing this (See `torch::distributed::detail::hasDeleter`).
309 // Since this output Tensor object is permanently owned by Static Runtime,
310 // this ownership passing does *not* have an intended effect of keeping the
311 // Tensor alive till the "owner" releases it: A premature call to
312 // `StaticRuntime::deallocateOutputTensors` can destruct such a Tensor
313 // object that a holding object believes to retain, causing it to read
314 // corrupted values from an already destructed Tensor object. Therefore, a
315 // client of receiving Static Runtime-managed Tensors needs to be very
316 // careful to call `StaticRuntime::deallocateOutputTensors` after these
317 // holding objects are gone.
318 tensor->storage().set_data_ptr_noswap(
319 at::DataPtr(src, /*ctx=*/src, nullptr, tensor->device()));
320 tensor->storage().set_nbytes(tensor_size);
321 offset += tensor_size;
322 }
323 TORCH_DCHECK_EQ(offset, output_buffer_bytes_);
324 }
325
allocate()326 void MemoryPlanner::allocate() {
327 // TODO: Improve this once D31357486 is landed.
328 allocateManagedTensors();
329 allocateOutputTensors();
330 }
331
deallocate()332 void MemoryPlanner::deallocate() {
333 for (auto& iv : borrowed_ivalues_needing_incref_) {
334 auto old = std::move(*iv);
335 *iv = IValue(old);
336 c10::MaybeOwnedTraits<c10::IValue>::destroyBorrow(old);
337 }
338 // for unmanaged ivalues (either tensor or non-tensor), we reset the *iv so
339 // that the objects pointed to by *iv may be reclaimed by reference counting
340 for (auto& iv : unmanaged_ivalues_) {
341 *iv = IValue();
342 }
343 for (auto& iv : unmanaged_borrowed_ivalues_) {
344 c10::MaybeOwnedTraits<c10::IValue>::destroyBorrow(*iv);
345 }
346 // It's important to call this function after all other owning refs
347 // of the managed StorageImpls are cleaned up. It can reset the
348 // the StorageImpl's refcount to (# tensors in storage group),
349 // so destructing any owning refs afterwards will bring the refcount
350 // lower than expected and trigger the debug assertion in
351 // ~intrusive_ptr_target.
352 deallocateManagedTensors();
353 buffer_ = {};
354 }
355
deallocateOutputTensors()356 void MemoryPlanner::deallocateOutputTensors() {
357 size_t output_buffer_bytes = 0;
358 for (auto& ms : managed_output_tensors_) {
359 auto* tensor = ms.second;
360 size_t current_size =
361 compute_aligned_tensor_size(tensor->storage().nbytes());
362 tensor->storage().unsafeGetStorageImpl()->reset();
363 if (current_size > ms.first) {
364 ms.first = current_size;
365 }
366 output_buffer_bytes += ms.first;
367 }
368 output_buffer_bytes_ = output_buffer_bytes;
369 output_buffer_ = {};
370 }
371
StandardMemoryPlanner(BlockRunner * block_runner,const BlockInfo & block_info,bool enable_out_variant,bool manage_output_tensors,bool optimize_memory)372 StandardMemoryPlanner::StandardMemoryPlanner(
373 BlockRunner* block_runner,
374 const BlockInfo& block_info,
375 bool enable_out_variant,
376 bool manage_output_tensors,
377 bool optimize_memory)
378 : MemoryPlanner(
379 block_runner,
380 block_info,
381 enable_out_variant,
382 manage_output_tensors) {
383 const auto& managed_tensor_values = block_info.managed_tensor_values();
384 if (enable_out_variant) {
385 const auto tensor_value_to_tensor =
386 tensorValueToTensor(block_runner->nodes(), managed_tensor_values);
387 if (optimize_memory) {
388 managed_tensors_ = assignStorageToManagedTensors(
389 block_info.node_ptrs(),
390 block_info.managed_tensor_ranges(),
391 tensor_value_to_tensor);
392 } else {
393 for (auto& tensor : tensor_value_to_tensor) {
394 managed_tensors_.emplace_back(tensor.second);
395 }
396 }
397 }
398 }
399
allocateManagedTensors()400 void StandardMemoryPlanner::allocateManagedTensors() {
401 if (managed_bytes_ == 0) {
402 return;
403 }
404 DCHECK(!storages_.empty());
405 size_t offset = 0;
406 auto* start = allocateBuffer(managed_bytes_);
407
408 reused_tensors_ = 0;
409 size_t group_idx = 0;
410 for (const size_t storages_idx : c10::irange(storages_.size())) {
411 auto tensor_size = storages_nbytes_[storages_idx];
412 if (tensor_size == 0) {
413 group_idx++;
414 continue;
415 }
416 at::StorageImpl* storageImpl = &storages_[storages_idx];
417 TORCH_DCHECK_LE(offset + tensor_size, managed_bytes_);
418 void* src = static_cast<void*>(start + offset);
419
420 #ifndef NDEBUG
421 TORCH_DCHECK_EQ(tensor_size, managed_tensors_[group_idx].maxTensorSize());
422 for (auto* tensor : managed_tensors_[group_idx].group()) {
423 TORCH_DCHECK_EQ(storageImpl, tensor->storage().unsafeGetStorageImpl());
424 }
425 #endif
426 TORCH_DCHECK_NE(managed_tensors_[group_idx].numManagedTensors(), 0);
427 reused_tensors_ += managed_tensors_[group_idx].numManagedTensors() - 1;
428 storageImpl->set_data_ptr_noswap(
429 at::DataPtr(src, src, nullptr, c10::Device(c10::DeviceType::CPU)));
430 storageImpl->set_nbytes(tensor_size);
431
432 offset += tensor_size;
433 group_idx++;
434 }
435 TORCH_DCHECK_EQ(offset, managed_bytes_);
436 }
437
deallocateManagedTensors()438 void StandardMemoryPlanner::deallocateManagedTensors() {
439 managed_bytes_ = 0;
440 // free memory used by outputs of ops in out variants
441 // but keep the TensorImpl and StorageImpl around.
442
443 // We don't have any guarantee that the model doesn't change the
444 // Storage for managed tensors out from under us during execution,
445 // so we have to check the Storages each time we deallocate.
446 unsigned group_idx = 0;
447 const bool first_time = storages_.empty();
448 if (C10_UNLIKELY(first_time)) {
449 if (storages_.is_allocated()) {
450 storages_.deallocate();
451 }
452 storages_.allocate(managed_tensors_.size());
453 storages_nbytes_.reserve(managed_tensors_.size());
454 }
455 for (auto& ms : managed_tensors_) {
456 const auto& tensors = ms.group();
457 size_t max = ms.maxTensorSize();
458 for (auto& tensor : tensors) {
459 const auto& storage = tensor->storage();
460 size_t current_size = compute_aligned_tensor_size(storage.nbytes());
461 at::StorageImpl* tensorStorageImpl = storage.unsafeGetStorageImpl();
462 if (C10_UNLIKELY(first_time)) {
463 tensorStorageImpl->reset();
464
465 DCHECK(
466 storages_.size() == group_idx || storages_.size() == group_idx + 1);
467 if (storages_.size() == group_idx) {
468 storages_.append(*tensorStorageImpl);
469 storages_nbytes_.emplace_back(0);
470 }
471 at::StorageImpl* newImpl = &storages_[storages_.size() - 1];
472
473 // We want to manage StorageImpls' lifetimes ourselves, but TensorImpl
474 // expects to refcount them. unsafe_adapt_non_heap_allocated is our
475 // escape hatch: it sets the reference count for the StorageImpl to an
476 // impractically high value so that it will never get deallocated by
477 // intrusive_ptr, leaving us free to manage its lifetime as we see fit.
478 // (Note that allowing it to be deallocated by intrusive_ptr would be
479 // UB, because that would entail deleting an object that wasn't
480 // allocated with operator new.)
481 //
482 // For more information, see the doc comment for
483 // intrusive_ptr::unsafe_adapt_non_heap_allocated.
484 tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage(
485 c10::intrusive_ptr<at::StorageImpl>::
486 unsafe_adapt_non_heap_allocated(newImpl, tensors.size())));
487 } else if (C10_UNLIKELY(tensorStorageImpl != &storages_[group_idx])) {
488 tensorStorageImpl->reset();
489
490 // If somehow the tensor got different storage, put it back to
491 // the shared impl for this group.
492 tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(
493 at::Storage(c10::intrusive_ptr<at::StorageImpl>::
494 unsafe_adapt_non_heap_allocated(
495 &storages_[group_idx], tensors.size())));
496 }
497 TORCH_DCHECK_EQ(
498 tensor->storage().unsafeGetStorageImpl(), &storages_[group_idx]);
499 max = std::max(max, current_size);
500 }
501 // Static runtime does not know the size of tensors statically, so we use
502 // the tensor size from the previous run to allocate tensors for the next
503 // run (following C2 tradition), exploiting the fact that tensor storage
504 // size does not have to match that of real tensor size. The following logic
505 // records the tensor storage size for the next run.
506 storages_nbytes_[group_idx++] = max;
507 ms.setMaxTensorSize(max);
508 managed_bytes_ += max;
509 }
510
511 TORCH_DCHECK_EQ(storages_.size(), managed_tensors_.size());
512 VLOG(1) << "managed_bytes: " << managed_bytes_;
513 }
514
515 } // namespace torch::jit
516