1# Copyright 2017 The Abseil Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Tests for lower/upper bounds validators for numeric flags.""" 16 17from unittest import mock 18from absl import flags 19from absl.flags import _validators 20from absl.testing import absltest 21 22 23class NumericFlagBoundsTest(absltest.TestCase): 24 25 def setUp(self): 26 super(NumericFlagBoundsTest, self).setUp() 27 self.flag_values = flags.FlagValues() 28 29 def test_no_validator_if_no_bounds(self): 30 """Validator is not registered if lower and upper bound are None.""" 31 with mock.patch.object(_validators, 'register_validator' 32 ) as register_validator: 33 flags.DEFINE_integer('positive_flag', None, 'positive int', 34 lower_bound=0, flag_values=self.flag_values) 35 register_validator.assert_called_once_with( 36 'positive_flag', mock.ANY, flag_values=self.flag_values) 37 with mock.patch.object(_validators, 'register_validator' 38 ) as register_validator: 39 flags.DEFINE_integer('int_flag', None, 'just int', 40 flag_values=self.flag_values) 41 register_validator.assert_not_called() 42 43 def test_success(self): 44 flags.DEFINE_integer('int_flag', 5, 'Just integer', 45 flag_values=self.flag_values) 46 argv = ('./program', '--int_flag=13') 47 self.flag_values(argv) 48 self.assertEqual(13, self.flag_values.int_flag) 49 self.flag_values.int_flag = 25 50 self.assertEqual(25, self.flag_values.int_flag) 51 52 def test_success_if_none(self): 53 flags.DEFINE_integer('int_flag', None, '', 54 lower_bound=0, upper_bound=5, 55 flag_values=self.flag_values) 56 argv = ('./program',) 57 self.flag_values(argv) 58 self.assertIsNone(self.flag_values.int_flag) 59 60 def test_success_if_exactly_equals(self): 61 flags.DEFINE_float('float_flag', None, '', 62 lower_bound=1, upper_bound=1, 63 flag_values=self.flag_values) 64 argv = ('./program', '--float_flag=1') 65 self.flag_values(argv) 66 self.assertEqual(1, self.flag_values.float_flag) 67 68 def test_exception_if_smaller(self): 69 flags.DEFINE_integer('int_flag', None, '', 70 lower_bound=0, upper_bound=5, 71 flag_values=self.flag_values) 72 argv = ('./program', '--int_flag=-1') 73 try: 74 self.flag_values(argv) 75 except flags.IllegalFlagValueError as e: 76 text = 'flag --int_flag=-1: -1 is not an integer in the range [0, 5]' 77 self.assertEqual(text, str(e)) 78 79 80class SettingFlagAfterStartTest(absltest.TestCase): 81 82 def setUp(self): 83 self.flag_values = flags.FlagValues() 84 85 def test_success(self): 86 flags.DEFINE_integer('int_flag', None, 'Just integer', 87 flag_values=self.flag_values) 88 argv = ('./program', '--int_flag=13') 89 self.flag_values(argv) 90 self.assertEqual(13, self.flag_values.int_flag) 91 self.flag_values.int_flag = 25 92 self.assertEqual(25, self.flag_values.int_flag) 93 94 def test_exception_if_setting_integer_flag_outside_bounds(self): 95 flags.DEFINE_integer('int_flag', None, 'Just integer', lower_bound=0, 96 flag_values=self.flag_values) 97 argv = ('./program', '--int_flag=13') 98 self.flag_values(argv) 99 self.assertEqual(13, self.flag_values.int_flag) 100 with self.assertRaises(flags.IllegalFlagValueError): 101 self.flag_values.int_flag = -2 102 103 104if __name__ == '__main__': 105 absltest.main() 106