xref: /aosp_15_r20/external/pytorch/torch/_inductor/mkldnn_lowerings.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import functools
4from typing import List, Optional
5
6import torch
7import torch.utils._pytree as pytree
8from torch._inductor.kernel.mm_common import mm_args
9
10from . import ir
11from .codegen.cpp_gemm_template import CppPackedGemmTemplate
12from .codegen.cpp_utils import create_epilogue_with_attr
13from .ir import TensorBox
14from .lowering import (
15    add,
16    add_needs_realized_inputs,
17    aten,
18    permute,
19    register_lowering,
20    to_dtype,
21    view,
22)
23from .select_algorithm import (
24    autotune_select_algorithm,
25    ChoiceCaller,
26    ExternKernelChoice,
27)
28from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune
29from .virtualized import ops, V
30
31
32def register_onednn_fusion_ops():
33    if torch._C._has_mkldnn:
34        from . import mkldnn_ir
35
36        aten_mkldnn_linear_unary = ExternKernelChoice(
37            torch.ops.mkldnn._linear_pointwise,
38            "mkldnn::_linear_pointwise",
39            has_out_variant=False,
40            kernel_creator=mkldnn_ir.LinearUnary.create,
41        )
42        aten_mkldnn_linear_binary = ExternKernelChoice(
43            torch.ops.mkldnn._linear_pointwise.binary,
44            "mkldnn::_linear_pointwise",
45            has_out_variant=False,
46            kernel_creator=mkldnn_ir.LinearBinary.create,
47        )
48        aten_mkldnn_qlinear_unary = ExternKernelChoice(
49            torch.ops.onednn.qlinear_pointwise,
50            "onednn::qlinear_pointwise",
51            has_out_variant=False,
52            kernel_creator=mkldnn_ir.QLinearPointwisePT2E.create,
53        )
54        aten_mkldnn_qlinear_binary = ExternKernelChoice(
55            torch.ops.onednn.qlinear_pointwise.binary,
56            "onednn::qlinear_pointwise",
57            has_out_variant=False,
58            kernel_creator=mkldnn_ir.QLinearPointwiseBinaryPT2E.create,
59        )
60        cpu_needs_realized_inputs = [
61            torch.ops.mkldnn._convolution_pointwise,
62            torch.ops.mkldnn._convolution_pointwise_,
63            torch.ops.mkldnn._convolution_transpose_pointwise,
64            torch.ops.mkldnn._linear_pointwise,
65            aten.mkldnn_rnn_layer.default,
66            torch.ops.onednn.qconv2d_pointwise,
67        ]
68
69        @register_lowering(torch.ops.mkldnn._convolution_pointwise)
70        def convolution_unary(
71            x: TensorBox,
72            weight: TensorBox,
73            bias: TensorBox,
74            padding,
75            stride,
76            dilation,
77            groups,
78            attr,
79            scalars,
80            algorithm,
81        ):
82            return TensorBox.create(
83                mkldnn_ir.ConvolutionUnary.create(
84                    x,
85                    weight,
86                    bias,
87                    padding,
88                    stride,
89                    dilation,
90                    groups,
91                    attr,
92                    scalars,
93                    algorithm,
94                )
95            )
96
97        @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary)
98        def convolution_binary(
99            x: TensorBox,
100            other: TensorBox,
101            weight: TensorBox,
102            bias: TensorBox,
103            padding,
104            stride,
105            dilation,
106            groups,
107            binary_attr,
108            binary_alpha,
109            unary_attr,
110            unary_scalars,
111            unary_algorithm,
112        ):
113            return TensorBox.create(
114                mkldnn_ir.ConvolutionBinary.create(
115                    x,
116                    other,
117                    weight,
118                    bias,
119                    padding,
120                    stride,
121                    dilation,
122                    groups,
123                    binary_attr,
124                    binary_alpha,
125                    unary_attr,
126                    unary_scalars,
127                    unary_algorithm,
128                )
129            )
130
131        @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary)
132        def convolution_binary_inplace(
133            x: TensorBox,
134            other: TensorBox,
135            weight: TensorBox,
136            bias: TensorBox,
137            padding,
138            stride,
139            dilation,
140            groups,
141            binary_attr,
142            binary_alpha,
143            unary_attr,
144            unary_scalars,
145            unary_algorithm,
146        ):
147            return TensorBox.create(
148                mkldnn_ir.ConvolutionBinaryInplace.create(
149                    x,
150                    other,
151                    weight,
152                    bias,
153                    padding,
154                    stride,
155                    dilation,
156                    groups,
157                    binary_attr,
158                    binary_alpha,
159                    unary_attr,
160                    unary_scalars,
161                    unary_algorithm,
162                )
163            )
164
165        @register_lowering(torch.ops.mkldnn._linear_pointwise)
166        def linear_unary(
167            x: TensorBox,
168            w: TensorBox,
169            b: TensorBox,
170            attr,
171            scalars,
172            algorithm,
173            layout=None,
174        ):
175            x_size = x.get_size()
176            if len(x_size) > 2:
177                # GEMM template needs 2D input, normalize input shape here
178                x = view(x, [-1, x_size[-1]])
179            if b is not None:
180                b = ir.ExternKernel.realize_input(b)
181            choices: List[ChoiceCaller] = []
182            if use_max_autotune():
183                transposed_w = permute(w, [1, 0])
184                *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout)
185                if use_cpp_packed_gemm_template(layout, x, transposed_w):
186
187                    def epilogue_creator(buf):
188                        return create_epilogue_with_attr(
189                            buf, attr, scalars=scalars, algorithm=algorithm
190                        )
191
192                    kwargs = dict(
193                        has_bias=b is not None,
194                        trans_w=True,
195                        epilogue_creator=None if attr == "none" else epilogue_creator,
196                    )
197                    if b is not None:
198                        kwargs["input_indices"] = [2, 0, 1]  # type: ignore[assignment]
199                    CppPackedGemmTemplate.add_choices(
200                        choices,
201                        layout,
202                        [x, w] if b is None else [x, w, b],
203                        **kwargs,  # type: ignore[arg-type]
204                    )
205            if len(choices) == 0 or use_aten_gemm_kernels():
206                kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm)
207                if b is None:
208                    kwargs["B"] = None
209                choices.append(
210                    aten_mkldnn_linear_unary.bind(
211                        [x, w] if b is None else [x, w, b],
212                        layout,
213                        **kwargs,
214                    )
215                )
216            assert w.get_name() in V.graph.constants
217            input_gen_fns = {
218                1: lambda x: V.graph.constants[x.get_name()],
219            }
220            result = autotune_select_algorithm(
221                "linear_unary",
222                choices,
223                [x, w] if b is None else [x, w, b],
224                layout,
225                input_gen_fns=input_gen_fns,
226            )
227            if len(x_size) > 2:
228                result = view(result, (*x_size[:-1], result.get_size()[-1]))
229            return result
230
231        @register_lowering(torch.ops.mkldnn._linear_pointwise.binary)
232        def linear_binary(
233            x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None
234        ):
235            x_size = x.get_size()
236            if len(x_size) > 2:
237                # GEMM template needs 2D input, normalize input shape here
238                x = view(x, [-1, x_size[-1]])
239            y_size = y.get_size()
240            if len(y_size) > 2:
241                y = view(y, [-1, y_size[-1]])
242            if b is not None:
243                b = ir.ExternKernel.realize_input(b)
244            choices: List[ChoiceCaller] = []
245            if use_max_autotune():
246                transposed_w = permute(w, [1, 0])
247                *_, layout, x, transposed_w, y = mm_args(
248                    x, transposed_w, y, layout=layout
249                )
250                if use_cpp_packed_gemm_template(layout, x, transposed_w):
251
252                    def epilogue_creator(buf):
253                        return create_epilogue_with_attr(buf, attr, other=y)
254
255                    kwargs = dict(
256                        has_bias=b is not None,
257                        trans_w=True,
258                        epilogue_creator=epilogue_creator,
259                    )
260                    kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
261                    CppPackedGemmTemplate.add_choices(
262                        choices,
263                        layout,
264                        [x, y, w] if b is None else [x, y, w, b],
265                        **kwargs,  # type: ignore[arg-type]
266                    )
267            if len(choices) == 0 or use_aten_gemm_kernels():
268                kwargs = dict(attr=attr)
269                if b is None:
270                    kwargs["B"] = None
271                choices.append(
272                    aten_mkldnn_linear_binary.bind(
273                        [x, y, w] if b is None else [x, y, w, b],
274                        layout,
275                        **kwargs,
276                    )
277                )
278            assert w.get_name() in V.graph.constants
279            input_gen_fns = {
280                2: lambda x: V.graph.constants[x.get_name()],
281            }
282            result = autotune_select_algorithm(
283                "linear_binary",
284                choices,
285                [x, y, w] if b is None else [x, y, w, b],
286                layout,
287                input_gen_fns=input_gen_fns,
288            )
289            if len(x_size) > 2:
290                result = view(result, (*x_size[:-1], result.get_size()[-1]))
291            return result
292
293        @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
294        def convolution_transpose_unary(
295            x: TensorBox,
296            weight: TensorBox,
297            bias: TensorBox,
298            padding,
299            output_padding,
300            stride,
301            dilation,
302            groups,
303            attr,
304            scalars,
305            algorithm,
306        ):
307            return TensorBox.create(
308                mkldnn_ir.ConvolutionTransposeUnary.create(
309                    x,
310                    weight,
311                    bias,
312                    padding,
313                    output_padding,
314                    stride,
315                    dilation,
316                    groups,
317                    attr,
318                    scalars,
319                    algorithm,
320                )
321            )
322
323        @register_lowering(aten.mkldnn_rnn_layer.default)
324        def mkldnn_rnn_layer(
325            x: TensorBox,
326            w0: TensorBox,
327            w1: TensorBox,
328            w2: TensorBox,
329            w3: TensorBox,
330            hx: TensorBox,
331            cx: TensorBox,
332            reverse: bool,
333            batch_sizes: List[int],
334            mode: int,
335            hidden_size: int,
336            num_layers: int,
337            has_biases: bool,
338            bidirectional: bool,
339            batch_first: bool,
340            train: bool,
341        ):
342            return pytree.tree_map(
343                TensorBox.create,
344                mkldnn_ir.MkldnnRnnLayer.create(
345                    x,
346                    w0,
347                    w1,
348                    w2,
349                    w3,
350                    hx,
351                    cx,
352                    reverse,
353                    batch_sizes,
354                    mode,
355                    hidden_size,
356                    num_layers,
357                    has_biases,
358                    bidirectional,
359                    batch_first,
360                    train,
361                ),
362            )
363
364        @register_lowering(torch.ops.onednn.qconv2d_pointwise, type_promotion_kind=None)
365        def qconvolution_unary(
366            x: TensorBox,
367            x_scale,
368            x_zp,
369            packed_weight: TensorBox,
370            w_scale: TensorBox,
371            w_zp: TensorBox,
372            bias: TensorBox,
373            stride,
374            padding,
375            dilation,
376            groups,
377            o_inv_scale,
378            o_zero_point,
379            output_dtype,
380            attr,
381            scalars,
382            algorithm,
383        ):
384            return TensorBox.create(
385                mkldnn_ir.QConvPointWisePT2E.create(
386                    x,
387                    x_scale,
388                    x_zp,
389                    packed_weight,
390                    w_scale,
391                    w_zp,
392                    bias,
393                    stride,
394                    padding,
395                    dilation,
396                    groups,
397                    o_inv_scale,
398                    o_zero_point,
399                    output_dtype,
400                    attr,
401                    scalars,
402                    algorithm,
403                )
404            )
405
406        @register_lowering(
407            torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None
408        )
409        def qconvolution_binary(
410            x: TensorBox,
411            x_scale,
412            x_zp,
413            accum: TensorBox,
414            accum_scale,
415            accum_zp,
416            packed_weight: TensorBox,
417            w_scale: TensorBox,
418            w_zp: TensorBox,
419            bias: TensorBox,
420            stride,
421            padding,
422            dilation,
423            groups,
424            o_inv_scale,
425            o_zero_point,
426            output_dtype,
427            binary_attr,
428            alpha,
429            unary_attr,
430            unary_scalars,
431            unary_algorithmm,
432        ):
433            if (
434                binary_attr == "sum"
435                and output_dtype in [torch.float32, torch.bfloat16]
436                and accum.get_dtype() in [torch.float32, torch.bfloat16]
437                and accum.get_dtype() != output_dtype
438            ):
439                # For int8-mixed-bf16 quantization and inplace add,
440                # there is case when accum dtype is float32 but output dtype is bfloat16.
441                # Since the accum will be inplaced changed with post op sum,
442                # we will do accum dtype convertion here.
443                accum = to_dtype(accum, output_dtype)
444            return TensorBox.create(
445                mkldnn_ir.QConvPointWiseBinaryPT2E.create(
446                    x,
447                    x_scale,
448                    x_zp,
449                    accum,
450                    accum_scale,
451                    accum_zp,
452                    packed_weight,
453                    w_scale,
454                    w_zp,
455                    bias,
456                    stride,
457                    padding,
458                    dilation,
459                    groups,
460                    o_inv_scale,
461                    o_zero_point,
462                    output_dtype,
463                    binary_attr,
464                    alpha,
465                    unary_attr,
466                    unary_scalars,
467                    unary_algorithmm,
468                )
469            )
470
471        @register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None)
472        def qlinear_unary(
473            x: TensorBox,
474            x_scale,
475            x_zp,
476            packed_weight: TensorBox,
477            w_scale: TensorBox,
478            w_zp: TensorBox,
479            bias: TensorBox,
480            o_scale,
481            o_zero_point,
482            output_dtype,
483            attr,
484            scalars,
485            algorithm,
486            layout=None,
487        ):
488            x_size = x.get_size()
489            if len(x_size) > 2:
490                # GEMM template needs 2D input, normalize input shape here
491                x = view(x, [-1, x_size[-1]])
492            if not isinstance(x_scale, ir.TensorBox):
493                assert type(x_scale) == float
494                x_scale = V.graph.add_tensor_constant(
495                    torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
496                )
497            else:
498                x_scale.realize()
499            if not isinstance(x_zp, ir.TensorBox):
500                assert type(x_zp) == int
501                x_zp = V.graph.add_tensor_constant(
502                    torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
503                )
504            else:
505                x_zp.realize()
506
507            # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
508            # Refer to https://github.com/pytorch/pytorch/blob
509            # /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577
510            w_scale.realize()
511            w_zp.realize()
512            if w_zp.get_dtype() != torch.int32 and isinstance(
513                ir.InputsKernel.unwrap_storage_for_input(w_zp),
514                ir.ConstantBuffer,
515            ):
516                # W_zp might be a ConstantBuffer with int64, convert it to int32
517                w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
518                w_zp = V.graph.add_tensor_constant(
519                    torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
520                )
521
522            bias_dtype = None if bias is None else bias.get_dtype()
523
524            choices: List[ChoiceCaller] = []
525            if use_max_autotune():
526                *_, layout, x, packed_weight = mm_args(
527                    x, packed_weight, layout=layout, out_dtype=output_dtype
528                )
529                if (
530                    isinstance(
531                        ir.InputsKernel.unwrap_storage_for_input(x_zp),
532                        ir.ConstantBuffer,
533                    )
534                    and len(x_zp.get_layout().size) == 0  # Per tensor quant of act
535                    and isinstance(
536                        ir.InputsKernel.unwrap_storage_for_input(w_zp),
537                        ir.ConstantBuffer,
538                    )
539                    and torch.equal(
540                        torch.zeros_like(V.graph.constants[w_zp.get_name()]),
541                        V.graph.constants[w_zp.get_name()],
542                    )  # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA
543                    and use_cpp_packed_gemm_template(layout, x, packed_weight)
544                ):
545                    W_tensor = V.graph.constants[packed_weight.get_name()].to_dense()
546                    weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
547                    weight_compens = V.graph.add_tensor_constant(
548                        weight_compens_tensor,
549                        name=packed_weight.get_name() + "_BMatrixCompens",
550                    )
551
552                    def epilogue_creator(input_buffer):
553                        # Epilogue to convert from s32 to f32 for u8s8f32
554                        assert output_dtype in [
555                            torch.float32,
556                            torch.bfloat16,
557                            torch.uint8,
558                        ]
559                        input_loader = input_buffer.make_loader()
560                        weight_compens_loader = weight_compens.make_loader()
561                        x_scale_loader = x_scale.make_loader()
562                        w_scale_loader = w_scale.make_loader()
563                        x_zp_loader = x_zp.make_loader()
564                        nonlocal bias
565                        bias_loader = None
566                        if bias is not None:
567                            bias_loader = bias.make_loader()
568
569                        def inner_fn(index):
570                            nonlocal bias
571                            input = input_loader(index)
572                            # MicroKernel Output is with int32
573                            # cvt to FP32 before doing compensation
574                            input = ops.to_dtype(input, torch.float32)
575                            weight_compens_index = (index[-1],)
576                            _x_scale = x_scale_loader(())
577                            _x_zp = x_zp_loader(())
578                            _w_scale = w_scale_loader(weight_compens_index)
579                            _weight_compo = weight_compens_loader(weight_compens_index)
580                            # Step 1: Doing compensation to cvt fp32
581                            temp = ops.mul(
582                                ops.mul(
583                                    input,
584                                    _x_scale,
585                                ),
586                                _w_scale,
587                            )
588                            temp = ops.sub(
589                                temp,
590                                ops.mul(
591                                    ops.mul(
592                                        ops.mul(
593                                            _x_scale,
594                                            _w_scale,
595                                        ),
596                                        _x_zp,
597                                    ),
598                                    _weight_compo,
599                                ),
600                            )
601                            # Step 2: add Bias if applicable
602                            if bias is not None:
603                                _bias = bias_loader(weight_compens_index)
604                                nonlocal bias_dtype
605                                assert bias_dtype in [torch.float32, torch.bfloat16]
606                                if bias_dtype == torch.bfloat16:
607                                    _bias = ops.to_dtype(_bias, torch.float32)
608                                temp = ops.add(temp, _bias)
609
610                            return temp
611
612                        output_buf = ir.Pointwise(
613                            device=input_buffer.get_device(),
614                            dtype=torch.float32,  # Hardcode to FP32 for u8s8f32
615                            inner_fn=inner_fn,
616                            ranges=input_buffer.get_size(),
617                        )
618
619                        # Step 3: Doing the unary post op fusion
620                        if attr != "none":
621                            output_buf = create_epilogue_with_attr(
622                                output_buf, attr, scalars=scalars, algorithm=algorithm
623                            )
624
625                        # Step 4: Cast output to Target Dtype
626                        if output_dtype == torch.bfloat16:
627                            output_cast_loader = output_buf.make_loader()
628
629                            def inner_fn_cast_output_to_bf16(index):
630                                input = output_cast_loader(index)
631                                return ops.to_dtype(input, output_dtype)
632
633                            output_buf = ir.Pointwise(
634                                device=output_buf.get_device(),
635                                dtype=output_dtype,
636                                inner_fn=inner_fn_cast_output_to_bf16,
637                                ranges=output_buf.get_size(),
638                            )
639                        elif output_dtype == torch.uint8:
640                            from .lowering import _create_constants
641
642                            requant_input_loader = output_buf.make_loader()
643
644                            def inner_fn_requant(index, scale, zero_point):
645                                input = requant_input_loader(index)
646                                inv_scale, zero_point = _create_constants(
647                                    1.0 / scale, zero_point, dtype=torch.float32
648                                )
649                                val = ops.round(input * inv_scale) + zero_point
650                                qmin, qmax = _create_constants(
651                                    0, 255, dtype=torch.float32
652                                )
653                                clamped = ops.minimum(ops.maximum(val, qmin), qmax)
654                                return ops.to_dtype(clamped, torch.uint8)
655
656                            output_buf = ir.Pointwise(
657                                device=output_buf.get_device(),
658                                dtype=output_dtype,
659                                inner_fn=functools.partial(
660                                    inner_fn_requant,
661                                    scale=float(o_scale),
662                                    zero_point=int(o_zero_point),
663                                ),
664                                ranges=output_buf.get_size(),
665                            )
666
667                        return output_buf
668
669                    assert x.get_dtype() == torch.uint8
670                    CppPackedGemmTemplate.add_choices(
671                        choices,
672                        layout,
673                        [x, x_scale, x_zp, packed_weight, w_scale, w_zp]
674                        if bias is None
675                        else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
676                        has_bias=bias is not None,
677                        epilogue_creator=epilogue_creator,
678                        input_indices=[0, 3, 1, 2, 4, 5]
679                        if bias is None
680                        else [6, 0, 3, 1, 2, 4, 5],
681                    )
682            if len(choices) == 0 or use_aten_gemm_kernels():
683                kwargs = dict(
684                    output_scale=o_scale,
685                    output_zero_point=o_zero_point,
686                    output_dtype=output_dtype,
687                    post_op_name=attr,
688                    post_op_args=scalars,
689                    post_op_algorithm=algorithm,
690                )
691                if bias is None:
692                    kwargs["bias"] = None
693                choices.append(
694                    aten_mkldnn_qlinear_unary.bind(
695                        (x, x_scale, x_zp, packed_weight, w_scale, w_zp)
696                        if bias is None
697                        else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias),
698                        layout,
699                        **kwargs,
700                    )
701                )
702            assert packed_weight.get_name() in V.graph.constants
703            input_gen_fns = {
704                3: lambda x: V.graph.constants[x.get_name()],
705                4: lambda x: V.graph.constants[x.get_name()],
706                5: lambda x: V.graph.constants[x.get_name()],
707                6: lambda x: V.graph.constants[x.get_name()],  # For bias
708            }
709            result = autotune_select_algorithm(
710                "qlinear_unary",
711                choices,
712                [x, x_scale, x_zp, packed_weight, w_scale, w_zp]
713                if bias is None
714                else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
715                layout,
716                input_gen_fns=input_gen_fns,
717            )
718            if len(x_size) > 2:
719                result = view(result, (*x_size[:-1], result.get_size()[-1]))
720            return result
721
722        @register_lowering(
723            torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None
724        )
725        @register_lowering(
726            torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None
727        )
728        def qlinear_binary(
729            x: TensorBox,
730            x_scale,
731            x_zp,
732            packed_weight: TensorBox,
733            w_scale: TensorBox,
734            w_zp: TensorBox,
735            x2: TensorBox,
736            bias: TensorBox,
737            o_scale,
738            o_zero_point,
739            output_dtype,
740            x2_scale,
741            x2_zp,
742            binary_attr,
743            alpha,
744            unary_attr,
745            unary_scalars,
746            unary_algorithmm,
747            layout=None,
748        ):
749            x_size = x.get_size()
750            x2_size = x2.get_size()
751            assert len(x_size) == len(x2_size)
752            if len(x_size) > 2 and binary_attr == "add":
753                # GEMM template needs 2D input, normalize input shape here
754                x = view(x, [-1, x_size[-1]])
755                x2 = view(x2, [-1, x2_size[-1]])
756            if not isinstance(x_scale, ir.TensorBox):
757                assert type(x_scale) == float
758                x_scale = V.graph.add_tensor_constant(
759                    torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
760                )
761            else:
762                x_scale.realize()
763            if not isinstance(x_zp, ir.TensorBox):
764                assert type(x_zp) == int
765                x_zp = V.graph.add_tensor_constant(
766                    torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
767                )
768            else:
769                x_zp.realize()
770
771            # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
772            # Refer to https://github.com/pytorch/pytorch/blob
773            # /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577
774            w_scale.realize()
775            w_zp.realize()
776            if w_zp.get_dtype() != torch.int32 and isinstance(
777                ir.InputsKernel.unwrap_storage_for_input(w_zp),
778                ir.ConstantBuffer,
779            ):
780                w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
781                w_zp = V.graph.add_tensor_constant(
782                    torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
783                )
784            if binary_attr == "sum":
785                if output_dtype in [
786                    torch.float32,
787                    torch.bfloat16,
788                ] and x2.get_dtype() in [torch.float32, torch.bfloat16]:
789                    if x2.get_dtype() != output_dtype:
790                        # For int8-mixed-bf16 quantization and inplace add,
791                        # there is case when accum dtype is float32 but output dtype is bfloat16.
792                        # Since the accum will be inplaced changed with post op sum,
793                        # we will do accum dtype convertion here.
794                        x2 = to_dtype(x2, output_dtype)
795                else:
796                    assert (
797                        x2.get_dtype() == output_dtype
798                    ), "dtype of accum for qlinear post op sum should be the same as output"
799            x2_dtype = x2.get_dtype()
800            bias_dtype = bias.get_dtype() if bias is not None else None
801            choices: List[ChoiceCaller] = []
802            if (
803                use_max_autotune() and binary_attr == "add"
804            ):  # <TODO> Support inplace sum fusion
805                *_, layout, x, packed_weight, x2 = mm_args(
806                    x, packed_weight, x2, layout=layout, out_dtype=output_dtype
807                )
808                if (
809                    isinstance(
810                        ir.InputsKernel.unwrap_storage_for_input(x_zp),
811                        ir.ConstantBuffer,
812                    )
813                    and len(x_zp.get_layout().size) == 0  # Per tensor quant of act
814                    and isinstance(
815                        ir.InputsKernel.unwrap_storage_for_input(w_zp),
816                        ir.ConstantBuffer,
817                    )
818                    and torch.equal(
819                        torch.zeros_like(V.graph.constants[w_zp.get_name()]),
820                        V.graph.constants[w_zp.get_name()],
821                    )  # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA
822                    and use_cpp_packed_gemm_template(layout, x, packed_weight)
823                ):
824                    W_tensor = V.graph.constants[packed_weight.get_name()]
825                    W_tensor = W_tensor.to_dense()
826                    weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
827                    weight_compens = V.graph.add_tensor_constant(
828                        weight_compens_tensor,
829                        name=packed_weight.get_name() + "_BMatrixCompens",
830                    )
831
832                    def epilogue_creator(input_buffer):
833                        # Epilogue to convert from s32 to f32 for u8s8f32
834                        assert output_dtype in [
835                            torch.float32,
836                            torch.bfloat16,
837                            torch.uint8,
838                        ]
839
840                        input_loader = input_buffer.make_loader()
841                        x2_loader = x2.make_loader()
842                        weight_compens_loader = weight_compens.make_loader()
843                        x_scale_loader = x_scale.make_loader()
844                        w_scale_loader = w_scale.make_loader()
845                        x_zp_loader = x_zp.make_loader()
846                        nonlocal bias
847                        bias_loader = None
848                        if bias is not None:
849                            bias_loader = bias.make_loader()
850
851                        def inner_fn(index):
852                            nonlocal bias
853                            input = input_loader(index)
854                            _x2 = x2_loader(index)
855                            _x_scale = x_scale_loader(())
856                            _x_zp = x_zp_loader(())
857
858                            # MicroKernel Output is with int32
859                            # cvt to FP32 before doing compensation
860                            input = ops.to_dtype(input, torch.float32)
861                            weight_compens_index = (index[-1],)
862                            _w_scale = w_scale_loader(weight_compens_index)
863                            _weight_compens = weight_compens_loader(
864                                weight_compens_index
865                            )
866                            # Step 1: Doing compensation to cvt fp32
867                            temp = ops.mul(
868                                ops.mul(
869                                    input,
870                                    _x_scale,
871                                ),
872                                _w_scale,
873                            )
874                            temp = ops.sub(
875                                temp,
876                                ops.mul(
877                                    ops.mul(
878                                        ops.mul(
879                                            _x_scale,
880                                            _w_scale,
881                                        ),
882                                        _x_zp,
883                                    ),
884                                    _weight_compens,
885                                ),
886                            )
887
888                            # Step 2: add Bias if applicable
889                            if bias is not None:
890                                _bias = bias_loader(weight_compens_index)
891                                nonlocal bias_dtype
892                                assert bias_dtype in [torch.float32, torch.bfloat16]
893                                if bias_dtype == torch.bfloat16:
894                                    _bias = ops.to_dtype(_bias, torch.float32)
895                                temp = ops.add(temp, _bias)
896
897                            # Step 3: Binary add
898                            nonlocal x2_dtype
899                            assert x2_dtype in [torch.float32, torch.bfloat16]
900                            if x2_dtype == torch.bfloat16:
901                                _x2 = ops.to_dtype(_x2, torch.float32)
902                            temp = ops.add(temp, _x2)
903
904                            return temp
905
906                        output_buf = ir.Pointwise(
907                            device=input_buffer.get_device(),
908                            dtype=torch.float32,  # Hardcode to FP32 for u8s8f32
909                            inner_fn=inner_fn,
910                            ranges=input_buffer.get_size(),
911                        )
912
913                        # Step 4: Unary post op if has
914                        if unary_attr != "none":
915                            output_buf = create_epilogue_with_attr(
916                                output_buf,
917                                unary_attr,
918                                scalars=unary_scalars,
919                                algorithm=unary_algorithmm,
920                            )
921
922                        # Step 5: Cast output to Target Dtype
923                        if output_dtype == torch.bfloat16:
924                            output_cast_loader = output_buf.make_loader()
925
926                            def inner_fn_cast_output_to_bf16(index):
927                                input = output_cast_loader(index)
928                                return ops.to_dtype(input, output_dtype)
929
930                            output_buf = ir.Pointwise(
931                                device=output_buf.get_device(),
932                                dtype=output_dtype,
933                                inner_fn=inner_fn_cast_output_to_bf16,
934                                ranges=output_buf.get_size(),
935                            )
936                        elif output_dtype == torch.uint8:
937                            from .lowering import _create_constants
938
939                            requant_input_loader = output_buf.make_loader()
940
941                            def inner_fn_requant(index, scale, zero_point):
942                                input = requant_input_loader(index)
943                                inv_scale, zero_point = _create_constants(
944                                    1.0 / scale, zero_point, dtype=torch.float32
945                                )
946                                val = ops.round(input * inv_scale) + zero_point
947                                qmin, qmax = _create_constants(
948                                    0, 255, dtype=torch.float32
949                                )
950                                clamped = ops.minimum(ops.maximum(val, qmin), qmax)
951                                return ops.to_dtype(clamped, torch.uint8)
952
953                            output_buf = ir.Pointwise(
954                                device=output_buf.get_device(),
955                                dtype=torch.uint8,
956                                inner_fn=functools.partial(
957                                    inner_fn_requant,
958                                    scale=float(o_scale),
959                                    zero_point=int(o_zero_point),
960                                ),
961                                ranges=output_buf.get_size(),
962                            )
963
964                        return output_buf
965
966                    CppPackedGemmTemplate.add_choices(
967                        choices,
968                        layout,
969                        [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
970                        if bias is None
971                        else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
972                        has_bias=bias is not None,
973                        epilogue_creator=epilogue_creator,
974                        # Reorder bias and x2
975                        input_indices=[0, 3, 1, 2, 4, 5, 6]
976                        if bias is None
977                        else [7, 0, 3, 1, 2, 4, 5, 6],
978                    )
979
980            if len(choices) == 0 or use_aten_gemm_kernels():
981                kwargs = dict(
982                    output_scale=o_scale,
983                    output_zero_point=o_zero_point,
984                    output_dtype=output_dtype,
985                    other_scale=x2_scale,
986                    other_zp=x2_zp,
987                    binary_post_op=binary_attr,
988                    binary_alpha=alpha,
989                    unary_post_op=unary_attr,
990                    unary_post_op_args=unary_scalars,
991                    unary_post_op_algorithm=unary_algorithmm,
992                )
993                if bias is None:
994                    kwargs["bias"] = None
995                choices.append(
996                    aten_mkldnn_qlinear_binary.bind(
997                        (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2)
998                        if bias is None
999                        else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias),
1000                        layout,
1001                        **kwargs,
1002                    )
1003                )
1004            assert packed_weight.get_name() in V.graph.constants
1005            input_gen_fns = {
1006                3: lambda x: V.graph.constants[x.get_name()],
1007                4: lambda x: V.graph.constants[x.get_name()],
1008                5: lambda x: V.graph.constants[x.get_name()],
1009            }
1010            if bias is not None:
1011                input_gen_fns[7] = lambda x: V.graph.constants[x.get_name()]  # For bias
1012            result = autotune_select_algorithm(
1013                "qlinear_binary",
1014                choices,
1015                [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
1016                if bias is None
1017                else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
1018                layout,
1019                input_gen_fns=input_gen_fns,
1020            )
1021            if len(x_size) > 2 and binary_attr == "add":
1022                result = view(result, (*x_size[:-1], result.get_size()[-1]))
1023            return result
1024
1025        if torch._C.has_mkl:
1026            aten_mkl_linear = ExternKernelChoice(
1027                torch.ops.mkl._mkl_linear,
1028                "mkl::_mkl_linear",
1029                has_out_variant=False,
1030                kernel_creator=mkldnn_ir.MKLPackedLinear.create,
1031            )
1032            cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear)
1033
1034            @register_lowering(torch.ops.mkl._mkl_linear)
1035            def mkl_packed_linear(
1036                x: TensorBox,
1037                packed_w: TensorBox,
1038                orig_w: TensorBox,
1039                b: Optional[TensorBox],
1040                batch_size,
1041                *,
1042                layout=None,
1043            ):
1044                choices: List[ChoiceCaller] = []
1045                if use_max_autotune():
1046                    transposed_w = permute(orig_w, [1, 0])
1047                    *_, layout, x, transposed_w = mm_args(
1048                        x, transposed_w, layout=layout
1049                    )
1050                    if use_cpp_packed_gemm_template(layout, x, transposed_w):
1051                        CppPackedGemmTemplate.add_choices(
1052                            choices,
1053                            layout,
1054                            [x, packed_w, orig_w],
1055                            trans_w=True,
1056                            input_indices=[0, 2],
1057                        )
1058
1059                if len(choices) == 0 or use_aten_gemm_kernels():
1060                    choices.append(
1061                        aten_mkl_linear.bind(
1062                            (x, packed_w, orig_w), layout, B=None, batch_size=batch_size
1063                        )
1064                    )
1065
1066                assert packed_w.get_name() in V.graph.constants
1067                assert orig_w.get_name() in V.graph.constants
1068                # packed_w is a mkldnn tensor which we can't generate directly
1069                # so we use the weights from the original tensor in autotune.
1070                input_gen_fns = {
1071                    1: lambda x: V.graph.constants[x.get_name()],
1072                    2: lambda x: V.graph.constants[x.get_name()],
1073                }
1074                result: TensorBox = autotune_select_algorithm(
1075                    "packed_linear",
1076                    choices,
1077                    [x, packed_w, orig_w],
1078                    layout,
1079                    input_gen_fns=input_gen_fns,
1080                )
1081                if b is not None:
1082                    result = add(result, b)
1083                return result
1084
1085        add_needs_realized_inputs(cpu_needs_realized_inputs)
1086    else:
1087        pass
1088