xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/dataloader_options.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/types.h>
5 
6 #include <chrono>
7 #include <cstddef>
8 
9 namespace torch {
10 namespace data {
11 
12 /// Options to configure a `DataLoader`.
13 struct DataLoaderOptions {
14   DataLoaderOptions() = default;
DataLoaderOptionsDataLoaderOptions15   /* implicit */ DataLoaderOptions(size_t batch_size)
16       : batch_size_(batch_size) {}
17 
18   /// The size of each batch to fetch.
19   TORCH_ARG(size_t, batch_size) = 1;
20 
21   /// The number of worker threads to launch. If zero, the main thread will
22   /// synchronously perform the data loading.
23   TORCH_ARG(size_t, workers) = 0;
24 
25   /// The maximum number of jobs to enqueue for fetching by worker threads.
26   /// Defaults to two times the number of worker threads.
27   TORCH_ARG(std::optional<size_t>, max_jobs);
28 
29   /// An optional limit on the time to wait for the next batch.
30   TORCH_ARG(std::optional<std::chrono::milliseconds>, timeout);
31 
32   /// Whether to enforce ordering of batches when multiple are loaded
33   /// asynchronously by worker threads. Set to `false` for better performance if
34   /// you do not care about determinism.
35   TORCH_ARG(bool, enforce_ordering) = true;
36 
37   /// Whether to omit the last batch if it contains less than `batch_size`
38   /// examples.
39   TORCH_ARG(bool, drop_last) = false;
40 };
41 
42 /// Like `DataLoaderOptions`, but without any unconfigured state.
43 /// `DataLoaderOptions` has some options that depend on other options
44 /// (`max_jobs` => `2 * workers`). In the spirit of properly using the C++ type
45 /// system, `DataLoaderOptions` allows only setting values. To access values,
46 /// you must create a `FullDataLoaderOptions` from a `DataLoaderOptions`
47 /// instance, which will do any necessary coalescing.
48 struct FullDataLoaderOptions {
FullDataLoaderOptionsFullDataLoaderOptions49   explicit FullDataLoaderOptions(DataLoaderOptions options)
50       : batch_size(options.batch_size()),
51         workers(options.workers()),
52         max_jobs(options.max_jobs().value_or(2 * workers)),
53         timeout(options.timeout()),
54         enforce_ordering(options.enforce_ordering()),
55         drop_last(options.drop_last()) {}
56 
57   size_t batch_size;
58   size_t workers;
59   size_t max_jobs;
60   std::optional<std::chrono::milliseconds> timeout;
61   bool enforce_ordering;
62   bool drop_last;
63 };
64 } // namespace data
65 } // namespace torch
66