xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/mps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Context.h>
2 #include <torch/mps.h>
3 
4 namespace torch {
5 namespace mps {
6 
is_available()7 bool is_available() {
8   return at::detail::getMPSHooks().hasMPS();
9 }
10 
11 /// Sets the seed for the MPS's default generator.
manual_seed(uint64_t seed)12 void manual_seed(uint64_t seed) {
13   if (is_available()) {
14     auto gen = at::detail::getMPSHooks().getDefaultMPSGenerator();
15     {
16       // See Note [Acquire lock when using random generators]
17       std::lock_guard<std::mutex> lock(gen.mutex());
18       gen.set_current_seed(seed);
19     }
20   }
21 }
22 
synchronize()23 void synchronize() {
24   at::detail::getMPSHooks().deviceSynchronize();
25 }
26 
commit()27 void commit() {
28   at::detail::getMPSHooks().commitStream();
29 }
30 
get_command_buffer()31 MTLCommandBuffer_t get_command_buffer() {
32   return at::detail::getMPSHooks().getCommandBuffer();
33 }
34 
get_dispatch_queue()35 DispatchQueue_t get_dispatch_queue() {
36   return at::detail::getMPSHooks().getDispatchQueue();
37 }
38 
39 } // namespace mps
40 } // namespace torch
41