xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/operator_upgraders/utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/macros/Export.h>
3 #include <torch/csrc/jit/operator_upgraders/version_map.h>
4 #include <cstdint>
5 #include <optional>
6 #include <string>
7 #include <vector>
8 
9 namespace torch::jit {
10 
11 struct UpgraderRange {
12   int min_version;
13   int max_version;
14 };
15 
16 // Given a list of upgrader entries for a single operator
17 // and the model version for that operator, find a valid
18 // upgrader.
19 TORCH_API std::optional<UpgraderEntry> findUpgrader(
20     const std::vector<UpgraderEntry>& upgraders_for_schema,
21     size_t current_version);
22 
23 // Utility methods to find if the operator is up-to-date
24 // based on all registered upgraders for this operator.
25 // This can be different from the current server version
26 // because the implementation of this operator could have
27 // been consistent for many later version bumps.
28 TORCH_API bool isOpCurrentBasedOnUpgraderEntries(
29     const std::vector<UpgraderEntry>& upgraders_for_schema,
30     size_t current_version);
31 
32 TORCH_API bool isOpSymbolCurrent(
33     const std::string& name,
34     size_t current_version);
35 
36 // Returns the possible old schemas for the operator that
37 // doesn't exist anymore. This can be true for deprecated
38 // operators. Since name is always a symbol name, there
39 // can be multiple schemas for different overloads.
40 TORCH_API std::vector<std::string> loadPossibleHistoricOps(
41     const std::string& name,
42     std::optional<size_t> version);
43 
44 TORCH_API uint64_t getMaxOperatorVersion();
45 
46 // Returns the list of min and max version numbers of the operators
47 // that an upgrader `x` support for all upgraders for op `foo`
48 TORCH_API std::vector<UpgraderRange> getUpgradersRangeForOp(
49     const std::string& name);
50 
51 } // namespace torch::jit
52