1 #ifndef AT_PER_OPERATOR_HEADERS
2 #include <ATen/Functions.h>
3 #else
4 #include <ATen/ops/mm.h>
5 #endif
6
7 #include <torch/csrc/autograd/functions/accumulate_grad.h>
8 #include <torch/csrc/inductor/inductor_ops.h>
9 #include <torch/library.h>
10
11 #include <ATen/FunctionalTensorWrapper.h>
12
13 namespace torch::inductor {
14 using namespace at;
15
_mm_plus_mm_out(Tensor & out,const Tensor & a,const Tensor & b,const Tensor & c,const Tensor & d)16 Tensor _mm_plus_mm_out(
17 Tensor& out,
18 const Tensor& a,
19 const Tensor& b,
20 const Tensor& c,
21 const Tensor& d) {
22 at::mm_out(out, a, b);
23 out.addmm_(c, d);
24 return out;
25 }
26
_mm_plus_mm(const Tensor & a,const Tensor & b,const Tensor & c,const Tensor & d,Tensor & out)27 Tensor _mm_plus_mm(
28 const Tensor& a,
29 const Tensor& b,
30 const Tensor& c,
31 const Tensor& d,
32 Tensor& out) {
33 return _mm_plus_mm_out(out, a, b, c, d);
34 }
35
_alloc_from_pool(const Tensor & self,int64_t offset_bytes,ScalarType dtype,IntArrayRef size,IntArrayRef stride)36 Tensor _alloc_from_pool(
37 const Tensor& self,
38 int64_t offset_bytes,
39 ScalarType dtype,
40 IntArrayRef size,
41 IntArrayRef stride) {
42 TORCH_CHECK(self.storage_offset() == 0);
43 // based on alias_with_sizes_and_strides from TensorShape.cpp
44 Tensor self_ = at::detail::make_tensor<TensorImpl>(
45 // c10::TensorImpl::VIEW,
46 Storage(self.storage()),
47 self.key_set(),
48 caffe2::TypeMeta::fromScalarType(dtype));
49 auto* self_tmp_ = self_.unsafeGetTensorImpl();
50 self_tmp_->set_storage_offset(
51 offset_bytes / static_cast<int64_t>(c10::elementSize(dtype)));
52 self_tmp_->set_sizes_and_strides(size, stride);
53 return self_;
54 }
55
56 // Similar to as_strided with the following differences
57 // - offset is added to the existing offset (rather than replacing it)
58 // - view tracking is disabled similar to unsafe_view
_reinterpret_tensor(const Tensor & self,IntArrayRef size,IntArrayRef stride,int64_t offset_increment)59 Tensor _reinterpret_tensor(
60 const Tensor& self,
61 IntArrayRef size,
62 IntArrayRef stride,
63 int64_t offset_increment) {
64 Tensor self_ = at::detail::make_tensor<TensorImpl>(
65 Storage(self.storage()), self.key_set(), self.dtype());
66 auto* self_tmp_ = self_.unsafeGetTensorImpl();
67 self_tmp_->set_storage_offset(self.storage_offset() + offset_increment);
68 self_tmp_->set_sizes_and_strides(size, stride);
69 return self_;
70 }
71
accumulate_grad_(const Tensor & variable,const Tensor & new_grad)72 static void accumulate_grad_(const Tensor& variable, const Tensor& new_grad) {
73 at::Tensor& grad = variable.mutable_grad();
74 if (new_grad.device() != kMeta) {
75 // Do not call into this codepath from C++ frontend, instead call directly
76 // into accumulateGrad with num_expected_refs set to 1 Here,
77 // num_expected_refs is set to 2 to steal the gradient when this is called
78 // from Python
79 torch::autograd::AccumulateGrad::accumulateGrad(
80 variable,
81 grad,
82 new_grad,
83 2 /* num_expected_refs */,
84 [&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); });
85 } else {
86 // no shape checking for `device="meta"` to workaround FSDP inplace mutation
87 if (!grad.defined()) {
88 grad = new_grad;
89 }
90 }
91 }
92
TORCH_LIBRARY_FRAGMENT(inductor,m)93 TORCH_LIBRARY_FRAGMENT(inductor, m) {
94 m.def(
95 "_mm_plus_mm(Tensor a, Tensor b, Tensor c, Tensor d, Tensor(t!) out) -> Tensor(t!)",
96 dispatch(c10::DispatchKey::CompositeExplicitAutograd, _mm_plus_mm),
97 {at::Tag::pt2_compliant_tag});
98 m.def(
99 "_alloc_from_pool(Tensor self, int offset_bytes, ScalarType dtype, int[] size, int[] stride) -> Tensor",
100 _alloc_from_pool,
101 {at::Tag::pt2_compliant_tag});
102 m.def(
103 "_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor",
104 dispatch(
105 c10::DispatchKey::CompositeExplicitAutograd, _reinterpret_tensor),
106 {at::Tag::pt2_compliant_tag});
107 m.def(
108 "accumulate_grad_(Tensor variable, Tensor new_grad) -> ()",
109 dispatch(c10::DispatchKey::CompositeExplicitAutograd, accumulate_grad_),
110 {at::Tag::pt2_compliant_tag});
111 }
112
113 } // namespace torch::inductor
114