1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/csrc/jit/codegen/fuser/fused_kernel.h> 6 7 #include <cstdint> 8 #include <memory> 9 #include <string> 10 11 // Forward declare DynamicLibrary 12 namespace at { 13 struct DynamicLibrary; 14 } 15 16 namespace torch { 17 namespace jit { 18 namespace fuser { 19 namespace cpu { 20 21 // Represents a compiled CPU kernel and the metadata necessary to run it 22 struct TORCH_API FusedKernelCPU : public FusedKernel { 23 FusedKernelCPU( 24 std::string name, 25 std::string code, 26 std::vector<TensorDesc> input_desc, 27 std::vector<TensorDesc> output_desc, 28 std::vector<PartitionDesc> chunk_desc, 29 std::vector<PartitionDesc> concat_desc, 30 bool has_random); 31 backendFusedKernelCPU32 at::Backend backend() const override { 33 return at::Backend::CPU; 34 } 35 launch_rawFusedKernelCPU36 void launch_raw(const uint32_t numel, std::vector<void*>& arguments) 37 const override { 38 kernel(numel, arguments.data()); 39 } 40 41 private: 42 std::unique_ptr<at::DynamicLibrary> so_lib; 43 void (*kernel)(uint32_t, void**) = nullptr; 44 }; 45 46 } // namespace cpu 47 } // namespace fuser 48 } // namespace jit 49 } // namespace torch 50