xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/arg_spec.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/ATen.h>
3 #include <ATen/core/functional.h> // fmap
4 #include <c10/util/hash.h>
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/codegen/fuser/tensor_desc.h>
7 
8 #include <cstdint>
9 #include <vector>
10 
11 namespace torch::jit::fuser {
12 
13 // Describes the (runtime) arguments to a kernel.
14 // ArgSpecs are also used as keys to lookup instantiated kernels, so
15 //  they are hashable.
16 // Note: the device to run on is included in the arg spec because kernels
17 //  are compiled per-device.
18 struct TORCH_API ArgSpec {
19   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
ArgSpecArgSpec20   ArgSpec(at::TensorList inputs, const int _device)
21       : descs_{c10::fmap<TensorDesc>(inputs)},
22         hash_code_{c10::get_hash(_device, inputs.size(), descs_)},
23         device_{_device} {}
24 
25   // (Common) hash function
hashArgSpec26   static size_t hash(const ArgSpec& spec) {
27     return spec.hash_code_;
28   }
29 
30   // Comparators
31   bool operator==(const ArgSpec& other) const {
32     return (descs_ == other.descs_ && device_ == other.device_);
33   }
34 
35   bool operator!=(const ArgSpec& spec) const {
36     return !(*this == spec);
37   }
38 
39   // Getters
hashCodeArgSpec40   size_t hashCode() const {
41     return hash_code_;
42   }
descsArgSpec43   const std::vector<TensorDesc>& descs() const {
44     return descs_;
45   }
deviceArgSpec46   int device() const {
47     return device_;
48   }
49 
50  private:
51   std::vector<TensorDesc> descs_;
52   size_t hash_code_;
53   int device_;
54 };
55 
56 } // namespace torch::jit::fuser
57