Commit e1efb3a9 authored by jean Ibarz's avatar jean Ibarz
Browse files

Added helper function to load Gammatone filters impulses responses from a directory.

parent 63aff984
......@@ -72,7 +72,8 @@ def generate_signals(df, movement, in_place=True, drop_hrirs=True):
return df
def generate_augmented_labelled_dataset(signals, labels, symmetry=True, n_augment: int = 1, n_samples=None, sounds=None):
def generate_augmented_labelled_dataset(signals, labels, symmetry=True, n_augment: int = 1, n_samples=None,
sounds=None):
n_channels = 2
if sounds is not None:
......@@ -120,7 +121,7 @@ def generate_augmented_labelled_dataset(signals, labels, symmetry=True, n_augmen
# inverse left and right ears signals, and changes the azimuth accordingly
# e.g. if signal is coming from 45°, it comes from 315° after left/right ears signals inversion
augmented_signals2 = np.flip(augmented_signals.copy(), axis=2)
augmented_labels2 = 360-augmented_labels.copy()
augmented_labels2 = 360 - augmented_labels.copy()
augmented_signals = np.concatenate([augmented_signals, augmented_signals2])
augmented_labels = np.concatenate([augmented_labels, augmented_labels2])
return augmented_signals, augmented_labels
......@@ -201,3 +202,26 @@ def set_gpus():
print(e)
def get_gtf_kernels(gtf_dirpath):
"""
Load all files found in gtf_dirpath that starts with 'filter' and ends with '.wav'
and return a n-dimensional array of time-reversed loaded signals
:return:
"""
import os
import scipy.io.wavfile
wav_files = [file for file in os.listdir(gtf_dirpath) if file.startswith('filter') and file.endswith('.wav')]
print(f'found {len(wav_files)} filters to be loaded...')
samplerates = []
gtf_irs = []
for wav_file in wav_files:
samplerate, data = scipy.io.wavfile.read(filename=os.path.join(gtf_dirpath, wav_file))
samplerates.append(samplerate)
gtf_irs.append(data)
gtf_irs = np.array(gtf_irs)
assert gtf_irs.shape[0] == len(wav_files)
kernels = np.flip(gtf_irs, axis=1)
kernels = np.transpose(kernels)
return kernels
......@@ -14,3 +14,9 @@ def test_load_ircam_morphological_data():
# there should be a dataframe of shape (51,60)
assert all_infos.shape == (51, 60)
def test_get_gtf_kernels():
from core.utils import get_gtf_kernels
kernels = get_gtf_kernels("../gtf/")
assert kernels.shape == (512, 20)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment