1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/kernel.h> 4 5 namespace torch { 6 namespace jit { 7 namespace tensorexpr { 8 9 TORCH_API Tensor computeSum( 10 const std::vector<ArgValue>& inputs, 11 const std::vector<ExprHandle>& outputShape, 12 const std::vector<ExprHandle>& outputStrides, 13 const std::optional<ScalarType>& outputType, 14 at::Device device); 15 TORCH_API Tensor computeMean( 16 const std::vector<ArgValue>& inputs, 17 const std::vector<ExprHandle>& outputShape, 18 const std::vector<ExprHandle>& outputStrides, 19 const std::optional<ScalarType>& outputType, 20 at::Device device); 21 TORCH_API Tensor computeAdaptiveAvgPool2d( 22 const std::vector<ArgValue>& inputs, 23 const std::vector<ExprHandle>& outputShape, 24 const std::vector<ExprHandle>& outputStrides, 25 const std::optional<ScalarType>& outputType, 26 at::Device device); 27 Tensor computeMax( 28 const std::vector<ArgValue>& inputs, 29 const std::vector<ExprHandle>& outputShape, 30 const std::vector<ExprHandle>& outputStrides, 31 const std::optional<ScalarType>& outputType, 32 at::Device device); 33 34 } // namespace tensorexpr 35 } // namespace jit 36 } // namespace torch 37