1 #include <ATen/ATen.h> 2 #include <ATen/core/dispatch/Dispatcher.h> 3 #include <ATen/core/op_registration/op_registration.h> 4 #include <ATen/native/UnaryOps.h> 5 #include <ATen/NativeFunctions.h> 6 #include <c10/util/irange.h> 7 #include <torch/library.h> 8 #include <ATen/native/MathBitFallThroughLists.h> 9 10 namespace at { 11 12 // TODO: add a note explaining the design decisions 13 // ZeroTensors are designed to be immutable. Thus, we error out when an in-place operation is performed on ZeroTensors zeroTensorFallback(const c10::OperatorHandle & op,DispatchKeySet dispatch_keys,torch::jit::Stack * stack)14 static void zeroTensorFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { 15 const auto& arguments = op.schema().arguments(); 16 const auto num_arguments = arguments.size(); 17 const auto stack_start = stack->size() - num_arguments; 18 19 std::optional<bool> is_write; 20 for (const auto i : c10::irange(num_arguments)) { 21 const auto& alias_info = arguments[i].alias_info(); 22 if (alias_info != nullptr) { 23 if (is_write.has_value()) { 24 TORCH_CHECK(*is_write == alias_info->isWrite(), 25 "Unsupported operator for ", "ZeroTensorFallback: ", op.schema().name(), 26 "ZeroTensor fallback doesn't work for operators with a mix " 27 "mutable and non-mutable inputs that alias with outputs, " 28 "this must be implemented manually. " 29 "If you got this error on a core op, please report a bug to PyTorch."); 30 } else { 31 is_write = alias_info->isWrite(); 32 } 33 } 34 } 35 36 if (is_write.has_value() && !*is_write) { 37 // We assume that view operators automatically handle the ZeroTensor bit 38 // correctly by propagating the dispatch key in key_set. 39 // This is not necessarily always right, so you should test these cases. 40 op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack); 41 return; 42 } 43 44 for (const auto i : c10::irange(num_arguments)) { 45 auto& ivalue = (*stack)[stack_start + i]; 46 if (!(ivalue.isTensor() || ivalue.isTensorList())) { 47 continue; 48 } 49 const auto& argument = arguments[i]; 50 bool mut_arg = false; 51 52 if (argument.alias_info()) { 53 // Was already tested by is_write loop above 54 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite()); 55 mut_arg = true; 56 } 57 58 if (ivalue.isTensor()) { 59 auto tensor = std::move(ivalue).toTensor(); 60 if (tensor._is_zerotensor()) { 61 TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ", 62 "obtained using .clone() if you want a mutable tensor."); 63 tensor = at::zeros({}, tensor.options()).expand(tensor.sizes()); 64 } 65 (*stack)[stack_start + i] = std::move(tensor); 66 } else if (ivalue.isTensorList()) { 67 auto tensors = std::move(ivalue).toTensorList(); 68 for(const auto j : c10::irange(tensors.size())) { 69 const Tensor& tensor = tensors[j]; 70 if (tensor._is_zerotensor()) { 71 // TODO: assert requires_grad=False 72 //_like should not propagate zerotensor dispatch key 73 TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ", 74 "obtained using .clone() if you want a mutable tensor."); 75 tensors[j] = at::zeros({}, tensor.options()).expand(tensor.sizes()); 76 } 77 } 78 (*stack)[stack_start + i] = std::move(tensors); 79 } 80 } 81 82 op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack); 83 } 84 85 TORCH_LIBRARY_IMPL(_,ZeroTensor,m)86 TORCH_LIBRARY_IMPL(_, ZeroTensor, m) { 87 m.fallback(torch::CppFunction::makeFromBoxedFunction<&zeroTensorFallback>()); 88 } 89 TORCH_LIBRARY_IMPL(aten,ZeroTensor,m)90 TORCH_LIBRARY_IMPL(aten, ZeroTensor, m) { 91 m.impl("zeros_like", torch::CppFunction::makeFallthrough()); 92 m.impl("mul.Scalar", torch::CppFunction::makeFallthrough()); 93 m.impl("add.Scalar", torch::CppFunction::makeFallthrough()); 94 m.impl("copy_", torch::CppFunction::makeFallthrough()); 95 m.impl("clone", torch::CppFunction::makeFallthrough()); 96 m.impl("dot", torch::CppFunction::makeFallthrough()); 97 m.impl("vdot", torch::CppFunction::makeFallthrough()); 98 // The functions in the list below have a specific registeration in native_functions.yaml and 99 // do not use the fallback. 100 // m.impl("mul.Tensor", torch::CppFunction::makeFallthrough()); 101 // m.impl("add.Tensor", torch::CppFunction::makeFallthrough()); 102 // m.impl("linalg_cross", torch::CppFunction::makeFallthrough()); 103 104 TORCH_VIEW_FNS(m) 105 TENSOR_UTILITIES_AND_CONSTRUCTORS(m) 106 } 107 } // namespace at 108