1 #pragma once 2 #include <ATen/ATen.h> 3 #include <ATen/core/functional.h> // fmap 4 #include <c10/util/hash.h> 5 #include <torch/csrc/Export.h> 6 #include <torch/csrc/jit/codegen/fuser/tensor_desc.h> 7 8 #include <cstdint> 9 #include <vector> 10 11 namespace torch::jit::fuser { 12 13 // Describes the (runtime) arguments to a kernel. 14 // ArgSpecs are also used as keys to lookup instantiated kernels, so 15 // they are hashable. 16 // Note: the device to run on is included in the arg spec because kernels 17 // are compiled per-device. 18 struct TORCH_API ArgSpec { 19 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) ArgSpecArgSpec20 ArgSpec(at::TensorList inputs, const int _device) 21 : descs_{c10::fmap<TensorDesc>(inputs)}, 22 hash_code_{c10::get_hash(_device, inputs.size(), descs_)}, 23 device_{_device} {} 24 25 // (Common) hash function hashArgSpec26 static size_t hash(const ArgSpec& spec) { 27 return spec.hash_code_; 28 } 29 30 // Comparators 31 bool operator==(const ArgSpec& other) const { 32 return (descs_ == other.descs_ && device_ == other.device_); 33 } 34 35 bool operator!=(const ArgSpec& spec) const { 36 return !(*this == spec); 37 } 38 39 // Getters hashCodeArgSpec40 size_t hashCode() const { 41 return hash_code_; 42 } descsArgSpec43 const std::vector<TensorDesc>& descs() const { 44 return descs_; 45 } deviceArgSpec46 int device() const { 47 return device_; 48 } 49 50 private: 51 std::vector<TensorDesc> descs_; 52 size_t hash_code_; 53 int device_; 54 }; 55 56 } // namespace torch::jit::fuser 57