1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.op; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertNotNull; 20 import static org.junit.Assert.fail; 21 22 import java.util.HashMap; 23 import java.util.Map; 24 25 import org.junit.Test; 26 import org.junit.runner.RunWith; 27 import org.junit.runners.JUnit4; 28 import org.tensorflow.Graph; 29 import org.tensorflow.Output; 30 import org.tensorflow.Session; 31 import org.tensorflow.Tensor; 32 import org.tensorflow.Tensors; 33 import org.tensorflow.types.UInt8; 34 35 /** Unit tests for {@link org.tensorflow.Scope}. */ 36 @RunWith(JUnit4.class) 37 public class ScopeTest { 38 39 @Test basicNames()40 public void basicNames() { 41 try (Graph g = new Graph()) { 42 Scope root = new Scope(g); 43 assertEquals("add", root.makeOpName("add")); 44 assertEquals("add_1", root.makeOpName("add")); 45 assertEquals("add_2", root.makeOpName("add")); 46 assertEquals("mul", root.makeOpName("mul")); 47 } 48 } 49 50 @Test hierarchicalNames()51 public void hierarchicalNames() { 52 try (Graph g = new Graph()) { 53 Scope root = new Scope(g); 54 Scope child = root.withSubScope("child"); 55 assertEquals("child/add", child.makeOpName("add")); 56 assertEquals("child/add_1", child.makeOpName("add")); 57 assertEquals("child/mul", child.makeOpName("mul")); 58 59 Scope child_1 = root.withSubScope("child"); 60 assertEquals("child_1/add", child_1.makeOpName("add")); 61 assertEquals("child_1/add_1", child_1.makeOpName("add")); 62 assertEquals("child_1/mul", child_1.makeOpName("mul")); 63 64 Scope c_c = root.withSubScope("c").withSubScope("c"); 65 assertEquals("c/c/add", c_c.makeOpName("add")); 66 67 Scope c_1 = root.withSubScope("c"); 68 Scope c_1_c = c_1.withSubScope("c"); 69 assertEquals("c_1/c/add", c_1_c.makeOpName("add")); 70 71 Scope c_1_c_1 = c_1.withSubScope("c"); 72 assertEquals("c_1/c_1/add", c_1_c_1.makeOpName("add")); 73 } 74 } 75 76 @Test scopeAndOpNames()77 public void scopeAndOpNames() { 78 try (Graph g = new Graph()) { 79 Scope root = new Scope(g); 80 81 Scope child = root.withSubScope("child"); 82 83 assertEquals("child/add", child.makeOpName("add")); 84 assertEquals("child_1", root.makeOpName("child")); 85 assertEquals("child_2/p", root.withSubScope("child").makeOpName("p")); 86 } 87 } 88 89 @Test validateNames()90 public void validateNames() { 91 try (Graph g = new Graph()) { 92 Scope root = new Scope(g); 93 94 final String[] invalid_names = { 95 "_", "-", "-x", // Names are constrained to start with [A-Za-z0-9.] 96 null, "", "a$", // Invalid characters 97 "a/b", // slashes not allowed 98 }; 99 100 for (String name : invalid_names) { 101 try { 102 root.withName(name); 103 fail("failed to catch invalid op name."); 104 } catch (IllegalArgumentException ex) { 105 // expected 106 } 107 // Subscopes follow the same rules 108 try { 109 root.withSubScope(name); 110 fail("failed to catch invalid scope name: " + name); 111 } catch (IllegalArgumentException ex) { 112 // expected 113 } 114 } 115 116 // Unusual but valid names. 117 final String[] valid_names = {".", "..", "._-.", "a--."}; 118 119 for (String name : valid_names) { 120 root.withName(name); 121 root.withSubScope(name); 122 } 123 } 124 } 125 126 @Test basic()127 public void basic() { 128 try (Graph g = new Graph()) { 129 Scope s = new Scope(g); 130 Const<Integer> c1 = Const.create(s, 42); 131 assertEquals("Const", c1.output().op().name()); 132 Const<Integer> c2 = Const.create(s, 7); 133 assertEquals("Const_1", c2.output().op().name()); 134 Const<Integer> c3 = Const.create(s.withName("four"), 4); 135 assertEquals("four", c3.output().op().name()); 136 Const<Integer> c4 = Const.create(s.withName("four"), 4); 137 assertEquals("four_1", c4.output().op().name()); 138 } 139 } 140 141 @Test hierarchy()142 public void hierarchy() { 143 try (Graph g = new Graph()) { 144 Scope root = new Scope(g); 145 Scope child = root.withSubScope("child"); 146 assertEquals("child/Const", Const.create(child, 42).output().op().name()); 147 assertEquals("child/four", Const.create(child.withName("four"), 4).output().op().name()); 148 } 149 } 150 151 @Test composite()152 public void composite() { 153 try (Graph g = new Graph(); 154 Session sess = new Session(g)) { 155 Scope s = new Scope(g); 156 Output<Integer> data = 157 Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output(); 158 159 // Create a composite op with a customized name 160 Variance<Integer> var1 = Variance.create(s.withName("example"), data, Integer.class); 161 assertEquals("example/variance", var1.output().op().name()); 162 163 // Confirm internally added ops have the right names. 164 assertNotNull(g.operation("example/squared_deviation")); 165 assertNotNull(g.operation("example/Mean")); 166 // assertNotNull(g.operation("example/zero")); 167 168 // Same composite op with a default name 169 Variance<Integer> var2 = Variance.create(s, data, Integer.class); 170 assertEquals("variance/variance", var2.output().op().name()); 171 172 // Confirm internally added ops have the right names. 173 assertNotNull(g.operation("variance/squared_deviation")); 174 assertNotNull(g.operation("variance/Mean")); 175 // assertNotNull(g.operation("variance/zero")); 176 177 // Verify correct results as well. 178 Tensor<Integer> result = 179 sess.runner().fetch(var1.output()).run().get(0).expect(Integer.class); 180 assertEquals(21704, result.intValue()); 181 result = sess.runner().fetch(var2.output()).run().get(0).expect(Integer.class); 182 assertEquals(21704, result.intValue()); 183 } 184 } 185 186 // "handwritten" sample operator classes 187 private static final class Const<T> { 188 private final Output<T> output; 189 create(Scope s, int v)190 static Const<Integer> create(Scope s, int v) { 191 return create(s, Tensors.create(v)); 192 } 193 create(Scope s, int[] v)194 static Const<Integer> create(Scope s, int[] v) { 195 return create(s, Tensors.create(v)); 196 } 197 create(Scope s, Tensor<T> value)198 static <T> Const<T> create(Scope s, Tensor<T> value) { 199 return new Const<T>( 200 s.env() 201 .opBuilder("Const", s.makeOpName("Const")) 202 .setAttr("dtype", value.dataType()) 203 .setAttr("value", value) 204 .build() 205 .<T>output(0)); 206 } 207 create(Scope s, Object v, Class<T> type)208 static <T> Const<T> create(Scope s, Object v, Class<T> type) { 209 try (Tensor<T> value = Tensor.create(v, type)) { 210 return new Const<T>( 211 s.env() 212 .opBuilder("Const", s.makeOpName("Const")) 213 .setAttr("dtype", value.dataType()) 214 .setAttr("value", value) 215 .build() 216 .<T>output(0)); 217 } 218 } 219 Const(Output<T> o)220 Const(Output<T> o) { 221 output = o; 222 } 223 output()224 Output<T> output() { 225 return output; 226 } 227 } 228 229 private static final class Mean<T> { 230 private final Output<T> output; 231 create(Scope s, Output<T> input, Output<T> reductionIndices)232 static <T> Mean<T> create(Scope s, Output<T> input, Output<T> reductionIndices) { 233 return new Mean<T>( 234 s.env() 235 .opBuilder("Mean", s.makeOpName("Mean")) 236 .addInput(input) 237 .addInput(reductionIndices) 238 .build() 239 .<T>output(0)); 240 } 241 Mean(Output<T> o)242 Mean(Output<T> o) { 243 output = o; 244 } 245 output()246 Output<T> output() { 247 return output; 248 } 249 } 250 251 private static final class SquaredDifference<T> { 252 private final Output<T> output; 253 create(Scope s, Output<T> x, Output<T> y)254 static <T> SquaredDifference<T> create(Scope s, Output<T> x, Output<T> y) { 255 return new SquaredDifference<T>( 256 s.env() 257 .opBuilder("SquaredDifference", s.makeOpName("SquaredDifference")) 258 .addInput(x) 259 .addInput(y) 260 .build() 261 .<T>output(0)); 262 } 263 SquaredDifference(Output<T> o)264 SquaredDifference(Output<T> o) { 265 output = o; 266 } 267 output()268 Output<T> output() { 269 return output; 270 } 271 } 272 273 /** 274 * Returns the zero value of type described by {@code c}, or null if the type (e.g., string) is 275 * not numeric and therefore has no zero value. 276 * 277 * @param c The class describing the TensorFlow type of interest. 278 */ zeroValue(Class<?> c)279 public static Object zeroValue(Class<?> c) { 280 return zeros.get(c); 281 } 282 283 private static final Map<Class<?>, Object> zeros = new HashMap<>(); 284 285 static { zeros.put(Float.class, 0.0f)286 zeros.put(Float.class, 0.0f); zeros.put(Double.class, 0.0)287 zeros.put(Double.class, 0.0); zeros.put(Integer.class, 0)288 zeros.put(Integer.class, 0); zeros.put(UInt8.class, (byte) 0)289 zeros.put(UInt8.class, (byte) 0); zeros.put(Long.class, 0L)290 zeros.put(Long.class, 0L); zeros.put(Boolean.class, false)291 zeros.put(Boolean.class, false); zeros.put(String.class, null)292 zeros.put(String.class, null); // no zero value 293 } 294 295 private static final class Variance<T> { 296 private final Output<T> output; 297 create(Scope base, Output<T> x, Class<T> type)298 static <T> Variance<T> create(Scope base, Output<T> x, Class<T> type) { 299 Scope s = base.withSubScope("variance"); 300 Output<T> zero = Const.create(base, zeroValue(type), type).output(); 301 Output<T> sqdiff = 302 SquaredDifference.create( 303 s.withName("squared_deviation"), x, Mean.create(s, x, zero).output()) 304 .output(); 305 306 return new Variance<T>(Mean.create(s.withName("variance"), sqdiff, zero).output()); 307 } 308 Variance(Output<T> o)309 Variance(Output<T> o) { 310 output = o; 311 } 312 output()313 Output<T> output() { 314 return output; 315 } 316 } 317 } 318