1 #pragma once 2 3 #include <c10/util/ArrayRef.h> 4 #include <torch/csrc/jit/api/function_impl.h> 5 #include <torch/csrc/jit/ir/ir.h> 6 #include <torch/csrc/lazy/backend/lowering_context.h> 7 #include <torch/csrc/lazy/core/ir.h> 8 #include <torch/csrc/lazy/core/shape.h> 9 #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h> 10 11 namespace torch { 12 namespace lazy { 13 14 using TSOpVector = std::vector<torch::jit::Value*>; 15 16 class TORCH_API TsNode : public lazy::Node { 17 public: 18 TsNode( 19 OpKind op, 20 OpList operands, 21 std::vector<Shape>&& shapes, 22 size_t num_outputs, 23 hash_t hash_seed = kHashSeed); 24 25 TsNode( 26 OpKind op, 27 OpList operands, 28 const std::function<Shape()>& shape_fn, 29 size_t num_outputs, 30 hash_t hash_seed = kHashSeed); 31 32 TsNode( 33 OpKind op, 34 OpList operands, 35 size_t num_outputs, 36 hash_t hash_seed = kHashSeed); 37 38 TsNode( 39 OpKind op, 40 Shape shape, 41 size_t num_outputs, 42 hash_t hash_seed = kHashSeed); 43 44 ~TsNode() override = default; 45 46 hash_t hash() const override; 47 48 hash_t shapeHash() const override; 49 50 const std::string getPythonStacktrace() const; 51 52 // Lower is a backend-specific method since it returns a backend specific 53 // type. hence, it is convenient to define it differently per-backend rather 54 // than at Node API 55 virtual TSOpVector Lower( 56 std::shared_ptr<torch::jit::GraphFunction> function, 57 TSLoweringContext* loctx) const; 58 59 private: 60 // The hash of the dag WITH size info. Used for shape caching 61 hash_t shape_hash_; 62 // The hash of the dag used to look up the compiled graph by a hash 63 // in this case, we will use the dag hash WITHOUT size info if dynamic shape 64 // is enabled and use the dag hash WITH size info otherwise. 65 hash_t dag_hash_; 66 }; 67 68 // Note: this OpKind is separate from ltc_ops.h since it would be a circular 69 // import otherwise, I like leaving TensorList in this file, and I think most of 70 // ltc_ops special cases will be deleted anyway 71 const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list"); 72 73 // TensorList represents an at::TensorList which is a vector[Tensor] but is also 74 // a first-class IValue and can be fed as a single input to a TS program. It is 75 // much easier to handle TensorLists in Lazy Tensor code if they are represented 76 // as a single Node so there can be more than one TensorList and more than one 77 // Tensor side-by-side as operands to an op. 78 // 79 // Note: shape is undefined for TensorList. We assert in some places that 80 // #shapes matches #outputs and this stems from 81 // the fact that currently all IR nodes represent tensors (there is no 82 // type system for this IR). Becuase of this, TensorList is a bit of a 83 // hack. 84 // 85 // TODO(whc) once Shape() API is moved to Node base, also make it virtual, and 86 // then implement it as NotImplemented for TensorList, also fixing the assertion 87 // that would fail. 88 struct TORCH_API TensorList : public TsNode { ClassOpKindTensorList89 static OpKind ClassOpKind() { 90 return tensor_list_opkind; 91 } 92 93 TensorList() = delete; 94 TensorList(OpList values); 95 CanBeReusedTensorList96 bool CanBeReused(OpList values) const { 97 return operands() == std::vector<Output>(values.begin(), values.end()); 98 } 99 100 TSOpVector Lower( 101 std::shared_ptr<torch::jit::GraphFunction> function, 102 TSLoweringContext* loctx) const override; 103 }; 104 105 } // namespace lazy 106 } // namespace torch 107