1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2021 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #include "src/runtime/CL/mlgo/MLGOHeuristics.h"
25*c217d954SCole Faust
26*c217d954SCole Faust #include "arm_compute/core/Log.h"
27*c217d954SCole Faust #include "src/runtime/CL/mlgo/MLGOParser.h"
28*c217d954SCole Faust #include "src/runtime/CL/mlgo/Utils.h"
29*c217d954SCole Faust
30*c217d954SCole Faust #include <fstream>
31*c217d954SCole Faust
32*c217d954SCole Faust namespace arm_compute
33*c217d954SCole Faust {
34*c217d954SCole Faust namespace mlgo
35*c217d954SCole Faust {
operator ==(const GEMMConfigNative & lhs,const GEMMConfigNative & rhs)36*c217d954SCole Faust bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs)
37*c217d954SCole Faust {
38*c217d954SCole Faust return std::tie(lhs.m0, lhs.n0, lhs.k0) == std::tie(rhs.m0, rhs.n0, rhs.k0);
39*c217d954SCole Faust }
operator ==(const GEMMConfigReshapedOnlyRHS & lhs,const GEMMConfigReshapedOnlyRHS & rhs)40*c217d954SCole Faust bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs)
41*c217d954SCole Faust {
42*c217d954SCole Faust return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs,
43*c217d954SCole Faust rhs.export_cl_image);
44*c217d954SCole Faust }
operator ==(const GEMMConfigReshaped & lhs,const GEMMConfigReshaped & rhs)45*c217d954SCole Faust bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs)
46*c217d954SCole Faust {
47*c217d954SCole Faust return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0,
48*c217d954SCole Faust rhs.interleave_lhs, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image);
49*c217d954SCole Faust }
50*c217d954SCole Faust
51*c217d954SCole Faust constexpr size_t MLGOHeuristics::_max_num_trees;
52*c217d954SCole Faust
MLGOHeuristics()53*c217d954SCole Faust MLGOHeuristics::MLGOHeuristics()
54*c217d954SCole Faust : _indices{}, _trees{}, _tree_valid{}, _valid{ false }
55*c217d954SCole Faust {
56*c217d954SCole Faust }
57*c217d954SCole Faust
query_gemm_type(const Query & query) const58*c217d954SCole Faust std::pair<bool, GEMMType> MLGOHeuristics::query_gemm_type(const Query &query) const
59*c217d954SCole Faust {
60*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm type. %s.", to_string(query).c_str());
61*c217d954SCole Faust const auto invalid = GEMMType::RESHAPED;
62*c217d954SCole Faust if(!_valid)
63*c217d954SCole Faust {
64*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
65*c217d954SCole Faust return { false, invalid };
66*c217d954SCole Faust }
67*c217d954SCole Faust auto index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type);
68*c217d954SCole Faust GEMMShape shape_query{ query.m, query.n, query.k, query.b };
69*c217d954SCole Faust if(_trees.find(index) == _trees.end())
70*c217d954SCole Faust {
71*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
72*c217d954SCole Faust return { false, invalid };
73*c217d954SCole Faust }
74*c217d954SCole Faust return _trees.at(index).query<GEMMType>(shape_query);
75*c217d954SCole Faust }
query_gemm_config_native(const Query & query) const76*c217d954SCole Faust std::pair<bool, GEMMConfigNative> MLGOHeuristics::query_gemm_config_native(const Query &query) const
77*c217d954SCole Faust {
78*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.", to_string(query).c_str());
79*c217d954SCole Faust const auto invalid = GEMMConfigNative{};
80*c217d954SCole Faust if(!_valid)
81*c217d954SCole Faust {
82*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
83*c217d954SCole Faust return { false, invalid };
84*c217d954SCole Faust }
85*c217d954SCole Faust auto index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type);
86*c217d954SCole Faust GEMMShape shape_query{ query.m, query.n, query.k, query.b };
87*c217d954SCole Faust if(_trees.find(index) == _trees.end())
88*c217d954SCole Faust {
89*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
90*c217d954SCole Faust return { false, invalid };
91*c217d954SCole Faust }
92*c217d954SCole Faust return _trees.at(index).query<GEMMConfigNative>(shape_query);
93*c217d954SCole Faust }
query_gemm_config_reshaped_only_rhs(const Query & query) const94*c217d954SCole Faust std::pair<bool, GEMMConfigReshapedOnlyRHS> MLGOHeuristics::query_gemm_config_reshaped_only_rhs(const Query &query) const
95*c217d954SCole Faust {
96*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.", to_string(query).c_str());
97*c217d954SCole Faust const auto invalid = GEMMConfigReshapedOnlyRHS{};
98*c217d954SCole Faust if(!_valid)
99*c217d954SCole Faust {
100*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
101*c217d954SCole Faust return { false, invalid };
102*c217d954SCole Faust }
103*c217d954SCole Faust auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type);
104*c217d954SCole Faust GEMMShape shape_query{ query.m, query.n, query.k, query.b };
105*c217d954SCole Faust if(_trees.find(index) == _trees.end())
106*c217d954SCole Faust {
107*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
108*c217d954SCole Faust return { false, invalid };
109*c217d954SCole Faust }
110*c217d954SCole Faust return _trees.at(index).query<GEMMConfigReshapedOnlyRHS>(shape_query);
111*c217d954SCole Faust }
query_gemm_config_reshaped(const Query & query) const112*c217d954SCole Faust std::pair<bool, GEMMConfigReshaped> MLGOHeuristics::query_gemm_config_reshaped(const Query &query) const
113*c217d954SCole Faust {
114*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.", to_string(query).c_str());
115*c217d954SCole Faust const auto invalid = GEMMConfigReshaped{};
116*c217d954SCole Faust if(!_valid)
117*c217d954SCole Faust {
118*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
119*c217d954SCole Faust return { false, invalid };
120*c217d954SCole Faust }
121*c217d954SCole Faust auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type);
122*c217d954SCole Faust GEMMShape shape_query{ query.m, query.n, query.k, query.b };
123*c217d954SCole Faust if(_trees.find(index) == _trees.end())
124*c217d954SCole Faust {
125*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
126*c217d954SCole Faust return { false, invalid };
127*c217d954SCole Faust }
128*c217d954SCole Faust return _trees.at(index).query<GEMMConfigReshaped>(shape_query);
129*c217d954SCole Faust }
130*c217d954SCole Faust
check_heuristic_tree(HeuristicTree::TreeID id)131*c217d954SCole Faust bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id)
132*c217d954SCole Faust {
133*c217d954SCole Faust bool status;
134*c217d954SCole Faust HeuristicTree *tree{ nullptr };
135*c217d954SCole Faust std::tie(status, tree) = get_heuristic_tree(id);
136*c217d954SCole Faust if(!status)
137*c217d954SCole Faust {
138*c217d954SCole Faust return status;
139*c217d954SCole Faust }
140*c217d954SCole Faust status = tree->check();
141*c217d954SCole Faust if(!status)
142*c217d954SCole Faust {
143*c217d954SCole Faust return status;
144*c217d954SCole Faust }
145*c217d954SCole Faust _tree_valid[id] = true;
146*c217d954SCole Faust return true;
147*c217d954SCole Faust }
148*c217d954SCole Faust
check_all() const149*c217d954SCole Faust bool MLGOHeuristics::check_all() const
150*c217d954SCole Faust {
151*c217d954SCole Faust // Tree validities are already checked and cached.
152*c217d954SCole Faust bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v)
153*c217d954SCole Faust {
154*c217d954SCole Faust return !v.second;
155*c217d954SCole Faust })
156*c217d954SCole Faust == _tree_valid.end();
157*c217d954SCole Faust if(!all_trees_are_checked)
158*c217d954SCole Faust {
159*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo");
160*c217d954SCole Faust return false;
161*c217d954SCole Faust }
162*c217d954SCole Faust
163*c217d954SCole Faust // Other top level checks...
164*c217d954SCole Faust
165*c217d954SCole Faust return true;
166*c217d954SCole Faust }
167*c217d954SCole Faust
get_heuristic_tree(HeuristicTree::TreeID id)168*c217d954SCole Faust std::pair<bool, HeuristicTree *> MLGOHeuristics::get_heuristic_tree(HeuristicTree::TreeID id)
169*c217d954SCole Faust {
170*c217d954SCole Faust if(_indices.find(id) == _indices.end())
171*c217d954SCole Faust {
172*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id);
173*c217d954SCole Faust return std::make_pair(false, nullptr);
174*c217d954SCole Faust }
175*c217d954SCole Faust const auto index = _indices[id];
176*c217d954SCole Faust
177*c217d954SCole Faust if(_trees.find(index) == _trees.end())
178*c217d954SCole Faust {
179*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
180*c217d954SCole Faust return std::make_pair(false, nullptr);
181*c217d954SCole Faust }
182*c217d954SCole Faust auto &t = _trees[index];
183*c217d954SCole Faust
184*c217d954SCole Faust return std::make_pair(true, &t);
185*c217d954SCole Faust }
186*c217d954SCole Faust
add_heuristic_tree(HeuristicTree && t)187*c217d954SCole Faust bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t)
188*c217d954SCole Faust {
189*c217d954SCole Faust if(_indices.size() >= _max_num_trees)
190*c217d954SCole Faust {
191*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees);
192*c217d954SCole Faust return false;
193*c217d954SCole Faust }
194*c217d954SCole Faust // PRE: correctness of t is guaranteed by the tree construction process
195*c217d954SCole Faust // Ensure unique id
196*c217d954SCole Faust const auto id = t.id();
197*c217d954SCole Faust if(_indices.find(id) != _indices.end())
198*c217d954SCole Faust {
199*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id);
200*c217d954SCole Faust return false;
201*c217d954SCole Faust }
202*c217d954SCole Faust
203*c217d954SCole Faust // Ensure unique index
204*c217d954SCole Faust const auto index = t.index();
205*c217d954SCole Faust if(_trees.find(index) != _trees.end())
206*c217d954SCole Faust {
207*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists");
208*c217d954SCole Faust return false;
209*c217d954SCole Faust }
210*c217d954SCole Faust
211*c217d954SCole Faust _indices[id] = index;
212*c217d954SCole Faust _trees[index] = std::move(t);
213*c217d954SCole Faust _tree_valid[id] = false;
214*c217d954SCole Faust return true;
215*c217d954SCole Faust }
216*c217d954SCole Faust
reload_from_file(const std::string & filename)217*c217d954SCole Faust bool MLGOHeuristics::reload_from_file(const std::string &filename)
218*c217d954SCole Faust {
219*c217d954SCole Faust std::ifstream fs;
220*c217d954SCole Faust fs.exceptions(std::ifstream::badbit);
221*c217d954SCole Faust fs.open(filename, std::ios::in);
222*c217d954SCole Faust if(!fs.is_open())
223*c217d954SCole Faust {
224*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str());
225*c217d954SCole Faust return _valid = false;
226*c217d954SCole Faust }
227*c217d954SCole Faust return reload_from_stream(fs);
228*c217d954SCole Faust }
229*c217d954SCole Faust
reload_from_stream(std::istream & in)230*c217d954SCole Faust bool MLGOHeuristics::reload_from_stream(std::istream &in)
231*c217d954SCole Faust {
232*c217d954SCole Faust auto parsed = parser::parse_mlgo(in);
233*c217d954SCole Faust if(!parsed.first)
234*c217d954SCole Faust {
235*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead");
236*c217d954SCole Faust return _valid = false;
237*c217d954SCole Faust }
238*c217d954SCole Faust *this = std::move(parsed.second);
239*c217d954SCole Faust ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO loaded successfully");
240*c217d954SCole Faust return _valid = true;
241*c217d954SCole Faust }
242*c217d954SCole Faust
243*c217d954SCole Faust } // namespace mlgo
244*c217d954SCole Faust } // namespace arm_compute