xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/import_export_helpers.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/serialization/import_export_helpers.h>
2 
3 #include <caffe2/serialize/inline_container.h>
4 #include <torch/csrc/jit/frontend/source_range.h>
5 #include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
6 
7 #include <c10/util/Exception.h>
8 
9 #include <algorithm>
10 
11 namespace torch::jit {
12 
qualifierToArchivePath(const std::string & qualifier,const std::string & export_prefix)13 std::string qualifierToArchivePath(
14     const std::string& qualifier,
15     const std::string& export_prefix) {
16   static const std::string kExportSuffix = "py";
17   std::string path = qualifier;
18   std::replace_if(
19       path.begin(), path.end(), [](char c) { return c == '.'; }, '/');
20   return export_prefix + path + "." + kExportSuffix;
21 }
22 
findSourceInArchiveFromQualifier(caffe2::serialize::PyTorchStreamReader & reader,const std::string & export_prefix,const std::string & qualifier)23 std::shared_ptr<Source> findSourceInArchiveFromQualifier(
24     caffe2::serialize::PyTorchStreamReader& reader,
25     const std::string& export_prefix,
26     const std::string& qualifier) {
27   const std::string path = qualifierToArchivePath(qualifier, export_prefix);
28   if (!reader.hasRecord(path)) {
29     return nullptr;
30   }
31   auto [data, size] = reader.getRecord(path);
32 
33   std::shared_ptr<ConcreteSourceRangeUnpickler> gen_ranges = nullptr;
34 
35   std::string debug_file = path + ".debug_pkl";
36   if (reader.hasRecord(debug_file)) {
37     auto [debug_data, debug_size] = reader.getRecord(debug_file);
38     gen_ranges = std::make_shared<ConcreteSourceRangeUnpickler>(
39         std::move(debug_data), debug_size);
40   }
41   return std::make_shared<Source>(
42       std::string(static_cast<const char*>(data.get()), size),
43       path,
44       1,
45       gen_ranges);
46 }
47 
48 } // namespace torch::jit
49