xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/input_buffer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/input_buffer.h>
2 
3 #include <ATen/CachedTensorUtils.h>
4 #include <ATen/LegacyBatchedTensorImpl.h>
5 #include <ATen/SparseCsrTensorUtils.h>
6 #include <ATen/TensorOperators.h>
7 #include <ATen/TensorSubclassLikeUtils.h>
8 #include <ATen/core/grad_mode.h>
9 #include <ATen/native/SparseTensorUtils.h>
10 
11 #include <c10/core/DeviceGuard.h>
12 #include <c10/core/Event.h>
13 #include <c10/core/StreamGuard.h>
14 #include <optional>
15 
16 #include <cstddef>
17 #include <utility>
18 #include <vector>
19 
20 namespace torch::autograd {
21 
22 namespace {
23 // look what you made me do >.<
24 // Divergent paths for per-Impl stream recording that leak implementation
25 // details of the impls should not be needed here.
26 // See https://github.com/pytorch/pytorch/issues/60306
27 // TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
28 // improved
record_stream_any_impl(Variable & var,c10::Stream & stream)29 void record_stream_any_impl(Variable& var, c10::Stream& stream) {
30   const auto guard = c10::impl::VirtualGuardImpl(device_of(var).value().type());
31 
32   if (C10_UNLIKELY(at::isBatchedTensor(var))) {
33     auto* impl = at::maybeGetBatchedImpl(var);
34     if (impl) {
35       guard.recordDataPtrOnStream(impl->value().storage().data_ptr(), stream);
36     } else {
37       TORCH_INTERNAL_ASSERT(false, "Expected batched tensor");
38     }
39   } else {
40     switch (var.layout()) {
41       case c10::kSparseCsr:
42       case c10::kSparseCsc:
43       case c10::kSparseBsr:
44       case c10::kSparseBsc: {
45         auto* impl = at::sparse_csr::get_sparse_csr_impl(var);
46         guard.recordDataPtrOnStream(
47             impl->values().storage().data_ptr(), stream);
48         guard.recordDataPtrOnStream(
49             impl->compressed_indices().storage().data_ptr(), stream);
50         guard.recordDataPtrOnStream(
51             impl->plain_indices().storage().data_ptr(), stream);
52         break;
53       }
54       case c10::kSparse: {
55         auto* impl = at::sparse::get_sparse_impl(var);
56         guard.recordDataPtrOnStream(
57             impl->values().storage().data_ptr(), stream);
58         guard.recordDataPtrOnStream(
59             impl->indices().storage().data_ptr(), stream);
60         break;
61       }
62       case c10::kStrided:
63         guard.recordDataPtrOnStream(var.storage().data_ptr(), stream);
64         break;
65       default:
66         TORCH_INTERNAL_ASSERT(
67             false, "Unknown layout in record_stream_any_impl");
68     }
69   }
70 }
71 
can_accumulate_inplace(const Variable & v)72 bool can_accumulate_inplace(const Variable& v) {
73   return (
74       // `v` is a "vanilla" Tensor
75       !(at::isTensorSubclassLike(v) || v._is_zerotensor() || v.is_nested()) &&
76 
77       // with a favorable memory layout
78       v.is_non_overlapping_and_dense() &&
79 
80       // and we hold the last reference
81       at::caching::adjusted_use_count(v) == 1 && v.has_storage() &&
82       v.storage().use_count() == 1);
83 }
84 } // anonymous namespace
85 
accumulate(std::vector<Variable> & buffer,const size_t pos,Variable && var)86 static void accumulate(
87     std::vector<Variable>& buffer,
88     const size_t pos,
89     Variable&& var) {
90   TORCH_INTERNAL_ASSERT(pos < buffer.size());
91   auto& old_var = buffer[pos];
92   // If we hold the last reference to `old_var` AND its storage we will try to
93   // repurpose it to store the output. (Or, if `old_var` is sparse then `var`
94   // becomes the candidate output Tensor.) We only do this if:
95   //  1) GradMode is disabled since Autograd has special handling for inplace
96   //     mutation which we don't want to trigger.
97   //
98   //  2) We hold the last reference.
99   //     (Both `.use_count` and `.storage().use_count()` are one)
100   //
101   //  3) The candidate tensor is a contiguous, non-overlapping, dense, and
102   //     otherwise stock standard Tensor.
103   //
104   //  4) The candidate is mutable. Currently only ZeroTensors are immutable.
105   //
106   //  5) The other Tensor is not a Tensor subclass (except sparse), since
107   //     it's hard to predict the semantics of arbitrary subclass behavior.
108 
109   // NOLINTNEXTLINE(bugprone-branch-clone)
110   if (at::GradMode::is_enabled()) {
111     buffer[pos] = old_var + var;
112   } else if (
113       // ATen doesn't route sparse additions correctly...
114       old_var.is_sparse() || old_var.is_sparse_csr()) {
115     if (can_accumulate_inplace(var)) {
116       buffer[pos] = var.add_(old_var);
117     } else {
118       buffer[pos] = var + old_var;
119     }
120   } else if (
121       can_accumulate_inplace(old_var) && !at::isTensorSubclassLike(var)) {
122     buffer[pos] = old_var.add_(var);
123   } else {
124     buffer[pos] = old_var + var;
125   }
126 }
127 
add(size_t pos,Variable && var,const std::optional<c10::Stream> & opt_producer_stream,const std::optional<c10::Stream> & opt_consumer_stream)128 void InputBuffer::add(
129     size_t pos,
130     Variable&& var,
131     const std::optional<c10::Stream>& opt_producer_stream,
132     const std::optional<c10::Stream>& opt_consumer_stream) {
133   TORCH_INTERNAL_ASSERT(pos < buffer.size());
134   if (!var.defined()) {
135     return;
136   }
137 
138   // Switches to accumulate device
139   // The device (and stream) chosen for accumulation is:
140   //  (1) var is not a CUDA/privateuse1 variable. Accumulation happens on var's
141   //  device. (2) var is a CUDA/privateuse1 variable and it, the consumer, and
142   //  the producer share the same device:
143   //       (2a) Uses the consumer's stream as the accumulation stream
144   //       (2b) Syncs the accumulation stream with the producer's stream (if
145   //       different) (2c) Accumulates.
146   //  (3) var is a CUDA/privateuse1 variable and it shares a device with the
147   //  consumer but not the producer:
148   //       (3a) Uses the consumer's stream as the accumulation stream
149   //       (3b) Syncs the accumulation stream with the consumer device's default
150   //       stream (3c) Accumulates.
151   //  (4) var is a CUDA/privateuse1 variable and it shares a device with the
152   //  producer but not the consumer:
153   //       (4a) Uses the producer device's default stream as the accumulation
154   //       stream (4b) Syncs the accumulation stream with the producer's
155   //       stream (4c) Accumulates.
156   //  (5) var is a CUDA/privateuse1 variable and it does not share a device with
157   //  the consumer or producer.
158   //      Accumulation happens on the var device's default stream.
159 
160   TORCH_INTERNAL_ASSERT(device_of(var));
161   std::optional<c10::Stream> opt_accumulate_stream = std::nullopt;
162   const auto device_type = device_of(var).value().type();
163   // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
164   if (device_of(var)->is_cuda() || device_of(var)->is_privateuseone()) {
165     const auto on_producer =
166         opt_producer_stream && device_of(var) == opt_producer_stream->device();
167     const auto on_consumer =
168         opt_consumer_stream && device_of(var) == opt_consumer_stream->device();
169 
170     if (on_producer && on_consumer) {
171       // (2a)
172       opt_accumulate_stream = opt_consumer_stream;
173       if (opt_accumulate_stream != opt_producer_stream) {
174         // (2b)
175         auto event = c10::Event{device_type};
176         event.record(*opt_producer_stream);
177         opt_accumulate_stream->wait(event);
178         record_stream_any_impl(var, *opt_accumulate_stream);
179       }
180     } else {
181       std::optional<c10::Stream> opt_sync_stream = std::nullopt;
182       const auto guard = c10::impl::VirtualGuardImpl{device_type};
183       if (on_consumer && !on_producer) {
184         // (3a)
185         opt_accumulate_stream = opt_consumer_stream;
186         opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device());
187       } else if (on_producer && !on_consumer) {
188         // (4a)
189         opt_accumulate_stream =
190             guard.getDefaultStream(opt_producer_stream->device());
191         opt_sync_stream = opt_producer_stream;
192       } else {
193         // (5)
194         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
195         opt_accumulate_stream = guard.getDefaultStream(*device_of(var));
196       }
197       if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
198         // (3b), (4b)
199         c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()};
200         auto event = c10::Event{device_type};
201         event.record(*opt_sync_stream);
202         opt_accumulate_stream->wait(event);
203         const auto guard = c10::impl::VirtualGuardImpl(device_type);
204         record_stream_any_impl(var, *opt_accumulate_stream);
205       }
206     }
207   }
208 
209   auto& old_var = buffer[pos];
210   if (!old_var.defined()) {
211     buffer[pos] = std::move(var);
212   } else {
213     if (opt_accumulate_stream) {
214       c10::OptionalStreamGuard stream_guard{opt_accumulate_stream};
215       accumulate(buffer, pos, std::move(var));
216     } else {
217       // (1) non-CUDA/privateuse1 variable
218       //     Accumulation happens on variable's device
219       c10::OptionalDeviceGuard device_guard{device_of(var)};
220       accumulate(buffer, pos, std::move(var));
221     }
222   }
223 }
224 
device() const225 auto InputBuffer::device() const -> at::Device {
226   // Since we pick the first non-CPU tensor, this won't work with
227   // mixed device-type operations (e.g., an op that is both CUDA
228   // and XLA).  This is *incredibly* unlikely, so we don't worry
229   // about it.
230   for (auto& var : buffer) {
231     if (var.defined()) {
232       auto device = var.device();
233       if (device.type() != at::kCPU) {
234         return device;
235       }
236     }
237   }
238   // Only report to the CPU thread if there really were no tensors
239   // from other devices.
240   return at::kCPU;
241 }
242 
variables(InputBuffer && g)243 auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> {
244   std::vector<Variable> result = std::move(g.buffer);
245   return result;
246 }
247 
248 } // namespace torch::autograd
249