Commit 6d78a0af authored by Jean Ibarz's avatar Jean Ibarz
Browse files

Refactoring of experiment_logger into a class.

parent 639047a5
......@@ -3,15 +3,19 @@ import os
import pickle
def experiment_logger(rootdir, logdir_prefix, exp_config, model, training_df, test_df):
out_dir = os.path.join(rootdir, logdir_prefix)
class Logger:
def __init__(self, rootdir):
self.rootdir = rootdir
# save trained model, exp_config.json, training_df, test_df, and computed errors on test set
os.makedirs(name=out_dir, exist_ok=True)
model.save(filepath=os.path.join(out_dir, 'model'))
with open(os.path.join(out_dir, 'exp_config.json'), mode='wt') as exp_config_f:
json.dump(exp_config, fp=exp_config_f, indent=4, sort_keys=True)
with open(os.path.join(out_dir, 'training_df.p'), mode='wb') as training_df_f:
pickle.dump(training_df, file=training_df_f)
with open(os.path.join(out_dir, 'test_df.p'), mode='wb') as test_df_f:
pickle.dump(test_df, file=test_df_f)
def log_model_training(self, logdir_prefix, exp_config, model, training_df, test_df):
out_dir = os.path.join(self.rootdir, logdir_prefix)
# save trained model, exp_config.json, training_df, test_df, and computed errors on test set
os.makedirs(name=out_dir, exist_ok=True)
model.save(filepath=os.path.join(out_dir, 'model'))
with open(os.path.join(out_dir, 'exp_config.json'), mode='wt') as exp_config_f:
json.dump(exp_config, fp=exp_config_f, indent=4, sort_keys=True)
with open(os.path.join(out_dir, 'training_df.p'), mode='wb') as training_df_f:
pickle.dump(training_df, file=training_df_f)
with open(os.path.join(out_dir, 'test_df.p'), mode='wb') as test_df_f:
pickle.dump(test_df, file=test_df_f)
......@@ -4,7 +4,7 @@ from core.utils import load_ircam_hrirs_data, split_dataset, generate_signals, g
azimuth_to_left_center_right_onehot
import numpy as np
import tensorflow as tf
from core.logger import experiment_logger
from core.logger import Logger
from core.model import model_factory
import pandas as pd
from core.metrics import MeanAbsoluteAzimuthError
......@@ -21,6 +21,7 @@ ExperimentConfiguration = dict
ALL_SUBJECTS_EXCEPT_1059 = list(range(1002, 1010)) + list(range(1012, 1019)) + list(range(1021, 1024)) + list(
[1025, 1026]) + list(range(1028, 1035)) + list(range(1037, 1059))
RESULTS_ROOTDIR = '../tf_results'
exp_config = ExperimentConfiguration({'n_augment': 1,
'n_epochs': 20,
......@@ -141,6 +142,7 @@ if __name__ == '__main__':
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
hrir_status = 'comp' if exp_config['compensated_hrirs'] else 'raw'
logdir_prefix = f"{hrir_status}_{timestr}_{k}"
logger = Logger(rootdir=RESULTS_ROOTDIR)
# create and train the model
model = model_factory(model_name=exp_config['model_name'], model_config=exp_config['model_config'])
......@@ -197,10 +199,19 @@ if __name__ == '__main__':
# Here, width=time axis, and Height=channel (left or right)
training_signals = np.expand_dims(training_signals, axis=-1)
test_signals = np.expand_dims(test_signals, axis=-1)
model.fit(x=training_signals, y=training_labels, batch_size=batch_size, epochs=exp_config['n_epochs'],
validation_data=(test_signals, test_labels))
model.fit(
x=training_signals, y=training_labels,
batch_size=batch_size,
epochs=exp_config['n_epochs'],
validation_data=(test_signals, test_labels),
callbacks=[tensorboard_callback],
)
# log the experiment configuration file, the trained model, the training set and the test set
experiment_logger(rootdir='../tf_results', logdir_prefix=logdir_prefix, exp_config=exp_config, model=model,
training_df=training_df,
test_df=test_df)
logger.log_model_training(
exp_config=exp_config,
logdir_prefix=logdir_prefix,
model=model,
training_df=training_df,
test_df=test_df
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment