1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8"""TestDataGenerator factory class.
9"""
10
11import logging
12
13from . import exceptions
14from . import test_data_generation
15
16
17class TestDataGeneratorFactory(object):
18    """Factory class used to create test data generators.
19
20  Usage: Create a factory passing parameters to the ctor with which the
21  generators will be produced.
22  """
23
24    def __init__(self, aechen_ir_database_path, noise_tracks_path,
25                 copy_with_identity):
26        """Ctor.
27
28    Args:
29      aechen_ir_database_path: Path to the Aechen Impulse Response database.
30      noise_tracks_path: Path to the noise tracks to add.
31      copy_with_identity: Flag indicating whether the identity generator has to
32                          make copies of the clean speech input files.
33    """
34        self._output_directory_prefix = None
35        self._aechen_ir_database_path = aechen_ir_database_path
36        self._noise_tracks_path = noise_tracks_path
37        self._copy_with_identity = copy_with_identity
38
39    def SetOutputDirectoryPrefix(self, prefix):
40        self._output_directory_prefix = prefix
41
42    def GetInstance(self, test_data_generators_class):
43        """Creates an TestDataGenerator instance given a class object.
44
45    Args:
46      test_data_generators_class: TestDataGenerator class object (not an
47                                  instance).
48
49    Returns:
50      TestDataGenerator instance.
51    """
52        if self._output_directory_prefix is None:
53            raise exceptions.InitializationException(
54                'The output directory prefix for test data generators is not set'
55            )
56        logging.debug('factory producing %s', test_data_generators_class)
57
58        if test_data_generators_class == (
59                test_data_generation.IdentityTestDataGenerator):
60            return test_data_generation.IdentityTestDataGenerator(
61                self._output_directory_prefix, self._copy_with_identity)
62        elif test_data_generators_class == (
63                test_data_generation.ReverberationTestDataGenerator):
64            return test_data_generation.ReverberationTestDataGenerator(
65                self._output_directory_prefix, self._aechen_ir_database_path)
66        elif test_data_generators_class == (
67                test_data_generation.AdditiveNoiseTestDataGenerator):
68            return test_data_generation.AdditiveNoiseTestDataGenerator(
69                self._output_directory_prefix, self._noise_tracks_path)
70        else:
71            return test_data_generators_class(self._output_directory_prefix)
72