xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/broadcast_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #include <gtest/gtest.h>
3 
4 #include <ATen/ATen.h>
5 
6 using namespace at;
7 
8 // can't expand empty tensor
TestEmptyTensor(DeprecatedTypeProperties & T)9 void TestEmptyTensor(DeprecatedTypeProperties& T) {
10   auto empty = randn({0}, T);
11   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
12   ASSERT_ANY_THROW(empty.expand({3}));
13 }
14 
15 // out-place function with 2 args
TestOut2Basic(DeprecatedTypeProperties & T)16 void TestOut2Basic(DeprecatedTypeProperties& T) {
17   auto a = randn({3, 1}, T);
18   auto b = randn({5}, T);
19   std::vector<int64_t> expanded_sizes = {3, 5};
20   ASSERT_TRUE(
21       (a + b).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes)));
22 }
23 
24 // with scalar
TestOut2WithScalar(DeprecatedTypeProperties & T)25 void TestOut2WithScalar(DeprecatedTypeProperties& T) {
26   auto aScalar = ones({}, T);
27   auto b = randn({3, 5}, T);
28   ASSERT_TRUE(
29       (aScalar + b).equal(aScalar.expand(b.sizes()) + b.expand(b.sizes())));
30 }
31 
32 // old fallback behavior yields error
TestOut2OldFallback(DeprecatedTypeProperties & T)33 void TestOut2OldFallback(DeprecatedTypeProperties& T) {
34   auto a = randn({3, 5}, T);
35   auto b = randn({5, 3}, T);
36   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
37   ASSERT_ANY_THROW(a + b);
38 }
39 
40 // with mismatched sizes
TestOut2MismatchedSizes(DeprecatedTypeProperties & T)41 void TestOut2MismatchedSizes(DeprecatedTypeProperties& T) {
42   auto a = randn({3, 5}, T);
43   auto b = randn({7, 5}, T);
44   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
45   ASSERT_ANY_THROW(a + b);
46 }
47 
48 // out-place function with 3 args
TestOut3Basic(DeprecatedTypeProperties & T)49 void TestOut3Basic(DeprecatedTypeProperties& T) {
50   auto a = randn({3, 1, 1}, T);
51   auto b = randn({1, 2, 1}, T);
52   auto c = randn({1, 1, 5}, T);
53   std::vector<int64_t> expanded_sizes = {3, 2, 5};
54   ASSERT_TRUE((a + b + c).equal(
55       a.expand(expanded_sizes) + b.expand(expanded_sizes) +
56       c.expand(expanded_sizes)));
57 }
58 
59 // with scalar
TestOut3WithScalar(DeprecatedTypeProperties & T)60 void TestOut3WithScalar(DeprecatedTypeProperties& T) {
61   auto aTensorScalar = ones({}, T);
62   auto b = randn({3, 2, 1}, T);
63   auto c = randn({1, 2, 5}, T);
64   std::vector<int64_t> expanded_sizes = {3, 2, 5};
65   ASSERT_TRUE(aTensorScalar.addcmul(b, c).equal(
66       aTensorScalar.expand(expanded_sizes)
67           .addcmul(b.expand(expanded_sizes), c.expand(expanded_sizes))));
68 }
69 
70 // old fallback behavior yields error
TestOut3OldFallback(DeprecatedTypeProperties & T)71 void TestOut3OldFallback(DeprecatedTypeProperties& T) {
72   auto a = randn({3, 2, 5}, T);
73   auto b = randn({2, 3, 5}, T);
74   auto c = randn({5, 3, 2}, T);
75   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
76   ASSERT_ANY_THROW(a.addcmul(b, c));
77 }
78 
79 // with mismatched sizes
TestOut3MismatchedSizes(DeprecatedTypeProperties & T)80 void TestOut3MismatchedSizes(DeprecatedTypeProperties& T) {
81   auto a = randn({3, 2, 5}, T);
82   auto b = randn({2, 3, 5}, T);
83   auto c = randn({5, 5, 5}, T);
84   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
85   ASSERT_ANY_THROW(a.addcmul(b, c));
86 }
87 
88 // in-place function with 2 args
TestIn2Basic(DeprecatedTypeProperties & T)89 void TestIn2Basic(DeprecatedTypeProperties& T) {
90   auto a = randn({3, 5}, T);
91   auto b = randn({3, 1}, T);
92   ASSERT_TRUE((a + b).equal(a + b.expand({3, 5})));
93 }
94 
95 // with scalar
TestIn2WithScalar(DeprecatedTypeProperties & T)96 void TestIn2WithScalar(DeprecatedTypeProperties& T) {
97   auto a = randn({3, 5}, T);
98   auto bScalar = ones({}, T);
99   ASSERT_TRUE((a + bScalar).equal(a + bScalar.expand(a.sizes())));
100 }
101 
102 // error: would have to expand inplace arg
TestIn2ExpandError(DeprecatedTypeProperties & T)103 void TestIn2ExpandError(DeprecatedTypeProperties& T) {
104   auto a = randn({1, 5}, T);
105   auto b = randn({3, 1}, T);
106   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
107   ASSERT_ANY_THROW(a.add_(b));
108 }
109 
110 // in-place function with 3 args
TestIn3Basic(DeprecatedTypeProperties & T)111 void TestIn3Basic(DeprecatedTypeProperties& T) {
112   auto a = randn({3, 5, 2}, T);
113   auto b = randn({3, 1, 2}, T);
114   auto c = randn({1, 5, 1}, T);
115   auto aClone = a.clone();
116   ASSERT_TRUE(a.addcmul_(b, c).equal(
117       aClone.addcmul_(b.expand(a.sizes()), c.expand(a.sizes()))));
118 }
119 
120 // with scalar
TestIn3WithScalar(DeprecatedTypeProperties & T)121 void TestIn3WithScalar(DeprecatedTypeProperties& T) {
122   auto a = randn({3, 5, 2}, T);
123   auto b = randn({3, 1, 2}, T);
124   auto c = randn({1, 5, 1}, T);
125   auto aClone = a.clone();
126   auto bScalar = ones({}, T);
127   ASSERT_TRUE(a.addcmul_(bScalar, c)
128                   .equal(aClone.addcmul_(
129                       bScalar.expand(a.sizes()), c.expand(a.sizes()))));
130 }
131 
132 // error: would have to expand inplace arg
TestIn3ExpandError(DeprecatedTypeProperties & T)133 void TestIn3ExpandError(DeprecatedTypeProperties& T) {
134   auto a = randn({1, 3, 5}, T);
135   auto b = randn({4, 1, 1}, T);
136   auto c = randn({1, 3, 1}, T);
137   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
138   ASSERT_ANY_THROW(a.addcmul_(b, c));
139 }
140 
141 // explicit dim specification
TestExplicitDimBasic(DeprecatedTypeProperties & T)142 void TestExplicitDimBasic(DeprecatedTypeProperties& T) {
143   auto a = randn({1}, T);
144   auto b = randn({5, 3}, T);
145   auto c = randn({3, 7}, T);
146   ASSERT_TRUE(a.addmm(b, c).equal(a.expand({5, 7}).addmm(b, c)));
147 }
148 
149 // with scalar
TestExplicitDimWithScalar(DeprecatedTypeProperties & T)150 void TestExplicitDimWithScalar(DeprecatedTypeProperties& T) {
151   auto a = randn({1}, T);
152   auto b = randn({5, 3}, T);
153   auto c = randn({3, 7}, T);
154   Tensor aScalar = ones({}, T);
155   ASSERT_TRUE(aScalar.addmm(b, c).equal(aScalar.expand({5, 7}).addmm(b, c)));
156 }
157 
158 // with mismatched sizes
TestExplicitDimWithMismatchedSizes(DeprecatedTypeProperties & T)159 void TestExplicitDimWithMismatchedSizes(DeprecatedTypeProperties& T) {
160   auto b = randn({5, 3}, T);
161   auto c = randn({3, 7}, T);
162   auto a = randn({3, 3}, T);
163   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
164   ASSERT_ANY_THROW(a.addmm(b, c));
165 }
166 
TEST(BroadcastTest,Broadcast)167 TEST(BroadcastTest, Broadcast) {
168   manual_seed(123);
169   DeprecatedTypeProperties& T = CPU(kFloat);
170 
171   TestEmptyTensor(T);
172 
173   TestOut2Basic(T);
174   TestOut2WithScalar(T);
175   TestOut2OldFallback(T);
176   TestOut2MismatchedSizes(T);
177 
178   TestOut3Basic(T);
179   TestOut3WithScalar(T);
180   TestOut3OldFallback(T);
181   TestOut3MismatchedSizes(T);
182 
183   TestIn2Basic(T);
184   TestIn2WithScalar(T);
185   TestIn2ExpandError(T);
186 
187   TestIn3Basic(T);
188   TestIn3WithScalar(T);
189   TestIn3ExpandError(T);
190 
191   TestExplicitDimBasic(T);
192   TestExplicitDimWithScalar(T);
193   TestExplicitDimWithMismatchedSizes(T);
194 }
195