xref: /aosp_15_r20/external/executorch/extension/pytree/aten_util/test/ivalue_util_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/pytree/aten_util/ivalue_util.h>
10 #include <gtest/gtest.h>
11 
12 using executorch::extension::flatten;
13 using executorch::extension::is_same;
14 using executorch::extension::unflatten;
15 
makeExampleTensors(size_t N)16 std::vector<at::Tensor> makeExampleTensors(size_t N) {
17   std::vector<at::Tensor> tensors;
18   for (int i = 0; i < N; ++i) {
19     tensors.push_back(at::randn({2, 3, 5}));
20   }
21   return tensors;
22 }
23 
24 struct TestCase {
25   c10::IValue ivalue;
26   std::vector<at::Tensor> tensors;
27 };
28 
makeExampleListOfTensors()29 TestCase makeExampleListOfTensors() {
30   auto tensors = makeExampleTensors(3);
31   auto list = c10::List<at::Tensor>{
32       tensors[0],
33       tensors[1],
34       tensors[2],
35   };
36   return {list, tensors};
37 }
38 
makeExampleTupleOfTensors()39 TestCase makeExampleTupleOfTensors() {
40   auto tensors = makeExampleTensors(3);
41   auto tuple = std::make_tuple(tensors[0], tensors[1], tensors[2]);
42   return {tuple, tensors};
43 }
44 
makeExampleDictOfTensors()45 TestCase makeExampleDictOfTensors() {
46   auto tensors = makeExampleTensors(3);
47   auto dict = c10::Dict<std::string, at::Tensor>();
48   dict.insert("x", tensors[0]);
49   dict.insert("y", tensors[1]);
50   dict.insert("z", tensors[2]);
51   return {dict, tensors};
52 }
53 
makeExampleComposite()54 TestCase makeExampleComposite() {
55   auto tensors = makeExampleTensors(8);
56 
57   c10::IValue list = c10::List<at::Tensor>{
58       tensors[1],
59       tensors[2],
60   };
61 
62   auto inner_dict1 = c10::Dict<std::string, at::Tensor>();
63   inner_dict1.insert("x", tensors[3]);
64   inner_dict1.insert("y", tensors[4]);
65 
66   auto inner_dict2 = c10::Dict<std::string, at::Tensor>();
67   inner_dict2.insert("z", tensors[5]);
68   inner_dict2.insert("w", tensors[6]);
69 
70   auto dict = c10::Dict<std::string, c10::Dict<std::string, at::Tensor>>();
71   dict.insert("a", inner_dict1);
72   dict.insert("b", inner_dict2);
73 
74   return {{std::make_tuple(tensors[0], list, dict, tensors[7])}, tensors};
75 }
76 
testFlatten(const TestCase & testcase)77 void testFlatten(const TestCase& testcase) {
78   auto ret = flatten(testcase.ivalue);
79   ASSERT_TRUE(is_same(ret.first, testcase.tensors));
80 }
81 
TEST(IValueFlattenTest,ListOfTensor)82 TEST(IValueFlattenTest, ListOfTensor) {
83   testFlatten(makeExampleListOfTensors());
84 }
85 
TEST(IValueFlattenTest,TupleOfTensor)86 TEST(IValueFlattenTest, TupleOfTensor) {
87   testFlatten(makeExampleTupleOfTensors());
88 }
89 
TEST(IValueFlattenTest,DictOfTensor)90 TEST(IValueFlattenTest, DictOfTensor) {
91   testFlatten(makeExampleDictOfTensors());
92 }
93 
TEST(IValueFlattenTest,Composite)94 TEST(IValueFlattenTest, Composite) {
95   testFlatten(makeExampleComposite());
96 }
97 
testUnflatten(const TestCase & testcase)98 void testUnflatten(const TestCase& testcase) {
99   // first we flatten the IValue
100   auto ret = flatten(testcase.ivalue);
101 
102   // then we unflatten it
103   c10::IValue unflattened = unflatten(ret.first, ret.second);
104 
105   // and see if we got the same IValue back
106   ASSERT_TRUE(is_same(unflattened, testcase.ivalue));
107 }
108 
TEST(IValueUnflattenTest,ListOfTensor)109 TEST(IValueUnflattenTest, ListOfTensor) {
110   testUnflatten(makeExampleListOfTensors());
111 }
112 
TEST(IValueUnflattenTest,TupleOfTensor)113 TEST(IValueUnflattenTest, TupleOfTensor) {
114   testUnflatten(makeExampleTupleOfTensors());
115 }
116 
TEST(IValueUnflattenTest,DictOfTensor)117 TEST(IValueUnflattenTest, DictOfTensor) {
118   testUnflatten(makeExampleDictOfTensors());
119 }
120 
TEST(IValueUnflattenTest,Composite)121 TEST(IValueUnflattenTest, Composite) {
122   testUnflatten(makeExampleComposite());
123 }
124