xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/dataloader.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/dataloader/stateful.h>
4 #include <torch/data/dataloader/stateless.h>
5 
6 #include <torch/csrc/utils/variadic.h>
7 
8 #include <c10/util/Exception.h>
9 
10 #include <cstddef>
11 #include <memory>
12 #include <type_traits>
13 #include <utility>
14 
15 namespace torch {
16 namespace data {
17 
18 /// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and
19 /// some `options`.
20 template <typename Dataset, typename Sampler>
21 std::enable_if_t<
22     !Dataset::is_stateful,
23     std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
make_data_loader(Dataset dataset,Sampler sampler,DataLoaderOptions options)24 make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) {
25   return std::make_unique<StatelessDataLoader<Dataset, Sampler>>(
26       std::move(dataset), std::move(sampler), std::move(options));
27 }
28 
29 /// Creates a `DataLoader` instance for a stateless `dataset` and some
30 /// `options`. A sampler (by default a `RandomSampler`) will be constructed from
31 /// the size of the dataset.
32 template <typename Sampler = samplers::RandomSampler, typename Dataset>
33 std::enable_if_t<
34     !Dataset::is_stateful && std::is_constructible_v<Sampler, size_t>,
35     std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
36 make_data_loader(
37     Dataset dataset,
38     DataLoaderOptions options = DataLoaderOptions()) {
39   const std::optional<size_t> size = dataset.size();
40   TORCH_CHECK(
41       size.has_value(),
42       "Expected the dataset to be sized in "
43       "order to construct the Sampler");
44   return make_data_loader(
45       std::move(dataset), Sampler(*size), std::move(options));
46 }
47 
48 /// Creates a `DataLoader` for a stateful `dataset` and some `options`.
49 template <typename Dataset, typename = std::enable_if_t<Dataset::is_stateful>>
50 std::unique_ptr<StatefulDataLoader<Dataset>> make_data_loader(
51     Dataset dataset,
52     DataLoaderOptions options = DataLoaderOptions()) {
53   return std::make_unique<StatefulDataLoader<Dataset>>(
54       std::move(dataset), std::move(options));
55 }
56 } // namespace data
57 } // namespace torch
58