xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_aten.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <algorithm>
2 #include <sstream>
3 #include <stdexcept>
4 
5 #include <gtest/gtest.h>
6 
7 #include <c10/macros/Macros.h>
8 #include <c10/util/irange.h>
9 #include "test/cpp/tensorexpr/padded_buffer.h"
10 #include "test/cpp/tensorexpr/test_base.h"
11 #include "torch/csrc/jit/tensorexpr/ir_printer.h"
12 
13 namespace torch {
14 namespace jit {
15 
16 using namespace torch::jit::tensorexpr;
17 
TEST(ATen,_cast_Float)18 TEST(ATen, _cast_Float) {
19   const int kTotalSize = 128;
20   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
21   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
22 
23   VarHandle index = VarHandle("index", kInt);
24   ExprHandle load_a = a_buf.load(index);
25   ExprHandle to_float = Cast::make(kFloat, load_a);
26   StmtPtr store_b = b_buf.store({index}, to_float);
27   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
28 
29   PaddedBuffer<int> a_v(kTotalSize);
30   PaddedBuffer<float> b_v(kTotalSize);
31 
32   for (const auto i : c10::irange(kTotalSize)) {
33     a_v(i) = i;
34   }
35 
36   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
37   ir_eval(a_v, b_v);
38 
39   for (const auto i : c10::irange(kTotalSize)) {
40     ASSERT_EQ(a_v(i), i);
41     ASSERT_EQ(b_v(i), static_cast<float>(i));
42   }
43 }
44 
TEST(ATen,negInt)45 TEST(ATen, negInt) {
46   const int kTotalSize = 128;
47   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
48   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
49 
50   VarHandle index = VarHandle("index", kInt);
51   ExprHandle load_a = a_buf.load(index);
52   ExprHandle to_float = Sub::make(0, load_a);
53   StmtPtr store_b = b_buf.store({index}, to_float);
54   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
55 
56   PaddedBuffer<int> a_v(kTotalSize);
57   PaddedBuffer<int> b_v(kTotalSize);
58 
59   for (const auto i : c10::irange(kTotalSize)) {
60     a_v(i) = i;
61   }
62 
63   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
64   ir_eval(a_v, b_v);
65 
66   for (const auto i : c10::irange(kTotalSize)) {
67     ASSERT_EQ(a_v(i), i);
68     ASSERT_EQ(b_v(i), -static_cast<float>(i));
69   }
70 }
71 
TEST(ATen,negFloat)72 TEST(ATen, negFloat) {
73   const int kTotalSize = 128;
74   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
75   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
76 
77   VarHandle index = VarHandle("index", kInt);
78   ExprHandle load_a = a_buf.load(index);
79   ExprHandle to_float = Sub::make(0, load_a);
80   StmtPtr store_b = b_buf.store({index}, to_float);
81   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
82 
83   PaddedBuffer<float> a_v(kTotalSize);
84   PaddedBuffer<float> b_v(kTotalSize);
85 
86   for (const auto i : c10::irange(kTotalSize)) {
87     a_v(i) = i;
88   }
89 
90   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
91   ir_eval(a_v, b_v);
92 
93   for (const auto i : c10::irange(kTotalSize)) {
94     ASSERT_EQ(a_v(i), i);
95     ASSERT_EQ(b_v(i), -i);
96   }
97 }
98 
TEST(ATen,addInt)99 TEST(ATen, addInt) {
100   const int kTotalSize = 128;
101   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
102   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
103   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
104   BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt);
105 
106   VarHandle index = VarHandle("index", kInt);
107   ExprHandle load_a = a_buf.load(index);
108   ExprHandle load_b = b_buf.load(index);
109   ExprHandle load_c = c_buf.load(index);
110   StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c);
111   StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
112 
113   PaddedBuffer<int> a_v(kTotalSize);
114   PaddedBuffer<int> b_v(kTotalSize);
115   PaddedBuffer<int> c_v(kTotalSize);
116   PaddedBuffer<int> d_v(kTotalSize);
117 
118   for (const auto i : c10::irange(kTotalSize)) {
119     a_v(i) = i;
120     b_v(i) = 2 * i + 1;
121     c_v(i) = 3 * i + 2;
122   }
123 
124   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf});
125   ir_eval(a_v, b_v, c_v, d_v);
126 
127   for (const auto i : c10::irange(kTotalSize)) {
128     ASSERT_EQ(a_v(i), i);
129     ASSERT_EQ(b_v(i), 2 * i + 1);
130     ASSERT_EQ(c_v(i), 3 * i + 2);
131     ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i));
132   }
133 }
134 
TEST(ATen,addFloat)135 TEST(ATen, addFloat) {
136   const int kTotalSize = 128;
137   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
138   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
139   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
140   BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat);
141 
142   VarHandle index = VarHandle("index", kInt);
143   ExprHandle load_a = a_buf.load(index);
144   ExprHandle load_b = b_buf.load(index);
145   ExprHandle load_c = c_buf.load(index);
146   StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c);
147   StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
148 
149   PaddedBuffer<float> a_v(kTotalSize);
150   PaddedBuffer<float> b_v(kTotalSize);
151   PaddedBuffer<float> c_v(kTotalSize);
152   PaddedBuffer<float> d_v(kTotalSize);
153 
154   for (const auto i : c10::irange(kTotalSize)) {
155     a_v(i) = i;
156     b_v(i) = 2 * i + 1;
157     c_v(i) = 3 * i + 2;
158   }
159 
160   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf});
161   ir_eval(a_v, b_v, c_v, d_v);
162 
163   for (const auto i : c10::irange(kTotalSize)) {
164     ASSERT_EQ(a_v(i), i);
165     ASSERT_EQ(b_v(i), 2 * i + 1);
166     ASSERT_EQ(c_v(i), 3 * i + 2);
167     ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i));
168   }
169 }
170 
TEST(ATen,subInt)171 TEST(ATen, subInt) {
172   const int kTotalSize = 128;
173   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
174   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
175   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
176   BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt);
177 
178   VarHandle index = VarHandle("index", kInt);
179   ExprHandle load_a = a_buf.load(index);
180   ExprHandle load_b = b_buf.load(index);
181   ExprHandle load_c = c_buf.load(index);
182   StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c);
183   StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
184 
185   PaddedBuffer<int> a_v(kTotalSize);
186   PaddedBuffer<int> b_v(kTotalSize);
187   PaddedBuffer<int> c_v(kTotalSize);
188   PaddedBuffer<int> d_v(kTotalSize);
189 
190   for (const auto i : c10::irange(kTotalSize)) {
191     a_v(i) = i;
192     b_v(i) = 2 * i + 1;
193     c_v(i) = 3 * i + 2;
194   }
195 
196   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf});
197   ir_eval(a_v, b_v, c_v, d_v);
198 
199   for (const auto i : c10::irange(kTotalSize)) {
200     ASSERT_EQ(a_v(i), i);
201     ASSERT_EQ(b_v(i), 2 * i + 1);
202     ASSERT_EQ(c_v(i), 3 * i + 2);
203     ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i));
204   }
205 }
206 
TEST(ATen,subFloat)207 TEST(ATen, subFloat) {
208   const int kTotalSize = 128;
209   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
210   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
211   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
212   BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat);
213 
214   VarHandle index = VarHandle("index", kInt);
215   ExprHandle load_a = a_buf.load(index);
216   ExprHandle load_b = b_buf.load(index);
217   ExprHandle load_c = c_buf.load(index);
218   StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c);
219   StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
220 
221   PaddedBuffer<float> a_v(kTotalSize);
222   PaddedBuffer<float> b_v(kTotalSize);
223   PaddedBuffer<float> c_v(kTotalSize);
224   PaddedBuffer<float> d_v(kTotalSize);
225 
226   for (const auto i : c10::irange(kTotalSize)) {
227     a_v(i) = i;
228     b_v(i) = 2 * i + 1;
229     c_v(i) = 3 * i + 2;
230   }
231 
232   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf});
233   ir_eval(a_v, b_v, c_v, d_v);
234 
235   for (const auto i : c10::irange(kTotalSize)) {
236     ASSERT_EQ(a_v(i), i);
237     ASSERT_EQ(b_v(i), 2 * i + 1);
238     ASSERT_EQ(c_v(i), 3 * i + 2);
239     ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i));
240   }
241 }
242 
TEST(ATen,lerp)243 TEST(ATen, lerp) {
244   const int kTotalSize = 128;
245   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
246   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
247   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
248   BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat);
249 
250   VarHandle index = VarHandle("index", kInt);
251   ExprHandle load_a = a_buf.load(index);
252   ExprHandle load_b = b_buf.load(index);
253   ExprHandle load_c = c_buf.load(index);
254   StmtPtr store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a));
255   StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
256 
257   PaddedBuffer<float> a_v(kTotalSize);
258   PaddedBuffer<float> b_v(kTotalSize);
259   PaddedBuffer<float> c_v(kTotalSize);
260   PaddedBuffer<float> d_v(kTotalSize);
261 
262   for (const auto i : c10::irange(kTotalSize)) {
263     a_v(i) = i;
264     b_v(i) = 2 * i + 1;
265     c_v(i) = 3 * i + 2;
266   }
267 
268   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf});
269   ir_eval(a_v, b_v, c_v, d_v);
270 
271   for (const auto i : c10::irange(kTotalSize)) {
272     ASSERT_EQ(a_v(i), i);
273     ASSERT_EQ(b_v(i), 2 * i + 1);
274     ASSERT_EQ(c_v(i), 3 * i + 2);
275     ASSERT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i)));
276   }
277 }
278 
TEST(ATen,addcmulInt)279 TEST(ATen, addcmulInt) {
280   const int kTotalSize = 128;
281   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
282   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
283   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
284   BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt);
285   BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kInt);
286 
287   VarHandle index = VarHandle("index", kInt);
288   ExprHandle load_a = a_buf.load(index);
289   ExprHandle load_b = b_buf.load(index);
290   ExprHandle load_c = c_buf.load(index);
291   ExprHandle load_d = d_buf.load(index);
292   StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d);
293   StmtPtr stmt = For::make(index, 0, kTotalSize, store_e);
294 
295   PaddedBuffer<int> a_v(kTotalSize);
296   PaddedBuffer<int> b_v(kTotalSize);
297   PaddedBuffer<int> c_v(kTotalSize);
298   PaddedBuffer<int> d_v(kTotalSize);
299   PaddedBuffer<int> e_v(kTotalSize);
300 
301   for (const auto i : c10::irange(kTotalSize)) {
302     a_v(i) = i;
303     b_v(i) = 2 * i + 1;
304     c_v(i) = 3 * i + 2;
305     d_v(i) = 5 * i + 3;
306   }
307 
308   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf});
309   ir_eval(a_v, b_v, c_v, d_v, e_v);
310 
311   for (const auto i : c10::irange(kTotalSize)) {
312     ASSERT_EQ(a_v(i), i);
313     ASSERT_EQ(b_v(i), 2 * i + 1);
314     ASSERT_EQ(c_v(i), 3 * i + 2);
315     ASSERT_EQ(d_v(i), 5 * i + 3);
316     ASSERT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i));
317   }
318 }
319 
TEST(ATen,addcmulFloat)320 TEST(ATen, addcmulFloat) {
321   const int kTotalSize = 128;
322   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
323   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
324   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
325   BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat);
326   BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kFloat);
327 
328   VarHandle index = VarHandle("index", kInt);
329   ExprHandle load_a = a_buf.load(index);
330   ExprHandle load_b = b_buf.load(index);
331   ExprHandle load_c = c_buf.load(index);
332   ExprHandle load_d = d_buf.load(index);
333   StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d);
334   StmtPtr stmt = For::make(index, 0, kTotalSize, store_e);
335 
336   PaddedBuffer<float> a_v(kTotalSize);
337   PaddedBuffer<float> b_v(kTotalSize);
338   PaddedBuffer<float> c_v(kTotalSize);
339   PaddedBuffer<float> d_v(kTotalSize);
340   PaddedBuffer<float> e_v(kTotalSize);
341 
342   for (const auto i : c10::irange(kTotalSize)) {
343     a_v(i) = i;
344     b_v(i) = 2 * i + 1;
345     c_v(i) = 3 * i + 2;
346     d_v(i) = 5 * i + 3;
347   }
348 
349   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf});
350   ir_eval(a_v, b_v, c_v, d_v, e_v);
351 
352   for (const auto i : c10::irange(kTotalSize)) {
353     ASSERT_EQ(a_v(i), i);
354     ASSERT_EQ(b_v(i), 2 * i + 1);
355     ASSERT_EQ(c_v(i), 3 * i + 2);
356     ASSERT_EQ(d_v(i), 5 * i + 3);
357     ASSERT_FLOAT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i));
358   }
359 }
360 
TEST(ATen,mulInt)361 TEST(ATen, mulInt) {
362   const int kTotalSize = 128;
363   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
364   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
365   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
366 
367   VarHandle index = VarHandle("index", kInt);
368   ExprHandle load_a = a_buf.load(index);
369   ExprHandle load_b = b_buf.load(index);
370   StmtPtr store_c = c_buf.store({index}, load_a * load_b);
371   StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
372 
373   PaddedBuffer<int> a_v(kTotalSize);
374   PaddedBuffer<int> b_v(kTotalSize);
375   PaddedBuffer<int> c_v(kTotalSize);
376 
377   for (const auto i : c10::irange(kTotalSize)) {
378     a_v(i) = i;
379     b_v(i) = 2 * i + 1;
380   }
381 
382   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
383   ir_eval(a_v, b_v, c_v);
384 
385   for (const auto i : c10::irange(kTotalSize)) {
386     ASSERT_EQ(a_v(i), i);
387     ASSERT_EQ(b_v(i), 2 * i + 1);
388     ASSERT_EQ(c_v(i), a_v(i) * b_v(i));
389   }
390 }
391 
TEST(ATen,mulFloat)392 TEST(ATen, mulFloat) {
393   const int kTotalSize = 128;
394   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
395   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
396   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
397 
398   VarHandle index = VarHandle("index", kInt);
399   ExprHandle load_a = a_buf.load(index);
400   ExprHandle load_b = b_buf.load(index);
401   StmtPtr store_c = c_buf.store({index}, load_a * load_b);
402   StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
403 
404   PaddedBuffer<float> a_v(kTotalSize);
405   PaddedBuffer<float> b_v(kTotalSize);
406   PaddedBuffer<float> c_v(kTotalSize);
407 
408   for (const auto i : c10::irange(kTotalSize)) {
409     a_v(i) = i;
410     b_v(i) = 2 * i + 1;
411   }
412 
413   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
414   ir_eval(a_v, b_v, c_v);
415 
416   for (const auto i : c10::irange(kTotalSize)) {
417     ASSERT_EQ(a_v(i), i);
418     ASSERT_EQ(b_v(i), 2 * i + 1);
419     ASSERT_EQ(c_v(i), a_v(i) * b_v(i));
420   }
421 }
422 
TEST(ATen,divInt)423 TEST(ATen, divInt) {
424   const int kTotalSize = 128;
425   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
426   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
427   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
428 
429   VarHandle index = VarHandle("index", kInt);
430   ExprHandle load_a = a_buf.load(index);
431   ExprHandle load_b = b_buf.load(index);
432   StmtPtr store_c = c_buf.store({index}, load_a / load_b);
433   StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
434 
435   PaddedBuffer<int> a_v(kTotalSize);
436   PaddedBuffer<int> b_v(kTotalSize);
437   PaddedBuffer<int> c_v(kTotalSize);
438 
439   for (const auto i : c10::irange(kTotalSize)) {
440     a_v(i) = 2 * i + 1;
441     b_v(i) = i + 1;
442   }
443 
444   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
445   ir_eval(a_v, b_v, c_v);
446 
447   for (const auto i : c10::irange(kTotalSize)) {
448     ASSERT_EQ(a_v(i), 2 * i + 1);
449     ASSERT_EQ(b_v(i), i + 1);
450     ASSERT_EQ(c_v(i), a_v(i) / b_v(i));
451   }
452 }
453 
TEST(ATen,divFloat)454 TEST(ATen, divFloat) {
455   const int kTotalSize = 128;
456   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
457   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
458   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
459 
460   VarHandle index = VarHandle("index", kInt);
461   ExprHandle load_a = a_buf.load(index);
462   ExprHandle load_b = b_buf.load(index);
463   StmtPtr store_c = c_buf.store({index}, load_a / load_b);
464   StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
465 
466   PaddedBuffer<float> a_v(kTotalSize);
467   PaddedBuffer<float> b_v(kTotalSize);
468   PaddedBuffer<float> c_v(kTotalSize);
469 
470   for (const auto i : c10::irange(kTotalSize)) {
471     a_v(i) = 2 * i + 1;
472     b_v(i) = i + 1;
473   }
474 
475   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
476   ir_eval(a_v, b_v, c_v);
477 
478   for (const auto i : c10::irange(kTotalSize)) {
479     ASSERT_EQ(a_v(i), 2 * i + 1);
480     ASSERT_EQ(b_v(i), i + 1);
481     ASSERT_EQ(c_v(i), a_v(i) / b_v(i));
482   }
483 }
484 
TEST(ATen,maxInt)485 TEST(ATen, maxInt) {
486   const int kTotalSize = 128;
487   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
488   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
489   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
490 
491   VarHandle index = VarHandle("index", kInt);
492   ExprHandle load_a = a_buf.load(index);
493   ExprHandle load_b = b_buf.load(index);
494   StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true));
495   StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
496 
497   PaddedBuffer<int> a_v(kTotalSize);
498   PaddedBuffer<int> b_v(kTotalSize);
499   PaddedBuffer<int> c_v(kTotalSize);
500 
501   for (const auto i : c10::irange(kTotalSize)) {
502     a_v(i) = i;
503     b_v(i) = 2 * i + 1;
504   }
505 
506   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
507   ir_eval(a_v, b_v, c_v);
508 
509   for (const auto i : c10::irange(kTotalSize)) {
510     ASSERT_EQ(a_v(i), i);
511     ASSERT_EQ(b_v(i), 2 * i + 1);
512     ASSERT_EQ(c_v(i), std::max(a_v(i), b_v(i)));
513   }
514 }
515 
TEST(ATen,maxFloat)516 TEST(ATen, maxFloat) {
517   const int kTotalSize = 128;
518   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
519   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
520   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
521 
522   VarHandle index = VarHandle("index", kInt);
523   ExprHandle load_a = a_buf.load(index);
524   ExprHandle load_b = b_buf.load(index);
525   StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true));
526   StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
527 
528   PaddedBuffer<float> a_v(kTotalSize);
529   PaddedBuffer<float> b_v(kTotalSize);
530   PaddedBuffer<float> c_v(kTotalSize);
531 
532   for (const auto i : c10::irange(kTotalSize)) {
533     a_v(i) = i;
534     b_v(i) = 2 * i + 1;
535   }
536 
537   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
538   ir_eval(a_v, b_v, c_v);
539 
540   for (const auto i : c10::irange(kTotalSize)) {
541     ASSERT_EQ(a_v(i), i);
542     ASSERT_EQ(b_v(i), 2 * i + 1);
543     ASSERT_EQ(c_v(i), std::fmax(a_v(i), b_v(i)));
544   }
545 }
546 
TEST(ATen,minInt)547 TEST(ATen, minInt) {
548   const int kTotalSize = 128;
549   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
550   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
551   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt);
552 
553   VarHandle index = VarHandle("index", kInt);
554   ExprHandle load_a = a_buf.load(index);
555   ExprHandle load_b = b_buf.load(index);
556   StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true));
557   StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
558 
559   PaddedBuffer<int> a_v(kTotalSize);
560   PaddedBuffer<int> b_v(kTotalSize);
561   PaddedBuffer<int> c_v(kTotalSize);
562 
563   for (const auto i : c10::irange(kTotalSize)) {
564     a_v(i) = i;
565     b_v(i) = 2 * i + 1;
566   }
567 
568   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
569   ir_eval(a_v, b_v, c_v);
570 
571   for (const auto i : c10::irange(kTotalSize)) {
572     ASSERT_EQ(a_v(i), i);
573     ASSERT_EQ(b_v(i), 2 * i + 1);
574     ASSERT_EQ(c_v(i), std::min(a_v(i), b_v(i)));
575   }
576 }
577 
TEST(ATen,minFloat)578 TEST(ATen, minFloat) {
579   const int kTotalSize = 128;
580   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
581   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
582   BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat);
583 
584   VarHandle index = VarHandle("index", kInt);
585   ExprHandle load_a = a_buf.load(index);
586   ExprHandle load_b = b_buf.load(index);
587   StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true));
588   StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
589 
590   PaddedBuffer<float> a_v(kTotalSize);
591   PaddedBuffer<float> b_v(kTotalSize);
592   PaddedBuffer<float> c_v(kTotalSize);
593 
594   for (const auto i : c10::irange(kTotalSize)) {
595     a_v(i) = i;
596     b_v(i) = 2 * i + 1;
597   }
598 
599   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
600   ir_eval(a_v, b_v, c_v);
601 
602   for (const auto i : c10::irange(kTotalSize)) {
603     ASSERT_EQ(a_v(i), i);
604     ASSERT_EQ(b_v(i), 2 * i + 1);
605     ASSERT_EQ(c_v(i), std::fmin(a_v(i), b_v(i)));
606   }
607 }
608 
testATenreciprocal()609 void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() {
610   const int kTotalSize = 128;
611   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
612   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
613 
614   VarHandle index = VarHandle("index", kInt);
615   ExprHandle load_a = a_buf.load(index);
616   StmtPtr store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a);
617   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
618 
619   PaddedBuffer<float> a_v(kTotalSize);
620   PaddedBuffer<float> b_v(kTotalSize);
621 
622   for (const auto i : c10::irange(kTotalSize)) {
623     a_v(i) = i;
624   }
625 
626   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
627   ir_eval(a_v, b_v);
628 
629   for (const auto i : c10::irange(kTotalSize)) {
630     ASSERT_EQ(a_v(i), i);
631     ASSERT_EQ(b_v(i), 1.0f / i);
632   }
633 }
634 
TEST(ATen,reluInt)635 TEST(ATen, reluInt) {
636   const int kTotalSize = 128;
637   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
638   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
639 
640   VarHandle index = VarHandle("index", kInt);
641   ExprHandle load_a = a_buf.load(index);
642   StmtPtr store_b = b_buf.store({index}, Max::make(load_a, 0, false));
643   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
644 
645   PaddedBuffer<int> a_v(kTotalSize);
646   PaddedBuffer<int> b_v(kTotalSize);
647 
648   for (const auto i : c10::irange(kTotalSize)) {
649     a_v(i) = i - 64;
650   }
651 
652   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
653   ir_eval(a_v, b_v);
654 
655   for (const auto i : c10::irange(kTotalSize)) {
656     ASSERT_EQ(a_v(i), i - 64);
657     ASSERT_EQ(b_v(i), std::max(a_v(i), 0));
658   }
659 }
660 
TEST(ATen,reluFloat)661 TEST(ATen, reluFloat) {
662   const int kTotalSize = 128;
663   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
664   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
665 
666   VarHandle index = VarHandle("index", kInt);
667   ExprHandle load_a = a_buf.load(index);
668   StmtPtr store_b = b_buf.store(
669       {index}, Max::make(load_a, 0, false) // relu does not propagate nans
670   );
671   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
672 
673   PaddedBuffer<float> a_v(kTotalSize);
674   PaddedBuffer<float> b_v(kTotalSize);
675 
676   for (const auto i : c10::irange(kTotalSize)) {
677     a_v(i) = i - 64;
678   }
679 
680   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
681   ir_eval(a_v, b_v);
682 
683   for (const auto i : c10::irange(kTotalSize)) {
684     ASSERT_EQ(a_v(i), i - 64);
685     ASSERT_EQ(b_v(i), std::fmax(a_v(i), 0));
686   }
687 }
688 
TEST(ATen,logFloat)689 TEST(ATen, logFloat) {
690   const int kTotalSize = 128;
691   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
692   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
693 
694   VarHandle index = VarHandle("index", kInt);
695   ExprHandle load_a = a_buf.load(index);
696   StmtPtr store_b = b_buf.store({index}, log(load_a));
697   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
698 
699   PaddedBuffer<float> a_v(kTotalSize);
700   PaddedBuffer<float> b_v(kTotalSize);
701 
702   for (const auto i : c10::irange(kTotalSize)) {
703     a_v(i) = i + 10;
704   }
705 
706   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
707   ir_eval(a_v, b_v);
708 
709   for (const auto i : c10::irange(kTotalSize)) {
710     ASSERT_EQ(a_v(i), i + 10);
711     ASSERT_EQ(b_v(i), std::log(a_v(i)));
712   }
713 }
714 
TEST(ATen,fastLogFloat)715 TEST(ATen, fastLogFloat) {
716   const int kTotalSize = 128;
717   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
718   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
719 
720   VarHandle index = VarHandle("index", kInt);
721   ExprHandle load_a = a_buf.load(index);
722   StmtPtr store_b = b_buf.store({index}, fast_log(load_a));
723   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
724 
725   PaddedBuffer<float> a_v(kTotalSize);
726   PaddedBuffer<float> b_v(kTotalSize);
727 
728   for (const auto i : c10::irange(kTotalSize)) {
729     a_v(i) = at::randn({1}).item().to<float>();
730   }
731 
732   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
733   ir_eval(a_v, b_v);
734 
735   for (const auto i : c10::irange(kTotalSize)) {
736     auto test = b_v(i);
737     auto ref = std::log(a_v(i));
738     if (std::isnan(ref)) {
739       ASSERT_EQ(std::isnan(test), true);
740     } else {
741       ASSERT_FLOAT_EQ(test, ref);
742     }
743   }
744 }
745 
TEST(ATen,fastTanhFloat)746 TEST(ATen, fastTanhFloat) {
747   const int kTotalSize = 128;
748   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
749   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
750 
751   VarHandle index = VarHandle("index", kInt);
752   ExprHandle load_a = a_buf.load(index);
753   StmtPtr store_b = b_buf.store({index}, fast_tanh(load_a));
754   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
755 
756   PaddedBuffer<float> a_v(kTotalSize);
757   PaddedBuffer<float> b_v(kTotalSize);
758 
759   for (const auto i : c10::irange(kTotalSize)) {
760     a_v(i) = at::randn({1}).item().to<float>();
761   }
762 
763   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
764   ir_eval(a_v, b_v);
765 
766   for (const auto i : c10::irange(kTotalSize)) {
767     auto test = b_v(i);
768     auto ref = std::tanh(a_v(i));
769     if (std::isnan(ref)) {
770       ASSERT_EQ(std::isnan(test), true);
771     } else {
772       ASSERT_NEAR(test, ref, 1e-6);
773     }
774   }
775 }
776 
TEST(ATen,fastSigmoidFloat)777 TEST(ATen, fastSigmoidFloat) {
778   const int kTotalSize = 128;
779   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
780   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
781 
782   VarHandle index = VarHandle("index", kInt);
783   ExprHandle load_a = a_buf.load(index);
784   StmtPtr store_b = b_buf.store({index}, fast_sigmoid(load_a));
785   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
786 
787   PaddedBuffer<float> a_v(kTotalSize);
788   PaddedBuffer<float> b_v(kTotalSize);
789 
790   for (const auto i : c10::irange(kTotalSize)) {
791     a_v(i) = at::randn({1}).item().to<float>();
792   }
793 
794   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
795   ir_eval(a_v, b_v);
796 
797   for (const auto i : c10::irange(kTotalSize)) {
798     auto test = b_v(i);
799     at::Tensor t = at::ones({1}) * a_v(i);
800     float ref = at::sigmoid(t).item().to<float>();
801     if (std::isnan(ref)) {
802       ASSERT_EQ(std::isnan(test), true);
803     } else {
804       ASSERT_NEAR(test, ref, 1e-6);
805     }
806   }
807 }
808 
TEST(ATen,log10Float)809 TEST(ATen, log10Float) {
810   const int kTotalSize = 128;
811   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
812   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
813 
814   VarHandle index = VarHandle("index", kInt);
815   ExprHandle load_a = a_buf.load(index);
816   StmtPtr store_b = b_buf.store({index}, log10(load_a));
817   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
818 
819   PaddedBuffer<float> a_v(kTotalSize);
820   PaddedBuffer<float> b_v(kTotalSize);
821 
822   for (const auto i : c10::irange(kTotalSize)) {
823     a_v(i) = i + 10;
824   }
825 
826   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
827   ir_eval(a_v, b_v);
828 
829   for (const auto i : c10::irange(kTotalSize)) {
830     ASSERT_EQ(a_v(i), i + 10);
831     ASSERT_EQ(b_v(i), std::log10(a_v(i)));
832   }
833 }
834 
TEST(ATen,log2Float)835 TEST(ATen, log2Float) {
836   const int kTotalSize = 128;
837   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
838   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
839 
840   VarHandle index = VarHandle("index", kInt);
841   ExprHandle load_a = a_buf.load(index);
842   StmtPtr store_b = b_buf.store({index}, log2(load_a));
843   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
844 
845   PaddedBuffer<float> a_v(kTotalSize);
846   PaddedBuffer<float> b_v(kTotalSize);
847 
848   for (const auto i : c10::irange(kTotalSize)) {
849     a_v(i) = i + 10;
850   }
851 
852   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
853   ir_eval(a_v, b_v);
854 
855   for (const auto i : c10::irange(kTotalSize)) {
856     ASSERT_EQ(a_v(i), i + 10);
857     ASSERT_EQ(b_v(i), std::log2(a_v(i)));
858   }
859 }
860 
TEST(ATen,expFloat)861 TEST(ATen, expFloat) {
862   const int kTotalSize = 128;
863   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
864   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
865 
866   VarHandle index = VarHandle("index", kInt);
867   ExprHandle load_a = a_buf.load(index);
868   StmtPtr store_b = b_buf.store({index}, exp(load_a));
869   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
870 
871   PaddedBuffer<float> a_v(kTotalSize);
872   PaddedBuffer<float> b_v(kTotalSize);
873 
874   for (const auto i : c10::irange(kTotalSize)) {
875     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
876     a_v(i) = i / 10.0f;
877   }
878 
879   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
880   ir_eval(a_v, b_v);
881 
882   for (const auto i : c10::irange(kTotalSize)) {
883     ASSERT_EQ(a_v(i), i / 10.0f);
884     ASSERT_EQ(b_v(i), std::exp(a_v(i)));
885   }
886 }
887 
TEST(ATen,erfFloat)888 TEST(ATen, erfFloat) {
889   const int kTotalSize = 128;
890   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
891   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
892 
893   VarHandle index = VarHandle("index", kInt);
894   ExprHandle load_a = a_buf.load(index);
895   StmtPtr store_b = b_buf.store({index}, erf(load_a));
896   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
897 
898   PaddedBuffer<float> a_v(kTotalSize);
899   PaddedBuffer<float> b_v(kTotalSize);
900 
901   for (const auto i : c10::irange(kTotalSize)) {
902     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
903     a_v(i) = i / 10.0f;
904   }
905 
906   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
907   ir_eval(a_v, b_v);
908 
909   for (const auto i : c10::irange(kTotalSize)) {
910     ASSERT_EQ(a_v(i), i / 10.0f);
911     ASSERT_EQ(b_v(i), std::erf(a_v(i)));
912   }
913 }
914 
TEST(ATen,cosFloat)915 TEST(ATen, cosFloat) {
916   const int kTotalSize = 128;
917   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
918   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
919 
920   VarHandle index = VarHandle("index", kInt);
921   ExprHandle load_a = a_buf.load(index);
922   StmtPtr store_b = b_buf.store({index}, cos(load_a));
923   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
924 
925   PaddedBuffer<float> a_v(kTotalSize);
926   PaddedBuffer<float> b_v(kTotalSize);
927 
928   for (const auto i : c10::irange(kTotalSize)) {
929     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
930     a_v(i) = i / 10.0f;
931   }
932 
933   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf});
934   ir_eval(a_v, b_v);
935 
936   for (const auto i : c10::irange(kTotalSize)) {
937     ASSERT_EQ(a_v(i), i / 10.0f);
938     ASSERT_EQ(b_v(i), std::cos(a_v(i)));
939   }
940 }
941 
TEST(ATen,eqInt)942 TEST(ATen, eqInt) {
943   constexpr int N = 128;
944   BufHandle a("A", {N}, kInt);
945   BufHandle b("B", {N}, kInt);
946   BufHandle c("C", {N}, kInt);
947   std::vector<int> a_buffer(N, 1);
948   std::vector<int> b_buffer(N, 1);
949   std::vector<int> c_buffer(N, 0);
950 
951   VarHandle i("i", kInt);
952   auto memcpy_expr = For::make(
953       i,
954       0,
955       N,
956       c.store(
957           {i},
958           CompareSelect::make(
959               a.load(i), b.load(i), CompareSelectOperation::kEQ)));
960 
961   SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
962   ir_eval(a_buffer, b_buffer, c_buffer);
963 
964   assertAllEqual(c_buffer, 1);
965 }
966 
TEST(ATen,geInt)967 TEST(ATen, geInt) {
968   constexpr int N = 128;
969   BufHandle a("A", {N}, kInt);
970   BufHandle b("B", {N}, kInt);
971   BufHandle c("C", {N}, kInt);
972   std::vector<int> a_buffer(N, 5);
973   std::vector<int> b_buffer(N, 5);
974   std::vector<int> c_buffer(N, 0);
975 
976   VarHandle i("i", kInt);
977   auto memcpy_expr = For::make(
978       i,
979       0,
980       N,
981       c.store(
982           {i},
983           CompareSelect::make(
984               a.load(i), b.load(i), CompareSelectOperation::kGE)));
985 
986   SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
987   ir_eval(a_buffer, b_buffer, c_buffer);
988 
989   assertAllEqual(c_buffer, 1);
990 }
991 
TEST(ATen,gtInt)992 TEST(ATen, gtInt) {
993   constexpr int N = 128;
994   BufHandle a("A", {N}, kInt);
995   BufHandle b("B", {N}, kInt);
996   BufHandle c("C", {N}, kInt);
997   std::vector<int> a_buffer(N, 6);
998   std::vector<int> b_buffer(N, 3);
999   std::vector<int> c_buffer(N, 0);
1000 
1001   VarHandle i("i", kInt);
1002   auto memcpy_expr = For::make(
1003       i,
1004       0,
1005       N,
1006       c.store(
1007           {i},
1008           CompareSelect::make(
1009               a.load(i), b.load(i), CompareSelectOperation::kGT)));
1010 
1011   SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
1012   ir_eval(a_buffer, b_buffer, c_buffer);
1013 
1014   assertAllEqual(c_buffer, 1);
1015 }
1016 
TEST(ATen,leInt)1017 TEST(ATen, leInt) {
1018   constexpr int N = 128;
1019   BufHandle a("A", {N}, kInt);
1020   BufHandle b("B", {N}, kInt);
1021   BufHandle c("C", {N}, kInt);
1022   std::vector<int> a_buffer(N, 5);
1023   std::vector<int> b_buffer(N, 5);
1024   std::vector<int> c_buffer(N, 0);
1025 
1026   VarHandle i("i", kInt);
1027   auto memcpy_expr = For::make(
1028       i,
1029       0,
1030       N,
1031       c.store(
1032           {i},
1033           CompareSelect::make(
1034               a.load(i), b.load(i), CompareSelectOperation::kLE)));
1035 
1036   SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
1037   ir_eval(a_buffer, b_buffer, c_buffer);
1038 
1039   assertAllEqual(c_buffer, 1);
1040 }
1041 
TEST(ATen,ltInt)1042 TEST(ATen, ltInt) {
1043   constexpr int N = 128;
1044   BufHandle a("A", {N}, kInt);
1045   BufHandle b("B", {N}, kInt);
1046   BufHandle c("C", {N}, kInt);
1047   std::vector<int> a_buffer(N, 5);
1048   std::vector<int> b_buffer(N, 5);
1049   std::vector<int> c_buffer(N, 1);
1050 
1051   VarHandle i("i", kInt);
1052   auto memcpy_expr = For::make(
1053       i,
1054       0,
1055       N,
1056       c.store(
1057           {i},
1058           CompareSelect::make(
1059               a.load(i), b.load(i), CompareSelectOperation::kLT)));
1060 
1061   SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
1062   ir_eval(a_buffer, b_buffer, c_buffer);
1063 
1064   assertAllEqual(c_buffer, 0);
1065 }
1066 
1067 } // namespace jit
1068 } // namespace torch
1069