xref: /aosp_15_r20/external/pytorch/test/cpp/lazy/test_permutation_util.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/lazy/core/permutation_util.h>
5 
6 namespace torch {
7 namespace lazy {
8 
TEST(PermutationUtilTest,TestInversePermutation)9 TEST(PermutationUtilTest, TestInversePermutation) {
10   EXPECT_EQ(InversePermutation({0}), std::vector<int64_t>({0}));
11   EXPECT_EQ(InversePermutation({0, 1, 2}), std::vector<int64_t>({0, 1, 2}));
12   EXPECT_EQ(
13       InversePermutation({1, 3, 2, 0}), std::vector<int64_t>({3, 0, 2, 1}));
14   // Not a valid permutation
15   EXPECT_THROW(InversePermutation({-1}), c10::Error);
16   EXPECT_THROW(InversePermutation({1, 1}), c10::Error);
17 }
18 
TEST(PermutationUtilTest,TestIsPermutation)19 TEST(PermutationUtilTest, TestIsPermutation) {
20   EXPECT_TRUE(IsPermutation({0}));
21   EXPECT_TRUE(IsPermutation({0, 1, 2, 3}));
22   EXPECT_FALSE(IsPermutation({-1}));
23   EXPECT_FALSE(IsPermutation({5, 3}));
24   EXPECT_FALSE(IsPermutation({1, 2, 3}));
25 }
26 
TEST(PermutationUtilTest,TestPermute)27 TEST(PermutationUtilTest, TestPermute) {
28   EXPECT_EQ(
29       PermuteDimensions({0}, std::vector<int64_t>({224})),
30       std::vector<int64_t>({224}));
31   EXPECT_EQ(
32       PermuteDimensions({1, 2, 0}, std::vector<int64_t>({3, 224, 224})),
33       std::vector<int64_t>({224, 224, 3}));
34   // Not a valid permutation
35   EXPECT_THROW(
36       PermuteDimensions({-1}, std::vector<int64_t>({244})), c10::Error);
37   EXPECT_THROW(
38       PermuteDimensions({3, 2}, std::vector<int64_t>({244})), c10::Error);
39   // Permutation size is different from the to-be-permuted vector size
40   EXPECT_THROW(
41       PermuteDimensions({0, 1}, std::vector<int64_t>({244})), c10::Error);
42   EXPECT_THROW(
43       PermuteDimensions({0}, std::vector<int64_t>({3, 244, 244})), c10::Error);
44 }
45 
46 } // namespace lazy
47 } // namespace torch
48