1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/runtime/kernel/kernel_includes.h> 12 13 namespace torch { 14 namespace executor { 15 16 bool check_addmm_args( 17 const Tensor& in, 18 const Tensor& mat1, 19 const Tensor& mat2, 20 const Scalar& beta, 21 const Scalar& alpha, 22 Tensor& out); 23 24 bool check_bmm_args(const Tensor& in, const Tensor& mat2, Tensor& out); 25 26 void get_bmm_out_target_size( 27 const Tensor& mat1, 28 const Tensor& mat2, 29 Tensor::SizesType* out_sizes, 30 size_t* out_ndim); 31 32 bool check_mm_args(const Tensor& in, const Tensor& mat2, Tensor& out); 33 34 void get_mm_out_target_size( 35 const Tensor& mat1, 36 const Tensor& mat2, 37 Tensor::SizesType* out_sizes, 38 size_t* out_ndim); 39 40 bool check_linear_args(const Tensor& in, const Tensor& mat2, Tensor& out); 41 42 void get_linear_out_target_size( 43 const Tensor& mat1, 44 const Tensor& mat2, 45 Tensor::SizesType* out_sizes, 46 size_t* out_ndim); 47 48 } // namespace executor 49 } // namespace torch 50