1 #include <torch/csrc/autograd/input_buffer.h>
2
3 #include <ATen/CachedTensorUtils.h>
4 #include <ATen/LegacyBatchedTensorImpl.h>
5 #include <ATen/SparseCsrTensorUtils.h>
6 #include <ATen/TensorOperators.h>
7 #include <ATen/TensorSubclassLikeUtils.h>
8 #include <ATen/core/grad_mode.h>
9 #include <ATen/native/SparseTensorUtils.h>
10
11 #include <c10/core/DeviceGuard.h>
12 #include <c10/core/Event.h>
13 #include <c10/core/StreamGuard.h>
14 #include <optional>
15
16 #include <cstddef>
17 #include <utility>
18 #include <vector>
19
20 namespace torch::autograd {
21
22 namespace {
23 // look what you made me do >.<
24 // Divergent paths for per-Impl stream recording that leak implementation
25 // details of the impls should not be needed here.
26 // See https://github.com/pytorch/pytorch/issues/60306
27 // TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
28 // improved
record_stream_any_impl(Variable & var,c10::Stream & stream)29 void record_stream_any_impl(Variable& var, c10::Stream& stream) {
30 const auto guard = c10::impl::VirtualGuardImpl(device_of(var).value().type());
31
32 if (C10_UNLIKELY(at::isBatchedTensor(var))) {
33 auto* impl = at::maybeGetBatchedImpl(var);
34 if (impl) {
35 guard.recordDataPtrOnStream(impl->value().storage().data_ptr(), stream);
36 } else {
37 TORCH_INTERNAL_ASSERT(false, "Expected batched tensor");
38 }
39 } else {
40 switch (var.layout()) {
41 case c10::kSparseCsr:
42 case c10::kSparseCsc:
43 case c10::kSparseBsr:
44 case c10::kSparseBsc: {
45 auto* impl = at::sparse_csr::get_sparse_csr_impl(var);
46 guard.recordDataPtrOnStream(
47 impl->values().storage().data_ptr(), stream);
48 guard.recordDataPtrOnStream(
49 impl->compressed_indices().storage().data_ptr(), stream);
50 guard.recordDataPtrOnStream(
51 impl->plain_indices().storage().data_ptr(), stream);
52 break;
53 }
54 case c10::kSparse: {
55 auto* impl = at::sparse::get_sparse_impl(var);
56 guard.recordDataPtrOnStream(
57 impl->values().storage().data_ptr(), stream);
58 guard.recordDataPtrOnStream(
59 impl->indices().storage().data_ptr(), stream);
60 break;
61 }
62 case c10::kStrided:
63 guard.recordDataPtrOnStream(var.storage().data_ptr(), stream);
64 break;
65 default:
66 TORCH_INTERNAL_ASSERT(
67 false, "Unknown layout in record_stream_any_impl");
68 }
69 }
70 }
71
can_accumulate_inplace(const Variable & v)72 bool can_accumulate_inplace(const Variable& v) {
73 return (
74 // `v` is a "vanilla" Tensor
75 !(at::isTensorSubclassLike(v) || v._is_zerotensor() || v.is_nested()) &&
76
77 // with a favorable memory layout
78 v.is_non_overlapping_and_dense() &&
79
80 // and we hold the last reference
81 at::caching::adjusted_use_count(v) == 1 && v.has_storage() &&
82 v.storage().use_count() == 1);
83 }
84 } // anonymous namespace
85
accumulate(std::vector<Variable> & buffer,const size_t pos,Variable && var)86 static void accumulate(
87 std::vector<Variable>& buffer,
88 const size_t pos,
89 Variable&& var) {
90 TORCH_INTERNAL_ASSERT(pos < buffer.size());
91 auto& old_var = buffer[pos];
92 // If we hold the last reference to `old_var` AND its storage we will try to
93 // repurpose it to store the output. (Or, if `old_var` is sparse then `var`
94 // becomes the candidate output Tensor.) We only do this if:
95 // 1) GradMode is disabled since Autograd has special handling for inplace
96 // mutation which we don't want to trigger.
97 //
98 // 2) We hold the last reference.
99 // (Both `.use_count` and `.storage().use_count()` are one)
100 //
101 // 3) The candidate tensor is a contiguous, non-overlapping, dense, and
102 // otherwise stock standard Tensor.
103 //
104 // 4) The candidate is mutable. Currently only ZeroTensors are immutable.
105 //
106 // 5) The other Tensor is not a Tensor subclass (except sparse), since
107 // it's hard to predict the semantics of arbitrary subclass behavior.
108
109 // NOLINTNEXTLINE(bugprone-branch-clone)
110 if (at::GradMode::is_enabled()) {
111 buffer[pos] = old_var + var;
112 } else if (
113 // ATen doesn't route sparse additions correctly...
114 old_var.is_sparse() || old_var.is_sparse_csr()) {
115 if (can_accumulate_inplace(var)) {
116 buffer[pos] = var.add_(old_var);
117 } else {
118 buffer[pos] = var + old_var;
119 }
120 } else if (
121 can_accumulate_inplace(old_var) && !at::isTensorSubclassLike(var)) {
122 buffer[pos] = old_var.add_(var);
123 } else {
124 buffer[pos] = old_var + var;
125 }
126 }
127
add(size_t pos,Variable && var,const std::optional<c10::Stream> & opt_producer_stream,const std::optional<c10::Stream> & opt_consumer_stream)128 void InputBuffer::add(
129 size_t pos,
130 Variable&& var,
131 const std::optional<c10::Stream>& opt_producer_stream,
132 const std::optional<c10::Stream>& opt_consumer_stream) {
133 TORCH_INTERNAL_ASSERT(pos < buffer.size());
134 if (!var.defined()) {
135 return;
136 }
137
138 // Switches to accumulate device
139 // The device (and stream) chosen for accumulation is:
140 // (1) var is not a CUDA/privateuse1 variable. Accumulation happens on var's
141 // device. (2) var is a CUDA/privateuse1 variable and it, the consumer, and
142 // the producer share the same device:
143 // (2a) Uses the consumer's stream as the accumulation stream
144 // (2b) Syncs the accumulation stream with the producer's stream (if
145 // different) (2c) Accumulates.
146 // (3) var is a CUDA/privateuse1 variable and it shares a device with the
147 // consumer but not the producer:
148 // (3a) Uses the consumer's stream as the accumulation stream
149 // (3b) Syncs the accumulation stream with the consumer device's default
150 // stream (3c) Accumulates.
151 // (4) var is a CUDA/privateuse1 variable and it shares a device with the
152 // producer but not the consumer:
153 // (4a) Uses the producer device's default stream as the accumulation
154 // stream (4b) Syncs the accumulation stream with the producer's
155 // stream (4c) Accumulates.
156 // (5) var is a CUDA/privateuse1 variable and it does not share a device with
157 // the consumer or producer.
158 // Accumulation happens on the var device's default stream.
159
160 TORCH_INTERNAL_ASSERT(device_of(var));
161 std::optional<c10::Stream> opt_accumulate_stream = std::nullopt;
162 const auto device_type = device_of(var).value().type();
163 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
164 if (device_of(var)->is_cuda() || device_of(var)->is_privateuseone()) {
165 const auto on_producer =
166 opt_producer_stream && device_of(var) == opt_producer_stream->device();
167 const auto on_consumer =
168 opt_consumer_stream && device_of(var) == opt_consumer_stream->device();
169
170 if (on_producer && on_consumer) {
171 // (2a)
172 opt_accumulate_stream = opt_consumer_stream;
173 if (opt_accumulate_stream != opt_producer_stream) {
174 // (2b)
175 auto event = c10::Event{device_type};
176 event.record(*opt_producer_stream);
177 opt_accumulate_stream->wait(event);
178 record_stream_any_impl(var, *opt_accumulate_stream);
179 }
180 } else {
181 std::optional<c10::Stream> opt_sync_stream = std::nullopt;
182 const auto guard = c10::impl::VirtualGuardImpl{device_type};
183 if (on_consumer && !on_producer) {
184 // (3a)
185 opt_accumulate_stream = opt_consumer_stream;
186 opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device());
187 } else if (on_producer && !on_consumer) {
188 // (4a)
189 opt_accumulate_stream =
190 guard.getDefaultStream(opt_producer_stream->device());
191 opt_sync_stream = opt_producer_stream;
192 } else {
193 // (5)
194 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
195 opt_accumulate_stream = guard.getDefaultStream(*device_of(var));
196 }
197 if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
198 // (3b), (4b)
199 c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()};
200 auto event = c10::Event{device_type};
201 event.record(*opt_sync_stream);
202 opt_accumulate_stream->wait(event);
203 const auto guard = c10::impl::VirtualGuardImpl(device_type);
204 record_stream_any_impl(var, *opt_accumulate_stream);
205 }
206 }
207 }
208
209 auto& old_var = buffer[pos];
210 if (!old_var.defined()) {
211 buffer[pos] = std::move(var);
212 } else {
213 if (opt_accumulate_stream) {
214 c10::OptionalStreamGuard stream_guard{opt_accumulate_stream};
215 accumulate(buffer, pos, std::move(var));
216 } else {
217 // (1) non-CUDA/privateuse1 variable
218 // Accumulation happens on variable's device
219 c10::OptionalDeviceGuard device_guard{device_of(var)};
220 accumulate(buffer, pos, std::move(var));
221 }
222 }
223 }
224
device() const225 auto InputBuffer::device() const -> at::Device {
226 // Since we pick the first non-CPU tensor, this won't work with
227 // mixed device-type operations (e.g., an op that is both CUDA
228 // and XLA). This is *incredibly* unlikely, so we don't worry
229 // about it.
230 for (auto& var : buffer) {
231 if (var.defined()) {
232 auto device = var.device();
233 if (device.type() != at::kCPU) {
234 return device;
235 }
236 }
237 }
238 // Only report to the CPU thread if there really were no tensors
239 // from other devices.
240 return at::kCPU;
241 }
242
variables(InputBuffer && g)243 auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> {
244 std::vector<Variable> result = std::move(g.buffer);
245 return result;
246 }
247
248 } // namespace torch::autograd
249