1 #pragma once 2 3 #include <ATen/core/stack.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/csrc/jit/codegen/fuser/arg_spec.h> 6 #include <torch/csrc/jit/codegen/fuser/fused_kernel.h> 7 #include <torch/csrc/jit/codegen/fuser/interface.h> 8 #include <torch/csrc/jit/codegen/fuser/kernel_spec.h> 9 #include <torch/csrc/jit/ir/ir.h> 10 11 #include <cstdint> 12 #include <vector> 13 14 namespace torch::jit::fuser { 15 16 // Performs device-independent "upfront" compilation of the given fusion_group, 17 // if it has not been registered already. 18 // Returns a key that can be used to run the fusion later 19 TORCH_API int64_t registerFusion(const Node* fusion_group); 20 21 // Performs device-specific "runtime" compilation of the given kernel 22 // with the runtime arguments specified in ArgSpec. 23 // Outputs are allocated using map_size on the specified device. 24 TORCH_API std::shared_ptr<FusedKernel> compileKernel( 25 const KernelSpec& spec, 26 const ArgSpec& arg_spec, 27 const std::vector<int64_t>& map_size, 28 const at::Device device); 29 30 TORCH_API size_t nCompiledKernels(); 31 32 TORCH_API int debugFuser(); 33 34 using FusedKernelConstructor = std::function<std::shared_ptr<FusedKernel>( 35 int16_t device, 36 std::string name, 37 std::string code, 38 std::vector<TensorDesc> input_desc, 39 std::vector<TensorDesc> output_desc, 40 std::vector<PartitionDesc> chunk_desc, 41 std::vector<PartitionDesc> concat_desc, 42 bool has_random)>; 43 44 TORCH_API void registerFusionBackend( 45 at::Device::Type backend_type, 46 FusedKernelConstructor ctor); 47 TORCH_API bool hasFusionBackend(at::Device::Type backend_type); 48 struct TORCH_API RegisterFusionBackend { RegisterFusionBackendRegisterFusionBackend49 RegisterFusionBackend( 50 at::Device::Type backend_type, 51 FusedKernelConstructor ctor) { 52 registerFusionBackend(backend_type, std::move(ctor)); 53 } 54 }; 55 56 } // namespace torch::jit::fuser 57