xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/kernel_spec.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/stack.h>
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/codegen/fuser/arg_spec.h>
7 #include <torch/csrc/jit/codegen/fuser/fused_kernel.h>
8 #include <torch/csrc/jit/codegen/fuser/interface.h>
9 #include <torch/csrc/jit/ir/ir.h>
10 #include <torch/csrc/jit/runtime/interpreter.h>
11 #include <optional>
12 
13 #include <cstdint>
14 #include <memory>
15 #include <mutex>
16 #include <unordered_map>
17 #include <vector>
18 
19 namespace torch::jit::fuser {
20 
21 // Helper struct containing partition information: the number of tensors
22 // created and the dimension the partitioning is performed on.
23 // Note: created during upfront compilation, once the tensors are known
24 // at runtime the partition info is logically combined with the tensor
25 // descriptions to create PartitionDesc objects.
26 struct TORCH_API PartitionInfo {
PartitionInfoPartitionInfo27   PartitionInfo(const int64_t _nSubTensors, const int64_t _dim)
28       : nSubTensors_{_nSubTensors}, dim_{_dim} {};
29 
nSubTensorsPartitionInfo30   int64_t nSubTensors() const {
31     return nSubTensors_;
32   }
dimPartitionInfo33   int64_t dim() const {
34     return dim_;
35   }
36 
37  private:
38   int64_t nSubTensors_;
39   int64_t dim_;
40 };
41 
42 // "Kernel Specification." - Contains device-independent fusion information.
43 // Each kernel specification contains a map of instantiated generated functions
44 // that implement some or most of its functionality. Multiple generated
45 // functions are needed by each abstract specification because of different
46 // devices (cpu vs gpu, different gpus) and different inputs (int vs float,
47 // contiguous vs discontiguous).
48 // Note: uses a mutex to control access to its kernel store
49 // Note: unordered containers do not invalidate references/pointers on
50 //   rehashing, which is critical for thread-safety.
51 // TODO: allow abstract kernels to use multiple generated kernels
52 // TODO: allow abstract kernels to reuse generated kernels from common pool
53 struct TORCH_API KernelSpec {
54   // Note: assumes the spec is a single block
55   // Note: This is the appropriate place to generalize if you want to add other
56   //  passes to upfront compilation that walk the graph.
KernelSpecKernelSpec57   KernelSpec(const int64_t _key, const std::shared_ptr<Graph>& _graph)
58       : key_{_key},
59         graph_{_graph},
60         code_{_graph, "<fused code>"},
61         nInputs_{_graph->inputs().size()},
62 
63         inputBroadcastGroups_{},
64         inputChunks_{},
65 
66         kernels_{} {
67     // No need to iterate over reference since n is pointer
68     for (const auto n : graph_->nodes()) {
69       static_assert(std::is_pointer_v<decltype(n)>, "n must be a pointer");
70       if (n->kind() == aten::rand_like) {
71         has_random_ = true;
72         break;
73       }
74     }
75     nTensorInputs_ = std::count_if(
76         graph_->inputs().begin(), graph_->inputs().end(), [](const Value* v) {
77           return v->type()->isSubtypeOf(*TensorType::get());
78         });
79   }
80 
81   // Getters
keyKernelSpec82   int64_t key() const {
83     return key_;
84   }
graphKernelSpec85   std::shared_ptr<Graph> graph() const {
86     return graph_;
87   }
codeKernelSpec88   const Code& code() const {
89     return code_;
90   }
nInputsKernelSpec91   int64_t nInputs() const {
92     return nInputs_;
93   }
nTensorInputsKernelSpec94   int64_t nTensorInputs() const {
95     return nTensorInputs_;
96   }
97 
inputBroadcastGroupsKernelSpec98   std::vector<std::vector<int64_t>>& inputBroadcastGroups() {
99     return inputBroadcastGroups_;
100   }
inputBroadcastGroupsKernelSpec101   const std::vector<std::vector<int64_t>>& inputBroadcastGroups() const {
102     return inputBroadcastGroups_;
103   }
104 
inputChunksKernelSpec105   std::vector<PartitionInfo>& inputChunks() {
106     return inputChunks_;
107   }
inputChunksKernelSpec108   const std::vector<PartitionInfo>& inputChunks() const {
109     return inputChunks_;
110   }
111 
hasRandomKernelSpec112   bool hasRandom() const {
113     return has_random_;
114   }
115 
116   // Cache functions
findKernelKernelSpec117   std::optional<std::shared_ptr<FusedKernel>> findKernel(
118       const ArgSpec& arg_spec) const {
119     std::lock_guard<std::mutex> guard{mutex_};
120     const auto it = kernels_.find(arg_spec);
121     if (it == kernels_.end())
122       return std::nullopt;
123     return it->second;
124   }
cacheKernelKernelSpec125   void cacheKernel(
126       const ArgSpec& arg_spec,
127       const std::shared_ptr<FusedKernel>& kernel) const {
128     std::lock_guard<std::mutex> guard{mutex_};
129     kernels_.emplace(arg_spec, kernel);
130   }
131 
132  private:
133   int64_t key_;
134   std::shared_ptr<Graph> graph_;
135   Code code_;
136   uint64_t nInputs_;
137   uint64_t nTensorInputs_{};
138   std::vector<std::vector<int64_t>> inputBroadcastGroups_;
139   std::vector<PartitionInfo> inputChunks_;
140   bool has_random_{false};
141   mutable std::mutex mutex_;
142   mutable std::
143       unordered_map<ArgSpec, std::shared_ptr<FusedKernel>, c10::hash<ArgSpec>>
144           kernels_;
145 };
146 
147 } // namespace torch::jit::fuser
148