xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/versioned_symbols.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/versioned_symbols.h>
2 
3 #include <caffe2/serialize/versions.h>
4 #include <torch/csrc/api/include/torch/jit.h>
5 
6 #include <unordered_map>
7 
8 namespace torch::jit {
9 // Note [Versioned Symbols]
10 // When the schema or behavior of a symbol changes, serialized Torchscript
11 // programs using that symbol are likely to break. To prevent those breaks,
12 // the symbol's historic behavior can be implemented as a Torchscript builtin
13 // and when an older Torchscript program is loaded the program's uses of the
14 // symbol can be replaced with the builtin.
15 //
16 // For example, a function _test_serialization_subcmul(a, b, alpha) might have
17 // been improperly implemented as (b - alpha * a).
18 // Some users may have written and serialized programs using that function,
19 // however, and fixing it to perform (a - alpha * b) would break their programs.
20 // Using the "Versioned Symbol" pattern lets you replace
21 // _test_serialization_subcmul in older programs with a builtin
22 // _test_serialization_subcmul<version_range> that implements the historic
23 // behavior. That way old programs preserve their semantics while new programs
24 // can take advantage of the fix.
25 //
26 // To do this:
27 //
28 // 1) Identify the file version range where the symbol should be replaced,
29 //    e.g. versions 0 to 2, inclusive.
30 // 2) Create one or more builtins implementing the symbol's historic behavior.
31 //    These should be named <function>_<start_version>_<end_version> and
32 //    go into the "upgraders" namespace.
33 //    For example, the test-only aten::_test_serialization_subcmul has a builtin
34 //    for its "historic" behavior called
35 //    upgraders::_test_serialization_subcmul_0_2.
36 // 3) Add a mapping from the symbol to the corresponding SymbolRange
37 //    in the symbol_range_map (below).
38 //
39 // To test your versioning:
40 //
41 // 1) Serialize a module demonstrating the historic behavior.
42 // 2) Save it to test/jit/fixtures.
43 // 3) Implement your new behavior and bump the version counter.
44 // 4) Write the builtins and extend the symbol_range_map per the above
45 //    instructions.
46 // 5) Create a test in jit/test_save_load.py that loads the old module
47 //    and verifies it exhibits the historic behavior, then saves and
48 //    loads the same module and verifies it exhibits the current behavior.
49 //    See test_versioned_symbols for an example.
50 
51 // Helper to hold the version range (inclusive on both ends) and the symbol
52 // to map to for that range.
53 struct SymbolRange {
SymbolRangetorch::jit::SymbolRange54   SymbolRange(
55       const uint64_t _start_version,
56       const uint64_t _end_version,
57       const Symbol _sym)
58       : start_version_{_start_version},
59         end_version_{_end_version},
60         sym_{_sym} {}
61   const uint64_t start_version_;
62   const uint64_t end_version_;
63   const Symbol sym_;
64 };
65 
66 static std::unordered_map<Symbol, SymbolRange> symbol_range_map({
67     {Symbol::fromQualString("aten::_test_serialization_subcmul"),
68      {0,
69       2,
70       Symbol::fromQualString("upgraders::_test_serialization_subcmul_0_2")}},
71     {Symbol::fromQualString("aten::div"),
72      {0, 3, Symbol::fromQualString("upgraders::div_0_3")}},
73     {Symbol::fromQualString("aten::div_"),
74      {0, 3, Symbol::fromQualString("upgraders::div__0_3")}},
75     {Symbol::fromQualString("aten::full"),
76      {0, 4, Symbol::fromQualString("upgraders::full_0_4")}},
77 });
78 
79 static std::unordered_map<NodeKind, uint64_t> kind_min_version_map({
80     {aten::div, 4},
81     {aten::div_, 4},
82     {aten::full, 5}, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
83 });
84 
get_symbol_for_version(const Symbol name,const uint64_t version)85 Symbol get_symbol_for_version(const Symbol name, const uint64_t version) {
86   auto it = symbol_range_map.find(name);
87   if (it == symbol_range_map.end()) {
88     return name;
89   }
90 
91   auto& entry = it->second;
92   if (entry.start_version_ <= version && entry.end_version_ >= version) {
93     return entry.sym_;
94   }
95 
96   return name;
97 }
98 
get_min_version_for_kind(const NodeKind & kind)99 uint64_t get_min_version_for_kind(const NodeKind& kind) {
100   auto it = kind_min_version_map.find(kind);
101   if (it == kind_min_version_map.end()) {
102     return 0;
103   }
104 
105   return it->second;
106 }
107 
108 } // namespace torch::jit
109