1 #include <torch/csrc/lazy/core/debug_util.h>
2 #include <torch/csrc/lazy/ts_backend/ts_node.h>
3
4 namespace {
GetFirstUserFrameInPythonIfEnabled()5 std::string GetFirstUserFrameInPythonIfEnabled() {
6 static const auto LTC_ENABLE_SOURCE_INFO =
7 std::getenv("LTC_ENABLE_SOURCE_INFO");
8 if (!LTC_ENABLE_SOURCE_INFO) {
9 return {};
10 }
11
12 return torch::lazy::GetFirstUserFrameInPython();
13 }
14 } // namespace
15
16 namespace torch {
17 namespace lazy {
18
OperandHashes(const OpList & operands,const c10::ArrayRef<Shape> & shapes,const hash_t & seed,bool bakeInSizes)19 static hash_t OperandHashes(
20 const OpList& operands,
21 const c10::ArrayRef<Shape>& shapes,
22 const hash_t& seed,
23 bool bakeInSizes) {
24 hash_t hash = seed;
25 for (auto& operand : operands) {
26 if (!operand) {
27 hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
28 continue;
29 }
30 auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash();
31 hash = HashCombine(hash, operand_hash);
32 }
33 for (auto& shape : shapes) {
34 hash = HashCombine(hash, shape.hash(bakeInSizes));
35 }
36 return hash;
37 }
38
TsNode(OpKind op,OpList operands,std::vector<Shape> && shapes,size_t num_outputs,hash_t hash_seed)39 TsNode::TsNode(
40 OpKind op,
41 OpList operands,
42 std::vector<Shape>&& shapes,
43 size_t num_outputs,
44 hash_t hash_seed)
45 : Node(op, operands, std::move(shapes), num_outputs) {
46 hash_seed = HashCombine(op.hash(), hash_seed);
47 shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
48 dag_hash_ =
49 (enableDynamicShape()
50 ? OperandHashes(operands, this->shapes(), hash_seed, false)
51 : shape_hash_);
52 }
53
TsNode(OpKind op,OpList operands,const std::function<Shape ()> & shape_fn,size_t num_outputs,hash_t hash_seed)54 TsNode::TsNode(
55 OpKind op,
56 OpList operands,
57 const std::function<Shape()>& shape_fn,
58 size_t num_outputs,
59 hash_t hash_seed)
60 : TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
61 addComputedShape(shape_fn);
62 }
63
TsNode(OpKind op,OpList operands,size_t num_outputs,hash_t hash_seed)64 TsNode::TsNode(OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
65 : TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
66
TsNode(OpKind op,Shape shape,size_t num_outputs,hash_t hash_seed)67 TsNode::TsNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
68 : TsNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}
69
hash() const70 hash_t TsNode::hash() const {
71 return dag_hash_;
72 }
73
shapeHash() const74 hash_t TsNode::shapeHash() const {
75 return shape_hash_;
76 }
77
getPythonStacktrace() const78 const std::string TsNode::getPythonStacktrace() const {
79 return GetFirstUserFrameInPythonIfEnabled();
80 }
81
TensorList(OpList values)82 TensorList::TensorList(OpList values)
83 : TsNode(
84 /*op=*/ClassOpKind(),
85 /*operands=*/values,
86 /*shapes=*/std::vector<Shape>(),
87 /*num_outputs=*/1,
88 /*hash_seed=*/kHashSeed) {}
89
Lower(std::shared_ptr<torch::jit::GraphFunction> function,TSLoweringContext * loctx) const90 TSOpVector TensorList::Lower(
91 std::shared_ptr<torch::jit::GraphFunction> function,
92 TSLoweringContext* loctx) const {
93 std::vector<torch::jit::Value*> tensor_list;
94 TORCH_CHECK(!operands().empty());
95 for (const torch::lazy::Output& operand : operands()) {
96 tensor_list.emplace_back(loctx->GetOutputOp(operand));
97 }
98 auto graph = function->graph();
99 auto listnode =
100 graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list));
101 return {listnode->output()};
102 }
103
104 } // namespace lazy
105 } // namespace torch
106