1
2 #pragma once
3
4 #include <ATen/ArrayRef.h>
5 #include <ATen/FunctionalStorageImpl.h>
6 #include <ATen/core/IListRef.h>
7 #include <ATen/core/List.h>
8 #include <ATen/core/boxing/BoxedKernel.h>
9 #include <ATen/core/boxing/impl/boxing.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11
12 #include <c10/core/DispatchKey.h>
13
14 namespace at {
15
16 // Note [Functionalization Pass In Core]
17 // The Functionalization pass is used to remove aliasing from a pytorch program.
18 //
19 // This is useful for backends that don't support aliasing, like XLA and Vulkan.
20 // It's also necessary in order to remove mutation from a program, which is
21 // needed in Functorch.
22 //
23 // Consider this program:
24 // a = torch.ones(...)
25 // b = a.view(...)
26 // b.add_(1)
27 //
28 // In this program, b is meant to alias with a due to the use of view(). At the
29 // end of the program, both a and b are full of 2's. However, backends that
30 // don't support aliasing aren't able to correctly implement the view()
31 // operator. Instead, they can opt into the Functionalization pass, which will
32 // sit between the user and the backend, and provide the necessary aliasing
33 // logic.
34 //
35 // The functionalization pass will turn the above program into a slightly
36 // different program that has the same semantics, transparently to the user,
37 // that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
38 // a.view_copy(...) # view() replaced with view_copy(). Backends like
39 // XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
40 // pass machinery knows that a and b are aliased - it applies b's mutation to a
41 // too.
42 //
43 // So, how does the functionalization pass keep track of which tensors are
44 // aliased? The pass works by wrapping EVERY tensor in the program inside of a
45 // FunctionalTensorWrapper, which knows about its alias'd tensors.
46 //
47 // See Note [Functionalization: Alias Removal] for details on the aliasing
48 // machinery. See Note [Functionalization: Mutation Removal] for details on
49 // mutation removal.
50 struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
51 explicit FunctionalTensorWrapper(const Tensor& value);
52 // Additional constructor to create a FunctionalTensorWrapper directly from an
53 // underlying tensor that was created from a view. For example, the code b =
54 // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
55 // view1_meta)
56 explicit FunctionalTensorWrapper(
57 const Tensor& view_value,
58 const FunctionalTensorWrapper* base,
59 const functionalization::ViewMeta& meta);
60
61 // Get the underlying, actual tensor, that doesn't know anything about
62 // functionalization.
valueFunctionalTensorWrapper63 const Tensor& value() const {
64 return value_;
65 };
66 // The concept of "level" is only ever important to functorch; it's exposed
67 // here as more of a hook for functorch to use.
levelFunctionalTensorWrapper68 int64_t level() const {
69 return level_;
70 };
set_levelFunctionalTensorWrapper71 void set_level(int64_t level) {
72 level_ = level;
73 }
has_metadata_mutationFunctionalTensorWrapper74 bool has_metadata_mutation() const {
75 return has_metadata_mutation_;
76 };
77
mark_mutationFunctionalTensorWrapper78 void mark_mutation() {
79 functional_storage_impl()->mark_mutation();
80 }
81 // Denotes a mutation that's hidden from autograd,
82 // e.g. for the purposes of passing a tensor to a triton kernel
mark_mutation_hidden_from_autogradFunctionalTensorWrapper83 void mark_mutation_hidden_from_autograd() {
84 functional_storage_impl()->mark_mutation_hidden_from_autograd();
85 }
mark_mutation_during_no_grad_or_inference_modeFunctionalTensorWrapper86 void mark_mutation_during_no_grad_or_inference_mode() {
87 functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
88 }
89 // Are all the mutations happening to the tensor hidden from autograd
are_all_mutations_hidden_from_autogradFunctionalTensorWrapper90 bool are_all_mutations_hidden_from_autograd() const {
91 return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
92 }
93 // Did all mutations happen under no_grad or inference_mode
94 // (We also need to ignore mutations fully hidden from autograd here)
are_all_mutations_under_no_grad_or_inference_modeFunctionalTensorWrapper95 bool are_all_mutations_under_no_grad_or_inference_mode() const {
96 return functional_storage_impl()
97 ->are_all_mutations_under_no_grad_or_inference_mode();
98 }
99
maybe_mark_symbolicFunctionalTensorWrapper100 void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
101 is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
102 }
103
is_symbolicFunctionalTensorWrapper104 bool is_symbolic() const {
105 return is_symbolic_;
106 }
107
108 // Runs the forward_fn of every ViewMeta collected in the current instance
109 // to some other base.
110 Tensor apply_view_metas(const Tensor& base);
111
112 // Sync's the underlying tensor with its alias, if it's out of date. This
113 // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
114 // Replay the views (if any) to regenerate the current tensor off of the
115 // updated alias.
116 void sync_();
117 // Performs step (1) of the sync. This is its own public API because it's
118 // needed by view_inplace ops like transpose_. See Note [Functionalization
119 // Pass - Inplace View Ops]
120 void regenerate_from_base();
121 // Performs step (2) of the sync. This is its own public API because it's
122 // needed by functorch. functorch wants to make sure that all input tensors to
123 // a functionalized program have been properly synced so it can properly
124 // propagate mutations to inputs. It can't just call sync_(), because the
125 // FunctionalTensorWrapper will look like it has no aliases and sync_ will be
126 // a noop. We use the reference count on storage_ to determine if the wrapper
127 // is aliased, and by the time functorch is ready to propagate updates to
128 // inputs, any intermediate views of the input created by the program will
129 // have been deallocated. This function also returns whether or not the base
130 // actually had any updates to apply.
131 bool apply_updates();
132 // Takes the current state of value_ and snapshots it, sending it as a pending
133 // update to the alias.
134 void commit_update();
135 // When any tensor is mutated, the tensor increments its alias's "generation".
136 // Separately, each tensor maintains its own "generation" counter, which is
137 // used to determine if it's up-to-date with its alias. The act of syncing a
138 // tensor will set a tensor's generation equal to its alias's generation.
139 bool is_up_to_date() const;
140 // Freezes the storage of this tensor, preventing subsequent mutations
141 void freeze_storage() const;
142 // Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
143 // describing the series of view ops that ran to generate the current tensor
144 // from the base tensor. This method is used by inplace-view ops like
145 // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
146 // tensor by replaying the views off of the alias.
147 void mutate_view_meta(const at::functionalization::ViewMeta& meta);
148
149 // Custom implementation of self.set_(src)
150 void set__impl(const FunctionalTensorWrapper* other);
151
152 // Custom implementation of resize_storage_bytes_(self, new_size)
153 void storage_resize_(const c10::SymInt& new_size);
154
155 // Returns whether the current tensor's data was ever mutated
156 bool has_data_mutation();
157 //
158 // Returns whether the current FunctionalTensorWrapper
159 // experienced a set_() call.
was_storage_changedFunctionalTensorWrapper160 bool was_storage_changed() {
161 return was_storage_changed_;
162 }
163
set_storage_changedFunctionalTensorWrapper164 void set_storage_changed() {
165 was_storage_changed_ = true;
166 }
167
168 // A FunctionalTensor is considered a base if its not a view of another
169 // tensor.
isBaseTensorFunctionalTensorWrapper170 bool isBaseTensor() const {
171 return view_metas_.empty();
172 }
173
get_storage_sizeFunctionalTensorWrapper174 c10::SymInt get_storage_size(bool before) {
175 return functional_storage_impl()->get_storage_size(before);
176 }
177
178 // Returns whether the FunctionalTensor experienced an
179 // untyped_storage().resize_() call
was_inductor_storage_resizedFunctionalTensorWrapper180 bool was_inductor_storage_resized() {
181 return functional_storage_impl()->was_inductor_storage_resized();
182 }
183
184 // The functionalization pass can be used to remove mutations.
185 // It does so by replacing any mutation op with it's corresponding
186 // out-of-place op, followed by a call to replace_(). e.g:
187 //
188 // a.add_(1)
189 //
190 // will turn into:
191 //
192 // tmp = a.add(1)
193 // a.replace_(tmp)
194 //
195 // replace_() swaps out the wrapped tensor, value_, with tmp.
196 void replace_(const Tensor& other, bool from_lazy_regenerate = false);
197
is_multi_output_viewFunctionalTensorWrapper198 bool is_multi_output_view() {
199 return is_multi_output_view_;
200 }
201
202 // See Note[resize_() in functionalization pass]
203 void maybe_replace_storage(const Tensor& other);
204
205 // Replaces the storage with a new functional storage,
206 // and clears the view_metas_ stack.
207 // WARNING: Calling this function will sever the aliasing relationship between
208 // the current FunctionalTensorWrapper and any of its outstanding aliases.
209 // Please only call if you know what you're doing.
210 void _unsafe_reset_storage();
211
212 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
213 const c10::VariableVersion& version_counter,
214 bool allow_tensor_metadata_change) const override;
215
216 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
217 c10::VariableVersion&& version_counter,
218 bool allow_tensor_metadata_change) const override;
219
220 ~FunctionalTensorWrapper() override = default;
221
222 // FunctionalTensorWrapper overrides all custom size/stride function,
223 // so that if the inner tensor has a custom implementation
224 // we make sure to call that implementation.
225 at::IntArrayRef sizes_custom() const override;
226 at::IntArrayRef strides_custom() const override;
227 int64_t dim_custom() const override;
228 int64_t numel_custom() const override;
229 bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
230 c10::SymIntArrayRef sym_sizes_custom() const override;
231 c10::SymInt sym_size_custom(int64_t d) const override;
232 c10::SymIntArrayRef sym_strides_custom() const override;
233 c10::SymInt sym_storage_offset_custom() const override;
234 c10::Device device_custom() const override;
235 c10::Layout layout_impl() const override;
236
237 private:
238 const char* tensorimpl_type_name() const override;
239 void set_constructor_metadata();
240 functionalization::FunctionalStorageImpl* functional_storage_impl() const;
241
242 // This is used to re-implement shallow_copy_and_detach for
243 // FunctionalTensorWrapper. The implementation is identical, but we just need
244 // to return a subclass instead of a plain TensorImpl.
245 // TODO: maybe it's possible to arrange for that to happen automatically
246 // without an override here?
247 template <typename VariableVersion>
248 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
249 VariableVersion&& version_counter,
250 bool allow_tensor_metadata_change) const;
251
252 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
253 void copy_tensor_metadata_and_refresh(
254 const FunctionalTensorWrapper* src_impl,
255 FunctionalTensorWrapper* dest_impl,
256 const c10::VariableVersion& version_counter,
257 bool allow_tensor_metadata_change) const;
258
259 // Note that value is not taken by reference: internally, the wrapper will
260 // change the value tensor that it points to over time.
261 Tensor value_;
262 int64_t level_{};
263 // These two counters are used for identifying
264 // whether all the mutations on a given tensor are hidden from autograd or
265 // not. If we have an input mutation that is hidden from autograd, then once
266 // we convert the input mutation to a copy_() we know it will be safe to hide
267 // the copy_() from autograd as well.
268 bool has_metadata_mutation_ = false;
269 bool is_multi_output_view_ = false;
270 // Did the tensor experience a set_() call.
271 bool was_storage_changed_ = false;
272 // Did the tensor experience any view operation with symbolic int.
273 bool is_symbolic_ = false;
274
275 size_t generation_ = 0;
276 std::vector<at::functionalization::ViewMeta> view_metas_;
277
278 protected:
279 static void copy_tensor_metadata(
280 const FunctionalTensorWrapper* src_impl,
281 FunctionalTensorWrapper* dest_impl,
282 const c10::VariableVersion& version_counter,
283 bool allow_tensor_metadata_change);
284 };
285
286 // Utility functions for the functionalization pass.
287
288 namespace functionalization {
289 namespace impl {
290
unsafeGetFunctionalWrapper(const Tensor & tensor)291 TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
292 const Tensor& tensor) {
293 auto functional_impl =
294 static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
295 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
296 return functional_impl;
297 }
298
299 TORCH_API bool isBaseTensor(const at::Tensor& tensor);
300
301 TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
302 TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
303 TORCH_API bool isFunctionalTensor(
304 const c10::List<std::optional<Tensor>>& t_list);
305 TORCH_API bool isFunctionalTensor(ITensorListRef list);
306
307 TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
308 TORCH_API std::optional<Tensor> to_functional_tensor(
309 const std::optional<Tensor>& tensor);
310 TORCH_API c10::List<std::optional<Tensor>> to_functional_tensor(
311 const c10::List<std::optional<Tensor>>& t_list);
312 TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
313
314 TORCH_API void freeze_functional_tensor(const Tensor& tensor);
315
316 TORCH_API Tensor
317 from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
318 TORCH_API std::optional<Tensor> from_functional_tensor(
319 const std::optional<Tensor>& t,
320 bool assert_functional = true);
321 TORCH_API c10::List<std::optional<Tensor>> from_functional_tensor(
322 const c10::List<std::optional<Tensor>>& t_list);
323 TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
324
325 TORCH_API void sync(const at::Tensor& t);
326 TORCH_API void sync(const std::optional<Tensor>& t);
327 TORCH_API void sync(const c10::List<std::optional<Tensor>>& t_list);
328 TORCH_API void sync(ITensorListRef t_list);
329
330 TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
331 TORCH_API void replace_(
332 const ITensorListRef functional_tensor,
333 ITensorListRef other);
334
335 TORCH_API void commit_update(const Tensor& functional_tensor);
336 TORCH_API void commit_update(ITensorListRef functional_tensor);
337
338 TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
339
340 TORCH_API void mark_mutation_hidden_from_autograd(
341 const Tensor& functional_tensor);
342
343 TORCH_API bool are_all_mutations_hidden_from_autograd(
344 const Tensor& functional_tensor);
345
346 TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
347 const Tensor& functional_tensor);
348
349 // These two methods are XLA-specific logic and are no-ops
350 // for the normal functionalization flow.
351 TORCH_API void propagate_xla_data(
352 const Tensor& functional_tensor,
353 const Tensor& other);
354 TORCH_API void propagate_xla_data(
355 const ITensorListRef functional_tensor,
356 ITensorListRef other);
357
358 TORCH_API void propagate_xla_data_direct(
359 const Tensor& tensor,
360 const Tensor& other);
361 TORCH_API void propagate_xla_data_direct(
362 const ITensorListRef tensor,
363 ITensorListRef other);
364
365 Tensor create_functional_tensor_with_view_meta(
366 const Tensor& view_to_wrap,
367 const Tensor& base,
368 functionalization::ViewMeta meta,
369 int64_t out_idx = 0);
370 std::vector<Tensor> create_functional_tensor_with_view_meta(
371 ITensorListRef view_to_wrap,
372 const Tensor& base,
373 const functionalization::ViewMeta& meta);
374
375 void mutate_view_meta(
376 const Tensor& self,
377 const functionalization::ViewMeta& meta);
378
379 void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
380 void set_sizes_strides_offset(
381 const std::vector<Tensor>& outs,
382 const std::vector<Tensor>& meta_outs);
383
384 // ~~~~~ TLS used in functionalization ~~~~~
385
386 TORCH_API bool getFunctionalizationReapplyViewsTLS();
387 TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
388
389 class TORCH_API FunctionalizationReapplyViewsGuard {
390 public:
FunctionalizationReapplyViewsGuard(bool reapply_views)391 FunctionalizationReapplyViewsGuard(bool reapply_views)
392 : prev_(getFunctionalizationReapplyViewsTLS()) {
393 setFunctionalizationReapplyViewsTLS(reapply_views);
394 }
395
~FunctionalizationReapplyViewsGuard()396 ~FunctionalizationReapplyViewsGuard() {
397 setFunctionalizationReapplyViewsTLS(prev_);
398 }
399
400 FunctionalizationReapplyViewsGuard(
401 const FunctionalizationReapplyViewsGuard&) = delete;
402 FunctionalizationReapplyViewsGuard operator=(
403 const FunctionalizationReapplyViewsGuard&) = delete;
404 FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
405 delete;
406 FunctionalizationReapplyViewsGuard operator=(
407 FunctionalizationReapplyViewsGuard&&) = delete;
408
409 private:
410 bool prev_;
411 };
412
413 } // namespace impl
414
415 // Helper function to call an out-of-place composite aten kernel that may use
416 // mutations / views internally, and functionalize them.
417 TORCH_API void functionalize_op_helper(
418 const c10::OperatorHandle& op,
419 torch::jit::Stack* stack);
420
421 template <class Op, bool symint, class ReturnType, class... ParameterTypes>
422 struct _functionalize_aten_op final {};
423
424 template <class Op, bool symint, class ReturnType, class... ParameterTypes>
425 struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
426 static ReturnType call(
427 typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
428 using FuncType = ReturnType(
429 typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
430 auto op = c10::Dispatcher::singleton()
431 .findSchemaOrThrow(
432 (const char*)Op::name, (const char*)Op::overload_name)
433 .typed<FuncType>();
434
435 return c10::impl::BoxedKernelWrapper<FuncType>::call(
436 c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
437 op,
438 // BoxedKernelWrapper knows to ignore this keyset argument,
439 // because functionalize_op_helper doesn't take in a DispatchKeySet
440 c10::DispatchKeySet(),
441 args...);
442 }
443 };
444
445 template <class Op>
446 using functionalize_aten_op =
447 _functionalize_aten_op<Op, false, typename Op::schema>;
448
449 template <class Op>
450 using functionalize_aten_op_symint =
451 _functionalize_aten_op<Op, true, typename Op::schema>;
452
453 } // namespace functionalization
454 } // namespace at
455