xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ZeroTensorFallback.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/core/dispatch/Dispatcher.h>
3 #include <ATen/core/op_registration/op_registration.h>
4 #include <ATen/native/UnaryOps.h>
5 #include <ATen/NativeFunctions.h>
6 #include <c10/util/irange.h>
7 #include <torch/library.h>
8 #include <ATen/native/MathBitFallThroughLists.h>
9 
10 namespace at {
11 
12   // TODO: add a note explaining the design decisions
13   // ZeroTensors are designed to be immutable. Thus, we error out when an in-place operation is performed on ZeroTensors
zeroTensorFallback(const c10::OperatorHandle & op,DispatchKeySet dispatch_keys,torch::jit::Stack * stack)14   static void zeroTensorFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
15     const auto& arguments = op.schema().arguments();
16     const auto num_arguments = arguments.size();
17     const auto stack_start = stack->size() - num_arguments;
18 
19     std::optional<bool> is_write;
20     for (const auto i : c10::irange(num_arguments)) {
21       const auto& alias_info = arguments[i].alias_info();
22       if (alias_info != nullptr) {
23         if (is_write.has_value()) {
24           TORCH_CHECK(*is_write == alias_info->isWrite(),
25             "Unsupported operator for ", "ZeroTensorFallback: ", op.schema().name(),
26             "ZeroTensor fallback doesn't work for operators with a mix "
27             "mutable and non-mutable inputs that alias with outputs, "
28             "this must be implemented manually.  "
29             "If you got this error on a core op, please report a bug to PyTorch.");
30         } else {
31           is_write = alias_info->isWrite();
32         }
33       }
34     }
35 
36     if (is_write.has_value() && !*is_write) {
37       // We assume that view operators automatically handle the ZeroTensor bit
38       // correctly by propagating the dispatch key in key_set.
39       // This is not necessarily always right, so you should test these cases.
40       op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
41       return;
42     }
43 
44     for (const auto i : c10::irange(num_arguments)) {
45       auto& ivalue = (*stack)[stack_start + i];
46       if (!(ivalue.isTensor() || ivalue.isTensorList())) {
47         continue;
48       }
49       const auto& argument = arguments[i];
50       bool mut_arg = false;
51 
52       if (argument.alias_info()) {
53         // Was already tested by is_write loop above
54         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
55         mut_arg = true;
56       }
57 
58       if (ivalue.isTensor()) {
59         auto tensor = std::move(ivalue).toTensor();
60         if (tensor._is_zerotensor()) {
61           TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
62                     "obtained using .clone() if you want a mutable tensor.");
63           tensor = at::zeros({}, tensor.options()).expand(tensor.sizes());
64         }
65         (*stack)[stack_start + i] = std::move(tensor);
66       } else if (ivalue.isTensorList()) {
67         auto tensors = std::move(ivalue).toTensorList();
68         for(const auto j : c10::irange(tensors.size())) {
69           const Tensor& tensor = tensors[j];
70           if (tensor._is_zerotensor()) {
71             // TODO: assert requires_grad=False
72             //_like should not propagate zerotensor dispatch key
73             TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
74                     "obtained using .clone() if you want a mutable tensor.");
75             tensors[j] = at::zeros({}, tensor.options()).expand(tensor.sizes());
76           }
77         }
78         (*stack)[stack_start + i] = std::move(tensors);
79       }
80     }
81 
82     op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
83   }
84 
85 
TORCH_LIBRARY_IMPL(_,ZeroTensor,m)86   TORCH_LIBRARY_IMPL(_, ZeroTensor, m) {
87     m.fallback(torch::CppFunction::makeFromBoxedFunction<&zeroTensorFallback>());
88   }
89 
TORCH_LIBRARY_IMPL(aten,ZeroTensor,m)90   TORCH_LIBRARY_IMPL(aten, ZeroTensor, m) {
91     m.impl("zeros_like", torch::CppFunction::makeFallthrough());
92     m.impl("mul.Scalar", torch::CppFunction::makeFallthrough());
93     m.impl("add.Scalar", torch::CppFunction::makeFallthrough());
94     m.impl("copy_", torch::CppFunction::makeFallthrough());
95     m.impl("clone", torch::CppFunction::makeFallthrough());
96     m.impl("dot", torch::CppFunction::makeFallthrough());
97     m.impl("vdot", torch::CppFunction::makeFallthrough());
98     // The functions in the list below have a specific registeration in native_functions.yaml and
99     // do not use the fallback.
100     // m.impl("mul.Tensor", torch::CppFunction::makeFallthrough());
101     // m.impl("add.Tensor", torch::CppFunction::makeFallthrough());
102     // m.impl("linalg_cross", torch::CppFunction::makeFallthrough());
103 
104     TORCH_VIEW_FNS(m)
105     TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
106   }
107 } // namespace at
108