xref: /aosp_15_r20/external/libopus/dnn/torch/osce/create_testvectors.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31import argparse
32
33import torch
34import numpy as np
35
36from models import model_dict
37from utils import endoscopy
38
39parser = argparse.ArgumentParser()
40
41parser.add_argument('checkpoint_path', type=str, help='path to folder containing checkpoints "lace_checkpoint.pth" and nolace_checkpoint.pth"')
42parser.add_argument('output_folder', type=str, help='output folder for testvectors')
43parser.add_argument('--debug', action='store_true', help='add debug output to output folder')
44
45
46def create_adaconv_testvector(prefix, adaconv, num_frames, debug=False):
47    feature_dim = adaconv.feature_dim
48    in_channels = adaconv.in_channels
49    out_channels = adaconv.out_channels
50    frame_size = adaconv.frame_size
51
52    features = torch.randn((1, num_frames, feature_dim))
53    x_in = torch.randn((1, in_channels, num_frames * frame_size))
54
55    x_out = adaconv(x_in, features, debug=debug)
56
57    features = features[0].detach().numpy()
58    x_in = x_in[0].reshape(in_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
59    x_out = x_out[0].reshape(out_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
60
61    features.tofile(prefix + '_features.f32')
62    x_in.tofile(prefix + '_x_in.f32')
63    x_out.tofile(prefix + '_x_out.f32')
64
65def create_adacomb_testvector(prefix, adacomb, num_frames, debug=False):
66    feature_dim = adacomb.feature_dim
67    in_channels = 1
68    frame_size = adacomb.frame_size
69
70    features = torch.randn((1, num_frames, feature_dim))
71    x_in = torch.randn((1, in_channels, num_frames * frame_size))
72    p_in = torch.randint(adacomb.kernel_size, 250, (1, num_frames))
73
74    x_out = adacomb(x_in, features, p_in, debug=debug)
75
76    features = features[0].detach().numpy()
77    x_in = x_in[0].permute(1, 0).detach().numpy()
78    p_in = p_in[0].detach().numpy().astype(np.int32)
79    x_out = x_out[0].permute(1, 0).detach().numpy()
80
81    features.tofile(prefix + '_features.f32')
82    x_in.tofile(prefix + '_x_in.f32')
83    p_in.tofile(prefix + '_p_in.s32')
84    x_out.tofile(prefix + '_x_out.f32')
85
86def create_adashape_testvector(prefix, adashape, num_frames):
87    feature_dim = adashape.feature_dim
88    frame_size = adashape.frame_size
89
90    features = torch.randn((1, num_frames, feature_dim))
91    x_in = torch.randn((1, 1, num_frames * frame_size))
92
93    x_out = adashape(x_in, features)
94
95    features = features[0].detach().numpy()
96    x_in = x_in.flatten().detach().numpy()
97    x_out = x_out.flatten().detach().numpy()
98
99    features.tofile(prefix + '_features.f32')
100    x_in.tofile(prefix + '_x_in.f32')
101    x_out.tofile(prefix + '_x_out.f32')
102
103def create_feature_net_testvector(prefix, model, num_frames):
104    num_features = model.num_features
105    num_subframes = 4 * num_frames
106
107    input_features = torch.randn((1, num_subframes, num_features))
108    periods = torch.randint(32, 300, (1, num_subframes))
109    numbits = model.numbits_range[0] + torch.rand((1, num_frames, 2)) * (model.numbits_range[1] - model.numbits_range[0])
110
111
112    pembed = model.pitch_embedding(periods)
113    nembed = torch.repeat_interleave(model.numbits_embedding(numbits).flatten(2), 4, dim=1)
114    full_features = torch.cat((input_features, pembed, nembed), dim=-1)
115
116    cf = model.feature_net(full_features)
117
118    input_features.float().numpy().tofile(prefix + "_in_features.f32")
119    periods.numpy().astype(np.int32).tofile(prefix + "_periods.s32")
120    numbits.float().numpy().tofile(prefix + "_numbits.f32")
121    full_features.detach().numpy().tofile(prefix + "_full_features.f32")
122    cf.detach().numpy().tofile(prefix + "_out_features.f32")
123
124
125
126if __name__ == "__main__":
127    args = parser.parse_args()
128
129    os.makedirs(args.output_folder, exist_ok=True)
130
131    lace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "lace_checkpoint.pth"), map_location='cpu')
132    nolace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "nolace_checkpoint.pth"), map_location='cpu')
133
134    lace = model_dict['lace'](**lace_checkpoint['setup']['model']['kwargs'])
135    nolace = model_dict['nolace'](**nolace_checkpoint['setup']['model']['kwargs'])
136
137    lace.load_state_dict(lace_checkpoint['state_dict'])
138    nolace.load_state_dict(nolace_checkpoint['state_dict'])
139
140    if args.debug:
141        endoscopy.init(args.output_folder)
142
143    # lace af1, 1 input channel, 1 output channel
144    create_adaconv_testvector(os.path.join(args.output_folder, "lace_af1"), lace.af1, 5, debug=args.debug)
145
146    # nolace af1, 1 input channel, 2 output channels
147    create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af1"), nolace.af1, 5, debug=args.debug)
148
149    # nolace af4, 2 input channel, 1 output channels
150    create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af4"), nolace.af4, 5, debug=args.debug)
151
152    # nolace af2, 2 input channel, 2 output channels
153    create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af2"), nolace.af2, 5, debug=args.debug)
154
155    # lace cf1
156    create_adacomb_testvector(os.path.join(args.output_folder, "lace_cf1"), lace.cf1, 5, debug=args.debug)
157
158    # nolace tdshape1
159    create_adashape_testvector(os.path.join(args.output_folder, "nolace_tdshape1"), nolace.tdshape1, 5)
160
161    # lace feature net
162    create_feature_net_testvector(os.path.join(args.output_folder, 'lace'), lace, 5)
163
164    if args.debug:
165        endoscopy.close()
166