1 #include <torch/csrc/jit/frontend/versioned_symbols.h>
2
3 #include <caffe2/serialize/versions.h>
4 #include <torch/csrc/api/include/torch/jit.h>
5
6 #include <unordered_map>
7
8 namespace torch::jit {
9 // Note [Versioned Symbols]
10 // When the schema or behavior of a symbol changes, serialized Torchscript
11 // programs using that symbol are likely to break. To prevent those breaks,
12 // the symbol's historic behavior can be implemented as a Torchscript builtin
13 // and when an older Torchscript program is loaded the program's uses of the
14 // symbol can be replaced with the builtin.
15 //
16 // For example, a function _test_serialization_subcmul(a, b, alpha) might have
17 // been improperly implemented as (b - alpha * a).
18 // Some users may have written and serialized programs using that function,
19 // however, and fixing it to perform (a - alpha * b) would break their programs.
20 // Using the "Versioned Symbol" pattern lets you replace
21 // _test_serialization_subcmul in older programs with a builtin
22 // _test_serialization_subcmul<version_range> that implements the historic
23 // behavior. That way old programs preserve their semantics while new programs
24 // can take advantage of the fix.
25 //
26 // To do this:
27 //
28 // 1) Identify the file version range where the symbol should be replaced,
29 // e.g. versions 0 to 2, inclusive.
30 // 2) Create one or more builtins implementing the symbol's historic behavior.
31 // These should be named <function>_<start_version>_<end_version> and
32 // go into the "upgraders" namespace.
33 // For example, the test-only aten::_test_serialization_subcmul has a builtin
34 // for its "historic" behavior called
35 // upgraders::_test_serialization_subcmul_0_2.
36 // 3) Add a mapping from the symbol to the corresponding SymbolRange
37 // in the symbol_range_map (below).
38 //
39 // To test your versioning:
40 //
41 // 1) Serialize a module demonstrating the historic behavior.
42 // 2) Save it to test/jit/fixtures.
43 // 3) Implement your new behavior and bump the version counter.
44 // 4) Write the builtins and extend the symbol_range_map per the above
45 // instructions.
46 // 5) Create a test in jit/test_save_load.py that loads the old module
47 // and verifies it exhibits the historic behavior, then saves and
48 // loads the same module and verifies it exhibits the current behavior.
49 // See test_versioned_symbols for an example.
50
51 // Helper to hold the version range (inclusive on both ends) and the symbol
52 // to map to for that range.
53 struct SymbolRange {
SymbolRangetorch::jit::SymbolRange54 SymbolRange(
55 const uint64_t _start_version,
56 const uint64_t _end_version,
57 const Symbol _sym)
58 : start_version_{_start_version},
59 end_version_{_end_version},
60 sym_{_sym} {}
61 const uint64_t start_version_;
62 const uint64_t end_version_;
63 const Symbol sym_;
64 };
65
66 static std::unordered_map<Symbol, SymbolRange> symbol_range_map({
67 {Symbol::fromQualString("aten::_test_serialization_subcmul"),
68 {0,
69 2,
70 Symbol::fromQualString("upgraders::_test_serialization_subcmul_0_2")}},
71 {Symbol::fromQualString("aten::div"),
72 {0, 3, Symbol::fromQualString("upgraders::div_0_3")}},
73 {Symbol::fromQualString("aten::div_"),
74 {0, 3, Symbol::fromQualString("upgraders::div__0_3")}},
75 {Symbol::fromQualString("aten::full"),
76 {0, 4, Symbol::fromQualString("upgraders::full_0_4")}},
77 });
78
79 static std::unordered_map<NodeKind, uint64_t> kind_min_version_map({
80 {aten::div, 4},
81 {aten::div_, 4},
82 {aten::full, 5}, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
83 });
84
get_symbol_for_version(const Symbol name,const uint64_t version)85 Symbol get_symbol_for_version(const Symbol name, const uint64_t version) {
86 auto it = symbol_range_map.find(name);
87 if (it == symbol_range_map.end()) {
88 return name;
89 }
90
91 auto& entry = it->second;
92 if (entry.start_version_ <= version && entry.end_version_ >= version) {
93 return entry.sym_;
94 }
95
96 return name;
97 }
98
get_min_version_for_kind(const NodeKind & kind)99 uint64_t get_min_version_for_kind(const NodeKind& kind) {
100 auto it = kind_min_version_map.find(kind);
101 if (it == kind_min_version_map.end()) {
102 return 0;
103 }
104
105 return it->second;
106 }
107
108 } // namespace torch::jit
109