1 #pragma once 2 3 #include <torch/data/transforms/base.h> 4 5 #include <functional> 6 #include <utility> 7 #include <vector> 8 9 namespace torch { 10 namespace data { 11 namespace transforms { 12 13 /// A `BatchTransform` that applies a user-provided functor to a batch. 14 template <typename Input, typename Output = Input> 15 class BatchLambda : public BatchTransform<Input, Output> { 16 public: 17 using typename BatchTransform<Input, Output>::InputBatchType; 18 using typename BatchTransform<Input, Output>::OutputBatchType; 19 using FunctionType = std::function<OutputBatchType(InputBatchType)>; 20 21 /// Constructs the `BatchLambda` from the given `function` object. BatchLambda(FunctionType function)22 explicit BatchLambda(FunctionType function) 23 : function_(std::move(function)) {} 24 25 /// Applies the user-provided function object to the `input_batch`. apply_batch(InputBatchType input_batch)26 OutputBatchType apply_batch(InputBatchType input_batch) override { 27 return function_(std::move(input_batch)); 28 } 29 30 private: 31 FunctionType function_; 32 }; 33 34 // A `Transform` that applies a user-provided functor to individual examples. 35 template <typename Input, typename Output = Input> 36 class Lambda : public Transform<Input, Output> { 37 public: 38 using typename Transform<Input, Output>::InputType; 39 using typename Transform<Input, Output>::OutputType; 40 using FunctionType = std::function<Output(Input)>; 41 42 /// Constructs the `Lambda` from the given `function` object. Lambda(FunctionType function)43 explicit Lambda(FunctionType function) : function_(std::move(function)) {} 44 45 /// Applies the user-provided function object to the `input`. apply(InputType input)46 OutputType apply(InputType input) override { 47 return function_(std::move(input)); 48 } 49 50 private: 51 FunctionType function_; 52 }; 53 54 } // namespace transforms 55 } // namespace data 56 } // namespace torch 57