xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h>
2 
3 #include <ATen/FunctionalTensorWrapper.h>
4 #include <ATen/Functions.h>
5 #include <ATen/core/boxing/KernelFunction.h>
6 #include <ATen/native/CPUFallback.h>
7 #include <torch/csrc/lazy/backend/backend_interface.h>
8 #include <torch/csrc/lazy/core/config.h>
9 #include <torch/csrc/lazy/core/metrics.h>
10 #include <torch/csrc/lazy/core/tensor.h>
11 #include <torch/library.h>
12 #include <sstream>
13 #include <unordered_map>
14 
15 namespace torch {
16 namespace lazy {
17 namespace {
18 
_to_eager(at::TensorList tensors,c10::DeviceType device_type)19 std::vector<at::Tensor> _to_eager(
20     at::TensorList tensors,
21     c10::DeviceType device_type) {
22   switch (device_type) {
23     case at::kCPU: {
24       return at::_to_cpu(tensors);
25     }
26     default: {
27       std::vector<at::Tensor> eager_tensors;
28       for (const auto& t : tensors) {
29         c10::TensorOptions options = t.options().device(device_type);
30         at::Tensor eager_tensor = t.to(
31             options,
32             /*non_blocking*/ false,
33             /*copy*/ false);
34         eager_tensors.push_back(eager_tensor);
35       }
36       return eager_tensors;
37     }
38   }
39 }
40 
41 // convenience helper for converting tensors to cpu
42 
to_eager(const at::TensorList & tensors,c10::DeviceType device_type)43 std::vector<at::Tensor> to_eager(
44     const at::TensorList& tensors,
45     c10::DeviceType device_type) {
46   // We can't just call _to_eager() on the entire list of Tensors because it
47   // will break on undefined tensors. Separate out undefined tensors first.
48   std::vector<at::Tensor> eager_tensors(tensors.size());
49   std::vector<at::Tensor> valid_tensors;
50   std::vector<bool> to_translate(tensors.size());
51   for (size_t i = 0; i < tensors.size(); ++i) {
52     const at::Tensor& tensor = tensors[i];
53     // Explicitly handling undefined tensors here instead of letting `_to_eager`
54     // handle it. Otherwise, we'd need to require all backends with their own
55     // implementation of _to_eager to properly handle undefined tensors.
56     if (tensor.defined()) {
57       to_translate[i] = true;
58       valid_tensors.push_back(tensor);
59     } else {
60       eager_tensors[i] = tensor;
61     }
62   }
63   auto eager_valid_tensors = _to_eager(valid_tensors, device_type);
64   for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) {
65     if (to_translate[i]) {
66       eager_tensors[i] = std::move(eager_valid_tensors[defined_pos++]);
67     }
68   }
69   return eager_tensors;
70 }
71 
to_eager(const std::vector<std::optional<at::Tensor>> & tensors,c10::DeviceType device_type)72 std::vector<std::optional<at::Tensor>> to_eager(
73     const std::vector<std::optional<at::Tensor>>& tensors,
74     c10::DeviceType device_type) {
75   // We can't just call _to_eager() on the entire list of Tensors because it
76   // will break on undefined tensors. Separate out undefined tensors first.
77   std::vector<std::optional<at::Tensor>> eager_tensors(tensors.size());
78   std::vector<at::Tensor> valid_tensors;
79   std::vector<bool> to_translate(tensors.size());
80   for (size_t i = 0; i < tensors.size(); ++i) {
81     const std::optional<at::Tensor>& tensor = tensors[i];
82     // Explicitly handling undefined tensors here instead of letting `_to_eager`
83     // handle it. Otherwise, we'd need to require all backends with their own
84     // implementation of _to_eager to properly handle undefined tensors.
85     if (tensor.has_value() && tensor->defined()) {
86       to_translate[i] = true;
87       valid_tensors.push_back(*tensor);
88     } else {
89       eager_tensors[i] = tensor;
90     }
91   }
92   auto eager_valid_tensors = _to_eager(valid_tensors, device_type);
93   for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) {
94     if (to_translate[i]) {
95       eager_tensors[i] = std::move(eager_valid_tensors[defined_pos++]);
96     }
97   }
98   return eager_tensors;
99 }
100 
dispatch_key(c10::DeviceType device_type)101 c10::DispatchKey dispatch_key(c10::DeviceType device_type) {
102   switch (device_type) {
103     case at::kCPU: {
104       return c10::DispatchKey::CPU;
105     }
106     case at::kCUDA: {
107       return c10::DispatchKey::CUDA;
108     }
109     default: {
110       AT_ERROR("Unsupported device type: ", device_type);
111     }
112   }
113 }
114 
compute_target_device(std::vector<at::Tensor> & t_args,std::vector<c10::List<at::Tensor>> tlist_args,std::vector<c10::List<std::optional<at::Tensor>>> opt_tlist_args)115 std::optional<c10::Device> compute_target_device(
116     std::vector<at::Tensor>& t_args,
117     std::vector<c10::List<at::Tensor>> tlist_args,
118     std::vector<c10::List<std::optional<at::Tensor>>> opt_tlist_args) {
119   // Decide what device to move the output tensor(s) to.
120   // The current convention is that we use the first tensor arg to pick the
121   // device Barring that, we take the first tensor from a TensorList arg.
122   if (!t_args.empty()) {
123     return t_args[0].device();
124   } else {
125     // We need to loop through all of the (potentially multiple) TensorList
126     // arguments In case, e.g. the first one is empty but the second is not.
127     for (auto& tens_list : tlist_args) {
128       for (const auto i : c10::irange(tens_list.size())) {
129         return tens_list.get(i).device();
130       }
131     }
132     for (auto& tens_list : opt_tlist_args) {
133       for (const auto i : c10::irange(tens_list.size())) {
134         if (tens_list.get(i).has_value()) {
135           return tens_list.get(i)->device();
136         }
137       }
138     }
139   }
140   return std::nullopt;
141 }
142 
143 } // namespace
144 
145 static std::unordered_map<std::string, ::torch::lazy::Counter*>
146     _eager_fallback_counters;
147 
force_eager_fallback(c10::Symbol op)148 bool force_eager_fallback(c10::Symbol op) {
149   auto force_str = getLTCForceFallback();
150   if (!force_str.empty()) {
151     static auto force_sym = c10::Symbol::fromQualString(std::string(force_str));
152     if (op == force_sym) {
153       return true;
154     }
155   }
156   if (op == at::aten::nonzero) {
157     // When symbolic shape mode is not enabled, the nonzero shape function
158     // returns an incorrect result.
159     return !symbolicShapeEnabled();
160   }
161 
162   return false;
163 }
164 
ltc_eager_fallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)165 void ltc_eager_fallback(
166     const c10::OperatorHandle& op,
167     torch::jit::Stack* stack) {
168   // TODO(whc) this FN_TRACK thing hasn't been used so far in LTC iirc but could
169   // land/re-enable it LTC_FN_TRACK(3);;
170   const auto name = c10::toString(op.operator_name());
171 
172   // Manually applying the TORCH_LAZY_COUNTER macro.
173   // We need to do it ourselves and explicitly keep a mapping of counters
174   // because this boxed fallback kernel is used by multiple operators,
175   // and the macro stamps out a static Counter object with a fixed name
176   // at the code location that it was called.
177   if (_eager_fallback_counters.find(name) == _eager_fallback_counters.end()) {
178     _eager_fallback_counters[name] = new ::torch::lazy::Counter(name);
179   }
180   _eager_fallback_counters[name]->AddValue(1);
181 
182   auto& args = op.schema().arguments();
183   auto arguments = torch::jit::last(stack, args.size());
184 
185   // Log each tensor argument.
186   for (const auto& ivalue : arguments) {
187     if (ivalue.isTensor()) {
188       VLOG(3) << ivalue.toTensor().toString();
189     }
190   }
191 
192   // Call the actual boxed CPU fallback.
193   ts_eager_fallback(
194       op, stack, torch::lazy::getBackend()->EagerFallbackDeviceType());
195 }
196 
register_ts_ltc_eager_fallback()197 void register_ts_ltc_eager_fallback() {
198   static auto m = MAKE_TORCH_LIBRARY_IMPL(_, Lazy);
199   // Most backends use TORCH_LIBRARY_* macros which perform their dispatcher
200   // registrations at static library init time, but the lazy Torchscript backend
201   // does not since it is built in the main torch lib but not always used.
202   // In particular, if another external backend wants to register itself to the
203   // same key (Lazy), Torchscript backend must not be initialized.
204   m.fallback(torch::CppFunction::makeFromBoxedFunction<&ltc_eager_fallback>());
205 }
206 
ts_eager_fallback(const c10::OperatorHandle & op,torch::jit::Stack * stack,c10::DeviceType device_type)207 void ts_eager_fallback(
208     const c10::OperatorHandle& op,
209     torch::jit::Stack* stack,
210     c10::DeviceType device_type) {
211   auto& schema_args = op.schema().arguments();
212   const auto num_arguments = schema_args.size();
213   auto arguments = torch::jit::last(stack, num_arguments);
214   const auto arguments_begin = stack->size() - num_arguments;
215 
216   std::vector<at::Tensor> tensor_args;
217   std::vector<int> tensor_args_indices;
218 
219   std::vector<c10::List<at::Tensor>> tensorlist_args;
220   std::vector<c10::List<std::optional<at::Tensor>>> opt_tensorlist_args;
221 
222   // Step 1: Convert all non-eager tensor inputs into eager tensors and put them
223   // on the stack at the correct indices.
224   for (size_t idx = 0; idx < arguments.size(); ++idx) {
225     const auto& ivalue = arguments[idx];
226     if (ivalue.isTensor()) {
227       tensor_args.push_back(ivalue.toTensor());
228       tensor_args_indices.push_back(idx);
229     } else if (ivalue.isTensorList()) {
230       // Note: we copy each TensorList argument to eager individually out of
231       // convenience, but XLA would benefit from materializing all tensor and
232       // TensorList args onto the CPU at the same time. We can improve this if
233       // we need better perf for XLA's CPU fallbacks.
234       auto eager_ivalue = c10::IValue(c10::List<at::Tensor>(
235           to_eager(ivalue.toTensorVector(), device_type)));
236       (*stack)[arguments_begin + idx] = std::move(eager_ivalue);
237       tensorlist_args.push_back(ivalue.toTensorList());
238     } else if (ivalue.isOptionalTensorList()) {
239       auto eager_ivalue = c10::IValue(c10::List<std::optional<at::Tensor>>(
240           to_eager(ivalue.toOptionalTensorVector(), device_type)));
241       (*stack)[arguments_begin + idx] = std::move(eager_ivalue);
242       opt_tensorlist_args.push_back(ivalue.toOptionalTensorList());
243     }
244   }
245   // XLA requires all of the tensor arguments to be gathered up and converted to
246   // CPU together.
247   auto eager_tensors = to_eager(tensor_args, device_type);
248 
249   for (const auto i : c10::irange(tensor_args_indices.size())) {
250     auto idx = tensor_args_indices[i];
251     (*stack)[arguments_begin + idx] = c10::IValue(eager_tensors[i]);
252   }
253 
254   // Step 2: Call the underlying eager implementation of the operator
255   op.redispatchBoxed(c10::DispatchKeySet(dispatch_key(device_type)), stack);
256 
257   // Step 3: We need to take special care to handle mutable aliases properly:
258   // If any input tensors are mutable aliases, we need to directly copy the
259   // updated data on the eager tensors back to the original inputs.
260   for (const auto i : c10::irange(tensor_args_indices.size())) {
261     auto tensor_idx = tensor_args_indices[i];
262     const auto alias_info = schema_args[tensor_idx].alias_info();
263     if (alias_info != nullptr && alias_info->isWrite()) {
264       at::_copy_from_and_resize(eager_tensors[i], tensor_args[i]);
265     }
266   }
267 
268   // Step 4: Convert any eager output tensors back to the original input device.
269   // For mutable alias'd outputs, we also need to take special care
270   // to move the ORIGINAL input tensor back onto the stack, in place of
271   // the temporary eager output tensor that we created.
272   //
273   // Note [Eager Fallback Does Not Handle View Operators]
274   // Also note that we are incapable of handling immutable alises properly.
275   // Why?
276   // Schemas with an immutable alias'd tensor outputs correspond to view
277   // operators. For example, the `view_as` schema from native_functions.yaml:
278   // `view_as(Tensor(a) self, Tensor other) -> Tensor(a)`
279   // We can't handle these ops properly, because view ops are supposed to return
280   // a NEW tensor that shares the SAME storage as the original tensor.
281   // However, the new tensor that we created cannot share the same storage,
282   // since it lives on the eager CPU / CUDA device and the original tensor lives
283   // on a different device. Because of that, we warn if someone attempts to call
284   // the eager fallback on a view operator (this is to maintain BC for view ops
285   // for XLA that fall back to CPU).
286   const auto& schema_returns = op.schema().returns();
287   const auto& num_returns = schema_returns.size();
288   auto returns = torch::jit::last(stack, num_returns);
289   const auto returns_begin = stack->size() - num_returns;
290 
291   for (const auto idx : c10::irange(returns.size())) {
292     if (returns[idx].isTensor()) {
293       const auto& return_tens = returns[idx].toTensor();
294       if (return_tens.defined()) {
295         const auto alias_info = schema_returns[idx].alias_info();
296         if (alias_info != nullptr && alias_info->isWrite()) {
297           // Case (1): mutable alias case. Move the input ivalue directly onto
298           // the stack in place of the existing eager output tensor.
299           bool found_alias = false;
300           // We could store some extra metadata on the function schema to avoid
301           // the loop here if we need to improve perf.
302           for (const auto i : c10::irange(tensor_args_indices.size())) {
303             auto input_tensor_idx = tensor_args_indices[i];
304             const auto& input_tensor = eager_tensors[i];
305             const auto input_alias_info =
306                 schema_args[input_tensor_idx].alias_info();
307             if (input_tensor.defined() && input_alias_info != nullptr &&
308                 *alias_info == *input_alias_info) {
309               // We've found the original input tensor that aliases with the
310               // current output. Wrap it in an IValue and put it directly on the
311               // stack.
312               (*stack)[returns_begin + idx] = c10::IValue(tensor_args[i]);
313               found_alias = true;
314               break;
315             }
316           }
317           TORCH_CHECK(
318               found_alias,
319               "The operator ",
320               op.schema().operator_name(),
321               " appears to have invalid alias information. ",
322               "Found a return tensor argument with a mismatched "
323               "mutable alias: ",
324               schema_returns[idx]);
325         } else {
326           std::optional<c10::Device> tgt_device = compute_target_device(
327               tensor_args, tensorlist_args, opt_tensorlist_args);
328           if (alias_info != nullptr && !alias_info->isWrite()) {
329             // immutable alias (view) case: Warn here, since we're copying and
330             // not creating a view.
331             // If this operator is needed, the backend should provide a kernel
332             // for it.
333             // See Note [Eager Fallback Does Not Handle View Operators]
334             std::stringstream dev_str;
335             if (tgt_device) {
336               dev_str << *tgt_device;
337             } else {
338               dev_str << "<none>";
339             }
340             // We should never hit this for a view op,
341             // because LazyTensor should provide a lowering for the
342             // corresponding view_copy operator. The functionalization pass will
343             // take care of calling the view_copy operator intead of the view.
344             TORCH_CHECK(
345                 false,
346                 "The operator ",
347                 op.schema().operator_name(),
348                 " appears to be a view operator, ",
349                 "but it has no implementation for the backend \"",
350                 dev_str.str(),
351                 "\". View operators don't support ",
352                 "falling back to run on the eager, since the tensor's "
353                 "storage cannot be shared across devices.");
354           }
355           // Case (2): copy case. Copy the eager output tensor to the original
356           // device.
357 
358           // We technically  might not have a target device, e.g. if you call
359           // torch.cat() with an empty list In that case, we shouldn't have any
360           // tensors to schlep across devices anyway.
361           if (tgt_device) {
362             (*stack)[returns_begin + idx] =
363                 c10::IValue(returns[idx].toTensor().to(*tgt_device));
364           }
365         }
366       }
367     }
368   }
369 }
370 
371 } // namespace lazy
372 } // namespace torch
373