xref: /aosp_15_r20/external/pytorch/aten/src/ATen/FunctionalTensorWrapper.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #pragma once
3 
4 #include <ATen/ArrayRef.h>
5 #include <ATen/FunctionalStorageImpl.h>
6 #include <ATen/core/IListRef.h>
7 #include <ATen/core/List.h>
8 #include <ATen/core/boxing/BoxedKernel.h>
9 #include <ATen/core/boxing/impl/boxing.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11 
12 #include <c10/core/DispatchKey.h>
13 
14 namespace at {
15 
16 // Note [Functionalization Pass In Core]
17 // The Functionalization pass is used to remove aliasing from a pytorch program.
18 //
19 // This is useful for backends that don't support aliasing, like XLA and Vulkan.
20 // It's also necessary in order to remove mutation from a program, which is
21 // needed in Functorch.
22 //
23 // Consider this program:
24 // a = torch.ones(...)
25 // b = a.view(...)
26 // b.add_(1)
27 //
28 // In this program, b is meant to alias with a due to the use of view(). At the
29 // end of the program, both a and b are full of 2's. However, backends that
30 // don't support aliasing aren't able to correctly implement the view()
31 // operator. Instead, they can opt into the Functionalization pass, which will
32 // sit between the user and the backend, and provide the necessary aliasing
33 // logic.
34 //
35 // The functionalization pass will turn the above program into a slightly
36 // different program that has the same semantics, transparently to the user,
37 // that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
38 // a.view_copy(...)  # view() replaced with view_copy(). Backends like
39 // XLA/Vulkan can implement this! b.add_(1) a.add_(1)  # Our functionalization
40 // pass machinery knows that a and b are aliased - it applies b's mutation to a
41 // too.
42 //
43 // So, how does the functionalization pass keep track of which tensors are
44 // aliased? The pass works by wrapping EVERY tensor in the program inside of a
45 // FunctionalTensorWrapper, which knows about its alias'd tensors.
46 //
47 // See Note [Functionalization: Alias Removal] for details on the aliasing
48 // machinery. See Note [Functionalization: Mutation Removal] for details on
49 // mutation removal.
50 struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
51   explicit FunctionalTensorWrapper(const Tensor& value);
52   // Additional constructor to create a FunctionalTensorWrapper directly from an
53   // underlying tensor that was created from a view. For example, the code b =
54   // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
55   // view1_meta)
56   explicit FunctionalTensorWrapper(
57       const Tensor& view_value,
58       const FunctionalTensorWrapper* base,
59       const functionalization::ViewMeta& meta);
60 
61   // Get the underlying, actual tensor, that doesn't know anything about
62   // functionalization.
valueFunctionalTensorWrapper63   const Tensor& value() const {
64     return value_;
65   };
66   // The concept of "level" is only ever important to functorch; it's exposed
67   // here as more of a hook for functorch to use.
levelFunctionalTensorWrapper68   int64_t level() const {
69     return level_;
70   };
set_levelFunctionalTensorWrapper71   void set_level(int64_t level) {
72     level_ = level;
73   }
has_metadata_mutationFunctionalTensorWrapper74   bool has_metadata_mutation() const {
75     return has_metadata_mutation_;
76   };
77 
mark_mutationFunctionalTensorWrapper78   void mark_mutation() {
79     functional_storage_impl()->mark_mutation();
80   }
81   // Denotes a mutation that's hidden from autograd,
82   // e.g. for the purposes of passing a tensor to a triton kernel
mark_mutation_hidden_from_autogradFunctionalTensorWrapper83   void mark_mutation_hidden_from_autograd() {
84     functional_storage_impl()->mark_mutation_hidden_from_autograd();
85   }
mark_mutation_during_no_grad_or_inference_modeFunctionalTensorWrapper86   void mark_mutation_during_no_grad_or_inference_mode() {
87     functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
88   }
89   // Are all the mutations happening to the tensor hidden from autograd
are_all_mutations_hidden_from_autogradFunctionalTensorWrapper90   bool are_all_mutations_hidden_from_autograd() const {
91     return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
92   }
93   // Did all mutations happen under no_grad or inference_mode
94   // (We also need to ignore mutations fully hidden from autograd here)
are_all_mutations_under_no_grad_or_inference_modeFunctionalTensorWrapper95   bool are_all_mutations_under_no_grad_or_inference_mode() const {
96     return functional_storage_impl()
97         ->are_all_mutations_under_no_grad_or_inference_mode();
98   }
99 
maybe_mark_symbolicFunctionalTensorWrapper100   void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
101     is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
102   }
103 
is_symbolicFunctionalTensorWrapper104   bool is_symbolic() const {
105     return is_symbolic_;
106   }
107 
108   // Runs the forward_fn of every ViewMeta collected in the current instance
109   // to some other base.
110   Tensor apply_view_metas(const Tensor& base);
111 
112   // Sync's the underlying tensor with its alias, if it's out of date. This
113   // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
114   // Replay the views (if any) to regenerate the current tensor off of the
115   // updated alias.
116   void sync_();
117   // Performs step (1) of the sync. This is its own public API because it's
118   // needed by view_inplace ops like transpose_. See Note [Functionalization
119   // Pass - Inplace View Ops]
120   void regenerate_from_base();
121   // Performs step (2) of the sync. This is its own public API because it's
122   // needed by functorch. functorch wants to make sure that all input tensors to
123   // a functionalized program have been properly synced so it can properly
124   // propagate mutations to inputs. It can't just call sync_(), because the
125   // FunctionalTensorWrapper will look like it has no aliases and sync_ will be
126   // a noop. We use the reference count on storage_ to determine if the wrapper
127   // is aliased, and by the time functorch is ready to propagate updates to
128   // inputs, any intermediate views of the input created by the program will
129   // have been deallocated. This function also returns whether or not the base
130   // actually had any updates to apply.
131   bool apply_updates();
132   // Takes the current state of value_ and snapshots it, sending it as a pending
133   // update to the alias.
134   void commit_update();
135   // When any tensor is mutated, the tensor increments its alias's "generation".
136   // Separately, each tensor maintains its own "generation" counter, which is
137   // used to determine if it's up-to-date with its alias. The act of syncing a
138   // tensor will set a tensor's generation equal to its alias's generation.
139   bool is_up_to_date() const;
140   // Freezes the storage of this tensor, preventing subsequent mutations
141   void freeze_storage() const;
142   // Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
143   // describing the series of view ops that ran to generate the current tensor
144   // from the base tensor. This method is used by inplace-view ops like
145   // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
146   // tensor by replaying the views off of the alias.
147   void mutate_view_meta(const at::functionalization::ViewMeta& meta);
148 
149   // Custom implementation of self.set_(src)
150   void set__impl(const FunctionalTensorWrapper* other);
151 
152   // Custom implementation of resize_storage_bytes_(self, new_size)
153   void storage_resize_(const c10::SymInt& new_size);
154 
155   // Returns whether the current tensor's data was ever mutated
156   bool has_data_mutation();
157   //
158   // Returns whether the current FunctionalTensorWrapper
159   // experienced a set_() call.
was_storage_changedFunctionalTensorWrapper160   bool was_storage_changed() {
161     return was_storage_changed_;
162   }
163 
set_storage_changedFunctionalTensorWrapper164   void set_storage_changed() {
165     was_storage_changed_ = true;
166   }
167 
168   // A FunctionalTensor is considered a base if its not a view of another
169   // tensor.
isBaseTensorFunctionalTensorWrapper170   bool isBaseTensor() const {
171     return view_metas_.empty();
172   }
173 
get_storage_sizeFunctionalTensorWrapper174   c10::SymInt get_storage_size(bool before) {
175     return functional_storage_impl()->get_storage_size(before);
176   }
177 
178   // Returns whether the FunctionalTensor experienced an
179   // untyped_storage().resize_() call
was_inductor_storage_resizedFunctionalTensorWrapper180   bool was_inductor_storage_resized() {
181     return functional_storage_impl()->was_inductor_storage_resized();
182   }
183 
184   // The functionalization pass can be used to remove mutations.
185   // It does so by replacing any mutation op with it's corresponding
186   // out-of-place op, followed by a call to replace_(). e.g:
187   //
188   // a.add_(1)
189   //
190   // will turn into:
191   //
192   // tmp = a.add(1)
193   // a.replace_(tmp)
194   //
195   // replace_() swaps out the wrapped tensor, value_, with tmp.
196   void replace_(const Tensor& other, bool from_lazy_regenerate = false);
197 
is_multi_output_viewFunctionalTensorWrapper198   bool is_multi_output_view() {
199     return is_multi_output_view_;
200   }
201 
202   // See Note[resize_() in functionalization pass]
203   void maybe_replace_storage(const Tensor& other);
204 
205   // Replaces the storage with a new functional storage,
206   // and clears the view_metas_ stack.
207   // WARNING: Calling this function will sever the aliasing relationship between
208   // the current FunctionalTensorWrapper and any of its outstanding aliases.
209   // Please only call if you know what you're doing.
210   void _unsafe_reset_storage();
211 
212   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
213       const c10::VariableVersion& version_counter,
214       bool allow_tensor_metadata_change) const override;
215 
216   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
217       c10::VariableVersion&& version_counter,
218       bool allow_tensor_metadata_change) const override;
219 
220   ~FunctionalTensorWrapper() override = default;
221 
222   // FunctionalTensorWrapper overrides all custom size/stride function,
223   // so that if the inner tensor has a custom implementation
224   // we make sure to call that implementation.
225   at::IntArrayRef sizes_custom() const override;
226   at::IntArrayRef strides_custom() const override;
227   int64_t dim_custom() const override;
228   int64_t numel_custom() const override;
229   bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
230   c10::SymIntArrayRef sym_sizes_custom() const override;
231   c10::SymInt sym_size_custom(int64_t d) const override;
232   c10::SymIntArrayRef sym_strides_custom() const override;
233   c10::SymInt sym_storage_offset_custom() const override;
234   c10::Device device_custom() const override;
235   c10::Layout layout_impl() const override;
236 
237  private:
238   const char* tensorimpl_type_name() const override;
239   void set_constructor_metadata();
240   functionalization::FunctionalStorageImpl* functional_storage_impl() const;
241 
242   // This is used to re-implement shallow_copy_and_detach for
243   // FunctionalTensorWrapper. The implementation is identical, but we just need
244   // to return a subclass instead of a plain TensorImpl.
245   // TODO: maybe it's possible to arrange for that to happen automatically
246   // without an override here?
247   template <typename VariableVersion>
248   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
249       VariableVersion&& version_counter,
250       bool allow_tensor_metadata_change) const;
251 
252   void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
253   void copy_tensor_metadata_and_refresh(
254       const FunctionalTensorWrapper* src_impl,
255       FunctionalTensorWrapper* dest_impl,
256       const c10::VariableVersion& version_counter,
257       bool allow_tensor_metadata_change) const;
258 
259   // Note that value is not taken by reference: internally, the wrapper will
260   // change the value tensor that it points to over time.
261   Tensor value_;
262   int64_t level_{};
263   // These two counters are used for identifying
264   // whether all the mutations on a given tensor are hidden from autograd or
265   // not. If we have an input mutation that is hidden from autograd, then once
266   // we convert the input mutation to a copy_() we know it will be safe to hide
267   // the copy_() from autograd as well.
268   bool has_metadata_mutation_ = false;
269   bool is_multi_output_view_ = false;
270   // Did the tensor experience a set_() call.
271   bool was_storage_changed_ = false;
272   // Did the tensor experience any view operation with symbolic int.
273   bool is_symbolic_ = false;
274 
275   size_t generation_ = 0;
276   std::vector<at::functionalization::ViewMeta> view_metas_;
277 
278  protected:
279   static void copy_tensor_metadata(
280       const FunctionalTensorWrapper* src_impl,
281       FunctionalTensorWrapper* dest_impl,
282       const c10::VariableVersion& version_counter,
283       bool allow_tensor_metadata_change);
284 };
285 
286 // Utility functions for the functionalization pass.
287 
288 namespace functionalization {
289 namespace impl {
290 
unsafeGetFunctionalWrapper(const Tensor & tensor)291 TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
292     const Tensor& tensor) {
293   auto functional_impl =
294       static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
295   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
296   return functional_impl;
297 }
298 
299 TORCH_API bool isBaseTensor(const at::Tensor& tensor);
300 
301 TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
302 TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
303 TORCH_API bool isFunctionalTensor(
304     const c10::List<std::optional<Tensor>>& t_list);
305 TORCH_API bool isFunctionalTensor(ITensorListRef list);
306 
307 TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
308 TORCH_API std::optional<Tensor> to_functional_tensor(
309     const std::optional<Tensor>& tensor);
310 TORCH_API c10::List<std::optional<Tensor>> to_functional_tensor(
311     const c10::List<std::optional<Tensor>>& t_list);
312 TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
313 
314 TORCH_API void freeze_functional_tensor(const Tensor& tensor);
315 
316 TORCH_API Tensor
317 from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
318 TORCH_API std::optional<Tensor> from_functional_tensor(
319     const std::optional<Tensor>& t,
320     bool assert_functional = true);
321 TORCH_API c10::List<std::optional<Tensor>> from_functional_tensor(
322     const c10::List<std::optional<Tensor>>& t_list);
323 TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
324 
325 TORCH_API void sync(const at::Tensor& t);
326 TORCH_API void sync(const std::optional<Tensor>& t);
327 TORCH_API void sync(const c10::List<std::optional<Tensor>>& t_list);
328 TORCH_API void sync(ITensorListRef t_list);
329 
330 TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
331 TORCH_API void replace_(
332     const ITensorListRef functional_tensor,
333     ITensorListRef other);
334 
335 TORCH_API void commit_update(const Tensor& functional_tensor);
336 TORCH_API void commit_update(ITensorListRef functional_tensor);
337 
338 TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
339 
340 TORCH_API void mark_mutation_hidden_from_autograd(
341     const Tensor& functional_tensor);
342 
343 TORCH_API bool are_all_mutations_hidden_from_autograd(
344     const Tensor& functional_tensor);
345 
346 TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
347     const Tensor& functional_tensor);
348 
349 // These two methods are XLA-specific logic and are no-ops
350 // for the normal functionalization flow.
351 TORCH_API void propagate_xla_data(
352     const Tensor& functional_tensor,
353     const Tensor& other);
354 TORCH_API void propagate_xla_data(
355     const ITensorListRef functional_tensor,
356     ITensorListRef other);
357 
358 TORCH_API void propagate_xla_data_direct(
359     const Tensor& tensor,
360     const Tensor& other);
361 TORCH_API void propagate_xla_data_direct(
362     const ITensorListRef tensor,
363     ITensorListRef other);
364 
365 Tensor create_functional_tensor_with_view_meta(
366     const Tensor& view_to_wrap,
367     const Tensor& base,
368     functionalization::ViewMeta meta,
369     int64_t out_idx = 0);
370 std::vector<Tensor> create_functional_tensor_with_view_meta(
371     ITensorListRef view_to_wrap,
372     const Tensor& base,
373     const functionalization::ViewMeta& meta);
374 
375 void mutate_view_meta(
376     const Tensor& self,
377     const functionalization::ViewMeta& meta);
378 
379 void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
380 void set_sizes_strides_offset(
381     const std::vector<Tensor>& outs,
382     const std::vector<Tensor>& meta_outs);
383 
384 //  ~~~~~ TLS used in functionalization ~~~~~
385 
386 TORCH_API bool getFunctionalizationReapplyViewsTLS();
387 TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
388 
389 class TORCH_API FunctionalizationReapplyViewsGuard {
390  public:
FunctionalizationReapplyViewsGuard(bool reapply_views)391   FunctionalizationReapplyViewsGuard(bool reapply_views)
392       : prev_(getFunctionalizationReapplyViewsTLS()) {
393     setFunctionalizationReapplyViewsTLS(reapply_views);
394   }
395 
~FunctionalizationReapplyViewsGuard()396   ~FunctionalizationReapplyViewsGuard() {
397     setFunctionalizationReapplyViewsTLS(prev_);
398   }
399 
400   FunctionalizationReapplyViewsGuard(
401       const FunctionalizationReapplyViewsGuard&) = delete;
402   FunctionalizationReapplyViewsGuard operator=(
403       const FunctionalizationReapplyViewsGuard&) = delete;
404   FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
405       delete;
406   FunctionalizationReapplyViewsGuard operator=(
407       FunctionalizationReapplyViewsGuard&&) = delete;
408 
409  private:
410   bool prev_;
411 };
412 
413 } // namespace impl
414 
415 // Helper function to call an out-of-place composite aten kernel that may use
416 // mutations / views internally, and functionalize them.
417 TORCH_API void functionalize_op_helper(
418     const c10::OperatorHandle& op,
419     torch::jit::Stack* stack);
420 
421 template <class Op, bool symint, class ReturnType, class... ParameterTypes>
422 struct _functionalize_aten_op final {};
423 
424 template <class Op, bool symint, class ReturnType, class... ParameterTypes>
425 struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
426   static ReturnType call(
427       typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
428     using FuncType = ReturnType(
429         typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
430     auto op = c10::Dispatcher::singleton()
431                   .findSchemaOrThrow(
432                       (const char*)Op::name, (const char*)Op::overload_name)
433                   .typed<FuncType>();
434 
435     return c10::impl::BoxedKernelWrapper<FuncType>::call(
436         c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
437         op,
438         // BoxedKernelWrapper knows to ignore this keyset argument,
439         // because functionalize_op_helper doesn't take in a DispatchKeySet
440         c10::DispatchKeySet(),
441         args...);
442   }
443 };
444 
445 template <class Op>
446 using functionalize_aten_op =
447     _functionalize_aten_op<Op, false, typename Op::schema>;
448 
449 template <class Op>
450 using functionalize_aten_op_symint =
451     _functionalize_aten_op<Op, true, typename Op::schema>;
452 
453 } // namespace functionalization
454 } // namespace at
455