1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2021 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #ifndef SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H 25*c217d954SCole Faust #define SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H 26*c217d954SCole Faust 27*c217d954SCole Faust #include "src/runtime/CL/mlgo/Common.h" 28*c217d954SCole Faust #include "src/runtime/CL/mlgo/HeuristicTree.h" 29*c217d954SCole Faust 30*c217d954SCole Faust #include <iostream> 31*c217d954SCole Faust #include <map> 32*c217d954SCole Faust #include <string> 33*c217d954SCole Faust #include <utility> 34*c217d954SCole Faust namespace arm_compute 35*c217d954SCole Faust { 36*c217d954SCole Faust namespace mlgo 37*c217d954SCole Faust { 38*c217d954SCole Faust /** Query interface */ 39*c217d954SCole Faust struct Query 40*c217d954SCole Faust { 41*c217d954SCole Faust std::string ip_target; /**< The name of the IP target */ 42*c217d954SCole Faust DataType data_type; /**< Data type */ 43*c217d954SCole Faust unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */ 44*c217d954SCole Faust unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */ 45*c217d954SCole Faust unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */ 46*c217d954SCole Faust unsigned int b; /**< Batch size */ 47*c217d954SCole Faust }; 48*c217d954SCole Faust 49*c217d954SCole Faust bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs); 50*c217d954SCole Faust bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs); 51*c217d954SCole Faust bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs); 52*c217d954SCole Faust 53*c217d954SCole Faust /** MLGOHeuristics for configuring GEMM kernels */ 54*c217d954SCole Faust class MLGOHeuristics 55*c217d954SCole Faust { 56*c217d954SCole Faust public: 57*c217d954SCole Faust /** Constructor */ 58*c217d954SCole Faust MLGOHeuristics(); 59*c217d954SCole Faust /** Default Destructor */ 60*c217d954SCole Faust ~MLGOHeuristics() = default; 61*c217d954SCole Faust /** Prevent Copy Construct */ 62*c217d954SCole Faust MLGOHeuristics(const MLGOHeuristics &) = delete; 63*c217d954SCole Faust /** Prevent Copy Assignment */ 64*c217d954SCole Faust MLGOHeuristics &operator=(const MLGOHeuristics &) = delete; 65*c217d954SCole Faust /** Default Move Constructor */ 66*c217d954SCole Faust MLGOHeuristics(MLGOHeuristics &&) = default; 67*c217d954SCole Faust /** Default Move Assignment */ 68*c217d954SCole Faust MLGOHeuristics &operator=(MLGOHeuristics &&) = default; 69*c217d954SCole Faust /** Query the gemm type 70*c217d954SCole Faust * 71*c217d954SCole Faust * @param[in] query Query 72*c217d954SCole Faust * 73*c217d954SCole Faust * @return std::pair<bool, GEMMType> signals if the query succeeded or failed 74*c217d954SCole Faust */ 75*c217d954SCole Faust std::pair<bool, GEMMType> query_gemm_type(const Query &query) const; 76*c217d954SCole Faust /** Query the gemm configuration for native kernel 77*c217d954SCole Faust * 78*c217d954SCole Faust * @param[in] query Query 79*c217d954SCole Faust * 80*c217d954SCole Faust * @return std::pair<bool, GEMMConfigNative> bool signals if the query succeeded or failed 81*c217d954SCole Faust */ 82*c217d954SCole Faust std::pair<bool, GEMMConfigNative> query_gemm_config_native(const Query &query) const; 83*c217d954SCole Faust /** Query the gemm configuration for reshaped only rhs kernel 84*c217d954SCole Faust * 85*c217d954SCole Faust * @param[in] query Query 86*c217d954SCole Faust * 87*c217d954SCole Faust * @return std::pair<bool, GEMMConfigReshapedOnlyRHS> bool signals if the query succeeded or failed 88*c217d954SCole Faust */ 89*c217d954SCole Faust std::pair<bool, GEMMConfigReshapedOnlyRHS> query_gemm_config_reshaped_only_rhs(const Query &query) const; 90*c217d954SCole Faust /** Query the gemm configuration for reshaped kernel 91*c217d954SCole Faust * 92*c217d954SCole Faust * @param[in] query Query 93*c217d954SCole Faust * 94*c217d954SCole Faust * @return std::pair<bool, GEMMConfigReshaped> bool signals if the query succeeded or failed 95*c217d954SCole Faust */ 96*c217d954SCole Faust std::pair<bool, GEMMConfigReshaped> query_gemm_config_reshaped(const Query &query) const; 97*c217d954SCole Faust /** (Re)Load the heuristics from reading a dotmlgo file 98*c217d954SCole Faust * 99*c217d954SCole Faust * @param[in] filename Path to the dotmlgo file 100*c217d954SCole Faust * 101*c217d954SCole Faust * @return bool Signals if the reload succeeded or failed 102*c217d954SCole Faust */ 103*c217d954SCole Faust bool reload_from_file(const std::string &filename); 104*c217d954SCole Faust /** (Re)Load the heuristics from reading an input stream 105*c217d954SCole Faust * 106*c217d954SCole Faust * @param[in] istream Istream containing mlgo heuristics 107*c217d954SCole Faust * 108*c217d954SCole Faust * @return bool Signals if the reload succeeded or failed 109*c217d954SCole Faust */ 110*c217d954SCole Faust bool reload_from_stream(std::istream &istream); 111*c217d954SCole Faust 112*c217d954SCole Faust /** Get the heuristic tree from tree id 113*c217d954SCole Faust * 114*c217d954SCole Faust * @param[in] id Tree id. 115*c217d954SCole Faust * 116*c217d954SCole Faust * @return HeuristicTree& 117*c217d954SCole Faust */ 118*c217d954SCole Faust std::pair<bool, HeuristicTree *> get_heuristic_tree(HeuristicTree::TreeID id); 119*c217d954SCole Faust /** Add a heuristic tree 120*c217d954SCole Faust * @param t Heuristic tree to be added 121*c217d954SCole Faust */ 122*c217d954SCole Faust bool add_heuristic_tree(HeuristicTree &&t); 123*c217d954SCole Faust 124*c217d954SCole Faust /** Check the validity of the heuristic tree. 125*c217d954SCole Faust * 126*c217d954SCole Faust * @param id ID of the tree to be checked 127*c217d954SCole Faust * 128*c217d954SCole Faust * @return bool 129*c217d954SCole Faust */ 130*c217d954SCole Faust bool check_heuristic_tree(HeuristicTree::TreeID id); 131*c217d954SCole Faust 132*c217d954SCole Faust /** Check the overall validity of the heuristics. 133*c217d954SCole Faust * @return bool 134*c217d954SCole Faust */ 135*c217d954SCole Faust bool check_all() const; 136*c217d954SCole Faust 137*c217d954SCole Faust private: 138*c217d954SCole Faust static constexpr size_t _max_num_trees{ 100 }; /**< Max number of trees that can be added*/ 139*c217d954SCole Faust 140*c217d954SCole Faust private: 141*c217d954SCole Faust // There exists a one-to-one mappipng between TreeID and Index, either can be used to identify a @ref HeuristicTree 142*c217d954SCole Faust std::map<HeuristicTree::TreeID, HeuristicTree::Index> _indices; /**< A mapping from TreeID to Index */ 143*c217d954SCole Faust std::map<HeuristicTree::Index, HeuristicTree> _trees; /**< A mapping from Index to HeuristicTree */ 144*c217d954SCole Faust std::map<HeuristicTree::TreeID, bool> _tree_valid; /**< Result cache of the tree validity checks */ 145*c217d954SCole Faust bool _valid; /**< Overall validity */ 146*c217d954SCole Faust }; 147*c217d954SCole Faust 148*c217d954SCole Faust } // namespace mlgo 149*c217d954SCole Faust } // namespace arm_compute 150*c217d954SCole Faust #endif //SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H