1/* 2Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package tensorflow 18 19import ( 20 "fmt" 21 "testing" 22 23 tspb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto" 24 typb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/types_go_proto" 25 corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto" 26) 27 28func TestSignatureFromProto(t *testing.T) { 29 got := signatureDefFromProto(&corepb.SignatureDef{ 30 Inputs: map[string]*corepb.TensorInfo{ 31 "input_1": &corepb.TensorInfo{ 32 Encoding: &corepb.TensorInfo_Name{ 33 Name: "tensor_1", 34 }, 35 Dtype: typb.DataType_DT_INT8, 36 TensorShape: &tspb.TensorShapeProto{ 37 Dim: []*tspb.TensorShapeProto_Dim{ 38 {Size: 1}, 39 {Size: 2}, 40 {Size: 3}, 41 }, 42 }, 43 }, 44 "input_2": &corepb.TensorInfo{ 45 Encoding: &corepb.TensorInfo_Name{ 46 Name: "tensor_2", 47 }, 48 Dtype: typb.DataType_DT_FLOAT, 49 TensorShape: &tspb.TensorShapeProto{ 50 Dim: []*tspb.TensorShapeProto_Dim{ 51 {Size: 4}, 52 {Size: 5}, 53 {Size: 6}, 54 }, 55 }, 56 }, 57 }, 58 Outputs: map[string]*corepb.TensorInfo{ 59 "output_1": &corepb.TensorInfo{ 60 Encoding: &corepb.TensorInfo_Name{ 61 Name: "tensor_3", 62 }, 63 Dtype: typb.DataType_DT_STRING, 64 TensorShape: &tspb.TensorShapeProto{ 65 Dim: []*tspb.TensorShapeProto_Dim{ 66 {Size: 1}, 67 {Size: 2}, 68 {Size: 3}, 69 }, 70 }, 71 }, 72 "output_2": &corepb.TensorInfo{ 73 Encoding: &corepb.TensorInfo_Name{ 74 Name: "tensor_4", 75 }, 76 Dtype: typb.DataType_DT_BOOL, 77 TensorShape: &tspb.TensorShapeProto{ 78 Dim: []*tspb.TensorShapeProto_Dim{ 79 {Size: 4}, 80 {Size: 5}, 81 {Size: 6}, 82 }, 83 }, 84 }, 85 }, 86 MethodName: "method", 87 }) 88 89 want := Signature{ 90 Inputs: map[string]TensorInfo{ 91 "input_1": TensorInfo{ 92 Name: "tensor_1", 93 DType: Int8, 94 Shape: MakeShape(1, 2, 3), 95 }, 96 "input_2": TensorInfo{ 97 Name: "tensor_2", 98 DType: Float, 99 Shape: MakeShape(4, 5, 6), 100 }, 101 }, 102 Outputs: map[string]TensorInfo{ 103 "output_1": TensorInfo{ 104 Name: "tensor_3", 105 DType: String, 106 Shape: MakeShape(1, 2, 3), 107 }, 108 "output_2": TensorInfo{ 109 Name: "tensor_4", 110 DType: Bool, 111 Shape: MakeShape(4, 5, 6), 112 }, 113 }, 114 MethodName: "method", 115 } 116 117 for k, input := range want.Inputs { 118 diff, err := diffTensorInfos(got.Inputs[k], input) 119 if err != nil { 120 t.Fatalf("Signature.Inputs[%s]: unable to diff TensorInfos: %v", k, err) 121 } 122 if diff != "" { 123 t.Errorf("Signature.Inputs[%s] diff:\n%s", k, diff) 124 } 125 } 126 127 for k, output := range want.Outputs { 128 diff, err := diffTensorInfos(got.Outputs[k], output) 129 if err != nil { 130 t.Fatalf("Signature.Outputs[%s]: unable to diff TensorInfos: %v", k, err) 131 } 132 if diff != "" { 133 t.Errorf("Signature.Outputs[%s] diff:\n%s", k, diff) 134 } 135 } 136 137 if got.MethodName != want.MethodName { 138 t.Errorf("Signature.MethodName: got %q, want %q", got.MethodName, want.MethodName) 139 } 140} 141 142func TestTensorInfoFromProto(t *testing.T) { 143 got := tensorInfoFromProto(&corepb.TensorInfo{ 144 Encoding: &corepb.TensorInfo_Name{ 145 Name: "tensor", 146 }, 147 Dtype: typb.DataType_DT_INT8, 148 TensorShape: &tspb.TensorShapeProto{ 149 Dim: []*tspb.TensorShapeProto_Dim{ 150 {Size: 1}, 151 {Size: 2}, 152 {Size: 3}, 153 }, 154 }, 155 }) 156 want := TensorInfo{ 157 Name: "tensor", 158 DType: Int8, 159 Shape: MakeShape(1, 2, 3), 160 } 161 162 diff, err := diffTensorInfos(got, want) 163 if err != nil { 164 t.Fatalf("Unable to diff TensorInfos: %v", err) 165 } 166 if diff != "" { 167 t.Errorf("tensorInfoFromProto produced a diff (got -> want): %s", diff) 168 } 169} 170 171func diffTensorInfos(a, b TensorInfo) (string, error) { 172 diff := "" 173 if a.Name != b.Name { 174 diff += fmt.Sprintf("Name: %q -> %q\n", a.Name, b.Name) 175 } 176 if a.DType != b.DType { 177 diff += fmt.Sprintf("DType: %v -> %v\n", a.DType, b.DType) 178 } 179 180 aShape, err := a.Shape.ToSlice() 181 if err != nil { 182 return "", err 183 } 184 bShape, err := b.Shape.ToSlice() 185 if err != nil { 186 return "", err 187 } 188 shapeLen := len(aShape) 189 if len(bShape) > shapeLen { 190 shapeLen = len(bShape) 191 } 192 for i := 0; i < shapeLen; i++ { 193 if i >= len(aShape) { 194 diff += fmt.Sprintf("+Shape[%d]: %d\n", i, bShape[i]) 195 continue 196 } 197 if i >= len(bShape) { 198 diff += fmt.Sprintf("-Shape[%d]: %d\n", i, aShape[i]) 199 continue 200 } 201 if aShape[i] != bShape[i] { 202 diff += fmt.Sprintf("Shape[%d]: %d -> %d\n", i, aShape[i], bShape[i]) 203 } 204 } 205 206 return diff, nil 207} 208