xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/inductor_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef AT_PER_OPERATOR_HEADERS
2 #include <ATen/Functions.h>
3 #else
4 #include <ATen/ops/mm.h>
5 #endif
6 
7 #include <torch/csrc/autograd/functions/accumulate_grad.h>
8 #include <torch/csrc/inductor/inductor_ops.h>
9 #include <torch/library.h>
10 
11 #include <ATen/FunctionalTensorWrapper.h>
12 
13 namespace torch::inductor {
14 using namespace at;
15 
_mm_plus_mm_out(Tensor & out,const Tensor & a,const Tensor & b,const Tensor & c,const Tensor & d)16 Tensor _mm_plus_mm_out(
17     Tensor& out,
18     const Tensor& a,
19     const Tensor& b,
20     const Tensor& c,
21     const Tensor& d) {
22   at::mm_out(out, a, b);
23   out.addmm_(c, d);
24   return out;
25 }
26 
_mm_plus_mm(const Tensor & a,const Tensor & b,const Tensor & c,const Tensor & d,Tensor & out)27 Tensor _mm_plus_mm(
28     const Tensor& a,
29     const Tensor& b,
30     const Tensor& c,
31     const Tensor& d,
32     Tensor& out) {
33   return _mm_plus_mm_out(out, a, b, c, d);
34 }
35 
_alloc_from_pool(const Tensor & self,int64_t offset_bytes,ScalarType dtype,IntArrayRef size,IntArrayRef stride)36 Tensor _alloc_from_pool(
37     const Tensor& self,
38     int64_t offset_bytes,
39     ScalarType dtype,
40     IntArrayRef size,
41     IntArrayRef stride) {
42   TORCH_CHECK(self.storage_offset() == 0);
43   // based on alias_with_sizes_and_strides from TensorShape.cpp
44   Tensor self_ = at::detail::make_tensor<TensorImpl>(
45       // c10::TensorImpl::VIEW,
46       Storage(self.storage()),
47       self.key_set(),
48       caffe2::TypeMeta::fromScalarType(dtype));
49   auto* self_tmp_ = self_.unsafeGetTensorImpl();
50   self_tmp_->set_storage_offset(
51       offset_bytes / static_cast<int64_t>(c10::elementSize(dtype)));
52   self_tmp_->set_sizes_and_strides(size, stride);
53   return self_;
54 }
55 
56 // Similar to as_strided with the following differences
57 // - offset is added to the existing offset (rather than replacing it)
58 // - view tracking is disabled similar to unsafe_view
_reinterpret_tensor(const Tensor & self,IntArrayRef size,IntArrayRef stride,int64_t offset_increment)59 Tensor _reinterpret_tensor(
60     const Tensor& self,
61     IntArrayRef size,
62     IntArrayRef stride,
63     int64_t offset_increment) {
64   Tensor self_ = at::detail::make_tensor<TensorImpl>(
65       Storage(self.storage()), self.key_set(), self.dtype());
66   auto* self_tmp_ = self_.unsafeGetTensorImpl();
67   self_tmp_->set_storage_offset(self.storage_offset() + offset_increment);
68   self_tmp_->set_sizes_and_strides(size, stride);
69   return self_;
70 }
71 
accumulate_grad_(const Tensor & variable,const Tensor & new_grad)72 static void accumulate_grad_(const Tensor& variable, const Tensor& new_grad) {
73   at::Tensor& grad = variable.mutable_grad();
74   if (new_grad.device() != kMeta) {
75     // Do not call into this codepath from C++ frontend, instead call directly
76     // into accumulateGrad with num_expected_refs set to 1 Here,
77     // num_expected_refs is set to 2 to steal the gradient when this is called
78     // from Python
79     torch::autograd::AccumulateGrad::accumulateGrad(
80         variable,
81         grad,
82         new_grad,
83         2 /* num_expected_refs */,
84         [&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); });
85   } else {
86     // no shape checking for `device="meta"` to workaround FSDP inplace mutation
87     if (!grad.defined()) {
88       grad = new_grad;
89     }
90   }
91 }
92 
TORCH_LIBRARY_FRAGMENT(inductor,m)93 TORCH_LIBRARY_FRAGMENT(inductor, m) {
94   m.def(
95       "_mm_plus_mm(Tensor a, Tensor b, Tensor c, Tensor d, Tensor(t!) out) -> Tensor(t!)",
96       dispatch(c10::DispatchKey::CompositeExplicitAutograd, _mm_plus_mm),
97       {at::Tag::pt2_compliant_tag});
98   m.def(
99       "_alloc_from_pool(Tensor self, int offset_bytes, ScalarType dtype, int[] size, int[] stride) -> Tensor",
100       _alloc_from_pool,
101       {at::Tag::pt2_compliant_tag});
102   m.def(
103       "_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor",
104       dispatch(
105           c10::DispatchKey::CompositeExplicitAutograd, _reinterpret_tensor),
106       {at::Tag::pt2_compliant_tag});
107   m.def(
108       "accumulate_grad_(Tensor variable, Tensor new_grad) -> ()",
109       dispatch(c10::DispatchKey::CompositeExplicitAutograd, accumulate_grad_),
110       {at::Tag::pt2_compliant_tag});
111 }
112 
113 } // namespace torch::inductor
114