1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test configs for conv with activations.""" 16import numpy as np 17import tensorflow.compat.v1 as tf 18from tensorflow.lite.testing.zip_test_utils import create_tensor_data 19from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests 20from tensorflow.lite.testing.zip_test_utils import register_make_test_function 21 22 23def make_conv_activation_tests(activation_op): 24 """Make a set of tests to do convolution with activation.""" 25 26 def f(options): 27 """Actual function that generates examples.""" 28 test_parameters = [ 29 { 30 "input_shape": [[1, 3, 4, 3], [4, 6, 6, 1]], 31 "filter_shape": [[1, 1], [2, 3], [3, 3]], 32 "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], 33 "dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]], 34 "padding": ["SAME", "VALID"], 35 "data_format": ["NHWC"], # TODO(aselle): NCHW would be good 36 "constant_filter": [True, False], 37 "channel_multiplier": [1, 2], 38 "fully_quantize": [False], 39 "quant_16x8": [False], 40 "dynamic_range_quantize": [False], 41 }, 42 # TODO(b/134702301): The fully_quantize param is just ignored by the 43 # MLIR testing path now, resulting in duplicate tests. Either ignore 44 # these tests or handle it properly in the mlir_convert() function. 45 { 46 "input_shape": [[1, 3, 4, 3], [4, 6, 6, 1]], 47 "filter_shape": [[1, 1], [2, 3]], 48 "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], 49 "dilations": [[1, 1, 1, 1], [1, 3, 2, 1]], 50 "padding": ["SAME", "VALID"], 51 "data_format": ["NHWC"], # TODO(aselle): NCHW would be good 52 "constant_filter": [True], 53 "channel_multiplier": [1, 2], 54 "fully_quantize": [True], 55 "quant_16x8": [False, True], 56 "dynamic_range_quantize": [False], 57 }, 58 { 59 "input_shape": [[1, 3, 4, 3]], 60 "filter_shape": [[1, 1], [2, 3], [3, 3]], 61 "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], 62 "dilations": [[1, 1, 1, 1]], 63 "padding": ["SAME", "VALID"], 64 "data_format": ["NHWC"], 65 "constant_filter": [True], 66 "channel_multiplier": [1, 2], 67 "fully_quantize": [False], 68 "quant_16x8": [False], 69 "dynamic_range_quantize": [True], 70 }, 71 ] 72 73 def get_tensor_shapes(parameters): 74 input_shape = parameters["input_shape"] 75 filter_size = parameters["filter_shape"] 76 filter_shape = filter_size + [ 77 input_shape[3], parameters["channel_multiplier"] 78 ] 79 return [input_shape, filter_shape] 80 81 def build_graph(parameters): 82 """Build a conv graph given `parameters`.""" 83 input_shape, filter_shape = get_tensor_shapes(parameters) 84 input_tensor = tf.compat.v1.placeholder( 85 dtype=tf.float32, name="input", shape=input_shape) 86 87 # Get filter input either as a placeholder or constants. Also get a list 88 # of the input tensors that are represented as placeholders. 89 if parameters["constant_filter"]: 90 filter_input = create_tensor_data( 91 np.float32, filter_shape, min_value=-10, max_value=10) 92 input_tensors = [input_tensor] 93 else: 94 filter_input = tf.compat.v1.placeholder( 95 dtype=tf.float32, name="filter", shape=filter_shape) 96 input_tensors = [input_tensor, filter_input] 97 98 out = tf.nn.conv2d( 99 input_tensor, 100 filter_input, 101 strides=parameters["strides"], 102 dilations=parameters["dilations"], 103 padding=parameters["padding"], 104 data_format=parameters["data_format"]) 105 out = activation_op(out) 106 return input_tensors, [out] 107 108 def build_inputs(parameters, sess, inputs, outputs): 109 """Build inputs for conv with activation.""" 110 111 input_shape, filter_shape = get_tensor_shapes(parameters) 112 values = [ 113 create_tensor_data( 114 np.float32, input_shape, min_value=-1, max_value=1) 115 ] 116 if not parameters["constant_filter"]: 117 values.append(create_tensor_data(np.float32, filter_shape)) 118 return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) 119 120 make_zip_of_tests( 121 options, 122 test_parameters, 123 build_graph, 124 build_inputs, 125 expected_tf_failures=48) 126 127 return f 128 129 130@register_make_test_function() 131def make_conv_relu6_tests(options): 132 """Make a set of tests to do conv_relu6.""" 133 return make_conv_activation_tests(tf.nn.relu6)(options) 134 135 136@register_make_test_function() 137def make_conv_relu_tests(options): 138 """Make a set of tests to do conv_relu.""" 139 return make_conv_activation_tests(tf.nn.relu)(options) 140 141 142def relu1(input_tensor): 143 # Note that the following is not supported: 144 # out = tf.maximum(-1.0, tf.minimum(input_tensor, 1.0)) 145 out = tf.minimum(1.0, tf.maximum(input_tensor, -1.0)) 146 return out 147 148 149@register_make_test_function() 150def make_conv_relu1_tests(options): 151 """Make a set of tests to do conv_relu1.""" 152 return make_conv_activation_tests(relu1)(options) 153