1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/pytree/aten_util/ivalue_util.h>
10 #include <gtest/gtest.h>
11
12 using executorch::extension::flatten;
13 using executorch::extension::is_same;
14 using executorch::extension::unflatten;
15
makeExampleTensors(size_t N)16 std::vector<at::Tensor> makeExampleTensors(size_t N) {
17 std::vector<at::Tensor> tensors;
18 for (int i = 0; i < N; ++i) {
19 tensors.push_back(at::randn({2, 3, 5}));
20 }
21 return tensors;
22 }
23
24 struct TestCase {
25 c10::IValue ivalue;
26 std::vector<at::Tensor> tensors;
27 };
28
makeExampleListOfTensors()29 TestCase makeExampleListOfTensors() {
30 auto tensors = makeExampleTensors(3);
31 auto list = c10::List<at::Tensor>{
32 tensors[0],
33 tensors[1],
34 tensors[2],
35 };
36 return {list, tensors};
37 }
38
makeExampleTupleOfTensors()39 TestCase makeExampleTupleOfTensors() {
40 auto tensors = makeExampleTensors(3);
41 auto tuple = std::make_tuple(tensors[0], tensors[1], tensors[2]);
42 return {tuple, tensors};
43 }
44
makeExampleDictOfTensors()45 TestCase makeExampleDictOfTensors() {
46 auto tensors = makeExampleTensors(3);
47 auto dict = c10::Dict<std::string, at::Tensor>();
48 dict.insert("x", tensors[0]);
49 dict.insert("y", tensors[1]);
50 dict.insert("z", tensors[2]);
51 return {dict, tensors};
52 }
53
makeExampleComposite()54 TestCase makeExampleComposite() {
55 auto tensors = makeExampleTensors(8);
56
57 c10::IValue list = c10::List<at::Tensor>{
58 tensors[1],
59 tensors[2],
60 };
61
62 auto inner_dict1 = c10::Dict<std::string, at::Tensor>();
63 inner_dict1.insert("x", tensors[3]);
64 inner_dict1.insert("y", tensors[4]);
65
66 auto inner_dict2 = c10::Dict<std::string, at::Tensor>();
67 inner_dict2.insert("z", tensors[5]);
68 inner_dict2.insert("w", tensors[6]);
69
70 auto dict = c10::Dict<std::string, c10::Dict<std::string, at::Tensor>>();
71 dict.insert("a", inner_dict1);
72 dict.insert("b", inner_dict2);
73
74 return {{std::make_tuple(tensors[0], list, dict, tensors[7])}, tensors};
75 }
76
testFlatten(const TestCase & testcase)77 void testFlatten(const TestCase& testcase) {
78 auto ret = flatten(testcase.ivalue);
79 ASSERT_TRUE(is_same(ret.first, testcase.tensors));
80 }
81
TEST(IValueFlattenTest,ListOfTensor)82 TEST(IValueFlattenTest, ListOfTensor) {
83 testFlatten(makeExampleListOfTensors());
84 }
85
TEST(IValueFlattenTest,TupleOfTensor)86 TEST(IValueFlattenTest, TupleOfTensor) {
87 testFlatten(makeExampleTupleOfTensors());
88 }
89
TEST(IValueFlattenTest,DictOfTensor)90 TEST(IValueFlattenTest, DictOfTensor) {
91 testFlatten(makeExampleDictOfTensors());
92 }
93
TEST(IValueFlattenTest,Composite)94 TEST(IValueFlattenTest, Composite) {
95 testFlatten(makeExampleComposite());
96 }
97
testUnflatten(const TestCase & testcase)98 void testUnflatten(const TestCase& testcase) {
99 // first we flatten the IValue
100 auto ret = flatten(testcase.ivalue);
101
102 // then we unflatten it
103 c10::IValue unflattened = unflatten(ret.first, ret.second);
104
105 // and see if we got the same IValue back
106 ASSERT_TRUE(is_same(unflattened, testcase.ivalue));
107 }
108
TEST(IValueUnflattenTest,ListOfTensor)109 TEST(IValueUnflattenTest, ListOfTensor) {
110 testUnflatten(makeExampleListOfTensors());
111 }
112
TEST(IValueUnflattenTest,TupleOfTensor)113 TEST(IValueUnflattenTest, TupleOfTensor) {
114 testUnflatten(makeExampleTupleOfTensors());
115 }
116
TEST(IValueUnflattenTest,DictOfTensor)117 TEST(IValueUnflattenTest, DictOfTensor) {
118 testUnflatten(makeExampleDictOfTensors());
119 }
120
TEST(IValueUnflattenTest,Composite)121 TEST(IValueUnflattenTest, Composite) {
122 testUnflatten(makeExampleComposite());
123 }
124