xref: /aosp_15_r20/external/ruy/ruy/trmul.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The 'middle-end' in ruy. See TrMul function comment.
17 
18 #include "ruy/trmul.h"
19 
20 #include <algorithm>
21 #include <atomic>
22 #include <cstdint>
23 #include <cstring>
24 #include <limits>
25 #include <memory>
26 #include <vector>
27 
28 #include "ruy/allocator.h"
29 #include "ruy/block_map.h"
30 #include "ruy/check_macros.h"
31 #include "ruy/cpu_cache_params.h"
32 #include "ruy/cpuinfo.h"
33 #include "ruy/ctx.h"
34 #include "ruy/denormal.h"
35 #include "ruy/mat.h"
36 #include "ruy/matrix.h"
37 #include "ruy/mul_params.h"
38 #include "ruy/strategy_controls.h"
39 #include "ruy/opt_set.h"
40 #include "ruy/profiler/instrumentation.h"
41 #include "ruy/side_pair.h"
42 #include "ruy/size_util.h"
43 #include "ruy/thread_pool.h"
44 #include "ruy/trace.h"
45 #include "ruy/tune.h"
46 
47 namespace ruy {
48 
49 namespace {
50 
51 // Enum to track the packingstatus of a block of the LHS or RHS matrix.
52 enum class PackingStatus : std::uint8_t {
53   kNotStarted,  // No thread has started packing this block yet.
54   kInProgress,  // Some thread is currently packing this block.
55   kFinished     // This block has already been packed.
56 };
57 
58 // TrMulTask is the task that a ruy thread runs to perform the TrMul operation.
59 class TrMulTask final : public Task {
60  public:
TrMulTask(TrMulParams * params,const BlockMap & block_map,std::atomic<int> * atomic_block_id,int thread_id,bool need_atomics,SidePair<std::atomic<PackingStatus> * > packing_status,TuningResolver * tuning_resolver,Allocator * local_allocator,CpuInfo * cpuinfo)61   TrMulTask(TrMulParams* params, const BlockMap& block_map,
62             std::atomic<int>* atomic_block_id, int thread_id, bool need_atomics,
63             SidePair<std::atomic<PackingStatus>*> packing_status,
64             TuningResolver* tuning_resolver, Allocator* local_allocator,
65             CpuInfo* cpuinfo)
66       : params_(params),
67         block_map_(block_map),
68         atomic_block_id_(atomic_block_id),
69         thread_id_(thread_id),
70         need_atomics_(need_atomics),
71         packing_status_(packing_status),
72         tuning_resolver_(tuning_resolver),
73         local_allocator_(local_allocator),
74         local_already_packed_{nullptr, nullptr},
75         cpuinfo_(cpuinfo) {}
76 
77   // Thread main function. This is one thread's share of the TrMul work.
Run()78   void Run() override {
79     RUY_TRACE_SCOPE_NAME("TrMulTask::Run");
80     RUY_TRACE_SET_THEAD_ID(thread_id_);
81     // Allocate and initialize `local_packed`.
82     for (Side side : {Side::kLhs, Side::kRhs}) {
83       if (!params_->is_prepacked[side]) {
84         const int size = NumBlocksPerSide(side, block_map_);
85         local_allocator_->Allocate(size, &local_already_packed_[side]);
86         memset(local_already_packed_[side], 0, size * sizeof(bool));
87       }
88     }
89 
90     const Tuning tuning = tuning_resolver_->Resolve(cpuinfo_);
91     const int num_blocks = NumBlocks(block_map_);
92 
93     // Each thread starts by initially reserving the block whose id
94     // is the thread id.
95     int block_id = thread_id_;
96     // Loop until all blocks have been computed.
97     while (block_id < num_blocks) {
98       RUY_TRACE_SCOPE_NAME("Main loop iteration");
99       // Reserve the next block to handle, hiding the latency of this atomic op.
100       const int next_block_id =
101           atomic_block_id_->fetch_add(1, std::memory_order_relaxed);
102       // Get coordinates of the current block to handle, in "block space".
103       SidePair<int> block;
104       GetBlockByIndex(block_map_, block_id, &block);
105       // Get coordinates of the current block to handle, in matrix space.
106       SidePair<int> start, end;
107       GetBlockMatrixCoords(block_map_, block, &start, &end);
108       RUY_TRACE_INFO(TRMUL_TASK_MAIN_LOOP_GOT_BLOCK_COORDS);
109       // Maybe pack the current LHS/RHS block, if not already packed.
110       EnsurePacked(block, start, end, tuning);
111       // Actually do matrix multiplication work
112       params_->RunKernel(tuning, start, end);
113       // Move on to the next block as obtained by the atomic increment
114       // at the start of this while loop iteration.
115       block_id = next_block_id;
116     }
117 
118     local_allocator_->FreeAll();
119   }
120 
121  private:
122   // Tries to pack a block, without blocking.
123   // If the block was already packed, returns true.
124   // If the block was not started packing, packs it and returns true.
125   // If the block was being packed by another thread, returns false.
TryPack(Side side,int block,int start,int end,Tuning tuning)126   bool TryPack(Side side, int block, int start, int end, Tuning tuning) {
127     if (params_->is_prepacked[side]) {
128       return true;
129     }
130     if (!local_already_packed_[side][block]) {
131       if (need_atomics_) {
132         // Explanation of this compare_exchange_strong operation:
133         // This atomically performs all of the following:
134         // 1. Read `status` with "acquire" memory order.
135         //    * That this read uses "acquire" is because both memory orders
136         //      specified have "acquire" as their read-component.
137         // 2. Compare (bitwise) with `exchanged_status`.
138         // 3. If equal, stores the value kInProgress to `status` with "release"
139         //    memory order, and returns true, so we take this 'if' branch.
140         //    * That this store uses "release" is because of the _rel part in
141         //      memory_order_acq_rel passed as the first memory order argument.
142         // 4. If not equal, stores the loaded value of `status` to
143         //    `exchanged_status` with "relaxed" semantics, and returns false,
144         //    so we take the 'else' branch.
145         //    * That this store uses "relaxed" is because the second memory
146         //      order argument, memory_order_acquire, implies no particular
147         //      store semantics. "relaxed" is acceptable here because this
148         //      stores to a local stack variable.
149         //
150         // Rationale for compare_exchange_strong as opposed to
151         // compare_exchange_weak:
152         // The spurious-failure case with compare_exchange_weak will actually
153         // happen a lot here, because the atomic 'status' bytes are stored
154         // contiguously in arrays and neighboring values will be accessed
155         // by multiple threads concurrently. On a typical ARM CPU, an exclusives
156         // reservation granule is 64 bytes, so a lot of false-sharing may
157         // happen. Using compare_exchange_weak would thus result in often having
158         // TryPack return 'false' when it could instead have done the packing
159         // work and returned 'true'. Heuristically, that is not a good thing.
160         // Moreover, this changes the TryPack contract, loosening it and making
161         // it harder for the caller to reason about. Finally, the overhead of
162         // atomic operations is mitigated by the enclosing check on
163         // local_already_packed, so maybe the overhead of
164         // compare_exchange_strong isn't such a problem. But we don't really
165         // know for sure, that would be interesting to experiment more with.
166         PackingStatus exchanged_status = PackingStatus::kNotStarted;
167         std::atomic<PackingStatus>& status = packing_status_[side][block];
168         if (status.compare_exchange_strong(
169                 exchanged_status, PackingStatus::kInProgress,
170                 std::memory_order_acq_rel, std::memory_order_acquire)) {
171           // In this branch, the status was kNotStarted and we just atomically
172           // changed it to kInProgress as we are about to handle the packing
173           // ourselves.
174           RUY_TRACE_INFO(TRYPACK_PACKING);
175           params_->RunPack(side, tuning, start, end);
176           status.store(PackingStatus::kFinished, std::memory_order_release);
177         } else if (exchanged_status == PackingStatus::kInProgress) {
178           // Another thread is currently packing this block.
179           RUY_TRACE_INFO(TRYPACK_ANOTHER_THREAD_PACKING);
180           return false;
181         } else {
182           RUY_TRACE_INFO(TRYPACK_PACKED_BY_ANOTHER_THREAD);
183         }
184         RUY_DCHECK(status.load(std::memory_order_acquire) ==
185                    PackingStatus::kFinished);
186       } else {
187         // Single-threaded case: no need for expensive atomics,
188         // local_already_packed is the truth already.
189         params_->RunPack(side, tuning, start, end);
190       }
191       local_already_packed_[side][block] = true;
192     } else {
193       RUY_TRACE_INFO(TRYPACK_PREVIOUSLY_PACKED);
194     }
195     return true;
196   }
197 
198   // Ensures that both the LHS and RHS blocks required by the specified block
199   // are packed. In the event that they are already being packed on another
200   // threads, this function may perform the packing of some other block while
201   // waiting for that other thread to finish packing the requested block.
EnsurePacked(const SidePair<int> & block,const SidePair<int> & start,const SidePair<int> & end,Tuning tuning)202   void EnsurePacked(const SidePair<int>& block, const SidePair<int>& start,
203                     const SidePair<int>& end, Tuning tuning) {
204 #if RUY_OPT(PACK_AHEAD)
205     SidePair<int> next_runahead_block{block[Side::kLhs] + 1,
206                                       block[Side::kRhs] + 1};
207     Side next_runahead_side = Side::kLhs;
208 #endif
209     while (true) {
210       bool both_sides_packed = true;
211       for (Side side : {Side::kLhs, Side::kRhs}) {
212         both_sides_packed &=
213             TryPack(side, block[side], start[side], end[side], tuning);
214       }
215       if (both_sides_packed) {
216         break;
217       }
218 #if RUY_OPT(PACK_AHEAD)
219       RUY_TRACE_INFO(ENSURE_PACKED_ENTER_RUN_AHEAD);
220       const Side runahead_side = next_runahead_side;
221       const int runahead_block = next_runahead_block[runahead_side];
222       next_runahead_side = OtherSide(next_runahead_side);
223       if (runahead_block >= NumBlocksPerSide(runahead_side, block_map_)) {
224         continue;
225       }
226       int runahead_block_start, runahead_block_end;
227       GetBlockMatrixCoords(runahead_side, block_map_, runahead_block,
228                            &runahead_block_start, &runahead_block_end);
229       TryPack(runahead_side, runahead_block, runahead_block_start,
230               runahead_block_end, tuning);
231       next_runahead_block[runahead_side] = runahead_block + 1;
232 #endif
233     }
234     RUY_TRACE_INFO(ENSURE_PACKED_END);
235   }
236 
237   TrMulParams* params_;
238   const BlockMap& block_map_;
239   std::atomic<int>* atomic_block_id_;
240   int thread_id_;
241   bool need_atomics_;
242   SidePair<std::atomic<PackingStatus>*> packing_status_;
243   TuningResolver* tuning_resolver_;
244   Allocator* local_allocator_;
245 
246   // Local indicators of packedness to avoid the overhead of atomic ops.
247   SidePair<bool*> local_already_packed_;
248 
249   CpuInfo* cpuinfo_;
250 };
251 
GetTentativeThreadCount(Ctx * ctx,int rows,int cols,int depth)252 int GetTentativeThreadCount(Ctx* ctx, int rows, int cols, int depth) {
253 #if RUY_PLATFORM_EMSCRIPTEN
254   // b/139927184, std::thread constructor raises exception
255   return 1;
256 #endif
257   RUY_TRACE_SCOPE;
258   // Empirically determined rule for reasonable number of
259   // threads to use. This is proportional to the number of arithmetic ops
260   // in this Mul (product of the 3 sizes).
261   // Be defensive here by explicitly promoting operands to int64 to avoid the
262   // pitfall of `int64 result = x * y;` overflowing as x and y are still narrow.
263   if (ctx->num_threads_strategy() == NumThreadsStrategy::kForceMaxNumThreads) {
264     return ctx->max_num_threads();
265   }
266   RUY_CHECK_EQ(ctx->num_threads_strategy(), NumThreadsStrategy::kDefault);
267   const std::int64_t rows_i64 = rows;
268   const std::int64_t cols_i64 = cols;
269   const std::int64_t depth_i64 = depth;
270   const std::int64_t problem_size = rows_i64 * cols_i64 * depth_i64;
271   // Division is cheap when the denominator is constant
272   static constexpr std::int64_t kSizePerAdditionalThread = 32768;
273   std::int64_t tentative_thread_count = problem_size / kSizePerAdditionalThread;
274   // tentative_thread_count is still an int64, still not necessarily in the
275   // range of type int. It probably is as long as kSizePerAdditionalThread is
276   // large, but imagine that that constant might change in the future.
277   tentative_thread_count = std::max<std::int64_t>(tentative_thread_count, 1);
278   tentative_thread_count =
279       std::min<std::int64_t>(tentative_thread_count, ctx->max_num_threads());
280   // now tentative_thread_count must be in the range of type int, because
281   // ctx->max_num_threads() is.
282   RUY_DCHECK_LE(tentative_thread_count, std::numeric_limits<int>::max());
283   return tentative_thread_count;
284 }
285 
GetUseSimpleLoop(int tentative_thread_count,int rows,int cols,int depth,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)286 bool GetUseSimpleLoop(int tentative_thread_count, int rows, int cols, int depth,
287                       int lhs_scalar_size, int rhs_scalar_size,
288                       const CpuCacheParams& cpu_cache_params) {
289   RUY_TRACE_SCOPE;
290   if (tentative_thread_count == 1) {
291     if (IsObviouslyLinearTraversal(rows, cols, depth, lhs_scalar_size,
292                                    rhs_scalar_size, cpu_cache_params)) {
293       RUY_TRACE_INFO(GET_USE_SIMPLE_LOOP_RETURNS_TRUE);
294       return true;
295     }
296   }
297   RUY_TRACE_INFO(GET_USE_SIMPLE_LOOP_RETURNS_FALSE);
298   return false;
299 }
300 
301 }  // namespace
302 
303 // TrMul is the ruy middle-end. It contains the high-level logic to perform
304 // a ruy::Mul's work, down to calls to back-end Kernel and Pack functions.
305 // This includes determining how many threads to use, computing the BlockMap,
306 // executing tasks on a thread-pool. The TrMul function itself runs on the main
307 // thread, the code that is potentially running on worker threads is in
308 // TrMulTask::Run().
TrMul(Ctx * ctx,TrMulParams * params)309 void TrMul(Ctx* ctx, TrMulParams* params) {
310   RUY_TRACE_SCOPE;
311   profiler::ScopeLabel label(
312       "TrMul (Path=0x%x, max_num_threads=%d, is_prepacked=(%d,%d))",
313       static_cast<int>(params->path), ctx->max_num_threads(),
314       params->is_prepacked[Side::kLhs], params->is_prepacked[Side::kRhs]);
315 
316   PEMat& packed_lhs = params->packed_matrix[Side::kLhs];
317   PEMat& packed_rhs = params->packed_matrix[Side::kRhs];
318   EMat& lhs = params->src[Side::kLhs];
319   EMat& rhs = params->src[Side::kRhs];
320 
321   const int rows = lhs.layout.cols;
322   const int cols = rhs.layout.cols;
323   const int depth = lhs.layout.rows;
324 
325   const int tentative_thread_count =
326       GetTentativeThreadCount(ctx, rows, cols, depth);
327   const auto& cpu_cache_params = ctx->mutable_cpuinfo()->CacheParams();
328 
329   // Suppress denormals to avoid computation inefficiency.
330   // Note this only handles the denormal suppression on the main thread. As for
331   // worker threads, the suppression is handled in each thread's main loop. See
332   // the corresponding code in thread_pool.cc for details.
333   ScopedSuppressDenormals suppress_denormals;
334 
335   // Case of running this TrMul as a simple loop.
336   // This is a good place to start reading this function: all the rest
337   // of this function is just an optimized, but functionally equivalent,
338   // version of that.
339   if (GetUseSimpleLoop(tentative_thread_count, rows, cols, depth,
340                        lhs.data_type.size, rhs.data_type.size,
341                        cpu_cache_params)) {
342     profiler::ScopeLabel label_simple("TrMulImpl, simple loop");
343     Tuning tuning = ctx->GetMainThreadTuning();
344     RUY_TRACE_INFO(TRMUL_SIMPLE_LOOP);
345 
346     const SidePair<int> origin{0, 0};
347     const SidePair<int> rounded_dims{packed_lhs.layout.cols,
348                                      packed_rhs.layout.cols};
349     for (Side side : {Side::kLhs, Side::kRhs}) {
350       if (!params->is_prepacked[side]) {
351         params->RunPack(side, tuning, origin[side], rounded_dims[side]);
352       }
353     }
354     params->RunKernel(tuning, origin, rounded_dims);
355     return;
356   }
357 
358   profiler::ScopeLabel label_general("TrMulImpl, general case");
359   RUY_TRACE_INFO(TRMUL_GENERAL_CASE);
360   Allocator* main_allocator = ctx->GetMainAllocator();
361 
362   // Initialize block map.
363   BlockMap block_map;
364   MakeBlockMap(packed_lhs.layout.cols, packed_rhs.layout.cols, depth,
365                packed_lhs.layout.kernel.cols, packed_rhs.layout.kernel.cols,
366                packed_lhs.data_type.size, packed_rhs.data_type.size,
367                tentative_thread_count, cpu_cache_params, &block_map);
368 
369   // Initialize per-thread state.
370   const int thread_count = block_map.thread_count;
371   const bool need_atomics = thread_count > 1;
372   ctx->EnsureThreadSpecificResources(thread_count);
373   for (int i = 0; i < thread_count; i++) {
374     ctx->GetThreadSpecificTuningResolver(i)->SetTuning(ctx->explicit_tuning());
375   }
376 
377   // In the need_atomics case, allocate and initialize atomic values tracking
378   // the packing status of blocks.
379   SidePair<std::atomic<PackingStatus>*> packing_status{nullptr, nullptr};
380   if (need_atomics) {
381     for (Side side : {Side::kLhs, Side::kRhs}) {
382       if (!params->is_prepacked[side]) {
383         const int size = NumBlocksPerSide(side, block_map);
384         main_allocator->Allocate(size, &packing_status[side]);
385         for (int i = 0; i < size; i++) {
386           packing_status[side][i].store(PackingStatus::kNotStarted,
387                                         std::memory_order_relaxed);
388         }
389       }
390     }
391   }
392 
393   // Create the atomic block id, allocate it using Allocator so that
394   // we get the alignment ensuring that it sits alone in its exclusives
395   // reservation granule.
396   std::atomic<int>* atomic_block_id;
397   main_allocator->Allocate(1, &atomic_block_id);
398   atomic_block_id->store(thread_count);
399 
400   // Create task objects. We allocate a single buffer and then use placement-new
401   // to construct N TrMulTask objects within it. To avoid having the Clang CFI
402   // sanitizer complain about a TrMulTask* pointer temporarily pointing to
403   // garbage, we keep the pointer a plain char* until finished constructing.
404   char* tasks_buf =
405       main_allocator->Allocate<char>(thread_count * sizeof(TrMulTask));
406   for (int i = 0; i < thread_count; i++) {
407     auto* allocator = ctx->GetThreadSpecificAllocator(i);
408     auto* tuning_resolver = ctx->GetThreadSpecificTuningResolver(i);
409     new (tasks_buf + i * sizeof(TrMulTask)) TrMulTask(
410         params, block_map, atomic_block_id, i, need_atomics, packing_status,
411         tuning_resolver, allocator, ctx->mutable_cpuinfo());
412   }
413   TrMulTask* tasks = reinterpret_cast<TrMulTask*>(tasks_buf);
414 
415   // Do the computation.
416   ctx->mutable_thread_pool()->Execute(thread_count, tasks);
417 
418   // Finish up.
419   for (int i = 0; i < thread_count; i++) {
420     tasks[i].~TrMulTask();
421   }
422 }
423 
424 }  // namespace ruy
425