1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""RISC Operations.""" 16 17 18from tensorflow.python.ops import gen_risc_ops 19 20 21def risc_abs(x, name='RISC_ABS'): 22 return gen_risc_ops.risc_abs(x, name=name) 23 24 25def risc_add( 26 input_lhs, 27 input_rhs, 28 name='RISC_ADD'): 29 return gen_risc_ops.risc_add(input_lhs, input_rhs, name=name) 30 31 32def risc_binary_arithmetic(x, y, op_type, name='RISC_BinaryArithmetic'): 33 return gen_risc_ops.risc_binary_arithmetic(x, y, op_type=op_type, name=name) 34 35 36def risc_binary_comparison(x, y, op_type, name='RISC_BinaryComparison'): 37 return gen_risc_ops.risc_binary_comparison(x, y, op_type=op_type, name=name) 38 39 40def risc_bitcast(x, dtype, name='RISC_BITCAST'): 41 return gen_risc_ops.risc_bitcast(x, dtype, name=name) 42 43 44def risc_broadcast(x, shape, name='RISC_BROADCAST'): 45 return gen_risc_ops.risc_broadcast(x, shape, name=name) 46 47 48def risc_cast(x, dtype, name='RISC_CAST'): 49 return gen_risc_ops.risc_cast(x, dtype, name=name) 50 51 52def risc_ceil(x, name='RISC_CEIL'): 53 return gen_risc_ops.risc_ceil(x, name=name) 54 55 56def risc_cos(x, name='RISC_COS'): 57 return gen_risc_ops.risc_cos(x, name=name) 58 59 60def risc_cholesky(x, name='RISC_CHOLESKY'): 61 return gen_risc_ops.risc_cholesky(x, name=name) 62 63 64def risc_concat(x, axis, name='RISC_CONCAT'): 65 return gen_risc_ops.risc_concat(x, axis, name=name) 66 67 68def risc_condition(pred, 69 input_true, 70 input_false, 71 func_true, 72 func_false, 73 name='RISC_CONDITION'): 74 return gen_risc_ops.risc_condition( 75 pred, 76 input_true, 77 input_false, 78 func_true=func_true, 79 func_false=func_false, 80 name=name) 81 82 83def risc_conv(x, 84 kernel, 85 strides, 86 data_format='NHWC', 87 dilations=None, 88 name='RISC_CONV'): 89 return gen_risc_ops.risc_conv( 90 x, 91 kernel, 92 strides, 93 data_format=data_format, 94 dilations=dilations, 95 name=name) 96 97 98def risc_div(input_lhs, input_rhs, name='RISC_DIV'): 99 return gen_risc_ops.risc_div(input_lhs, input_rhs, name=name) 100 101 102def risc_dot(input_lhs, 103 input_rhs, 104 transpose_a=False, 105 transpose_b=False, 106 name='RISC_DOT'): 107 return gen_risc_ops.risc_dot( 108 input_lhs, 109 input_rhs, 110 transpose_a=transpose_a, 111 transpose_b=transpose_b, 112 name=name) 113 114 115def risc_exp(x, name='RISC_EXP'): 116 return gen_risc_ops.risc_exp(x, name=name) 117 118 119def risc_fft(x, name='RISC_FFT'): 120 return gen_risc_ops.risc_fft(x, name=name) 121 122 123def risc_floor(x, name='RISC_FLOOR'): 124 return gen_risc_ops.risc_floor(x, name=name) 125 126 127def risc_gather(params, 128 indices, 129 validate_indices=None, 130 axis=None, 131 batch_dims=0, 132 name='RISC_GATHER'): 133 return gen_risc_ops.risc_gather( 134 params, 135 indices, 136 validate_indices=validate_indices, 137 name=name, 138 axis=axis, 139 batch_dims=batch_dims) 140 141 142def risc_imag(x, name='RISC_IMAG'): 143 return gen_risc_ops.risc_imag(x, name=name) 144 145 146def risc_is_finite(x, name='RISC_IS_FINITE'): 147 return gen_risc_ops.risc_is_finite(x, name=name) 148 149 150def risc_log(x, name='RISC_LOG'): 151 return gen_risc_ops.risc_log(x, name=name) 152 153 154def risc_logical_and(a, b, name='RISC_LOGICAL_AND'): 155 return gen_risc_ops.risc_logical_and(a, b, name=name) 156 157 158def risc_logical_not(a, b, name='RISC_LOGICAL_NOT'): 159 return gen_risc_ops.risc_logical_not(a, b, name=name) 160 161 162def risc_logical_or(a, b, name='RISC_LOGICAL_OR'): 163 return gen_risc_ops.risc_logical_or(a, b, name=name) 164 165 166def risc_max(input_lhs, input_rhs, name='RISC_MAX'): 167 return gen_risc_ops.risc_max(input_lhs, input_rhs, name=name) 168 169 170def risc_min(input_lhs, input_rhs, name='RISC_MIN'): 171 return gen_risc_ops.risc_min(input_lhs, input_rhs, name=name) 172 173 174def risc_mul(input_lhs, input_rhs, name='RISC_MUL'): 175 return gen_risc_ops.risc_mul(input_lhs, input_rhs, name=name) 176 177 178def risc_neg(x, name='RISC_NEG'): 179 return gen_risc_ops.risc_neg(x, name=name) 180 181 182def risc_pad(x, padding, constant_values, name='RISC_PAD'): 183 return gen_risc_ops.risc_pad(x, padding, constant_values, name=name) 184 185 186def risc_pool(x, ksize, strides, pooling_type='MAX', name='RISC_POOL'): 187 return gen_risc_ops.risc_pool( 188 x, ksize, strides, pooling_type=pooling_type, name=name) 189 190 191def risc_pow(input_lhs, input_rhs, name='RISC_POW'): 192 return gen_risc_ops.risc_pow(input_lhs, input_rhs, name=name) 193 194 195def risc_random_uniform(shape, seed, name='RISC_RANDOM_UNIFORM'): 196 return gen_risc_ops.risc_random_uniform(shape, seed, name=name) 197 198 199def risc_real(x, name='RISC_REAL'): 200 return gen_risc_ops.risc_real(x, name=name) 201 202 203def risc_reduce(x, axis, reduce_type, name='RISC_REDUCE'): 204 return gen_risc_ops.risc_reduce(x, axis, reduce_type=reduce_type, name=name) 205 206 207def risc_rem(x, name='RISC_REM'): 208 return gen_risc_ops.risc_rem(x, name=name) 209 210 211def risc_reshape(x, shape, name='RISC_RESHAPE'): 212 return gen_risc_ops.risc_reshape(x, shape, name=name) 213 214 215def risc_reverse(x, axis, name='RISC_REVERSE'): 216 return gen_risc_ops.risc_reverse(x, axis, name=name) 217 218 219def risc_scatter(indices, updates, shape, name='RISC_SCATTER'): 220 return gen_risc_ops.risc_scatter(indices, updates, shape, name=name) 221 222 223def risc_shape(x, name='RISC_SHAPE'): 224 return gen_risc_ops.risc_shape(x, name=name) 225 226 227def risc_sign(x, name='RISC_SIGN'): 228 return gen_risc_ops.risc_sign(x, name=name) 229 230 231def risc_slice(x, begin, size, name='RISC_SLICE'): 232 return gen_risc_ops.risc_slice(x, begin, size, name=name) 233 234 235def risc_sub(input_lhs, input_rhs, name='RISC_SUB'): 236 return gen_risc_ops.risc_sub(input_lhs, input_rhs, name=name) 237 238 239def risc_sort(x, axis, direction='ASCENDING', name='RISC_SORT'): 240 return gen_risc_ops.risc_sort(x, axis, direction=direction, name=name) 241 242 243def risc_squeeze(x, axis=None, name='RISC_SQUEEZE'): 244 return gen_risc_ops.risc_squeeze(x, axis, name=name) 245 246 247def risc_transpose(x, perm=None, name='RISC_TRANSPOSE'): 248 return gen_risc_ops.risc_transpose(x, perm, name=name) 249 250 251def risc_triangular_solve(matrix, 252 rhs, 253 lower=True, 254 adjoint=False, 255 name='RISC_TRIANGULAR_SOLVE'): 256 return gen_risc_ops.risc_triangular_solve( 257 matrix, rhs, lower=lower, adjoint=adjoint, name=name) 258 259 260def risc_unary(x, op_type='ABL', name='RISC_UNARY'): 261 return gen_risc_ops.risc_unary(x, op_type=op_type, name=name) 262 263 264def risc_while(cond, 265 body, 266 loop_vars, 267 shape_invariants=None, 268 parallel_iterations=10, 269 back_prop=True, 270 swap_memory=False, 271 maximum_iterations=None, 272 name='RISC_WHILE'): 273 return gen_risc_ops.risc_while( 274 cond=cond, 275 body=body, 276 loop_vars=loop_vars, 277 shape_invariants=shape_invariants, 278 parallel_iterations=parallel_iterations, 279 back_prop=back_prop, 280 swap_memory=swap_memory, 281 name=name, 282 maximum_iterations=maximum_iterations, 283 return_same_structure=True) 284