Commit c552ff75 authored by Jean Ibarz's avatar Jean Ibarz
Browse files

Added some code to do some class balancing, repeating 'center' class inputs,...

Added some code to do some class balancing, repeating 'center' class inputs, when training the 'left_center_right' model.
parent 8565b72b
......@@ -196,6 +196,21 @@ if __name__ == '__main__':
n_augment=1)
if exp_config['model_name'] == 'left_center_right':
# for each individual, classes are 11 left, 2 center, 11 right
# to reduce class unbalancing, we repeat center examples 5 times (11//2)
i_left = np.argwhere((training_labels > 0) & (training_labels < 180))
i_center = np.argwhere((training_labels == 0) | (training_labels == 180) | (training_labels == 360))
i_right = np.argwhere((training_labels > 180) & (training_labels < 360))
n_left = len(i_left)
n_center = len(i_center)
n_right = len(i_right)
n_max = max(n_left, n_center, n_right)
i_balanced = np.concatenate([(indices * (n_max // len(indices))) for indices in
[list(i_left), list(i_center), list(i_right)]]).flatten()
np.random.shuffle(i_balanced)
training_signals = training_signals[i_balanced]
training_labels = training_labels[i_balanced]
# convert labels with azimuth in [0,360] to left,center,right with one_hot encoding
training_labels = azimuth_to_left_center_right_onehot(training_labels, in_place=True)
test_labels = azimuth_to_left_center_right_onehot(test_labels, in_place=True)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment