1"""Generic visitor pattern implementation for Python objects.""" 2 3import enum 4 5 6class Visitor(object): 7 defaultStop = False 8 9 @classmethod 10 def _register(celf, clazzes_attrs): 11 assert celf != Visitor, "Subclass Visitor instead." 12 if "_visitors" not in celf.__dict__: 13 celf._visitors = {} 14 15 def wrapper(method): 16 assert method.__name__ == "visit" 17 for clazzes, attrs in clazzes_attrs: 18 if type(clazzes) != tuple: 19 clazzes = (clazzes,) 20 if type(attrs) == str: 21 attrs = (attrs,) 22 for clazz in clazzes: 23 _visitors = celf._visitors.setdefault(clazz, {}) 24 for attr in attrs: 25 assert attr not in _visitors, ( 26 "Oops, class '%s' has visitor function for '%s' defined already." 27 % (clazz.__name__, attr) 28 ) 29 _visitors[attr] = method 30 return None 31 32 return wrapper 33 34 @classmethod 35 def register(celf, clazzes): 36 if type(clazzes) != tuple: 37 clazzes = (clazzes,) 38 return celf._register([(clazzes, (None,))]) 39 40 @classmethod 41 def register_attr(celf, clazzes, attrs): 42 clazzes_attrs = [] 43 if type(clazzes) != tuple: 44 clazzes = (clazzes,) 45 if type(attrs) == str: 46 attrs = (attrs,) 47 for clazz in clazzes: 48 clazzes_attrs.append((clazz, attrs)) 49 return celf._register(clazzes_attrs) 50 51 @classmethod 52 def register_attrs(celf, clazzes_attrs): 53 return celf._register(clazzes_attrs) 54 55 @classmethod 56 def _visitorsFor(celf, thing, _default={}): 57 typ = type(thing) 58 59 for celf in celf.mro(): 60 _visitors = getattr(celf, "_visitors", None) 61 if _visitors is None: 62 break 63 64 m = celf._visitors.get(typ, None) 65 if m is not None: 66 return m 67 68 return _default 69 70 def visitObject(self, obj, *args, **kwargs): 71 """Called to visit an object. This function loops over all non-private 72 attributes of the objects and calls any user-registered (via 73 @register_attr() or @register_attrs()) visit() functions. 74 75 If there is no user-registered visit function, of if there is and it 76 returns True, or it returns None (or doesn't return anything) and 77 visitor.defaultStop is False (default), then the visitor will proceed 78 to call self.visitAttr()""" 79 80 keys = sorted(vars(obj).keys()) 81 _visitors = self._visitorsFor(obj) 82 defaultVisitor = _visitors.get("*", None) 83 for key in keys: 84 if key[0] == "_": 85 continue 86 value = getattr(obj, key) 87 visitorFunc = _visitors.get(key, defaultVisitor) 88 if visitorFunc is not None: 89 ret = visitorFunc(self, obj, key, value, *args, **kwargs) 90 if ret == False or (ret is None and self.defaultStop): 91 continue 92 self.visitAttr(obj, key, value, *args, **kwargs) 93 94 def visitAttr(self, obj, attr, value, *args, **kwargs): 95 """Called to visit an attribute of an object.""" 96 self.visit(value, *args, **kwargs) 97 98 def visitList(self, obj, *args, **kwargs): 99 """Called to visit any value that is a list.""" 100 for value in obj: 101 self.visit(value, *args, **kwargs) 102 103 def visitDict(self, obj, *args, **kwargs): 104 """Called to visit any value that is a dictionary.""" 105 for value in obj.values(): 106 self.visit(value, *args, **kwargs) 107 108 def visitLeaf(self, obj, *args, **kwargs): 109 """Called to visit any value that is not an object, list, 110 or dictionary.""" 111 pass 112 113 def visit(self, obj, *args, **kwargs): 114 """This is the main entry to the visitor. The visitor will visit object 115 obj. 116 117 The visitor will first determine if there is a registered (via 118 @register()) visit function for the type of object. If there is, it 119 will be called, and (visitor, obj, *args, **kwargs) will be passed to 120 the user visit function. 121 122 If there is no user-registered visit function, of if there is and it 123 returns True, or it returns None (or doesn't return anything) and 124 visitor.defaultStop is False (default), then the visitor will proceed 125 to dispatch to one of self.visitObject(), self.visitList(), 126 self.visitDict(), or self.visitLeaf() (any of which can be overriden in 127 a subclass).""" 128 129 visitorFunc = self._visitorsFor(obj).get(None, None) 130 if visitorFunc is not None: 131 ret = visitorFunc(self, obj, *args, **kwargs) 132 if ret == False or (ret is None and self.defaultStop): 133 return 134 if hasattr(obj, "__dict__") and not isinstance(obj, enum.Enum): 135 self.visitObject(obj, *args, **kwargs) 136 elif isinstance(obj, list): 137 self.visitList(obj, *args, **kwargs) 138 elif isinstance(obj, dict): 139 self.visitDict(obj, *args, **kwargs) 140 else: 141 self.visitLeaf(obj, *args, **kwargs) 142