xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/CPUFallback.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/CPUFallback.h>
3 
4 #include <ATen/core/ivalue.h>
5 #include <ATen/core/stack.h>
6 #include <ATen/core/dispatch/Dispatcher.h>
7 
8 #include <sstream>
9 #include <vector>
10 
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #else
15 #include <ATen/ops/_copy_from_and_resize.h>
16 #include <ATen/ops/_to_cpu.h>
17 #endif
18 
19 
20 namespace at::native {
21 
22 // convenience helper for converting tensors to cpu
23 
24 template<typename T, std::enable_if_t<std::is_same_v<T, at::Tensor> || std::is_same_v<T, std::optional<at::Tensor>>, int> = 1>
to_cpu(const std::vector<T> & tensors)25 static std::vector<T> to_cpu(const std::vector<T>& tensors) {
26     // We can't just call at::to_cpu() on the entire list of Tensors
27     // Because it will break on undefined tensors. Separate out undefined tensors first.
28     const int num = tensors.size();
29     std::vector<T> cpu_tensors(num);
30     std::vector<at::Tensor> valid_tensors;
31     std::vector<bool> to_translate(num);
32     for (const auto i : c10::irange(num)) {
33       // Explicitly handling undefined tensors here instead of letting `at::_to_cpu` handle it.
34       // Otherwise, we'd need to require all backends with their own implementation of _to_cpu
35       // to properly handle undefined tensors.
36       if constexpr(std::is_same_v<T, std::optional<at::Tensor>>) {
37         if (tensors[i].has_value() && tensors[i].value().defined()) {
38           to_translate[i] = true;
39           valid_tensors.push_back(tensors[i].value());
40         } else {
41           cpu_tensors[i] = tensors[i];
42         }
43       } else {
44         if (tensors[i].defined()) {
45           to_translate[i] = true;
46           valid_tensors.push_back(tensors[i]);
47         } else {
48           cpu_tensors[i] = tensors[i];
49         }
50       }
51     }
52     auto cpu_valid_tensors = at::_to_cpu(valid_tensors);
53     for (int i = 0, defined_pos = 0; i < num; ++i) {
54       if (to_translate[i]) {
55         cpu_tensors[i] = std::move(cpu_valid_tensors[defined_pos++]);
56       }
57     }
58   return cpu_tensors;
59 }
60 
compute_target_device(std::vector<at::Tensor> & t_args,const std::vector<c10::List<at::Tensor>> & tlist_args)61 static std::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args, const std::vector<c10::List<at::Tensor>>& tlist_args) {
62   // Decide what device to move the output tensor(s) to.
63   // The current convention is that we use the first tensor arg to pick the device
64   // Barring that, we take the first tensor from a TensorList arg.
65   if (!t_args.empty()) {
66     return t_args[0].device();
67   } else {
68     // We need to loop through all of the (potentially multiple) TensorList arguments
69     // In case, e.g. the first one is empty but the second is not.
70     for (auto& tens_list : tlist_args) {
71       for (const auto i : c10::irange(tens_list.size())) {
72         return tens_list.get(i).device();
73       }
74     }
75   }
76   return std::nullopt;
77 }
78 
validate_tensor_list(const c10::List<at::Tensor> & tensorlist)79 static bool validate_tensor_list(const c10::List<at::Tensor>& tensorlist) {
80   bool flag = false;
81 
82   for (const auto& i : c10::irange(tensorlist.size())) {
83     if (tensorlist[i].defined())
84       flag = true;
85   }
86 
87   return flag;
88 }
89 
cpu_fallback(const c10::OperatorHandle & op,torch::jit::Stack * stack,bool error_on_views,c10::DispatchKey cpu_dispatch_key)90 void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views,
91                   c10::DispatchKey cpu_dispatch_key) {
92   TORCH_CHECK(c10::BackendComponent::CPUBit == c10::toBackendComponent(cpu_dispatch_key),
93               "Expected CPU backend DispatchKey but got ",
94               c10::toString(cpu_dispatch_key));
95   auto& schema_args = op.schema().arguments();
96   const auto num_arguments = schema_args.size();
97   auto arguments = torch::jit::last(stack, num_arguments);
98   const auto arguments_begin = stack->size() - num_arguments;
99 
100   std::vector<at::Tensor> tensor_args;
101   std::vector<int> tensor_args_indices;
102 
103   std::vector<c10::List<at::Tensor>> tensorlist_args;
104   std::vector<int> tensorlist_args_indices;
105 
106   std::vector<c10::List<std::optional<at::Tensor>>> optional_tensorlist_args;
107   std::vector<int> optional_tensorlist_args_indices;
108 
109   std::optional<c10::Device> tgt_device = std::nullopt;
110   // save converted cpu tensor for TensorList and optional TensorList
111   std::vector<c10::IValue> tensorlist_cpu_args;
112   std::vector<c10::IValue> optional_tensorlist_cpu_args;
113 
114   // Step 1: Convert all non-CPU tensor inputs into CPU tensors
115   // and put them on the stack at the correct indices.
116   for (const auto idx : c10::irange(arguments.size())) {
117     const auto& ivalue = arguments[idx];
118     if (ivalue.isTensor()) {
119       tensor_args.push_back(ivalue.toTensor());
120       tensor_args_indices.push_back(idx);
121     } else if (ivalue.isTensorList()) {
122       // Note: we copy each TensorList argument to CPU individually out of convenience,
123       // but XLA would benefit from materializing all tensor and TensorList args onto the CPU at the same time.
124       // We can improve this if we need better perf for XLA's CPU fallbacks.
125       tensorlist_args.push_back(ivalue.toTensorList());
126       tensorlist_args_indices.push_back(idx);
127       auto cpu_ivalue = c10::IValue(c10::List<at::Tensor>(to_cpu(ivalue.toTensorVector())));
128       tensorlist_cpu_args.push_back(cpu_ivalue);
129       (*stack)[arguments_begin + idx] = std::move(cpu_ivalue);
130     } else if (ivalue.isOptionalTensorList()) {
131       optional_tensorlist_args.push_back(ivalue.toOptionalTensorList());
132       optional_tensorlist_args_indices.push_back(idx);
133       auto cpu_ivalue = c10::IValue(c10::List<std::optional<at::Tensor>>(to_cpu(ivalue.toOptionalTensorVector())));
134       optional_tensorlist_cpu_args.push_back(cpu_ivalue);
135       (*stack)[arguments_begin + idx] = c10::IValue(cpu_ivalue);
136     } else if (ivalue.isDevice()) {
137       tgt_device = ivalue.toDevice();
138       (*stack)[arguments_begin + idx] = c10::IValue(c10::Device(kCPU));
139     }
140   }
141   // XLA requires all of the tensor arguments to be gathered up and converted to CPU together.
142   auto cpu_tensors = to_cpu(tensor_args);
143 
144   for (const auto i : c10::irange(tensor_args_indices.size())) {
145     auto idx = tensor_args_indices[i];
146     (*stack)[arguments_begin + idx] = c10::IValue(cpu_tensors[i]);
147   }
148 
149   // Step 2: Call the underlying CPU implementation of the operator
150   op.redispatchBoxed(c10::DispatchKeySet(cpu_dispatch_key), stack);
151 
152   // Step 3: We need to take special care to handle mutable aliases properly:
153   // If any input tensors are mutable aliases, we need to
154   // directly copy the updated data on the CPU tensors back to the original inputs.
155   for (const auto i : c10::irange(tensor_args_indices.size())) {
156     auto tensor_idx = tensor_args_indices[i];
157     const AliasInfo* alias_info = schema_args[tensor_idx].alias_info();
158     if (alias_info != nullptr && alias_info->isWrite()) {
159       if (!tensor_args[i].defined()) continue;
160       at::_copy_from_and_resize(cpu_tensors[i], tensor_args[i]);
161     }
162   }
163 
164   // We also need to explicit reapply input mutations to inputs that are lists
165   // of tensors
166   for (const auto i : c10::irange(tensorlist_args_indices.size())) {
167     auto tensorlist_idx = tensorlist_args_indices[i];
168     const AliasInfo* alias_info = schema_args[tensorlist_idx].alias_info();
169     if (alias_info != nullptr && alias_info->isWrite()) {
170       const auto& cpu_tensors = tensorlist_cpu_args[i].toTensorVector();
171       for (const auto idx : c10::irange(tensorlist_args[i].size())) {
172         if (!cpu_tensors[idx].defined()) continue;
173         at::_copy_from_and_resize(cpu_tensors[idx], tensorlist_args[i][idx]);
174       }
175     }
176   }
177 
178   // We also need to explicit reapply input mutations to inputs that are lists
179   // of optional tensors
180   for (const auto i : c10::irange(optional_tensorlist_args_indices.size())) {
181     auto tensorlist_idx = optional_tensorlist_args_indices[i];
182     const AliasInfo* alias_info = schema_args[tensorlist_idx].alias_info();
183     if (alias_info != nullptr && alias_info->isWrite()) {
184       const auto& cpu_tensors = optional_tensorlist_cpu_args[i].toOptionalTensorList();
185       for (const auto idx : c10::irange(optional_tensorlist_args[i].size())) {
186         if (cpu_tensors[idx].has_value() && cpu_tensors[idx].value().defined()) {
187           const std::optional<at::Tensor>& optional_tensor = optional_tensorlist_args[i][idx];
188           at::_copy_from_and_resize(cpu_tensors[idx].value(), optional_tensor.value());
189         }
190       }
191     }
192   }
193 
194   // Step 4: Convert any CPU output tensors back to the original input device.
195   // For mutable alias'd outputs, we also need to take special care
196   // to move the ORIGINAL input tensor back onto the stack, in place of
197   // the temporary CPU output tensor that we created.
198   //
199   // Note [CPU Fallback Does Not Handle View Operators]
200   // Also note that we are incapable of handling immutable aliases properly.
201   // Why?
202   // Schemas with an immutable alias'd tensor outputs correspond to view operators.
203   // For example, the `view_as` schema from native_functions.yaml:
204   // `view_as(Tensor(a) self, Tensor other) -> Tensor(a)`
205   // We can't handle these ops properly, because view ops are supposed to return
206   // a NEW tensor that shares the SAME storage as the original tensor.
207   // However, the new tensor that we created cannot share the same storage,
208   // since it lives on CPU and the original tensor lives on a different device.
209   // Because of that, we warn if someone attempts to call the
210   // CPU fallback on a view operator (this is to maintain BC for view ops for XLA
211   // that fall back to CPU).
212   const auto& schema_returns = op.schema().returns();
213   const auto& num_returns = schema_returns.size();
214   auto returns = torch::jit::last(stack, num_returns);
215   const auto returns_begin = stack->size() - num_returns;
216 
217   if (tgt_device == std::nullopt) {
218     tgt_device = compute_target_device(tensor_args, tensorlist_args);
219   }
220 
221   for (const auto idx : c10::irange(returns.size())) {
222     const AliasInfo* alias_info = schema_returns[idx].alias_info();
223     if (alias_info != nullptr && alias_info->isWrite()) {
224       // Case (1): mutable alias case.
225       // Move the input ivalue directly onto the stack in place of
226       // the existing cpu output tensor.
227       bool found_alias = false;
228       if (returns[idx].isTensor() && returns[idx].toTensor().defined()) {
229         // We could store some extra metadata on the function schema to avoid
230         // the loop here if we need to improve perf.
231         for (const auto i : c10::irange(tensor_args_indices.size())) {
232           auto input_tensor_idx = tensor_args_indices[i];
233           const auto& input_tensor = cpu_tensors[i];
234           const AliasInfo* input_alias_info =
235               schema_args[input_tensor_idx].alias_info();
236           // Checked above; adding assert to guard against breakage of the below
237           // condition due to changing the above if test.
238           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alias_info != nullptr);
239           if (input_tensor.defined() &&
240               (alias_info == input_alias_info ||
241                (input_alias_info != nullptr &&
242                 *alias_info == *input_alias_info))) {
243             // We've found the original input tensor that aliases with the
244             // current output. Wrap it in an IValue and put it directly on the
245             // stack.
246             (*stack)[returns_begin + idx] = c10::IValue(tensor_args[i]);
247             found_alias = true;
248             break;
249           }
250         }
251       } else if (
252           returns[idx].isTensorList() &&
253           validate_tensor_list(returns[idx].toTensorList())) {
254         for (const auto i : c10::irange(tensorlist_args_indices.size())) {
255           auto input_tensor_idx = tensorlist_args_indices[i];
256           const AliasInfo* input_alias_info =
257               schema_args[input_tensor_idx].alias_info();
258           // Checked above; adding assert to guard against breakage of the below
259           // condition due to changing the above if test.
260           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alias_info != nullptr);
261           if (validate_tensor_list(tensorlist_args[i]) &&
262               (alias_info == input_alias_info ||
263                (input_alias_info != nullptr &&
264                 *alias_info == *input_alias_info))) {
265             // We've found the original input tensor that aliases with the
266             // current output. Wrap it in an IValue and put it directly on the
267             // stack.
268             (*stack)[returns_begin + idx] = c10::IValue(tensorlist_args[i]);
269             found_alias = true;
270             break;
271           }
272         }
273       }
274       TORCH_CHECK(
275           found_alias,
276           "The operator ",
277           op.schema().operator_name(),
278           " appears to have invalid alias information. ",
279           "Found a return tensor argument with a mismatched mutable alias: ",
280           schema_returns[idx]);
281     } else {
282       if (alias_info != nullptr && !alias_info->isWrite()) {
283         // Case (3): immutable alias (view) case.
284         // Warn here, since we're copying and not creating a view.
285         // If this operator is needed, the backend should provide a kernel for
286         // it. See Note [CPU Fallback Does Not Handle View Operators]
287         std::stringstream dev_str;
288         if (tgt_device) {
289           dev_str << *tgt_device;
290         } else {
291           dev_str << "<none>";
292         }
293         if (error_on_views) {
294           TORCH_CHECK(
295               false,
296               "The operator ",
297               op.schema().operator_name(),
298               " appears to be a view operator, ",
299               "but it has no implementation for the backend \"",
300               dev_str.str(),
301               "\". View operators don't support ",
302               "since the tensor's storage cannot be shared across devices.");
303         } else {
304           TORCH_WARN(
305               false,
306               "The operator ",
307               op.schema().operator_name(),
308               " appears to be a view operator, ",
309               "but it has no implementation for the backend \"",
310               dev_str.str(),
311               "\". View operators don't support falling back to run on the CPU, ",
312               "since the tensor's storage cannot be shared across devices.");
313         }
314       }
315       // Case (2): copy case.
316       // Copy the cpu output tensor to the original device.
317 
318       // We technically  might not have a target device, e.g. if you call
319       // torch.cat() with an empty list In that case, we shouldn't have any
320       // tensors to schlep across devices anyway.
321       if (tgt_device) {
322         if (returns[idx].isTensor() && returns[idx].toTensor().defined()) {
323           (*stack)[returns_begin + idx] =
324               c10::IValue(returns[idx].toTensor().to(*tgt_device));
325         } else if (
326             returns[idx].isTensorList() &&
327             validate_tensor_list(returns[idx].toTensorList())) {
328           const auto& cpu_tensors = returns[idx].toTensorList().vec();
329           std::vector<at::Tensor> tensors;
330           tensors.reserve(cpu_tensors.size());
331 
332           for (const auto& tensor : cpu_tensors) {
333             tensors.push_back(tensor.to(*tgt_device));
334           }
335           (*stack)[returns_begin + idx] =
336               c10::IValue(c10::List<at::Tensor>(tensors));
337         }
338       }
339     }
340   }
341 }
342 
343 } // namespace at::native
344