xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/AutogradComposite.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <c10/util/SmallBuffer.h>
4 #include <c10/core/impl/COW.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_has_same_storage_numel_native.h>
11 #include <ATen/ops/_make_dual_native.h>
12 #include <ATen/ops/_new_zeros_with_same_feature_meta_native.h>
13 #include <ATen/ops/_unpack_dual_native.h>
14 #include <ATen/ops/_lazy_clone_native.h>
15 #include <ATen/ops/alias.h>
16 #include <ATen/ops/zeros.h>
17 #endif
18 
19 namespace at::native {
20 
21 // We expect this code to only be reached in inference mode and when all inputs are inference tensors
_make_dual(const Tensor & primal,const Tensor & tangent,int64_t level)22 Tensor _make_dual(const Tensor& primal, const Tensor& tangent, int64_t level) {
23   TORCH_INTERNAL_ASSERT(
24       InferenceMode::is_enabled() && primal.is_inference() && tangent.is_inference(),
25       "Expected this function to only be reached in inference mode and when all the "
26       "inputs are inference tensors. You should NOT call this function directly as "
27       "native::_make_dual. Please use the dispatcher, i.e., at::_make_dual. Please "
28       "file an issue if you come across this error otherwise.");
29   return at::alias(primal);
30 }
31 
32 /// This function can be used to unpack a given dual Tensor to get its primal and tangent. The returned primal
33 /// is a view of the dual and the tangent is returned as is.
34 /// This function is backward differentiable.
_unpack_dual(const at::Tensor & tensor,int64_t level)35 std::tuple<at::Tensor, at::Tensor> _unpack_dual(const at::Tensor& tensor, int64_t level) {
36   return std::tuple<at::Tensor, at::Tensor>(tensor._fw_primal(level), tensor._fw_grad(level));
37 }
38 
39 // NB: This function can be called directly from _set_fw_grad or
40 //     if self is batched, from this function's batching rule.
41 //     See NOTE: [_new_zeros_with_same_feature_meta] for more information.
_new_zeros_with_same_feature_meta(const at::Tensor & self,const at::Tensor & other,int64_t self_num_batch_dims)42 Tensor _new_zeros_with_same_feature_meta(
43     const at::Tensor& self,
44     const at::Tensor& other,
45     int64_t self_num_batch_dims) {
46   auto other_sizes = other.sym_sizes();
47   auto other_strides = other.sym_strides();
48   auto other_storage_offset = other.storage_offset();
49   auto other_storage_numel = other.storage().sym_nbytes() / other.itemsize();
50 
51   if (self_num_batch_dims == 0) {
52     auto new_tensor = at::zeros_symint({other_storage_numel}, other.options());
53     return new_tensor.as_strided_symint(other_sizes, other_strides, other_storage_offset);
54   }
55 
56   auto self_sizes = self.sym_sizes();
57 
58   // NB: We don't check that the sizes of self is the same as that of other
59   //     because this function is also used in the inplace over view case
60   //     In the inplace over view case we cannot rely on self and other being
61   //     the same size. So we will use the size of other, and simply tack on
62   //     the batch dims from self. For example: If self.sizes: [B, 2, 3],
63   //     and other.size: [6], we return [B, 6].
64   //     Also see the test test_inplace_on_view_not_same_layout, for when we reach
65   //     this case.
66   constexpr int64_t kSmallBufferSizeHint = 8;
67 
68   auto out_sizes = c10::SmallVector<c10::SymInt, kSmallBufferSizeHint>(other.dim() + self_num_batch_dims);
69   std::copy(self_sizes.begin(), self_sizes.begin() + self_num_batch_dims, out_sizes.begin());
70   std::copy(other_sizes.begin(), other_sizes.end(), out_sizes.begin() + self_num_batch_dims);
71 
72   // We use the strides of other, and tack on the strides computed with
73   // the batch dims of self, so that the slices are arranged contiguously
74   auto out_strides = c10::SmallVector<c10::SymInt, kSmallBufferSizeHint>(other.dim() + self_num_batch_dims);
75   auto prod = other_storage_numel;
76 
77   for (int64_t i = self_num_batch_dims - 1; i >= 0; --i) {
78     out_strides[i] = prod;
79     prod *= self_sizes[i];
80   }
81   std::copy(other_strides.begin(), other_strides.end(), out_strides.begin() + self_num_batch_dims);
82 
83   auto storage_numel = prod;
84 
85   // Inherit the TensorOptions of the primal
86   auto new_tensor = at::zeros_symint({storage_numel}, other.options());
87   return new_tensor.as_strided_symint(out_sizes, out_strides, other_storage_offset);
88 }
89 
_has_same_storage_numel(const at::Tensor & base,const at::Tensor & other)90 bool _has_same_storage_numel(const at::Tensor& base, const at::Tensor& other) {
91   return base.storage().sym_nbytes() / base.itemsize() == other.storage().sym_nbytes() / other.itemsize();
92 }
93 
_lazy_clone(Tensor const & self)94 Tensor _lazy_clone(Tensor const& self) {
95   c10::StorageImpl* self_storage = self.storage().unsafeGetStorageImpl();
96   c10::intrusive_ptr<c10::StorageImpl> storage =
97     c10::impl::cow::lazy_clone_storage(*self_storage);
98   TORCH_CHECK(storage != nullptr);
99   auto tensor = c10::make_intrusive<c10::TensorImpl>(
100       c10::Storage(std::move(storage)),
101       self.key_set(),
102       self.dtype());
103   tensor->set_sizes_and_strides(self.sym_sizes(),
104                                 self.sym_strides(),
105                                 self.sym_storage_offset());
106   return Tensor(std::move(tensor));
107 }
108 
109 } // namespace at::native
110