xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ComparisonUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/TensorBase.h>
2 #include <ATen/core/TensorBody.h>
3 #include <c10/util/OptionalArrayRef.h>
4 
5 #ifdef AT_PER_OPERATOR_HEADERS
6 #include <ATen/ops/_assert_tensor_metadata_native.h>
7 #endif
8 
9 namespace at {
10 
11 class Tensor;
12 
13 namespace native {
14 
15 template<typename O, typename C>
_assert_match(const O & original,const C & compared,const std::string & name)16 void _assert_match(const O& original, const C& compared, const std::string& name) {
17   if (compared) {
18     bool equal = (original == compared.value());
19     if (!equal) {
20       std::stringstream msg;
21       msg << "Tensor " << name << " mismatch!";
22       AT_ASSERT(equal, msg.str());
23     }
24   }
25 }
26 
_assert_tensor_metadata(at::Tensor const & tensor,at::OptionalIntArrayRef sizes,at::OptionalIntArrayRef strides,std::optional<c10::ScalarType> dtype)27 void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional<c10::ScalarType> dtype) {
28   _assert_match(tensor.sizes(), sizes, "sizes");
29   _assert_match(tensor.strides(), strides, "strides");
30   _assert_match(tensor.dtype(), dtype, "dtype");
31 }
32 
33 }
34 }  // namespace at::native
35