1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2019-2021 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #ifndef ARM_COMPUTE_MISC_RANDOM_H 25*c217d954SCole Faust #define ARM_COMPUTE_MISC_RANDOM_H 26*c217d954SCole Faust 27*c217d954SCole Faust #include "arm_compute/core/Error.h" 28*c217d954SCole Faust #include "utils/Utils.h" 29*c217d954SCole Faust 30*c217d954SCole Faust #include <random> 31*c217d954SCole Faust #include <type_traits> 32*c217d954SCole Faust 33*c217d954SCole Faust namespace arm_compute 34*c217d954SCole Faust { 35*c217d954SCole Faust namespace utils 36*c217d954SCole Faust { 37*c217d954SCole Faust namespace random 38*c217d954SCole Faust { 39*c217d954SCole Faust /** Uniform distribution within a given number of sub-ranges 40*c217d954SCole Faust * 41*c217d954SCole Faust * @tparam T Distribution primitive type 42*c217d954SCole Faust */ 43*c217d954SCole Faust template <typename T> 44*c217d954SCole Faust class RangedUniformDistribution 45*c217d954SCole Faust { 46*c217d954SCole Faust public: 47*c217d954SCole Faust static constexpr bool is_fp_16bit = std::is_same<T, half>::value || std::is_same<T, bfloat16>::value; 48*c217d954SCole Faust static constexpr bool is_integral = std::is_integral<T>::value && !is_fp_16bit; 49*c217d954SCole Faust 50*c217d954SCole Faust using fp_dist = typename std::conditional<is_fp_16bit, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type; 51*c217d954SCole Faust using DT = typename std::conditional<is_integral, std::uniform_int_distribution<T>, fp_dist>::type; 52*c217d954SCole Faust using result_type = T; 53*c217d954SCole Faust using range_pair = std::pair<result_type, result_type>; 54*c217d954SCole Faust 55*c217d954SCole Faust /** Constructor 56*c217d954SCole Faust * 57*c217d954SCole Faust * @param[in] low lowest value in the range (inclusive) 58*c217d954SCole Faust * @param[in] high highest value in the range (inclusive for uniform_int_distribution, exclusive for uniform_real_distribution) 59*c217d954SCole Faust * @param[in] exclude_ranges Ranges to exclude from the generator 60*c217d954SCole Faust */ RangedUniformDistribution(result_type low,result_type high,const std::vector<range_pair> & exclude_ranges)61*c217d954SCole Faust RangedUniformDistribution(result_type low, result_type high, const std::vector<range_pair> &exclude_ranges) 62*c217d954SCole Faust : _distributions(), _selector() 63*c217d954SCole Faust { 64*c217d954SCole Faust result_type clow = low; 65*c217d954SCole Faust for(const auto &erange : exclude_ranges) 66*c217d954SCole Faust { 67*c217d954SCole Faust result_type epsilon = is_integral ? result_type(1) : result_type(std::numeric_limits<T>::epsilon()); 68*c217d954SCole Faust 69*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(clow > erange.first || clow >= erange.second); 70*c217d954SCole Faust 71*c217d954SCole Faust _distributions.emplace_back(DT(clow, erange.first - epsilon)); 72*c217d954SCole Faust clow = erange.second + epsilon; 73*c217d954SCole Faust } 74*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(clow > high); 75*c217d954SCole Faust _distributions.emplace_back(DT(clow, high)); 76*c217d954SCole Faust _selector = std::uniform_int_distribution<uint32_t>(0, _distributions.size() - 1); 77*c217d954SCole Faust } 78*c217d954SCole Faust /** Generate random number 79*c217d954SCole Faust * 80*c217d954SCole Faust * @tparam URNG Random number generator object type 81*c217d954SCole Faust * 82*c217d954SCole Faust * @param[in] g A uniform random number generator object, used as the source of randomness. 83*c217d954SCole Faust * 84*c217d954SCole Faust * @return A new random number. 85*c217d954SCole Faust */ 86*c217d954SCole Faust template <class URNG> operator()87*c217d954SCole Faust result_type operator()(URNG &g) 88*c217d954SCole Faust { 89*c217d954SCole Faust unsigned int rand_select = _selector(g); 90*c217d954SCole Faust return _distributions[rand_select](g); 91*c217d954SCole Faust } 92*c217d954SCole Faust 93*c217d954SCole Faust private: 94*c217d954SCole Faust std::vector<DT> _distributions; 95*c217d954SCole Faust std::uniform_int_distribution<uint32_t> _selector; 96*c217d954SCole Faust }; 97*c217d954SCole Faust } // namespace random 98*c217d954SCole Faust } // namespace utils 99*c217d954SCole Faust } // namespace arm_compute 100*c217d954SCole Faust #endif /* ARM_COMPUTE_MISC_RANDOM_H */ 101