xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/pt2e/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3import types
4from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
6import torch
7import torch.nn.functional as F
8from torch._export import capture_pre_autograd_graph
9
10# Makes sure that quantized_decomposed ops are registered
11from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
12from torch.ao.quantization.quantizer import QuantizationAnnotation
13from torch.export.unflatten import _assign_attr, _AttrKind
14from torch.fx import GraphModule, Node
15from torch.nn.utils.fusion import fuse_conv_bn_weights
16from torch.utils._pytree import LeafSpec
17
18
19__all__ = [
20    "fold_bn_weights_into_conv_node",
21    "remove_tensor_overload_for_qdq_ops",
22]
23
24_QUANTIZE_OPS = [
25    torch.ops.quantized_decomposed.quantize_per_tensor.default,
26    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
27    torch.ops.quantized_decomposed.quantize_per_channel.default,
28]
29
30
31_DEQUANTIZE_OPS = [
32    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
33    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
34    torch.ops.quantized_decomposed.dequantize_per_channel.default,
35]
36
37# Example inputs for conv-bn1d patterns
38_conv1d_bn_example_inputs = (
39    torch.randn(1, 1, 3),  # x
40    torch.randn(1, 1, 1),  # conv_weight
41    torch.randn(1),  # conv_bias
42    torch.randn(1),  # bn_weight
43    torch.randn(1),  # bn_bias
44    torch.randn(1),  # bn_running_mean
45    torch.randn(1),  # bn_running_var
46)
47
48# Example inputs for conv-bn2d patterns
49_conv2d_bn_example_inputs = (
50    torch.randn(1, 1, 3, 3),  # x
51    torch.randn(1, 1, 1, 1),  # conv_weight
52    torch.randn(1),  # conv_bias
53    torch.randn(1),  # bn_weight
54    torch.randn(1),  # bn_bias
55    torch.randn(1),  # bn_running_mean
56    torch.randn(1),  # bn_running_var
57)
58
59
60def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
61    """
62    Assuming dest is one of the ops inserted by quant workflow, this function
63    finds if source and dest are connected. Assumption is that only quant workflow
64    inserted ops exist between source and dest
65    """
66    quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS
67    quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor)
68    while dest.target in quant_workflow_ops:
69        if not isinstance(dest.args[0], torch.fx.Node):
70            raise ValueError(
71                f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}"
72            )
73        dest = dest.args[0]
74    return dest == source
75
76
77def _find_q_dq_node_for_user(
78    produer: torch.fx.Node, user: torch.fx.Node
79) -> Tuple[Any, Any]:
80    """
81    Find q, dq pair corresponding to [producer -> q -> dq -> user]
82    Utils works by finding dq arg of user and ensuring it is connected to
83    producer
84    """
85    dq_node = None
86    for n in user.args:
87        if (
88            isinstance(n, torch.fx.Node)
89            and n.op == "call_function"
90            and n.target in _DEQUANTIZE_OPS
91        ):
92            if _is_connected(produer, n):
93                dq_node = n
94                break
95    if dq_node is None:
96        for n in user.kwargs:
97            if (
98                isinstance(n, torch.fx.Node)
99                and n.op == "call_function"
100                and n.target in _DEQUANTIZE_OPS
101            ):
102                if _is_connected(produer, n):
103                    dq_node = n
104                    break
105    if dq_node is None:
106        return (None, None)
107
108    q_node = None
109    if (
110        dq_node.args[0].op == "call_function"  # type: ignore[union-attr]
111        and dq_node.args[0].target in _QUANTIZE_OPS  # type: ignore[union-attr]
112    ):
113        q_node = dq_node.args[0]
114    return (q_node, dq_node)
115
116
117def _is_sym_size_node(node: Node):
118    return (
119        node.op == "call_function"
120        and node.target == torch.ops.aten.sym_size.default
121        or node.target == torch.ops.aten.sym_numel.default
122        or node.target == torch.ops.aten.sym_numel
123        or node.target == torch.ops.aten.sym_size
124    )
125
126
127def _filter_sym_size_users(node: torch.fx.Node) -> List[torch.fx.Node]:
128    node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
129    return node_users
130
131
132def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool:
133    if annotation is None:
134        return False
135    input_qspec_map = annotation.input_qspec_map
136    output_qspec = annotation.output_qspec
137    if len(input_qspec_map) == 0 and output_qspec is None:
138        return False
139    return True
140
141
142def _get_tensor_constant_from_node(node, m):
143    if node is None:
144        return None
145    assert node.op == "get_attr"
146    target_atoms = node.target.split(".")
147    attr_itr = m
148    for i, atom in enumerate(target_atoms):
149        if not hasattr(attr_itr, atom):
150            raise RuntimeError(
151                f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
152            )
153        attr_itr = getattr(attr_itr, atom)
154    return attr_itr
155
156
157def _get_all_arguments(orig_args, orig_kwargs, args_schema):
158    all_args = []
159    for i, schema in enumerate(args_schema):
160        if schema.name in orig_kwargs:
161            all_args.append(orig_kwargs[schema.name])
162        elif not schema.kwarg_only and i < len(orig_args):
163            all_args.append(orig_args[i])
164        else:
165            all_args.append(schema.default_value)
166    return all_args
167
168
169def _is_supported_batch_norm_for_training(node: Node):
170    """
171    Return True if the given node refers to an aten batch norm op QAT supports.
172    """
173    supported_ops = [
174        torch.ops.aten.batch_norm.default,
175        torch.ops.aten._native_batch_norm_legit.default,
176        # Note: we won't need this op anymore after batch norm consolidation
177        # For now, we need to continue to support it because it gives better
178        # training numerics than `_native_batch_norm_legit`
179        torch.ops.aten.cudnn_batch_norm.default,
180        torch.ops.aten.miopen_batch_norm.default,
181    ]
182    return node.target in supported_ops
183
184
185# TODO: move this to torch/ao/quantization/utils.py
186def _is_conv_node(n: Node):
187    """
188    Return whether the node refers to an aten conv op.
189    """
190    return n.op == "call_function" and n.target in [
191        torch.ops.aten.conv1d.default,
192        torch.ops.aten.conv2d.default,
193    ]
194
195
196def _is_conv_transpose_node(n: Node):
197    """
198    Return whether the node refers to an aten conv_transpose op.
199    """
200    return n.op == "call_function" and n.target in [
201        torch.ops.aten.conv_transpose1d,
202        torch.ops.aten.conv_transpose1d.default,
203        torch.ops.aten.conv_transpose2d,
204        torch.ops.aten.conv_transpose2d.input,
205    ]
206
207
208def _is_conv_or_conv_transpose_node(n: Node):
209    """
210    Return whether the node refers to an aten conv or conv transpose op.
211    """
212    return _is_conv_node(n) or _is_conv_transpose_node(n)
213
214
215def _is_conv_transpose_fn(conv_fn: Callable):
216    return conv_fn in [F.conv_transpose1d, F.conv_transpose2d]
217
218
219def _is_bn_node(n: Node):
220    return (
221        _is_supported_batch_norm_for_training(n)
222        or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
223    )
224
225
226def fold_bn_weights_into_conv_node(
227    conv_node: Node,
228    conv_weight_node: Node,
229    conv_bias_node: Optional[Node],
230    bn_node: Node,
231    m: GraphModule,
232) -> None:
233    # conv args: input, weight, bias, stride, padding, dilation, ...
234    conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
235    conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
236    transpose = _is_conv_transpose_node(conv_node)
237
238    # eval bn args: input, weight, bias, running mean, running var, momentum, eps
239    # train bn args: input, weight, bias, running mean, running var, training, momentum, eps
240    bn_args_schema = bn_node.target._schema.arguments  # type: ignore[union-attr]
241    bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
242    bn_w = _get_tensor_constant_from_node(bn_args[1], m)
243    bn_b = _get_tensor_constant_from_node(bn_args[2], m)
244    bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
245    bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
246    if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
247        eps_arg_index = 6
248    elif _is_supported_batch_norm_for_training(bn_node):
249        eps_arg_index = 7
250    else:
251        raise ValueError("BN node target is unexpected ", bn_node.target)
252    bn_eps = bn_args[eps_arg_index]
253
254    fused_weight, fused_bias = fuse_conv_bn_weights(
255        conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose
256    )
257
258    # update the weight and bias for conv
259    conv_args = list(conv_node.args)
260    # filling in the default bias argument
261    if len(conv_args) == 2:
262        conv_args.append(None)
263
264    # calling data since the fused_weight and fused_bias are nn.Parameter
265    weight_attr_name = conv_weight_node.target
266    assert isinstance(weight_attr_name, str)
267    _assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER)
268    if conv_bias_node is not None:
269        bias_attr_name = conv_bias_node.target
270        _assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER)
271    else:
272        bias_attr_name = weight_attr_name + "_bias"
273        _assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER)
274        with m.graph.inserting_before(conv_node):
275            get_bias_node = m.graph.get_attr(bias_attr_name)
276        # NOTE: here we assume the bias of conv is not quantized!
277        conv_args[2] = get_bias_node
278    conv_node.args = tuple(conv_args)
279
280    # native_batch_norm has 3 outputs, we expect getitem calls on the output
281    # and we want to replace the uses of getitem 0 with the output of conv
282    #
283    if bn_node.target == torch.ops.aten.batch_norm.default:
284        # With the new training ir, instead of batch_norm + getitem,
285        # we only have the batch_norm node.
286        #
287        # Before:
288        # conv -> bn -> users
289        # After:
290        # conv -> users
291        #       bn has no users now
292        bn_node.replace_all_uses_with(conv_node)
293    else:
294        # Before:
295        # conv -> bn - (first output) -> users1
296        #          \ - (second output) -> users2
297        #          \ - (third output) -> users3
298        # After:
299        # conv -> (first output) -> users1
300        #       bn -
301        #          \ - (second output) -> users2
302        #          \ - (third output) -> users3
303        # if users2 and users3 are empty then bn will be removed through dead code elimination
304        for user in bn_node.users:
305            if (
306                user.op != "call_function"
307                or user.target != operator.getitem
308                or user.args[1] != 0
309            ):
310                continue
311            user.replace_all_uses_with(conv_node)
312
313    # If the BN node does not have users, erase it from the graph
314    # Note: we need to do this manually because the model can still be in train
315    # mode at this point, in which case DCE won't erase the BN node automatically
316    # since the node refers to a mutating op. Here we still need to call DCE first
317    # to get rid of the unused getitem nodes that consume the BN node.
318    m.graph.eliminate_dead_code()
319    if len(bn_node.users) == 0:
320        m.graph.erase_node(bn_node)
321
322
323# fuse conv bn weights, inplace modification of the graph_module and graph
324def _fuse_conv_bn_(m: GraphModule) -> None:
325    has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
326    if not has_bn:
327        return
328    for n in m.graph.nodes:
329        if n.op != "call_function" or n.target not in (
330            torch.ops.aten._native_batch_norm_legit_no_training.default,
331            torch.ops.aten.batch_norm.default,
332        ):
333            continue
334        bn_node = n
335        n = bn_node.args[0]
336        if not _is_conv_or_conv_transpose_node(n):
337            continue
338        conv_node = n
339        conv_weight_node = conv_node.args[1]
340        conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
341        fold_bn_weights_into_conv_node(
342            conv_node, conv_weight_node, conv_bias_node, bn_node, m
343        )
344
345    m.graph.eliminate_dead_code()
346    m.recompile()
347
348
349def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
350    # TODO: move this information to fx node itself
351    node_name_to_scope: Dict[str, Tuple[str, type]] = {}
352    for n in model.graph.nodes:
353        nn_module_stack = n.meta.get("nn_module_stack", None)
354        current_scope = ("", type(None))
355        if nn_module_stack:
356            bt = list(nn_module_stack.values())[-1]
357            current_scope = (bt[0].split(".")[-1], bt[1])
358        node_name_to_scope[n.name] = current_scope
359    return node_name_to_scope
360
361
362def _get_aten_graph_module_for_pattern(
363    pattern: Callable,
364    example_inputs: Tuple[Any, ...],
365    is_cuda: bool = False,
366    **kwargs,
367) -> GraphModule:
368    """
369    Convert the pattern to an FX graph with decomposed aten ops.
370    """
371    if is_cuda:
372        example_inputs = tuple(
373            [x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
374        )
375    aten_pattern = capture_pre_autograd_graph(
376        pattern,  # type: ignore[arg-type]
377        example_inputs,
378        kwargs,
379    )
380    aten_pattern.graph.eliminate_dead_code()
381    aten_pattern.recompile()
382
383    # ep.module() adds copy_ nodes for the mutated inputs.
384    # For patterns, it doesn't matter
385    for node in aten_pattern.graph.nodes:
386        if (
387            node.op == "call_function"
388            and node.target == torch.ops.aten.copy_.default
389            and len(node.users) == 0
390        ):
391            aten_pattern.graph.erase_node(node)
392
393    aten_pattern.graph.eliminate_dead_code()
394    aten_pattern.recompile()
395
396    return aten_pattern  # type: ignore[return-value]
397
398
399def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
400    """Remove .tensor overload for quantize/dequantize ops so that we can
401    use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
402    """
403    _MAP = {
404        torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
405        torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
406        torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
407        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
408        torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor,
409        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor,
410        torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
411        torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel,
412        torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp,
413    }
414    for n in match_pattern.graph.nodes:
415        if n.op != "call_function":
416            continue
417        if n.target in _MAP:
418            n.target = _MAP[n.target]
419
420
421def _is_literal(arg):
422    if isinstance(arg, (int, float)):
423        return True
424    if isinstance(arg, (tuple, list)):
425        return all(map(_is_literal, arg))
426    return False
427
428
429def _replace_literals_with_new_placeholders(
430    gm: torch.fx.GraphModule,
431    merge_dup: bool = False,
432    exclude_literals: Optional[List[Any]] = None,
433):
434    """Replace the literals in the graph with placeholder nodes that's created on the fly while we
435    traverse the graph, so that the literal arguments in the graph can be matched and replaced
436
437    To use this, the pattern and replacement graph should have the exact same number of literal args
438    and they should be used in the exact same order in the pattern and replacement graph.
439
440    If the literal arguments are not used in the same order in pattern and replacement graph, please
441    use `_replace_literals_with_existing_placeholders` instead
442
443    Args:
444        `gm`: input GraphModule that we'll transform
445        `merge_dup`: boolean flag to indicate that if the same literal appears multiple times in
446         the graph, whether they should correspond to the same placeholder or not
447        `exclude_literals`: a list of literals that will not be replaced with placeholders
448
449    Example:
450
451    # 1. Original Graph
452    def pattern(self, x):
453        return x + 3
454
455    def replacement(self, x):
456        return x - 3
457
458    example_inputs = (torch.randn(1, 3, 3, 3),)
459    pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
460    replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
461
462    # 2. Before calling replace literals we'll see the following graph:
463    def pattern(self, x):
464        return x + 3
465
466    def replacement(self, x):
467        return x - 3
468
469    pattern_gm = _replace_literals_with_new_placeholders(pattern_gm)
470    replacement_gm = _replace_literals_with_new_placeholders(replacement_gm)
471
472    # 3. After replacing literals with new placeholder nodes
473
474    def pattern(self, x, new_ph):
475        return x + new_ph
476
477    def pattern(self, x, new_ph):
478        return x - new_ph
479
480    """
481    last_ph = None
482    cnt = 0
483    literal_to_ph: Dict[Union[float, bool, int, torch.dtype], Node] = {}
484    if exclude_literals is None:
485        exclude_literals = []
486
487    in_spec = gm._in_spec
488    args_spec = in_spec.children_specs[0]
489    for node in gm.graph.nodes:
490        if node.op == "placeholder":
491            last_ph = node
492            cnt += 1
493            continue
494        with gm.graph.inserting_after(last_ph):
495            new_args = []
496            for arg in node.args:
497                if _is_literal(arg) and arg not in exclude_literals:
498                    if merge_dup and arg in literal_to_ph:
499                        new_args.append(literal_to_ph[arg])
500                    else:
501                        ph_node = gm.graph.placeholder("arg" + str(cnt))
502                        new_args.append(ph_node)
503                        args_spec.children_specs.append(LeafSpec())
504                        cnt += 1
505                        if merge_dup:
506                            literal_to_ph[arg] = ph_node
507                else:
508                    new_args.append(arg)
509            new_args = tuple(new_args)
510
511        node.args = new_args
512
513    # Update `num_nodes`, `num_leaves`, `num_children`.
514    args_spec.__post_init__()
515    in_spec.__post_init__()
516    return gm
517
518
519def _replace_literals_with_existing_placeholders(
520    gm: torch.fx.GraphModule,
521    exclude_literals: Optional[List[Any]] = None,
522    literal_to_ph_idx: Optional[Dict[Union[float, int, bool, torch.dtype], int]] = None,
523):
524    """Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments
525    in the graph can be matched and replaced
526
527    To use this, all literal args in the graph should be unique and each of them should correspond
528    to exactly one placeholder node
529
530    # 1. Original Graph
531    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
532        return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
533
534    def replacement(x_i8, scale, zero_point, quant_min, quant_max):
535        x_i8 = torch.clamp(x_i8, quant_min, quant_max)
536        return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
537
538    example_inputs = (
539        torch.randn(1, 3, 3, 3),
540        1.0,
541        0,
542        -128,
543        127,
544    )
545    pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
546    replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
547
548    # 2. Before calling replace literals we'll see the following graph:
549    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
550        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
551        return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127)
552
553    def replacement(x_i8, scale, zero_point, quant_min, quant_max):
554        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
555        x_i8 = torch.clamp(x_i8, -128, 127)
556        return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32)
557
558    # Note that literal args appear in different order in pattern and replacement graph, so
559    # we can't use _replace_literals_with_new_placeholders
560
561    literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4}
562    pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx)
563    replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx)
564
565    # 3. After replacing literals with existing placeholder nodes
566
567    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
568        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
569        return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
570
571    def replacement(x_i8, scale, zero_point, quant_min, quant_max):
572        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
573        x_i8 = torch.clamp(x_i8, quant_min, quant_max)
574        return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
575    """
576    if exclude_literals is None:
577        exclude_literals = []
578
579    if literal_to_ph_idx is None:
580        literal_to_ph_idx = {}
581
582    phs = [node for node in gm.graph.nodes if node.op == "placeholder"]
583
584    for node in gm.graph.nodes:
585        if node.op != "call_function":
586            continue
587        new_args = []
588        for arg in node.args:
589            if (
590                _is_literal(arg)
591                and arg not in exclude_literals
592                and arg in literal_to_ph_idx
593            ):
594                ph_idx = literal_to_ph_idx[arg]
595                ph_node = phs[ph_idx]
596                new_args.append(ph_node)
597            else:
598                new_args.append(arg)
599        new_args = tuple(new_args)
600        node.args = new_args
601    return gm
602
603
604# TODO: Handle this in export itself and don't wrap the model in another GraphModule
605# in prepare and convert
606def _disallow_eval_train(model: GraphModule):
607    """
608    Disallow calling `model.train()` or `model.eval()` on the given GraphModule.
609    This is useful for exported models, where these methods don't actually behave as expected.
610    """
611    error_message = """
612        Calling train() or eval() is not supported for exported models.
613        Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead.
614
615        If you cannot replace the calls to `model.train()` and `model.eval()`, you may override
616        the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`,
617        which does the above automatically for you. Note that this has limited effect on switching
618        behavior between train and eval modes, and should be used only for special ops such as dropout
619        and batchnorm.
620        """
621
622    def _train(self, mode: bool = True):
623        raise NotImplementedError(error_message)
624
625    def _eval(self, mode: bool = True):
626        raise NotImplementedError(error_message)
627
628    model.train = types.MethodType(_train, model)  # type: ignore[method-assign]
629    model.eval = types.MethodType(_eval, model)  # type: ignore[method-assign]
630    return model
631