Commit ffdc29ea authored by Jean Ibarz's avatar Jean Ibarz
Browse files

Removed data_generator(), and updated scripts accordingly.

parent fbed93ec
......@@ -128,25 +128,6 @@ def generate_augmented_labelled_dataset(signals, labels, convolve_mode='valid',
return augmented_signals, augmented_labels
def data_generator(signal):
"""
Apply a random shift and random scaling to a signal with shape (samples, channels)
:param signal:
:return:
"""
# random shift
r1 = np.random.randint(low=1, high=100)
signal = np.copy(signal)
_, n_channels = signal.shape
signal = np.concatenate([np.zeros(shape=(r1, n_channels)), signal, np.zeros(shape=(100 - r1, n_channels))], axis=0)
# scaling by a factor sampled from a uniform random variable in [0.1, 10[
r2 = np.random.random() * 9.9 + 0.1
signal = signal * r2
return signal
def split_long_wav(rootpath, max_chunks=None, chunk_length=1223, quantile=0.5):
"""
Read all wav files in rootpath, and return the list of wav parts of length 40000 samples who's RMS values are
......@@ -172,15 +153,15 @@ def split_long_wav(rootpath, max_chunks=None, chunk_length=1223, quantile=0.5):
data = data[:, 0]
# keep only a length that is splittable into N chunks of equal length
_nb_chunks = len(data) // chunk_length
if nb_chunks + _nb_chunks > max_chunks/quantile:
_nb_chunks = int(max_chunks//quantile - nb_chunks)
if nb_chunks + _nb_chunks > max_chunks / quantile:
_nb_chunks = int(max_chunks // quantile - nb_chunks)
data = data[0:_nb_chunks * chunk_length]
max_abs = np.max(np.abs(data))
assert max_abs >= 0.1
all_data.append(data / max_abs) # keep only the left channel
nb_chunks += _nb_chunks
if max_chunks is not None and nb_chunks >= max_chunks/quantile:
if max_chunks is not None and nb_chunks >= max_chunks / quantile:
# it is only a heuristic of the max parts... in reality, the number of parts may be lower in the end
break
all_data = np.concatenate(all_data)
......
......@@ -7,7 +7,6 @@ import pandas as pd
import plotly.graph_objects as go
import tensorflow as tf
from ipywidgets import widgets
from core.utils import data_generator
from core.metrics import MeanAbsoluteAzimuthError
#set GPUs if any
......@@ -93,7 +92,7 @@ def generate_df_from_exp_dir(exp_dir):
true_azimuth = np.array(test_df['azimuth'])
for i in range(10):
test_signals = np.array([data_generator(signal) for signal in list(test_df['signal'])])
test_signals = np.array([signal for signal in list(test_df['signal'])])
# use the model to make predictions
predicted_azimuth = model.predict_step(data=test_signals).numpy()
......
......@@ -7,7 +7,6 @@ import pandas as pd
import plotly.graph_objects as go
import tensorflow as tf
from ipywidgets import widgets
from core.utils import data_generator
from core.metrics import MeanAbsoluteAzimuthError
import math
import tensorflow.keras.backend as kb
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment