1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Workerimport typing 9*523fa7a6SAndroid Build Coastguard Workerfrom typing import List 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport torch 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import ( 14*523fa7a6SAndroid Build Coastguard Worker AllocationDetails, 15*523fa7a6SAndroid Build Coastguard Worker Chain, 16*523fa7a6SAndroid Build Coastguard Worker ContainerMetadata, 17*523fa7a6SAndroid Build Coastguard Worker EValue, 18*523fa7a6SAndroid Build Coastguard Worker ExecutionPlan, 19*523fa7a6SAndroid Build Coastguard Worker Instruction, 20*523fa7a6SAndroid Build Coastguard Worker Int, 21*523fa7a6SAndroid Build Coastguard Worker KernelCall, 22*523fa7a6SAndroid Build Coastguard Worker Null, 23*523fa7a6SAndroid Build Coastguard Worker Operator, 24*523fa7a6SAndroid Build Coastguard Worker Program, 25*523fa7a6SAndroid Build Coastguard Worker ScalarType, 26*523fa7a6SAndroid Build Coastguard Worker String, 27*523fa7a6SAndroid Build Coastguard Worker SubsegmentOffsets, 28*523fa7a6SAndroid Build Coastguard Worker Tensor, 29*523fa7a6SAndroid Build Coastguard Worker TensorShapeDynamism, 30*523fa7a6SAndroid Build Coastguard Worker) 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Workerdef get_test_program() -> Program: 34*523fa7a6SAndroid Build Coastguard Worker return Program( 35*523fa7a6SAndroid Build Coastguard Worker version=0, 36*523fa7a6SAndroid Build Coastguard Worker execution_plan=[ 37*523fa7a6SAndroid Build Coastguard Worker ExecutionPlan( 38*523fa7a6SAndroid Build Coastguard Worker name="forward", 39*523fa7a6SAndroid Build Coastguard Worker values=[ 40*523fa7a6SAndroid Build Coastguard Worker EValue(Int(1)), 41*523fa7a6SAndroid Build Coastguard Worker EValue(Int(0)), 42*523fa7a6SAndroid Build Coastguard Worker EValue(Null()), 43*523fa7a6SAndroid Build Coastguard Worker EValue(String("pass")), 44*523fa7a6SAndroid Build Coastguard Worker EValue( 45*523fa7a6SAndroid Build Coastguard Worker val=Tensor( 46*523fa7a6SAndroid Build Coastguard Worker scalar_type=ScalarType.FLOAT, 47*523fa7a6SAndroid Build Coastguard Worker storage_offset=0, 48*523fa7a6SAndroid Build Coastguard Worker sizes=[2, 2], 49*523fa7a6SAndroid Build Coastguard Worker dim_order=typing.cast(List[bytes], [0, 1]), 50*523fa7a6SAndroid Build Coastguard Worker requires_grad=False, 51*523fa7a6SAndroid Build Coastguard Worker layout=0, 52*523fa7a6SAndroid Build Coastguard Worker data_buffer_idx=0, 53*523fa7a6SAndroid Build Coastguard Worker allocation_info=AllocationDetails( 54*523fa7a6SAndroid Build Coastguard Worker memory_id=1, 55*523fa7a6SAndroid Build Coastguard Worker memory_offset_high=0, 56*523fa7a6SAndroid Build Coastguard Worker memory_offset_low=16, 57*523fa7a6SAndroid Build Coastguard Worker ), 58*523fa7a6SAndroid Build Coastguard Worker shape_dynamism=TensorShapeDynamism.STATIC, 59*523fa7a6SAndroid Build Coastguard Worker ) 60*523fa7a6SAndroid Build Coastguard Worker ), 61*523fa7a6SAndroid Build Coastguard Worker ], 62*523fa7a6SAndroid Build Coastguard Worker inputs=[0], 63*523fa7a6SAndroid Build Coastguard Worker outputs=[1], 64*523fa7a6SAndroid Build Coastguard Worker chains=[ 65*523fa7a6SAndroid Build Coastguard Worker Chain( 66*523fa7a6SAndroid Build Coastguard Worker inputs=[], 67*523fa7a6SAndroid Build Coastguard Worker outputs=[], 68*523fa7a6SAndroid Build Coastguard Worker instructions=[Instruction(KernelCall(op_index=0, args=[0, 1]))], 69*523fa7a6SAndroid Build Coastguard Worker stacktrace=None, 70*523fa7a6SAndroid Build Coastguard Worker ) 71*523fa7a6SAndroid Build Coastguard Worker ], 72*523fa7a6SAndroid Build Coastguard Worker container_meta_type=ContainerMetadata( 73*523fa7a6SAndroid Build Coastguard Worker encoded_inp_str="place", encoded_out_str="place" 74*523fa7a6SAndroid Build Coastguard Worker ), 75*523fa7a6SAndroid Build Coastguard Worker operators=[Operator(name="aten::add", overload="Tensor")], 76*523fa7a6SAndroid Build Coastguard Worker delegates=[], 77*523fa7a6SAndroid Build Coastguard Worker non_const_buffer_sizes=[0, 1024], 78*523fa7a6SAndroid Build Coastguard Worker ) 79*523fa7a6SAndroid Build Coastguard Worker ], 80*523fa7a6SAndroid Build Coastguard Worker constant_buffer=[], 81*523fa7a6SAndroid Build Coastguard Worker backend_delegate_data=[], 82*523fa7a6SAndroid Build Coastguard Worker segments=[], 83*523fa7a6SAndroid Build Coastguard Worker constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]), 84*523fa7a6SAndroid Build Coastguard Worker ) 85*523fa7a6SAndroid Build Coastguard Worker 86*523fa7a6SAndroid Build Coastguard Worker 87*523fa7a6SAndroid Build Coastguard Workerdef register_additional_test_aten_ops() -> None: 88*523fa7a6SAndroid Build Coastguard Worker # TODO: either mark those ops as canonical in native_functions.yaml, 89*523fa7a6SAndroid Build Coastguard Worker # or stop using graphs with those in tests. 90*523fa7a6SAndroid Build Coastguard Worker canonical = torch.Tag.core 91*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.max.default.tags.append(canonical) 92*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.sum.default.tags.append(canonical) 93*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.searchsorted.Tensor.tags.append(canonical) 94*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.ones_like.default.tags.append(canonical) 95*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.upsample_nearest2d.default.tags.append(canonical) 96*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.index.Tensor.tags.append(canonical) 97*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.addbmm.default.tags.append(canonical) 98