xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/symbolic_shape_cache.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
2 #include <torch/csrc/jit/passes/symbolic_shape_cache.h>
3 #include <torch/csrc/lazy/core/cache.h>
4 
5 #include <utility>
6 
7 // SHAPE CACHING CODE
8 
9 namespace torch::jit {
10 namespace {
11 using CanonicalArg = std::variant<CanonicalizedSymbolicShape, IValue>;
12 using CanonicalArgVec = std::vector<CanonicalArg>;
13 using CanonicalRet = std::vector<CanonicalizedSymbolicShape>;
14 using ShapeCacheKey = std::tuple<c10::OperatorName, CanonicalArgVec>;
15 
cannonicalizeVec(const std::vector<SSAInput> & arg_vec,std::unordered_map<int64_t,int64_t> & ss_map,bool deep_copy=true)16 CanonicalArgVec cannonicalizeVec(
17     const std::vector<SSAInput>& arg_vec,
18     std::unordered_map<int64_t, int64_t>& ss_map,
19     bool deep_copy = true) {
20   CanonicalArgVec canonical_args;
21   canonical_args.reserve(arg_vec.size());
22   for (auto& arg : arg_vec) {
23     if (const IValue* iv = std::get_if<IValue>(&arg)) {
24       if (deep_copy) {
25         canonical_args.emplace_back(iv->deepcopy());
26       } else {
27         canonical_args.emplace_back(*iv);
28       }
29     } else {
30       auto& ss = std::get<at::SymbolicShape>(arg);
31       canonical_args.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
32     }
33   }
34   return canonical_args;
35 }
36 
cannonicalizeVec(const std::vector<at::SymbolicShape> & ret_vec,std::unordered_map<int64_t,int64_t> & ss_map)37 std::vector<CanonicalizedSymbolicShape> cannonicalizeVec(
38     const std::vector<at::SymbolicShape>& ret_vec,
39     std::unordered_map<int64_t, int64_t>& ss_map) {
40   std::vector<CanonicalizedSymbolicShape> canonical_rets;
41   canonical_rets.reserve(ret_vec.size());
42   for (auto& ss : ret_vec) {
43     canonical_rets.emplace_back(ss, ss_map);
44   }
45   return canonical_rets;
46 }
47 
48 struct ArgumentsHasher {
operator ()torch::jit::__anon33b674b20111::ArgumentsHasher49   size_t operator()(const ShapeCacheKey& cacheKey) const {
50     // TODO: ignore arguments that are not used in shape function (not needed
51     // initially)
52     auto& op_name = std::get<0>(cacheKey);
53     auto& arg_vec = std::get<1>(cacheKey);
54 
55     size_t hash_val = c10::hash<c10::OperatorName>()(op_name);
56 
57     hash_val = at::hash_combine(std::hash<size_t>{}(arg_vec.size()), hash_val);
58     for (const CanonicalArg& arg : arg_vec) {
59       size_t cur_arg = 0;
60       if (const IValue* ival = std::get_if<IValue>(&arg)) {
61         // IValue doesn't hash List (as Python doesn't), so we will do a custom
62         // list hash
63         if (ival->isList()) {
64           TORCH_INTERNAL_ASSERT(ival->isIntList(), "Unexpected Args in List");
65           cur_arg = ival->toListRef().size();
66           for (const IValue& elem_ival : ival->toListRef()) {
67             cur_arg = at::hash_combine(cur_arg, IValue::hash(elem_ival));
68           }
69         } else {
70           cur_arg = IValue::hash(ival);
71         }
72       } else {
73         cur_arg = std::get<CanonicalizedSymbolicShape>(arg).hash();
74       }
75       hash_val = at::hash_combine(hash_val, cur_arg);
76     }
77     return hash_val;
78   }
79 };
80 
81 using ShapeCache = lazy::Cache<
82     ShapeCacheKey,
83     std::vector<CanonicalizedSymbolicShape>,
84     ArgumentsHasher>;
85 
86 constexpr size_t kShapeCacheSize = 1024;
87 ShapeCache shapeCache(kShapeCacheSize);
88 
get_cache_key(const FunctionSchema * schema,const std::vector<SSAInput> & arg_vec,std::unordered_map<int64_t,int64_t> & ss_map,bool deep_copy=true)89 ShapeCacheKey get_cache_key(
90     const FunctionSchema* schema,
91     const std::vector<SSAInput>& arg_vec,
92     std::unordered_map<int64_t, int64_t>& ss_map,
93     bool deep_copy = true) {
94   CanonicalArgVec canonical_args = cannonicalizeVec(arg_vec, ss_map, deep_copy);
95   return std::make_tuple(schema->operator_name(), canonical_args);
96 }
97 
98 } // namespace
99 
cache_shape_function(const FunctionSchema * schema,const std::vector<SSAInput> & arg_vec,const std::vector<at::SymbolicShape> & ret_vec)100 TORCH_API void cache_shape_function(
101     const FunctionSchema* schema,
102     const std::vector<SSAInput>& arg_vec,
103     const std::vector<at::SymbolicShape>& ret_vec) {
104   // TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>>
105   auto ss_map = std::unordered_map<int64_t, int64_t>();
106   auto cache_key = get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ true);
107   auto can_ret_vec = std::make_shared<std::vector<CanonicalizedSymbolicShape>>(
108       cannonicalizeVec(ret_vec, ss_map));
109   shapeCache.Add(std::move(cache_key), std::move(can_ret_vec));
110 }
111 
112 TORCH_API std::optional<std::vector<at::SymbolicShape>>
get_cached_shape_function(const FunctionSchema * schema,const std::vector<SSAInput> & arg_vec)113 get_cached_shape_function(
114     const FunctionSchema* schema,
115     const std::vector<SSAInput>& arg_vec) {
116   // TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>> for both
117   // ss_map and inverse_ss_map
118   auto ss_map = std::unordered_map<int64_t, int64_t>();
119   auto cache_key =
120       get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ false);
121   auto cached_ret_vec = shapeCache.Get(cache_key);
122   if (cached_ret_vec == nullptr) {
123     return std::nullopt;
124   }
125   // Decanonicalize the return values
126   auto inverse_ss_map = std::unordered_map<int64_t, int64_t>();
127   for (auto& ss_val : ss_map) {
128     inverse_ss_map[ss_val.second] = ss_val.first;
129   }
130   std::vector<at::SymbolicShape> ret_vec;
131   for (auto& css : *cached_ret_vec) {
132     ret_vec.emplace_back(css.toSymbolicShape(inverse_ss_map));
133   }
134   return ret_vec;
135 }
136 
137 // Function only to access the cache, used for testing
clear_shape_cache()138 TORCH_API void clear_shape_cache() {
139   shapeCache.Clear();
140 }
141 
get_shape_cache_size()142 TORCH_API size_t get_shape_cache_size() {
143   return shapeCache.Numel();
144 }
145 
init(const c10::SymbolicShape & orig_shape,std::unordered_map<int64_t,int64_t> & ss_map)146 void CanonicalizedSymbolicShape::init(
147     const c10::SymbolicShape& orig_shape,
148     std::unordered_map<int64_t, int64_t>& ss_map) {
149   auto sizes = orig_shape.sizes();
150   if (!sizes) {
151     values_ = std::nullopt;
152     return;
153   }
154   values_ = std::vector<int64_t>();
155   int64_t cur_symbolic_index = -static_cast<int64_t>(ss_map.size()) - 1;
156   for (auto& cur_shape : *sizes) {
157     if (cur_shape.is_static()) {
158       values_->push_back(cur_shape.static_size());
159     } else {
160       // Check for aliasing
161       auto it = ss_map.find(cur_shape.value());
162 
163       if (it == ss_map.end()) {
164         values_->push_back(cur_symbolic_index);
165         ss_map.insert({cur_shape.value(), cur_symbolic_index});
166         cur_symbolic_index--;
167       } else {
168         values_->push_back(it->second);
169       }
170     }
171   }
172 }
173 
toSymbolicShape(std::unordered_map<int64_t,int64_t> & inverse_ss_map) const174 c10::SymbolicShape CanonicalizedSymbolicShape::toSymbolicShape(
175     std::unordered_map<int64_t, int64_t>& inverse_ss_map) const {
176   if (!values_.has_value()) {
177     return c10::SymbolicShape();
178   }
179   std::vector<at::ShapeSymbol> sizes;
180   for (long long cur_val : *values_) {
181     if (cur_val >= 0) {
182       sizes.push_back(at::ShapeSymbol::fromStaticSize(cur_val));
183       continue;
184     }
185     auto res = inverse_ss_map.find(cur_val);
186     if (res != inverse_ss_map.end()) {
187       sizes.push_back(at::ShapeSymbol::fromStaticSize(res->second));
188     } else {
189       auto new_symbol = at::ShapeSymbol::newSymbol();
190       inverse_ss_map.insert({cur_val, new_symbol.value()});
191       sizes.push_back(new_symbol);
192     }
193   }
194   return c10::SymbolicShape(std::move(sizes));
195 }
196 
hash() const197 size_t CanonicalizedSymbolicShape::hash() const {
198   if (!values_.has_value()) {
199     return 0x8cc80c80; // random value to prevent hash collisions
200   }
201   return c10::hash<std::vector<int64_t>>()(values_.value());
202 }
203 
operator ==(const CanonicalizedSymbolicShape & a,const CanonicalizedSymbolicShape & b)204 bool operator==(
205     const CanonicalizedSymbolicShape& a,
206     const CanonicalizedSymbolicShape& b) {
207   return a.values_ == b.values_;
208 };
209 } // namespace torch::jit
210