xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ts_node.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/core/debug_util.h>
2 #include <torch/csrc/lazy/ts_backend/ts_node.h>
3 
4 namespace {
GetFirstUserFrameInPythonIfEnabled()5 std::string GetFirstUserFrameInPythonIfEnabled() {
6   static const auto LTC_ENABLE_SOURCE_INFO =
7       std::getenv("LTC_ENABLE_SOURCE_INFO");
8   if (!LTC_ENABLE_SOURCE_INFO) {
9     return {};
10   }
11 
12   return torch::lazy::GetFirstUserFrameInPython();
13 }
14 } // namespace
15 
16 namespace torch {
17 namespace lazy {
18 
OperandHashes(const OpList & operands,const c10::ArrayRef<Shape> & shapes,const hash_t & seed,bool bakeInSizes)19 static hash_t OperandHashes(
20     const OpList& operands,
21     const c10::ArrayRef<Shape>& shapes,
22     const hash_t& seed,
23     bool bakeInSizes) {
24   hash_t hash = seed;
25   for (auto& operand : operands) {
26     if (!operand) {
27       hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
28       continue;
29     }
30     auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash();
31     hash = HashCombine(hash, operand_hash);
32   }
33   for (auto& shape : shapes) {
34     hash = HashCombine(hash, shape.hash(bakeInSizes));
35   }
36   return hash;
37 }
38 
TsNode(OpKind op,OpList operands,std::vector<Shape> && shapes,size_t num_outputs,hash_t hash_seed)39 TsNode::TsNode(
40     OpKind op,
41     OpList operands,
42     std::vector<Shape>&& shapes,
43     size_t num_outputs,
44     hash_t hash_seed)
45     : Node(op, operands, std::move(shapes), num_outputs) {
46   hash_seed = HashCombine(op.hash(), hash_seed);
47   shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
48   dag_hash_ =
49       (enableDynamicShape()
50            ? OperandHashes(operands, this->shapes(), hash_seed, false)
51            : shape_hash_);
52 }
53 
TsNode(OpKind op,OpList operands,const std::function<Shape ()> & shape_fn,size_t num_outputs,hash_t hash_seed)54 TsNode::TsNode(
55     OpKind op,
56     OpList operands,
57     const std::function<Shape()>& shape_fn,
58     size_t num_outputs,
59     hash_t hash_seed)
60     : TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
61   addComputedShape(shape_fn);
62 }
63 
TsNode(OpKind op,OpList operands,size_t num_outputs,hash_t hash_seed)64 TsNode::TsNode(OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
65     : TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
66 
TsNode(OpKind op,Shape shape,size_t num_outputs,hash_t hash_seed)67 TsNode::TsNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
68     : TsNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}
69 
hash() const70 hash_t TsNode::hash() const {
71   return dag_hash_;
72 }
73 
shapeHash() const74 hash_t TsNode::shapeHash() const {
75   return shape_hash_;
76 }
77 
getPythonStacktrace() const78 const std::string TsNode::getPythonStacktrace() const {
79   return GetFirstUserFrameInPythonIfEnabled();
80 }
81 
TensorList(OpList values)82 TensorList::TensorList(OpList values)
83     : TsNode(
84           /*op=*/ClassOpKind(),
85           /*operands=*/values,
86           /*shapes=*/std::vector<Shape>(),
87           /*num_outputs=*/1,
88           /*hash_seed=*/kHashSeed) {}
89 
Lower(std::shared_ptr<torch::jit::GraphFunction> function,TSLoweringContext * loctx) const90 TSOpVector TensorList::Lower(
91     std::shared_ptr<torch::jit::GraphFunction> function,
92     TSLoweringContext* loctx) const {
93   std::vector<torch::jit::Value*> tensor_list;
94   TORCH_CHECK(!operands().empty());
95   for (const torch::lazy::Output& operand : operands()) {
96     tensor_list.emplace_back(loctx->GetOutputOp(operand));
97   }
98   auto graph = function->graph();
99   auto listnode =
100       graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list));
101   return {listnode->output()};
102 }
103 
104 } // namespace lazy
105 } // namespace torch
106