/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include // Declares the operator #include #include #include #include #include #include #include #include #include using namespace ::testing; using exec_aten::ArrayRef; using exec_aten::Scalar; using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::testing::TensorFactory; class OpMmOutTest : public OperatorTest { protected: Tensor& op_mm_out(const Tensor& self, const Tensor& mat2, Tensor& out) { return torch::executor::aten::mm_outf(context_, self, mat2, out); } template void test_dtype() { TensorFactory tf; if (torch::executor::testing::SupportedFeatures::get()->is_aten) { if (DTYPE == ScalarType::Half) { GTEST_SKIP() << "skip Half because torch::executor::aten::mm_out does not support Half"; return; } } // matmul gives 4 * 2 * 3 = 24 Tensor x = tf.full({3, 4}, 2); Tensor y = tf.full({4, 5}, 3); // Output shape should be (3, 5) Tensor out = tf.zeros({3, 5}); op_mm_out(x, y, out); Tensor expected = tf.full({3, 5}, 24); EXPECT_TENSOR_EQ(out, expected); } }; TEST_F(OpMmOutTest, OutputDim) { TensorFactory tf; // 3 tensors with compatible dimensions: (3, 5), (3, 4) and (4, 5). Tensor x = tf.ones({3, 4}); Tensor y = tf.ones({4, 5}); Tensor out = tf.zeros({3, 5}); Tensor ret = op_mm_out(x, y, out); // Should always return the provided out Tensor. EXPECT_TENSOR_EQ(ret, out); // Expected tensor, filled with 4. Tensor expected = tf.full({3, 5}, 4); EXPECT_TENSOR_EQ(out, expected); } /// A generic smoke test that works for any dtype that supports ones() and /// zeros(). TEST_F(OpMmOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY // TODO: Also add tests for half, complex, quantized, and other types. Easiest // way to do that would be to make TensorFactory support zeros() and ones() // for those types. } TEST_F(OpMmOutTest, EmptyInputWithEmptyOutTensorPasses) { TensorFactory tf; // Empty input matrices Tensor x = tf.make({0, 3}, {}); Tensor y = tf.make({3, 0}, {}); // Output matrix is also empty Tensor out = tf.make({0, 0}, {}); Tensor expected = tf.make({0, 0}, {}); EXPECT_TENSOR_EQ(op_mm_out(x, y, out), expected); } TEST_F(OpMmOutTest, InfinityTensorPasses) { TensorFactory tff; Tensor x = tff.full({3, 4}, std::numeric_limits::infinity()); Tensor y = tff.full({4, 5}, 3); // Output shape should be (3, 5) Tensor out = tff.zeros({3, 5}); Tensor expected = tff.full({3, 5}, std::numeric_limits::infinity()); EXPECT_TENSOR_EQ(op_mm_out(x, y, out), expected); } TEST_F(OpMmOutTest, MismatchedDimensionsDies) { TensorFactory tf; Tensor x = tf.full({2, 2}, 3); Tensor wrong_y = tf.full({3, 1}, 1); Tensor right_y = tf.full({2, 2}, 1); // Make an empty out tensor and demonstrate that it's empty. Tensor out = tf.full({2, 2}, 0); Tensor expected = tf.full({2, 2}, 6); ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, wrong_y, out)); EXPECT_TENSOR_EQ(op_mm_out(x, right_y, out), expected); } TEST_F(OpMmOutTest, MismatchedDimensionSizeDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched dimension size"; } TensorFactory tf; Tensor x = tf.full({2, 2}, 3); // wrong_y has incompatible dim Tensor wrong_y = tf.full({2, 2, 2}, 1); Tensor right_y = tf.full({2, 2}, 1); // wrong_out has incompatible dim Tensor right_out = tf.ones({2, 2}); Tensor wrong_out = tf.ones({2, 2, 3}); ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, right_y, wrong_out)); ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, wrong_y, right_out)); } TEST_F(OpMmOutTest, WrongOutShapeDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle wrong out shape"; } TensorFactory tf; Tensor x = tf.ones({10, 3}); Tensor y = tf.ones({3, 4}); // wrong_out has incompatible shape Tensor right_out = tf.ones({10, 4}); Tensor wrong_out = tf.ones({7, 5}); ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, y, wrong_out)); EXPECT_TENSOR_EQ(op_mm_out(x, y, right_out), tf.full({10, 4}, 3)); } TEST_F(OpMmOutTest, DynamicShapeUpperBoundSameAsExpected) { TensorFactory tf; Tensor x = tf.make( {3, 2}, {0.17412060499191284, 0.34793388843536377, 0.8187907934188843, 0.9979893565177917, 0.7049332857131958, 0.4255824089050293}); Tensor y = tf.make( {2, 4}, {0.8071839213371277, 0.13667285442352295, 0.9002121090888977, 0.9070476293563843, 0.31638312339782715, 0.3691965937614441, 0.09420186281204224, 0.9310881495475769}); Tensor expected_result = tf.make( {3, 4}, {0.2506277561187744, 0.15225356817245483, 0.18952149152755737, 0.48189279437065125, 0.976661741733551, 0.480360746383667, 0.8310978412628174, 1.6718982458114624, 0.703657865524292, 0.2534688115119934, 0.6746801733970642, 1.0356627702713013}); Tensor out = tf.zeros({3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); Tensor ret = op_mm_out(x, y, out); EXPECT_TENSOR_CLOSE(out, expected_result); } TEST_F(OpMmOutTest, DynamicShapeUpperBoundLargerThanExpected) { TensorFactory tf; Tensor x = tf.make( {3, 2}, {0.17412060499191284, 0.34793388843536377, 0.8187907934188843, 0.9979893565177917, 0.7049332857131958, 0.4255824089050293}); Tensor y = tf.make( {2, 4}, {0.8071839213371277, 0.13667285442352295, 0.9002121090888977, 0.9070476293563843, 0.31638312339782715, 0.3691965937614441, 0.09420186281204224, 0.9310881495475769}); Tensor expected_result = tf.make( {3, 4}, {0.2506277561187744, 0.15225356817245483, 0.18952149152755737, 0.48189279437065125, 0.976661741733551, 0.480360746383667, 0.8310978412628174, 1.6718982458114624, 0.703657865524292, 0.2534688115119934, 0.6746801733970642, 1.0356627702713013}); Tensor out = tf.zeros({10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); Tensor ret = op_mm_out(x, y, out); EXPECT_TENSOR_CLOSE(out, expected_result); } TEST_F(OpMmOutTest, DynamicShapeUnbound) { GTEST_SKIP() << "Dynamic shape not supported"; TensorFactory tf; Tensor x = tf.make( {3, 2}, {0.17412060499191284, 0.34793388843536377, 0.8187907934188843, 0.9979893565177917, 0.7049332857131958, 0.4255824089050293}); Tensor y = tf.make( {2, 4}, {0.8071839213371277, 0.13667285442352295, 0.9002121090888977, 0.9070476293563843, 0.31638312339782715, 0.3691965937614441, 0.09420186281204224, 0.9310881495475769}); Tensor expected_result = tf.make( {3, 4}, {0.2506277561187744, 0.15225356817245483, 0.18952149152755737, 0.48189279437065125, 0.976661741733551, 0.480360746383667, 0.8310978412628174, 1.6718982458114624, 0.703657865524292, 0.2534688115119934, 0.6746801733970642, 1.0356627702713013}); Tensor out = tf.zeros({1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND); Tensor ret = op_mm_out(x, y, out); EXPECT_TENSOR_CLOSE(out, expected_result); }