1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package tensorflow 18 19import ( 20 "bytes" 21 "fmt" 22 "strings" 23 "testing" 24) 25 26func hasOperations(g *Graph, ops ...string) error { 27 var missing []string 28 for _, op := range ops { 29 if g.Operation(op) == nil { 30 missing = append(missing, op) 31 } 32 } 33 if len(missing) != 0 { 34 return fmt.Errorf("Graph does not have the operations %v", missing) 35 } 36 37 inList := map[string]bool{} 38 for _, op := range g.Operations() { 39 inList[op.Name()] = true 40 } 41 42 for _, op := range ops { 43 if !inList[op] { 44 missing = append(missing, op) 45 } 46 } 47 48 if len(missing) != 0 { 49 return fmt.Errorf("Operations %v are missing from graph.Operations()", missing) 50 } 51 52 return nil 53} 54 55func TestGraphWriteToAndImport(t *testing.T) { 56 // Construct a graph 57 g := NewGraph() 58 v, err := NewTensor(int64(1)) 59 if err != nil { 60 t.Fatal(err) 61 } 62 input, err := Placeholder(g, "input", v.DataType()) 63 if err != nil { 64 t.Fatal(err) 65 } 66 if _, err := Neg(g, "neg", input); err != nil { 67 t.Fatal(err) 68 } 69 70 // Serialize the graph 71 buf := new(bytes.Buffer) 72 if _, err := g.WriteTo(buf); err != nil { 73 t.Fatal(err) 74 } 75 76 // Import it into the same graph, with a prefix 77 if err := g.Import(buf.Bytes(), "imported"); err != nil { 78 t.Error(err) 79 } 80 if err := hasOperations(g, "input", "neg", "imported/input", "imported/neg"); err != nil { 81 t.Error(err) 82 } 83} 84 85func TestGraphInputMapping(t *testing.T) { 86 // Construct a graph 87 g := NewGraph() 88 v, err := NewTensor(int64(1)) 89 if err != nil { 90 t.Fatal(err) 91 } 92 input, err := Placeholder(g, "input", v.DataType()) 93 if err != nil { 94 t.Fatal(err) 95 } 96 neg, err := Neg(g, "neg", input) 97 if err != nil { 98 t.Fatal(err) 99 } 100 101 // Serialize the graph 102 buf := new(bytes.Buffer) 103 if _, err := g.WriteTo(buf); err != nil { 104 t.Fatal(err) 105 } 106 107 g = NewGraph() 108 v, err = NewTensor(int64(1)) 109 if err != nil { 110 t.Fatal(err) 111 } 112 113 replacement, err := Placeholder(g, "replacement", v.DataType()) 114 if err != nil { 115 t.Fatal(err) 116 } 117 118 options := GraphImportOptions{ 119 Prefix: "imported", 120 } 121 options.AddInputMapping("input", 0, replacement) 122 // Import it into the same graph, with a prefix and replacement 123 if err := g.ImportWithOptions(buf.Bytes(), options); err != nil { 124 t.Error(err) 125 } 126 if err := hasOperations(g, "replacement", "imported/neg"); err != nil { 127 t.Error(err) 128 } 129 130 sess, err := NewSession(g, nil) 131 if err != nil { 132 t.Fatal(err) 133 } 134 135 neg = g.Operation("imported/neg").Output(0) 136 137 outputs, err := sess.Run( 138 map[Output]*Tensor{replacement: v}, 139 []Output{neg}, 140 nil) 141 if err != nil { 142 t.Fatal(err) 143 } 144 if len(outputs) != 1 { 145 t.Fatal(len(outputs)) 146 } 147 if outputs[0].Value().(int64) != -1 { 148 t.Fatalf("Got %v, wanted int64 -1", outputs[0].Value()) 149 } 150} 151 152func TestGraphAddGradients(t *testing.T) { 153 g := NewGraph() 154 x1, err := Placeholder(g, "x1", Float) 155 if err != nil { 156 t.Fatal(err) 157 } 158 x2, err := Placeholder(g, "x2", Float) 159 if err != nil { 160 t.Fatal(err) 161 } 162 op0, err := g.AddOperation(OpSpec{ 163 Type: "Square", 164 Name: "y0", 165 Input: []Input{x1}, 166 }) 167 if err != nil { 168 t.Fatal(err) 169 } 170 y0 := op0.Output(0) 171 op1, err := g.AddOperation(OpSpec{ 172 Type: "Square", 173 Name: "y1", 174 Input: []Input{y0}, 175 }) 176 if err != nil { 177 t.Fatal(err) 178 } 179 y1 := op1.Output(0) 180 op2, err := g.AddOperation(OpSpec{ 181 Type: "AddN", 182 Input: []Input{OutputList([]Output{y0, x2})}, 183 }) 184 if err != nil { 185 t.Fatal(err) 186 } 187 y2 := op2.Output(0) 188 189 grads0, err := g.AddGradients("", []Output{y1}, []Output{x1}, nil) 190 if err != nil { 191 t.Fatal(err) 192 } 193 if len(grads0) != 1 { 194 t.Fatal(len(grads0)) 195 } 196 if grads0[0].DataType() != Float { 197 t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float) 198 } 199 200 grads1, err := g.AddGradients("", []Output{y2}, []Output{x1, x2}, nil) 201 if err != nil { 202 t.Fatal(err) 203 } 204 if len(grads1) != 2 { 205 t.Fatal(len(grads1)) 206 } 207 if grads1[0].DataType() != Float { 208 t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), Float) 209 } 210 if grads1[1].DataType() != Float { 211 t.Fatalf("Got DataType %v, wanted %v", grads1[1].DataType(), Float) 212 } 213 214 sess, err := NewSession(g, nil) 215 if err != nil { 216 t.Fatal(err) 217 } 218 219 c1, _ := NewTensor(float32(3.0)) 220 c2, _ := NewTensor(float32(2.0)) 221 outputs, err := sess.Run( 222 map[Output]*Tensor{x1: c1, x2: c2}, 223 []Output{grads0[0], grads1[0], grads1[1]}, 224 nil) 225 if err != nil { 226 t.Fatal(err) 227 } 228 if len(outputs) != 3 { 229 t.Fatal(len(outputs)) 230 } 231 if outputs[0].Value().(float32) != 108.0 { 232 t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value()) 233 } 234 if outputs[1].Value().(float32) != 6.0 { 235 t.Fatalf("Got %v, wanted float 6.0", outputs[1].Value()) 236 } 237 if outputs[2].Value().(float32) != 1.0 { 238 t.Fatalf("Got %v, wanted float 1.0", outputs[2].Value()) 239 } 240} 241 242func TestGraphAddGradientsSums(t *testing.T) { 243 g := NewGraph() 244 x, err := Placeholder(g, "x", Float) 245 if err != nil { 246 t.Fatal(err) 247 } 248 op0, err := g.AddOperation(OpSpec{ 249 Type: "Square", 250 Name: "y0", 251 Input: []Input{x}, 252 }) 253 if err != nil { 254 t.Fatal(err) 255 } 256 y0 := op0.Output(0) 257 op1, err := g.AddOperation(OpSpec{ 258 Type: "Square", 259 Name: "y1", 260 Input: []Input{y0}, 261 }) 262 y1 := op1.Output(0) 263 264 grad, err := g.AddGradients("", []Output{y0, y1}, []Output{x}, nil) 265 if err != nil { 266 t.Fatal(err) 267 } 268 if len(grad) != 1 { 269 t.Fatal(len(grad)) 270 } 271 if grad[0].DataType() != Float { 272 t.Fatalf("Got DataType %v, wanted %v", grad[0].DataType(), Float) 273 } 274 275 sess, err := NewSession(g, nil) 276 if err != nil { 277 t.Fatal(err) 278 } 279 280 c, _ := NewTensor(float32(3.0)) 281 outputs, err := sess.Run( 282 map[Output]*Tensor{x: c}, 283 []Output{grad[0]}, 284 nil) 285 if err != nil { 286 t.Fatal(err) 287 } 288 if outputs[0].Value().(float32) != 114.0 { 289 t.Fatalf("Got %v, wanted float 114.0", outputs[0].Value()) 290 } 291} 292 293func TestGraphAddGradientsWithInitialValues(t *testing.T) { 294 g := NewGraph() 295 x, err := Placeholder(g, "x", Float) 296 op0, err := g.AddOperation(OpSpec{ 297 Type: "Square", 298 Name: "y0", 299 Input: []Input{x}, 300 }) 301 if err != nil { 302 t.Fatal(err) 303 } 304 y0 := op0.Output(0) 305 op1, err := g.AddOperation(OpSpec{ 306 Type: "Square", 307 Name: "y1", 308 Input: []Input{y0}, 309 }) 310 if err != nil { 311 t.Fatal(err) 312 } 313 y1 := op1.Output(0) 314 315 grads0, err := g.AddGradients("", []Output{y1}, []Output{y0}, nil) 316 if err != nil { 317 t.Fatal(err) 318 } 319 if len(grads0) != 1 { 320 t.Fatal(len(grads0)) 321 } 322 if grads0[0].DataType() != Float { 323 t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float) 324 } 325 326 grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, []Output{grads0[0]}) 327 if err != nil { 328 t.Fatal(err) 329 } 330 if len(grads1) != 1 { 331 t.Fatal(len(grads1)) 332 } 333 if grads1[0].DataType() != Float { 334 t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), Float) 335 } 336 337 sess, err := NewSession(g, nil) 338 if err != nil { 339 t.Fatal(err) 340 } 341 342 c, _ := NewTensor(float32(3.0)) 343 outputs, err := sess.Run( 344 map[Output]*Tensor{x: c}, 345 []Output{grads1[0]}, 346 nil) 347 if err != nil { 348 t.Fatal(err) 349 } 350 if outputs[0].Value().(float32) != 108.0 { 351 t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value()) 352 } 353} 354 355func TestGraphValidateGradientsNames(t *testing.T) { 356 g := NewGraph() 357 x, err := Placeholder(g, "x", Float) 358 if err != nil { 359 t.Fatal(err) 360 } 361 op0, err := g.AddOperation(OpSpec{ 362 Type: "Square", 363 Name: "y0", 364 Input: []Input{x}, 365 }) 366 if err != nil { 367 t.Fatal(err) 368 } 369 y0 := op0.Output(0) 370 371 grads0, err := g.AddGradients("", []Output{y0}, []Output{x}, nil) 372 if err != nil { 373 t.Fatal(err) 374 } 375 if !strings.HasPrefix(grads0[0].Op.Name(), "gradients/") { 376 t.Fatalf("Got name %v, wanted started with gradients/", grads0[0].Op.Name()) 377 } 378 379 grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, nil) 380 if err != nil { 381 t.Fatal(err) 382 } 383 if !strings.HasPrefix(grads1[0].Op.Name(), "gradients_1/") { 384 t.Fatalf("Got name %v, wanted started with gradients_1/", grads1[0].Op.Name()) 385 } 386 387 grads2, err := g.AddGradients("more_gradients", []Output{y0}, []Output{x}, nil) 388 if err != nil { 389 t.Fatal(err) 390 } 391 if !strings.HasPrefix(grads2[0].Op.Name(), "more_gradients/") { 392 t.Fatalf("Got name %v, wanted started with more_gradients/", grads2[0].Op.Name()) 393 } 394 395 grads3, err := g.AddGradients("even_more_gradients", []Output{y0}, []Output{x}, nil) 396 if err != nil { 397 t.Fatal(err) 398 } 399 if !strings.HasPrefix(grads3[0].Op.Name(), "even_more_gradients/") { 400 t.Fatalf("Got name %v, wanted started with even_more_gradients/", grads3[0].Op.Name()) 401 } 402 403 _, err = g.AddGradients("even_more_gradients", []Output{y0}, []Output{x}, nil) 404 if err == nil { 405 t.Error("AddGradients should have failed if gradients name is already existing") 406 } 407} 408