#include #include #include using namespace torch::autograd; using namespace torch::test; namespace { torch::Tensor functional_op(torch::Tensor& x) { return x * x; } void inplace_op(torch::Tensor& x) { x.mul_(1); } torch::Tensor view_op(torch::Tensor& x) { return x.view({2, 3}); } /* Only the following combos of Autograd & ADInplaceOrView keys on tensors are valid: - Autograd=true, ADInplaceOrView=true (normal tensor) - Autograd=false, ADInplaceOrView=false (inference tensor) Tensors created in InferenceMode are mostly inference tensors. The only exception is that view of normal tensors created in InferenceMode still produce normal tensor. */ void assert_TLS_states(bool inference_mode) { ASSERT_EQ(InferenceMode::is_enabled(), inference_mode); ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded( c10::DispatchKey::ADInplaceOrView)); ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included( c10::autograd_dispatch_keyset)); ASSERT_EQ( c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), inference_mode); ASSERT_EQ( c10::impl::tls_is_dispatch_key_included( c10::DispatchKey::ADInplaceOrView), !inference_mode); ASSERT_EQ(GradMode::is_enabled(), !inference_mode); } } // namespace TEST(InferenceModeTest, TestTLSState) { assert_TLS_states(false); { InferenceMode guard; assert_TLS_states(true); { InferenceMode guard(false); assert_TLS_states(false); } assert_TLS_states(true); } assert_TLS_states(false); } TEST(InferenceModeTest, TestInferenceTensorCreation) { { InferenceMode guard; // New tensor created through constructors are inference tensors. torch::Tensor c = torch::ones({1, 2, 3}); ASSERT_FALSE(c.requires_grad()); ASSERT_TRUE(c.is_inference()); // requires_grad doesn't change inference tensor behavior inside // InferenceMode. torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true); ASSERT_TRUE(tmp.requires_grad()); ASSERT_TRUE(tmp.is_inference()); tmp = torch::ones({1, 2, 3}).set_requires_grad(false); ASSERT_FALSE(tmp.requires_grad()); ASSERT_TRUE(tmp.is_inference()); } } TEST(InferenceModeTest, TestExistingAutogradSession) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true); torch::Tensor a = s.clone(); // Save `a` in an existing autograd session torch::Tensor out = a * a; { InferenceMode guard; inplace_op(a); } // Performing backward should trigger error since `a`'s version has been // bumped. ASSERT_THROWS_WITH( out.backward(torch::ones_like(out)), "one of the variables needed for gradient computation has been modified by an inplace operation") } TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) { c10::InferenceMode guard; for (bool requires_grad : {true, false}) { torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor func_out = functional_op(c); // go through kernels: CPU ASSERT_TRUE(func_out.is_inference()); ASSERT_FALSE(func_out.requires_grad()); } } TEST(InferenceModeTest, TestInferenceTensorInInferenceModeInplaceOp) { c10::InferenceMode guard; for (bool requires_grad : {true, false}) { torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); inplace_op(c); // go through kernels: CPU ASSERT_TRUE(c.is_inference()); ASSERT_EQ(c.requires_grad(), requires_grad); } } TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) { c10::InferenceMode guard; for (bool requires_grad : {true, false}) { torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor view_out = view_op(c); // go through kernels: CPU ASSERT_TRUE(view_out.is_inference()); // Note this is different from NoGradMode but makes sense. ASSERT_FALSE(view_out.requires_grad()); ASSERT_FALSE(view_out.is_view()); } } TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) { torch::Tensor inference_tensor; for (bool requires_grad : {true, false}) { { InferenceMode guard; inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } // Due to issue #54614, this might run slower compared to InferenceMode // since intermediate tensors are normal tensors, and they might dispatch to // VariableType kernels. This is fine since users can easily fix it by // moving it inside InferenceMode block. torch::Tensor tmp = functional_op(inference_tensor); // go through kernels: // ADInplaceOrView(fallthrough), CPU ASSERT_FALSE(tmp.is_inference()); ASSERT_FALSE(tmp.requires_grad()); } } TEST(InferenceModeTest, TestInferenceTensorInNormalModeInplaceOp) { torch::Tensor inference_tensor; for (bool requires_grad : {true, false}) { { InferenceMode guard; inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } ASSERT_THROWS_WITH( inplace_op( inference_tensor), // go through kernels: ADInplaceOrView, CPU "Inplace update to inference tensor outside InferenceMode is not allowed"); } } TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) { torch::Tensor inference_tensor; for (bool requires_grad : {true, false}) { { InferenceMode guard; inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } torch::Tensor out = view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU ASSERT_TRUE(out.is_inference()); ASSERT_FALSE(out.requires_grad()); ASSERT_FALSE(out.is_view()); ASSERT_TRUE(out.is_leaf()); } } TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) { for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); { c10::InferenceMode guard; inplace_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(a.is_inference()); ASSERT_EQ(a.requires_grad(), requires_grad); // inplace -> inplace inplace_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(a.is_inference()); ASSERT_EQ(a.requires_grad(), requires_grad); // inplace -> inplace -> view torch::Tensor view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(view_out.is_inference()); ASSERT_EQ(view_out.requires_grad(), requires_grad); } } } TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) { for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); { c10::InferenceMode guard; inplace_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(a.is_inference()); ASSERT_EQ(a.requires_grad(), requires_grad); } torch::Tensor tmp = functional_op(a); // go through kernels: VariableType, // ADInplaceOrView(fallthrough), CPU ASSERT_FALSE(tmp.is_inference()); ASSERT_EQ(tmp.requires_grad(), requires_grad); inplace_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU ASSERT_FALSE(a.is_inference()); ASSERT_EQ(a.requires_grad(), requires_grad); tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU ASSERT_FALSE(tmp.is_inference()); ASSERT_EQ(tmp.requires_grad(), requires_grad); } } TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) { for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out, tmp; { c10::InferenceMode guard; // View ops on normal tensor produce normal tensors as output. // - For view ops it has both dispatch keys since due to the way we create // view Tensors in alias_with_sizes_and_strides: // ``` // auto impl = c10::make_intrusive( // Storage(self.storage()), self.key_set(), self.dtype()); // ``` // In addition, these view output tensors are normal in the sense they // have both Autograd and ADInplaceOrView keys. But they're still // special since they'll have CreationMeta::INFERENCE_MODE. In other // words they behave exactly the same as a view tensor created in // no_grad mode. view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(view_out.is_inference()); assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE); ASSERT_EQ(view_out.requires_grad(), requires_grad); ASSERT_TRUE(view_out.is_leaf()); // view -> view tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(tmp.is_inference()); assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE); ASSERT_EQ(tmp.requires_grad(), requires_grad); ASSERT_TRUE(tmp.is_leaf()); // view -> view -> inplace inplace_op(tmp); // kernels: ADInplaceOrView, CPU assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE); ASSERT_FALSE(tmp.is_inference()); ASSERT_EQ(tmp.requires_grad(), requires_grad); ASSERT_TRUE(tmp.is_leaf()); ASSERT_EQ(a._version(), tmp._version()); } } } TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) { for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out, tmp; { c10::InferenceMode guard; view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(view_out.is_inference()); assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE); ASSERT_EQ(view_out.requires_grad(), requires_grad); ASSERT_TRUE(view_out.is_leaf()); } tmp = functional_op(view_out); ASSERT_FALSE(view_out.is_inference()); ASSERT_EQ(tmp.requires_grad(), requires_grad); if (requires_grad) { ASSERT_THROWS_WITH( inplace_op(view_out), // go through kernels: VariableType, // ADInplaceOrView, CPU "A view was created in inference mode and is being modified inplace") } else { inplace_op(view_out); } tmp = view_op(view_out); ASSERT_FALSE(view_out.is_inference()); ASSERT_EQ(tmp.requires_grad(), requires_grad); } } TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) { for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor c; { InferenceMode guard; c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } // add(Tensor, Tensor) is safe with inference tensor since it doesn't save // any variable for backward. torch::Tensor out = c.add(s); // go through kernels: VariableType, // ADInplaceOrView(fallthrough), CPU ASSERT_FALSE(out.is_inference()); ASSERT_EQ(out.requires_grad(), requires_grad); if (requires_grad) { // leaf inference tensor with requires_grad=true can still have gradient. // Note this behavior is different from NoGradMode which has empty grad. out.backward(torch::ones_like(out)); assert_tensor_equal(c.grad(), torch::ones_like(c)); } if (requires_grad) { // mul(self, other) saves variable when requires_grad=true ASSERT_THROWS_WITH( c.mul(s), "Inference tensors cannot be saved for backward."); // Inference tensor in TensorList input // stack does not capture anymore, so disabled // TODO: find alternative Function that captures a list (maybe custom fn) /* std::vector inputs = {s, c}; ASSERT_THROWS_WITH( torch::stack(inputs), // go through kernels: VariableType(ERROR)!, // ADInplaceOrView(fallthrough), CPU "Inference tensors cannot be saved for backward.") */ } } } TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) { for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor c; { InferenceMode guard; c = torch::ones({1, 2, 3}); } if (requires_grad) { ASSERT_THROWS_WITH( a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode, // CPU "Inference tensors cannot be saved for backward."); ASSERT_THROWS_WITH( torch::mul_out( /*out=*/c, s, s), // go through kernels: VariableType(ERROR!), // ADInplaceOrView, CPU "out=... arguments don't support automatic differentiation, but one of the arguments requires grad") } else { a.mul_(c); ASSERT_THROWS_WITH( torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType, // ADInplaceOrView(ERROR!), CPU "Inplace update to inference tensor outside InferenceMode is not allowed"); } } } TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) { for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor c; { InferenceMode guard; c = torch::ones({1, 2, 3}); } // view_as is a composite op which calls view() with only one tensor // argument. So there isn't a mixed inference tensor and normal tensor // inputs for view ops. torch::Tensor tmp1 = c.view_as(s); // go through kernels: ADInplaceOrView, CPU ASSERT_TRUE(tmp1.is_inference()); ASSERT_FALSE(tmp1.requires_grad()); // This is fine since it's equivalent as s.view(c.sizes()) which // isn't a mixed input scenario. torch::Tensor tmp2 = s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU ASSERT_FALSE(tmp2.is_inference()); ASSERT_EQ(tmp2.requires_grad(), requires_grad); } } TEST(InferenceModeTest, TestHandleDirectViewOnRebase) { for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out; { InferenceMode guard; view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU } if (requires_grad) { ASSERT_THROWS_WITH( inplace_op(view_out), "A view was created in inference mode and is being modified inplace") } else { inplace_op(view_out); } } } TEST(InferenceModeTest, TestHandleInDirectViewOnRebase) { for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out; { InferenceMode guard; view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU } inplace_op(a); if (requires_grad) { ASSERT_THROWS_WITH( view_out.grad_fn(), "A view was created in inference mode and its base or another view of its base has been modified inplace"); } else { view_out.grad_fn(); } } } TEST(InferenceModeTest, TestCreationMetaPropagation) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true); torch::Tensor b, c; { InferenceMode guard; b = s.view_as(s); } ASSERT_THROWS_WITH( b.add_(1), "A view was created in inference mode and is being modified inplace"); { AutoGradMode mode(false); c = b.view_as(b); } ASSERT_THROWS_WITH( c.add_(1), "A view was created in inference mode and is being modified inplace"); } TEST(InferenceModeTest, TestCreationMetaPropagationInput) { torch::Tensor s = torch::ones({2, 2, 3}).set_requires_grad(true); auto s_view = s.view_as(s); std::vector b, c; { InferenceMode guard; b = s_view.split_with_sizes({1, 1}); s = s.view_as(s); c = s.split_with_sizes({1, 1}); } for (auto& b_el : b) { assert_tensor_creation_meta(b_el, CreationMeta::INFERENCE_MODE); ASSERT_THROWS_WITH( b_el.add_(1), "A view was created in inference mode and is being modified inplace"); } for (auto& c_el : c) { assert_tensor_creation_meta(c_el, CreationMeta::INFERENCE_MODE); ASSERT_THROWS_WITH( c_el.add_(1), "A view was created in inference mode and is being modified inplace"); } } TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) { for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor t; { InferenceMode guard; t = torch::ones({1, 2, 3}); t.copy_(s); ASSERT_TRUE(t.is_inference()); ASSERT_FALSE(t.requires_grad()); } ASSERT_THROWS_WITH( t.copy_(s), "Inplace update to inference tensor outside InferenceMode is not allowed"); } } TEST(InferenceModeTest, TestSetRequiresGradInNormalMode) { torch::Tensor t; { InferenceMode guard; t = torch::ones({1, 2, 3}); } t.set_requires_grad(false); ASSERT_THROWS_WITH( t.set_requires_grad(true), "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed."); } TEST(InferenceModeTest, TestAccessVersionCounter) { torch::Tensor t; { InferenceMode guard; t = torch::ones({1, 2, 3}); ASSERT_THROWS_WITH( t.unsafeGetTensorImpl()->version_counter().current_version(), "Inference tensors do not track version counter."); t.unsafeGetTensorImpl()->bump_version(); } ASSERT_THROWS_WITH( t.unsafeGetTensorImpl()->version_counter().current_version(), "Inference tensors do not track version counter."); ASSERT_THROWS_WITH( t.unsafeGetTensorImpl()->bump_version(), "Inplace update to inference tensor outside InferenceMode is not allowed."); // Suggested workaround torch::Tensor c = t.clone(); uint32_t v = c.unsafeGetTensorImpl()->version_counter().current_version(); c.unsafeGetTensorImpl()->bump_version(); ASSERT_EQ( c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1); } TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) { torch::Tensor s = torch::ones({1, 2, 3}); torch::Tensor t; { InferenceMode guard; t = torch::ones({1, 2, 3}); // Testing both copy_ from VariableTypeManual and add_ from generated code. s.copy_(t); s.add_(t); t.add_(s); t.copy_(s); } s.copy_(t); s.add_(t); ASSERT_THROWS_WITH( t.copy_(s), "Inplace update to inference tensor outside InferenceMode is not allowed"); ASSERT_THROWS_WITH( t.add_(s), "Inplace update to inference tensor outside InferenceMode is not allowed"); } TEST(InferenceModeTest, TestComplexViewInInferenceMode) { torch::Tensor s = torch::ones({3, 3, 2}); torch::Tensor t = torch::view_as_complex(s); { InferenceMode guard; torch::Tensor tmp; tmp = torch::view_as_real(t); ASSERT_FALSE(tmp.is_inference()); tmp = torch::view_as_complex(s); ASSERT_FALSE(tmp.is_inference()); torch::Tensor e = torch::ones({3, 3, 2}); tmp = torch::view_as_complex(e); ASSERT_TRUE(tmp.is_inference()); tmp = torch::view_as_real(tmp); ASSERT_TRUE(tmp.is_inference()); } } TEST(InferenceModeTest, TestComplexViewInNormalMode) { torch::Tensor s; { InferenceMode guard; s = torch::ones({3, 3, 2}); } torch::Tensor tmp = torch::view_as_complex(s); ASSERT_TRUE(tmp.is_inference()); tmp = torch::view_as_real(tmp); ASSERT_TRUE(tmp.is_inference()); } TEST(InferenceModeTest, TestCustomFunction) { struct MyFunction : public Function { static Variable forward( AutogradContext* ctx, Variable var1, int mul, Variable var2) { ctx->saved_data["mul"] = mul; ctx->save_for_backward({var1, var2}); return var1 + mul * var2 + var1 * var2; } static variable_list backward( AutogradContext* ctx, variable_list grad_output) { int mul = ctx->saved_data["mul"].toInt(); auto saved = ctx->get_saved_variables(); auto var1 = saved[0]; auto var2 = saved[1]; variable_list output = { grad_output[0] + grad_output[0] * var2, Variable(), grad_output[0] * mul + grad_output[0] * var1}; return output; } }; { InferenceMode guard; torch::Tensor var1 = torch::ones({3, 3}).set_requires_grad(true); auto var2 = var1.clone(); int mul = 2; // If InferenceMode didn't set NoGradGuard automatically, this line // would error out when trying to save `var1` and `var2` for backward. auto y = MyFunction::apply(var1, mul, var2); torch::Tensor expected = var1 + mul * var2 + var1 * var2; assert_tensor_equal(y, expected); } } TEST(InferenceModeTest, TestLegacyAutoNonVariableTypeModeWarning) { c10::WarningUtils::WarnAlways warn_always(true); WarningCapture warnings; at::AutoNonVariableTypeMode guard; ASSERT_TRUE( warnings.str().find("AutoNonVariableTypeMode is deprecated") != std::string::npos); }