Commit 0c586bf2 authored by jean Ibarz's avatar jean Ibarz
Browse files

Added GtfLayer(), a Conv1D layer with non trainable kernels who's values are...

Added GtfLayer(), a Conv1D layer with non trainable kernels who's values are loaded from time-reversed Gammatone filters impulses responses, which are either passed directly or loaded from a directory containing the Gammatone filters. A test is used to verify that the layer can process some input, and some lines can be uncommented to check visually that the filters are correctly applied to the input (a dirac with some silence before and after).
parent e1efb3a9
......@@ -66,3 +66,20 @@ class RandomScale2DLayer(tf.keras.layers.Layer):
def compute_output_shape(self, input_shape):
return input_shape
class GtfLayer(tf.keras.layers.Conv1D):
def __init__(self, kernels=None, gtf_dirpath=None):
if kernels is None:
if gtf_dirpath is None:
raise ValueError('at least one argument of kernels or gtf_dirpath must be provided')
from core.utils import get_gtf_kernels
kernels = get_gtf_kernels(gtf_dirpath=gtf_dirpath)
assert len(kernels.shape) == 2
elif gtf_dirpath is not None:
raise UserWarning('gtf_dirpath argument is ignored because kernels argument is also provided')
self.kernels = tf.constant_initializer(kernels)
super(GtfLayer, self).__init__(filters=kernels.shape[1], kernel_size=kernels.shape[0], strides=1,
padding='causal', kernel_initializer=self.kernels, trainable=False,
......@@ -83,3 +83,28 @@ def test_random_scale_2d_layer():
np_output = layer(input, training=False)
np_output = np_output.numpy()
assert np.allclose(a=input, b=np_output)
def test_gtf_layer():
import tensorflow as tf
from core.utils import get_gtf_kernels
kernels = get_gtf_kernels(gtf_dirpath='../gtf')
assert kernels.shape == (512, 20)
from core.layers import GtfLayer
gtf_layer = GtfLayer(kernels=kernels)
diracs = np.zeros(shape=(1, 512 * 2, 2, 1))
diracs[:, 511, :, 0] = 1
np_r = gtf_layer(diracs[:, :, 0, :]).numpy()
assert True
# # for debugging purposes:
# # plot some diracs convolved by the gtf_layer to ensure the filters are applied correctly
# import matplotlib.pyplot as plt
# plt.plot(np_r[0, :, 0], 'r')
# plt.plot(np_r[0, :, 10], 'b')
# plt.plot(np_r[0, :, -1], 'g')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment