1 #pragma once 2 3 #include <ATen/core/Dict.h> 4 #include <ATen/core/ivalue.h> 5 #include <ATen/core/jit_type.h> 6 #include <torch/csrc/utils/pybind.h> 7 8 namespace torch::jit { 9 10 void initScriptDictBindings(PyObject* module); 11 12 /// An iterator over the keys of ScriptDict. This is used to support 13 /// .keys() and iteration. 14 class ScriptDictKeyIterator final { 15 public: ScriptDictKeyIterator(c10::impl::GenericDict::iterator iter,c10::impl::GenericDict::iterator end)16 ScriptDictKeyIterator( 17 c10::impl::GenericDict::iterator iter, 18 c10::impl::GenericDict::iterator end) 19 : iter_(std::move(iter)), end_(std::move(end)) {} 20 at::IValue next(); 21 22 private: 23 c10::impl::GenericDict::iterator iter_; 24 c10::impl::GenericDict::iterator end_; 25 }; 26 27 /// An iterator over the key-value pairs of ScriptDict. This is used to support 28 /// .items(). 29 class ScriptDictIterator final { 30 public: ScriptDictIterator(c10::impl::GenericDict::iterator iter,c10::impl::GenericDict::iterator end)31 ScriptDictIterator( 32 c10::impl::GenericDict::iterator iter, 33 c10::impl::GenericDict::iterator end) 34 : iter_(std::move(iter)), end_(std::move(end)) {} 35 at::IValue next(); 36 37 private: 38 c10::impl::GenericDict::iterator iter_; 39 c10::impl::GenericDict::iterator end_; 40 }; 41 42 /// A wrapper around c10::Dict that can be exposed in Python via pybind 43 /// with an API identical to the Python dictionary class. This allows 44 /// dictionaries to have reference semantics across the Python/TorchScript 45 /// boundary. 46 class ScriptDict final { 47 public: 48 // Constructor. ScriptDict(const at::IValue & data)49 ScriptDict(const at::IValue& data) 50 : dict_(at::AnyType::get(), at::AnyType::get()) { 51 TORCH_INTERNAL_ASSERT(data.isGenericDict()); 52 dict_ = data.toGenericDict(); 53 } 54 55 // Get the type of the dictionary. type()56 at::DictTypePtr type() const { 57 return at::DictType::create(dict_.keyType(), dict_.valueType()); 58 } 59 60 // Return a string representation that can be used 61 // to reconstruct the instance. repr()62 std::string repr() const { 63 std::ostringstream s; 64 s << '{'; 65 bool f = false; 66 for (auto const& kv : dict_) { 67 if (f) { 68 s << ", "; 69 } 70 s << kv.key() << ": " << kv.value(); 71 f = true; 72 } 73 s << '}'; 74 return s.str(); 75 } 76 77 // Return an iterator over the keys of the dictionary. iter()78 ScriptDictKeyIterator iter() const { 79 auto begin = dict_.begin(); 80 auto end = dict_.end(); 81 return ScriptDictKeyIterator(begin, end); 82 } 83 84 // Return an iterator over the key-value pairs of the dictionary. items()85 ScriptDictIterator items() const { 86 auto begin = dict_.begin(); 87 auto end = dict_.end(); 88 return ScriptDictIterator(begin, end); 89 } 90 91 // Interpret the dictionary as a boolean; empty means false, non-empty means 92 // true. toBool()93 bool toBool() const { 94 return !(dict_.empty()); 95 } 96 97 // Get the value for the given key. Throws std::out_of_range if the key does 98 // not exist. getItem(const at::IValue & key)99 at::IValue getItem(const at::IValue& key) { 100 return dict_.at(key); 101 }; 102 103 // Set the value for the given key. setItem(const at::IValue & key,const at::IValue & value)104 void setItem(const at::IValue& key, const at::IValue& value) { 105 dict_.insert_or_assign(key, value); 106 }; 107 108 // Check whether the dictionary contains the given key. contains(const at::IValue & key)109 bool contains(const at::IValue& key) { 110 return dict_.contains(key); 111 } 112 113 // Delete the given key from the dictionary. delItem(const at::IValue & key)114 bool delItem(const at::IValue& key) { 115 return dict_.erase(key); 116 } 117 118 // Get the size of the dictionary. len()119 int64_t len() const { 120 return dict_.size(); 121 } 122 123 // A c10::Dict instance that holds the actual data. 124 c10::impl::GenericDict dict_; 125 }; 126 127 } // namespace torch::jit 128