xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/TorchDispatchUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/TorchDispatchUtils.h>
2 
3 
4 namespace at::impl {
5 
tensor_has_dispatch(const at::Tensor & t)6 bool tensor_has_dispatch(const at::Tensor& t) {
7   DispatchKeySet key_set({DispatchKey::Python, DispatchKey::PythonTLSSnapshot});
8   return t.key_set().has_any(key_set);
9 }
10 
tensorlist_has_dispatch(at::ITensorListRef li)11 bool tensorlist_has_dispatch(at::ITensorListRef li) {
12   for (const auto& t : li) {
13     if (tensor_has_dispatch(t)) {
14       return true;
15     }
16   }
17   return false;
18 }
19 
tensorlist_has_dispatch(const c10::List<std::optional<at::Tensor>> & li)20 bool tensorlist_has_dispatch(const c10::List<std::optional<at::Tensor>>& li) {
21   for (auto i : c10::irange(li.size())) {
22     auto t = li.get(i);
23     if (t && tensor_has_dispatch(*t)) {
24       return true;
25     }
26   }
27   return false;
28 }
29 
30 } // namespace at::impl
31