Commit b7388a26 authored by Jean Ibarz's avatar Jean Ibarz
Browse files

Added a model factory to create a 'default' model or a 'left_center_right' model.

parent 9b54deb9
......@@ -3,6 +3,15 @@ import tensorflow.keras.layers as tkl
from core.layers import RandomShift2DLayer, RandomScale2DLayer
def model_factory(model_config, model_name: str = 'default'):
if model_name == 'default':
return default_model_creator(model_config)
elif model_name == 'left_center_right':
return left_center_or_right_model_creator(model_config)
else:
raise ValueError(f'model_name {model_name} unknown. Possible values: {{''default'', ''left_center_right''}}')
def default_model_creator(model_config):
model = tf.keras.Sequential([
RandomScale2DLayer(minval=model_config['random_scale']['minval'],
......
......@@ -4,7 +4,7 @@ from core.utils import load_ircam_hrirs_data, split_dataset, generate_signals, g
import numpy as np
import tensorflow as tf
from core.logger import experiment_logger
from core.model import default_model_creator
from core.model import model_factory
import pandas as pd
from core.metrics import MeanAbsoluteAzimuthError
......@@ -26,6 +26,7 @@ exp_config = ExperimentConfiguration({'n_augment': 1,
'n_iters': 20,
'batch_size': 512,
'learning_rate': 0.0005,
'model_name': 'default', # possible values: {'default', 'left_center_right'}
'model_config': {
'random_scale': {
'minval': -10,
......@@ -131,7 +132,7 @@ if __name__ == '__main__':
logdir_prefix = f"{hrir_status}_{timestr}_{k}"
# create and train the model
model = default_model_creator(model_config=exp_config['model_config'])
model = model_factory(model_name=exp_config['model_name'], model_config=exp_config['model_config'])
opt = tf.keras.optimizers.RMSprop(learning_rate=exp_config['learning_rate'])
# opt = tf.keras.optimizers.Adam(learning_rate=lr)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment