1"""Framework classes for generation of bignum mod_raw test cases.""" 2# Copyright The Mbed TLS Contributors 3# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later 4# 5 6from typing import Iterator, List 7 8from . import test_case 9from . import test_data_generation 10from . import bignum_common 11from .bignum_data import ONLY_PRIME_MODULI 12 13class BignumModRawTarget(test_data_generation.BaseTarget): 14 #pylint: disable=abstract-method, too-few-public-methods 15 """Target for bignum mod_raw test case generation.""" 16 target_basename = 'test_suite_bignum_mod_raw.generated' 17 18 19class BignumModRawSub(bignum_common.ModOperationCommon, 20 BignumModRawTarget): 21 """Test cases for bignum mpi_mod_raw_sub().""" 22 symbol = "-" 23 test_function = "mpi_mod_raw_sub" 24 test_name = "mbedtls_mpi_mod_raw_sub" 25 input_style = "fixed" 26 arity = 2 27 28 def arguments(self) -> List[str]: 29 return [bignum_common.quote_str(n) for n in [self.arg_a, 30 self.arg_b, 31 self.arg_n] 32 ] + self.result() 33 34 def result(self) -> List[str]: 35 result = (self.int_a - self.int_b) % self.int_n 36 return [self.format_result(result)] 37 38class BignumModRawFixQuasiReduction(bignum_common.ModOperationCommon, 39 BignumModRawTarget): 40 """Test cases for ecp quasi_reduction().""" 41 symbol = "-" 42 test_function = "mpi_mod_raw_fix_quasi_reduction" 43 test_name = "fix_quasi_reduction" 44 input_style = "fixed" 45 arity = 1 46 47 # Extend the default values with n < x < 2n 48 input_values = bignum_common.ModOperationCommon.input_values + [ 49 "73", 50 51 # First number generated by random.getrandbits(1024) - seed(3,2) 52 "ea7b5bf55eb561a4216363698b529b4a97b750923ceb3ffd", 53 54 # First number generated by random.getrandbits(1024) - seed(1,2) 55 ("cd447e35b8b6d8fe442e3d437204e52db2221a58008a05a6c4647159c324c985" 56 "9b810e766ec9d28663ca828dd5f4b3b2e4b06ce60741c7a87ce42c8218072e8c" 57 "35bf992dc9e9c616612e7696a6cecc1b78e510617311d8a3c2ce6f447ed4d57b" 58 "1e2feb89414c343c1027c4d1c386bbc4cd613e30d8f16adf91b7584a2265b1f5") 59 ] # type: List[str] 60 61 def result(self) -> List[str]: 62 result = self.int_a % self.int_n 63 return [self.format_result(result)] 64 65 @property 66 def is_valid(self) -> bool: 67 return bool(self.int_a < 2 * self.int_n) 68 69class BignumModRawMul(bignum_common.ModOperationCommon, 70 BignumModRawTarget): 71 """Test cases for bignum mpi_mod_raw_mul().""" 72 symbol = "*" 73 test_function = "mpi_mod_raw_mul" 74 test_name = "mbedtls_mpi_mod_raw_mul" 75 input_style = "arch_split" 76 arity = 2 77 78 def arguments(self) -> List[str]: 79 return [self.format_result(self.to_montgomery(self.int_a)), 80 self.format_result(self.to_montgomery(self.int_b)), 81 bignum_common.quote_str(self.arg_n) 82 ] + self.result() 83 84 def result(self) -> List[str]: 85 result = (self.int_a * self.int_b) % self.int_n 86 return [self.format_result(self.to_montgomery(result))] 87 88 89class BignumModRawInvPrime(bignum_common.ModOperationCommon, 90 BignumModRawTarget): 91 """Test cases for bignum mpi_mod_raw_inv_prime().""" 92 moduli = ONLY_PRIME_MODULI 93 symbol = "^ -1" 94 test_function = "mpi_mod_raw_inv_prime" 95 test_name = "mbedtls_mpi_mod_raw_inv_prime (Montgomery form only)" 96 input_style = "arch_split" 97 arity = 1 98 suffix = True 99 montgomery_form_a = True 100 disallow_zero_a = True 101 102 def result(self) -> List[str]: 103 result = bignum_common.invmod_positive(self.int_a, self.int_n) 104 mont_result = self.to_montgomery(result) 105 return [self.format_result(mont_result)] 106 107 108class BignumModRawAdd(bignum_common.ModOperationCommon, 109 BignumModRawTarget): 110 """Test cases for bignum mpi_mod_raw_add().""" 111 symbol = "+" 112 test_function = "mpi_mod_raw_add" 113 test_name = "mbedtls_mpi_mod_raw_add" 114 input_style = "fixed" 115 arity = 2 116 117 def result(self) -> List[str]: 118 result = (self.int_a + self.int_b) % self.int_n 119 return [self.format_result(result)] 120 121 122class BignumModRawConvertRep(bignum_common.ModOperationCommon, 123 BignumModRawTarget): 124 # This is an abstract class, it's ok to have unimplemented methods. 125 #pylint: disable=abstract-method 126 """Test cases for representation conversion.""" 127 symbol = "" 128 input_style = "arch_split" 129 arity = 1 130 rep = bignum_common.ModulusRepresentation.INVALID 131 132 def set_representation(self, r: bignum_common.ModulusRepresentation) -> None: 133 self.rep = r 134 135 def arguments(self) -> List[str]: 136 return ([bignum_common.quote_str(self.arg_n), self.rep.symbol(), 137 bignum_common.quote_str(self.arg_a)] + 138 self.result()) 139 140 def description(self) -> str: 141 base = super().description() 142 mod_with_rep = 'mod({})'.format(self.rep.name) 143 return base.replace('mod', mod_with_rep, 1) 144 145 @classmethod 146 def test_cases_for_values(cls, rep: bignum_common.ModulusRepresentation, 147 n: str, a: str) -> Iterator[test_case.TestCase]: 148 """Emit test cases for the given values (if any). 149 150 This may emit no test cases if a isn't valid for the modulus n, 151 or multiple test cases if rep requires different data depending 152 on the limb size. 153 """ 154 for bil in cls.limb_sizes: 155 test_object = cls(n, a, bits_in_limb=bil) 156 test_object.set_representation(rep) 157 # The class is set to having separate test cases for each limb 158 # size, because the Montgomery representation requires it. 159 # But other representations don't require it. So for other 160 # representations, emit a single test case with no dependency 161 # on the limb size. 162 if rep is not bignum_common.ModulusRepresentation.MONTGOMERY: 163 test_object.dependencies = \ 164 [dep for dep in test_object.dependencies 165 if not dep.startswith('MBEDTLS_HAVE_INT')] 166 if test_object.is_valid: 167 yield test_object.create_test_case() 168 if rep is not bignum_common.ModulusRepresentation.MONTGOMERY: 169 # A single test case (emitted, or skipped due to invalidity) 170 # is enough, since this test case doesn't depend on the 171 # limb size. 172 break 173 174 # The parent class doesn't support non-bignum parameters. So we override 175 # test generation, in order to have the representation as a parameter. 176 @classmethod 177 def generate_function_tests(cls) -> Iterator[test_case.TestCase]: 178 179 for rep in bignum_common.ModulusRepresentation.supported_representations(): 180 for n in cls.moduli: 181 for a in cls.input_values: 182 yield from cls.test_cases_for_values(rep, n, a) 183 184class BignumModRawCanonicalToModulusRep(BignumModRawConvertRep): 185 """Test cases for mpi_mod_raw_canonical_to_modulus_rep.""" 186 test_function = "mpi_mod_raw_canonical_to_modulus_rep" 187 test_name = "Rep canon->mod" 188 189 def result(self) -> List[str]: 190 return [self.format_result(self.convert_from_canonical(self.int_a, self.rep))] 191 192class BignumModRawModulusToCanonicalRep(BignumModRawConvertRep): 193 """Test cases for mpi_mod_raw_modulus_to_canonical_rep.""" 194 test_function = "mpi_mod_raw_modulus_to_canonical_rep" 195 test_name = "Rep mod->canon" 196 197 @property 198 def arg_a(self) -> str: 199 return self.format_arg("{:x}".format(self.convert_from_canonical(self.int_a, self.rep))) 200 201 def result(self) -> List[str]: 202 return [self.format_result(self.int_a)] 203 204 205class BignumModRawConvertToMont(bignum_common.ModOperationCommon, 206 BignumModRawTarget): 207 """ Test cases for mpi_mod_raw_to_mont_rep(). """ 208 test_function = "mpi_mod_raw_to_mont_rep" 209 test_name = "Convert into Mont: " 210 symbol = "R *" 211 input_style = "arch_split" 212 arity = 1 213 214 def result(self) -> List[str]: 215 result = self.to_montgomery(self.int_a) 216 return [self.format_result(result)] 217 218class BignumModRawConvertFromMont(bignum_common.ModOperationCommon, 219 BignumModRawTarget): 220 """ Test cases for mpi_mod_raw_from_mont_rep(). """ 221 test_function = "mpi_mod_raw_from_mont_rep" 222 test_name = "Convert from Mont: " 223 symbol = "1/R *" 224 input_style = "arch_split" 225 arity = 1 226 227 def result(self) -> List[str]: 228 result = self.from_montgomery(self.int_a) 229 return [self.format_result(result)] 230 231class BignumModRawModNegate(bignum_common.ModOperationCommon, 232 BignumModRawTarget): 233 """ Test cases for mpi_mod_raw_neg(). """ 234 test_function = "mpi_mod_raw_neg" 235 test_name = "Modular negation: " 236 symbol = "-" 237 input_style = "arch_split" 238 arity = 1 239 240 def result(self) -> List[str]: 241 result = (self.int_n - self.int_a) % self.int_n 242 return [self.format_result(result)] 243