xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/transforms/lambda.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/transforms/base.h>
4 
5 #include <functional>
6 #include <utility>
7 #include <vector>
8 
9 namespace torch {
10 namespace data {
11 namespace transforms {
12 
13 /// A `BatchTransform` that applies a user-provided functor to a batch.
14 template <typename Input, typename Output = Input>
15 class BatchLambda : public BatchTransform<Input, Output> {
16  public:
17   using typename BatchTransform<Input, Output>::InputBatchType;
18   using typename BatchTransform<Input, Output>::OutputBatchType;
19   using FunctionType = std::function<OutputBatchType(InputBatchType)>;
20 
21   /// Constructs the `BatchLambda` from the given `function` object.
BatchLambda(FunctionType function)22   explicit BatchLambda(FunctionType function)
23       : function_(std::move(function)) {}
24 
25   /// Applies the user-provided function object to the `input_batch`.
apply_batch(InputBatchType input_batch)26   OutputBatchType apply_batch(InputBatchType input_batch) override {
27     return function_(std::move(input_batch));
28   }
29 
30  private:
31   FunctionType function_;
32 };
33 
34 // A `Transform` that applies a user-provided functor to individual examples.
35 template <typename Input, typename Output = Input>
36 class Lambda : public Transform<Input, Output> {
37  public:
38   using typename Transform<Input, Output>::InputType;
39   using typename Transform<Input, Output>::OutputType;
40   using FunctionType = std::function<Output(Input)>;
41 
42   /// Constructs the `Lambda` from the given `function` object.
Lambda(FunctionType function)43   explicit Lambda(FunctionType function) : function_(std::move(function)) {}
44 
45   /// Applies the user-provided function object to the `input`.
apply(InputType input)46   OutputType apply(InputType input) override {
47     return function_(std::move(input));
48   }
49 
50  private:
51   FunctionType function_;
52 };
53 
54 } // namespace transforms
55 } // namespace data
56 } // namespace torch
57