1 #pragma once
2
3 #include <ATen/Parallel.h>
4 #include <ATen/record_function.h>
5 #include <torch/csrc/api/include/torch/types.h>
6 #include <torch/csrc/autograd/grad_mode.h>
7 #include <torch/csrc/autograd/profiler.h>
8 #include <cstdint>
9
10 namespace torch {
11
12 /// A RAII, thread-local guard that disabled gradient calculation.
13 ///
14 /// Disabling gradient calculation is useful for inference, when you are sure
15 /// that you will not call `at::Tensor::backward`. It will reduce memory
16 /// consumption for computations that would otherwise have `requires_grad() ==
17 /// true`.
18 ///
19 /// In this mode, the result of every computation will have
20 /// `requires_grad() == false`, even when the inputs have `requires_grad() ==
21 /// true`.
22 ///
23 /// This context manager is thread-local; it will not affect computation
24 /// in other threads.
25 ///
26 /// Example:
27 /// @code
28 /// auto x = torch::tensor({1.}, torch::requires_grad());
29 /// {
30 /// torch::NoGradGuard no_grad;
31 /// auto y = x * 2;
32 /// std::cout << y.requires_grad() << std::endl; // prints `false`
33 /// }
34 /// {
35 /// auto doubler = [](torch::Tensor x) {
36 /// torch::NoGradGuard no_grad;
37 /// return x * 2;
38 /// };
39 /// auto z = doubler(x);
40 /// std::cout << z.requires_grad() << std::endl; // prints `false`
41 /// }
42 /// @endcode
43 using NoGradGuard = at::NoGradGuard;
44
45 /// A RAII, thread-local guard that sets gradient calculation to on or off.
46 ///
47 /// ``AutoGradMode`` will enable or disable grads based on its argument
48 /// `enabled`.
49 ///
50 /// This context manager is thread-local; it will not affect computation
51 /// in other threads.
52 ///
53 /// \param enabled: Flag whether to enable grad (``true``), or disable
54 /// (``false``). This can be used to conditionally enable
55 /// gradients.
56 ///
57 /// Example:
58 /// @code
59 /// auto x = torch::tensor({1.}, torch::requires_grad());
60 /// {
61 /// torch::AutoGradMode enable_grad(true);
62 /// auto y = x * 2;
63 /// std::cout << y.requires_grad() << std::endl; // prints `true`
64 /// }
65 /// {
66 /// torch::AutoGradMode enable_grad(false);
67 /// auto y = x * 2;
68 /// std::cout << y.requires_grad() << std::endl; // prints `false`
69 /// }
70 /// @endcode
71 using AutoGradMode = at::AutoGradMode;
72
73 /// Sets the global random seed for all newly created CPU and CUDA tensors.
74 using at::manual_seed;
75
76 // Called during new thread initialization
77 using at::init_num_threads;
78
79 // Returns the number of threads used in parallel region.
80 using at::get_num_threads;
81
82 // Sets the number of threads to be used in parallel region.
83 using at::set_num_threads;
84
85 // Returns the number of threads used for inter-op parallelism.
86 using at::get_num_interop_threads;
87
88 // Sets the number of threads to be used for inter-op parallelism.
89 using at::set_num_interop_threads;
90
91 // Returns true if both t1, t2 are undefined or both are defined and equal
equal_if_defined(Tensor t1,Tensor t2)92 inline bool equal_if_defined(Tensor t1, Tensor t2) {
93 return (
94 (!t1.defined() && !t2.defined()) ||
95 (t1.defined() && t2.defined() && torch::equal(t1, t2)));
96 }
97
98 // RecordFunction API
99 using at::addGlobalCallback;
100 using at::addThreadLocalCallback;
101 using at::CallbackHandle;
102 using at::clearCallbacks;
103 using at::clearGlobalCallbacks;
104 using at::clearThreadLocalCallbacks;
105 using at::DisableRecordFunctionGuard;
106 using at::enableRecordFunction;
107 using at::hasCallbacks;
108 using at::hasGlobalCallbacks;
109 using at::hasThreadLocalCallbacks;
110 using at::isRecordFunctionEnabled;
111 using at::RecordFunction;
112 using at::RecordFunctionCallback;
113 using at::RecordFunctionGuard;
114 using at::removeCallback;
115
116 } // namespace torch
117