1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/types.h> 5 6 #include <chrono> 7 #include <cstddef> 8 9 namespace torch { 10 namespace data { 11 12 /// Options to configure a `DataLoader`. 13 struct DataLoaderOptions { 14 DataLoaderOptions() = default; DataLoaderOptionsDataLoaderOptions15 /* implicit */ DataLoaderOptions(size_t batch_size) 16 : batch_size_(batch_size) {} 17 18 /// The size of each batch to fetch. 19 TORCH_ARG(size_t, batch_size) = 1; 20 21 /// The number of worker threads to launch. If zero, the main thread will 22 /// synchronously perform the data loading. 23 TORCH_ARG(size_t, workers) = 0; 24 25 /// The maximum number of jobs to enqueue for fetching by worker threads. 26 /// Defaults to two times the number of worker threads. 27 TORCH_ARG(std::optional<size_t>, max_jobs); 28 29 /// An optional limit on the time to wait for the next batch. 30 TORCH_ARG(std::optional<std::chrono::milliseconds>, timeout); 31 32 /// Whether to enforce ordering of batches when multiple are loaded 33 /// asynchronously by worker threads. Set to `false` for better performance if 34 /// you do not care about determinism. 35 TORCH_ARG(bool, enforce_ordering) = true; 36 37 /// Whether to omit the last batch if it contains less than `batch_size` 38 /// examples. 39 TORCH_ARG(bool, drop_last) = false; 40 }; 41 42 /// Like `DataLoaderOptions`, but without any unconfigured state. 43 /// `DataLoaderOptions` has some options that depend on other options 44 /// (`max_jobs` => `2 * workers`). In the spirit of properly using the C++ type 45 /// system, `DataLoaderOptions` allows only setting values. To access values, 46 /// you must create a `FullDataLoaderOptions` from a `DataLoaderOptions` 47 /// instance, which will do any necessary coalescing. 48 struct FullDataLoaderOptions { FullDataLoaderOptionsFullDataLoaderOptions49 explicit FullDataLoaderOptions(DataLoaderOptions options) 50 : batch_size(options.batch_size()), 51 workers(options.workers()), 52 max_jobs(options.max_jobs().value_or(2 * workers)), 53 timeout(options.timeout()), 54 enforce_ordering(options.enforce_ordering()), 55 drop_last(options.drop_last()) {} 56 57 size_t batch_size; 58 size_t workers; 59 size_t max_jobs; 60 std::optional<std::chrono::milliseconds> timeout; 61 bool enforce_ordering; 62 bool drop_last; 63 }; 64 } // namespace data 65 } // namespace torch 66