xref: /aosp_15_r20/external/executorch/exir/passes/_quant_patterns_and_replacements.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8from typing import Callable, List, Optional, Tuple
9
10import torch
11from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
12from executorch.exir.passes.replace_aten_with_edge_pass import (
13    aten_to_edge,
14    should_lower_to_edge,
15)
16from torch import fx
17from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
18from torch.library import impl, register_fake
19
20
21__all__ = [
22    "get_quant_patterns_and_replacements",
23]
24
25# TODO: extending an existing library that is defined in OSS might be a bit
26# confusing, we can investigate if it is possible to define a new library
27
28quantized_decomposed_lib.define(
29    "embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
30    "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
31)
32
33quantized_decomposed_lib.define(
34    "embedding_byte.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
35    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
36)
37
38quantized_decomposed_lib.define(
39    "embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
40    "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
41)
42
43quantized_decomposed_lib.define(
44    "embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
45    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
46)
47
48
49def embedding_weight_checks(weight, weight_scales, weight_zero_points):
50    assert weight.dtype in [
51        torch.int8,
52        torch.uint8,
53    ], f"Expecting weights to be of dtype in [torch.int8, torch.uint8], but got {weight.dtype}"
54    assert (
55        weight.dim() == 2
56    ), f"Expecting weight tensor to have dim()==2, but found {weight.dim()}"
57
58    assert weight_scales.dtype in [
59        torch.float16,
60        torch.float32,
61    ], f"Expecting weight_scales to be of dtype in [torch.float16, torch.float32], but got {weight_scales.dtype}"
62    assert (
63        weight_scales.dim() == 1 or weight_scales.dim() == 2
64    ), f"Expecting weight_scales tensor to have rank 1 or 2, but found {weight_scales.dim()}"
65    assert weight_scales.size(0) == weight.size(
66        0
67    ), f"Expecting weight and scale tensor to have same number of rows, but found {weight.size()} and {weight_scales.size()}"
68
69    assert (
70        weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype
71    ), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
72    assert (
73        weight_zero_points is None or weight_zero_points.dim() == 1
74    ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
75    assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
76        0
77    ), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"
78
79
80@impl(quantized_decomposed_lib, "embedding_byte", "CompositeExplicitAutograd")
81def embedding_byte(
82    weight: torch.Tensor,
83    weight_scales: torch.Tensor,
84    weight_zero_points: Optional[torch.Tensor],
85    weight_quant_min: int,
86    weight_quant_max: int,
87    indices: torch.Tensor,
88) -> torch.Tensor:
89    embedding_weight_checks(weight, weight_scales, weight_zero_points)
90    group_size = weight.size(1) // (
91        weight_scales.size(1) if weight_scales.dim() == 2 else 1
92    )
93    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
94        weight,
95        weight_scales,
96        weight_zero_points,
97        weight_quant_min,
98        weight_quant_max,
99        weight.dtype,
100        group_size,
101        weight_scales.dtype,
102    )
103    return torch.ops.aten.embedding.default(weight, indices)
104
105
106@register_fake("quantized_decomposed::embedding_byte.out")
107def embedding_byte_out_meta(
108    weight: torch.Tensor,
109    weight_scales: torch.Tensor,
110    weight_zero_points: Optional[torch.Tensor],
111    weight_quant_min: int,
112    weight_quant_max: int,
113    indices: torch.Tensor,
114    out: torch.Tensor,
115) -> torch.Tensor:
116    return embedding_byte(
117        weight,
118        weight_scales,
119        weight_zero_points,
120        weight_quant_min,
121        weight_quant_max,
122        indices,
123    )
124
125
126@impl(quantized_decomposed_lib, "embedding_byte.dtype", "CompositeExplicitAutograd")
127def embedding_byte_dtype(
128    weight: torch.Tensor,
129    weight_scales: torch.Tensor,
130    weight_zero_points: Optional[torch.Tensor],
131    weight_quant_min: int,
132    weight_quant_max: int,
133    indices: torch.Tensor,
134    dtype: Optional[torch.dtype],
135) -> torch.Tensor:
136    embedding_weight_checks(weight, weight_scales, weight_zero_points)
137    group_size = weight.size(1) // (
138        weight_scales.size(1) if weight_scales.dim() == 2 else 1
139    )
140    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
141        weight,
142        weight_scales,
143        weight_zero_points,
144        weight_quant_min,
145        weight_quant_max,
146        weight.dtype,
147        group_size,
148        dtype,
149    )
150    return torch.ops.aten.embedding.default(weight, indices)
151
152
153@register_fake("quantized_decomposed::embedding_byte.dtype_out")
154def embedding_byte_dtype_out_meta(
155    weight: torch.Tensor,
156    weight_scales: torch.Tensor,
157    weight_zero_points: Optional[torch.Tensor],
158    weight_quant_min: int,
159    weight_quant_max: int,
160    indices: torch.Tensor,
161    dtype: Optional[torch.dtype],
162    out: torch.Tensor,
163) -> torch.Tensor:
164    return embedding_byte_dtype(
165        weight,
166        weight_scales,
167        weight_zero_points,
168        weight_quant_min,
169        weight_quant_max,
170        indices,
171        dtype,
172    )
173
174
175quantized_decomposed_lib.define(
176    "embedding_2bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
177    "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
178)
179
180quantized_decomposed_lib.define(
181    "embedding_2bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
182    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
183)
184
185quantized_decomposed_lib.define(
186    "embedding_2bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
187    "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
188)
189
190quantized_decomposed_lib.define(
191    "embedding_2bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
192    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
193)
194
195
196@impl(quantized_decomposed_lib, "embedding_2bit", "CompositeExplicitAutograd")
197def embedding_2bit(
198    weight: torch.Tensor,
199    weight_scales: torch.Tensor,
200    weight_zero_points: Optional[torch.Tensor],
201    weight_quant_min: int,
202    weight_quant_max: int,
203    indices: torch.Tensor,
204) -> torch.Tensor:
205    embedding_weight_checks(weight, weight_scales, weight_zero_points)
206    group_size = (4 * weight.size(1)) // (
207        weight_scales.size(1) if weight_scales.dim() == 2 else 1
208    )
209    weight_0 = weight & 3
210    weight_1 = (weight & 12) >> 2
211    weight_2 = (weight & 48) >> 4
212    weight_3 = (weight & 192) >> 6
213    weight_unpacked = torch.stack((weight_0, weight_1, weight_2, weight_3), dim=-1)
214    weight = weight_unpacked.view(weight.shape[0], -1)
215    weight = weight.view(torch.int8).add(-2)
216
217    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
218        weight,
219        weight_scales,
220        weight_zero_points,
221        weight_quant_min,
222        weight_quant_max,
223        weight.dtype,
224        group_size,
225        weight_scales.dtype,
226    )
227    return torch.ops.aten.embedding.default(weight, indices)
228
229
230@register_fake("quantized_decomposed::embedding_2bit.out")
231def embedding_2bit_out_meta(
232    weight: torch.Tensor,
233    weight_scales: torch.Tensor,
234    weight_zero_points: Optional[torch.Tensor],
235    weight_quant_min: int,
236    weight_quant_max: int,
237    indices: torch.Tensor,
238    out: torch.Tensor,
239) -> torch.Tensor:
240    return embedding_2bit(
241        weight,
242        weight_scales,
243        weight_zero_points,
244        weight_quant_min,
245        weight_quant_max,
246        indices,
247    )
248
249
250@impl(quantized_decomposed_lib, "embedding_2bit.dtype", "CompositeExplicitAutograd")
251def embedding_2bit_dtype(
252    weight: torch.Tensor,
253    weight_scales: torch.Tensor,
254    weight_zero_points: Optional[torch.Tensor],
255    weight_quant_min: int,
256    weight_quant_max: int,
257    indices: torch.Tensor,
258    dtype: Optional[torch.dtype],
259) -> torch.Tensor:
260    embedding_weight_checks(weight, weight_scales, weight_zero_points)
261    group_size = (4 * weight.size(1)) // (
262        weight_scales.size(1) if weight_scales.dim() == 2 else 1
263    )
264    weight_0 = weight & 3
265    weight_1 = (weight & 12) >> 2
266    weight_2 = (weight & 48) >> 4
267    weight_3 = (weight & 192) >> 6
268    weight_unpacked = torch.stack((weight_0, weight_1, weight_2, weight_3), dim=-1)
269    weight = weight_unpacked.view(weight.shape[0], -1)
270    weight = weight.view(torch.int8).add(-2)
271
272    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
273        weight,
274        weight_scales,
275        weight_zero_points,
276        weight_quant_min,
277        weight_quant_max,
278        weight.dtype,
279        group_size,
280        dtype,
281    )
282    return torch.ops.aten.embedding.default(weight, indices)
283
284
285@register_fake("quantized_decomposed::embedding_2bit.dtype_out")
286def embedding_2bit_dtype_out_meta(
287    weight: torch.Tensor,
288    weight_scales: torch.Tensor,
289    weight_zero_points: Optional[torch.Tensor],
290    weight_quant_min: int,
291    weight_quant_max: int,
292    indices: torch.Tensor,
293    dtype: Optional[torch.dtype],
294    out: torch.Tensor,
295) -> torch.Tensor:
296    return embedding_2bit_dtype(
297        weight,
298        weight_scales,
299        weight_zero_points,
300        weight_quant_min,
301        weight_quant_max,
302        indices,
303        dtype,
304    )
305
306
307quantized_decomposed_lib.define(
308    "embedding_4bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
309    "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
310)
311
312quantized_decomposed_lib.define(
313    "embedding_4bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
314    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
315)
316
317quantized_decomposed_lib.define(
318    "embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
319    "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
320)
321
322quantized_decomposed_lib.define(
323    "embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
324    "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
325)
326
327
328@impl(quantized_decomposed_lib, "embedding_4bit", "CompositeExplicitAutograd")
329def embedding_4bit(
330    weight: torch.Tensor,
331    weight_scales: torch.Tensor,
332    weight_zero_points: Optional[torch.Tensor],
333    weight_quant_min: int,
334    weight_quant_max: int,
335    indices: torch.Tensor,
336) -> torch.Tensor:
337    embedding_weight_checks(weight, weight_scales, weight_zero_points)
338    group_size = (2 * weight.size(1)) // (
339        weight_scales.size(1) if weight_scales.dim() == 2 else 1
340    )
341    weight_even = weight.div(16, rounding_mode="trunc")
342    weight_odd = weight.remainder(16)
343    weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
344    weight = weight_unpacked.view(weight.shape[0], -1)
345    weight = weight.view(torch.int8).add(-8)
346
347    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
348        weight,
349        weight_scales,
350        weight_zero_points,
351        weight_quant_min,
352        weight_quant_max,
353        weight.dtype,
354        group_size,
355        weight_scales.dtype,
356    )
357    return torch.ops.aten.embedding.default(weight, indices)
358
359
360@register_fake("quantized_decomposed::embedding_4bit.out")
361def embedding_4bit_out_meta(
362    weight: torch.Tensor,
363    weight_scales: torch.Tensor,
364    weight_zero_points: Optional[torch.Tensor],
365    weight_quant_min: int,
366    weight_quant_max: int,
367    indices: torch.Tensor,
368    out: torch.Tensor,
369) -> torch.Tensor:
370    return embedding_4bit(
371        weight,
372        weight_scales,
373        weight_zero_points,
374        weight_quant_min,
375        weight_quant_max,
376        indices,
377    )
378
379
380@impl(quantized_decomposed_lib, "embedding_4bit.dtype", "CompositeExplicitAutograd")
381def embedding_4bit_dtype(
382    weight: torch.Tensor,
383    weight_scales: torch.Tensor,
384    weight_zero_points: Optional[torch.Tensor],
385    weight_quant_min: int,
386    weight_quant_max: int,
387    indices: torch.Tensor,
388    dtype: Optional[torch.dtype],
389) -> torch.Tensor:
390    embedding_weight_checks(weight, weight_scales, weight_zero_points)
391    group_size = (2 * weight.size(1)) // (
392        weight_scales.size(1) if weight_scales.dim() == 2 else 1
393    )
394    weight_even = weight.div(16, rounding_mode="trunc")
395    weight_odd = weight.remainder(16)
396    weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
397    weight = weight_unpacked.view(weight.shape[0], -1)
398    weight = weight.view(torch.int8).add(-8)
399
400    weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
401        weight,
402        weight_scales,
403        weight_zero_points,
404        weight_quant_min,
405        weight_quant_max,
406        weight.dtype,
407        group_size,
408        dtype,
409    )
410    return torch.ops.aten.embedding.default(weight, indices)
411
412
413@register_fake("quantized_decomposed::embedding_4bit.dtype_out")
414def embedding_4bit_dtype_out_meta(
415    weight: torch.Tensor,
416    weight_scales: torch.Tensor,
417    weight_zero_points: Optional[torch.Tensor],
418    weight_quant_min: int,
419    weight_quant_max: int,
420    indices: torch.Tensor,
421    dtype: Optional[torch.dtype],
422    out: torch.Tensor,
423) -> torch.Tensor:
424    return embedding_4bit_dtype(
425        weight,
426        weight_scales,
427        weight_zero_points,
428        weight_quant_min,
429        weight_quant_max,
430        indices,
431        dtype,
432    )
433
434
435quantized_decomposed_lib.define(
436    "mixed_mm(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points) -> Tensor",
437)
438
439quantized_decomposed_lib.define(
440    "mixed_linear(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, ScalarType? dtype=None) -> Tensor",
441)
442
443quantized_decomposed_lib.define(
444    "add(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc"
445)
446
447quantized_decomposed_lib.define(
448    "add.scalar(Tensor qa, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, ScalarType a_dtype, Scalar b, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max, ScalarType out_dtype) -> Tensor"
449)
450
451quantized_decomposed_lib.define(
452    "add_relu(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc"
453)
454
455
456def _trace_and_lower_to_edge_ops(f: Callable) -> fx.GraphModule:
457    gm = fx.symbolic_trace(f)
458    for node in gm.graph.nodes:
459        if node.op == "call_function" and should_lower_to_edge(node.target):
460            node.target = aten_to_edge(node.target)
461    gm.recompile()
462    return gm
463
464
465def _sixth_input_is_scalar(match, original_graph, pattern_graph):
466    """check the node that's matched to the sixth input of the pattern graph
467
468    is a scalar number
469    """
470    input_idx = 0
471    for node in pattern_graph.nodes:
472        if node.op == "placeholder":
473            if input_idx == 5:
474                num_node = node
475            input_idx += 1
476    if not isinstance(match.nodes_map[num_node], (int, float)):
477        return False
478    return True
479
480
481def _get_binary_op_patterns_and_replacements(
482    binary_op: Callable,
483    qbinary_op: Callable,
484    qbinary_scalar_op: Callable,
485    qbinary_relu_op: Callable,
486) -> List[Tuple[Callable, Callable]]:
487    @bind_pattern_to_op(quantized_decomposed_lib, qbinary_op.name())
488    def binary_op_pattern(
489        x,
490        x_scale,
491        x_zero_point,
492        x_qmin,
493        x_qmax,
494        y,
495        y_scale,
496        y_zero_point,
497        y_qmin,
498        y_qmax,
499        out_scale,
500        out_zero_point,
501        out_qmin,
502        out_qmax,
503    ):
504        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
505            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
506        )
507        y = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
508            y, y_scale, y_zero_point, y_qmin, y_qmax, torch.uint8
509        )
510
511        out = binary_op(x, y)
512        out = torch.ops.quantized_decomposed.quantize_per_tensor.default(
513            out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8
514        )
515
516        return out
517
518    def binary_op_replacement(
519        x,
520        x_scale,
521        x_zero_point,
522        x_qmin,
523        x_qmax,
524        y,
525        y_scale,
526        y_zero_point,
527        y_qmin,
528        y_qmax,
529        out_scale,
530        out_zero_point,
531        out_qmin,
532        out_qmax,
533    ):
534        out = qbinary_op(
535            x,
536            x_scale,
537            x_zero_point,
538            x_qmin,
539            x_qmax,
540            y,
541            y_scale,
542            y_zero_point,
543            y_qmin,
544            y_qmax,
545            out_scale,
546            out_zero_point,
547            out_qmin,
548            out_qmax,
549        )
550
551        return out
552
553    @bind_pattern_to_op(quantized_decomposed_lib, qbinary_scalar_op.name())
554    def binary_op_scalar_1_pattern(
555        x,
556        x_scale,
557        x_zero_point,
558        x_qmin,
559        x_qmax,
560        num,
561        out_scale,
562        out_zero_point,
563        out_qmin,
564        out_qmax,
565    ):
566        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
567            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
568        )
569
570        out = binary_op(x, num)
571        out = torch.ops.quantized_decomposed.quantize_per_tensor.default(
572            out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8
573        )
574
575        return out
576
577    def binary_op_scalar_1_replacement(
578        x,
579        x_scale,
580        x_zero_point,
581        x_qmin,
582        x_qmax,
583        num,
584        out_scale,
585        out_zero_point,
586        out_qmin,
587        out_qmax,
588    ):
589        out = qbinary_scalar_op(
590            x,
591            x_scale,
592            x_zero_point,
593            x_qmin,
594            x_qmax,
595            num,
596            out_scale,
597            out_zero_point,
598            out_qmin,
599            out_qmax,
600        )
601
602        return out
603
604    @bind_pattern_to_op(quantized_decomposed_lib, qbinary_scalar_op.name())
605    def binary_op_scalar_2_pattern(
606        x,
607        x_scale,
608        x_zero_point,
609        x_qmin,
610        x_qmax,
611        num,
612        out_scale,
613        out_zero_point,
614        out_qmin,
615        out_qmax,
616    ):
617        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
618            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
619        )
620
621        out = binary_op(num, x)
622        out = torch.ops.quantized_decomposed.quantize_per_tensor.default(
623            out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8
624        )
625
626        return out
627
628    def binary_op_scalar_2_replacement(
629        x,
630        x_scale,
631        x_zero_point,
632        x_qmin,
633        x_qmax,
634        num,
635        out_scale,
636        out_zero_point,
637        out_qmin,
638        out_qmax,
639    ):
640        out = qbinary_scalar_op(
641            x,
642            x_scale,
643            x_zero_point,
644            x_qmin,
645            x_qmax,
646            num,
647            out_scale,
648            out_zero_point,
649            out_qmin,
650            out_qmax,
651        )
652
653        return out
654
655    @bind_pattern_to_op(quantized_decomposed_lib, qbinary_relu_op.name())
656    def binary_relu_op_pattern(
657        x,
658        x_scale,
659        x_zero_point,
660        x_qmin,
661        x_qmax,
662        y,
663        y_scale,
664        y_zero_point,
665        y_qmin,
666        y_qmax,
667        out_scale,
668        out_zero_point,
669        out_qmin,
670        out_qmax,
671    ):
672        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
673            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
674        )
675        y = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
676            y, y_scale, y_zero_point, y_qmin, y_qmax, torch.uint8
677        )
678
679        out = binary_op(x, y)
680        out = torch.ops.aten.relu.default(out)
681        out = torch.ops.quantized_decomposed.quantize_per_tensor.default(
682            out, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8
683        )
684
685        return out
686
687    def binary_relu_op_replacement(
688        x,
689        x_scale,
690        x_zero_point,
691        x_qmin,
692        x_qmax,
693        y,
694        y_scale,
695        y_zero_point,
696        y_qmin,
697        y_qmax,
698        out_scale,
699        out_zero_point,
700        out_qmin,
701        out_qmax,
702    ):
703        out = qbinary_relu_op(
704            x,
705            x_scale,
706            x_zero_point,
707            x_qmin,
708            x_qmax,
709            y,
710            y_scale,
711            y_zero_point,
712            y_qmin,
713            y_qmax,
714            out_scale,
715            out_zero_point,
716            out_qmin,
717            out_qmax,
718        )
719
720        return out
721
722    return [
723        (
724            _trace_and_lower_to_edge_ops(binary_relu_op_pattern),
725            _trace_and_lower_to_edge_ops(binary_relu_op_replacement),
726            [],
727        ),
728        (
729            _trace_and_lower_to_edge_ops(binary_op_pattern),
730            _trace_and_lower_to_edge_ops(binary_op_replacement),
731            [],
732        ),
733        (
734            _trace_and_lower_to_edge_ops(binary_op_scalar_1_pattern),
735            _trace_and_lower_to_edge_ops(binary_op_scalar_1_replacement),
736            [_sixth_input_is_scalar],
737        ),
738        (
739            _trace_and_lower_to_edge_ops(binary_op_scalar_2_pattern),
740            _trace_and_lower_to_edge_ops(binary_op_scalar_2_replacement),
741            [_sixth_input_is_scalar],
742        ),
743    ]
744
745
746def _get_binary_ops_patterns_and_replacements() -> (
747    List[Tuple[Callable, Callable, List[Callable]]]
748):
749
750    # TODO: replace qbinary op with the ops implemented in lean mode
751    binary_op_to_qbinary_ops = {
752        exir_ops.edge.aten.add.Tensor: (
753            exir_ops.edge.quantized_decomposed.add.default,
754            exir_ops.edge.quantized_decomposed.add.scalar,
755            exir_ops.edge.quantized_decomposed.add_relu.default,
756        ),
757    }
758    pattern_and_replacements = []
759    for binary_op, (qbop, qbscalar_op, qbrelu_op) in binary_op_to_qbinary_ops.items():
760        pattern_and_replacements.extend(
761            _get_binary_op_patterns_and_replacements(
762                binary_op, qbop, qbscalar_op, qbrelu_op
763            )
764        )
765
766    return pattern_and_replacements
767
768
769def _get_reshape_patterns_and_replacements() -> (
770    List[Tuple[Callable, Callable, List[Callable]]]
771):
772    def pattern(
773        x,
774        arg0,
775        arg1,
776        x_scale,
777        x_zero_point,
778        x_qmin,
779        x_qmax,
780        out_scale,
781        out_zero_point,
782        out_qmin,
783        out_qmax,
784    ):
785        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
786            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
787        )
788
789        x = torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1)
790        x = torch.ops.quantized_decomposed.quantize_per_tensor.default(
791            x, out_scale, out_zero_point, out_qmin, out_qmax, torch.uint8
792        )
793
794        return x
795
796    def replacement(
797        x,
798        arg0,
799        arg1,
800        x_scale,
801        x_zero_point,
802        x_qmin,
803        x_qmax,
804        out_scale,
805        out_zero_point,
806        out_qmin,
807        out_qmax,
808    ):
809
810        x = torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1)
811        return x
812
813    return [
814        (
815            _trace_and_lower_to_edge_ops(pattern),
816            _trace_and_lower_to_edge_ops(replacement),
817            [],
818        )
819    ]
820
821
822def _get_slice_patterns_and_replacements() -> (
823    List[Tuple[Callable, Callable, List[Callable]]]
824):
825    def pattern(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
826        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
827            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
828        )
829        x = torch.ops.aten.slice_copy.Tensor(x, dim, start, end)
830        x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
831            x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8
832        )
833        return x
834
835    def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
836        x = torch.ops.aten.slice_copy.Tensor(x, dim, start, end)
837        return x
838
839    return [
840        (
841            _trace_and_lower_to_edge_ops(pattern),
842            _trace_and_lower_to_edge_ops(replacement),
843            [],
844        )
845    ]
846
847
848def _get_embedding_ops_patterns_and_replacements() -> (
849    List[Tuple[Callable, Callable, List[Callable]]]
850):
851    def get_pattern_and_replacement():
852        @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
853        def pattern(
854            weight,
855            weight_scales,
856            weight_zero_points,
857            weight_quant_min,
858            weight_quant_max,
859            indicies,
860        ):
861            weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
862                weight,
863                weight_scales,
864                weight_zero_points,
865                0,
866                weight_quant_min,
867                weight_quant_max,
868                torch.uint8,
869            )
870            out = torch.ops.aten.embedding.default(weight, indicies)
871            return out
872
873        def replacement(
874            weight,
875            weight_scales,
876            weight_zero_points,
877            weight_quant_min,
878            weight_quant_max,
879            indicies,
880        ):
881            out = torch.ops.quantized_decomposed.embedding_byte.default(
882                weight,
883                weight_scales,
884                weight_zero_points,
885                weight_quant_min,
886                weight_quant_max,
887                indicies,
888            )
889            return out
890
891        @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
892        def pattern_groupwise(
893            weight,
894            weight_scales,
895            weight_zero_points,
896            weight_quant_min,
897            weight_quant_max,
898            indices,
899            group_size,
900        ):
901            weight = (
902                torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
903                    weight,
904                    weight_scales,
905                    weight_zero_points,
906                    weight_quant_min,
907                    weight_quant_max,
908                    weight.dtype,
909                    group_size,
910                    weight_scales.dtype,
911                )
912            )
913            out = torch.ops.aten.embedding.default(weight, indices)
914            return out
915
916        def replacement_groupwise(
917            weight,
918            weight_scales,
919            weight_zero_points,
920            weight_quant_min,
921            weight_quant_max,
922            indices,
923            group_size,
924        ):
925            out = torch.ops.quantized_decomposed.embedding_byte.default(
926                weight,
927                weight_scales,
928                weight_zero_points,
929                weight_quant_min,
930                weight_quant_max,
931                indices,
932            )
933            return out
934
935        @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
936        def pattern_with_padding_idx(
937            weight,
938            weight_scales,
939            weight_zero_points,
940            weight_quant_min,
941            weight_quant_max,
942            indicies,
943            padding_idx,
944        ):
945            weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
946                weight,
947                weight_scales,
948                weight_zero_points,
949                0,
950                weight_quant_min,
951                weight_quant_max,
952                torch.uint8,
953            )
954            out = torch.ops.aten.embedding.default(weight, indicies, padding_idx)
955            return out
956
957        def replacement_with_padding_idx(
958            weight,
959            weight_scales,
960            weight_zero_points,
961            weight_quant_min,
962            weight_quant_max,
963            indicies,
964            _,  # padding_idx only matters for training and not when running op for inference
965        ):
966            out = torch.ops.quantized_decomposed.embedding_byte.default(
967                weight,
968                weight_scales,
969                weight_zero_points,
970                weight_quant_min,
971                weight_quant_max,
972                indicies,
973            )
974            return out
975
976        @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
977        def pattern_with_padding_idx_groupwise(
978            weight,
979            weight_scales,
980            weight_zero_points,
981            weight_quant_min,
982            weight_quant_max,
983            indices,
984            group_size,
985            padding_idx,
986        ):
987            weight = (
988                torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
989                    weight,
990                    weight_scales,
991                    weight_zero_points,
992                    weight_quant_min,
993                    weight_quant_max,
994                    weight.dtype,
995                    group_size,
996                    weight_scales.dtype,
997                )
998            )
999            out = torch.ops.aten.embedding.default(weight, indices, padding_idx)
1000            return out
1001
1002        def replacement_with_padding_idx_groupwise(
1003            weight,
1004            weight_scales,
1005            weight_zero_points,
1006            weight_quant_min,
1007            weight_quant_max,
1008            indices,
1009            group_size,
1010            _,  # padding_idx only matters for training and not when running op for inference
1011        ):
1012            out = torch.ops.quantized_decomposed.embedding_byte.default(
1013                weight,
1014                weight_scales,
1015                weight_zero_points,
1016                weight_quant_min,
1017                weight_quant_max,
1018                indices,
1019            )
1020            return out
1021
1022        @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte.dtype")
1023        def pattern_with_dtype_groupwise(
1024            weight,
1025            weight_scales,
1026            weight_zero_points,
1027            weight_quant_min,
1028            weight_quant_max,
1029            indices,
1030            group_size,
1031            dtype,
1032        ):
1033            weight = (
1034                torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
1035                    weight,
1036                    weight_scales,
1037                    weight_zero_points,
1038                    weight_quant_min,
1039                    weight_quant_max,
1040                    weight.dtype,
1041                    group_size,
1042                    dtype,
1043                )
1044            )
1045            out = torch.ops.aten.embedding.default(weight, indices)
1046            return out
1047
1048        def replacement_with_dtype_groupwise(
1049            weight,
1050            weight_scales,
1051            weight_zero_points,
1052            weight_quant_min,
1053            weight_quant_max,
1054            indices,
1055            group_size,
1056            dtype,
1057        ):
1058            out = torch.ops.quantized_decomposed.embedding_byte.dtype(
1059                weight,
1060                weight_scales,
1061                weight_zero_points,
1062                weight_quant_min,
1063                weight_quant_max,
1064                indices,
1065                dtype=dtype,
1066            )
1067            return out
1068
1069        return [
1070            (
1071                _trace_and_lower_to_edge_ops(pattern),
1072                _trace_and_lower_to_edge_ops(replacement),
1073                [],
1074            ),
1075            (
1076                _trace_and_lower_to_edge_ops(pattern_groupwise),
1077                _trace_and_lower_to_edge_ops(replacement_groupwise),
1078                [],
1079            ),
1080            (
1081                _trace_and_lower_to_edge_ops(pattern_with_padding_idx),
1082                _trace_and_lower_to_edge_ops(replacement_with_padding_idx),
1083                [],
1084            ),
1085            (
1086                _trace_and_lower_to_edge_ops(pattern_with_padding_idx_groupwise),
1087                _trace_and_lower_to_edge_ops(replacement_with_padding_idx_groupwise),
1088                [],
1089            ),
1090            (
1091                _trace_and_lower_to_edge_ops(pattern_with_dtype_groupwise),
1092                _trace_and_lower_to_edge_ops(replacement_with_dtype_groupwise),
1093                [],
1094            ),
1095        ]
1096
1097    patterns_and_replacements = []
1098    patterns_and_replacements.extend(
1099        get_pattern_and_replacement(),
1100    )
1101    return patterns_and_replacements
1102
1103
1104"""
1105def _get_fixed_qparams_ops_patterns_and_replacements() -> List[Tuple[Callable, Callable, List[Callable]]]:
1106    fixed_qparams_op_to_qop = {
1107        torch.ops.aten.softmax: (torch.ops.quantized_decomposed.softmax, 1.0 / 256.0, 0)
1108    }
1109    def get_pattern_and_replacement(fixed_qparams_op, fixed_scale, fixed_zero_point):
1110        def pattern(x, x_scale, x_zero_point, x_qmin, x_qmax):
1111            x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8)
1112            x = fixed_qparams_op(x)
1113            x = torch.ops.quantized_decomposed.dequantize_per_tensor.default(x, fixed_scale, fixed_zero_point, 0, 255, torch.uint8)
1114            return x
1115
1116        def replacement(x, x_scale, x_zero_point, x_qmin, x_qmax):
1117            x = fixed_qparams_qop(x, x_scale, x_zero_point, x_qmin, x_qmax, torch.uint8)
1118            return x
1119
1120n        return [(pattern, replacement, [])]
1121
1122    patterns_and_replacements = []
1123    for op, (qop, fixed_scale, fixed_zero_point) in fixed_qparams_op_to_qop.items():
1124        patterns_and_replacements.extend(
1125            get_pattern_and_replacement(op, qop, fixed_scale, fixed_zero_point)
1126        )
1127"""
1128
1129
1130def get_quant_patterns_and_replacements() -> (
1131    List[Tuple[Callable, Callable, List[Callable]]]
1132):
1133
1134    return copy.copy(
1135        [
1136            *_get_binary_ops_patterns_and_replacements(),
1137            # TODO: enable following after the corresponding ops are implemented
1138            *_get_reshape_patterns_and_replacements(),
1139            *_get_slice_patterns_and_replacements(),
1140            # *_get_fixed_qparams_ops_patterns_and_replacements(),
1141            *_get_embedding_ops_patterns_and_replacements(),
1142        ]
1143    )
1144