#include #include #include #include namespace torch { namespace jit { class UnionTypeTest : public ::testing::Test { public: // None const TypePtr none = NoneType::get(); // List[str] const TypePtr l1 = ListType::ofStrings(); // Optional[int] const TypePtr opt1 = OptionalType::create(IntType::get()); // Optional[float] const TypePtr opt2 = OptionalType::create(FloatType::get()); // Optional[List[str]] const TypePtr opt3 = OptionalType::create(ListType::ofStrings()); // Tuple[Optional[int], int] const TypePtr tup1 = TupleType::create({OptionalType::create(IntType::get()), IntType::get()}); // Tuple[int, int] const TypePtr tup2 = TupleType::create({IntType::get(), IntType::get()}); bool hasType(UnionTypePtr u, TypePtr t) { auto res = std::find(u->getTypes().begin(), u->getTypes().end(), t); return res != u->getTypes().end(); } }; TEST_F(UnionTypeTest, UnionOperatorEquals) { const UnionTypePtr u1 = UnionType::create({l1, tup2, StringType::get()}); // Same thing, but using different TypePtrs const TypePtr l1_ = ListType::ofStrings(); const TypePtr tup2_ = TupleType::create({IntType::get(), IntType::get()}); const UnionTypePtr u2 = UnionType::create({l1_, tup2_, StringType::get()}); ASSERT_TRUE(*u1 == *u2); } TEST_F(UnionTypeTest, UnionCreate_OptionalT1AndOptionalT2) { // Goal: Union[int, float, None] const UnionTypePtr u = UnionType::create({opt1, opt2}); ASSERT_EQ(u->getTypes().size(), 3); ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get())); ASSERT_TRUE(UnionTypeTest::hasType(u, FloatType::get())); ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get())); } TEST_F(UnionTypeTest, UnionCreate_OptionalTAndT) { // Goal: Union[int, None] const UnionTypePtr u = UnionType::create({opt1, IntType::get()}); ASSERT_EQ(u->getTypes().size(), 2); ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get())); ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get())); } TEST_F(UnionTypeTest, UnionCreate_TupleWithSubtypingRelationship) { // Goal: Union[Tuple[Optional[int], int], str] const UnionTypePtr u = UnionType::create({StringType::get(), tup1, tup2}); ASSERT_EQ(u->getTypes().size(), 2); ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get())); ASSERT_TRUE(UnionTypeTest::hasType(u, tup1)); } TEST_F(UnionTypeTest, UnionCreate_ContainerTAndT) { // Goal: Union[List[str], str] const UnionTypePtr u = UnionType::create({l1, StringType::get()}); ASSERT_EQ(u->getTypes().size(), 2); ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get())); ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings())); } TEST_F(UnionTypeTest, UnionCreate_OptionalContainerTAndContainerTAndT) { // Goal: Union[List[str], None, str] const UnionTypePtr u = UnionType::create({l1, opt3, StringType::get()}); ASSERT_EQ(u->getTypes().size(), 3); ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get())); ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings())); } TEST_F(UnionTypeTest, Subtyping_NumberType) { // Union[int, float, Complex] const UnionTypePtr union1 = UnionType::create({IntType::get(), FloatType::get(), ComplexType::get()}); // Union[int, float, Complex, None] const UnionTypePtr union2 = UnionType::create( {IntType::get(), FloatType::get(), ComplexType::get(), NoneType::get()}); const NumberTypePtr num = NumberType::get(); ASSERT_TRUE(num->isSubtypeOf(*union1)); ASSERT_TRUE(union1->isSubtypeOf(*num)); ASSERT_TRUE(*num == *union1); ASSERT_TRUE(num->isSubtypeOf(*union2)); ASSERT_FALSE(union2->isSubtypeOf(*num)); ASSERT_FALSE(*num == *union2); } TEST_F(UnionTypeTest, Subtyping_OptionalType) { // Union[int, None] const UnionTypePtr union1 = UnionType::create({IntType::get(), NoneType::get()}); // Union[int, str, None] const UnionTypePtr union2 = UnionType::create({IntType::get(), StringType::get(), NoneType::get()}); // Union[int, str, List[str]] const UnionTypePtr union3 = UnionType::create( {IntType::get(), StringType::get(), ListType::ofStrings()}); ASSERT_TRUE(none->isSubtypeOf(opt1)); ASSERT_TRUE(none->isSubtypeOf(union1)); ASSERT_TRUE(none->isSubtypeOf(union2)); ASSERT_FALSE(none->isSubtypeOf(union3)); ASSERT_FALSE(opt1->isSubtypeOf(none)); ASSERT_TRUE(opt1->isSubtypeOf(union1)); ASSERT_TRUE(opt1->isSubtypeOf(union2)); ASSERT_FALSE(opt1->isSubtypeOf(union3)); ASSERT_FALSE(union1->isSubtypeOf(none)); ASSERT_TRUE(union1->isSubtypeOf(opt1)); ASSERT_TRUE(union1->isSubtypeOf(union2)); ASSERT_FALSE(union1->isSubtypeOf(union3)); ASSERT_FALSE(union2->isSubtypeOf(union1)); } } // namespace jit } // namespace torch