Commit ba8c39c8 authored by Jean Ibarz's avatar Jean Ibarz
Browse files

Update train_model.py

parent 5655cba0
......@@ -136,7 +136,6 @@ if __name__ == '__main__':
model = model_factory(model_name=exp_config['model_name'], model_config=exp_config['model_config'])
opt = tf.keras.optimizers.RMSprop(learning_rate=exp_config['learning_rate'])
# opt = tf.keras.optimizers.Adam(learning_rate=lr)
if exp_config['model_name'] == 'default':
model.compile(optimizer=opt,
......@@ -144,8 +143,9 @@ if __name__ == '__main__':
metrics=[])
elif exp_config['model_name'] == 'left_center_right':
model.compile(optimizer=opt,
loss=tf.keras.losses.CategoricalCrossentropy(), # tf.keras.losses.MeanSquaredError(),
metrics=[ # tf.keras.metrics.TruePositives(name='tp'),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[
# tf.keras.metrics.TruePositives(name='tp'),
# tf.keras.metrics.FalsePositives(name='fp'),
# tf.keras.metrics.TrueNegatives(name='tn'),
# tf.keras.metrics.FalseNegatives(name='fn'),
......@@ -156,8 +156,6 @@ if __name__ == '__main__':
])
# model.summary()
# model = default_model_creator(lr=exp_config['learning_rate'])
n_samples = None
for l in range(exp_config['n_iters']):
print(f'\tITERATION {l}')
......@@ -191,8 +189,6 @@ if __name__ == '__main__':
model.fit(x=training_signals, y=training_labels, batch_size=batch_size, epochs=exp_config['n_epochs'],
validation_data=(test_signals, test_labels))
# n_samples = int((n_samples + 1) * 1.3)
# log the experiment configuration file, the trained model, the training set and the test set
experiment_logger(rootdir='../tf_results', logdir_prefix=logdir_prefix, exp_config=exp_config, model=model,
training_df=training_df,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment