Commit b2d41cb7 authored by Jean Ibarz's avatar Jean Ibarz
Browse files

Updated train2.py script to do left and right zero padding when using...

Updated train2.py script to do left and right zero padding when using convolution with mode 'valid'. Stimulus truncature with n_samples, which was incremented for each training iteration, has also been disabled. It should be removed in the future.
parent be8c48de
......@@ -26,8 +26,19 @@ exp_config = ExperimentConfiguration({'n_augment': 1,
'n_iters': 20,
'batch_size': 512,
'learning_rate': 0.0005,
'model_config': {
'random_scale': {
'minval': -10,
'maxval': 10,
},
'random_shift': {
'minval': 0,
'maxval': 100,
}
},
'compensated_hrirs': True,
'stimulus': 'dirac', # possible values: ['dirac', 'any_sound']
'stimulus': 'dirac', # possible values: {'dirac', 'any_sound'}
'convolve_mode': 'valid', # possible values: {'full', 'valid', 'same'}
'train_subject_ids': ALL_SUBJECTS_EXCEPT_1059[0:30],
'test_subject_ids': [1059],
'movement': 'none', # possible values: ['none', 'random', 'deterministic']
......@@ -41,12 +52,24 @@ if __name__ == '__main__':
batch_size = exp_config['batch_size']
if 'stimulus' in exp_config.keys():
if exp_config['stimulus'] == 'any_sound':
sounds = np.array(split_long_wav(rootpath=long_mp3_filepath, max_parts=10000))
training_sounds, test_sounds = np.array_split(sounds, 2, axis=0)
elif exp_config['stimulus'] == 'dirac':
sounds = np.ones(shape=(2, 1))
if exp_config['stimulus'] == 'dirac':
if not 'convolve_mode' in exp_config.keys() or exp_config['convolve_mode'] == 'valid':
# the valid part from the convolved sound is sound[len(hrir)-1:-len(hrir)+1]
# hence we have to do len(hrir)-1 left and right zero padding so that the dirac belong to the result
n_zeros = 512-1+exp_config['model_config']['random_shift']['maxval']
zeros = np.zeros(shape=(2, n_zeros))
sounds = np.concatenate([zeros, np.ones(shape=(2, 1)), zeros], axis=1)
training_sounds, test_sounds = np.array_split(sounds, 2, axis=0)
elif exp_config['convolve_mode'] == 'full':
# the valid part from the convolved sound is the full discrete convolution: no zero padding required
sounds = np.ones(shape=(2, 1))
else:
raise ValueError('convolve_mode must be ''valid'' (default value) or ''full''')
elif exp_config['stimulus'] == 'any_sound':
sounds = np.array(split_long_wav(rootpath=long_mp3_filepath, max_chunks=5000, chunk_length=1223))
training_sounds, test_sounds = np.array_split(sounds, 2, axis=0)
else:
raise ValueError('stimulus must be ''dirac'' or ''any_sound''')
df = load_ircam_hrirs_data(filepath='../databases/ircam_hrirs_512samples.p')
......@@ -108,7 +131,8 @@ if __name__ == '__main__':
logdir_prefix = f"{hrir_status}_{timestr}_{k}"
# create and train the model
model = default_model_creator()
model = default_model_creator(model_config=exp_config['model_config'])
opt = tf.keras.optimizers.RMSprop(learning_rate=exp_config['learning_rate'])
# opt = tf.keras.optimizers.Adam(learning_rate=lr)
model.compile(optimizer=opt,
......@@ -117,8 +141,7 @@ if __name__ == '__main__':
# model.summary()
# model = default_model_creator(lr=exp_config['learning_rate'])
n_samples = 1
n_samples = None
for l in range(exp_config['n_iters']):
print(f'\tITERATION {l}')
......@@ -148,7 +171,7 @@ if __name__ == '__main__':
model.fit(x=training_signals, y=training_labels, batch_size=batch_size, epochs=exp_config['n_epochs'],
validation_data=(test_signals, test_labels))
n_samples = int((n_samples + 1) * 1.3)
# n_samples = int((n_samples + 1) * 1.3)
# log the experiment configuration file, the trained model, the training set and the test set
experiment_logger(rootdir='../tf_results', logdir_prefix=logdir_prefix, exp_config=exp_config, model=model,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment