1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4
5 using namespace at;
6
7 // TODO: failing sizes {4, 1, 4, 1}
8 std::vector<std::vector<int64_t>> sizes = {{4, 4, 4, 4}, {4, 4, 1, 1}, {4, 1, 4, 4}, {4, 1, 1, 4}, {1, 4, 1, 4}, {1, 4, 4, 1}};
9
CheckStrideIndices(const Tensor & t,at::MemoryFormat format)10 inline bool CheckStrideIndices(const Tensor& t, at::MemoryFormat format) {
11 size_t n_dim = t.dim();
12 std::vector<size_t> stride_indices(n_dim);
13 if (format == at::MemoryFormat::ChannelsLast) {
14 // stride_indices_ should be {1, n-1, n-2, ..., 2, 0}
15 std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
16 stride_indices[0] = 1;
17 stride_indices[n_dim - 1] = 0;
18 } else if (format == at::MemoryFormat::Contiguous) {
19 // stride_indices_ should be {n-1, n-2, n-3, ..., 0}
20 std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
21 } else {
22 TORCH_INTERNAL_ASSERT(false, "not recognized memory format");
23 }
24
25 // testing computeStrideProps with `IValue ival(t)` somehow doesn't work on CI
26 // with onnx; The function works fine within, but stride properties is somehow
27 // altered in ival->type()->cast<TensorType>();
28 auto tt = TensorType::create(std::nullopt, std::nullopt, t.sizes(), t.strides(), std::nullopt);
29 TORCH_INTERNAL_ASSERT(tt->stride_properties().isComplete(), "complete stride properties is needed for the test");
30
31 auto index_iter = stride_indices.begin();
32 for (const auto& opt_stride : *tt->stride_properties().sizes()) {
33 if (*index_iter++ != opt_stride->stride_index_.value()) {
34 return false;
35 }
36 }
37
38 return true;
39 }
40
TEST(StridePropertiesTest,StrideIndicesTest)41 TEST(StridePropertiesTest, StrideIndicesTest) {
42 // NOLINTNEXTLINE(performance-for-range-copy)
43 for (const auto& size : sizes) {
44 Tensor t = at::rand(size);
45 for (auto memory_format : {at::MemoryFormat::ChannelsLast, at::MemoryFormat::Contiguous}) {
46 t.resize_(size, memory_format);
47 EXPECT_TRUE(CheckStrideIndices(t, memory_format));
48 }
49 }
50 }
51
TEST(StridePropertiesTest,ZeroStrideIndicesEagerConsistencyTest)52 TEST(StridePropertiesTest, ZeroStrideIndicesEagerConsistencyTest) {
53 auto permuted_tensor = at::rand({6, 3, 1, 5, 2}).permute({0, 3, 2, 1, 4}); // permute dim-1 & dim-3
54 auto tensor = permuted_tensor.expand({6, 5, 4, 3, 2}); // expand dim-2
55
56 auto temp = TensorType::create(std::nullopt, std::nullopt, tensor.sizes(), tensor.strides(), std::nullopt);
57
58 // TensorIterator would preserve stride order, this is the eager reference
59 auto eager_tensor = tensor.relu();
60 auto ref_type = TensorType::create(std::nullopt, std::nullopt, eager_tensor.sizes(), eager_tensor.strides(), std::nullopt);
61
62 TORCH_INTERNAL_ASSERT(temp->stride_properties().isComplete() &&
63 temp->stride_properties().isComplete(), "complete stride properties is needed for the test");
64 auto ref_iter = (*(ref_type->stride_properties().sizes())).begin();
65 for (const auto& opt_stride : *temp->stride_properties().sizes()) {
66 EXPECT_TRUE(opt_stride->stride_index_.value() == (*ref_iter)->stride_index_.value());
67 ref_iter++;
68 }
69 }
70
TEST(StridePropertiesTest,ExpandedStrideIndicesTest)71 TEST(StridePropertiesTest, ExpandedStrideIndicesTest) {
72 Tensor t = at::rand({1});
73 // note: expand with dimension of size 1 is tricky as stride is different
74 // depending on the order of the unsqueezed dimension.
75 t = t.expand({4, 4, 4});
76 EXPECT_TRUE(CheckStrideIndices(t, at::MemoryFormat::Contiguous));
77 }
78
TEST(StridePropertiesTest,SlicedStrideIndicesTest)79 TEST(StridePropertiesTest, SlicedStrideIndicesTest) {
80 // Sliced tensor shouldn't have changed stride order
81 Tensor t = at::rand({16, 4}).slice(1, 0, 4, 4);
82
83 auto temp = TensorType::create(std::nullopt, std::nullopt, t.sizes(), t.strides(), std::nullopt);
84 TORCH_INTERNAL_ASSERT(temp->stride_properties().isComplete() &&
85 temp->stride_properties().isComplete(), "complete stride properties is needed for the test");
86 std::vector<size_t> stride_indices(2);
87 std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
88
89 auto index_iter = stride_indices.begin();
90 for (const auto& opt_stride : *temp->stride_properties().sizes()) {
91 EXPECT_TRUE(*index_iter++ == opt_stride->stride_index_.value());
92 }
93 }
94