1 #pragma once 2 3 #include <ATen/core/symbol.h> 4 5 #include <functional> 6 #include <memory> 7 #include <set> 8 #include <string> 9 #include <unordered_map> 10 #include <unordered_set> 11 #include <utility> 12 #include <vector> 13 14 #include <c10/core/ScalarType.h> 15 #include <c10/util/Flags.h> 16 #include <torch/csrc/lazy/core/hash.h> 17 #include <torch/csrc/lazy/core/ir.h> 18 #include <torch/csrc/lazy/core/ir_metadata.h> 19 #include <torch/csrc/lazy/ts_backend/ts_node.h> 20 21 namespace torch { 22 namespace lazy { 23 24 /** 25 * The goal of "dynamic" Nodes is to patch a hole in our tracing. 26 * Previously, if a user called `sizes` on a Tensor, it would leak out 27 * of our tracing system, as `sizes` returns a torch.Size or an int. To 28 * prevent this from happening, we introduce DimensionNode, a new type 29 * of Node that abstracts the operation of getting the dimensions of a 30 * Tensor. 31 * 32 * Consider the following example: 33 * ``` 34 * numel = x.shape()[0] * x.shape()[1] 35 * ``` 36 * 37 * Here, `x.shape()[i]` will be a SizeNode (subclass of DimensionNode), 38 * and the multiplication of the two SizeNodes will be represented by 39 * a SizeMul (also a subclass of DimensionNode). Through this, we can 40 * prevent `numel` from being represented as a Python int and thus 41 * burned into the Graph. 42 */ 43 44 class TORCH_API DimensionNode { 45 public: isSymbolic()46 virtual bool isSymbolic() const { 47 return false; 48 }; getDynamicValue()49 virtual int64_t getDynamicValue() const { 50 TORCH_CHECK(false, "NYI"); 51 }; getStaticValue()52 virtual int64_t getStaticValue() const { 53 TORCH_CHECK(false, "NYI"); 54 }; 55 virtual ~DimensionNode() = default; 56 }; 57 58 } // namespace lazy 59 } // namespace torch 60