1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <ATen/core/stack.h> 5 #include <torch/csrc/Export.h> 6 #include <torch/csrc/jit/codegen/fuser/arg_spec.h> 7 #include <torch/csrc/jit/codegen/fuser/fused_kernel.h> 8 #include <torch/csrc/jit/codegen/fuser/interface.h> 9 #include <torch/csrc/jit/ir/ir.h> 10 #include <torch/csrc/jit/runtime/interpreter.h> 11 #include <optional> 12 13 #include <cstdint> 14 #include <memory> 15 #include <mutex> 16 #include <unordered_map> 17 #include <vector> 18 19 namespace torch::jit::fuser { 20 21 // Helper struct containing partition information: the number of tensors 22 // created and the dimension the partitioning is performed on. 23 // Note: created during upfront compilation, once the tensors are known 24 // at runtime the partition info is logically combined with the tensor 25 // descriptions to create PartitionDesc objects. 26 struct TORCH_API PartitionInfo { PartitionInfoPartitionInfo27 PartitionInfo(const int64_t _nSubTensors, const int64_t _dim) 28 : nSubTensors_{_nSubTensors}, dim_{_dim} {}; 29 nSubTensorsPartitionInfo30 int64_t nSubTensors() const { 31 return nSubTensors_; 32 } dimPartitionInfo33 int64_t dim() const { 34 return dim_; 35 } 36 37 private: 38 int64_t nSubTensors_; 39 int64_t dim_; 40 }; 41 42 // "Kernel Specification." - Contains device-independent fusion information. 43 // Each kernel specification contains a map of instantiated generated functions 44 // that implement some or most of its functionality. Multiple generated 45 // functions are needed by each abstract specification because of different 46 // devices (cpu vs gpu, different gpus) and different inputs (int vs float, 47 // contiguous vs discontiguous). 48 // Note: uses a mutex to control access to its kernel store 49 // Note: unordered containers do not invalidate references/pointers on 50 // rehashing, which is critical for thread-safety. 51 // TODO: allow abstract kernels to use multiple generated kernels 52 // TODO: allow abstract kernels to reuse generated kernels from common pool 53 struct TORCH_API KernelSpec { 54 // Note: assumes the spec is a single block 55 // Note: This is the appropriate place to generalize if you want to add other 56 // passes to upfront compilation that walk the graph. KernelSpecKernelSpec57 KernelSpec(const int64_t _key, const std::shared_ptr<Graph>& _graph) 58 : key_{_key}, 59 graph_{_graph}, 60 code_{_graph, "<fused code>"}, 61 nInputs_{_graph->inputs().size()}, 62 63 inputBroadcastGroups_{}, 64 inputChunks_{}, 65 66 kernels_{} { 67 // No need to iterate over reference since n is pointer 68 for (const auto n : graph_->nodes()) { 69 static_assert(std::is_pointer_v<decltype(n)>, "n must be a pointer"); 70 if (n->kind() == aten::rand_like) { 71 has_random_ = true; 72 break; 73 } 74 } 75 nTensorInputs_ = std::count_if( 76 graph_->inputs().begin(), graph_->inputs().end(), [](const Value* v) { 77 return v->type()->isSubtypeOf(*TensorType::get()); 78 }); 79 } 80 81 // Getters keyKernelSpec82 int64_t key() const { 83 return key_; 84 } graphKernelSpec85 std::shared_ptr<Graph> graph() const { 86 return graph_; 87 } codeKernelSpec88 const Code& code() const { 89 return code_; 90 } nInputsKernelSpec91 int64_t nInputs() const { 92 return nInputs_; 93 } nTensorInputsKernelSpec94 int64_t nTensorInputs() const { 95 return nTensorInputs_; 96 } 97 inputBroadcastGroupsKernelSpec98 std::vector<std::vector<int64_t>>& inputBroadcastGroups() { 99 return inputBroadcastGroups_; 100 } inputBroadcastGroupsKernelSpec101 const std::vector<std::vector<int64_t>>& inputBroadcastGroups() const { 102 return inputBroadcastGroups_; 103 } 104 inputChunksKernelSpec105 std::vector<PartitionInfo>& inputChunks() { 106 return inputChunks_; 107 } inputChunksKernelSpec108 const std::vector<PartitionInfo>& inputChunks() const { 109 return inputChunks_; 110 } 111 hasRandomKernelSpec112 bool hasRandom() const { 113 return has_random_; 114 } 115 116 // Cache functions findKernelKernelSpec117 std::optional<std::shared_ptr<FusedKernel>> findKernel( 118 const ArgSpec& arg_spec) const { 119 std::lock_guard<std::mutex> guard{mutex_}; 120 const auto it = kernels_.find(arg_spec); 121 if (it == kernels_.end()) 122 return std::nullopt; 123 return it->second; 124 } cacheKernelKernelSpec125 void cacheKernel( 126 const ArgSpec& arg_spec, 127 const std::shared_ptr<FusedKernel>& kernel) const { 128 std::lock_guard<std::mutex> guard{mutex_}; 129 kernels_.emplace(arg_spec, kernel); 130 } 131 132 private: 133 int64_t key_; 134 std::shared_ptr<Graph> graph_; 135 Code code_; 136 uint64_t nInputs_; 137 uint64_t nTensorInputs_{}; 138 std::vector<std::vector<int64_t>> inputBroadcastGroups_; 139 std::vector<PartitionInfo> inputChunks_; 140 bool has_random_{false}; 141 mutable std::mutex mutex_; 142 mutable std:: 143 unordered_map<ArgSpec, std::shared_ptr<FusedKernel>, c10::hash<ArgSpec>> 144 kernels_; 145 }; 146 147 } // namespace torch::jit::fuser 148