1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <c10/util/SmallBuffer.h>
4 #include <c10/core/impl/COW.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_has_same_storage_numel_native.h>
11 #include <ATen/ops/_make_dual_native.h>
12 #include <ATen/ops/_new_zeros_with_same_feature_meta_native.h>
13 #include <ATen/ops/_unpack_dual_native.h>
14 #include <ATen/ops/_lazy_clone_native.h>
15 #include <ATen/ops/alias.h>
16 #include <ATen/ops/zeros.h>
17 #endif
18
19 namespace at::native {
20
21 // We expect this code to only be reached in inference mode and when all inputs are inference tensors
_make_dual(const Tensor & primal,const Tensor & tangent,int64_t level)22 Tensor _make_dual(const Tensor& primal, const Tensor& tangent, int64_t level) {
23 TORCH_INTERNAL_ASSERT(
24 InferenceMode::is_enabled() && primal.is_inference() && tangent.is_inference(),
25 "Expected this function to only be reached in inference mode and when all the "
26 "inputs are inference tensors. You should NOT call this function directly as "
27 "native::_make_dual. Please use the dispatcher, i.e., at::_make_dual. Please "
28 "file an issue if you come across this error otherwise.");
29 return at::alias(primal);
30 }
31
32 /// This function can be used to unpack a given dual Tensor to get its primal and tangent. The returned primal
33 /// is a view of the dual and the tangent is returned as is.
34 /// This function is backward differentiable.
_unpack_dual(const at::Tensor & tensor,int64_t level)35 std::tuple<at::Tensor, at::Tensor> _unpack_dual(const at::Tensor& tensor, int64_t level) {
36 return std::tuple<at::Tensor, at::Tensor>(tensor._fw_primal(level), tensor._fw_grad(level));
37 }
38
39 // NB: This function can be called directly from _set_fw_grad or
40 // if self is batched, from this function's batching rule.
41 // See NOTE: [_new_zeros_with_same_feature_meta] for more information.
_new_zeros_with_same_feature_meta(const at::Tensor & self,const at::Tensor & other,int64_t self_num_batch_dims)42 Tensor _new_zeros_with_same_feature_meta(
43 const at::Tensor& self,
44 const at::Tensor& other,
45 int64_t self_num_batch_dims) {
46 auto other_sizes = other.sym_sizes();
47 auto other_strides = other.sym_strides();
48 auto other_storage_offset = other.storage_offset();
49 auto other_storage_numel = other.storage().sym_nbytes() / other.itemsize();
50
51 if (self_num_batch_dims == 0) {
52 auto new_tensor = at::zeros_symint({other_storage_numel}, other.options());
53 return new_tensor.as_strided_symint(other_sizes, other_strides, other_storage_offset);
54 }
55
56 auto self_sizes = self.sym_sizes();
57
58 // NB: We don't check that the sizes of self is the same as that of other
59 // because this function is also used in the inplace over view case
60 // In the inplace over view case we cannot rely on self and other being
61 // the same size. So we will use the size of other, and simply tack on
62 // the batch dims from self. For example: If self.sizes: [B, 2, 3],
63 // and other.size: [6], we return [B, 6].
64 // Also see the test test_inplace_on_view_not_same_layout, for when we reach
65 // this case.
66 constexpr int64_t kSmallBufferSizeHint = 8;
67
68 auto out_sizes = c10::SmallVector<c10::SymInt, kSmallBufferSizeHint>(other.dim() + self_num_batch_dims);
69 std::copy(self_sizes.begin(), self_sizes.begin() + self_num_batch_dims, out_sizes.begin());
70 std::copy(other_sizes.begin(), other_sizes.end(), out_sizes.begin() + self_num_batch_dims);
71
72 // We use the strides of other, and tack on the strides computed with
73 // the batch dims of self, so that the slices are arranged contiguously
74 auto out_strides = c10::SmallVector<c10::SymInt, kSmallBufferSizeHint>(other.dim() + self_num_batch_dims);
75 auto prod = other_storage_numel;
76
77 for (int64_t i = self_num_batch_dims - 1; i >= 0; --i) {
78 out_strides[i] = prod;
79 prod *= self_sizes[i];
80 }
81 std::copy(other_strides.begin(), other_strides.end(), out_strides.begin() + self_num_batch_dims);
82
83 auto storage_numel = prod;
84
85 // Inherit the TensorOptions of the primal
86 auto new_tensor = at::zeros_symint({storage_numel}, other.options());
87 return new_tensor.as_strided_symint(out_sizes, out_strides, other_storage_offset);
88 }
89
_has_same_storage_numel(const at::Tensor & base,const at::Tensor & other)90 bool _has_same_storage_numel(const at::Tensor& base, const at::Tensor& other) {
91 return base.storage().sym_nbytes() / base.itemsize() == other.storage().sym_nbytes() / other.itemsize();
92 }
93
_lazy_clone(Tensor const & self)94 Tensor _lazy_clone(Tensor const& self) {
95 c10::StorageImpl* self_storage = self.storage().unsafeGetStorageImpl();
96 c10::intrusive_ptr<c10::StorageImpl> storage =
97 c10::impl::cow::lazy_clone_storage(*self_storage);
98 TORCH_CHECK(storage != nullptr);
99 auto tensor = c10::make_intrusive<c10::TensorImpl>(
100 c10::Storage(std::move(storage)),
101 self.key_set(),
102 self.dtype());
103 tensor->set_sizes_and_strides(self.sym_sizes(),
104 self.sym_strides(),
105 self.sym_storage_offset());
106 return Tensor(std::move(tensor));
107 }
108
109 } // namespace at::native
110