Commit 0d7bf843 authored by Jean Ibarz's avatar Jean Ibarz
Browse files

Added utility functions to generate multilabel azimuth encoding from azimuth...

Added utility functions to generate multilabel azimuth encoding from azimuth values, and corresponding test.
parent c552ff75
......@@ -231,3 +231,32 @@ def azimuth_to_left_center_right_onehot(labels, in_place=False):
labels[(labels > 0) & (labels < 180)] = 2 # right side
labels = tf.one_hot(labels, depth=3)
return labels
def is_power_of_two(n):
# Code from https://stackoverflow.com/users/6045800/tomerikoo answer,
# see: https://stackoverflow.com/questions/57025836/how-to-check-if-a-given-number-is-a-power-of-two
return (n & (n - 1) == 0) and n != 0
def generate_multilabel_from_azimuth(azimuth, n_labels):
"""
n_labels must be a power of 2, i.e. it should exist an integer k such that n_labels=2^k
Take azimuth labels, and encode in a binary word of length 2+2^2+2^3+...+2^k
:param n_labels:
:return:
"""
if not is_power_of_two(n_labels):
raise ValueError('n_labels must be a power of two')
import math
import tensorflow as tf
n_iters = int(math.log2(n_labels))
labels = []
for i in range(n_iters):
bins = np.linspace(start=0, stop=360, num=2 ** (i + 1), endpoint=False)
inds = np.digitize(x=azimuth, bins=bins)
one_hot = tf.one_hot(inds - 1, depth=2 ** (i + 1))
labels.append(one_hot)
labels = tf.concat(values=labels, axis=1)
return labels.shape[-1], labels
......@@ -20,3 +20,46 @@ def test_get_gtf_kernels():
from core.utils import get_gtf_kernels
kernels = get_gtf_kernels("../gtf/")
assert kernels.shape == (512, 20)
def test_generate_multilabel_from_azimuth():
from core.utils import generate_multilabel_from_azimuth
import numpy as np
azimuth = np.array([0, 10, 45, 80, 90, 95, 170, 190, 355, 360])
n_labels = 32
total_labels, labels = generate_multilabel_from_azimuth(azimuth=azimuth, n_labels=n_labels)
expected_total_labels = 2 + 4 + 8 + 16 + 32
expected_labels = np.array(
[[1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])
assert total_labels == expected_total_labels
assert np.array_equal(a1=labels, a2=expected_labels)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment