From 11934b1f616f63141b62bd4437992cd2afb29965 Mon Sep 17 00:00:00 2001 From: David Madl Date: Wed, 13 May 2026 05:16:17 +0200 Subject: [PATCH] feat: restrict B in BassAnalyzer to sensible freq range --- rhythm.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/rhythm.py b/rhythm.py index 6c3e9c3..d8c6714 100644 --- a/rhythm.py +++ b/rhythm.py @@ -13,6 +13,7 @@ import numpy as np from numpy.fft import fft +import matplotlib.pyplot as plt # for debug only from hsh_signal.signal import lowpass_fft import time @@ -81,7 +82,30 @@ def gabor_wavelet(omega, nu, fs, T, tt=None): psi = 1.0 / np.sqrt(omega) * np.exp(-np.pi * (t / omega)**2) * np.exp(1j*2*np.pi * nu * t / omega) return psi -class BassAnalyzer: +class Analyzer: + def __init__(self): pass + def debug_plot(self, i1, i2): + Scp2, path = self.Scp2, self.path + fs, Dp = self.fs, self.Dp + + ss, omega, nu, fsp, Wp, I, J, freqs = self.pms + + Scp2_slice = np.abs(Scp2[i1:i2]) + + plt.figure(figsize=(8,2)) + plt.imshow(Scp2_slice.T, origin='lower') + x_positions = np.arange(Scp2_slice.shape[0]//250+1)*250 + if x_positions[-1] == Scp2_slice.shape[0]: + x_positions[-1] -= 1 # so last tick is shown properly + t1 = i1 / (fs / Dp) + x_labels = ['{:.1f}'.format(t1+x*Dp/fs) for x in x_positions] + plt.xticks(x_positions, x_labels) + y_positions = np.arange(Scp2_slice.shape[1]//50)*50 + y_labels = ['{:.1f}'.format((nu/(omega*ss[y]))) for y in y_positions] # Hz equivalents of wavelet scale + plt.yticks(y_positions, y_labels) + plt.plot(np.arange(Scp2_slice.shape[0]), path[i1:i2], c='r') + +class BassAnalyzer(Analyzer): """ Rhythm analysis from songs. Provides a beat amplitude signal from the audio signal. @@ -112,6 +136,7 @@ class BassAnalyzer: :param fs: sampling rate :param sig: audio signal normalized to [-1,1] """ + super(BassAnalyzer, self).__init__() self.D = int(self.shift_sec * fs) #: spectrogram step if self.Wp_force: self.Wp = self.Wp_force @@ -151,6 +176,11 @@ class BassAnalyzer: t8 = time.time() ampl = self._viterbi_ampl(Scp2, path) t9 = time.time() + + self.Scp2 = Scp2 + self.path = path + self.pms = pms + if not dbg_time: return ampl else: @@ -192,6 +222,11 @@ class BassAnalyzer: pt_re = (np.diff(pt) == 1).astype(int) # rising edge self.B = max(np.sum(pt_re), 1) # total number of pulses in the 'pt' pulse train signal + # clip B, to force **reasonable** frequency range for wavelets + # (noise will otherwise cause many transitions -> high B -> bass falls below freq range -> algo fail) + B_min, B_max = 0.5 * M / (fs / self.D), 5.0 * M / (fs / self.D) + self.B = np.clip(self.B, a_min=B_min, a_max=B_max) + # resample 'pt' (M) at these indices -> 'ptr' (L), like original 'f' (signal padded) squashed_idxs = np.floor(np.linspace(0, L-1, L) * (M/L)).astype(int) ptr = pt[squashed_idxs]