xref: /aosp_15_r20/external/tensorflow/tensorflow/go/graph_test.go (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package tensorflow
18
19import (
20	"bytes"
21	"fmt"
22	"strings"
23	"testing"
24)
25
26func hasOperations(g *Graph, ops ...string) error {
27	var missing []string
28	for _, op := range ops {
29		if g.Operation(op) == nil {
30			missing = append(missing, op)
31		}
32	}
33	if len(missing) != 0 {
34		return fmt.Errorf("Graph does not have the operations %v", missing)
35	}
36
37	inList := map[string]bool{}
38	for _, op := range g.Operations() {
39		inList[op.Name()] = true
40	}
41
42	for _, op := range ops {
43		if !inList[op] {
44			missing = append(missing, op)
45		}
46	}
47
48	if len(missing) != 0 {
49		return fmt.Errorf("Operations %v are missing from graph.Operations()", missing)
50	}
51
52	return nil
53}
54
55func TestGraphWriteToAndImport(t *testing.T) {
56	// Construct a graph
57	g := NewGraph()
58	v, err := NewTensor(int64(1))
59	if err != nil {
60		t.Fatal(err)
61	}
62	input, err := Placeholder(g, "input", v.DataType())
63	if err != nil {
64		t.Fatal(err)
65	}
66	if _, err := Neg(g, "neg", input); err != nil {
67		t.Fatal(err)
68	}
69
70	// Serialize the graph
71	buf := new(bytes.Buffer)
72	if _, err := g.WriteTo(buf); err != nil {
73		t.Fatal(err)
74	}
75
76	// Import it into the same graph, with a prefix
77	if err := g.Import(buf.Bytes(), "imported"); err != nil {
78		t.Error(err)
79	}
80	if err := hasOperations(g, "input", "neg", "imported/input", "imported/neg"); err != nil {
81		t.Error(err)
82	}
83}
84
85func TestGraphInputMapping(t *testing.T) {
86	// Construct a graph
87	g := NewGraph()
88	v, err := NewTensor(int64(1))
89	if err != nil {
90		t.Fatal(err)
91	}
92	input, err := Placeholder(g, "input", v.DataType())
93	if err != nil {
94		t.Fatal(err)
95	}
96	neg, err := Neg(g, "neg", input)
97	if err != nil {
98		t.Fatal(err)
99	}
100
101	// Serialize the graph
102	buf := new(bytes.Buffer)
103	if _, err := g.WriteTo(buf); err != nil {
104		t.Fatal(err)
105	}
106
107	g = NewGraph()
108	v, err = NewTensor(int64(1))
109	if err != nil {
110		t.Fatal(err)
111	}
112
113	replacement, err := Placeholder(g, "replacement", v.DataType())
114	if err != nil {
115		t.Fatal(err)
116	}
117
118	options := GraphImportOptions{
119		Prefix: "imported",
120	}
121	options.AddInputMapping("input", 0, replacement)
122	// Import it into the same graph, with a prefix and replacement
123	if err := g.ImportWithOptions(buf.Bytes(), options); err != nil {
124		t.Error(err)
125	}
126	if err := hasOperations(g, "replacement", "imported/neg"); err != nil {
127		t.Error(err)
128	}
129
130	sess, err := NewSession(g, nil)
131	if err != nil {
132		t.Fatal(err)
133	}
134
135	neg = g.Operation("imported/neg").Output(0)
136
137	outputs, err := sess.Run(
138		map[Output]*Tensor{replacement: v},
139		[]Output{neg},
140		nil)
141	if err != nil {
142		t.Fatal(err)
143	}
144	if len(outputs) != 1 {
145		t.Fatal(len(outputs))
146	}
147	if outputs[0].Value().(int64) != -1 {
148		t.Fatalf("Got %v, wanted int64 -1", outputs[0].Value())
149	}
150}
151
152func TestGraphAddGradients(t *testing.T) {
153	g := NewGraph()
154	x1, err := Placeholder(g, "x1", Float)
155	if err != nil {
156		t.Fatal(err)
157	}
158	x2, err := Placeholder(g, "x2", Float)
159	if err != nil {
160		t.Fatal(err)
161	}
162	op0, err := g.AddOperation(OpSpec{
163		Type:  "Square",
164		Name:  "y0",
165		Input: []Input{x1},
166	})
167	if err != nil {
168		t.Fatal(err)
169	}
170	y0 := op0.Output(0)
171	op1, err := g.AddOperation(OpSpec{
172		Type:  "Square",
173		Name:  "y1",
174		Input: []Input{y0},
175	})
176	if err != nil {
177		t.Fatal(err)
178	}
179	y1 := op1.Output(0)
180	op2, err := g.AddOperation(OpSpec{
181		Type:  "AddN",
182		Input: []Input{OutputList([]Output{y0, x2})},
183	})
184	if err != nil {
185		t.Fatal(err)
186	}
187	y2 := op2.Output(0)
188
189	grads0, err := g.AddGradients("", []Output{y1}, []Output{x1}, nil)
190	if err != nil {
191		t.Fatal(err)
192	}
193	if len(grads0) != 1 {
194		t.Fatal(len(grads0))
195	}
196	if grads0[0].DataType() != Float {
197		t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float)
198	}
199
200	grads1, err := g.AddGradients("", []Output{y2}, []Output{x1, x2}, nil)
201	if err != nil {
202		t.Fatal(err)
203	}
204	if len(grads1) != 2 {
205		t.Fatal(len(grads1))
206	}
207	if grads1[0].DataType() != Float {
208		t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), Float)
209	}
210	if grads1[1].DataType() != Float {
211		t.Fatalf("Got DataType %v, wanted %v", grads1[1].DataType(), Float)
212	}
213
214	sess, err := NewSession(g, nil)
215	if err != nil {
216		t.Fatal(err)
217	}
218
219	c1, _ := NewTensor(float32(3.0))
220	c2, _ := NewTensor(float32(2.0))
221	outputs, err := sess.Run(
222		map[Output]*Tensor{x1: c1, x2: c2},
223		[]Output{grads0[0], grads1[0], grads1[1]},
224		nil)
225	if err != nil {
226		t.Fatal(err)
227	}
228	if len(outputs) != 3 {
229		t.Fatal(len(outputs))
230	}
231	if outputs[0].Value().(float32) != 108.0 {
232		t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
233	}
234	if outputs[1].Value().(float32) != 6.0 {
235		t.Fatalf("Got %v, wanted float 6.0", outputs[1].Value())
236	}
237	if outputs[2].Value().(float32) != 1.0 {
238		t.Fatalf("Got %v, wanted float 1.0", outputs[2].Value())
239	}
240}
241
242func TestGraphAddGradientsSums(t *testing.T) {
243	g := NewGraph()
244	x, err := Placeholder(g, "x", Float)
245	if err != nil {
246		t.Fatal(err)
247	}
248	op0, err := g.AddOperation(OpSpec{
249		Type:  "Square",
250		Name:  "y0",
251		Input: []Input{x},
252	})
253	if err != nil {
254		t.Fatal(err)
255	}
256	y0 := op0.Output(0)
257	op1, err := g.AddOperation(OpSpec{
258		Type:  "Square",
259		Name:  "y1",
260		Input: []Input{y0},
261	})
262	y1 := op1.Output(0)
263
264	grad, err := g.AddGradients("", []Output{y0, y1}, []Output{x}, nil)
265	if err != nil {
266		t.Fatal(err)
267	}
268	if len(grad) != 1 {
269		t.Fatal(len(grad))
270	}
271	if grad[0].DataType() != Float {
272		t.Fatalf("Got DataType %v, wanted %v", grad[0].DataType(), Float)
273	}
274
275	sess, err := NewSession(g, nil)
276	if err != nil {
277		t.Fatal(err)
278	}
279
280	c, _ := NewTensor(float32(3.0))
281	outputs, err := sess.Run(
282		map[Output]*Tensor{x: c},
283		[]Output{grad[0]},
284		nil)
285	if err != nil {
286		t.Fatal(err)
287	}
288	if outputs[0].Value().(float32) != 114.0 {
289		t.Fatalf("Got %v, wanted float 114.0", outputs[0].Value())
290	}
291}
292
293func TestGraphAddGradientsWithInitialValues(t *testing.T) {
294	g := NewGraph()
295	x, err := Placeholder(g, "x", Float)
296	op0, err := g.AddOperation(OpSpec{
297		Type:  "Square",
298		Name:  "y0",
299		Input: []Input{x},
300	})
301	if err != nil {
302		t.Fatal(err)
303	}
304	y0 := op0.Output(0)
305	op1, err := g.AddOperation(OpSpec{
306		Type:  "Square",
307		Name:  "y1",
308		Input: []Input{y0},
309	})
310	if err != nil {
311		t.Fatal(err)
312	}
313	y1 := op1.Output(0)
314
315	grads0, err := g.AddGradients("", []Output{y1}, []Output{y0}, nil)
316	if err != nil {
317		t.Fatal(err)
318	}
319	if len(grads0) != 1 {
320		t.Fatal(len(grads0))
321	}
322	if grads0[0].DataType() != Float {
323		t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float)
324	}
325
326	grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, []Output{grads0[0]})
327	if err != nil {
328		t.Fatal(err)
329	}
330	if len(grads1) != 1 {
331		t.Fatal(len(grads1))
332	}
333	if grads1[0].DataType() != Float {
334		t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), Float)
335	}
336
337	sess, err := NewSession(g, nil)
338	if err != nil {
339		t.Fatal(err)
340	}
341
342	c, _ := NewTensor(float32(3.0))
343	outputs, err := sess.Run(
344		map[Output]*Tensor{x: c},
345		[]Output{grads1[0]},
346		nil)
347	if err != nil {
348		t.Fatal(err)
349	}
350	if outputs[0].Value().(float32) != 108.0 {
351		t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
352	}
353}
354
355func TestGraphValidateGradientsNames(t *testing.T) {
356	g := NewGraph()
357	x, err := Placeholder(g, "x", Float)
358	if err != nil {
359		t.Fatal(err)
360	}
361	op0, err := g.AddOperation(OpSpec{
362		Type:  "Square",
363		Name:  "y0",
364		Input: []Input{x},
365	})
366	if err != nil {
367		t.Fatal(err)
368	}
369	y0 := op0.Output(0)
370
371	grads0, err := g.AddGradients("", []Output{y0}, []Output{x}, nil)
372	if err != nil {
373		t.Fatal(err)
374	}
375	if !strings.HasPrefix(grads0[0].Op.Name(), "gradients/") {
376		t.Fatalf("Got name %v, wanted started with gradients/", grads0[0].Op.Name())
377	}
378
379	grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, nil)
380	if err != nil {
381		t.Fatal(err)
382	}
383	if !strings.HasPrefix(grads1[0].Op.Name(), "gradients_1/") {
384		t.Fatalf("Got name %v, wanted started with gradients_1/", grads1[0].Op.Name())
385	}
386
387	grads2, err := g.AddGradients("more_gradients", []Output{y0}, []Output{x}, nil)
388	if err != nil {
389		t.Fatal(err)
390	}
391	if !strings.HasPrefix(grads2[0].Op.Name(), "more_gradients/") {
392		t.Fatalf("Got name %v, wanted started with more_gradients/", grads2[0].Op.Name())
393	}
394
395	grads3, err := g.AddGradients("even_more_gradients", []Output{y0}, []Output{x}, nil)
396	if err != nil {
397		t.Fatal(err)
398	}
399	if !strings.HasPrefix(grads3[0].Op.Name(), "even_more_gradients/") {
400		t.Fatalf("Got name %v, wanted started with even_more_gradients/", grads3[0].Op.Name())
401	}
402
403	_, err = g.AddGradients("even_more_gradients", []Output{y0}, []Output{x}, nil)
404	if err == nil {
405		t.Error("AddGradients should have failed if gradients name is already existing")
406	}
407}
408