xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/internal_ops/ltc_ops.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/lazy/core/ir.h>
4 
5 #include <c10/util/CallOnce.h>
6 
7 #include <mutex>
8 #include <string>
9 
10 namespace torch {
11 namespace lazy {
12 
13 class TORCH_API OpKindWrapper {
14  public:
OpKindWrapper(const char * name)15   explicit OpKindWrapper(const char* name) : name_(name) {}
16 
17   const OpKind& operator*() const {
18     return get();
19   }
20 
OpKind()21   operator OpKind() const {
22     return get();
23   }
24 
25  private:
get()26   const OpKind& get() const {
27     c10::call_once(once_, [this]() { op_kind_ = OpKind::Get(name_); });
28     return op_kind_;
29   }
30 
31   const char* name_;
32   mutable OpKind op_kind_;
33   mutable c10::once_flag once_;
34 };
35 
36 const OpKindWrapper ltc_all_to_all("lazy_tensors::all_to_all");
37 const OpKindWrapper ltc_cast("lazy_tensors::cast");
38 const OpKindWrapper ltc_collective_permute("lazy_tensors::collective_permute");
39 const OpKindWrapper ltc_cross_replica_sum("lazy_tensors::cross_replica_sum");
40 const OpKindWrapper ltc_device_data("lazy_tensors::device_data");
41 const OpKindWrapper ltc_get_dimensions_size(
42     "lazy_tensors::ltc_get_dimensions_size");
43 const OpKindWrapper ltc_moving_average("lazy_tensors::moving_average");
44 const OpKindWrapper ltc_nms("lazy_tensors::nms");
45 const OpKindWrapper ltc_not_supported("lazy_tensors::not_supported");
46 const OpKindWrapper ltc_replication_pad("lazy_tensors::replication_pad");
47 const OpKindWrapper ltc_replication_pad_backward(
48     "lazy_tensors::replication_pad_backward");
49 const OpKindWrapper ltc_tensor_data("lazy_tensors::tensor_data");
50 
51 } // namespace lazy
52 } // namespace torch
53