xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ts_node.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/ArrayRef.h>
4 #include <torch/csrc/jit/api/function_impl.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/lazy/backend/lowering_context.h>
7 #include <torch/csrc/lazy/core/ir.h>
8 #include <torch/csrc/lazy/core/shape.h>
9 #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
10 
11 namespace torch {
12 namespace lazy {
13 
14 using TSOpVector = std::vector<torch::jit::Value*>;
15 
16 class TORCH_API TsNode : public lazy::Node {
17  public:
18   TsNode(
19       OpKind op,
20       OpList operands,
21       std::vector<Shape>&& shapes,
22       size_t num_outputs,
23       hash_t hash_seed = kHashSeed);
24 
25   TsNode(
26       OpKind op,
27       OpList operands,
28       const std::function<Shape()>& shape_fn,
29       size_t num_outputs,
30       hash_t hash_seed = kHashSeed);
31 
32   TsNode(
33       OpKind op,
34       OpList operands,
35       size_t num_outputs,
36       hash_t hash_seed = kHashSeed);
37 
38   TsNode(
39       OpKind op,
40       Shape shape,
41       size_t num_outputs,
42       hash_t hash_seed = kHashSeed);
43 
44   ~TsNode() override = default;
45 
46   hash_t hash() const override;
47 
48   hash_t shapeHash() const override;
49 
50   const std::string getPythonStacktrace() const;
51 
52   // Lower is a backend-specific method since it returns a backend specific
53   // type. hence, it is convenient to define it differently per-backend rather
54   // than at Node API
55   virtual TSOpVector Lower(
56       std::shared_ptr<torch::jit::GraphFunction> function,
57       TSLoweringContext* loctx) const;
58 
59  private:
60   // The hash of the dag WITH size info. Used for shape caching
61   hash_t shape_hash_;
62   // The hash of the dag used to look up the compiled graph by a hash
63   // in this case, we will use the dag hash WITHOUT size info if dynamic shape
64   // is enabled and use the dag hash WITH size info otherwise.
65   hash_t dag_hash_;
66 };
67 
68 // Note: this OpKind is separate from ltc_ops.h since it would be a circular
69 // import otherwise, I like leaving TensorList in this file, and I think most of
70 // ltc_ops special cases will be deleted anyway
71 const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
72 
73 // TensorList represents an at::TensorList which is a vector[Tensor] but is also
74 // a first-class IValue and can be fed as a single input to a TS program.  It is
75 // much easier to handle TensorLists in Lazy Tensor code if they are represented
76 // as a single Node so there can be more than one TensorList and more than one
77 // Tensor side-by-side as operands to an op.
78 //
79 // Note: shape is undefined for TensorList.  We assert in some places that
80 // #shapes matches #outputs and this stems from
81 //       the fact that currently all IR nodes represent tensors (there is no
82 //       type system for this IR).  Becuase of this, TensorList is a bit of a
83 //       hack.
84 //
85 // TODO(whc) once Shape() API is moved to Node base, also make it virtual, and
86 // then implement it as NotImplemented for TensorList, also fixing the assertion
87 // that would fail.
88 struct TORCH_API TensorList : public TsNode {
ClassOpKindTensorList89   static OpKind ClassOpKind() {
90     return tensor_list_opkind;
91   }
92 
93   TensorList() = delete;
94   TensorList(OpList values);
95 
CanBeReusedTensorList96   bool CanBeReused(OpList values) const {
97     return operands() == std::vector<Output>(values.begin(), values.end());
98   }
99 
100   TSOpVector Lower(
101       std::shared_ptr<torch::jit::GraphFunction> function,
102       TSLoweringContext* loctx) const override;
103 };
104 
105 } // namespace lazy
106 } // namespace torch
107