xref: /aosp_15_r20/external/executorch/exir/tests/common.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Workerimport typing
9*523fa7a6SAndroid Build Coastguard Workerfrom typing import List
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport torch
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import (
14*523fa7a6SAndroid Build Coastguard Worker    AllocationDetails,
15*523fa7a6SAndroid Build Coastguard Worker    Chain,
16*523fa7a6SAndroid Build Coastguard Worker    ContainerMetadata,
17*523fa7a6SAndroid Build Coastguard Worker    EValue,
18*523fa7a6SAndroid Build Coastguard Worker    ExecutionPlan,
19*523fa7a6SAndroid Build Coastguard Worker    Instruction,
20*523fa7a6SAndroid Build Coastguard Worker    Int,
21*523fa7a6SAndroid Build Coastguard Worker    KernelCall,
22*523fa7a6SAndroid Build Coastguard Worker    Null,
23*523fa7a6SAndroid Build Coastguard Worker    Operator,
24*523fa7a6SAndroid Build Coastguard Worker    Program,
25*523fa7a6SAndroid Build Coastguard Worker    ScalarType,
26*523fa7a6SAndroid Build Coastguard Worker    String,
27*523fa7a6SAndroid Build Coastguard Worker    SubsegmentOffsets,
28*523fa7a6SAndroid Build Coastguard Worker    Tensor,
29*523fa7a6SAndroid Build Coastguard Worker    TensorShapeDynamism,
30*523fa7a6SAndroid Build Coastguard Worker)
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Workerdef get_test_program() -> Program:
34*523fa7a6SAndroid Build Coastguard Worker    return Program(
35*523fa7a6SAndroid Build Coastguard Worker        version=0,
36*523fa7a6SAndroid Build Coastguard Worker        execution_plan=[
37*523fa7a6SAndroid Build Coastguard Worker            ExecutionPlan(
38*523fa7a6SAndroid Build Coastguard Worker                name="forward",
39*523fa7a6SAndroid Build Coastguard Worker                values=[
40*523fa7a6SAndroid Build Coastguard Worker                    EValue(Int(1)),
41*523fa7a6SAndroid Build Coastguard Worker                    EValue(Int(0)),
42*523fa7a6SAndroid Build Coastguard Worker                    EValue(Null()),
43*523fa7a6SAndroid Build Coastguard Worker                    EValue(String("pass")),
44*523fa7a6SAndroid Build Coastguard Worker                    EValue(
45*523fa7a6SAndroid Build Coastguard Worker                        val=Tensor(
46*523fa7a6SAndroid Build Coastguard Worker                            scalar_type=ScalarType.FLOAT,
47*523fa7a6SAndroid Build Coastguard Worker                            storage_offset=0,
48*523fa7a6SAndroid Build Coastguard Worker                            sizes=[2, 2],
49*523fa7a6SAndroid Build Coastguard Worker                            dim_order=typing.cast(List[bytes], [0, 1]),
50*523fa7a6SAndroid Build Coastguard Worker                            requires_grad=False,
51*523fa7a6SAndroid Build Coastguard Worker                            layout=0,
52*523fa7a6SAndroid Build Coastguard Worker                            data_buffer_idx=0,
53*523fa7a6SAndroid Build Coastguard Worker                            allocation_info=AllocationDetails(
54*523fa7a6SAndroid Build Coastguard Worker                                memory_id=1,
55*523fa7a6SAndroid Build Coastguard Worker                                memory_offset_high=0,
56*523fa7a6SAndroid Build Coastguard Worker                                memory_offset_low=16,
57*523fa7a6SAndroid Build Coastguard Worker                            ),
58*523fa7a6SAndroid Build Coastguard Worker                            shape_dynamism=TensorShapeDynamism.STATIC,
59*523fa7a6SAndroid Build Coastguard Worker                        )
60*523fa7a6SAndroid Build Coastguard Worker                    ),
61*523fa7a6SAndroid Build Coastguard Worker                ],
62*523fa7a6SAndroid Build Coastguard Worker                inputs=[0],
63*523fa7a6SAndroid Build Coastguard Worker                outputs=[1],
64*523fa7a6SAndroid Build Coastguard Worker                chains=[
65*523fa7a6SAndroid Build Coastguard Worker                    Chain(
66*523fa7a6SAndroid Build Coastguard Worker                        inputs=[],
67*523fa7a6SAndroid Build Coastguard Worker                        outputs=[],
68*523fa7a6SAndroid Build Coastguard Worker                        instructions=[Instruction(KernelCall(op_index=0, args=[0, 1]))],
69*523fa7a6SAndroid Build Coastguard Worker                        stacktrace=None,
70*523fa7a6SAndroid Build Coastguard Worker                    )
71*523fa7a6SAndroid Build Coastguard Worker                ],
72*523fa7a6SAndroid Build Coastguard Worker                container_meta_type=ContainerMetadata(
73*523fa7a6SAndroid Build Coastguard Worker                    encoded_inp_str="place", encoded_out_str="place"
74*523fa7a6SAndroid Build Coastguard Worker                ),
75*523fa7a6SAndroid Build Coastguard Worker                operators=[Operator(name="aten::add", overload="Tensor")],
76*523fa7a6SAndroid Build Coastguard Worker                delegates=[],
77*523fa7a6SAndroid Build Coastguard Worker                non_const_buffer_sizes=[0, 1024],
78*523fa7a6SAndroid Build Coastguard Worker            )
79*523fa7a6SAndroid Build Coastguard Worker        ],
80*523fa7a6SAndroid Build Coastguard Worker        constant_buffer=[],
81*523fa7a6SAndroid Build Coastguard Worker        backend_delegate_data=[],
82*523fa7a6SAndroid Build Coastguard Worker        segments=[],
83*523fa7a6SAndroid Build Coastguard Worker        constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]),
84*523fa7a6SAndroid Build Coastguard Worker    )
85*523fa7a6SAndroid Build Coastguard Worker
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard Workerdef register_additional_test_aten_ops() -> None:
88*523fa7a6SAndroid Build Coastguard Worker    # TODO: either mark those ops as canonical in native_functions.yaml,
89*523fa7a6SAndroid Build Coastguard Worker    # or stop using graphs with those in tests.
90*523fa7a6SAndroid Build Coastguard Worker    canonical = torch.Tag.core
91*523fa7a6SAndroid Build Coastguard Worker    torch.ops.aten.max.default.tags.append(canonical)
92*523fa7a6SAndroid Build Coastguard Worker    torch.ops.aten.sum.default.tags.append(canonical)
93*523fa7a6SAndroid Build Coastguard Worker    torch.ops.aten.searchsorted.Tensor.tags.append(canonical)
94*523fa7a6SAndroid Build Coastguard Worker    torch.ops.aten.ones_like.default.tags.append(canonical)
95*523fa7a6SAndroid Build Coastguard Worker    torch.ops.aten.upsample_nearest2d.default.tags.append(canonical)
96*523fa7a6SAndroid Build Coastguard Worker    torch.ops.aten.index.Tensor.tags.append(canonical)
97*523fa7a6SAndroid Build Coastguard Worker    torch.ops.aten.addbmm.default.tags.append(canonical)
98