xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/dynamic_ir.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/symbol.h>
4 
5 #include <functional>
6 #include <memory>
7 #include <set>
8 #include <string>
9 #include <unordered_map>
10 #include <unordered_set>
11 #include <utility>
12 #include <vector>
13 
14 #include <c10/core/ScalarType.h>
15 #include <c10/util/Flags.h>
16 #include <torch/csrc/lazy/core/hash.h>
17 #include <torch/csrc/lazy/core/ir.h>
18 #include <torch/csrc/lazy/core/ir_metadata.h>
19 #include <torch/csrc/lazy/ts_backend/ts_node.h>
20 
21 namespace torch {
22 namespace lazy {
23 
24 /**
25  * The goal of "dynamic" Nodes is to patch a hole in our tracing.
26  * Previously, if a user called `sizes` on a Tensor, it would leak out
27  * of our tracing system, as `sizes` returns a torch.Size or an int. To
28  * prevent this from happening, we introduce DimensionNode, a new type
29  * of Node that abstracts the operation of getting the dimensions of a
30  * Tensor.
31  *
32  * Consider the following example:
33  * ```
34  * numel = x.shape()[0] * x.shape()[1]
35  * ```
36  *
37  * Here, `x.shape()[i]` will be a SizeNode (subclass of DimensionNode),
38  * and the multiplication of the two SizeNodes will be represented by
39  * a SizeMul (also a subclass of DimensionNode). Through this, we can
40  * prevent `numel` from being represented as a Python int and thus
41  * burned into the Graph.
42  */
43 
44 class TORCH_API DimensionNode {
45  public:
isSymbolic()46   virtual bool isSymbolic() const {
47     return false;
48   };
getDynamicValue()49   virtual int64_t getDynamicValue() const {
50     TORCH_CHECK(false, "NYI");
51   };
getStaticValue()52   virtual int64_t getStaticValue() const {
53     TORCH_CHECK(false, "NYI");
54   };
55   virtual ~DimensionNode() = default;
56 };
57 
58 } // namespace lazy
59 } // namespace torch
60