xref: /aosp_15_r20/external/emboss/compiler/front_end/write_inference.py (revision 99e0aae7469b87d12f0ad23e61142c2d74c1ef70)
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Adds auto-generated virtual fields to the IR."""
16
17from compiler.front_end import attributes
18from compiler.front_end import expression_bounds
19from compiler.util import ir_data
20from compiler.util import ir_data_utils
21from compiler.util import ir_util
22from compiler.util import traverse_ir
23
24
25def _find_field_reference_path(expression):
26  """Returns a path to a field reference, or None.
27
28  If the provided expression contains exactly one field_reference,
29  _find_field_reference_path will return a list of indexes, such that
30  recursively reading the index'th element of expression.function.args will find
31  the field_reference.  For example, for:
32
33      5 + (x * 2)
34
35  _find_field_reference_path will return [1, 0]: from the top-level `+`
36  expression, arg 1 is the `x * 2` expression, and `x` is arg 0 of the `*`
37  expression.
38
39  Arguments:
40    expression: an ir_data.Expression to walk
41
42  Returns:
43    A list of indexes to find a field_reference, or None.
44  """
45  found, indexes = _recursively_find_field_reference_path(expression)
46  if found == 1:
47    return indexes
48  else:
49    return None
50
51
52def _recursively_find_field_reference_path(expression):
53  """Recursive implementation of _find_field_reference_path."""
54  if expression.WhichOneof("expression") == "field_reference":
55    return 1, []
56  elif expression.WhichOneof("expression") == "function":
57    field_count = 0
58    path = []
59    for index in range(len(expression.function.args)):
60      arg = expression.function.args[index]
61      arg_result = _recursively_find_field_reference_path(arg)
62      arg_field_count, arg_path = arg_result
63      if arg_field_count == 1 and field_count == 0:
64        path = [index] + arg_path
65      field_count += arg_field_count
66    if field_count == 1:
67      return field_count, path
68    else:
69      return field_count, []
70  else:
71    return 0, []
72
73
74def _invert_expression(expression, ir):
75  """For the given expression, searches for an algebraic inverse expression.
76
77  That is, it takes the notional equation:
78
79      $logical_value = expression
80
81  and, if there is exactly one `field_reference` in `expression`, it will
82  attempt to solve the equation for that field.  For example, if the expression
83  is `x + 1`, it will iteratively transform:
84
85      $logical_value = x + 1
86      $logical_value - 1 = x + 1 - 1
87      $logical_value - 1 = x
88
89  and finally return `x` and `$logical_value - 1`.
90
91  The purpose of this transformation is to find an assignment statement that can
92  be used to write back through certain virtual fields.  E.g., given:
93
94      struct Foo:
95        0 [+1]  UInt  raw_value
96        let actual_value = raw_value + 100
97
98  it should be possible to write a value to the `actual_value` field, and have
99  it set `raw_value` to the appropriate value.
100
101  Arguments:
102    expression: an ir_data.Expression to be inverted.
103    ir: the full IR, for looking up symbols.
104
105  Returns:
106    (field_reference, inverse_expression) if expression can be inverted,
107    otherwise None.
108  """
109  reference_path = _find_field_reference_path(expression)
110  if reference_path is None:
111    return None
112  subexpression = expression
113  result = ir_data.Expression(
114      builtin_reference=ir_data.Reference(
115          canonical_name=ir_data.CanonicalName(
116              module_file="",
117              object_path=["$logical_value"]
118          ),
119          source_name=[ir_data.Word(
120              text="$logical_value",
121              source_location=ir_data.Location(is_synthetic=True)
122          )],
123          source_location=ir_data.Location(is_synthetic=True)
124      ),
125      type=expression.type,
126      source_location=ir_data.Location(is_synthetic=True)
127  )
128
129  # This loop essentially starts with:
130  #
131  #     f(g(x)) == $logical_value
132  #
133  # and ends with
134  #
135  #     x == g_inv(f_inv($logical_value))
136  #
137  # At each step, `subexpression` has one layer removed, and `result` has a
138  # corresponding inverse function applied.  So, for example, it might start
139  # with:
140  #
141  #     2 + ((3 - x) - 10)  ==  $logical_value
142  #
143  # On each iteration, `subexpression` and `result` will become:
144  #
145  #     (3 - x) - 10  ==  $logical_value - 2    [subtract 2 from both sides]
146  #     (3 - x)  ==  ($logical_value - 2) + 10  [add 10 to both sides]
147  #     x  ==  3 - (($logical_value - 2) + 10)  [subtract both sides from 3]
148  #
149  # This is an extremely limited algebraic solver, but it covers common-enough
150  # cases.
151  #
152  # Note that any equation that can be solved here becomes part of Emboss's
153  # contract, forever, so be conservative in expanding its solving capabilities!
154  for index in reference_path:
155    if subexpression.function.function == ir_data.FunctionMapping.ADDITION:
156      result = ir_data.Expression(
157          function=ir_data.Function(
158              function=ir_data.FunctionMapping.SUBTRACTION,
159              args=[
160                  result,
161                  subexpression.function.args[1 - index],
162              ]
163          ),
164          type=ir_data.ExpressionType(integer=ir_data.IntegerType())
165      )
166    elif subexpression.function.function == ir_data.FunctionMapping.SUBTRACTION:
167      if index == 0:
168        result = ir_data.Expression(
169            function=ir_data.Function(
170                function=ir_data.FunctionMapping.ADDITION,
171                args=[
172                    result,
173                    subexpression.function.args[1],
174                ]
175            ),
176            type=ir_data.ExpressionType(integer=ir_data.IntegerType())
177        )
178      else:
179        result = ir_data.Expression(
180            function=ir_data.Function(
181                function=ir_data.FunctionMapping.SUBTRACTION,
182                args=[
183                    subexpression.function.args[0],
184                    result,
185                ]
186            ),
187            type=ir_data.ExpressionType(integer=ir_data.IntegerType())
188        )
189    else:
190      return None
191    subexpression = subexpression.function.args[index]
192  expression_bounds.compute_constraints_of_expression(result, ir)
193  return subexpression, result
194
195
196def _add_write_method(field, ir):
197  """Adds an appropriate write_method to field, if applicable.
198
199  Currently, the "alias" write_method will be added for virtual fields of the
200  form `let v = some_field_reference` when `some_field_reference` is a physical
201  field or a writeable alias.  The "physical" write_method will be added for
202  physical fields.  The "transform" write_method will be added when the virtual
203  field's value is an easily-invertible function of a single writeable field.
204  All other fields will have the "read_only" write_method; i.e., they will not
205  be writeable.
206
207  Arguments:
208    field: an ir_data.Field to which to add a write_method.
209    ir: The IR in which to look up field_references.
210
211  Returns:
212    None
213  """
214  if field.HasField("write_method"):
215    # Do not recompute anything.
216    return
217
218  if not ir_util.field_is_virtual(field):
219    # If the field is not virtual, writes are physical.
220    ir_data_utils.builder(field).write_method.physical = True
221    return
222
223  field_checker = ir_data_utils.reader(field)
224  field_builder = ir_data_utils.builder(field)
225
226  # A virtual field cannot be a direct alias if it has an additional
227  # requirement.
228  requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
229  if (field_checker.read_transform.WhichOneof("expression") != "field_reference" or
230      requires_attr is not None):
231    inverse = _invert_expression(field.read_transform, ir)
232    if inverse:
233      field_reference, function_body = inverse
234      referenced_field = ir_util.find_object(
235          field_reference.field_reference.path[-1], ir)
236      if not isinstance(referenced_field, ir_data.Field):
237        reference_is_read_only = True
238      else:
239        _add_write_method(referenced_field, ir)
240        reference_is_read_only = referenced_field.write_method.read_only
241      if not reference_is_read_only:
242        field_builder.write_method.transform.destination.CopyFrom(
243            field_reference.field_reference)
244        field_builder.write_method.transform.function_body.CopyFrom(function_body)
245      else:
246        # If the virtual field's expression is invertible, but its target field
247        # is read-only, it is also read-only.
248        field_builder.write_method.read_only = True
249    else:
250      # If the virtual field's expression is not invertible, it is
251      # read-only.
252      field_builder.write_method.read_only = True
253    return
254
255  referenced_field = ir_util.find_object(
256      field.read_transform.field_reference.path[-1], ir)
257  if not isinstance(referenced_field, ir_data.Field):
258    # If the virtual field aliases a non-field (i.e., a parameter), it is
259    # read-only.
260    field_builder.write_method.read_only = True
261    return
262
263  _add_write_method(referenced_field, ir)
264  if referenced_field.write_method.read_only:
265    # If the virtual field directly aliases a read-only field, it is read-only.
266    field_builder.write_method.read_only = True
267    return
268
269  # Otherwise, it can be written as a direct alias.
270  field_builder.write_method.alias.CopyFrom(
271      field.read_transform.field_reference)
272
273
274def set_write_methods(ir):
275  """Sets the write_method member of all ir_data.Fields in ir.
276
277  Arguments:
278      ir: The IR to which to add write_methods.
279
280  Returns:
281      A list of errors, or an empty list.
282  """
283  traverse_ir.fast_traverse_ir_top_down(ir, [ir_data.Field], _add_write_method)
284  return []
285