xref: /aosp_15_r20/external/executorch/backends/xnnpack/operators/node_visitor.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import ctypes
8
9from typing import cast, Dict, List, Optional, Tuple
10
11import torch
12from executorch.backends.transforms import get_shape
13
14from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
15    ChannelsLastTaggedReshapePass,
16)
17
18from executorch.backends.xnnpack.operators.quant_params import QuantParams
19
20from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
21    ConstantDataOffset,
22    PerChannelGroupQuant,
23    PerChannelQuant,
24    PerTensorQuant,
25    PerTokenDynamicQuant,
26    XNNDatatype,
27    XNNGraph,
28    XNNQuantizedTensorValue,
29    XNNQuantParams,
30    XNNTensorValue,
31    XValue,
32)
33from executorch.backends.xnnpack.utils.utils import (
34    check_or_raise,
35    get_input_node,
36    get_param_tensor,
37    is_param_node,
38    PERM_NCHW_TO_NHWC,
39)
40
41from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
42from torch.export import ExportedProgram
43
44XNN_TYPE_MAP = {
45    torch.float32: XNNDatatype.xnn_datatype_fp32,
46}
47
48from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import (
49    _aligned_size,
50    _pad_to,
51    CONSTANT_TENSOR_ALIGNMENT,
52)
53
54
55class InputTypeToIndex:
56    """
57    Mapping from input type to the arg index of a node
58    """
59
60    node_input: int
61    node_weight: int
62
63    def __init__(self, node_input: int, node_weight: int, node_bias=None):
64        self.node_input = node_input
65        self.node_weight = node_weight
66        self.node_bias = node_bias
67
68
69def get_tensor_value(xvalue: XValue) -> XNNTensorValue:
70    val_union = xvalue.xvalue_union
71    if isinstance(val_union, XNNTensorValue):
72        return val_union
73    else:
74        # it is XNNQuantizedTensorValue
75        q_tensor = val_union
76        return q_tensor.tensor_value
77
78
79class NodeVisitor:
80    """
81    Node visitor pattern for visiting nodes in an edge IR graph and
82    serializing them using the xnnpack serialization schema defined
83    """
84
85    def __init__(
86        self,
87        exported_program: ExportedProgram,
88        external_ids: Dict,
89        constant_data_bytes: bytearray,
90    ) -> None:
91        self._external_ids = external_ids or {}
92        self._exported_program = exported_program or None
93        self._constant_data_bytes = constant_data_bytes
94
95    @property
96    def external_ids(self) -> Dict:
97        return self._external_ids
98
99    @property
100    def exported_program(self) -> ExportedProgram:
101        return self._exported_program
102
103    def is_graph_input(self, tensor: torch.fx.Node) -> bool:
104        """
105        Checks if the given tensor is a graph input
106
107        Args:
108            tensor: EdgeIR Tensor that is being checked for graph input
109        """
110        return tensor.op == "placeholder" and not is_param_node(
111            self.exported_program, tensor
112        )
113
114    def is_graph_output(self, tensor: torch.fx.Node) -> bool:
115        """
116        Checks if the given tensor is used as a graph output
117
118        Args:
119            tensor: EdgeIR Tensor that is being checked for graph input
120        """
121
122        for user in tensor.users.keys():
123            if user.op == "output":
124                return True
125        return False
126
127    def gen_ids_and_flags(
128        self,
129        tensor: torch.fx.Node,
130        xnn_graph: XNNGraph,
131        quant_params: Optional[QuantParams],
132    ) -> Tuple[int, int, int]:
133        """
134        Generate new id, external id, and flag values for tensor info
135
136        Args:
137           tensor: EdgeIR Tensor that is being defined into xnn_graph
138           xnn_graph: XNNGraph object for serializing into flatbuffer
139           quant_params: QuantParams object representing the q params of this tensor
140                    is none if not quantized
141
142        Returns:
143            tuple of external_id, id_out and external input/output flags
144        """
145        id_out = len(xnn_graph.xvalues)
146        ext_id = XNN_INVALID_VALUE_ID
147        flag = 0
148
149        # Dynamic quant isn't really a quant
150        if quant_params is not None and quant_params.is_dynamic:
151            tensor = quant_params.q_input
152
153        # TODO tensor here for [placeholder -> q -> dq -> op] must be the placeholder node
154        # This will break if we change the way q/dq are partitioned
155
156        # Tensor can still be input if its quantizing node is an input
157        is_input = self.is_graph_input(tensor) or (
158            quant_params.is_input
159            and not is_param_node(self.exported_program, quant_params.q_input)
160            if quant_params
161            else False
162        )
163
164        # Tensor can still be output if its quantizing node is an output
165        is_output = self.is_graph_output(tensor) or (
166            quant_params.is_output if quant_params else False
167        )
168
169        if is_input:
170            tensor_input = tensor
171            if (
172                quant_params
173                and quant_params.is_input
174                and not is_param_node(self.exported_program, quant_params.q_input)
175                and not self.is_graph_input(tensor)
176            ):
177                tensor_input = quant_params.q_input
178
179            assert (
180                tensor_input in self.external_ids.keys()
181            ), f"Tensor {tensor_input}, is_input. ext_ids: {self.external_ids.keys()}"
182
183            ext_id = self.external_ids[tensor_input].external_id
184            xnn_graph.input_ids.append(id_out)
185            flag = self.external_ids[tensor_input].io_type
186
187        elif is_output:
188            tensor_output = tensor
189            if (
190                quant_params
191                and quant_params.is_output
192                and not self.is_graph_output(tensor)
193            ):
194                tensor_output = list(tensor.users)[0]
195
196            assert (
197                tensor_output in self.external_ids.keys()
198            ), f"Tensor {tensor_output} is_output. ext_ids: {self.external_ids.keys()}"
199
200            ext_id = self.external_ids[tensor_output].external_id
201            xnn_graph.output_ids.append(id_out)
202            flag = self.external_ids[tensor_output].io_type
203
204        return ext_id, id_out, flag
205
206    def get_serialized_dtype(
207        self,
208        quant_params: Optional[QuantParams],
209        node: torch.fx.Node,
210        fp32_static_weight: bool = False,
211    ) -> XNNDatatype:
212        # Default initialization
213        dtype = XNNDatatype.xnn_datatype_fp32
214
215        def get_node_dtype(node: torch.fx.Node) -> Optional[torch.dtype]:
216            """
217            Extract the tensor.dtype from the node meta data if possible
218            """
219            node_val = node.meta.get("val", None)
220            if node_val is not None:
221                if isinstance(node_val, torch.Tensor):
222                    return node_val.dtype
223
224        # only for static quant
225        def get_per_channel_dtype(
226            quant_params: QuantParams,
227        ) -> XNNDatatype:
228            if quant_params.dtype == torch.int32:
229                return XNNDatatype.xnn_datatype_qcint32
230            elif quant_params.dtype == torch.int8:
231                if quant_params.is_per_channel_group:
232                    # 4-bit per channel group quantized weights
233                    # No 8-bit support yet
234                    assert (
235                        quant_params.is_qc4w is True
236                    ), "Only 4-bit per channel group quantization is supported"
237                    return XNNDatatype.xnn_datatype_qbint4
238                else:
239                    # 4/8-bit per channel quantized weights
240                    return (
241                        XNNDatatype.xnn_datatype_qcint4
242                        if quant_params.is_qc4w
243                        else XNNDatatype.xnn_datatype_qcint8
244                    )
245            else:
246                raise RuntimeError(
247                    f"Unable to resolve static quantized tensor dtype using quant params dtype: {quant_params.dtype}, [qmin, qmax]: {quant_params.qmin}, {quant_params.qmax} for per channel quantization"
248                )
249
250        if quant_params is not None:
251            if quant_params.is_dynamic:
252                dtype = XNNDatatype.xnn_datatype_qdint8
253            else:
254                if quant_params.per_channel:
255                    dtype = get_per_channel_dtype(quant_params)
256                else:
257                    dtype = (
258                        XNNDatatype.xnn_datatype_qint32
259                        if quant_params.dtype == torch.int32
260                        else XNNDatatype.xnn_datatype_qint8
261                    )
262        else:
263            node_dtype = get_node_dtype(node)
264            if node_dtype is not None and node_dtype == torch.float16:
265                dtype = (
266                    XNNDatatype.xnn_datatype_fp32
267                    if fp32_static_weight
268                    else XNNDatatype.xnn_datatype_fp16
269                )
270
271        return dtype
272
273    def get_quant_params(self, quant_params: QuantParams) -> XNNQuantParams:
274        if quant_params.per_channel:
275            scale = cast(torch.Tensor, quant_params.scale)
276            if quant_params.is_per_channel_group:
277                return PerChannelGroupQuant(
278                    scale=scale.flatten().tolist(),
279                    channel_dim=quant_params.axis,
280                    group_size=quant_params.group_size,
281                )
282            else:  # per_channel quant
283                return PerChannelQuant(
284                    scale=scale.tolist(),
285                    channel_dim=quant_params.axis,
286                )
287        elif quant_params.is_dynamic:
288            # NB:
289            # We use per_token quantization for per_tensor quantization
290            # Beacuase that's the only option in XNNPACK in absance of per_tensor dynamic quantization
291            # TODO: Upstream support for per_tensor dynamic quantization or broadcasting same scale value internally
292            return PerTokenDynamicQuant(
293                num_nonbatch_dims=quant_params.num_nonbatch_dims,
294            )
295
296        return PerTensorQuant(
297            scale=cast(float, quant_params.scale),
298            zero_point=cast(int, quant_params.zp),
299        )
300
301    @staticmethod
302    def _check_per_channel_group_params(
303        quant_params: QuantParams, dims: List[int]
304    ) -> None:
305        # Make sure things are lining up for per_channel_group quantization case
306        # Has to be done this late because we don't have clean access to the actual tensor
307        assert quant_params.is_per_channel_group, "Not per_channel_group quantization"
308        # linear weights will be in [oc, ic]. And per_channel quantization must be on axis 0
309        num_groups = cast(torch.Tensor, quant_params.scale).shape[1]
310        assert (
311            quant_params.axis == 0
312        ), "For per_channel_group quant, axis must be 0, but got {axis}"
313        assert (
314            len(dims) == 2
315        ), "For per_channel_group quant, expecting linear weights to be 2d, but got {len(dims)}"
316        assert (
317            num_groups > 0 and quant_params.group_size > 0
318        ), "For per_channel_group quant, num_groups and group_size must be > 0, but got num_groups: {num_groups}, group_size: {quant_params.group_size}"
319        output_channels = dims[quant_params.axis]
320        input_channels = dims[quant_params.axis ^ 1]
321        assert (
322            output_channels == cast(torch.Tensor, quant_params.scale).shape[0]
323        ), "For per_channel_group quant, expecting output channels to match scale.shape[0], gut got: {output_channels}, scale.shape[0]: {quant_params.scale.shape[0]}"
324        assert (
325            input_channels % num_groups == 0
326        ), "For per_channel_group quant, expecting input channels to be divisible by num_groups, but got ic: {input_channels}, num_groups: {num_groups}"
327        assert (
328            input_channels % quant_params.group_size == 0
329        ), "For per_channel_group quant, expecting input channels to be divisible by group_size, but got ic: {input_channels}, group_size: {quant_params.group_size}"
330        assert (
331            input_channels / quant_params.group_size == num_groups
332        ), "For per_channel_group quant, expecting input channels // group_size == num_groups, but got ic: {input_channels}, group_size: {quant_params.group_size}, num_groups: {num_groups}"
333
334        # For now group quantization is only supported for 4b weights
335        assert quant_params.is_qc4w, "Only 4b group quantization is supported"
336
337    def define_tensor(
338        self,
339        tensor: torch.fx.Node,
340        xnn_graph: XNNGraph,
341        vals_to_ids: Dict[torch.fx.Node, int],
342        convert_to_nhwc: bool = False,
343        swap_nc_for_depthwise_weights: bool = False,
344        quant_params: Optional[QuantParams] = None,
345        fp32_static_weights: bool = False,
346    ) -> None:
347        """
348        Defines an tensor value into the XNNGraph
349
350        Args:
351            tensor: EdgeIR Tensor that is being defined into xnn_graph
352            xnn_graph: XNNGraph object for serializing into flatbuffer
353            vals_to_ids: dictionary mapping edge_graph values(node targets) to
354                        their corresponding ids in XNNGraph
355            convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
356                        reflect the nhwc memory format.
357            swap_nc_for_depthwise_weights: bool to indicate whether tensor shape
358                        should be permuted such that the N and C dimensions are
359                        swapped, which should be used for depthwise convolution
360                        weights. This is only valid for tensors which hold
361                        constant data. If used along with convert_to_nhwc, this
362                        swap will happen before converting to nhwc.
363            quant_params: Quantization meta data for this tensor, None if it is not quantized
364            fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
365        """
366
367        if tensor in vals_to_ids:
368            return
369
370        if quant_params is not None:
371            if quant_params.q_input in vals_to_ids:
372                vals_to_ids[tensor] = vals_to_ids[quant_params.q_input]
373                return
374        # Tag added by ChannelsLastTaggedReshapePass
375        convert_to_nhwc |= tensor.meta.get(
376            ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
377        )
378
379        # Get new xnn id for tensor value
380        ext_id, id_out, flag = self.gen_ids_and_flags(tensor, xnn_graph, quant_params)
381        dims = get_shape(tensor)
382        dims = [1] if len(dims) == 0 else dims
383
384        # check for per_channel_group quantization
385        if quant_params and quant_params.per_channel_group:
386            self._check_per_channel_group_params(quant_params, dims)
387
388        # constant values serialize data
389        buffer_idx = self.get_serialized_buffer_index(
390            tensor,
391            xnn_graph,
392            vals_to_ids,
393            convert_to_nhwc,
394            swap_nc_for_depthwise_weights,
395            quant_params,
396            fp32_static_weights,
397        )
398
399        # convert tensor shape must reflect memory format, default is contiguous, so
400        # only permute shape if we are converting the tensor to nhwc format
401        if swap_nc_for_depthwise_weights:
402            dims = [dims[1], dims[0]] + dims[2:]
403        if convert_to_nhwc:
404            check_or_raise(len(dims) == 4, "Converting to nhwc requires 4d tensor")
405            dims = [dims[i] for i in PERM_NCHW_TO_NHWC]
406
407        dtype = self.get_serialized_dtype(
408            quant_params, tensor, fp32_static_weight=fp32_static_weights
409        )
410
411        tvalue = XNNTensorValue(
412            datatype=dtype,
413            num_dims=len(dims),
414            dims=dims,
415            external_id=ext_id,
416            constant_buffer_idx=buffer_idx,
417            flags=flag,
418            id_out=id_out,
419        )
420
421        # Override the quant params axis since we have
422        # updated the weights for depthwise, with that the out_channels dim
423        # will be dims[3] instead of dims[0]. Let's update the per_channel
424        # quant axis to match the new weight tensor before serializing
425        if swap_nc_for_depthwise_weights and (
426            quant_params and quant_params.per_channel
427        ):
428            if quant_params.axis == 0:
429                quant_params.axis = len(dims) - 1
430            else:
431                assert f"Unsupported weight per channel quantization axis for depthwise conv2d: {quant_params.axis}, expecting 0."
432
433        # Serialize tensor value
434        ser_val = (
435            XValue(xvalue_union=tvalue)
436            if quant_params is None
437            else XValue(
438                xvalue_union=XNNQuantizedTensorValue(
439                    tensor_value=tvalue,
440                    quant_params=self.get_quant_params(quant_params),
441                )
442            )
443        )
444
445        xnn_graph.xvalues.append(ser_val)
446        vals_to_ids[tensor] = id_out
447        if quant_params is not None:
448            vals_to_ids[quant_params.q_input] = id_out
449
450    @staticmethod
451    def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
452        """
453        Convert a tensor to a quantized channelwise tensor 4bit tensor
454        """
455
456        import torch.nn.functional as F
457
458        # Assert we got a properly quantized tensor.
459        min, max = inp.min().item(), inp.max().item()
460        assert (
461            max <= 7 and min >= -8
462        ), f"convert_to_qc4w: [min,max] out of [-8, 7] range, got [{min}, {max}]"
463
464        # Assuming we have a 2d tensor
465        if inp.ndim != 2:
466            inp = inp.squeeze()
467        assert (
468            inp.ndim == 2
469        ), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}"
470
471        # pad ic
472        if inp.shape[-1] % 2 != 0:
473            inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0)
474
475        # Shape after padding
476        oc, ic = inp.shape
477        assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even"
478
479        # Adjust inp tensor for zp
480        inp = inp.to(dtype=torch.uint8) + 8
481
482        # Prepare the Result tensor
483        inp = inp.contiguous().view(-1)
484        return (inp[1::2] << 4 | inp[::2]).view(oc, int(ic / 2))
485
486    def get_serialized_buffer_index(
487        self,
488        tensor: torch.fx.Node,
489        xnn_graph: XNNGraph,
490        vals_to_ids: Dict[torch.fx.Node, int],
491        convert_to_nhwc: bool,
492        swap_nc_for_depthwise_weights: bool,
493        quant_params: Optional[QuantParams],
494        fp32_static_weights: bool = False,
495    ) -> int:
496        """
497        If tensor holds some constant data, serialize it and return the
498        index of its placement in the constant buffer
499
500        Args:
501            tensor: EdgeIR Tensor that is being defined into xnn_graph
502            xnn_graph: XNNGraph object for serializing into flatbuffer
503            vals_to_ids: dictionary apping edge_graph values(node targets) to
504                        their corresponding ids in XNNGraph
505            convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
506                        reflect the nhwc memory format.
507            swap_nc_for_depthwise_weights: bool to indicate whether tensor shape
508                        should be permuted such that the N and C dimensions are
509                        swapped, which should be used for depthwise convolution
510                        weights. This is only valid for tensors which hold
511                        constant data. If used along with convert_to_nhwc, this
512                        swap will happen before converting to nhwc.
513            quant_params: Quantization meta data for this tensor, None if it is not quantize
514            fp32_static_weights: bool to indicate whether tensor is fp32 static weights
515
516        Returns:
517            buffer_idx: idx of the serialized data. 0 If not associated constant
518                        data
519        """
520        # The get_attr node is the input to quant_params.
521        get_attr_node = tensor if quant_params is None else quant_params.q_input
522        if not is_param_node(self.exported_program, get_attr_node):
523            check_or_raise(
524                not swap_nc_for_depthwise_weights,
525                "Swapping N and C dimensions is only valid for constant data tensors",
526            )
527            return 0
528
529        buffer_idx = len(xnn_graph.constant_data)
530        const_val = get_param_tensor(self.exported_program, get_attr_node)
531        assert const_val is not None and isinstance(const_val, torch.Tensor)
532        const_val = const_val.contiguous()
533
534        # Quantize buffer if static data is indeed quantized
535        if quant_params is not None and not quant_params.is_dynamic:
536            const_val = quant_params.quantize_tensor(const_val).contiguous()
537        elif const_val.dtype != torch.float16 or fp32_static_weights:
538            # ensure that the const is fp32
539            const_val = const_val.to(dtype=torch.float32).contiguous()
540
541        if swap_nc_for_depthwise_weights:
542            const_val = const_val.permute(
543                dims=((1, 0) + tuple(range(2, const_val.dim())))
544            ).contiguous()
545
546        if convert_to_nhwc:
547            const_val = const_val.to(memory_format=torch.channels_last)
548
549        if quant_params is not None and quant_params.is_qc4w:
550            const_val = self.convert_to_qc4w(const_val)
551
552        array_type = ctypes.c_char * const_val.untyped_storage().nbytes()
553        array = ctypes.cast(
554            const_val.untyped_storage().data_ptr(),
555            ctypes.POINTER(array_type),
556        ).contents
557
558        offset = len(self._constant_data_bytes)
559        size = const_val.untyped_storage().nbytes()
560        xnn_graph.constant_data.append(ConstantDataOffset(offset=offset, size=size))
561        self._constant_data_bytes.extend(
562            _pad_to(bytes(array), _aligned_size(size, CONSTANT_TENSOR_ALIGNMENT))
563        )
564
565        return buffer_idx
566
567    def define_nodes_tensor_inputs_outputs(
568        self,
569        node: torch.fx.Node,
570        xnn_graph: XNNGraph,
571        vals_to_ids: Dict[torch.fx.Node, int],
572        convert_to_nhwc: bool = False,
573        input_type_map: Optional[InputTypeToIndex] = None,
574    ) -> None:
575        # serialize node outputs if not already defined
576        self.define_tensor(
577            node,
578            xnn_graph,
579            vals_to_ids,
580            quant_params=QuantParams.from_outputs(node),
581            convert_to_nhwc=convert_to_nhwc,
582        )
583
584        if input_type_map is None:
585            # serialize node inputs if not already defined
586            for inp in node.all_input_nodes:
587                self.define_tensor(
588                    inp,
589                    xnn_graph,
590                    vals_to_ids,
591                    quant_params=QuantParams.from_inputs(inp, self._exported_program),
592                    convert_to_nhwc=convert_to_nhwc,
593                )
594        else:
595            num_inputs = 3 if input_type_map.node_bias is not None else 2
596            check_or_raise(
597                num_inputs == len(node.all_input_nodes),
598                f"Invalid input type map given, {input_type_map}, {num_inputs}, {node.all_input_nodes}",
599            )
600            # Define Input Node
601            input_node = get_input_node(node, input_type_map.node_input)
602            input_quant_params = QuantParams.from_inputs(
603                input_node, self._exported_program
604            )
605            self.define_tensor(
606                input_node,
607                xnn_graph,
608                vals_to_ids,
609                quant_params=input_quant_params,
610                convert_to_nhwc=convert_to_nhwc,
611            )
612            # Define Weight Node
613            weight_node = get_input_node(node, input_type_map.node_weight)
614            weight_quant_params = QuantParams.from_weights(
615                weight_node, self._exported_program
616            )
617            self.define_tensor(
618                weight_node,
619                xnn_graph,
620                vals_to_ids,
621                quant_params=weight_quant_params,
622                convert_to_nhwc=convert_to_nhwc,
623            )
624            # Define Bias Node
625            if input_type_map.node_bias is not None:
626                bias_node = get_input_node(node, input_type_map.node_bias)
627                bias_quant_params = QuantParams.from_bias(
628                    bias_node, weight_quant_params, input_quant_params
629                )
630                self.define_tensor(
631                    bias_node,
632                    xnn_graph,
633                    vals_to_ids,
634                    quant_params=bias_quant_params,
635                    convert_to_nhwc=False,  # Bias is generally 1d and can not be in NHWC
636                )
637
638    def define_node(
639        self,
640        node: torch.fx.Node,
641        xnn_graph: XNNGraph,
642        vals_to_ids: Dict[torch.fx.Node, int],
643        debug_handle: int,
644    ) -> None:
645        raise NotImplementedError("NodeVisitor must be extended!")
646
647
648# This will hold mapping of all node names to the visitor class that will define
649# the torch.fx.Node object into the XNNGraph. Don't use it directly!
650_node_visitor_dict = {}
651
652
653def register_node_visitor(visitor):
654    assert (
655        isinstance(visitor, type)
656        and issubclass(visitor, NodeVisitor)
657        and hasattr(visitor, "target")
658    ), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
659    _node_visitor_dict[visitor.target] = visitor
660
661
662# @lru_cache - TODO enable caching - ATM dict being non hashable is causing issues with LRU cache
663def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
664    node_visitors = {}
665    """
666    Create a new class instance at runtime, and put them in a dict
667    """
668    for target, visitor in _node_visitor_dict.items():
669        assert callable(
670            visitor
671        ), f"Expecting a callable class, but got {visitor} of type {type(visitor)}"
672        node_visitors[target] = visitor(*args)
673    return node_visitors
674