1 #include <ATen/core/TensorBase.h>
2 #include <ATen/core/TensorBody.h>
3 #include <c10/util/OptionalArrayRef.h>
4
5 #ifdef AT_PER_OPERATOR_HEADERS
6 #include <ATen/ops/_assert_tensor_metadata_native.h>
7 #endif
8
9 namespace at {
10
11 class Tensor;
12
13 namespace native {
14
15 template<typename O, typename C>
_assert_match(const O & original,const C & compared,const std::string & name)16 void _assert_match(const O& original, const C& compared, const std::string& name) {
17 if (compared) {
18 bool equal = (original == compared.value());
19 if (!equal) {
20 std::stringstream msg;
21 msg << "Tensor " << name << " mismatch!";
22 AT_ASSERT(equal, msg.str());
23 }
24 }
25 }
26
_assert_tensor_metadata(at::Tensor const & tensor,at::OptionalIntArrayRef sizes,at::OptionalIntArrayRef strides,std::optional<c10::ScalarType> dtype)27 void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional<c10::ScalarType> dtype) {
28 _assert_match(tensor.sizes(), sizes, "sizes");
29 _assert_match(tensor.strides(), strides, "strides");
30 _assert_match(tensor.dtype(), dtype, "dtype");
31 }
32
33 }
34 } // namespace at::native
35