Commit 8e4ac5d7 authored by jean Ibarz's avatar jean Ibarz
Browse files

Added new script

parent 31c37eb3
import argparse
from datetime import datetime
from core.utils import load_ircam_hrirs_data, split_dataset, generate_signals, generate_augmented_labelled_dataset
import numpy as np
import tensorflow as tf
from core.logger import experiment_logger
from core.model import default_model_creator, conv_model_creator
import pandas as pd
from core.metrics import MeanAbsoluteAzimuthError
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
parser = argparse.ArgumentParser()
parser.add_argument("--in_dir", type=str, help="input directory")
ExperimentConfiguration = dict
ALL_SUBJECTS_EXCEPT_1059 = list(range(1002, 1010)) + list(range(1012, 1019)) + list(range(1021, 1024)) + list(
[1025, 1026]) + list(range(1028, 1035)) + list(range(1037, 1059))
exp_config = ExperimentConfiguration({'n_augment': 20,
'n_epochs': 20,
'n_iters': 20,
'batch_size': 512,
'learning_rate': 0.0005,
'compensated_hrirs': True,
'stimulus': 'dirac', # possible values: ['dirac', 'any_sound']
'train_subject_ids': ALL_SUBJECTS_EXCEPT_1059[0:30],
'test_subject_ids': [1059],
'movement': 'deterministic', # possible values: ['none', 'random', 'deterministic']
'sessions': [0]})
if __name__ == '__main__':
long_mp3_filepath = 'E:\\long_music_mixes'
from core.utils import split_long_wav
batch_size = exp_config['batch_size']
if 'stimulus' in exp_config.keys():
if exp_config['stimulus'] == 'any_sound':
sounds = np.array(split_long_wav(rootpath=long_mp3_filepath, max_parts=10000))
training_sounds, test_sounds = np.array_split(sounds, 2, axis=0)
elif exp_config['stimulus'] == 'dirac':
sounds = np.ones(shape=(2, 1))
training_sounds, test_sounds = np.array_split(sounds, 2, axis=0)
df = load_ircam_hrirs_data(filepath='../databases/ircam_hrirs_512samples.p')
# drop samplerate and distance columns
del df['samplerate']
del df['distance']
# keep only HRIRs with elevation in [-45,45]°
df = df.loc[(df['elevation'] <= 45) | (df['elevation'] >= 315)]
assert len([1 for hrir in df['hrir_l'] if isinstance(hrir, float)]) == 0
assert len([1 for hrir in df['hrir_r'] if isinstance(hrir, float)]) == 0
# keep only raw or compensated HRIRs
if exp_config['compensated_hrirs']:
df = df[df['type'] == 'comp']
else:
df = df[df['type'] == 'raw']
# reset index
df.reset_index(drop=True, inplace=True)
all_subject_ids = df['subject_id'].cat.categories
valid_train_subject_ids = 'train_subject_ids' in exp_config.keys() and exp_config[
'train_subject_ids'] is not None and len(exp_config['train_subject_ids']) > 0
valid_test_subject_ids = 'test_subject_ids' in exp_config.keys() and exp_config['test_subject_ids'] is not None and len(
exp_config['test_subject_ids']) > 0
if valid_train_subject_ids and valid_test_subject_ids:
train_subject_ids = exp_config['train_subject_ids']
test_subject_ids = exp_config['test_subject_ids']
elif valid_train_subject_ids and not valid_test_subject_ids:
train_subject_ids = exp_config['train_subject_ids']
test_subject_ids = [id for id in all_subject_ids if id not in train_subject_ids]
elif not valid_train_subject_ids and valid_test_subject_ids:
test_subject_ids = exp_config['test_subject_ids']
train_subject_ids = [id for id in all_subject_ids if id not in test_subject_ids]
elif not valid_train_subject_ids and not valid_test_subject_ids:
test_subject_ids = np.random.choice(all_subject_ids.values, size=10, replace=False)
train_subject_ids = [id for id in all_subject_ids if id not in test_subject_ids]
else:
raise RuntimeError('this line should not be reached')
# convert to str in case elements are integers
test_subject_ids = [str(id) for id in test_subject_ids]
train_subject_ids = [str(id) for id in train_subject_ids]
training_df = pd.concat([generate_signals(df[df['subject_id'].isin(train_subject_ids)],
movement=exp_config['movement'], in_place=False, drop_hrirs=True) for
_ in range(exp_config['n_augment'])])
test_df = pd.concat([generate_signals(df[df['subject_id'].isin(test_subject_ids)],
movement=exp_config['movement'], in_place=False, drop_hrirs=True) for _
in range(exp_config['n_augment'])])
for k in exp_config['sessions']:
print(f'TRAINING SESSION N°{k}')
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
hrir_status = 'comp' if exp_config['compensated_hrirs'] else 'raw'
logdir_prefix = f"{hrir_status}_{timestr}_{k}"
# create and train the model
model = default_model_creator()
opt = tf.keras.optimizers.RMSprop(learning_rate=exp_config['learning_rate'])
# opt = tf.keras.optimizers.Adam(learning_rate=lr)
model.compile(optimizer=opt,
loss=MeanAbsoluteAzimuthError(), # tf.keras.losses.MeanSquaredError(),
metrics=[])
# model.summary()
# model = default_model_creator(lr=exp_config['learning_rate'])
n_samples = 1
for l in range(exp_config['n_iters']):
print(f'\tITERATION {l}')
# copy the dataframe and add a column 'signal', depending on test case (deterministic, random, or no movement)
# /!\ because we may do multiple training sessions, and hrirs are dropped after generating the signals,
# we must not modify df in place /!\
np.random.shuffle(training_sounds)
np.random.shuffle(test_sounds)
# creation of augmented labelled data sets
training_signals, training_labels = generate_augmented_labelled_dataset(
signals=np.array(training_df['signal']),
labels=np.array(training_df['azimuth']),
sounds=training_sounds,
n_samples=n_samples,
n_augment=1)
test_signals, test_labels = generate_augmented_labelled_dataset(signals=np.array(test_df['signal']),
labels=np.array(test_df['azimuth']),
sounds=test_sounds,
n_samples=n_samples,
n_augment=1)
model.fit(x=training_signals, y=training_labels, batch_size=batch_size, epochs=exp_config['n_epochs'],
validation_data=(test_signals, test_labels))
n_samples = int((n_samples + 1) * 1.3)
# log the experiment configuration file, the trained model, the training set and the test set
experiment_logger(rootdir='../tf_results', logdir_prefix=logdir_prefix, exp_config=exp_config, model=model,
training_df=training_df,
test_df=test_df)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment