Noise reduction using spectral gating in python
Posted on Sat 07 July 2018 in Signal Processing
A pip installable package for this code is now available at https://github.com/timsainb/noisereduce
Noise reduction in python using¶
- This algorithm is based (but not completely reproducing) on the one outlined by Audacity for the noise reduction effect (Link to C++ code)
- The algorithm requires two inputs:
- A noise audio clip comtaining prototypical noise of the audio clip
- A signal audio clip containing the signal and the noise intended to be removed
Steps of algorithm¶
- An FFT is calculated over the noise audio clip
- Statistics are calculated over FFT of the the noise (in frequency)
- A threshold is calculated based upon the statistics of the noise (and the desired sensitivity of the algorithm)
- An FFT is calculated over the signal
- A mask is determined by comparing the signal FFT to the threshold
- The mask is smoothed with a filter over frequency and time
- The mask is appled to the FFT of the signal, and is inverted
In [1]:
import IPython
from scipy.io import wavfile
import scipy.signal
import numpy as np
import matplotlib.pyplot as plt
import librosa
%matplotlib inline
Load data¶
In [2]:
wav_loc = "assets/audio/fish.wav"
rate, data = wavfile.read(wav_loc)
data = data / 32768
In [3]:
# from https://stackoverflow.com/questions/33933842/how-to-generate-noise-in-frequency-range-with-numpy
def fftnoise(f):
f = np.array(f, dtype="complex")
Np = (len(f) - 1) // 2
phases = np.random.rand(Np) * 2 * np.pi
phases = np.cos(phases) + 1j * np.sin(phases)
f[1 : Np + 1] *= phases
f[-1 : -1 - Np : -1] = np.conj(f[1 : Np + 1])
return np.fft.ifft(f).real
def band_limited_noise(min_freq, max_freq, samples=1024, samplerate=1):
freqs = np.abs(np.fft.fftfreq(samples, 1 / samplerate))
f = np.zeros(samples)
f[np.logical_and(freqs >= min_freq, freqs <= max_freq)] = 1
return fftnoise(f)
In [4]:
IPython.display.Audio(data=data, rate=rate)
Out[4]:
In [5]:
fig, ax = plt.subplots(figsize=(20,4))
ax.plot(data)
Out[5]:
Add noise¶
In [6]:
noise_len = 2 # seconds
noise = band_limited_noise(min_freq=4000, max_freq = 12000, samples=len(data), samplerate=rate)*10
noise_clip = noise[:rate*noise_len]
audio_clip_band_limited = data+noise
In [7]:
fig, ax = plt.subplots(figsize=(20,4))
ax.plot(audio_clip_band_limited)
IPython.display.Audio(data=audio_clip_band_limited, rate=rate)
Out[7]:
denoise¶
In [8]:
import time
from datetime import timedelta as td
def _stft(y, n_fft, hop_length, win_length):
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _istft(y, hop_length, win_length):
return librosa.istft(y, hop_length, win_length)
def _amp_to_db(x):
return librosa.core.amplitude_to_db(x, ref=1.0, amin=1e-20, top_db=80.0)
def _db_to_amp(x,):
return librosa.core.db_to_amplitude(x, ref=1.0)
def plot_spectrogram(signal, title):
fig, ax = plt.subplots(figsize=(20, 4))
cax = ax.matshow(
signal,
origin="lower",
aspect="auto",
cmap=plt.cm.seismic,
vmin=-1 * np.max(np.abs(signal)),
vmax=np.max(np.abs(signal)),
)
fig.colorbar(cax)
ax.set_title(title)
plt.tight_layout()
plt.show()
def plot_statistics_and_filter(
mean_freq_noise, std_freq_noise, noise_thresh, smoothing_filter
):
fig, ax = plt.subplots(ncols=2, figsize=(20, 4))
plt_mean, = ax[0].plot(mean_freq_noise, label="Mean power of noise")
plt_std, = ax[0].plot(std_freq_noise, label="Std. power of noise")
plt_std, = ax[0].plot(noise_thresh, label="Noise threshold (by frequency)")
ax[0].set_title("Threshold for mask")
ax[0].legend()
cax = ax[1].matshow(smoothing_filter, origin="lower")
fig.colorbar(cax)
ax[1].set_title("Filter for smoothing Mask")
plt.show()
def removeNoise(
audio_clip,
noise_clip,
n_grad_freq=2,
n_grad_time=4,
n_fft=2048,
win_length=2048,
hop_length=512,
n_std_thresh=1.5,
prop_decrease=1.0,
verbose=False,
visual=False,
):
"""Remove noise from audio based upon a clip containing only noise
Args:
audio_clip (array): The first parameter.
noise_clip (array): The second parameter.
n_grad_freq (int): how many frequency channels to smooth over with the mask.
n_grad_time (int): how many time channels to smooth over with the mask.
n_fft (int): number audio of frames between STFT columns.
win_length (int): Each frame of audio is windowed by `window()`. The window will be of length `win_length` and then padded with zeros to match `n_fft`..
hop_length (int):number audio of frames between STFT columns.
n_std_thresh (int): how many standard deviations louder than the mean dB of the noise (at each frequency level) to be considered signal
prop_decrease (float): To what extent should you decrease noise (1 = all, 0 = none)
visual (bool): Whether to plot the steps of the algorithm
Returns:
array: The recovered signal with noise subtracted
"""
if verbose:
start = time.time()
# STFT over noise
noise_stft = _stft(noise_clip, n_fft, hop_length, win_length)
noise_stft_db = _amp_to_db(np.abs(noise_stft)) # convert to dB
# Calculate statistics over noise
mean_freq_noise = np.mean(noise_stft_db, axis=1)
std_freq_noise = np.std(noise_stft_db, axis=1)
noise_thresh = mean_freq_noise + std_freq_noise * n_std_thresh
if verbose:
print("STFT on noise:", td(seconds=time.time() - start))
start = time.time()
# STFT over signal
if verbose:
start = time.time()
sig_stft = _stft(audio_clip, n_fft, hop_length, win_length)
sig_stft_db = _amp_to_db(np.abs(sig_stft))
if verbose:
print("STFT on signal:", td(seconds=time.time() - start))
start = time.time()
# Calculate value to mask dB to
mask_gain_dB = np.min(_amp_to_db(np.abs(sig_stft)))
print(noise_thresh, mask_gain_dB)
# Create a smoothing filter for the mask in time and frequency
smoothing_filter = np.outer(
np.concatenate(
[
np.linspace(0, 1, n_grad_freq + 1, endpoint=False),
np.linspace(1, 0, n_grad_freq + 2),
]
)[1:-1],
np.concatenate(
[
np.linspace(0, 1, n_grad_time + 1, endpoint=False),
np.linspace(1, 0, n_grad_time + 2),
]
)[1:-1],
)
smoothing_filter = smoothing_filter / np.sum(smoothing_filter)
# calculate the threshold for each frequency/time bin
db_thresh = np.repeat(
np.reshape(noise_thresh, [1, len(mean_freq_noise)]),
np.shape(sig_stft_db)[1],
axis=0,
).T
# mask if the signal is above the threshold
sig_mask = sig_stft_db < db_thresh
if verbose:
print("Masking:", td(seconds=time.time() - start))
start = time.time()
# convolve the mask with a smoothing filter
sig_mask = scipy.signal.fftconvolve(sig_mask, smoothing_filter, mode="same")
sig_mask = sig_mask * prop_decrease
if verbose:
print("Mask convolution:", td(seconds=time.time() - start))
start = time.time()
# mask the signal
sig_stft_db_masked = (
sig_stft_db * (1 - sig_mask)
+ np.ones(np.shape(mask_gain_dB)) * mask_gain_dB * sig_mask
) # mask real
sig_imag_masked = np.imag(sig_stft) * (1 - sig_mask)
sig_stft_amp = (_db_to_amp(sig_stft_db_masked) * np.sign(sig_stft)) + (
1j * sig_imag_masked
)
if verbose:
print("Mask application:", td(seconds=time.time() - start))
start = time.time()
# recover the signal
recovered_signal = _istft(sig_stft_amp, hop_length, win_length)
recovered_spec = _amp_to_db(
np.abs(_stft(recovered_signal, n_fft, hop_length, win_length))
)
if verbose:
print("Signal recovery:", td(seconds=time.time() - start))
if visual:
plot_spectrogram(noise_stft_db, title="Noise")
if visual:
plot_statistics_and_filter(
mean_freq_noise, std_freq_noise, noise_thresh, smoothing_filter
)
if visual:
plot_spectrogram(sig_stft_db, title="Signal")
if visual:
plot_spectrogram(sig_mask, title="Mask applied")
if visual:
plot_spectrogram(sig_stft_db_masked, title="Masked signal")
if visual:
plot_spectrogram(recovered_spec, title="Recovered spectrogram")
return recovered_signal
In [9]:
output = removeNoise(audio_clip=audio_clip_band_limited, noise_clip=noise_clip,verbose=True,visual=True)
In [10]:
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(20,4))
plt.plot(output, color='black')
ax.set_xlim((0, len(output)))
plt.show()
# play back a sample of the song
IPython.display.Audio(data=output, rate=44100)
Out[10]:
Let's try this on a more challenging example¶
In [53]:
wav_loc = "assets/audio/fish.wav"
src_rate, src_data = wavfile.read(wav_loc)
# src_data = np.concatenate((src_data, np.zeros(src_rate*3)))
src_data = src_data / 32768
wav_loc = "assets/audio/cafe.wav"
noise_rate, noise_data = wavfile.read(wav_loc)
# get some noise to add to the signal
noise_to_add = noise_data[len(src_data) : len(src_data) * 2]
# get a different part of the noise clip for calculating statistics
noise_clip = noise_data[: len(src_data)]
noise_clip = noise_clip / max(noise_to_add)
noise_to_add = noise_to_add / max(noise_to_add)
# apply noise
snr = 1 # signal to noise ratio
audio_clip_cafe = src_data + noise_to_add / snr
noise_clip = noise_clip / snr
In [54]:
fig, ax = plt.subplots(figsize=(20, 4))
ax.plot(noise_clip)
IPython.display.Audio(data=noise_clip, rate=src_rate)
Out[54]:
In [55]:
fig, ax = plt.subplots(figsize=(20,4))
ax.plot(audio_clip_cafe)
IPython.display.Audio(data=audio_clip_cafe, rate=src_rate)
Out[55]:
In [56]:
output = removeNoise(
audio_clip=audio_clip_cafe,
noise_clip=noise_clip,
n_std_thresh=2,
prop_decrease=0.95,
visual=True,
)
In [57]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20, 4))
plt.plot(output, color="black")
ax.set_xlim((0, len(output)))
plt.show()
# play back a sample of the song
IPython.display.Audio(data=output, rate=44100)
Out[57]:
In [ ]:
In [ ]: