1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Adds auto-generated virtual fields to the IR.""" 16 17from compiler.front_end import attributes 18from compiler.front_end import expression_bounds 19from compiler.util import ir_data 20from compiler.util import ir_data_utils 21from compiler.util import ir_util 22from compiler.util import traverse_ir 23 24 25def _find_field_reference_path(expression): 26 """Returns a path to a field reference, or None. 27 28 If the provided expression contains exactly one field_reference, 29 _find_field_reference_path will return a list of indexes, such that 30 recursively reading the index'th element of expression.function.args will find 31 the field_reference. For example, for: 32 33 5 + (x * 2) 34 35 _find_field_reference_path will return [1, 0]: from the top-level `+` 36 expression, arg 1 is the `x * 2` expression, and `x` is arg 0 of the `*` 37 expression. 38 39 Arguments: 40 expression: an ir_data.Expression to walk 41 42 Returns: 43 A list of indexes to find a field_reference, or None. 44 """ 45 found, indexes = _recursively_find_field_reference_path(expression) 46 if found == 1: 47 return indexes 48 else: 49 return None 50 51 52def _recursively_find_field_reference_path(expression): 53 """Recursive implementation of _find_field_reference_path.""" 54 if expression.WhichOneof("expression") == "field_reference": 55 return 1, [] 56 elif expression.WhichOneof("expression") == "function": 57 field_count = 0 58 path = [] 59 for index in range(len(expression.function.args)): 60 arg = expression.function.args[index] 61 arg_result = _recursively_find_field_reference_path(arg) 62 arg_field_count, arg_path = arg_result 63 if arg_field_count == 1 and field_count == 0: 64 path = [index] + arg_path 65 field_count += arg_field_count 66 if field_count == 1: 67 return field_count, path 68 else: 69 return field_count, [] 70 else: 71 return 0, [] 72 73 74def _invert_expression(expression, ir): 75 """For the given expression, searches for an algebraic inverse expression. 76 77 That is, it takes the notional equation: 78 79 $logical_value = expression 80 81 and, if there is exactly one `field_reference` in `expression`, it will 82 attempt to solve the equation for that field. For example, if the expression 83 is `x + 1`, it will iteratively transform: 84 85 $logical_value = x + 1 86 $logical_value - 1 = x + 1 - 1 87 $logical_value - 1 = x 88 89 and finally return `x` and `$logical_value - 1`. 90 91 The purpose of this transformation is to find an assignment statement that can 92 be used to write back through certain virtual fields. E.g., given: 93 94 struct Foo: 95 0 [+1] UInt raw_value 96 let actual_value = raw_value + 100 97 98 it should be possible to write a value to the `actual_value` field, and have 99 it set `raw_value` to the appropriate value. 100 101 Arguments: 102 expression: an ir_data.Expression to be inverted. 103 ir: the full IR, for looking up symbols. 104 105 Returns: 106 (field_reference, inverse_expression) if expression can be inverted, 107 otherwise None. 108 """ 109 reference_path = _find_field_reference_path(expression) 110 if reference_path is None: 111 return None 112 subexpression = expression 113 result = ir_data.Expression( 114 builtin_reference=ir_data.Reference( 115 canonical_name=ir_data.CanonicalName( 116 module_file="", 117 object_path=["$logical_value"] 118 ), 119 source_name=[ir_data.Word( 120 text="$logical_value", 121 source_location=ir_data.Location(is_synthetic=True) 122 )], 123 source_location=ir_data.Location(is_synthetic=True) 124 ), 125 type=expression.type, 126 source_location=ir_data.Location(is_synthetic=True) 127 ) 128 129 # This loop essentially starts with: 130 # 131 # f(g(x)) == $logical_value 132 # 133 # and ends with 134 # 135 # x == g_inv(f_inv($logical_value)) 136 # 137 # At each step, `subexpression` has one layer removed, and `result` has a 138 # corresponding inverse function applied. So, for example, it might start 139 # with: 140 # 141 # 2 + ((3 - x) - 10) == $logical_value 142 # 143 # On each iteration, `subexpression` and `result` will become: 144 # 145 # (3 - x) - 10 == $logical_value - 2 [subtract 2 from both sides] 146 # (3 - x) == ($logical_value - 2) + 10 [add 10 to both sides] 147 # x == 3 - (($logical_value - 2) + 10) [subtract both sides from 3] 148 # 149 # This is an extremely limited algebraic solver, but it covers common-enough 150 # cases. 151 # 152 # Note that any equation that can be solved here becomes part of Emboss's 153 # contract, forever, so be conservative in expanding its solving capabilities! 154 for index in reference_path: 155 if subexpression.function.function == ir_data.FunctionMapping.ADDITION: 156 result = ir_data.Expression( 157 function=ir_data.Function( 158 function=ir_data.FunctionMapping.SUBTRACTION, 159 args=[ 160 result, 161 subexpression.function.args[1 - index], 162 ] 163 ), 164 type=ir_data.ExpressionType(integer=ir_data.IntegerType()) 165 ) 166 elif subexpression.function.function == ir_data.FunctionMapping.SUBTRACTION: 167 if index == 0: 168 result = ir_data.Expression( 169 function=ir_data.Function( 170 function=ir_data.FunctionMapping.ADDITION, 171 args=[ 172 result, 173 subexpression.function.args[1], 174 ] 175 ), 176 type=ir_data.ExpressionType(integer=ir_data.IntegerType()) 177 ) 178 else: 179 result = ir_data.Expression( 180 function=ir_data.Function( 181 function=ir_data.FunctionMapping.SUBTRACTION, 182 args=[ 183 subexpression.function.args[0], 184 result, 185 ] 186 ), 187 type=ir_data.ExpressionType(integer=ir_data.IntegerType()) 188 ) 189 else: 190 return None 191 subexpression = subexpression.function.args[index] 192 expression_bounds.compute_constraints_of_expression(result, ir) 193 return subexpression, result 194 195 196def _add_write_method(field, ir): 197 """Adds an appropriate write_method to field, if applicable. 198 199 Currently, the "alias" write_method will be added for virtual fields of the 200 form `let v = some_field_reference` when `some_field_reference` is a physical 201 field or a writeable alias. The "physical" write_method will be added for 202 physical fields. The "transform" write_method will be added when the virtual 203 field's value is an easily-invertible function of a single writeable field. 204 All other fields will have the "read_only" write_method; i.e., they will not 205 be writeable. 206 207 Arguments: 208 field: an ir_data.Field to which to add a write_method. 209 ir: The IR in which to look up field_references. 210 211 Returns: 212 None 213 """ 214 if field.HasField("write_method"): 215 # Do not recompute anything. 216 return 217 218 if not ir_util.field_is_virtual(field): 219 # If the field is not virtual, writes are physical. 220 ir_data_utils.builder(field).write_method.physical = True 221 return 222 223 field_checker = ir_data_utils.reader(field) 224 field_builder = ir_data_utils.builder(field) 225 226 # A virtual field cannot be a direct alias if it has an additional 227 # requirement. 228 requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES) 229 if (field_checker.read_transform.WhichOneof("expression") != "field_reference" or 230 requires_attr is not None): 231 inverse = _invert_expression(field.read_transform, ir) 232 if inverse: 233 field_reference, function_body = inverse 234 referenced_field = ir_util.find_object( 235 field_reference.field_reference.path[-1], ir) 236 if not isinstance(referenced_field, ir_data.Field): 237 reference_is_read_only = True 238 else: 239 _add_write_method(referenced_field, ir) 240 reference_is_read_only = referenced_field.write_method.read_only 241 if not reference_is_read_only: 242 field_builder.write_method.transform.destination.CopyFrom( 243 field_reference.field_reference) 244 field_builder.write_method.transform.function_body.CopyFrom(function_body) 245 else: 246 # If the virtual field's expression is invertible, but its target field 247 # is read-only, it is also read-only. 248 field_builder.write_method.read_only = True 249 else: 250 # If the virtual field's expression is not invertible, it is 251 # read-only. 252 field_builder.write_method.read_only = True 253 return 254 255 referenced_field = ir_util.find_object( 256 field.read_transform.field_reference.path[-1], ir) 257 if not isinstance(referenced_field, ir_data.Field): 258 # If the virtual field aliases a non-field (i.e., a parameter), it is 259 # read-only. 260 field_builder.write_method.read_only = True 261 return 262 263 _add_write_method(referenced_field, ir) 264 if referenced_field.write_method.read_only: 265 # If the virtual field directly aliases a read-only field, it is read-only. 266 field_builder.write_method.read_only = True 267 return 268 269 # Otherwise, it can be written as a direct alias. 270 field_builder.write_method.alias.CopyFrom( 271 field.read_transform.field_reference) 272 273 274def set_write_methods(ir): 275 """Sets the write_method member of all ir_data.Fields in ir. 276 277 Arguments: 278 ir: The IR to which to add write_methods. 279 280 Returns: 281 A list of errors, or an empty list. 282 """ 283 traverse_ir.fast_traverse_ir_top_down(ir, [ir_data.Field], _add_write_method) 284 return [] 285