xref: /aosp_15_r20/external/tensorflow/tensorflow/go/signature_test.go (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/*
2Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package tensorflow
18
19import (
20	"fmt"
21	"testing"
22
23	tspb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto"
24	typb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/types_go_proto"
25	corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"
26)
27
28func TestSignatureFromProto(t *testing.T) {
29	got := signatureDefFromProto(&corepb.SignatureDef{
30		Inputs: map[string]*corepb.TensorInfo{
31			"input_1": &corepb.TensorInfo{
32				Encoding: &corepb.TensorInfo_Name{
33					Name: "tensor_1",
34				},
35				Dtype: typb.DataType_DT_INT8,
36				TensorShape: &tspb.TensorShapeProto{
37					Dim: []*tspb.TensorShapeProto_Dim{
38						{Size: 1},
39						{Size: 2},
40						{Size: 3},
41					},
42				},
43			},
44			"input_2": &corepb.TensorInfo{
45				Encoding: &corepb.TensorInfo_Name{
46					Name: "tensor_2",
47				},
48				Dtype: typb.DataType_DT_FLOAT,
49				TensorShape: &tspb.TensorShapeProto{
50					Dim: []*tspb.TensorShapeProto_Dim{
51						{Size: 4},
52						{Size: 5},
53						{Size: 6},
54					},
55				},
56			},
57		},
58		Outputs: map[string]*corepb.TensorInfo{
59			"output_1": &corepb.TensorInfo{
60				Encoding: &corepb.TensorInfo_Name{
61					Name: "tensor_3",
62				},
63				Dtype: typb.DataType_DT_STRING,
64				TensorShape: &tspb.TensorShapeProto{
65					Dim: []*tspb.TensorShapeProto_Dim{
66						{Size: 1},
67						{Size: 2},
68						{Size: 3},
69					},
70				},
71			},
72			"output_2": &corepb.TensorInfo{
73				Encoding: &corepb.TensorInfo_Name{
74					Name: "tensor_4",
75				},
76				Dtype: typb.DataType_DT_BOOL,
77				TensorShape: &tspb.TensorShapeProto{
78					Dim: []*tspb.TensorShapeProto_Dim{
79						{Size: 4},
80						{Size: 5},
81						{Size: 6},
82					},
83				},
84			},
85		},
86		MethodName: "method",
87	})
88
89	want := Signature{
90		Inputs: map[string]TensorInfo{
91			"input_1": TensorInfo{
92				Name:  "tensor_1",
93				DType: Int8,
94				Shape: MakeShape(1, 2, 3),
95			},
96			"input_2": TensorInfo{
97				Name:  "tensor_2",
98				DType: Float,
99				Shape: MakeShape(4, 5, 6),
100			},
101		},
102		Outputs: map[string]TensorInfo{
103			"output_1": TensorInfo{
104				Name:  "tensor_3",
105				DType: String,
106				Shape: MakeShape(1, 2, 3),
107			},
108			"output_2": TensorInfo{
109				Name:  "tensor_4",
110				DType: Bool,
111				Shape: MakeShape(4, 5, 6),
112			},
113		},
114		MethodName: "method",
115	}
116
117	for k, input := range want.Inputs {
118		diff, err := diffTensorInfos(got.Inputs[k], input)
119		if err != nil {
120			t.Fatalf("Signature.Inputs[%s]: unable to diff TensorInfos: %v", k, err)
121		}
122		if diff != "" {
123			t.Errorf("Signature.Inputs[%s] diff:\n%s", k, diff)
124		}
125	}
126
127	for k, output := range want.Outputs {
128		diff, err := diffTensorInfos(got.Outputs[k], output)
129		if err != nil {
130			t.Fatalf("Signature.Outputs[%s]: unable to diff TensorInfos: %v", k, err)
131		}
132		if diff != "" {
133			t.Errorf("Signature.Outputs[%s] diff:\n%s", k, diff)
134		}
135	}
136
137	if got.MethodName != want.MethodName {
138		t.Errorf("Signature.MethodName: got %q, want %q", got.MethodName, want.MethodName)
139	}
140}
141
142func TestTensorInfoFromProto(t *testing.T) {
143	got := tensorInfoFromProto(&corepb.TensorInfo{
144		Encoding: &corepb.TensorInfo_Name{
145			Name: "tensor",
146		},
147		Dtype: typb.DataType_DT_INT8,
148		TensorShape: &tspb.TensorShapeProto{
149			Dim: []*tspb.TensorShapeProto_Dim{
150				{Size: 1},
151				{Size: 2},
152				{Size: 3},
153			},
154		},
155	})
156	want := TensorInfo{
157		Name:  "tensor",
158		DType: Int8,
159		Shape: MakeShape(1, 2, 3),
160	}
161
162	diff, err := diffTensorInfos(got, want)
163	if err != nil {
164		t.Fatalf("Unable to diff TensorInfos: %v", err)
165	}
166	if diff != "" {
167		t.Errorf("tensorInfoFromProto produced a diff (got -> want): %s", diff)
168	}
169}
170
171func diffTensorInfos(a, b TensorInfo) (string, error) {
172	diff := ""
173	if a.Name != b.Name {
174		diff += fmt.Sprintf("Name: %q -> %q\n", a.Name, b.Name)
175	}
176	if a.DType != b.DType {
177		diff += fmt.Sprintf("DType: %v -> %v\n", a.DType, b.DType)
178	}
179
180	aShape, err := a.Shape.ToSlice()
181	if err != nil {
182		return "", err
183	}
184	bShape, err := b.Shape.ToSlice()
185	if err != nil {
186		return "", err
187	}
188	shapeLen := len(aShape)
189	if len(bShape) > shapeLen {
190		shapeLen = len(bShape)
191	}
192	for i := 0; i < shapeLen; i++ {
193		if i >= len(aShape) {
194			diff += fmt.Sprintf("+Shape[%d]: %d\n", i, bShape[i])
195			continue
196		}
197		if i >= len(bShape) {
198			diff += fmt.Sprintf("-Shape[%d]: %d\n", i, aShape[i])
199			continue
200		}
201		if aShape[i] != bShape[i] {
202			diff += fmt.Sprintf("Shape[%d]: %d -> %d\n", i, aShape[i], bShape[i])
203		}
204	}
205
206	return diff, nil
207}
208