1*523fa7a6SAndroid Build Coastguard Worker# Copyright 2023-2024 Arm Limited and/or its affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# 3*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 4*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 5*523fa7a6SAndroid Build Coastguard Worker 6*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe 7*523fa7a6SAndroid Build Coastguard Worker 8*523fa7a6SAndroid Build Coastguard Workerimport logging 9*523fa7a6SAndroid Build Coastguard Workerimport os 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, cast 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport numpy as np 13*523fa7a6SAndroid Build Coastguard Workerimport serializer.tosa_serializer as ts 14*523fa7a6SAndroid Build Coastguard Workerimport torch 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_mapping import TosaArg 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_quant_utils import ( 18*523fa7a6SAndroid Build Coastguard Worker get_quant_arg_downstream, 19*523fa7a6SAndroid Build Coastguard Worker get_quant_arg_upstream, 20*523fa7a6SAndroid Build Coastguard Worker q_op, 21*523fa7a6SAndroid Build Coastguard Worker) 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops 23*523fa7a6SAndroid Build Coastguard Workerfrom serializer.tosa_serializer import TosaOp 24*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import Node 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__) 27*523fa7a6SAndroid Build Coastguard Workerlogger.setLevel(logging.WARNING) 28*523fa7a6SAndroid Build Coastguard WorkerTOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" 29*523fa7a6SAndroid Build Coastguard Workerif TOSA_DBG_VERBOSE: 30*523fa7a6SAndroid Build Coastguard Worker logging.basicConfig(level=logging.INFO) 31*523fa7a6SAndroid Build Coastguard Worker logger.setLevel(logging.INFO) 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Workerdef dbg_node(node): 35*523fa7a6SAndroid Build Coastguard Worker # Debug output of node information 36*523fa7a6SAndroid Build Coastguard Worker logger.info("OP") 37*523fa7a6SAndroid Build Coastguard Worker logger.info(f" op is {node.op}") 38*523fa7a6SAndroid Build Coastguard Worker logger.info(f" name is {node.name}") 39*523fa7a6SAndroid Build Coastguard Worker logger.info(f" node target is {node.target}") 40*523fa7a6SAndroid Build Coastguard Worker logger.info(f" node args is {node.args}") 41*523fa7a6SAndroid Build Coastguard Worker logger.info(f" node kwargs is {node.kwargs}") 42*523fa7a6SAndroid Build Coastguard Worker logger.info(" node.meta = ") 43*523fa7a6SAndroid Build Coastguard Worker for k, v in node.meta.items(): 44*523fa7a6SAndroid Build Coastguard Worker logger.info(f" '{k}' = {v}") 45*523fa7a6SAndroid Build Coastguard Worker if isinstance(v, list): 46*523fa7a6SAndroid Build Coastguard Worker for i in v: 47*523fa7a6SAndroid Build Coastguard Worker logger.info(f" {i} ") 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard Worker# Output TOSA flatbuffer and test harness file 51*523fa7a6SAndroid Build Coastguard Workerdef dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): 52*523fa7a6SAndroid Build Coastguard Worker filename = f"output{suffix}.tosa" 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker logger.info(f"Emitting debug output to: {path=}, {suffix=}") 55*523fa7a6SAndroid Build Coastguard Worker 56*523fa7a6SAndroid Build Coastguard Worker os.makedirs(path, exist_ok=True) 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker fb = tosa_graph.serialize() 59*523fa7a6SAndroid Build Coastguard Worker js = tosa_graph.writeJson(filename) 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard Worker filepath_tosa_fb = os.path.join(path, filename) 62*523fa7a6SAndroid Build Coastguard Worker with open(filepath_tosa_fb, "wb") as f: 63*523fa7a6SAndroid Build Coastguard Worker f.write(fb) 64*523fa7a6SAndroid Build Coastguard Worker assert os.path.exists(filepath_tosa_fb), "Failed to write TOSA flatbuffer" 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Worker filepath_desc_json = os.path.join(path, f"desc{suffix}.json") 67*523fa7a6SAndroid Build Coastguard Worker with open(filepath_desc_json, "w") as f: 68*523fa7a6SAndroid Build Coastguard Worker f.write(js) 69*523fa7a6SAndroid Build Coastguard Worker assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON" 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Workerdef dbg_fail(node, tosa_graph, path): 73*523fa7a6SAndroid Build Coastguard Worker dbg_tosa_dump(tosa_graph, path) 74*523fa7a6SAndroid Build Coastguard Worker logger.warn("Internal error due to poorly handled node:") 75*523fa7a6SAndroid Build Coastguard Worker dbg_node(node) 76*523fa7a6SAndroid Build Coastguard Worker logger.warn(f"Debug output captured in '{path}'.") 77*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError("TOSA Internal Error on node, enable logging for further info.") 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker# Helper function to match TOSA's broadcasting rank requirement 81*523fa7a6SAndroid Build Coastguard Worker# Ref: TOSA 0.80.0 specification - 1.9.3. Data Layouts from 82*523fa7a6SAndroid Build Coastguard Worker# https://www.mlplatform.org/tosa/tosa_spec.html 83*523fa7a6SAndroid Build Coastguard Workerdef promote_shape(tosa_fb, arg, promoted_shape, out_dtype): 84*523fa7a6SAndroid Build Coastguard Worker assert np.prod(arg.shape) == np.prod(promoted_shape), "Incompatible promoted shape" 85*523fa7a6SAndroid Build Coastguard Worker reshape_res = tosa_fb.addIntermediate(promoted_shape, out_dtype) 86*523fa7a6SAndroid Build Coastguard Worker attr = ts.TosaSerializerAttribute() 87*523fa7a6SAndroid Build Coastguard Worker attr.ReshapeAttribute(promoted_shape) 88*523fa7a6SAndroid Build Coastguard Worker tosa_fb.addOperator(TosaOp.Op().RESHAPE, [arg.name], [reshape_res.name], attr) 89*523fa7a6SAndroid Build Coastguard Worker return reshape_res 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker# Helper transpose function to match TOSA's shape requirements 93*523fa7a6SAndroid Build Coastguard Worker# E.g., TOSA 0.80.0 specification - 2.3.3 CONV2D shapes: 94*523fa7a6SAndroid Build Coastguard Worker# https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d 95*523fa7a6SAndroid Build Coastguard Workerdef transpose_helper(tosa_fb, input, new_order, out_dtype): 96*523fa7a6SAndroid Build Coastguard Worker # Check new_order's length is equal to input rank 97*523fa7a6SAndroid Build Coastguard Worker assert len(input.shape) == len(new_order), "Wrong shape order length" 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker # Check no duplications 100*523fa7a6SAndroid Build Coastguard Worker assert len(set(new_order)) == len(new_order), "Contain duplicated dim numbers" 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker # Check all dims are valid 103*523fa7a6SAndroid Build Coastguard Worker for idx in new_order: 104*523fa7a6SAndroid Build Coastguard Worker if idx < 0: 105*523fa7a6SAndroid Build Coastguard Worker assert True, "Negative dim number" 106*523fa7a6SAndroid Build Coastguard Worker elif idx >= len(input.shape): 107*523fa7a6SAndroid Build Coastguard Worker assert True, "Dim is greater than input rank" 108*523fa7a6SAndroid Build Coastguard Worker 109*523fa7a6SAndroid Build Coastguard Worker input_shape_transpoed = [input.shape[i] for i in new_order] 110*523fa7a6SAndroid Build Coastguard Worker attr = ts.TosaSerializerAttribute() 111*523fa7a6SAndroid Build Coastguard Worker attr.TransposeAttribute(new_order) 112*523fa7a6SAndroid Build Coastguard Worker input_transposed = tosa_fb.addIntermediate(input_shape_transpoed, out_dtype) 113*523fa7a6SAndroid Build Coastguard Worker tosa_fb.addOperator( 114*523fa7a6SAndroid Build Coastguard Worker TosaOp.Op().TRANSPOSE, [input.name], [input_transposed.name], attr 115*523fa7a6SAndroid Build Coastguard Worker ) 116*523fa7a6SAndroid Build Coastguard Worker return input_transposed 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Worker 119*523fa7a6SAndroid Build Coastguard Workerdef getNodeArgs(node: Node) -> list[TosaArg]: 120*523fa7a6SAndroid Build Coastguard Worker return [TosaArg(arg) for arg in node.args] 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker 123*523fa7a6SAndroid Build Coastguard Workerdef get_input_tensor(node: Node) -> TosaArg: 124*523fa7a6SAndroid Build Coastguard Worker return TosaArg(node.args[0]) 125*523fa7a6SAndroid Build Coastguard Worker 126*523fa7a6SAndroid Build Coastguard Worker 127*523fa7a6SAndroid Build Coastguard Workerdef get_output_node(node: Node) -> Node: 128*523fa7a6SAndroid Build Coastguard Worker return list(node.users)[0] 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker 131*523fa7a6SAndroid Build Coastguard Worker""" TOSA reshape returns a tensor with the same type/values as the input. 132*523fa7a6SAndroid Build Coastguard Worker No data conversion happens during a reshape operation. """ 133*523fa7a6SAndroid Build Coastguard Worker 134*523fa7a6SAndroid Build Coastguard Worker 135*523fa7a6SAndroid Build Coastguard Workerdef build_reshape(tosa_fb, input_name, new_shape, output_name): 136*523fa7a6SAndroid Build Coastguard Worker attr = ts.TosaSerializerAttribute() 137*523fa7a6SAndroid Build Coastguard Worker attr.ReshapeAttribute(new_shape) 138*523fa7a6SAndroid Build Coastguard Worker tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr) 139*523fa7a6SAndroid Build Coastguard Worker 140*523fa7a6SAndroid Build Coastguard Worker 141*523fa7a6SAndroid Build Coastguard Workerdef is_bias_node_for_quantized_conv(node): 142*523fa7a6SAndroid Build Coastguard Worker consumer_node = list(node.users)[0] 143*523fa7a6SAndroid Build Coastguard Worker return ( 144*523fa7a6SAndroid Build Coastguard Worker consumer_node.target == exir_ops.edge.aten.convolution.default 145*523fa7a6SAndroid Build Coastguard Worker and list(consumer_node.users)[0].target == q_op 146*523fa7a6SAndroid Build Coastguard Worker ) 147*523fa7a6SAndroid Build Coastguard Worker 148*523fa7a6SAndroid Build Coastguard Worker 149*523fa7a6SAndroid Build Coastguard Workerdef is_consumer_node_depthwise_conv2d(node): 150*523fa7a6SAndroid Build Coastguard Worker consumer_node = list(node.users)[0] 151*523fa7a6SAndroid Build Coastguard Worker if consumer_node.target == exir_ops.edge.aten.convolution.default: 152*523fa7a6SAndroid Build Coastguard Worker inputs = getNodeArgs(consumer_node) 153*523fa7a6SAndroid Build Coastguard Worker group = inputs[-1] 154*523fa7a6SAndroid Build Coastguard Worker in_channels = inputs[0].shape[1] 155*523fa7a6SAndroid Build Coastguard Worker out_channels = inputs[1].shape[0] 156*523fa7a6SAndroid Build Coastguard Worker if (in_channels == group.number) and (out_channels % in_channels) == 0: 157*523fa7a6SAndroid Build Coastguard Worker return True 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Worker return False 160*523fa7a6SAndroid Build Coastguard Worker 161*523fa7a6SAndroid Build Coastguard Worker 162*523fa7a6SAndroid Build Coastguard Workerdef build_avg_pool_2d_common( 163*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 164*523fa7a6SAndroid Build Coastguard Worker tosa_graph: ts.TosaSerializer, 165*523fa7a6SAndroid Build Coastguard Worker input_tensor: TosaArg, 166*523fa7a6SAndroid Build Coastguard Worker kernel_size: list, 167*523fa7a6SAndroid Build Coastguard Worker stride: list, 168*523fa7a6SAndroid Build Coastguard Worker padding: list, 169*523fa7a6SAndroid Build Coastguard Worker is_quant_node: bool, 170*523fa7a6SAndroid Build Coastguard Worker output: TosaArg, 171*523fa7a6SAndroid Build Coastguard Worker): 172*523fa7a6SAndroid Build Coastguard Worker accumulator_type = input_tensor.dtype 173*523fa7a6SAndroid Build Coastguard Worker 174*523fa7a6SAndroid Build Coastguard Worker if is_quant_node: 175*523fa7a6SAndroid Build Coastguard Worker # Accumulator type always is int32 when input tensor is an integer type. 176*523fa7a6SAndroid Build Coastguard Worker accumulator_type = ts.DType.INT32 177*523fa7a6SAndroid Build Coastguard Worker 178*523fa7a6SAndroid Build Coastguard Worker # Initilize zero point to zero. 179*523fa7a6SAndroid Build Coastguard Worker input_zp = 0 180*523fa7a6SAndroid Build Coastguard Worker output_zp = 0 181*523fa7a6SAndroid Build Coastguard Worker 182*523fa7a6SAndroid Build Coastguard Worker if is_quant_node: 183*523fa7a6SAndroid Build Coastguard Worker input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp 184*523fa7a6SAndroid Build Coastguard Worker output_zp = get_quant_arg_downstream(list(node.users)[0]).zp 185*523fa7a6SAndroid Build Coastguard Worker 186*523fa7a6SAndroid Build Coastguard Worker attr = ts.TosaSerializerAttribute() 187*523fa7a6SAndroid Build Coastguard Worker attr.PoolAttribute( 188*523fa7a6SAndroid Build Coastguard Worker kernel=kernel_size, 189*523fa7a6SAndroid Build Coastguard Worker stride=stride, 190*523fa7a6SAndroid Build Coastguard Worker pad=padding, 191*523fa7a6SAndroid Build Coastguard Worker input_zp=input_zp, 192*523fa7a6SAndroid Build Coastguard Worker output_zp=output_zp, 193*523fa7a6SAndroid Build Coastguard Worker accum_dtype=accumulator_type, 194*523fa7a6SAndroid Build Coastguard Worker ) 195*523fa7a6SAndroid Build Coastguard Worker 196*523fa7a6SAndroid Build Coastguard Worker tosa_graph.addOperator( 197*523fa7a6SAndroid Build Coastguard Worker TosaOp.Op().AVG_POOL2D, 198*523fa7a6SAndroid Build Coastguard Worker [input_tensor.name], 199*523fa7a6SAndroid Build Coastguard Worker [output.name], 200*523fa7a6SAndroid Build Coastguard Worker attr, 201*523fa7a6SAndroid Build Coastguard Worker ) 202*523fa7a6SAndroid Build Coastguard Worker 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Workerdef get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: 205*523fa7a6SAndroid Build Coastguard Worker """Returns two input nodes to 'node' in order. If 'node' only has one input, 206*523fa7a6SAndroid Build Coastguard Worker it is returned twice. 207*523fa7a6SAndroid Build Coastguard Worker 208*523fa7a6SAndroid Build Coastguard Worker Fails if there are no input nodes. 209*523fa7a6SAndroid Build Coastguard Worker Fails if there are >2 input nodes and 'check' is True, 210*523fa7a6SAndroid Build Coastguard Worker """ 211*523fa7a6SAndroid Build Coastguard Worker 212*523fa7a6SAndroid Build Coastguard Worker num_inputs = len(node.all_input_nodes) 213*523fa7a6SAndroid Build Coastguard Worker assert num_inputs > 0, f"Node '{node.name}' requires >0 input, got {num_inputs}." 214*523fa7a6SAndroid Build Coastguard Worker 215*523fa7a6SAndroid Build Coastguard Worker input1 = node.all_input_nodes[0] 216*523fa7a6SAndroid Build Coastguard Worker if num_inputs == 1: 217*523fa7a6SAndroid Build Coastguard Worker input2 = node.all_input_nodes[0] 218*523fa7a6SAndroid Build Coastguard Worker else: 219*523fa7a6SAndroid Build Coastguard Worker input2 = node.all_input_nodes[1] 220*523fa7a6SAndroid Build Coastguard Worker if check: 221*523fa7a6SAndroid Build Coastguard Worker assert ( 222*523fa7a6SAndroid Build Coastguard Worker num_inputs <= 2 223*523fa7a6SAndroid Build Coastguard Worker ), f"Node '{node.name}' requires <=2 inputs, got {num_inputs}." 224*523fa7a6SAndroid Build Coastguard Worker 225*523fa7a6SAndroid Build Coastguard Worker return input1, input2 226*523fa7a6SAndroid Build Coastguard Worker 227*523fa7a6SAndroid Build Coastguard Worker 228*523fa7a6SAndroid Build Coastguard Workerdef tosa_shape(shape, dim_order): 229*523fa7a6SAndroid Build Coastguard Worker return tuple([shape[dim] for dim in dim_order]) 230*523fa7a6SAndroid Build Coastguard Worker 231*523fa7a6SAndroid Build Coastguard Worker 232*523fa7a6SAndroid Build Coastguard Workerdef expand_dims( 233*523fa7a6SAndroid Build Coastguard Worker tosa_graph: ts.TosaSerializer, 234*523fa7a6SAndroid Build Coastguard Worker input_node: TosaArg, 235*523fa7a6SAndroid Build Coastguard Worker dtype: int, 236*523fa7a6SAndroid Build Coastguard Worker dim: int, 237*523fa7a6SAndroid Build Coastguard Worker) -> Any: 238*523fa7a6SAndroid Build Coastguard Worker """Inserts TOSA operators into the tosa_graph, that perform the equivalent 239*523fa7a6SAndroid Build Coastguard Worker of the expand_dims (a.k.a unsqueeze) operation. A new axis is created at the 240*523fa7a6SAndroid Build Coastguard Worker dim location. 241*523fa7a6SAndroid Build Coastguard Worker 242*523fa7a6SAndroid Build Coastguard Worker Args: 243*523fa7a6SAndroid Build Coastguard Worker tosa_graph (ts.TosaSerializer): The TOSA graph to manipulate. 244*523fa7a6SAndroid Build Coastguard Worker input_node (TosaArg): The parent node of the expand dim operations. 245*523fa7a6SAndroid Build Coastguard Worker dtype (ts.DType): The data type expand dims operations. 246*523fa7a6SAndroid Build Coastguard Worker dim (int): The dimension to expand. 247*523fa7a6SAndroid Build Coastguard Worker 248*523fa7a6SAndroid Build Coastguard Worker Returns: 249*523fa7a6SAndroid Build Coastguard Worker Any: The output tensor of the inserted operation in the TOSA graph. 250*523fa7a6SAndroid Build Coastguard Worker """ 251*523fa7a6SAndroid Build Coastguard Worker new_shape = list(input_node.shape) 252*523fa7a6SAndroid Build Coastguard Worker new_shape.insert(dim, 1) 253*523fa7a6SAndroid Build Coastguard Worker 254*523fa7a6SAndroid Build Coastguard Worker intermediate = tosa_graph.addIntermediate(new_shape, dtype) 255*523fa7a6SAndroid Build Coastguard Worker 256*523fa7a6SAndroid Build Coastguard Worker build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name) 257*523fa7a6SAndroid Build Coastguard Worker 258*523fa7a6SAndroid Build Coastguard Worker return intermediate 259*523fa7a6SAndroid Build Coastguard Worker 260*523fa7a6SAndroid Build Coastguard Worker 261*523fa7a6SAndroid Build Coastguard Workerdef get_resize_parameters( 262*523fa7a6SAndroid Build Coastguard Worker input_size: torch.Tensor, 263*523fa7a6SAndroid Build Coastguard Worker output_size: torch.Tensor, 264*523fa7a6SAndroid Build Coastguard Worker resize_mode: int, 265*523fa7a6SAndroid Build Coastguard Worker align_corners: bool, 266*523fa7a6SAndroid Build Coastguard Worker): 267*523fa7a6SAndroid Build Coastguard Worker """Get the tosa.resize parameters based on the input and output size. 268*523fa7a6SAndroid Build Coastguard Worker 269*523fa7a6SAndroid Build Coastguard Worker Args: 270*523fa7a6SAndroid Build Coastguard Worker input_size (torch.Tensor): Size of the input 271*523fa7a6SAndroid Build Coastguard Worker output_size (torch.Tensor): Size of the output 272*523fa7a6SAndroid Build Coastguard Worker resize_mode (tosa.ResizeMode): The TOSA resize mode 273*523fa7a6SAndroid Build Coastguard Worker align_corners (bool): Align the corners pixels of the input and output 274*523fa7a6SAndroid Build Coastguard Worker 275*523fa7a6SAndroid Build Coastguard Worker Returns: 276*523fa7a6SAndroid Build Coastguard Worker scale_n (torch.Tensor), scale_d (torch.Tensor), 277*523fa7a6SAndroid Build Coastguard Worker offset (torch.Tensor), border (torch.Tensor) 278*523fa7a6SAndroid Build Coastguard Worker """ 279*523fa7a6SAndroid Build Coastguard Worker assert torch.all(input_size > 0) 280*523fa7a6SAndroid Build Coastguard Worker assert torch.all(output_size > 0) 281*523fa7a6SAndroid Build Coastguard Worker 282*523fa7a6SAndroid Build Coastguard Worker scale_n = torch.tensor( 283*523fa7a6SAndroid Build Coastguard Worker [ 284*523fa7a6SAndroid Build Coastguard Worker so - 1 if align_corners and si > 1 and so > 1 else so 285*523fa7a6SAndroid Build Coastguard Worker for si, so in zip(input_size, output_size) 286*523fa7a6SAndroid Build Coastguard Worker ] 287*523fa7a6SAndroid Build Coastguard Worker ) 288*523fa7a6SAndroid Build Coastguard Worker scale_d = torch.tensor( 289*523fa7a6SAndroid Build Coastguard Worker [ 290*523fa7a6SAndroid Build Coastguard Worker si - 1 if align_corners and si > 1 and so > 1 else si 291*523fa7a6SAndroid Build Coastguard Worker for si, so in zip(input_size, output_size) 292*523fa7a6SAndroid Build Coastguard Worker ] 293*523fa7a6SAndroid Build Coastguard Worker ) 294*523fa7a6SAndroid Build Coastguard Worker 295*523fa7a6SAndroid Build Coastguard Worker gcd = torch.gcd(scale_n, scale_d) 296*523fa7a6SAndroid Build Coastguard Worker scale_n = scale_n // gcd 297*523fa7a6SAndroid Build Coastguard Worker scale_d = scale_d // gcd 298*523fa7a6SAndroid Build Coastguard Worker 299*523fa7a6SAndroid Build Coastguard Worker # No half-pixel centre support in PyTorch, no offset needed 300*523fa7a6SAndroid Build Coastguard Worker offset = torch.zeros_like(input_size) 301*523fa7a6SAndroid Build Coastguard Worker border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset 302*523fa7a6SAndroid Build Coastguard Worker 303*523fa7a6SAndroid Build Coastguard Worker return scale_n, scale_d, offset, border 304