xref: /aosp_15_r20/external/executorch/backends/arm/tosa_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright 2023-2024 Arm Limited and/or its affiliates.
2*523fa7a6SAndroid Build Coastguard Worker#
3*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
4*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
5*523fa7a6SAndroid Build Coastguard Worker
6*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Workerimport logging
9*523fa7a6SAndroid Build Coastguard Workerimport os
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, cast
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport numpy as np
13*523fa7a6SAndroid Build Coastguard Workerimport serializer.tosa_serializer as ts
14*523fa7a6SAndroid Build Coastguard Workerimport torch
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_mapping import TosaArg
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_quant_utils import (
18*523fa7a6SAndroid Build Coastguard Worker    get_quant_arg_downstream,
19*523fa7a6SAndroid Build Coastguard Worker    get_quant_arg_upstream,
20*523fa7a6SAndroid Build Coastguard Worker    q_op,
21*523fa7a6SAndroid Build Coastguard Worker)
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops
23*523fa7a6SAndroid Build Coastguard Workerfrom serializer.tosa_serializer import TosaOp
24*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import Node
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__)
27*523fa7a6SAndroid Build Coastguard Workerlogger.setLevel(logging.WARNING)
28*523fa7a6SAndroid Build Coastguard WorkerTOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
29*523fa7a6SAndroid Build Coastguard Workerif TOSA_DBG_VERBOSE:
30*523fa7a6SAndroid Build Coastguard Worker    logging.basicConfig(level=logging.INFO)
31*523fa7a6SAndroid Build Coastguard Worker    logger.setLevel(logging.INFO)
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Workerdef dbg_node(node):
35*523fa7a6SAndroid Build Coastguard Worker    # Debug output of node information
36*523fa7a6SAndroid Build Coastguard Worker    logger.info("OP")
37*523fa7a6SAndroid Build Coastguard Worker    logger.info(f"  op is {node.op}")
38*523fa7a6SAndroid Build Coastguard Worker    logger.info(f"  name is {node.name}")
39*523fa7a6SAndroid Build Coastguard Worker    logger.info(f"  node target is {node.target}")
40*523fa7a6SAndroid Build Coastguard Worker    logger.info(f"  node args is {node.args}")
41*523fa7a6SAndroid Build Coastguard Worker    logger.info(f"  node kwargs is {node.kwargs}")
42*523fa7a6SAndroid Build Coastguard Worker    logger.info("  node.meta = ")
43*523fa7a6SAndroid Build Coastguard Worker    for k, v in node.meta.items():
44*523fa7a6SAndroid Build Coastguard Worker        logger.info(f"    '{k}' = {v}")
45*523fa7a6SAndroid Build Coastguard Worker        if isinstance(v, list):
46*523fa7a6SAndroid Build Coastguard Worker            for i in v:
47*523fa7a6SAndroid Build Coastguard Worker                logger.info(f"      {i} ")
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Worker# Output TOSA flatbuffer and test harness file
51*523fa7a6SAndroid Build Coastguard Workerdef dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
52*523fa7a6SAndroid Build Coastguard Worker    filename = f"output{suffix}.tosa"
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker    logger.info(f"Emitting debug output to: {path=}, {suffix=}")
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker    os.makedirs(path, exist_ok=True)
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Worker    fb = tosa_graph.serialize()
59*523fa7a6SAndroid Build Coastguard Worker    js = tosa_graph.writeJson(filename)
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Worker    filepath_tosa_fb = os.path.join(path, filename)
62*523fa7a6SAndroid Build Coastguard Worker    with open(filepath_tosa_fb, "wb") as f:
63*523fa7a6SAndroid Build Coastguard Worker        f.write(fb)
64*523fa7a6SAndroid Build Coastguard Worker    assert os.path.exists(filepath_tosa_fb), "Failed to write TOSA flatbuffer"
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker    filepath_desc_json = os.path.join(path, f"desc{suffix}.json")
67*523fa7a6SAndroid Build Coastguard Worker    with open(filepath_desc_json, "w") as f:
68*523fa7a6SAndroid Build Coastguard Worker        f.write(js)
69*523fa7a6SAndroid Build Coastguard Worker    assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON"
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Workerdef dbg_fail(node, tosa_graph, path):
73*523fa7a6SAndroid Build Coastguard Worker    dbg_tosa_dump(tosa_graph, path)
74*523fa7a6SAndroid Build Coastguard Worker    logger.warn("Internal error due to poorly handled node:")
75*523fa7a6SAndroid Build Coastguard Worker    dbg_node(node)
76*523fa7a6SAndroid Build Coastguard Worker    logger.warn(f"Debug output captured in '{path}'.")
77*523fa7a6SAndroid Build Coastguard Worker    raise RuntimeError("TOSA Internal Error on node, enable logging for further info.")
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker# Helper function to match TOSA's broadcasting rank requirement
81*523fa7a6SAndroid Build Coastguard Worker# Ref: TOSA 0.80.0 specification - 1.9.3. Data Layouts from
82*523fa7a6SAndroid Build Coastguard Worker# https://www.mlplatform.org/tosa/tosa_spec.html
83*523fa7a6SAndroid Build Coastguard Workerdef promote_shape(tosa_fb, arg, promoted_shape, out_dtype):
84*523fa7a6SAndroid Build Coastguard Worker    assert np.prod(arg.shape) == np.prod(promoted_shape), "Incompatible promoted shape"
85*523fa7a6SAndroid Build Coastguard Worker    reshape_res = tosa_fb.addIntermediate(promoted_shape, out_dtype)
86*523fa7a6SAndroid Build Coastguard Worker    attr = ts.TosaSerializerAttribute()
87*523fa7a6SAndroid Build Coastguard Worker    attr.ReshapeAttribute(promoted_shape)
88*523fa7a6SAndroid Build Coastguard Worker    tosa_fb.addOperator(TosaOp.Op().RESHAPE, [arg.name], [reshape_res.name], attr)
89*523fa7a6SAndroid Build Coastguard Worker    return reshape_res
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker# Helper transpose function to match TOSA's shape requirements
93*523fa7a6SAndroid Build Coastguard Worker# E.g., TOSA 0.80.0 specification - 2.3.3 CONV2D shapes:
94*523fa7a6SAndroid Build Coastguard Worker# https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d
95*523fa7a6SAndroid Build Coastguard Workerdef transpose_helper(tosa_fb, input, new_order, out_dtype):
96*523fa7a6SAndroid Build Coastguard Worker    # Check new_order's length is equal to input rank
97*523fa7a6SAndroid Build Coastguard Worker    assert len(input.shape) == len(new_order), "Wrong shape order length"
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker    # Check no duplications
100*523fa7a6SAndroid Build Coastguard Worker    assert len(set(new_order)) == len(new_order), "Contain duplicated dim numbers"
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker    # Check all dims are valid
103*523fa7a6SAndroid Build Coastguard Worker    for idx in new_order:
104*523fa7a6SAndroid Build Coastguard Worker        if idx < 0:
105*523fa7a6SAndroid Build Coastguard Worker            assert True, "Negative dim number"
106*523fa7a6SAndroid Build Coastguard Worker        elif idx >= len(input.shape):
107*523fa7a6SAndroid Build Coastguard Worker            assert True, "Dim is greater than input rank"
108*523fa7a6SAndroid Build Coastguard Worker
109*523fa7a6SAndroid Build Coastguard Worker    input_shape_transpoed = [input.shape[i] for i in new_order]
110*523fa7a6SAndroid Build Coastguard Worker    attr = ts.TosaSerializerAttribute()
111*523fa7a6SAndroid Build Coastguard Worker    attr.TransposeAttribute(new_order)
112*523fa7a6SAndroid Build Coastguard Worker    input_transposed = tosa_fb.addIntermediate(input_shape_transpoed, out_dtype)
113*523fa7a6SAndroid Build Coastguard Worker    tosa_fb.addOperator(
114*523fa7a6SAndroid Build Coastguard Worker        TosaOp.Op().TRANSPOSE, [input.name], [input_transposed.name], attr
115*523fa7a6SAndroid Build Coastguard Worker    )
116*523fa7a6SAndroid Build Coastguard Worker    return input_transposed
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Worker
119*523fa7a6SAndroid Build Coastguard Workerdef getNodeArgs(node: Node) -> list[TosaArg]:
120*523fa7a6SAndroid Build Coastguard Worker    return [TosaArg(arg) for arg in node.args]
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Worker
123*523fa7a6SAndroid Build Coastguard Workerdef get_input_tensor(node: Node) -> TosaArg:
124*523fa7a6SAndroid Build Coastguard Worker    return TosaArg(node.args[0])
125*523fa7a6SAndroid Build Coastguard Worker
126*523fa7a6SAndroid Build Coastguard Worker
127*523fa7a6SAndroid Build Coastguard Workerdef get_output_node(node: Node) -> Node:
128*523fa7a6SAndroid Build Coastguard Worker    return list(node.users)[0]
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker
131*523fa7a6SAndroid Build Coastguard Worker""" TOSA reshape returns a tensor with the same type/values as the input.
132*523fa7a6SAndroid Build Coastguard Worker    No data conversion happens during a reshape operation. """
133*523fa7a6SAndroid Build Coastguard Worker
134*523fa7a6SAndroid Build Coastguard Worker
135*523fa7a6SAndroid Build Coastguard Workerdef build_reshape(tosa_fb, input_name, new_shape, output_name):
136*523fa7a6SAndroid Build Coastguard Worker    attr = ts.TosaSerializerAttribute()
137*523fa7a6SAndroid Build Coastguard Worker    attr.ReshapeAttribute(new_shape)
138*523fa7a6SAndroid Build Coastguard Worker    tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr)
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Workerdef is_bias_node_for_quantized_conv(node):
142*523fa7a6SAndroid Build Coastguard Worker    consumer_node = list(node.users)[0]
143*523fa7a6SAndroid Build Coastguard Worker    return (
144*523fa7a6SAndroid Build Coastguard Worker        consumer_node.target == exir_ops.edge.aten.convolution.default
145*523fa7a6SAndroid Build Coastguard Worker        and list(consumer_node.users)[0].target == q_op
146*523fa7a6SAndroid Build Coastguard Worker    )
147*523fa7a6SAndroid Build Coastguard Worker
148*523fa7a6SAndroid Build Coastguard Worker
149*523fa7a6SAndroid Build Coastguard Workerdef is_consumer_node_depthwise_conv2d(node):
150*523fa7a6SAndroid Build Coastguard Worker    consumer_node = list(node.users)[0]
151*523fa7a6SAndroid Build Coastguard Worker    if consumer_node.target == exir_ops.edge.aten.convolution.default:
152*523fa7a6SAndroid Build Coastguard Worker        inputs = getNodeArgs(consumer_node)
153*523fa7a6SAndroid Build Coastguard Worker        group = inputs[-1]
154*523fa7a6SAndroid Build Coastguard Worker        in_channels = inputs[0].shape[1]
155*523fa7a6SAndroid Build Coastguard Worker        out_channels = inputs[1].shape[0]
156*523fa7a6SAndroid Build Coastguard Worker        if (in_channels == group.number) and (out_channels % in_channels) == 0:
157*523fa7a6SAndroid Build Coastguard Worker            return True
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Worker    return False
160*523fa7a6SAndroid Build Coastguard Worker
161*523fa7a6SAndroid Build Coastguard Worker
162*523fa7a6SAndroid Build Coastguard Workerdef build_avg_pool_2d_common(
163*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node,
164*523fa7a6SAndroid Build Coastguard Worker    tosa_graph: ts.TosaSerializer,
165*523fa7a6SAndroid Build Coastguard Worker    input_tensor: TosaArg,
166*523fa7a6SAndroid Build Coastguard Worker    kernel_size: list,
167*523fa7a6SAndroid Build Coastguard Worker    stride: list,
168*523fa7a6SAndroid Build Coastguard Worker    padding: list,
169*523fa7a6SAndroid Build Coastguard Worker    is_quant_node: bool,
170*523fa7a6SAndroid Build Coastguard Worker    output: TosaArg,
171*523fa7a6SAndroid Build Coastguard Worker):
172*523fa7a6SAndroid Build Coastguard Worker    accumulator_type = input_tensor.dtype
173*523fa7a6SAndroid Build Coastguard Worker
174*523fa7a6SAndroid Build Coastguard Worker    if is_quant_node:
175*523fa7a6SAndroid Build Coastguard Worker        # Accumulator type always is int32 when input tensor is an integer type.
176*523fa7a6SAndroid Build Coastguard Worker        accumulator_type = ts.DType.INT32
177*523fa7a6SAndroid Build Coastguard Worker
178*523fa7a6SAndroid Build Coastguard Worker    # Initilize zero point to zero.
179*523fa7a6SAndroid Build Coastguard Worker    input_zp = 0
180*523fa7a6SAndroid Build Coastguard Worker    output_zp = 0
181*523fa7a6SAndroid Build Coastguard Worker
182*523fa7a6SAndroid Build Coastguard Worker    if is_quant_node:
183*523fa7a6SAndroid Build Coastguard Worker        input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp
184*523fa7a6SAndroid Build Coastguard Worker        output_zp = get_quant_arg_downstream(list(node.users)[0]).zp
185*523fa7a6SAndroid Build Coastguard Worker
186*523fa7a6SAndroid Build Coastguard Worker    attr = ts.TosaSerializerAttribute()
187*523fa7a6SAndroid Build Coastguard Worker    attr.PoolAttribute(
188*523fa7a6SAndroid Build Coastguard Worker        kernel=kernel_size,
189*523fa7a6SAndroid Build Coastguard Worker        stride=stride,
190*523fa7a6SAndroid Build Coastguard Worker        pad=padding,
191*523fa7a6SAndroid Build Coastguard Worker        input_zp=input_zp,
192*523fa7a6SAndroid Build Coastguard Worker        output_zp=output_zp,
193*523fa7a6SAndroid Build Coastguard Worker        accum_dtype=accumulator_type,
194*523fa7a6SAndroid Build Coastguard Worker    )
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker    tosa_graph.addOperator(
197*523fa7a6SAndroid Build Coastguard Worker        TosaOp.Op().AVG_POOL2D,
198*523fa7a6SAndroid Build Coastguard Worker        [input_tensor.name],
199*523fa7a6SAndroid Build Coastguard Worker        [output.name],
200*523fa7a6SAndroid Build Coastguard Worker        attr,
201*523fa7a6SAndroid Build Coastguard Worker    )
202*523fa7a6SAndroid Build Coastguard Worker
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Workerdef get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]:
205*523fa7a6SAndroid Build Coastguard Worker    """Returns two input nodes to 'node' in order. If 'node' only has one input,
206*523fa7a6SAndroid Build Coastguard Worker    it is returned twice.
207*523fa7a6SAndroid Build Coastguard Worker
208*523fa7a6SAndroid Build Coastguard Worker    Fails if there are no input nodes.
209*523fa7a6SAndroid Build Coastguard Worker    Fails if there are >2 input nodes and 'check' is True,
210*523fa7a6SAndroid Build Coastguard Worker    """
211*523fa7a6SAndroid Build Coastguard Worker
212*523fa7a6SAndroid Build Coastguard Worker    num_inputs = len(node.all_input_nodes)
213*523fa7a6SAndroid Build Coastguard Worker    assert num_inputs > 0, f"Node '{node.name}' requires >0 input, got {num_inputs}."
214*523fa7a6SAndroid Build Coastguard Worker
215*523fa7a6SAndroid Build Coastguard Worker    input1 = node.all_input_nodes[0]
216*523fa7a6SAndroid Build Coastguard Worker    if num_inputs == 1:
217*523fa7a6SAndroid Build Coastguard Worker        input2 = node.all_input_nodes[0]
218*523fa7a6SAndroid Build Coastguard Worker    else:
219*523fa7a6SAndroid Build Coastguard Worker        input2 = node.all_input_nodes[1]
220*523fa7a6SAndroid Build Coastguard Worker    if check:
221*523fa7a6SAndroid Build Coastguard Worker        assert (
222*523fa7a6SAndroid Build Coastguard Worker            num_inputs <= 2
223*523fa7a6SAndroid Build Coastguard Worker        ), f"Node '{node.name}' requires <=2 inputs, got {num_inputs}."
224*523fa7a6SAndroid Build Coastguard Worker
225*523fa7a6SAndroid Build Coastguard Worker    return input1, input2
226*523fa7a6SAndroid Build Coastguard Worker
227*523fa7a6SAndroid Build Coastguard Worker
228*523fa7a6SAndroid Build Coastguard Workerdef tosa_shape(shape, dim_order):
229*523fa7a6SAndroid Build Coastguard Worker    return tuple([shape[dim] for dim in dim_order])
230*523fa7a6SAndroid Build Coastguard Worker
231*523fa7a6SAndroid Build Coastguard Worker
232*523fa7a6SAndroid Build Coastguard Workerdef expand_dims(
233*523fa7a6SAndroid Build Coastguard Worker    tosa_graph: ts.TosaSerializer,
234*523fa7a6SAndroid Build Coastguard Worker    input_node: TosaArg,
235*523fa7a6SAndroid Build Coastguard Worker    dtype: int,
236*523fa7a6SAndroid Build Coastguard Worker    dim: int,
237*523fa7a6SAndroid Build Coastguard Worker) -> Any:
238*523fa7a6SAndroid Build Coastguard Worker    """Inserts TOSA operators into the tosa_graph, that perform the equivalent
239*523fa7a6SAndroid Build Coastguard Worker    of the expand_dims (a.k.a unsqueeze) operation. A new axis is created at the
240*523fa7a6SAndroid Build Coastguard Worker    dim location.
241*523fa7a6SAndroid Build Coastguard Worker
242*523fa7a6SAndroid Build Coastguard Worker    Args:
243*523fa7a6SAndroid Build Coastguard Worker        tosa_graph (ts.TosaSerializer): The TOSA graph to manipulate.
244*523fa7a6SAndroid Build Coastguard Worker        input_node (TosaArg): The parent node of the expand dim operations.
245*523fa7a6SAndroid Build Coastguard Worker        dtype (ts.DType): The data type expand dims operations.
246*523fa7a6SAndroid Build Coastguard Worker        dim (int): The dimension to expand.
247*523fa7a6SAndroid Build Coastguard Worker
248*523fa7a6SAndroid Build Coastguard Worker    Returns:
249*523fa7a6SAndroid Build Coastguard Worker        Any: The output tensor of the inserted operation in the TOSA graph.
250*523fa7a6SAndroid Build Coastguard Worker    """
251*523fa7a6SAndroid Build Coastguard Worker    new_shape = list(input_node.shape)
252*523fa7a6SAndroid Build Coastguard Worker    new_shape.insert(dim, 1)
253*523fa7a6SAndroid Build Coastguard Worker
254*523fa7a6SAndroid Build Coastguard Worker    intermediate = tosa_graph.addIntermediate(new_shape, dtype)
255*523fa7a6SAndroid Build Coastguard Worker
256*523fa7a6SAndroid Build Coastguard Worker    build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name)
257*523fa7a6SAndroid Build Coastguard Worker
258*523fa7a6SAndroid Build Coastguard Worker    return intermediate
259*523fa7a6SAndroid Build Coastguard Worker
260*523fa7a6SAndroid Build Coastguard Worker
261*523fa7a6SAndroid Build Coastguard Workerdef get_resize_parameters(
262*523fa7a6SAndroid Build Coastguard Worker    input_size: torch.Tensor,
263*523fa7a6SAndroid Build Coastguard Worker    output_size: torch.Tensor,
264*523fa7a6SAndroid Build Coastguard Worker    resize_mode: int,
265*523fa7a6SAndroid Build Coastguard Worker    align_corners: bool,
266*523fa7a6SAndroid Build Coastguard Worker):
267*523fa7a6SAndroid Build Coastguard Worker    """Get the tosa.resize parameters based on the input and output size.
268*523fa7a6SAndroid Build Coastguard Worker
269*523fa7a6SAndroid Build Coastguard Worker    Args:
270*523fa7a6SAndroid Build Coastguard Worker        input_size (torch.Tensor): Size of the input
271*523fa7a6SAndroid Build Coastguard Worker        output_size (torch.Tensor): Size of the output
272*523fa7a6SAndroid Build Coastguard Worker        resize_mode (tosa.ResizeMode): The TOSA resize mode
273*523fa7a6SAndroid Build Coastguard Worker        align_corners (bool): Align the corners pixels of the input and output
274*523fa7a6SAndroid Build Coastguard Worker
275*523fa7a6SAndroid Build Coastguard Worker    Returns:
276*523fa7a6SAndroid Build Coastguard Worker        scale_n (torch.Tensor), scale_d (torch.Tensor),
277*523fa7a6SAndroid Build Coastguard Worker        offset (torch.Tensor), border (torch.Tensor)
278*523fa7a6SAndroid Build Coastguard Worker    """
279*523fa7a6SAndroid Build Coastguard Worker    assert torch.all(input_size > 0)
280*523fa7a6SAndroid Build Coastguard Worker    assert torch.all(output_size > 0)
281*523fa7a6SAndroid Build Coastguard Worker
282*523fa7a6SAndroid Build Coastguard Worker    scale_n = torch.tensor(
283*523fa7a6SAndroid Build Coastguard Worker        [
284*523fa7a6SAndroid Build Coastguard Worker            so - 1 if align_corners and si > 1 and so > 1 else so
285*523fa7a6SAndroid Build Coastguard Worker            for si, so in zip(input_size, output_size)
286*523fa7a6SAndroid Build Coastguard Worker        ]
287*523fa7a6SAndroid Build Coastguard Worker    )
288*523fa7a6SAndroid Build Coastguard Worker    scale_d = torch.tensor(
289*523fa7a6SAndroid Build Coastguard Worker        [
290*523fa7a6SAndroid Build Coastguard Worker            si - 1 if align_corners and si > 1 and so > 1 else si
291*523fa7a6SAndroid Build Coastguard Worker            for si, so in zip(input_size, output_size)
292*523fa7a6SAndroid Build Coastguard Worker        ]
293*523fa7a6SAndroid Build Coastguard Worker    )
294*523fa7a6SAndroid Build Coastguard Worker
295*523fa7a6SAndroid Build Coastguard Worker    gcd = torch.gcd(scale_n, scale_d)
296*523fa7a6SAndroid Build Coastguard Worker    scale_n = scale_n // gcd
297*523fa7a6SAndroid Build Coastguard Worker    scale_d = scale_d // gcd
298*523fa7a6SAndroid Build Coastguard Worker
299*523fa7a6SAndroid Build Coastguard Worker    # No half-pixel centre support in PyTorch, no offset needed
300*523fa7a6SAndroid Build Coastguard Worker    offset = torch.zeros_like(input_size)
301*523fa7a6SAndroid Build Coastguard Worker    border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset
302*523fa7a6SAndroid Build Coastguard Worker
303*523fa7a6SAndroid Build Coastguard Worker    return scale_n, scale_d, offset, border
304