1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for lower/upper bounds validators for numeric flags."""
16
17from unittest import mock
18from absl import flags
19from absl.flags import _validators
20from absl.testing import absltest
21
22
23class NumericFlagBoundsTest(absltest.TestCase):
24
25  def setUp(self):
26    super(NumericFlagBoundsTest, self).setUp()
27    self.flag_values = flags.FlagValues()
28
29  def test_no_validator_if_no_bounds(self):
30    """Validator is not registered if lower and upper bound are None."""
31    with mock.patch.object(_validators, 'register_validator'
32                          ) as register_validator:
33      flags.DEFINE_integer('positive_flag', None, 'positive int',
34                           lower_bound=0, flag_values=self.flag_values)
35      register_validator.assert_called_once_with(
36          'positive_flag', mock.ANY, flag_values=self.flag_values)
37    with mock.patch.object(_validators, 'register_validator'
38                          ) as register_validator:
39      flags.DEFINE_integer('int_flag', None, 'just int',
40                           flag_values=self.flag_values)
41      register_validator.assert_not_called()
42
43  def test_success(self):
44    flags.DEFINE_integer('int_flag', 5, 'Just integer',
45                         flag_values=self.flag_values)
46    argv = ('./program', '--int_flag=13')
47    self.flag_values(argv)
48    self.assertEqual(13, self.flag_values.int_flag)
49    self.flag_values.int_flag = 25
50    self.assertEqual(25, self.flag_values.int_flag)
51
52  def test_success_if_none(self):
53    flags.DEFINE_integer('int_flag', None, '',
54                         lower_bound=0, upper_bound=5,
55                         flag_values=self.flag_values)
56    argv = ('./program',)
57    self.flag_values(argv)
58    self.assertIsNone(self.flag_values.int_flag)
59
60  def test_success_if_exactly_equals(self):
61    flags.DEFINE_float('float_flag', None, '',
62                       lower_bound=1, upper_bound=1,
63                       flag_values=self.flag_values)
64    argv = ('./program', '--float_flag=1')
65    self.flag_values(argv)
66    self.assertEqual(1, self.flag_values.float_flag)
67
68  def test_exception_if_smaller(self):
69    flags.DEFINE_integer('int_flag', None, '',
70                         lower_bound=0, upper_bound=5,
71                         flag_values=self.flag_values)
72    argv = ('./program', '--int_flag=-1')
73    try:
74      self.flag_values(argv)
75    except flags.IllegalFlagValueError as e:
76      text = 'flag --int_flag=-1: -1 is not an integer in the range [0, 5]'
77      self.assertEqual(text, str(e))
78
79
80class SettingFlagAfterStartTest(absltest.TestCase):
81
82  def setUp(self):
83    self.flag_values = flags.FlagValues()
84
85  def test_success(self):
86    flags.DEFINE_integer('int_flag', None, 'Just integer',
87                         flag_values=self.flag_values)
88    argv = ('./program', '--int_flag=13')
89    self.flag_values(argv)
90    self.assertEqual(13, self.flag_values.int_flag)
91    self.flag_values.int_flag = 25
92    self.assertEqual(25, self.flag_values.int_flag)
93
94  def test_exception_if_setting_integer_flag_outside_bounds(self):
95    flags.DEFINE_integer('int_flag', None, 'Just integer', lower_bound=0,
96                         flag_values=self.flag_values)
97    argv = ('./program', '--int_flag=13')
98    self.flag_values(argv)
99    self.assertEqual(13, self.flag_values.int_flag)
100    with self.assertRaises(flags.IllegalFlagValueError):
101      self.flag_values.int_flag = -2
102
103
104if __name__ == '__main__':
105  absltest.main()
106