1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestMeanDim(unittest.TestCase): 14 class MeanDim(torch.nn.Module): 15 def __init__(self, dims): 16 super().__init__() 17 self.dims = dims 18 19 def forward(self, x): 20 y = x + x 21 z = torch.mean(y, self.dims, keepdim=True) 22 return z 23 24 def _test_mean_dim(self, inputs): 25 ( 26 Tester(self.MeanDim((-1, -2)), inputs) 27 .export() 28 .check_count({"torch.ops.aten.mean.dim": 1}) 29 .to_edge_transform_and_lower() 30 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 31 .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) 32 .to_executorch() 33 .serialize() 34 .run_method_and_compare_outputs() 35 ) 36 37 def test_fp16_mean_dim(self): 38 inputs = (torch.randn(1, 5, 4, 4).to(torch.float16),) 39 self._test_mean_dim(inputs) 40 41 def test_fp32_mean_dim(self): 42 inputs = (torch.randn(1, 5, 4, 4),) 43 self._test_mean_dim(inputs) 44 45 def test_fp32_mean_dim_unsupported(self): 46 """ 47 XNNPack mean.dim implementation only supports innermost two dimensions. As such, 48 we expect it to fail to partition when dim=(3). 49 """ 50 inputs = (torch.randn(1, 5, 4, 4),) 51 ( 52 Tester(self.MeanDim((3)), inputs) 53 .export() 54 .check_count({"torch.ops.aten.mean.dim": 1}) 55 .to_edge_transform_and_lower() 56 .check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1}) 57 ) 58 59 def test_fp32_mean_dim_unsupported_3d(self): 60 """ 61 XNNPack mean.dim implementation only supports 4D tensors. 62 """ 63 inputs = (torch.randn(1, 5, 4),) 64 ( 65 Tester(self.MeanDim((-1, -2)), inputs) 66 .export() 67 .check_count({"torch.ops.aten.mean.dim": 1}) 68 .to_edge_transform_and_lower() 69 .check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1}) 70 ) 71 72 def test_qs8_mean_dim(self): 73 inputs = (torch.randn(1, 5, 4, 4),) 74 ( 75 Tester(self.MeanDim((-1, -2)), inputs) 76 .quantize() 77 .export() 78 .check_node_count( 79 { 80 torch.ops.aten.mean.dim: 1, 81 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 82 } 83 ) 84 .to_edge_transform_and_lower() 85 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 86 .check_not( 87 [ 88 "executorch_exir_dialects_edge__ops_aten_mean_dim", 89 "torch.ops.quantized_decomposed", 90 ] 91 ) 92 .to_executorch() 93 .serialize() 94 .run_method_and_compare_outputs(qtol=1) 95 ) 96