1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <limits> 5 6 namespace at::native::xnnpack { 7 8 // 9 // Convolution 10 // 11 12 bool use_convolution2d( 13 const Tensor& input, 14 const Tensor& weight, 15 const at::OptionalIntArrayRef bias_sizes_opt, 16 const IntArrayRef padding, 17 const IntArrayRef stride, 18 const IntArrayRef dilation, 19 const int64_t groups, 20 const bool transposed); 21 22 Tensor convolution2d( 23 const Tensor& input, 24 const Tensor& weight, 25 const Tensor& bias, 26 const IntArrayRef padding, 27 const IntArrayRef stride, 28 const IntArrayRef dilation, 29 const int64_t groups); 30 31 // 32 // Linear 33 // 34 35 bool use_linear( 36 const Tensor& input, 37 const Tensor& weight, 38 const Tensor& bias); 39 40 Tensor linear( 41 const Tensor& input, 42 const Tensor& weight, 43 const Tensor& bias); 44 45 // 46 // Max Pooling 47 // 48 49 bool use_max_pool2d( 50 const Tensor& input, 51 const IntArrayRef kernel, 52 const IntArrayRef padding, 53 IntArrayRef stride, 54 const IntArrayRef dilation, 55 const bool ceil_mode, 56 const float output_min = -std::numeric_limits<float>::infinity(), 57 const float output_max = +std::numeric_limits<float>::infinity()); 58 59 Tensor max_pool2d( 60 const Tensor& input, 61 const IntArrayRef kernel, 62 const IntArrayRef padding, 63 IntArrayRef stride, 64 const IntArrayRef dilation, 65 const bool ceil_mode, 66 const float output_min = -std::numeric_limits<float>::infinity(), 67 const float output_max = +std::numeric_limits<float>::infinity()); 68 69 // 70 // Global Average Pooling 71 // 72 73 bool use_global_average_pool(const Tensor& input); 74 Tensor global_average_pool(const Tensor& input); 75 76 // 77 // Channel Shuffle 78 // 79 80 bool use_channel_shuffle( 81 const Tensor& input, 82 const int64_t groups); 83 84 Tensor channel_shuffle( 85 const Tensor& input, 86 const int64_t groups); 87 88 // 89 // Activations 90 // 91 bool use_hardswish(const Tensor& input); 92 Tensor hardswish(const Tensor& input); 93 Tensor& hardswish_(Tensor& input); 94 95 } // namespace at::native::xnnpack 96