1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator 10 #include <executorch/kernels/test/TestUtil.h> 11 #include <executorch/runtime/core/exec_aten/exec_aten.h> 12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h> 13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h> 14 15 #include <gtest/gtest.h> 16 17 using namespace ::testing; 18 using exec_aten::IntArrayRef; 19 using exec_aten::ScalarType; 20 using exec_aten::Tensor; 21 using torch::executor::testing::TensorFactory; 22 23 class OpOnesOutTest : public OperatorTest { 24 protected: op_ones_out(IntArrayRef size,Tensor & out)25 Tensor& op_ones_out(IntArrayRef size, Tensor& out) { 26 return torch::executor::aten::ones_outf(context_, size, out); 27 } 28 29 template <ScalarType DTYPE> test_ones_out(std::vector<int32_t> && size_int32_t)30 void test_ones_out(std::vector<int32_t>&& size_int32_t) { 31 TensorFactory<DTYPE> tf; 32 std::vector<int64_t> size_int64_t(size_int32_t.begin(), size_int32_t.end()); 33 auto aref = IntArrayRef(size_int64_t.data(), size_int64_t.size()); 34 35 // Before: `out` consists of 0s. 36 Tensor out = tf.zeros(size_int32_t); 37 38 // After: `out` consists of 1s. 39 op_ones_out(aref, out); 40 41 EXPECT_TENSOR_EQ(out, tf.ones(size_int32_t)); 42 } 43 }; 44 45 #define GENERATE_TEST(_, DTYPE) \ 46 TEST_F(OpOnesOutTest, DTYPE##Tensors) { \ 47 test_ones_out<ScalarType::DTYPE>({}); \ 48 test_ones_out<ScalarType::DTYPE>({1}); \ 49 test_ones_out<ScalarType::DTYPE>({1, 1, 1}); \ 50 test_ones_out<ScalarType::DTYPE>({2, 0, 4}); \ 51 test_ones_out<ScalarType::DTYPE>({2, 3, 4}); \ 52 } 53 54 ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_TEST) 55