1# Copyright © 2020 Arm Ltd. All rights reserved. 2# SPDX-License-Identifier: MIT 3import pytest 4import pyarmnn as ann 5 6 7def test_tensor_shape_tuple(): 8 tensor_shape = ann.TensorShape((1, 2, 3)) 9 10 assert 3 == tensor_shape.GetNumDimensions() 11 assert 6 == tensor_shape.GetNumElements() 12 13 14def test_tensor_shape_one(): 15 tensor_shape = ann.TensorShape((10,)) 16 assert 1 == tensor_shape.GetNumDimensions() 17 assert 10 == tensor_shape.GetNumElements() 18 19 20def test_tensor_shape_empty(): 21 with pytest.raises(RuntimeError) as err: 22 ann.TensorShape(()) 23 24 assert "Tensor numDimensions must be greater than 0" in str(err.value) 25 26 27def test_tensor_shape_tuple_mess(): 28 tensor_shape = ann.TensorShape((1, "2", 3.0)) 29 30 assert 3 == tensor_shape.GetNumDimensions() 31 assert 6 == tensor_shape.GetNumElements() 32 33 34def test_tensor_shape_list(): 35 36 with pytest.raises(TypeError) as err: 37 ann.TensorShape([1, 2, 3]) 38 39 assert "Argument is not a tuple" in str(err.value) 40 41 42def test_tensor_shape_tuple_mess_fail(): 43 44 with pytest.raises(TypeError) as err: 45 ann.TensorShape((1, "two", 3.0)) 46 47 assert "All elements must be numbers" in str(err.value) 48 49 50def test_tensor_shape_varags(): 51 with pytest.raises(TypeError) as err: 52 ann.TensorShape(1, 2, 3) 53 54 assert "__init__() takes 2 positional arguments but 4 were given" in str(err.value) 55 56 57def test_tensor_shape__get_item_out_of_bounds(): 58 tensor_shape = ann.TensorShape((1, 2, 3)) 59 with pytest.raises(ValueError) as err: 60 for i in range(4): 61 tensor_shape[i] 62 63 assert "Invalid dimension index: 3 (number of dimensions is 3)" in str(err.value) 64 65 66def test_tensor_shape__set_item_out_of_bounds(): 67 tensor_shape = ann.TensorShape((1, 2, 3)) 68 with pytest.raises(ValueError) as err: 69 for i in range(4): 70 tensor_shape[i] = 1 71 72 assert "Invalid dimension index: 3 (number of dimensions is 3)" in str(err.value) 73 74 75def test_tensor_shape___str__(): 76 tensor_shape = ann.TensorShape((1, 2, 3)) 77 78 assert str(tensor_shape) == "TensorShape{Shape(1, 2, 3), NumDimensions: 3, NumElements: 6}" 79