Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Jean Ibarz
ml_binaural_audio
Commits
6d78a0af
Commit
6d78a0af
authored
Feb 11, 2021
by
Jean Ibarz
Browse files
Refactoring of experiment_logger into a class.
parent
639047a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
core/logger.py
View file @
6d78a0af
...
...
@@ -3,15 +3,19 @@ import os
import
pickle
def
experiment_logger
(
rootdir
,
logdir_prefix
,
exp_config
,
model
,
training_df
,
test_df
):
out_dir
=
os
.
path
.
join
(
rootdir
,
logdir_prefix
)
class
Logger
:
def
__init__
(
self
,
rootdir
):
self
.
rootdir
=
rootdir
# save trained model, exp_config.json, training_df, test_df, and computed errors on test set
os
.
makedirs
(
name
=
out_dir
,
exist_ok
=
True
)
model
.
save
(
filepath
=
os
.
path
.
join
(
out_dir
,
'model'
))
with
open
(
os
.
path
.
join
(
out_dir
,
'exp_config.json'
),
mode
=
'wt'
)
as
exp_config_f
:
json
.
dump
(
exp_config
,
fp
=
exp_config_f
,
indent
=
4
,
sort_keys
=
True
)
with
open
(
os
.
path
.
join
(
out_dir
,
'training_df.p'
),
mode
=
'wb'
)
as
training_df_f
:
pickle
.
dump
(
training_df
,
file
=
training_df_f
)
with
open
(
os
.
path
.
join
(
out_dir
,
'test_df.p'
),
mode
=
'wb'
)
as
test_df_f
:
pickle
.
dump
(
test_df
,
file
=
test_df_f
)
def
log_model_training
(
self
,
logdir_prefix
,
exp_config
,
model
,
training_df
,
test_df
):
out_dir
=
os
.
path
.
join
(
self
.
rootdir
,
logdir_prefix
)
# save trained model, exp_config.json, training_df, test_df, and computed errors on test set
os
.
makedirs
(
name
=
out_dir
,
exist_ok
=
True
)
model
.
save
(
filepath
=
os
.
path
.
join
(
out_dir
,
'model'
))
with
open
(
os
.
path
.
join
(
out_dir
,
'exp_config.json'
),
mode
=
'wt'
)
as
exp_config_f
:
json
.
dump
(
exp_config
,
fp
=
exp_config_f
,
indent
=
4
,
sort_keys
=
True
)
with
open
(
os
.
path
.
join
(
out_dir
,
'training_df.p'
),
mode
=
'wb'
)
as
training_df_f
:
pickle
.
dump
(
training_df
,
file
=
training_df_f
)
with
open
(
os
.
path
.
join
(
out_dir
,
'test_df.p'
),
mode
=
'wb'
)
as
test_df_f
:
pickle
.
dump
(
test_df
,
file
=
test_df_f
)
scripts/train_model.py
View file @
6d78a0af
...
...
@@ -4,7 +4,7 @@ from core.utils import load_ircam_hrirs_data, split_dataset, generate_signals, g
azimuth_to_left_center_right_onehot
import
numpy
as
np
import
tensorflow
as
tf
from
core.logger
import
experiment_l
ogger
from
core.logger
import
L
ogger
from
core.model
import
model_factory
import
pandas
as
pd
from
core.metrics
import
MeanAbsoluteAzimuthError
...
...
@@ -21,6 +21,7 @@ ExperimentConfiguration = dict
ALL_SUBJECTS_EXCEPT_1059
=
list
(
range
(
1002
,
1010
))
+
list
(
range
(
1012
,
1019
))
+
list
(
range
(
1021
,
1024
))
+
list
(
[
1025
,
1026
])
+
list
(
range
(
1028
,
1035
))
+
list
(
range
(
1037
,
1059
))
RESULTS_ROOTDIR
=
'../tf_results'
exp_config
=
ExperimentConfiguration
({
'n_augment'
:
1
,
'n_epochs'
:
20
,
...
...
@@ -141,6 +142,7 @@ if __name__ == '__main__':
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
hrir_status = 'comp' if exp_config['compensated_hrirs'] else 'raw'
logdir_prefix = f"{hrir_status}_{timestr}_{k}"
logger = Logger(rootdir=RESULTS_ROOTDIR)
# create and train the model
model = model_factory(model_name=exp_config['model_name'], model_config=exp_config['model_config'])
...
...
@@ -197,10 +199,19 @@ if __name__ == '__main__':
# Here, width=time axis, and Height=channel (left or right)
training_signals = np.expand_dims(training_signals, axis=-1)
test_signals = np.expand_dims(test_signals, axis=-1)
model.fit(x=training_signals, y=training_labels, batch_size=batch_size, epochs=exp_config['n_epochs'],
validation_data=(test_signals, test_labels))
model.fit(
x=training_signals, y=training_labels,
batch_size=batch_size,
epochs=exp_config['n_epochs'],
validation_data=(test_signals, test_labels),
callbacks=[tensorboard_callback],
)
# log the experiment configuration file, the trained model, the training set and the test set
experiment_logger(rootdir='../tf_results', logdir_prefix=logdir_prefix, exp_config=exp_config, model=model,
training_df=training_df,
test_df=test_df)
logger.log_model_training(
exp_config=exp_config,
logdir_prefix=logdir_prefix,
model=model,
training_df=training_df,
test_df=test_df
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment