xref: /aosp_15_r20/external/pytorch/test/cpp/lazy/test_trie_cache.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/lazy/core/config.h>
5 #include <torch/csrc/lazy/core/ir.h>
6 #include <torch/csrc/lazy/core/ir_builder.h>
7 #include <torch/csrc/lazy/core/ir_metadata.h>
8 #include <torch/csrc/lazy/core/ir_util.h>
9 #include <memory>
10 
11 namespace torch {
12 namespace lazy {
13 
14 class TrieCacheNode : public Node {
15  public:
ClassOpKind()16   static OpKind ClassOpKind() {
17     return OpKind();
18   }
19 
TrieCacheNode(size_t id)20   explicit TrieCacheNode(size_t id)
21       : Node(ClassOpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
22   ~TrieCacheNode() override = default;
23 
CanBeReused(size_t id) const24   bool CanBeReused(size_t id) const {
25     return (id_ == id);
26   }
27 
AddOperand(Value v)28   void AddOperand(Value v) {
29     if (!v.node) {
30       return;
31     }
32     operands_as_outputs_.emplace_back(v.node.get(), v.index);
33     operands_.push_back(std::move(v.node));
34   }
35 
hash() const36   hash_t hash() const override {
37     return hash_;
38   }
shapeHash() const39   hash_t shapeHash() const override {
40     return hash_;
41   }
42 
43  private:
44   size_t id_;
45   hash_t hash_;
46 };
47 
TEST(TrieCacheTest,TestSinglePath)48 TEST(TrieCacheTest, TestSinglePath) {
49   FLAGS_torch_lazy_reuse_ir = true;
50   TrieCache::Get()->Clear();
51 
52   NodePtr a = ReuseOrMakeNode<TrieCacheNode>(0);
53   NodePtr b = ReuseOrMakeNode<TrieCacheNode>(1);
54   NodePtr c = ReuseOrMakeNode<TrieCacheNode>(2);
55   TrieCache::Get()->ResetCurrent(); // MarkStep
56 
57   EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
58   EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
59   EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
60   TrieCache::Get()->ResetCurrent(); // MarkStep
61 }
62 
63 /*
64  *    0
65  *    |
66  *    1
67  *   / \
68  *  2   3
69  */
TEST(TrieCacheTest,TestTwoPaths)70 TEST(TrieCacheTest, TestTwoPaths) {
71   FLAGS_torch_lazy_reuse_ir = true;
72   TrieCache::Get()->Clear();
73 
74   NodePtr a = ReuseOrMakeNode<TrieCacheNode>(0);
75   NodePtr b = ReuseOrMakeNode<TrieCacheNode>(1);
76   NodePtr c = ReuseOrMakeNode<TrieCacheNode>(2);
77   TrieCache::Get()->ResetCurrent(); // MarkStep
78 
79   EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
80   EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
81   NodePtr d = ReuseOrMakeNode<TrieCacheNode>(3);
82   EXPECT_NE(d.get(), c.get());
83   TrieCache::Get()->ResetCurrent(); // MarkStep
84 
85   EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
86   EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
87   EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(3).get(), d.get());
88   TrieCache::Get()->ResetCurrent(); // MarkStep
89 
90   EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
91   EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
92   EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
93   TrieCache::Get()->ResetCurrent(); // MarkStep
94 }
95 
96 } // namespace lazy
97 } // namespace torch
98