xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ops/generic.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/lazy/ts_backend/ts_node.h>
4 
5 #include <torch/csrc/lazy/core/ir_builder.h>
6 
7 namespace torch {
8 namespace lazy {
9 
10 // Generic IR Node implementation for nodes which can simply be described by a
11 // specific OpKind and a lowering function. IR nodes carrying
12 // metadata should not be using this class TORCH_API (and have the metadata
13 // captured by the LowerFn), but they should instead create a dedicated IR node.
14 // Doing the former would limit IR introspection.
15 class TORCH_API Generic : public TsNode {
16  public:
17   Generic(
18       OpKind op,
19       OpList operands,
20       Shape shape,
21       size_t num_outputs = 1,
22       hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9));
23 
24   Generic(
25       OpKind op,
26       OpList operands,
27       const std::function<Shape()>& shape_fn,
28       size_t num_outputs = 1,
29       hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9));
30 
31   Generic(
32       OpKind op,
33       OpList operands,
34       size_t num_outputs = 1,
35       hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9));
36 
37   Generic(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed);
38 
39  private:
40   hash_t hash_seed_;
41 };
42 
43 inline NodePtr GenericOp(
44     OpKind op,
45     OpList operands,
46     Shape shape,
47     size_t num_outputs = 1,
48     hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9)) {
49   return MakeNode<Generic>(
50       op, operands, std::move(shape), num_outputs, hash_seed);
51 }
52 
53 } // namespace lazy
54 } // namespace torch
55