xref: /aosp_15_r20/external/executorch/exir/tests/common.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8import typing
9from typing import List
10
11import torch
12
13from executorch.exir.schema import (
14    AllocationDetails,
15    Chain,
16    ContainerMetadata,
17    EValue,
18    ExecutionPlan,
19    Instruction,
20    Int,
21    KernelCall,
22    Null,
23    Operator,
24    Program,
25    ScalarType,
26    String,
27    SubsegmentOffsets,
28    Tensor,
29    TensorShapeDynamism,
30)
31
32
33def get_test_program() -> Program:
34    return Program(
35        version=0,
36        execution_plan=[
37            ExecutionPlan(
38                name="forward",
39                values=[
40                    EValue(Int(1)),
41                    EValue(Int(0)),
42                    EValue(Null()),
43                    EValue(String("pass")),
44                    EValue(
45                        val=Tensor(
46                            scalar_type=ScalarType.FLOAT,
47                            storage_offset=0,
48                            sizes=[2, 2],
49                            dim_order=typing.cast(List[bytes], [0, 1]),
50                            requires_grad=False,
51                            layout=0,
52                            data_buffer_idx=0,
53                            allocation_info=AllocationDetails(
54                                memory_id=1,
55                                memory_offset_high=0,
56                                memory_offset_low=16,
57                            ),
58                            shape_dynamism=TensorShapeDynamism.STATIC,
59                        )
60                    ),
61                ],
62                inputs=[0],
63                outputs=[1],
64                chains=[
65                    Chain(
66                        inputs=[],
67                        outputs=[],
68                        instructions=[Instruction(KernelCall(op_index=0, args=[0, 1]))],
69                        stacktrace=None,
70                    )
71                ],
72                container_meta_type=ContainerMetadata(
73                    encoded_inp_str="place", encoded_out_str="place"
74                ),
75                operators=[Operator(name="aten::add", overload="Tensor")],
76                delegates=[],
77                non_const_buffer_sizes=[0, 1024],
78            )
79        ],
80        constant_buffer=[],
81        backend_delegate_data=[],
82        segments=[],
83        constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]),
84    )
85
86
87def register_additional_test_aten_ops() -> None:
88    # TODO: either mark those ops as canonical in native_functions.yaml,
89    # or stop using graphs with those in tests.
90    canonical = torch.Tag.core
91    torch.ops.aten.max.default.tags.append(canonical)
92    torch.ops.aten.sum.default.tags.append(canonical)
93    torch.ops.aten.searchsorted.Tensor.tags.append(canonical)
94    torch.ops.aten.ones_like.default.tags.append(canonical)
95    torch.ops.aten.upsample_nearest2d.default.tags.append(canonical)
96    torch.ops.aten.index.Tensor.tags.append(canonical)
97    torch.ops.aten.addbmm.default.tags.append(canonical)
98