1 #include <gtest/gtest.h>
2
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/lazy/core/config.h>
5 #include <torch/csrc/lazy/core/ir.h>
6 #include <torch/csrc/lazy/core/ir_builder.h>
7 #include <torch/csrc/lazy/core/ir_metadata.h>
8 #include <torch/csrc/lazy/core/ir_util.h>
9 #include <memory>
10
11 namespace torch {
12 namespace lazy {
13
14 class TrieCacheNode : public Node {
15 public:
ClassOpKind()16 static OpKind ClassOpKind() {
17 return OpKind();
18 }
19
TrieCacheNode(size_t id)20 explicit TrieCacheNode(size_t id)
21 : Node(ClassOpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
22 ~TrieCacheNode() override = default;
23
CanBeReused(size_t id) const24 bool CanBeReused(size_t id) const {
25 return (id_ == id);
26 }
27
AddOperand(Value v)28 void AddOperand(Value v) {
29 if (!v.node) {
30 return;
31 }
32 operands_as_outputs_.emplace_back(v.node.get(), v.index);
33 operands_.push_back(std::move(v.node));
34 }
35
hash() const36 hash_t hash() const override {
37 return hash_;
38 }
shapeHash() const39 hash_t shapeHash() const override {
40 return hash_;
41 }
42
43 private:
44 size_t id_;
45 hash_t hash_;
46 };
47
TEST(TrieCacheTest,TestSinglePath)48 TEST(TrieCacheTest, TestSinglePath) {
49 FLAGS_torch_lazy_reuse_ir = true;
50 TrieCache::Get()->Clear();
51
52 NodePtr a = ReuseOrMakeNode<TrieCacheNode>(0);
53 NodePtr b = ReuseOrMakeNode<TrieCacheNode>(1);
54 NodePtr c = ReuseOrMakeNode<TrieCacheNode>(2);
55 TrieCache::Get()->ResetCurrent(); // MarkStep
56
57 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
58 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
59 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
60 TrieCache::Get()->ResetCurrent(); // MarkStep
61 }
62
63 /*
64 * 0
65 * |
66 * 1
67 * / \
68 * 2 3
69 */
TEST(TrieCacheTest,TestTwoPaths)70 TEST(TrieCacheTest, TestTwoPaths) {
71 FLAGS_torch_lazy_reuse_ir = true;
72 TrieCache::Get()->Clear();
73
74 NodePtr a = ReuseOrMakeNode<TrieCacheNode>(0);
75 NodePtr b = ReuseOrMakeNode<TrieCacheNode>(1);
76 NodePtr c = ReuseOrMakeNode<TrieCacheNode>(2);
77 TrieCache::Get()->ResetCurrent(); // MarkStep
78
79 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
80 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
81 NodePtr d = ReuseOrMakeNode<TrieCacheNode>(3);
82 EXPECT_NE(d.get(), c.get());
83 TrieCache::Get()->ResetCurrent(); // MarkStep
84
85 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
86 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
87 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(3).get(), d.get());
88 TrieCache::Get()->ResetCurrent(); // MarkStep
89
90 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
91 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
92 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
93 TrieCache::Get()->ResetCurrent(); // MarkStep
94 }
95
96 } // namespace lazy
97 } // namespace torch
98