xref: /aosp_15_r20/external/pytorch/aten/src/ATen/templates/FunctionalInverses.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // ${generated_comment}
4 
5 #include <ATen/Tensor.h>
6 
7 namespace at {
8 namespace functionalization {
9 
10 enum class InverseReturnMode {
11   /// Specifies that functional inverses should always return a view.
12   AlwaysView,
13   /// Specifies that functional inverses should always return a non-view / copy.
14   NeverView,
15   /// Specifies that functional inverses should return a view unless a (copying) scatter
16   /// inverse exists, in which case that will be used instead.
17   /// This avoids as_strided() calls that can be difficult for subclasses to handle.
18   ViewOrScatterInverse,
19 };
20 
21 struct FunctionalInverses {
22 
23 ${view_inverse_declarations}
24 
25 // NB: These are not generated! They're manually implemented in the template.
26 // TODO: Change codegen to generate these. See the following link:
27 // https://github.com/pytorch/pytorch/blob/main/torchgen/model.py#L2583-L2585
28 static at::Tensor chunk_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim);
29 static at::Tensor narrow_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length);
30 
31 };
32 }
33 }
34