xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_dict.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Dict.h>
4 #include <ATen/core/ivalue.h>
5 #include <ATen/core/jit_type.h>
6 #include <torch/csrc/utils/pybind.h>
7 
8 namespace torch::jit {
9 
10 void initScriptDictBindings(PyObject* module);
11 
12 /// An iterator over the keys of ScriptDict. This is used to support
13 /// .keys() and iteration.
14 class ScriptDictKeyIterator final {
15  public:
ScriptDictKeyIterator(c10::impl::GenericDict::iterator iter,c10::impl::GenericDict::iterator end)16   ScriptDictKeyIterator(
17       c10::impl::GenericDict::iterator iter,
18       c10::impl::GenericDict::iterator end)
19       : iter_(std::move(iter)), end_(std::move(end)) {}
20   at::IValue next();
21 
22  private:
23   c10::impl::GenericDict::iterator iter_;
24   c10::impl::GenericDict::iterator end_;
25 };
26 
27 /// An iterator over the key-value pairs of ScriptDict. This is used to support
28 /// .items().
29 class ScriptDictIterator final {
30  public:
ScriptDictIterator(c10::impl::GenericDict::iterator iter,c10::impl::GenericDict::iterator end)31   ScriptDictIterator(
32       c10::impl::GenericDict::iterator iter,
33       c10::impl::GenericDict::iterator end)
34       : iter_(std::move(iter)), end_(std::move(end)) {}
35   at::IValue next();
36 
37  private:
38   c10::impl::GenericDict::iterator iter_;
39   c10::impl::GenericDict::iterator end_;
40 };
41 
42 /// A wrapper around c10::Dict that can be exposed in Python via pybind
43 /// with an API identical to the Python dictionary class. This allows
44 /// dictionaries to have reference semantics across the Python/TorchScript
45 /// boundary.
46 class ScriptDict final {
47  public:
48   // Constructor.
ScriptDict(const at::IValue & data)49   ScriptDict(const at::IValue& data)
50       : dict_(at::AnyType::get(), at::AnyType::get()) {
51     TORCH_INTERNAL_ASSERT(data.isGenericDict());
52     dict_ = data.toGenericDict();
53   }
54 
55   // Get the type of the dictionary.
type()56   at::DictTypePtr type() const {
57     return at::DictType::create(dict_.keyType(), dict_.valueType());
58   }
59 
60   // Return a string representation that can be used
61   // to reconstruct the instance.
repr()62   std::string repr() const {
63     std::ostringstream s;
64     s << '{';
65     bool f = false;
66     for (auto const& kv : dict_) {
67       if (f) {
68         s << ", ";
69       }
70       s << kv.key() << ": " << kv.value();
71       f = true;
72     }
73     s << '}';
74     return s.str();
75   }
76 
77   // Return an iterator over the keys of the dictionary.
iter()78   ScriptDictKeyIterator iter() const {
79     auto begin = dict_.begin();
80     auto end = dict_.end();
81     return ScriptDictKeyIterator(begin, end);
82   }
83 
84   // Return an iterator over the key-value pairs of the dictionary.
items()85   ScriptDictIterator items() const {
86     auto begin = dict_.begin();
87     auto end = dict_.end();
88     return ScriptDictIterator(begin, end);
89   }
90 
91   // Interpret the dictionary as a boolean; empty means false, non-empty means
92   // true.
toBool()93   bool toBool() const {
94     return !(dict_.empty());
95   }
96 
97   // Get the value for the given key. Throws std::out_of_range if the key does
98   // not exist.
getItem(const at::IValue & key)99   at::IValue getItem(const at::IValue& key) {
100     return dict_.at(key);
101   };
102 
103   // Set the value for the given key.
setItem(const at::IValue & key,const at::IValue & value)104   void setItem(const at::IValue& key, const at::IValue& value) {
105     dict_.insert_or_assign(key, value);
106   };
107 
108   // Check whether the dictionary contains the given key.
contains(const at::IValue & key)109   bool contains(const at::IValue& key) {
110     return dict_.contains(key);
111   }
112 
113   // Delete the given key from the dictionary.
delItem(const at::IValue & key)114   bool delItem(const at::IValue& key) {
115     return dict_.erase(key);
116   }
117 
118   // Get the size of the dictionary.
len()119   int64_t len() const {
120     return dict_.size();
121   }
122 
123   // A c10::Dict instance that holds the actual data.
124   c10::impl::GenericDict dict_;
125 };
126 
127 } // namespace torch::jit
128