# Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Adds auto-generated virtual fields to the IR.""" from compiler.front_end import attributes from compiler.front_end import expression_bounds from compiler.util import ir_data from compiler.util import ir_data_utils from compiler.util import ir_util from compiler.util import traverse_ir def _find_field_reference_path(expression): """Returns a path to a field reference, or None. If the provided expression contains exactly one field_reference, _find_field_reference_path will return a list of indexes, such that recursively reading the index'th element of expression.function.args will find the field_reference. For example, for: 5 + (x * 2) _find_field_reference_path will return [1, 0]: from the top-level `+` expression, arg 1 is the `x * 2` expression, and `x` is arg 0 of the `*` expression. Arguments: expression: an ir_data.Expression to walk Returns: A list of indexes to find a field_reference, or None. """ found, indexes = _recursively_find_field_reference_path(expression) if found == 1: return indexes else: return None def _recursively_find_field_reference_path(expression): """Recursive implementation of _find_field_reference_path.""" if expression.WhichOneof("expression") == "field_reference": return 1, [] elif expression.WhichOneof("expression") == "function": field_count = 0 path = [] for index in range(len(expression.function.args)): arg = expression.function.args[index] arg_result = _recursively_find_field_reference_path(arg) arg_field_count, arg_path = arg_result if arg_field_count == 1 and field_count == 0: path = [index] + arg_path field_count += arg_field_count if field_count == 1: return field_count, path else: return field_count, [] else: return 0, [] def _invert_expression(expression, ir): """For the given expression, searches for an algebraic inverse expression. That is, it takes the notional equation: $logical_value = expression and, if there is exactly one `field_reference` in `expression`, it will attempt to solve the equation for that field. For example, if the expression is `x + 1`, it will iteratively transform: $logical_value = x + 1 $logical_value - 1 = x + 1 - 1 $logical_value - 1 = x and finally return `x` and `$logical_value - 1`. The purpose of this transformation is to find an assignment statement that can be used to write back through certain virtual fields. E.g., given: struct Foo: 0 [+1] UInt raw_value let actual_value = raw_value + 100 it should be possible to write a value to the `actual_value` field, and have it set `raw_value` to the appropriate value. Arguments: expression: an ir_data.Expression to be inverted. ir: the full IR, for looking up symbols. Returns: (field_reference, inverse_expression) if expression can be inverted, otherwise None. """ reference_path = _find_field_reference_path(expression) if reference_path is None: return None subexpression = expression result = ir_data.Expression( builtin_reference=ir_data.Reference( canonical_name=ir_data.CanonicalName( module_file="", object_path=["$logical_value"] ), source_name=[ir_data.Word( text="$logical_value", source_location=ir_data.Location(is_synthetic=True) )], source_location=ir_data.Location(is_synthetic=True) ), type=expression.type, source_location=ir_data.Location(is_synthetic=True) ) # This loop essentially starts with: # # f(g(x)) == $logical_value # # and ends with # # x == g_inv(f_inv($logical_value)) # # At each step, `subexpression` has one layer removed, and `result` has a # corresponding inverse function applied. So, for example, it might start # with: # # 2 + ((3 - x) - 10) == $logical_value # # On each iteration, `subexpression` and `result` will become: # # (3 - x) - 10 == $logical_value - 2 [subtract 2 from both sides] # (3 - x) == ($logical_value - 2) + 10 [add 10 to both sides] # x == 3 - (($logical_value - 2) + 10) [subtract both sides from 3] # # This is an extremely limited algebraic solver, but it covers common-enough # cases. # # Note that any equation that can be solved here becomes part of Emboss's # contract, forever, so be conservative in expanding its solving capabilities! for index in reference_path: if subexpression.function.function == ir_data.FunctionMapping.ADDITION: result = ir_data.Expression( function=ir_data.Function( function=ir_data.FunctionMapping.SUBTRACTION, args=[ result, subexpression.function.args[1 - index], ] ), type=ir_data.ExpressionType(integer=ir_data.IntegerType()) ) elif subexpression.function.function == ir_data.FunctionMapping.SUBTRACTION: if index == 0: result = ir_data.Expression( function=ir_data.Function( function=ir_data.FunctionMapping.ADDITION, args=[ result, subexpression.function.args[1], ] ), type=ir_data.ExpressionType(integer=ir_data.IntegerType()) ) else: result = ir_data.Expression( function=ir_data.Function( function=ir_data.FunctionMapping.SUBTRACTION, args=[ subexpression.function.args[0], result, ] ), type=ir_data.ExpressionType(integer=ir_data.IntegerType()) ) else: return None subexpression = subexpression.function.args[index] expression_bounds.compute_constraints_of_expression(result, ir) return subexpression, result def _add_write_method(field, ir): """Adds an appropriate write_method to field, if applicable. Currently, the "alias" write_method will be added for virtual fields of the form `let v = some_field_reference` when `some_field_reference` is a physical field or a writeable alias. The "physical" write_method will be added for physical fields. The "transform" write_method will be added when the virtual field's value is an easily-invertible function of a single writeable field. All other fields will have the "read_only" write_method; i.e., they will not be writeable. Arguments: field: an ir_data.Field to which to add a write_method. ir: The IR in which to look up field_references. Returns: None """ if field.HasField("write_method"): # Do not recompute anything. return if not ir_util.field_is_virtual(field): # If the field is not virtual, writes are physical. ir_data_utils.builder(field).write_method.physical = True return field_checker = ir_data_utils.reader(field) field_builder = ir_data_utils.builder(field) # A virtual field cannot be a direct alias if it has an additional # requirement. requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES) if (field_checker.read_transform.WhichOneof("expression") != "field_reference" or requires_attr is not None): inverse = _invert_expression(field.read_transform, ir) if inverse: field_reference, function_body = inverse referenced_field = ir_util.find_object( field_reference.field_reference.path[-1], ir) if not isinstance(referenced_field, ir_data.Field): reference_is_read_only = True else: _add_write_method(referenced_field, ir) reference_is_read_only = referenced_field.write_method.read_only if not reference_is_read_only: field_builder.write_method.transform.destination.CopyFrom( field_reference.field_reference) field_builder.write_method.transform.function_body.CopyFrom(function_body) else: # If the virtual field's expression is invertible, but its target field # is read-only, it is also read-only. field_builder.write_method.read_only = True else: # If the virtual field's expression is not invertible, it is # read-only. field_builder.write_method.read_only = True return referenced_field = ir_util.find_object( field.read_transform.field_reference.path[-1], ir) if not isinstance(referenced_field, ir_data.Field): # If the virtual field aliases a non-field (i.e., a parameter), it is # read-only. field_builder.write_method.read_only = True return _add_write_method(referenced_field, ir) if referenced_field.write_method.read_only: # If the virtual field directly aliases a read-only field, it is read-only. field_builder.write_method.read_only = True return # Otherwise, it can be written as a direct alias. field_builder.write_method.alias.CopyFrom( field.read_transform.field_reference) def set_write_methods(ir): """Sets the write_method member of all ir_data.Fields in ir. Arguments: ir: The IR to which to add write_methods. Returns: A list of errors, or an empty list. """ traverse_ir.fast_traverse_ir_top_down(ir, [ir_data.Field], _add_write_method) return []