1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/operators/misc.h> 4 #include <torch/csrc/jit/tensorexpr/tensor.h> 5 6 namespace torch { 7 namespace jit { 8 namespace tensorexpr { 9 10 // An API to compute 2D depthwise convolutions with bias. 11 TORCH_API Tensor conv2d_depthwise( 12 BufHandle input, 13 BufHandle weight, 14 BufHandle bias, 15 int stride, 16 int pad, 17 int groups); 18 19 // An API to compute 2D depthwise convolutions without bias. 20 TORCH_API Tensor conv2d_depthwise( 21 BufHandle input, 22 BufHandle weight, 23 int stride, 24 int pad, 25 int groups); 26 27 TORCH_API Tensor conv2d_depthwise( 28 BufHandle input, 29 BufHandle weight, 30 BufHandle bias, 31 ExprHandle N, 32 ExprHandle C, 33 ExprHandle H, 34 ExprHandle W, 35 ExprHandle K, 36 ExprHandle CperG, 37 ExprHandle R, 38 ExprHandle S, 39 ExprHandle stride, 40 ExprHandle pad, 41 ExprHandle groups); 42 43 TORCH_API Tensor conv2d_depthwise( 44 BufHandle input, 45 BufHandle weight, 46 ExprHandle N, 47 ExprHandle C, 48 ExprHandle H, 49 ExprHandle W, 50 ExprHandle K, 51 ExprHandle CperG, 52 ExprHandle R, 53 ExprHandle S, 54 ExprHandle stride, 55 ExprHandle pad, 56 ExprHandle groups); 57 58 bool conv2dIsSupported( 59 const TensorInfo& input, 60 const TensorInfo& weight, 61 const TensorInfo& bias, 62 const std::vector<int64_t>& stride, 63 const std::vector<int64_t>& pad, 64 const std::vector<int64_t>& dilation, 65 int64_t groups); 66 bool mkldnnPrepackedConvIsSupported( 67 const TensorInfo& input, 68 const TensorInfo& weight, 69 const std::vector<int64_t>& stride, 70 const std::vector<int64_t>& pad, 71 const std::vector<int64_t>& dilation, 72 int64_t groups); 73 Tensor computeConv2d( 74 const std::vector<ArgValue>& inputs, 75 const std::vector<ExprHandle>& outputShape, 76 const std::vector<ExprHandle>& outputStrides, 77 const std::optional<ScalarType>& outputType, 78 at::Device device); 79 Tensor computeConv1d( 80 const std::vector<ArgValue>& inputs, 81 const std::vector<ExprHandle>& outputShape, 82 const std::vector<ExprHandle>& outputStrides, 83 const std::optional<ScalarType>& outputType, 84 at::Device device); 85 Tensor computePrepackedConv2dClampRun( 86 const std::vector<ArgValue>& inputs, 87 const std::vector<ExprHandle>& outputShape, 88 const std::vector<ExprHandle>& outputStrides, 89 const std::optional<ScalarType>& outputType, 90 at::Device device); 91 Tensor computePrepackedLinearClampRun( 92 const std::vector<ArgValue>& inputs, 93 const std::vector<ExprHandle>& outputShape, 94 const std::vector<ExprHandle>& outputStrides, 95 const std::optional<ScalarType>& outputType, 96 at::Device device); 97 Tensor computeMkldnnPrepackedConvRun( 98 const std::vector<ArgValue>& inputs, 99 const std::vector<ExprHandle>& outputShape, 100 const std::vector<ExprHandle>& outputStrides, 101 const std::optional<ScalarType>& outputType, 102 at::Device device); 103 } // namespace tensorexpr 104 } // namespace jit 105 } // namespace torch 106