xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/csrc/jit/codegen/fuser/fused_kernel.h>
6 
7 #include <cstdint>
8 #include <memory>
9 #include <string>
10 
11 // Forward declare DynamicLibrary
12 namespace at {
13 struct DynamicLibrary;
14 }
15 
16 namespace torch {
17 namespace jit {
18 namespace fuser {
19 namespace cpu {
20 
21 // Represents a compiled CPU kernel and the metadata necessary to run it
22 struct TORCH_API FusedKernelCPU : public FusedKernel {
23   FusedKernelCPU(
24       std::string name,
25       std::string code,
26       std::vector<TensorDesc> input_desc,
27       std::vector<TensorDesc> output_desc,
28       std::vector<PartitionDesc> chunk_desc,
29       std::vector<PartitionDesc> concat_desc,
30       bool has_random);
31 
backendFusedKernelCPU32   at::Backend backend() const override {
33     return at::Backend::CPU;
34   }
35 
launch_rawFusedKernelCPU36   void launch_raw(const uint32_t numel, std::vector<void*>& arguments)
37       const override {
38     kernel(numel, arguments.data());
39   }
40 
41  private:
42   std::unique_ptr<at::DynamicLibrary> so_lib;
43   void (*kernel)(uint32_t, void**) = nullptr;
44 };
45 
46 } // namespace cpu
47 } // namespace fuser
48 } // namespace jit
49 } // namespace torch
50