xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/datasets/mnist.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/datasets/base.h>
4 #include <torch/data/example.h>
5 #include <torch/types.h>
6 
7 #include <torch/csrc/Export.h>
8 
9 #include <cstddef>
10 #include <string>
11 
12 namespace torch {
13 namespace data {
14 namespace datasets {
15 /// The MNIST dataset.
16 class TORCH_API MNIST : public Dataset<MNIST> {
17  public:
18   /// The mode in which the dataset is loaded.
19   enum class Mode { kTrain, kTest };
20 
21   /// Loads the MNIST dataset from the `root` path.
22   ///
23   /// The supplied `root` path should contain the *content* of the unzipped
24   /// MNIST dataset, available from http://yann.lecun.com/exdb/mnist.
25   explicit MNIST(const std::string& root, Mode mode = Mode::kTrain);
26 
27   /// Returns the `Example` at the given `index`.
28   Example<> get(size_t index) override;
29 
30   /// Returns the size of the dataset.
31   std::optional<size_t> size() const override;
32 
33   /// Returns true if this is the training subset of MNIST.
34   // NOLINTNEXTLINE(bugprone-exception-escape)
35   bool is_train() const noexcept;
36 
37   /// Returns all images stacked into a single tensor.
38   const Tensor& images() const;
39 
40   /// Returns all targets stacked into a single tensor.
41   const Tensor& targets() const;
42 
43  private:
44   Tensor images_, targets_;
45 };
46 } // namespace datasets
47 } // namespace data
48 } // namespace torch
49