xref: /aosp_15_r20/external/yapf/yapftests/yapf_test_helper.py (revision 7249d1a64f4850ccf838e62a46276f891f72998e)
1# Copyright 2016 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Support module for tests for yapf."""
15
16import difflib
17import sys
18import unittest
19
20from yapf.yapflib import blank_line_calculator
21from yapf.yapflib import comment_splicer
22from yapf.yapflib import continuation_splicer
23from yapf.yapflib import identify_container
24from yapf.yapflib import py3compat
25from yapf.yapflib import pytree_unwrapper
26from yapf.yapflib import pytree_utils
27from yapf.yapflib import pytree_visitor
28from yapf.yapflib import split_penalty
29from yapf.yapflib import style
30from yapf.yapflib import subtype_assigner
31
32
33class YAPFTest(unittest.TestCase):
34
35  def __init__(self, *args):
36    super(YAPFTest, self).__init__(*args)
37    if not py3compat.PY3:
38      self.assertRaisesRegex = self.assertRaisesRegexp
39
40  def assertCodeEqual(self, expected_code, code):
41    if code != expected_code:
42      msg = ['Code format mismatch:', 'Expected:']
43      linelen = style.Get('COLUMN_LIMIT')
44      for line in expected_code.splitlines():
45        if len(line) > linelen:
46          msg.append('!> %s' % line)
47        else:
48          msg.append(' > %s' % line)
49      msg.append('Actual:')
50      for line in code.splitlines():
51        if len(line) > linelen:
52          msg.append('!> %s' % line)
53        else:
54          msg.append(' > %s' % line)
55      msg.append('Diff:')
56      msg.extend(
57          difflib.unified_diff(
58              code.splitlines(),
59              expected_code.splitlines(),
60              fromfile='actual',
61              tofile='expected',
62              lineterm=''))
63      self.fail('\n'.join(msg))
64
65
66def ParseAndUnwrap(code, dumptree=False):
67  """Produces logical lines from the given code.
68
69  Parses the code into a tree, performs comment splicing and runs the
70  unwrapper.
71
72  Arguments:
73    code: code to parse as a string
74    dumptree: if True, the parsed pytree (after comment splicing) is dumped
75              to stderr. Useful for debugging.
76
77  Returns:
78    List of logical lines.
79  """
80  tree = pytree_utils.ParseCodeToTree(code)
81  comment_splicer.SpliceComments(tree)
82  continuation_splicer.SpliceContinuations(tree)
83  subtype_assigner.AssignSubtypes(tree)
84  identify_container.IdentifyContainers(tree)
85  split_penalty.ComputeSplitPenalties(tree)
86  blank_line_calculator.CalculateBlankLines(tree)
87
88  if dumptree:
89    pytree_visitor.DumpPyTree(tree, target_stream=sys.stderr)
90
91  llines = pytree_unwrapper.UnwrapPyTree(tree)
92  for lline in llines:
93    lline.CalculateFormattingInformation()
94
95  return llines
96