1 #pragma once
2 #include <ATen/core/List.h>
3 #include <ATen/core/Tensor.h>
4 #include <c10/core/impl/TorchDispatchModeTLS.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/equal.h>
10 #endif
11
12 namespace at {
13
14 // Note [Tensor-subclass-like Tensors]
15 // Tensor-subclass-like is defined as:
16 // - a Tensor subclass (via __torch_dispatch__ in Python or extending
17 // TensorImpl in C++)
18 // - anything else that shares the same perils as Tensor subclasses.
19 // For example, many Tensor subclasses do not have storage and meta Tensors
20 // do not have storage either, so meta Tensors belong here.
21 //
22 // We should ensure that PyTorch internals supports Tensor-subclass-like
23 // objects. In particular, Tensor-subclass-like objects struggle with two
24 // classes of operations that are problematic for Tensor subclasses:
25 // 1. Because some Tensor subclasses do not have storage, .item() or
26 // .data_ptr() calls are not good.
27 // 2. Certain in-place operations can eliminate the typing of the Tensor
28 // subclass. For example:
29 // >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
30 // If input is a Tensor subclass, then the above ends up either erroring out
31 // or returning a regular non-Tensor-subclass Tensor!
32
33 constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
34 {DispatchKey::FuncTorchGradWrapper,
35 DispatchKey::FuncTorchBatched,
36 DispatchKey::Functionalize});
37
38 constexpr auto kTensorSubclassLike =
39 kFunctorchWrappedTensors |
40 DispatchKeySet(
41 {// WARNING: DO NOT put combined backend component + functionality keys
42 // here, you will incorrectly always match on the functionality key
43 // no matter the backend component
44 DispatchKey::Batched,
45 DispatchKey::Sparse,
46 DispatchKey::SparseCsr,
47 DispatchKey::Python}) |
48 DispatchKeySet(BackendComponent::MetaBit);
49
isTensorSubclassLike(const Tensor & tensor)50 inline bool isTensorSubclassLike(const Tensor& tensor) {
51 if (c10::impl::dispatch_mode_enabled())
52 return true;
53 auto key_set = tensor.unsafeGetTensorImpl()->key_set();
54 return !(key_set & kTensorSubclassLike).empty();
55 }
56
areAnyTensorSubclassLike(TensorList tensors)57 inline bool areAnyTensorSubclassLike(TensorList tensors) {
58 if (c10::impl::dispatch_mode_enabled())
59 return true;
60 return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
61 }
62
areAnyOptionalTensorSubclassLike(const c10::List<std::optional<Tensor>> & tensors)63 inline bool areAnyOptionalTensorSubclassLike(
64 const c10::List<std::optional<Tensor>>& tensors) {
65 if (c10::impl::dispatch_mode_enabled())
66 return true;
67 return std::any_of(
68 tensors.begin(),
69 tensors.end(),
70 [](const std::optional<Tensor>& opt_tensor) {
71 return (
72 opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
73 });
74 }
75
76 // Helper function to deal testing truthfulness of a scalar tensor
77 // in a Composite Compliant manner.
78 // NOTE: This function expects a scalar tensor of boolean dtype.
79 // Eg.
80 // Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
81 // Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
is_scalar_tensor_true(const Tensor & t)82 inline bool is_scalar_tensor_true(const Tensor& t) {
83 TORCH_INTERNAL_ASSERT(t.dim() == 0)
84 TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
85 return at::equal(t, t.new_ones({}, t.options()));
86 }
87
88 } // namespace at
89