1
2 #include <gtest/gtest.h>
3
4 #include <ATen/ATen.h>
5
6 using namespace at;
7
8 // can't expand empty tensor
TestEmptyTensor(DeprecatedTypeProperties & T)9 void TestEmptyTensor(DeprecatedTypeProperties& T) {
10 auto empty = randn({0}, T);
11 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
12 ASSERT_ANY_THROW(empty.expand({3}));
13 }
14
15 // out-place function with 2 args
TestOut2Basic(DeprecatedTypeProperties & T)16 void TestOut2Basic(DeprecatedTypeProperties& T) {
17 auto a = randn({3, 1}, T);
18 auto b = randn({5}, T);
19 std::vector<int64_t> expanded_sizes = {3, 5};
20 ASSERT_TRUE(
21 (a + b).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes)));
22 }
23
24 // with scalar
TestOut2WithScalar(DeprecatedTypeProperties & T)25 void TestOut2WithScalar(DeprecatedTypeProperties& T) {
26 auto aScalar = ones({}, T);
27 auto b = randn({3, 5}, T);
28 ASSERT_TRUE(
29 (aScalar + b).equal(aScalar.expand(b.sizes()) + b.expand(b.sizes())));
30 }
31
32 // old fallback behavior yields error
TestOut2OldFallback(DeprecatedTypeProperties & T)33 void TestOut2OldFallback(DeprecatedTypeProperties& T) {
34 auto a = randn({3, 5}, T);
35 auto b = randn({5, 3}, T);
36 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
37 ASSERT_ANY_THROW(a + b);
38 }
39
40 // with mismatched sizes
TestOut2MismatchedSizes(DeprecatedTypeProperties & T)41 void TestOut2MismatchedSizes(DeprecatedTypeProperties& T) {
42 auto a = randn({3, 5}, T);
43 auto b = randn({7, 5}, T);
44 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
45 ASSERT_ANY_THROW(a + b);
46 }
47
48 // out-place function with 3 args
TestOut3Basic(DeprecatedTypeProperties & T)49 void TestOut3Basic(DeprecatedTypeProperties& T) {
50 auto a = randn({3, 1, 1}, T);
51 auto b = randn({1, 2, 1}, T);
52 auto c = randn({1, 1, 5}, T);
53 std::vector<int64_t> expanded_sizes = {3, 2, 5};
54 ASSERT_TRUE((a + b + c).equal(
55 a.expand(expanded_sizes) + b.expand(expanded_sizes) +
56 c.expand(expanded_sizes)));
57 }
58
59 // with scalar
TestOut3WithScalar(DeprecatedTypeProperties & T)60 void TestOut3WithScalar(DeprecatedTypeProperties& T) {
61 auto aTensorScalar = ones({}, T);
62 auto b = randn({3, 2, 1}, T);
63 auto c = randn({1, 2, 5}, T);
64 std::vector<int64_t> expanded_sizes = {3, 2, 5};
65 ASSERT_TRUE(aTensorScalar.addcmul(b, c).equal(
66 aTensorScalar.expand(expanded_sizes)
67 .addcmul(b.expand(expanded_sizes), c.expand(expanded_sizes))));
68 }
69
70 // old fallback behavior yields error
TestOut3OldFallback(DeprecatedTypeProperties & T)71 void TestOut3OldFallback(DeprecatedTypeProperties& T) {
72 auto a = randn({3, 2, 5}, T);
73 auto b = randn({2, 3, 5}, T);
74 auto c = randn({5, 3, 2}, T);
75 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
76 ASSERT_ANY_THROW(a.addcmul(b, c));
77 }
78
79 // with mismatched sizes
TestOut3MismatchedSizes(DeprecatedTypeProperties & T)80 void TestOut3MismatchedSizes(DeprecatedTypeProperties& T) {
81 auto a = randn({3, 2, 5}, T);
82 auto b = randn({2, 3, 5}, T);
83 auto c = randn({5, 5, 5}, T);
84 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
85 ASSERT_ANY_THROW(a.addcmul(b, c));
86 }
87
88 // in-place function with 2 args
TestIn2Basic(DeprecatedTypeProperties & T)89 void TestIn2Basic(DeprecatedTypeProperties& T) {
90 auto a = randn({3, 5}, T);
91 auto b = randn({3, 1}, T);
92 ASSERT_TRUE((a + b).equal(a + b.expand({3, 5})));
93 }
94
95 // with scalar
TestIn2WithScalar(DeprecatedTypeProperties & T)96 void TestIn2WithScalar(DeprecatedTypeProperties& T) {
97 auto a = randn({3, 5}, T);
98 auto bScalar = ones({}, T);
99 ASSERT_TRUE((a + bScalar).equal(a + bScalar.expand(a.sizes())));
100 }
101
102 // error: would have to expand inplace arg
TestIn2ExpandError(DeprecatedTypeProperties & T)103 void TestIn2ExpandError(DeprecatedTypeProperties& T) {
104 auto a = randn({1, 5}, T);
105 auto b = randn({3, 1}, T);
106 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
107 ASSERT_ANY_THROW(a.add_(b));
108 }
109
110 // in-place function with 3 args
TestIn3Basic(DeprecatedTypeProperties & T)111 void TestIn3Basic(DeprecatedTypeProperties& T) {
112 auto a = randn({3, 5, 2}, T);
113 auto b = randn({3, 1, 2}, T);
114 auto c = randn({1, 5, 1}, T);
115 auto aClone = a.clone();
116 ASSERT_TRUE(a.addcmul_(b, c).equal(
117 aClone.addcmul_(b.expand(a.sizes()), c.expand(a.sizes()))));
118 }
119
120 // with scalar
TestIn3WithScalar(DeprecatedTypeProperties & T)121 void TestIn3WithScalar(DeprecatedTypeProperties& T) {
122 auto a = randn({3, 5, 2}, T);
123 auto b = randn({3, 1, 2}, T);
124 auto c = randn({1, 5, 1}, T);
125 auto aClone = a.clone();
126 auto bScalar = ones({}, T);
127 ASSERT_TRUE(a.addcmul_(bScalar, c)
128 .equal(aClone.addcmul_(
129 bScalar.expand(a.sizes()), c.expand(a.sizes()))));
130 }
131
132 // error: would have to expand inplace arg
TestIn3ExpandError(DeprecatedTypeProperties & T)133 void TestIn3ExpandError(DeprecatedTypeProperties& T) {
134 auto a = randn({1, 3, 5}, T);
135 auto b = randn({4, 1, 1}, T);
136 auto c = randn({1, 3, 1}, T);
137 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
138 ASSERT_ANY_THROW(a.addcmul_(b, c));
139 }
140
141 // explicit dim specification
TestExplicitDimBasic(DeprecatedTypeProperties & T)142 void TestExplicitDimBasic(DeprecatedTypeProperties& T) {
143 auto a = randn({1}, T);
144 auto b = randn({5, 3}, T);
145 auto c = randn({3, 7}, T);
146 ASSERT_TRUE(a.addmm(b, c).equal(a.expand({5, 7}).addmm(b, c)));
147 }
148
149 // with scalar
TestExplicitDimWithScalar(DeprecatedTypeProperties & T)150 void TestExplicitDimWithScalar(DeprecatedTypeProperties& T) {
151 auto a = randn({1}, T);
152 auto b = randn({5, 3}, T);
153 auto c = randn({3, 7}, T);
154 Tensor aScalar = ones({}, T);
155 ASSERT_TRUE(aScalar.addmm(b, c).equal(aScalar.expand({5, 7}).addmm(b, c)));
156 }
157
158 // with mismatched sizes
TestExplicitDimWithMismatchedSizes(DeprecatedTypeProperties & T)159 void TestExplicitDimWithMismatchedSizes(DeprecatedTypeProperties& T) {
160 auto b = randn({5, 3}, T);
161 auto c = randn({3, 7}, T);
162 auto a = randn({3, 3}, T);
163 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
164 ASSERT_ANY_THROW(a.addmm(b, c));
165 }
166
TEST(BroadcastTest,Broadcast)167 TEST(BroadcastTest, Broadcast) {
168 manual_seed(123);
169 DeprecatedTypeProperties& T = CPU(kFloat);
170
171 TestEmptyTensor(T);
172
173 TestOut2Basic(T);
174 TestOut2WithScalar(T);
175 TestOut2OldFallback(T);
176 TestOut2MismatchedSizes(T);
177
178 TestOut3Basic(T);
179 TestOut3WithScalar(T);
180 TestOut3OldFallback(T);
181 TestOut3MismatchedSizes(T);
182
183 TestIn2Basic(T);
184 TestIn2WithScalar(T);
185 TestIn2ExpandError(T);
186
187 TestIn3Basic(T);
188 TestIn3WithScalar(T);
189 TestIn3ExpandError(T);
190
191 TestExplicitDimBasic(T);
192 TestExplicitDimWithScalar(T);
193 TestExplicitDimWithMismatchedSizes(T);
194 }
195