1 #pragma once 2 3 #include <torch/csrc/lazy/core/ir.h> 4 5 #include <c10/util/CallOnce.h> 6 7 #include <mutex> 8 #include <string> 9 10 namespace torch { 11 namespace lazy { 12 13 class TORCH_API OpKindWrapper { 14 public: OpKindWrapper(const char * name)15 explicit OpKindWrapper(const char* name) : name_(name) {} 16 17 const OpKind& operator*() const { 18 return get(); 19 } 20 OpKind()21 operator OpKind() const { 22 return get(); 23 } 24 25 private: get()26 const OpKind& get() const { 27 c10::call_once(once_, [this]() { op_kind_ = OpKind::Get(name_); }); 28 return op_kind_; 29 } 30 31 const char* name_; 32 mutable OpKind op_kind_; 33 mutable c10::once_flag once_; 34 }; 35 36 const OpKindWrapper ltc_all_to_all("lazy_tensors::all_to_all"); 37 const OpKindWrapper ltc_cast("lazy_tensors::cast"); 38 const OpKindWrapper ltc_collective_permute("lazy_tensors::collective_permute"); 39 const OpKindWrapper ltc_cross_replica_sum("lazy_tensors::cross_replica_sum"); 40 const OpKindWrapper ltc_device_data("lazy_tensors::device_data"); 41 const OpKindWrapper ltc_get_dimensions_size( 42 "lazy_tensors::ltc_get_dimensions_size"); 43 const OpKindWrapper ltc_moving_average("lazy_tensors::moving_average"); 44 const OpKindWrapper ltc_nms("lazy_tensors::nms"); 45 const OpKindWrapper ltc_not_supported("lazy_tensors::not_supported"); 46 const OpKindWrapper ltc_replication_pad("lazy_tensors::replication_pad"); 47 const OpKindWrapper ltc_replication_pad_backward( 48 "lazy_tensors::replication_pad_backward"); 49 const OpKindWrapper ltc_tensor_data("lazy_tensors::tensor_data"); 50 51 } // namespace lazy 52 } // namespace torch 53