xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/stride_properties_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 
5 using namespace at;
6 
7 // TODO: failing sizes {4, 1, 4, 1}
8 std::vector<std::vector<int64_t>> sizes = {{4, 4, 4, 4}, {4, 4, 1, 1}, {4, 1, 4, 4}, {4, 1, 1, 4}, {1, 4, 1, 4}, {1, 4, 4, 1}};
9 
CheckStrideIndices(const Tensor & t,at::MemoryFormat format)10 inline bool CheckStrideIndices(const Tensor& t, at::MemoryFormat format) {
11   size_t n_dim = t.dim();
12   std::vector<size_t> stride_indices(n_dim);
13   if (format == at::MemoryFormat::ChannelsLast) {
14     // stride_indices_ should be {1, n-1, n-2, ..., 2, 0}
15     std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
16     stride_indices[0] = 1;
17     stride_indices[n_dim - 1] = 0;
18   } else if (format == at::MemoryFormat::Contiguous) {
19     // stride_indices_ should be {n-1, n-2, n-3, ..., 0}
20     std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
21   } else {
22     TORCH_INTERNAL_ASSERT(false, "not recognized memory format");
23   }
24 
25   // testing computeStrideProps with `IValue ival(t)` somehow doesn't work on CI
26   // with onnx; The function works fine within, but stride properties is somehow
27   // altered in ival->type()->cast<TensorType>();
28   auto tt = TensorType::create(std::nullopt, std::nullopt, t.sizes(), t.strides(), std::nullopt);
29   TORCH_INTERNAL_ASSERT(tt->stride_properties().isComplete(), "complete stride properties is needed for the test");
30 
31   auto index_iter = stride_indices.begin();
32   for (const auto& opt_stride : *tt->stride_properties().sizes()) {
33     if (*index_iter++ != opt_stride->stride_index_.value()) {
34       return false;
35     }
36   }
37 
38   return true;
39 }
40 
TEST(StridePropertiesTest,StrideIndicesTest)41 TEST(StridePropertiesTest, StrideIndicesTest) {
42   // NOLINTNEXTLINE(performance-for-range-copy)
43   for (const auto& size : sizes) {
44     Tensor t = at::rand(size);
45     for (auto memory_format : {at::MemoryFormat::ChannelsLast, at::MemoryFormat::Contiguous}) {
46       t.resize_(size, memory_format);
47       EXPECT_TRUE(CheckStrideIndices(t, memory_format));
48     }
49   }
50 }
51 
TEST(StridePropertiesTest,ZeroStrideIndicesEagerConsistencyTest)52 TEST(StridePropertiesTest, ZeroStrideIndicesEagerConsistencyTest) {
53   auto permuted_tensor = at::rand({6, 3, 1, 5, 2}).permute({0, 3, 2, 1, 4}); // permute dim-1 & dim-3
54   auto tensor = permuted_tensor.expand({6, 5, 4, 3, 2}); // expand dim-2
55 
56   auto temp = TensorType::create(std::nullopt, std::nullopt, tensor.sizes(), tensor.strides(), std::nullopt);
57 
58   // TensorIterator would preserve stride order, this is the eager reference
59   auto eager_tensor = tensor.relu();
60   auto ref_type = TensorType::create(std::nullopt, std::nullopt, eager_tensor.sizes(), eager_tensor.strides(), std::nullopt);
61 
62   TORCH_INTERNAL_ASSERT(temp->stride_properties().isComplete() &&
63       temp->stride_properties().isComplete(), "complete stride properties is needed for the test");
64   auto ref_iter = (*(ref_type->stride_properties().sizes())).begin();
65   for (const auto& opt_stride : *temp->stride_properties().sizes()) {
66     EXPECT_TRUE(opt_stride->stride_index_.value() == (*ref_iter)->stride_index_.value());
67     ref_iter++;
68   }
69 }
70 
TEST(StridePropertiesTest,ExpandedStrideIndicesTest)71 TEST(StridePropertiesTest, ExpandedStrideIndicesTest) {
72   Tensor t = at::rand({1});
73   // note: expand with dimension of size 1 is tricky as stride is different
74   // depending on the order of the unsqueezed dimension.
75   t = t.expand({4, 4, 4});
76   EXPECT_TRUE(CheckStrideIndices(t, at::MemoryFormat::Contiguous));
77 }
78 
TEST(StridePropertiesTest,SlicedStrideIndicesTest)79 TEST(StridePropertiesTest, SlicedStrideIndicesTest) {
80   // Sliced tensor shouldn't have changed stride order
81   Tensor t = at::rand({16, 4}).slice(1, 0, 4, 4);
82 
83   auto temp = TensorType::create(std::nullopt, std::nullopt, t.sizes(), t.strides(), std::nullopt);
84   TORCH_INTERNAL_ASSERT(temp->stride_properties().isComplete() &&
85       temp->stride_properties().isComplete(), "complete stride properties is needed for the test");
86   std::vector<size_t> stride_indices(2);
87   std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
88 
89   auto index_iter = stride_indices.begin();
90   for (const auto& opt_stride : *temp->stride_properties().sizes()) {
91     EXPECT_TRUE(*index_iter++ == opt_stride->stride_index_.value());
92   }
93 }
94