1
2 #include <ATen/FunctionalTensorWrapper.h>
3
4 #include <ATen/FunctionalInverses.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/WrapDimUtils.h>
7 #include <ATen/core/IListRef.h>
8 #include <ATen/core/LegacyTypeDispatch.h>
9 #include <c10/util/Exception.h>
10
11 #include <c10/util/irange.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/_propagate_xla_data.h>
17 #include <ATen/ops/_to_copy.h>
18 #endif
19
20 namespace at {
21
set_constructor_metadata()22 void FunctionalTensorWrapper::set_constructor_metadata() {
23 TORCH_INTERNAL_ASSERT(value_.defined());
24 // Note: "level" is a concept that we don't know how to compute in core.
25 // For now I'm retroactively setting this in functorch,
26 // but once Open Multiple Dispatch lands we should be able to calculate this in core.
27 level_ = -1;
28 // mirror all of the generic tensor metadata onto the wrapper
29 copy_generic_tensor_metadata(value_.getIntrusivePtr().get(), this);
30 refresh_numel();
31 refresh_contiguous();
32 storage_access_should_throw_ = false;
33 // In general, the sizes/stride metadata on a tensor can change as it is mutated,
34 // and these changes need to be reflected in the metadata of the wrapper.
35 set_allow_tensor_metadata_change(true);
36 key_set_ = c10::DispatchKeySet(c10::DispatchKey::Functionalize) | value_.key_set();
37 // All of the keys corresponding to functorch transforms should not be copied over.
38 // Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect
39 // to participate in the functorch transforms.
40 key_set_ = key_set_ - c10::functorch_transforms_ks - c10::python_ks;
41 // We override a bunch of _custom(), so make sure they get called
42 // TODO: metadata copying may not actually be necessary then
43 set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
44 set_custom_device(true);
45 // E.g. when running torch.compile under inference mode, we need to make sure that
46 // for any inputs that were created outside of inference mode (so they are not inference tensors),
47 // then the functional wrappers that we wrap them with should also not be inference tensors.
48 version_counter_ = value_.unsafeGetTensorImpl()->version_counter();
49 }
50
FunctionalTensorWrapper(const Tensor & value)51 FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
52 : c10::TensorImpl(
53 c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(value)),
54 c10::DispatchKeySet(DispatchKey::Functionalize) | value.key_set(),
55 value.dtype()
56 ),
57 value_(value)
58 {
59 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
60 TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
61 set_constructor_metadata();
62 }
63
freeze_storage() const64 void FunctionalTensorWrapper::freeze_storage() const {
65 functional_storage_impl()->freeze();
66 }
67
68 // Note [Functionalization: Alias Removal]
69 // When someone calls a view() op during the functionalization pass, e.g. 'b = a.view(...)',
70 // we link `b` and `a` to a shared Alias object to preserve the aliasing relationship.
71 //
72 // How do we do that?
73 //
74 // Every FunctionalTensorWrapper contains a dummy FunctionalStorageImpl, which subclasses from c10::StorageImpl.
75 // It doesn't contain any data (similar to MetaTensor storage), but it contains an Alias object that knows about the base tensor.
76 // When a tensor is created through a view operation, both the new and old tensor point to the same FunctionalStorageImpl.
77 //
78 // As mutations are applied to any of the views, we also queue each mutation up on the Alias object, so we can replay them.
79 // When the user requests a tensor that's had a view taken, we check if it's up to date.
80 // If it's not up to date, we first replay all of the queued up mutations onto the alias, and then re-apply the current view
81 // on top of the newly updated alias.
82 //
83 // Why do we queue up and lazily run mutations on the alias, instead of updating the alias eagerly?
84 // This behavior was taken from pytorch/xla, which the alias-removal logic was inspired from.
85 // One benefit of the laziness is that we save work in the cases where a user has multiple views and mutates one of them,
86 // but never uses the other views later in the program (in which case we'll never update the alias).
87 // It also has downsides though: repeatedly applying mutations to the same view without syncing
88 // will silently use up more and more memory as more mutations are queued up.
89 //
90 // Corresponding diagram:
91 //
92 // b = a.view(...)
93 //
94 // a b
95 // | | If the user asks for b and it’s out of date,
96 // \/ \/ We regenerate b by replaying it’s views from the alias.
97 // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - .
98 // | FunctionalTensorWrapper | | FunctionalTensorWrapper |
99 // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - .
100 // | value | storage | | storage | Value |
101 // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - .
102 // | \ / |
103 // | \ / |
104 // | . - - - - - - - - - - - - . |
105 // | | FunctionalStorageImpl | |
106 // | . - - - - - - - - - - - - . |
107 // | | Alias | |
108 // | . - - - - - - - - - - - - . |
109 // | / mutations to a or b |
110 // | / are queued onto Alias |
111 // | / |
112 // \/ / \/
113 // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
114 // | TensorImpl | | TensorImpl |
115 // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
116 // | value | storage | | storage | Value |
117 // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
118 // | |
119 // | |
120 // | |
121 // | In this picture the two tensor views their own storages, |
122 // | have their own storages, but backends like functorch |
123 // \/ are allowed to re-alias underneath the pass \/
124 // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
125 // | underyling_storage | | underyling_storage |
126 // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
127 //
128 // This constructor is only used by view ops.
129 // - view_value: The output tensor that we need to wrap.
130 // - base: The "base" of the view that `view_value` was generated from.
131 // See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic.
FunctionalTensorWrapper(const Tensor & view_value,const FunctionalTensorWrapper * base,const functionalization::ViewMeta & meta)132 FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta)
133 : c10::TensorImpl(
134 c10::DispatchKeySet(DispatchKey::Functionalize),
135 view_value.dtype(),
136 view_value.device()
137 ),
138 value_(view_value),
139 is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output),
140 was_storage_changed_(base->was_storage_changed_),
141 is_symbolic_(base->is_symbolic_)
142 {
143 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
144 TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
145 set_constructor_metadata();
146 // Copy the original tensor's ViewMeta vector and push the current one.
147 if (!base->view_metas_.empty()) {
148 view_metas_ = base->view_metas_; // copy
149 }
150 view_metas_.push_back(meta);
151 maybe_mark_symbolic(meta);
152 storage_ = base->storage_; // alias this tensor's storage with the base tensor's
153 }
154
155
functional_storage_impl() const156 functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const {
157 return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl());
158 }
159
commit_update()160 void FunctionalTensorWrapper::commit_update() {
161 auto storage_impl = functional_storage_impl();
162 storage_impl->add_update(value_, view_metas_);
163 // As an optimization, we used to mark the tensor here as "up-to-date",
164 // That way, code like:
165 // x = torch.ones(1'000'000)
166 // x[0].add_(1)
167 // doesn't result in an unnecessary materialization of the base.
168 // This optimization results in the slice temporarily haven't incorrect
169 // stride/storage_offset though, and DCE should handle that optimization anyway.
170 // generation_ = storage_impl->generation();
171 }
172
is_up_to_date() const173 bool FunctionalTensorWrapper::is_up_to_date() const {
174 auto alias_generation = functional_storage_impl()->generation();
175 return generation_ == alias_generation;
176 }
177
178 // See Note [Functionalization Pass - Inplace View Ops]
mutate_view_meta(const at::functionalization::ViewMeta & meta)179 void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) {
180 view_metas_.push_back(meta);
181 // Manually track the fact that this tensor recieved a metadata mutation!
182 has_metadata_mutation_ = true;
183 // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation.
184 maybe_mark_symbolic(meta);
185 // Note [Functionalization Pass - Inplace View Ops]
186 // So, these ops are special - they're mutation AND view ops. They get special codegen.
187 // An example is transpose_, e.g. `a.transpose_()`
188 // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
189 at::AutoDispatchSkipFunctionalize guard;
190 value_ = meta.forward_fn(value_, meta.out_index);
191 TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
192 }
193
194 // Note [Functionalization: Mutation Removal]
195 // Mutation removal is used to take a program like this:
196 //
197 // a.add_(b)
198 //
199 // and replace it with a slightly different program that has the same semantics:
200 //
201 // tmp = a.add(b)
202 // a.replace_(tmp)
203 //
204 // Where the replace_() call is implemented directly in the functionalization pass, so it is transparent to the backend.
205 // This is useful for backends that aren't able to handle certain types of mutations, like functorch.
206 //
207 // Why do we need to wrap every tensor in a FunctionalTensorWrapper? Consider this program:
208 //
209 // Before:
210 // tensor.add_(batched_tensor)
211 //
212 // After:
213 // tmp = tensor.add(batched_tensor)
214 // tensor.replace_(tmp)
215 //
216 // In the above, tmp is a batched tensor (because adding a normal tensor to a batched tensor does broadcasting and creates a batched tensor).
217 // But we can't just replace the underlying memory backing `tensor` with `tmp` - a batched tensor takes up more space!
218 // Instead, every input, intermediate and output of the program is wrapped in a FunctionalTensorImpl, which wraps the underlying tensor.
replace_(const Tensor & other,bool from_lazy_regenerate)219 void FunctionalTensorWrapper::replace_(const Tensor& other, bool from_lazy_regenerate) {
220 // TODO: going to need to change this if we want nested functionalize() transforms.
221 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other));
222 value_ = other;
223 TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
224 // out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor.
225 // We need to propagate that metadata mutation to the wrapper (new size).
226 auto sizes_ = value_.sym_sizes();
227 auto strides_ = value_.sym_strides();
228 auto storage_offset_ = value_.sym_storage_offset();
229 set_sizes_and_strides(sizes_, strides_, storage_offset_);
230 if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
231 // .to() should not re-entrantly go through functionalization.
232 at::AutoDispatchSkipFunctionalize guard;
233 // and we want _to_copy() to show up in the graph, not the composite .to() operator
234 // (this can happen if autograd has already run by the time we enter this code)
235 value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
236 TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
237 }
238 // might not be until after the no_grad region is exited.
239 // Therefore, replace_() is not unconditionally safe to check the current no_grad state.
240 // If this is a lazy regeneration, then it is guaranteed that we have already
241 // done the mutation for the storage alias (when we originally performed the mutation),
242 // so no counter update may be needed.
243 // Example: if a mutation happens to a view under a no_grad,
244 // we won't call replace_() on the other alias until the alias is later used, which
245 if (!from_lazy_regenerate) {
246 mark_mutation();
247 if (!at::GradMode::is_enabled() || InferenceMode::is_enabled()) {
248 // This mutation happened under no_grad or inference_mode
249 mark_mutation_during_no_grad_or_inference_mode();
250 }
251 }
252 }
253
has_data_mutation()254 bool FunctionalTensorWrapper::has_data_mutation() {
255 // Current tensor's data was mutated if its storage saw any mutations.
256 return functional_storage_impl()->generation() > 0;
257 }
258
set__impl(const FunctionalTensorWrapper * other)259 void FunctionalTensorWrapper::set__impl(const FunctionalTensorWrapper* other) {
260 // self.set_(src) will cause self to have all of the tensor properties of self.
261 value_ = other->value_;
262 generation_ = other->generation_;
263 view_metas_ = other->view_metas_;
264 is_symbolic_ = other->is_symbolic_;
265 // FREEZE the old storage, preventing mutations to it.
266 // this is a huge pain to handle properly in all cases, so we ban it.
267 functional_storage_impl()->freeze();
268 // Unsafely swap out the storage with other's storage,
269 // disconnecting `self` with its view chain
270 storage_ = other->storage_;
271 /// explicitly mark the tensor as having its storage changed from set_()
272 // Otherwise, we don't actually have a 100% accurate way to check this.
273 // (We could check if the updated value has a new storage than the original value,
274 // but this won't also let us uniquely determine if the tensor **also**
275 // experienced a data mutation).
276 was_storage_changed_ = true;
277
278 auto sizes_ = value_.sym_sizes();
279 auto strides_ = value_.sym_strides();
280 auto storage_offset_ = value_.sym_storage_offset();
281 set_sizes_and_strides(sizes_, strides_, storage_offset_);
282 }
283
storage_resize_(const c10::SymInt & new_size)284 void FunctionalTensorWrapper::storage_resize_(const c10::SymInt& new_size) {
285 auto curr_storage_size = value_.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl()->sym_nbytes();
286 // storage resizing is severely limited: we only support resizing either to zero, or from zero bytes.
287 TORCH_CHECK(new_size == 0 || curr_storage_size == 0, "new_size: ", new_size, ". curr_storage_size: ", curr_storage_size);
288 // The "functionalization rule" for storage resizing is a giant no-op, mainly because we don't want
289 // resize_() calls to actualy emit any ops in the functional graph.
290 // How does it work?
291 // Resizing up (old size == 0):
292 // We do nothing in this case.
293 // The expection is that for the user code to be valid, the next op that should run against the current tensor "x"
294 // will be a x.copy_(y) (or similar), that will fully overwrite the data of x.
295 // If there are any outstanding aliases of x, we expect them not to be used until after the copy_() call
296 // (otherwise the eager code would be invalid),
297 // and therefore functionalization will regenerate the aliases off of the result of `x.copy(y)`.
298 // Resizing down (new size == 0):
299 // We also do nothing in this case. The assumption is that after resizing a tensor down,
300 // it is fully unused in the program (unless it is later resized back up first, has data copied in)
301 // Although it might be saved for backward, which happens in FSDP.
302 // The expected pattern is that the param will then be resized back up from zero in the backward.
303
304 // Mark the tensor as having its storage resized.
305 // This is so we can detect it for inputs in AOTAutograd and error / emit
306 // an input mutation resize_() appropriately
307 functional_storage_impl()->mark_inductor_storage_resize(new_size);
308 }
309
maybe_replace_storage(const Tensor & other)310 void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
311 // Note [resize_() in functionalization pass]
312 // resize_() is a special operator in functionalization because it can reallocate its underlying storage.
313 // This function is only ever called in the case that resize_() needs to reallocate its storage to a larger size.
314 //
315 // However, functionalization currently bans the following code:
316 // a = torch.ones(2)
317 // b = a.view(2)
318 // b.resize_(4) # b is a view tensor, that we are trying to increase the storage size of
319 //
320 // Why is this code difficult to handle?
321 // The functionalization pass currently keeps aliases in sync by making the following assumptions:
322 // - The “base” tensor always refers to “all of the data”
323 // - Whenever you have b = view_op(a), “b” should always refer to a subset of “a”s memory.
324 //
325 // The code above breaks that assumption b.resize_(4) actually needs to update "a"
326 // to tell it that it is now actually some slice of a pre-existing larger storage.
327 // We're also no longer re-generate "b" fully from "a" anymore, since "a" refers to a slice of "b"'s data.
328 //
329 // This is probably fixable in theory, but:
330 // - the fix would likey complicated the functionalization logic quite a bit.
331 // - the primary use case for resize_() today is resizing zero-sized tensors in out= variants of operators
332 // - resize_() also can give you weird results today if you try to resize_() a weirdly strided tensor.
333 //
334 // Given all of the above, for now we're just banning the above usage.
335 TORCH_CHECK(storage().use_count() == 1, "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass");
336 TORCH_CHECK(view_metas_.empty(), "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass");
337 // If this tensor is not a view (and has no outstanding views taken out on it),
338 // Then it's safe to throw out the old storage and replace it with the new, larger one.
339 storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(other));
340 value_ = other;
341 TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
342 generation_ = 0;
343 // And update the metadata on the wrapper to reflect the new sizes and strides
344 set_sizes_and_strides(value_.sizes(), value_.strides());
345 refresh_numel();
346 // (Technically we should be guaranteed that the tensor was already contiguous,
347 // since it's guaranteed not to have been a view. Doesnt hurt to run though)
348 refresh_contiguous();
349 // Swapping out the storage of a tensor (aka from a resize_() call) will update the sizes and strides of the tensor,
350 // so we need to record the fact that metadata was mutated.
351 has_metadata_mutation_ = true;
352 }
353
_unsafe_reset_storage()354 void FunctionalTensorWrapper::_unsafe_reset_storage() {
355 // Reset the storage with the current value_ tensor as the base
356 storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(value_));
357 // Reset the generation so that it matches the new storage
358 generation_ = 0;
359 // Clear any pre-existing view metas so that base and value_ are semantically the same
360 view_metas_.clear();
361 }
362
sync_()363 void FunctionalTensorWrapper::sync_() {
364 if (is_up_to_date()) {
365 return;
366 }
367 apply_updates();
368 regenerate_from_base();
369 }
370
apply_view_metas(const Tensor & base)371 Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) {
372 auto t = base;
373
374 // Reapply views to get the viewed tensor from the base in alias_
375 for (auto& view_meta: view_metas_) {
376 t = view_meta.forward_fn(t, view_meta.out_index);
377 }
378
379 return t;
380 }
381
regenerate_from_base()382 void FunctionalTensorWrapper::regenerate_from_base() {
383 at::AutoDispatchSkipFunctionalize guard;
384 auto storage_impl = functional_storage_impl();
385 auto t = storage_impl->base();
386
387 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
388 t = apply_view_metas(t);
389 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
390
391 replace_(t, /*from_lazy_regenerate=*/true);
392 generation_ = storage_impl->generation();
393 }
394
apply_updates()395 bool FunctionalTensorWrapper::apply_updates() {
396 // Apply all updates on alias_
397 auto storage_impl = functional_storage_impl();
398 return storage_impl->apply_updates();
399 }
400
tensorimpl_type_name() const401 const char* FunctionalTensorWrapper::tensorimpl_type_name() const {
402 return "FunctionalTensorWrapper";
403 }
404
copy_tensor_metadata(const FunctionalTensorWrapper * src_impl,FunctionalTensorWrapper * dest_impl,const c10::VariableVersion & version_counter,bool allow_tensor_metadata_change)405 void FunctionalTensorWrapper::copy_tensor_metadata(
406 const FunctionalTensorWrapper* src_impl,
407 FunctionalTensorWrapper* dest_impl,
408 const c10::VariableVersion& version_counter,
409 bool allow_tensor_metadata_change) {
410 TensorImpl::copy_tensor_metadata(
411 src_impl,
412 dest_impl,
413 version_counter,
414 allow_tensor_metadata_change);
415
416 // FunctionalTensorWrapper-specific fields.
417 dest_impl->value_ = src_impl->value_;
418 dest_impl->level_ = src_impl->level_;
419 dest_impl->has_metadata_mutation_ = src_impl->has_metadata_mutation_;
420 dest_impl->is_multi_output_view_ = src_impl->is_multi_output_view_;
421 dest_impl->was_storage_changed_ = src_impl->was_storage_changed_;
422 dest_impl->is_symbolic_ = src_impl->is_symbolic_;
423 dest_impl->generation_ = src_impl->generation_;
424 dest_impl->view_metas_ = src_impl->view_metas_;
425 }
426
427
copy_tensor_metadata_and_refresh(const FunctionalTensorWrapper * src_impl,FunctionalTensorWrapper * dest_impl,const c10::VariableVersion & version_counter,bool allow_tensor_metadata_change) const428 void FunctionalTensorWrapper::copy_tensor_metadata_and_refresh(
429 const FunctionalTensorWrapper* src_impl,
430 FunctionalTensorWrapper* dest_impl,
431 const c10::VariableVersion& version_counter,
432 bool allow_tensor_metadata_change) const {
433 copy_tensor_metadata(src_impl, dest_impl, version_counter, allow_tensor_metadata_change);
434 dest_impl->refresh_numel();
435 dest_impl->refresh_contiguous();
436 }
437
438 template <typename VariableVersion>
shallow_copy_and_detach_core(VariableVersion && version_counter,bool allow_tensor_metadata_change) const439 c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach_core(
440 VariableVersion&& version_counter,
441 bool allow_tensor_metadata_change) const {
442 if (key_set_.has(DispatchKey::Python) &&
443 !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
444 auto r = pyobj_slot_.load_pyobj_interpreter()->detach(this);
445 if (r) {
446 r->set_version_counter(std::forward<VariableVersion>(version_counter));
447 r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
448 return r;
449 }
450 }
451
452 auto impl = c10::make_intrusive<FunctionalTensorWrapper>(value_);
453 copy_tensor_metadata_and_refresh(
454 /*src_impl=*/this,
455 /*dest_impl=*/impl.get(),
456 /*version_counter=*/std::forward<VariableVersion>(version_counter),
457 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
458 return impl;
459 }
460
shallow_copy_and_detach(const c10::VariableVersion & version_counter,bool allow_tensor_metadata_change) const461 c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
462 const c10::VariableVersion& version_counter,
463 bool allow_tensor_metadata_change) const {
464 return shallow_copy_and_detach_core(
465 version_counter, allow_tensor_metadata_change);
466 }
467
shallow_copy_and_detach(c10::VariableVersion && version_counter,bool allow_tensor_metadata_change) const468 c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
469 c10::VariableVersion&& version_counter,
470 bool allow_tensor_metadata_change) const {
471 return shallow_copy_and_detach_core(
472 std::move(version_counter), allow_tensor_metadata_change);
473 }
474
shallow_copy_from(const c10::intrusive_ptr<TensorImpl> & impl)475 void FunctionalTensorWrapper::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
476 AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
477 auto functional_impl =
478 static_cast<FunctionalTensorWrapper*>(impl.get());
479 copy_tensor_metadata_and_refresh(
480 /*src_impl=*/functional_impl,
481 /*dest_impl=*/this,
482 /*version_counter=*/version_counter(),
483 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
484 }
485
486
device_custom() const487 c10::Device FunctionalTensorWrapper::device_custom() const {
488 return value_.unsafeGetTensorImpl()->device();
489 }
sizes_custom() const490 at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const {
491 return value_.unsafeGetTensorImpl()->sizes();
492 }
strides_custom() const493 at::IntArrayRef FunctionalTensorWrapper::strides_custom() const {
494 return value_.unsafeGetTensorImpl()->strides();
495 }
dim_custom() const496 int64_t FunctionalTensorWrapper::dim_custom() const {
497 return value_.unsafeGetTensorImpl()->dim();
498 }
numel_custom() const499 int64_t FunctionalTensorWrapper::numel_custom() const {
500 return value_.unsafeGetTensorImpl()->numel();
501 }
is_contiguous_custom(at::MemoryFormat memory_format) const502 bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const {
503 return value_.unsafeGetTensorImpl()->is_contiguous(memory_format);
504 }
sym_sizes_custom() const505 c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const {
506 return value_.unsafeGetTensorImpl()->sym_sizes();
507 }
sym_strides_custom() const508 c10::SymIntArrayRef FunctionalTensorWrapper::sym_strides_custom() const {
509 return value_.unsafeGetTensorImpl()->sym_strides();
510 }
sym_size_custom(int64_t d) const511 c10::SymInt FunctionalTensorWrapper::sym_size_custom(int64_t d) const {
512 return value_.unsafeGetTensorImpl()->sym_size(d);
513 }
sym_storage_offset_custom() const514 c10::SymInt FunctionalTensorWrapper::sym_storage_offset_custom() const {
515 return value_.unsafeGetTensorImpl()->sym_storage_offset();
516 }
layout_impl() const517 c10::Layout FunctionalTensorWrapper::layout_impl() const {
518 return value_.unsafeGetTensorImpl()->layout();
519 }
520
521 namespace functionalization {
522 namespace impl {
523
to_functional_tensor(const Tensor & tensor)524 Tensor to_functional_tensor(const Tensor& tensor) {
525 // Note [Wrapped Numbers <> Functionalization]
526 if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
527 return tensor;
528 }
529 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isFunctionalTensor(tensor));
530 return at::detail::make_tensor<FunctionalTensorWrapper>(tensor);
531 }
to_functional_tensor(const std::optional<Tensor> & tensor)532 std::optional<Tensor> to_functional_tensor(const std::optional<Tensor>& tensor) {
533 if (tensor.has_value()) {
534 return std::make_optional<Tensor>(to_functional_tensor(*tensor));
535 }
536 return std::nullopt;
537 }
to_functional_tensor(const c10::List<::std::optional<Tensor>> & t_list)538 c10::List<::std::optional<Tensor>> to_functional_tensor(const c10::List<::std::optional<Tensor>>& t_list) {
539 c10::List<::std::optional<Tensor>> outputs;
540 outputs.reserve(t_list.size());
541 for (const auto i : c10::irange(t_list.size())) {
542 outputs.push_back(to_functional_tensor(t_list[i]));
543 }
544 return outputs;
545 }
to_functional_tensor(ITensorListRef t_list)546 std::vector<Tensor> to_functional_tensor(ITensorListRef t_list) {
547 std::vector<Tensor> outputs;
548 outputs.reserve(t_list.size());
549 for (const auto& tensor : t_list) {
550 outputs.push_back(to_functional_tensor(tensor));
551 }
552 return outputs;
553 }
554
from_functional_tensor(const Tensor & tensor,bool assert_functional)555 Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional) {
556 // Note [Wrapped Numbers <> Functionalization]
557 if (!tensor.defined() || tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
558 return tensor;
559 }
560 if (isFunctionalTensor(tensor)) {
561 auto impl = unsafeGetFunctionalWrapper(tensor);
562 return impl->value();
563 } else {
564 // If the current tensor is not functional, then raise an error
565 // if assert_functional is true. Otherwise, return the input.
566 TORCH_INTERNAL_ASSERT(!assert_functional)
567 return tensor;
568 }
569 }
from_functional_tensor(const std::optional<Tensor> & t,bool assert_functional)570 std::optional<Tensor> from_functional_tensor(const std::optional<Tensor>& t, bool assert_functional) {
571 if (t.has_value()) {
572 return std::make_optional<Tensor>(from_functional_tensor(*t, assert_functional));
573 }
574 return std::nullopt;
575 }
from_functional_tensor(ITensorListRef t_list)576 std::vector<Tensor> from_functional_tensor(ITensorListRef t_list) {
577 std::vector<Tensor> outputs;
578 outputs.reserve(t_list.size());
579 for (const auto& tensor : t_list) {
580 // from_functional_tensor(Tensor) has asserts to make sure you don't accidentally call
581 // it on a non-functional input,
582 // but from_functional_tensor(TensorList) can recieve a list containing both
583 // functional and non-functional tensors.
584 // Example of when that can happen: torch.cat(function_input_tensor, global_state_tensor).
585 // When that happens, we're okay with only unwrapping the functional tensors.
586 outputs.push_back(from_functional_tensor(tensor, /*assert_functional=*/false));
587 }
588 return outputs;
589 }
from_functional_tensor(const c10::List<::std::optional<Tensor>> & t_list)590 c10::List<::std::optional<Tensor>> from_functional_tensor(const c10::List<::std::optional<Tensor>>& t_list) {
591 c10::List<::std::optional<Tensor>> outputs;
592 outputs.reserve(t_list.size());
593 for (const auto i : c10::irange(t_list.size())) {
594 outputs.push_back(from_functional_tensor(t_list[i], /*assert_functional=*/false));
595 }
596 return outputs;
597 }
598
sync(const Tensor & t)599 void sync(const Tensor& t) {
600 if (t.unsafeGetTensorImpl()->is_wrapped_number()) {
601 // Note [Wrapped Numbers <> Functionalization]
602 // Unfortunately, we can't easily guarantee that wrapped numbers (scalar-tensors)
603 // get wrapped up in a FunctionalTensorWrapper object, since they skip the dispatcher.
604 // That shouldn't matter, since I don't think we're allowed to assign to wrapped numbers anyway.
605 return;
606 }
607 // Not every tensor that hits a functionalization kernel is necessarily a functional tensor.
608 // For example, xla_tensor.copy_(cpu_tensor) needs to hit the functionalization kernel
609 // to sync xla_tensor, but not cpu_tensor.
610 if (!at::functionalization::impl::isFunctionalTensor(t)) {
611 return;
612 }
613 auto functional_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
614 functional_impl->sync_();
615 }
sync(const std::optional<Tensor> & t)616 void sync(const std::optional<Tensor>& t) {
617 if (t.has_value()) {
618 sync(*t);
619 }
620 }
sync(ITensorListRef t_list)621 void sync(ITensorListRef t_list) {
622 for (const auto& t : t_list) {
623 sync(t);
624 }
625 }
sync(const c10::List<::std::optional<Tensor>> & t_list)626 void sync(const c10::List<::std::optional<Tensor>>& t_list) {
627 for (const auto i : c10::irange(t_list.size())) {
628 sync(t_list[i]);
629 }
630 }
631
replace_(const Tensor & functional_tensor,const Tensor & other)632 void replace_(const Tensor& functional_tensor, const Tensor& other) {
633 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
634 unsafeGetFunctionalWrapper(functional_tensor)->replace_(other);
635 }
636
replace_(const ITensorListRef functional_tensor,ITensorListRef other)637 void replace_(const ITensorListRef functional_tensor, ITensorListRef other) {
638 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size());
639 auto functional_tensor_it = functional_tensor.begin();
640 auto other_it = other.begin();
641 for (C10_UNUSED const auto i : c10::irange(functional_tensor.size())) {
642 replace_(*functional_tensor_it++, *other_it++);
643 }
644 }
645
propagate_xla_data(const Tensor & functional_tensor,const Tensor & other)646 void propagate_xla_data(const Tensor& functional_tensor, const Tensor& other) {
647 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
648 if (functional_tensor.key_set().has(c10::DispatchKey::XLA)) {
649 at::_propagate_xla_data(at::functionalization::impl::unsafeGetFunctionalWrapper(functional_tensor)
650 ->value(), other);
651 }
652 }
653
propagate_xla_data(const ITensorListRef functional_tensor,ITensorListRef other)654 void propagate_xla_data(const ITensorListRef functional_tensor, ITensorListRef other) {
655 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size());
656 auto functional_tensor_it = functional_tensor.begin();
657 auto other_it = other.begin();
658 for (C10_UNUSED const auto i : c10::irange(functional_tensor.size())) {
659 propagate_xla_data(*functional_tensor_it++, *other_it++);
660 }
661 }
662
propagate_xla_data_direct(const Tensor & tensor,const Tensor & other)663 void propagate_xla_data_direct(const Tensor& tensor, const Tensor& other) {
664 if (tensor.key_set().has(c10::DispatchKey::XLA)) {
665 at::_propagate_xla_data(tensor, other);
666 }
667 }
668
propagate_xla_data_direct(const ITensorListRef tensor,ITensorListRef other)669 void propagate_xla_data_direct(const ITensorListRef tensor,
670 ITensorListRef other) {
671 auto tensor_it = tensor.begin();
672 auto other_it = other.begin();
673 for (C10_UNUSED const auto i : c10::irange(tensor.size())) {
674 propagate_xla_data_direct(*tensor_it++, *other_it++);
675 }
676 }
677
commit_update(const Tensor & functional_tensor)678 void commit_update(const Tensor& functional_tensor) {
679 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
680 unsafeGetFunctionalWrapper(functional_tensor)->commit_update();
681 }
682
commit_update(ITensorListRef functional_tensor)683 void commit_update(ITensorListRef functional_tensor) {
684 for (const auto& t : functional_tensor) {
685 commit_update(t);
686 }
687 }
688
unsafe_reset_storage(const Tensor & functional_tensor)689 void unsafe_reset_storage(const Tensor& functional_tensor) {
690 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
691 unsafeGetFunctionalWrapper(functional_tensor)->_unsafe_reset_storage();
692 }
693
mark_mutation_hidden_from_autograd(const Tensor & functional_tensor)694 void mark_mutation_hidden_from_autograd(const Tensor& functional_tensor) {
695 TORCH_CHECK(isFunctionalTensor(functional_tensor));
696 unsafeGetFunctionalWrapper(functional_tensor)->mark_mutation_hidden_from_autograd();
697 }
698
are_all_mutations_hidden_from_autograd(const Tensor & functional_tensor)699 bool are_all_mutations_hidden_from_autograd(const Tensor& functional_tensor) {
700 TORCH_CHECK(isFunctionalTensor(functional_tensor));
701 return unsafeGetFunctionalWrapper(functional_tensor)->are_all_mutations_hidden_from_autograd();
702 }
703
are_all_mutations_under_no_grad_or_inference_mode(const Tensor & functional_tensor)704 bool are_all_mutations_under_no_grad_or_inference_mode(const Tensor& functional_tensor) {
705 TORCH_CHECK(isFunctionalTensor(functional_tensor));
706 return unsafeGetFunctionalWrapper(functional_tensor)->are_all_mutations_under_no_grad_or_inference_mode();
707 }
708
isFunctionalTensor(const at::Tensor & tensor)709 bool isFunctionalTensor(const at::Tensor& tensor) {
710 return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
711 }
712
isBaseTensor(const at::Tensor & tensor)713 bool isBaseTensor(const at::Tensor& tensor) {
714 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(tensor));
715 return unsafeGetFunctionalWrapper(tensor)->isBaseTensor();
716 }
717
isFunctionalTensor(const std::optional<Tensor> & t)718 bool isFunctionalTensor(const std::optional<Tensor>& t) {
719 if (t.has_value()) {
720 return isFunctionalTensor(*t);
721 } else {
722 return false;
723 }
724 }
725
isFunctionalTensor(const c10::List<::std::optional<Tensor>> & t_list)726 bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) {
727 if (t_list.empty()) return false;
728 auto functional_count = 0;
729 for (const auto i : c10::irange(t_list.size())) {
730 if (!t_list[i].has_value() || !t_list[i]->defined()) continue;
731 if (isFunctionalTensor(t_list[i])) {
732 ++functional_count;
733 }
734 }
735 return functional_count > 0;
736 }
737
738 template <typename T>
isFunctionalTensorIListRef(c10::IListRef<T> list)739 bool isFunctionalTensorIListRef(c10::IListRef<T> list) {
740 if (list.size() == 0) return false;
741 auto functional_count = 0;
742 for (const auto& tensor : list) {
743 if (!tensor.defined()) continue;
744 if (isFunctionalTensor(tensor)) {
745 ++functional_count;
746 }
747 }
748 return functional_count > 0;
749 }
750
isFunctionalTensor(ITensorListRef list)751 bool isFunctionalTensor(ITensorListRef list) {
752 return isFunctionalTensorIListRef(list);
753 }
754
freeze_functional_tensor(const Tensor & tensor)755 void freeze_functional_tensor(const Tensor& tensor) {
756 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(tensor));
757 auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
758 functional_base_impl->freeze_storage();
759 }
760
create_functional_tensor_with_view_meta(const at::Tensor & view_to_wrap,const at::Tensor & base,functionalization::ViewMeta meta,int64_t out_idx)761 Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) {
762 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap));
763 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base));
764 auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base);
765 if (out_idx != 0) {
766 // Note [out_idx in ViewMeta]
767 // When a view op outputs multiple tensors, each output needs its own separate ViewMeta.
768 // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function.
769 meta = meta.to_out_idx(out_idx);
770 }
771 return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta);
772 }
773
create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap,const at::Tensor & base,const functionalization::ViewMeta & meta)774 std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) {
775 std::vector<Tensor> outputs(view_to_wrap.size());
776 int64_t i = 0;
777 for (const auto& tensor : view_to_wrap) {
778 outputs[i] = create_functional_tensor_with_view_meta(tensor, base, meta, i);
779 i++;
780 }
781 return outputs;
782 }
783
mutate_view_meta(const at::Tensor & self,const functionalization::ViewMeta & meta)784 void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) {
785 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
786 auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
787 self_impl->mutate_view_meta(meta);
788 }
789
790 // Note [Propagating strides in the functionalization pass]
791 // In order to properly compute stride information, the functionalization pass
792 // calls each {view} reference implementations with meta tensors.
793 // The output meta tensor's stride info serves as a reference for what the correct strides should be.
set_sizes_strides_offset(const Tensor & out,const Tensor & reference_out)794 void set_sizes_strides_offset(const Tensor& out, const Tensor& reference_out) {
795 out.unsafeGetTensorImpl()->set_sizes_and_strides(reference_out.sym_sizes(), reference_out.sym_strides(), reference_out.sym_storage_offset());
796 }
797
set_sizes_strides_offset(const std::vector<Tensor> & outs,const std::vector<Tensor> & reference_outs)798 void set_sizes_strides_offset(const std::vector<Tensor>& outs, const std::vector<Tensor>& reference_outs) {
799 TORCH_INTERNAL_ASSERT(outs.size() == reference_outs.size());
800 for (const auto i : c10::irange(reference_outs.size())) {
801 set_sizes_strides_offset(outs[i], reference_outs[i]);
802 }
803 }
804
805 thread_local bool _functionalizationReapplyViews;
806
getFunctionalizationReapplyViewsTLS()807 bool getFunctionalizationReapplyViewsTLS() {
808 return _functionalizationReapplyViews;
809 }
setFunctionalizationReapplyViewsTLS(bool reapply_views)810 void setFunctionalizationReapplyViewsTLS(bool reapply_views) {
811 _functionalizationReapplyViews = reapply_views;
812 }
813
814 } // namespace impl
815
816
817 // Given an **out-of-place** op that might internally call view/inplace ops,
818 // This function will "functionalize" it.
819 // That is, it will call the operator, but removing any intermediate views/mutations
820 // that are performed inside of it.
821 // This is useful for LTC/XLA, which would like to re-use some of our composite kernels
822 // from pytorch core but not have to worry about the view ops that they might call.
823 // e.g. at::block_diag
functionalize_op_helper(const c10::OperatorHandle & op,torch::jit::Stack * stack)824 void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
825 const auto& schema = op.schema();
826 const auto num_arguments = schema.arguments().size();
827 const auto arguments_begin = stack->size() - num_arguments;
828 auto arguments = torch::jit::last(stack, num_arguments);
829
830 // Wrap all tensor-like inputs into FunctionalTensorWrappers.
831 // When we re-invoke the dispatcher, this will automatically enable the functionalization pass.
832 for (uint64_t idx = 0; idx < num_arguments; ++idx) {
833 const auto& ivalue = arguments[idx];
834 if (ivalue.isTensor()) {
835 const auto& t = ivalue.toTensor();
836 if (t.defined()) {
837 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t),
838 "The composite op functionalization fallback expects its inputs all not to be functional tensors");
839 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
840 (*stack)[arguments_begin + idx] = t_new;
841 }
842 } else if (ivalue.isTensorList()) {
843 auto tensors = ivalue.toTensorList();
844 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensors),
845 "The composite op functionalization fallback expects its inputs all not to be functional tensors");
846 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
847 (*stack)[arguments_begin + idx] = t_new;
848 } else if (ivalue.isOptionalTensorList()) {
849 auto opt_tensors = ivalue.toOptionalTensorList();
850 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(opt_tensors),
851 "The composite op functionalization fallback expects its inputs all not to be functional tensors");
852 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
853 (*stack)[arguments_begin + idx] = t_new;
854 }
855 }
856
857 {
858 // Today when you call at::empty(device=lazy), the lazy backend decides whether or not to wrap
859 // the output in a functional tensor based on TLS.
860 // In this code, we're re-entrantly entering functionalization in the same call-stack,
861 // so we need to manually fix up TLS as if it hadn't already been called.
862 auto curr_tls = c10::impl::tls_local_dispatch_key_set();
863 auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet();
864 tls_reenable_functionalize.set_included(curr_tls.included_);
865 tls_reenable_functionalize.set_excluded(curr_tls.excluded_.remove(c10::DispatchKey::Functionalize));
866 c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize);
867 // So, we should probably provide a way to directly call a kernel registered to
868 // the `CompositeExplicitAutograd` key.
869 // We can't do that today, so this should be a reasonably good proxy
870 // (It won't work in cases where an op has both a CompositeExplicitAutograd kernel
871 // AND a dedicated meta kernel, but that probably shouldn't ever happen).
872 op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::Meta), stack);
873 }
874
875 const auto num_returns = schema.returns().size();
876 const auto returns_begin = stack->size() - num_returns;
877 auto returns = torch::jit::last(stack, num_returns);
878
879 for (const auto idx : c10::irange(num_returns)) {
880 const auto& ivalue = returns[idx];
881 if (ivalue.isTensor()) {
882 const auto& t = ivalue.toTensor();
883 if (!t.defined()) continue;
884 at::functionalization::impl::sync(t);
885 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
886 (*stack)[returns_begin + idx] = t_new;
887 } else if (ivalue.isTensorList()) {
888 auto tensors = ivalue.toTensorList();
889 at::functionalization::impl::sync(tensors);
890 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
891 (*stack)[returns_begin + idx] = t_new;
892 } else if (ivalue.isOptionalTensorList()) {
893 auto opt_tensors = ivalue.toOptionalTensorList();
894 at::functionalization::impl::sync(opt_tensors);
895 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
896 (*stack)[returns_begin + idx] = t_new;
897 }
898 }
899 }
900
901
902
903 } // namespace functionalization
904 } // namespace at
905