1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/kernel.h> 4 5 namespace torch { 6 namespace jit { 7 namespace tensorexpr { 8 9 Tensor computeMatmul( 10 const std::vector<ArgValue>& inputs, 11 const std::vector<ExprHandle>& outputShape, 12 const std::vector<ExprHandle>& outputStrides, 13 const std::optional<ScalarType>& outputType, 14 at::Device device); 15 Tensor computeAddMM( 16 const std::vector<ArgValue>& inputs, 17 const std::vector<ExprHandle>& outputShape, 18 const std::vector<ExprHandle>& outputStrides, 19 const std::optional<ScalarType>& outputType, 20 at::Device device); 21 22 } // namespace tensorexpr 23 } // namespace jit 24 } // namespace torch 25