xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Equal.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/NamedTensorUtils.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/NativeFunctions.h>
7 #include <ATen/CUDAFunctions.h>
8 #else
9 #include <ATen/ops/eq_cuda_dispatch.h>
10 #include <ATen/ops/equal_native.h>
11 #endif
12 
13 namespace at::native {
14 
cuda_equal(const Tensor & self,const Tensor & src)15 bool cuda_equal(const Tensor& self, const Tensor &src) {
16   if (!at::namedinference::are_names_equal(
17           self.unsafeGetTensorImpl(), src.unsafeGetTensorImpl())) {
18     return false;
19   }
20   at::NoNamesGuard guard;
21   TORCH_CHECK(self.device() == src.device(), "Cannot compare two tensors on "
22               "different devices. Got: ", self.device(), " and ", src.device());
23   if (self.sizes() != src.sizes()) {
24     return false;
25   }
26   if (self.numel() == 0) {
27     return true;
28   }
29 
30   // This is the same optimization done in the cpu_equal. Since the flags like neg/conj should be already handled outside the
31   // cuda_equal, it should be safe to have the following fast path by
32   // ensuring the storage and strides exactly the same.
33   if (self.is_alias_of(src)
34       && self.storage_offset() == src.storage_offset()
35       && self.dtype() == src.dtype()
36       && self.is_contiguous() == src.is_contiguous()
37       && self.strides().equals(src.strides())
38       // Extra checks to ensure the safety in case cuda_equal is directly called in C++.
39       && self.layout() == src.layout()
40       && self.is_neg() == src.is_neg()
41       && self.is_conj() == src.is_conj()) {
42     return true;
43   }
44 
45   return at::cuda::eq(self, src).all().item().to<bool>();
46 }
47 
48 } // namespace at::native
49