1 #pragma once 2 3 namespace at::native { inline namespace CPU_CAPABILITY { 4 5 // n: number of function arguments (arity) 6 // traits: function_traits (see FunctionTraits.h) 7 // s: index of scalar argument or -1 8 template <int n, int stride_index, typename traits, int s=-1> 9 struct IsContiguous { evalIsContiguous10 static bool eval(const int64_t* strides) { 11 using type = typename traits::template arg<n - 1>::type; 12 return strides[stride_index] == (s == n ? 0 : sizeof(type)) && 13 IsContiguous<n - 1, stride_index - 1, traits, s>::eval(strides); 14 } 15 }; 16 17 // will be called when there is an output exists 18 template <typename traits, int s> 19 struct IsContiguous<0, 0, traits, s> { 20 static bool eval(const int64_t* strides) { 21 return strides[0] == sizeof(typename traits::result_type); 22 } 23 }; 24 25 // will be called when there is no output 26 template <typename traits, int s> 27 struct IsContiguous<0, -1, traits, s> { 28 static bool eval(const int64_t* /*strides*/) { 29 return true; 30 } 31 }; 32 33 // output and all inputs are contiguous 34 template <typename traits, 35 typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr> 36 static inline bool is_contiguous(const int64_t* strides) { 37 return IsContiguous<traits::arity, traits::arity - 1, traits>::eval(strides); 38 } 39 40 template <typename traits, 41 typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr> 42 static inline bool is_contiguous(const int64_t* strides) { 43 return IsContiguous<traits::arity, traits::arity, traits>::eval(strides); 44 } 45 46 // input at `s` is scalar (stride 0); output and other inputs are contiguous 47 // NB: output is typically at strides[0] so first input corresponds to s=1 48 template <typename traits, int s, 49 typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr> 50 static inline bool is_contiguous_scalar(const int64_t* strides) { 51 static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); 52 return IsContiguous<traits::arity, traits::arity - 1, traits, s>::eval(strides); 53 } 54 55 template <typename traits, int s, 56 typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr> 57 static inline bool is_contiguous_scalar(const int64_t* strides) { 58 static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); 59 return IsContiguous<traits::arity, traits::arity, traits, s>::eval(strides); 60 } 61 62 }} 63