1 #include <gtest/gtest.h>
2
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/lazy/core/permutation_util.h>
5
6 namespace torch {
7 namespace lazy {
8
TEST(PermutationUtilTest,TestInversePermutation)9 TEST(PermutationUtilTest, TestInversePermutation) {
10 EXPECT_EQ(InversePermutation({0}), std::vector<int64_t>({0}));
11 EXPECT_EQ(InversePermutation({0, 1, 2}), std::vector<int64_t>({0, 1, 2}));
12 EXPECT_EQ(
13 InversePermutation({1, 3, 2, 0}), std::vector<int64_t>({3, 0, 2, 1}));
14 // Not a valid permutation
15 EXPECT_THROW(InversePermutation({-1}), c10::Error);
16 EXPECT_THROW(InversePermutation({1, 1}), c10::Error);
17 }
18
TEST(PermutationUtilTest,TestIsPermutation)19 TEST(PermutationUtilTest, TestIsPermutation) {
20 EXPECT_TRUE(IsPermutation({0}));
21 EXPECT_TRUE(IsPermutation({0, 1, 2, 3}));
22 EXPECT_FALSE(IsPermutation({-1}));
23 EXPECT_FALSE(IsPermutation({5, 3}));
24 EXPECT_FALSE(IsPermutation({1, 2, 3}));
25 }
26
TEST(PermutationUtilTest,TestPermute)27 TEST(PermutationUtilTest, TestPermute) {
28 EXPECT_EQ(
29 PermuteDimensions({0}, std::vector<int64_t>({224})),
30 std::vector<int64_t>({224}));
31 EXPECT_EQ(
32 PermuteDimensions({1, 2, 0}, std::vector<int64_t>({3, 224, 224})),
33 std::vector<int64_t>({224, 224, 3}));
34 // Not a valid permutation
35 EXPECT_THROW(
36 PermuteDimensions({-1}, std::vector<int64_t>({244})), c10::Error);
37 EXPECT_THROW(
38 PermuteDimensions({3, 2}, std::vector<int64_t>({244})), c10::Error);
39 // Permutation size is different from the to-be-permuted vector size
40 EXPECT_THROW(
41 PermuteDimensions({0, 1}, std::vector<int64_t>({244})), c10::Error);
42 EXPECT_THROW(
43 PermuteDimensions({0}, std::vector<int64_t>({3, 244, 244})), c10::Error);
44 }
45
46 } // namespace lazy
47 } // namespace torch
48