1 #pragma once 2 3 #include <torch/data/datasets/base.h> 4 #include <torch/data/example.h> 5 #include <torch/types.h> 6 7 #include <torch/csrc/Export.h> 8 9 #include <cstddef> 10 #include <string> 11 12 namespace torch { 13 namespace data { 14 namespace datasets { 15 /// The MNIST dataset. 16 class TORCH_API MNIST : public Dataset<MNIST> { 17 public: 18 /// The mode in which the dataset is loaded. 19 enum class Mode { kTrain, kTest }; 20 21 /// Loads the MNIST dataset from the `root` path. 22 /// 23 /// The supplied `root` path should contain the *content* of the unzipped 24 /// MNIST dataset, available from http://yann.lecun.com/exdb/mnist. 25 explicit MNIST(const std::string& root, Mode mode = Mode::kTrain); 26 27 /// Returns the `Example` at the given `index`. 28 Example<> get(size_t index) override; 29 30 /// Returns the size of the dataset. 31 std::optional<size_t> size() const override; 32 33 /// Returns true if this is the training subset of MNIST. 34 // NOLINTNEXTLINE(bugprone-exception-escape) 35 bool is_train() const noexcept; 36 37 /// Returns all images stacked into a single tensor. 38 const Tensor& images() const; 39 40 /// Returns all targets stacked into a single tensor. 41 const Tensor& targets() const; 42 43 private: 44 Tensor images_, targets_; 45 }; 46 } // namespace datasets 47 } // namespace data 48 } // namespace torch 49