1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8import typing 9from typing import List 10 11import torch 12 13from executorch.exir.schema import ( 14 AllocationDetails, 15 Chain, 16 ContainerMetadata, 17 EValue, 18 ExecutionPlan, 19 Instruction, 20 Int, 21 KernelCall, 22 Null, 23 Operator, 24 Program, 25 ScalarType, 26 String, 27 SubsegmentOffsets, 28 Tensor, 29 TensorShapeDynamism, 30) 31 32 33def get_test_program() -> Program: 34 return Program( 35 version=0, 36 execution_plan=[ 37 ExecutionPlan( 38 name="forward", 39 values=[ 40 EValue(Int(1)), 41 EValue(Int(0)), 42 EValue(Null()), 43 EValue(String("pass")), 44 EValue( 45 val=Tensor( 46 scalar_type=ScalarType.FLOAT, 47 storage_offset=0, 48 sizes=[2, 2], 49 dim_order=typing.cast(List[bytes], [0, 1]), 50 requires_grad=False, 51 layout=0, 52 data_buffer_idx=0, 53 allocation_info=AllocationDetails( 54 memory_id=1, 55 memory_offset_high=0, 56 memory_offset_low=16, 57 ), 58 shape_dynamism=TensorShapeDynamism.STATIC, 59 ) 60 ), 61 ], 62 inputs=[0], 63 outputs=[1], 64 chains=[ 65 Chain( 66 inputs=[], 67 outputs=[], 68 instructions=[Instruction(KernelCall(op_index=0, args=[0, 1]))], 69 stacktrace=None, 70 ) 71 ], 72 container_meta_type=ContainerMetadata( 73 encoded_inp_str="place", encoded_out_str="place" 74 ), 75 operators=[Operator(name="aten::add", overload="Tensor")], 76 delegates=[], 77 non_const_buffer_sizes=[0, 1024], 78 ) 79 ], 80 constant_buffer=[], 81 backend_delegate_data=[], 82 segments=[], 83 constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]), 84 ) 85 86 87def register_additional_test_aten_ops() -> None: 88 # TODO: either mark those ops as canonical in native_functions.yaml, 89 # or stop using graphs with those in tests. 90 canonical = torch.Tag.core 91 torch.ops.aten.max.default.tags.append(canonical) 92 torch.ops.aten.sum.default.tags.append(canonical) 93 torch.ops.aten.searchsorted.Tensor.tags.append(canonical) 94 torch.ops.aten.ones_like.default.tags.append(canonical) 95 torch.ops.aten.upsample_nearest2d.default.tags.append(canonical) 96 torch.ops.aten.index.Tensor.tags.append(canonical) 97 torch.ops.aten.addbmm.default.tags.append(canonical) 98