feat: song: bass reg beat detector in slices

This commit is contained in:
2026-05-17 12:32:39 +02:00
parent ee5a1376ee
commit 71f1975a97
4 changed files with 175 additions and 6 deletions

161
song.py Normal file
View File

@@ -0,0 +1,161 @@
import numpy as np
from rhythm import BassAnalyzer, GuitarAnalyzer
from segmenter import Segmenter
from beat import SsfZxing, RegularBeatFinder
from sqi import gauss, shift
class SongBeatDetector:
SEGMENT_SLICE_LEN_SEC = 8.0 #: slice length for processing (long enough to contain bar structure; short enough for a constant freq. beat placement)
SSF_REL_THRES = 1.5 #: optimize for slope of error (mae) function over beat frequency
def __init__(self): pass
def detect(self, fs, sig, debug_fe_idx=None):
self.fs = fs
#self.sig = sig
self.ba = BassAnalyzer(fs, sig)
self.bass, times = self.ba.viterbi_wavelet_scalogram_amplitudes(dbg_time=True)
# times: durations of different stages
self.ga = GuitarAnalyzer(fs, sig)
self.guitar = self.ga.spectrogram_power_amplitudes()
fsd = fs / self.ga.D # <- guitar ('ga')
self.D = self.ga.D # <- guitar ('ga')
# self.bass, self.guitar: functions on windowed spectrum 0.008 sec apart (125 Hz)
self.sg = Segmenter()
self.i_seg = self.sg.get_segments(fsd, self.guitar) # <- guitar
self.t_seg = self.i_seg / fsd
self.fsd = fsd # reciprocal window step size
# we segment on 'guitar' info, but process 'bass' later
seg_sl = int(SongBeatDetector.SEGMENT_SLICE_LEN_SEC * fsd)
self.zds = []
# for each segment
for i in np.arange(self.i_seg.shape[0]-1):
i1, i2 = self.i_seg[i], self.i_seg[i+1]
t1, t2 = i1 / fsd, i2 / fsd
# split segment into slices
if i2-i1 < seg_sl: continue
num_sl = (i2-i1) // seg_sl
for m in np.arange(num_sl):
j1, j2 = i1+m*seg_sl, i1+(m+1)*seg_sl
sig_slice = self.bass[slice(j1, j2)] # <- bass
if debug_fe_idx is not None:
# there will be many (upto 50) different slices - do not debug-plot them all
debug_fe_sidx = debug_fe_idx / fs * fsd
debug_fe = i1 <= debug_fe_sidx < i2
else:
debug_fe = False
zdd = self._process_slice(j1, j2, m, seg_sl, sig_slice, debug_fe=debug_fe)
self.zds.append(zdd)
return self.zds
def _process_slice(self, j1, j2, m, seg_sl, sig_slice, debug_fe=False):
"""
:param j1: lower index into 'sig_slice'
:param j2: upper index into 'sig_slice'
:param m: slice number (used to check if debugging)
:param seg_sl: segment slice length in 1/fsd units
:param debug_fe: show plots for SSF and raw/reg beat placement
"""
# TODO: C++ impl of SsfZxing._ssf_det_zxings() has diverged.
# - refractory period changes
# - ssf_th filter with 6-points
# - ?? others ??
# NOTE: SsfZxing here is always getting short 8-sec slices only (nb. for 'ssf_th' comput.)
fsd = self.fsd #: reciprocal window step size
SsfZxing.ssf_rel_thres = SongBeatDetector.SSF_REL_THRES
zd = SsfZxing()
ssf, ssf_th = zd._ssf_function(fsd, sig_slice)
ssf_zxings = zd._ssf_det_zxings(fsd, ssf, ssf_th)
zdd = {
'i1': j1 * self.D, 'i2': j2 * self.D,
# ssf_zxings: raw beats (relative to slice)
'zd': zd, 'ssf': ssf, 'ssf_zxings': ssf_zxings,
'sig_slice': sig_slice, 'sig_source': 'bass',
'ssf_th': np.ones(ssf.shape[0]) * ssf_th
}
# (only plot first slice of a wider segment)
#if num_sl > 2 and m == 0:
if debug_fe:
#
# scalogram image, with viterbi path
self.ba.debug_plot(j1, j2) # TODO: adapt 'bass'
plt.title(f'scalogram & viterbi path, slice [{j1}:{j2}]')
# SSF function and detected raw beats
zd.debug_plot(0, seg_sl)
plt.title(f'raw beats, slice [{j1}:{j2}]')
# nice-to: optimize phase, (maybe iteratively, optimize phase and freq each)
bf = RegularBeatFinder()
fb, ne = bf.find_beat(fsd, ssf_zxings, debug_fe=debug_fe, debug_i=None)
if debug_fe: plt.title(f'regular-beat placement error (mae), slice [{j1}:{j2}]')
# mae is unnurmalized here, as returned from RegularBeatFinder._get_opt_ibi_freq_2()
zdd.update({
# bf: beat finder
# fb: beat frequency, in Hz
# ne: normalized mae error
'bf': bf, 'fb': fb, 'ne': ne
})
# TODO: ne > 30 is suspiciously bad - filter those "detections" out eventually
# TODO: # catch basic errors: ne == 0, or len(est_zxings) == 0, means slice is bad
# NOTE: since 2x the zero-crossings, we get twice the frequency here.
# NOTE: this means 0.5 lower freq bound of RegularBeatFinder will find at most 60 bpm in the song.
# TODO: RegularBeatFinder currently not using 'phase' info, but should be optimized
# TODO: (currently we start the pattern at the first detected beat, may or may not be good)
est_zxings = np.cumsum(np.pad(bf.freq_to_est_ibis(fsd, fb, j2-j1), (1,0))) # rel. to slice
if ssf_zxings.shape[0] > 0:
est_zxings += ssf_zxings[0] # add phase = currently we just start at first detected beat
# nice-to: median-filter the freq, etc.pp.
# nice-to: avoid adding len(est_zxings)=0 entries later
# trim back to max. index
est_zxings = est_zxings[np.where(est_zxings < ssf.shape[0])[0]]
zdd.update({
# est_zxings: regular beats (relative to slice)
'est_zxings': est_zxings
})
if debug_fe:
plt.figure(figsize=(8,2))
plt.plot(ssf)
plt.plot(np.arange(ssf.shape[0]), np.ones(ssf.shape[0]) * ssf_th); None
plt.scatter(ssf_zxings, np.ones(ssf_zxings.shape[0]) * ssf_th, c='r')
plt.scatter(est_zxings, np.ones(est_zxings.shape[0]) * ssf_th, c='g')
plt.title(f'ssf, ssf_th, raw beats (r), reg beats (g), slice [{j1}:{j2}]')
return zdd
# _debug_fmt_est_zxings
def _place_fmt_zxings(self, fsd, ssf, ssf_zxings):
gauss_beat_template_win_sec = 0.25542 #: gauss window width (as compared to beats in ssf function)
gauss_beat_template_sigma_sec = 0.027 #: gauss bump half-width parameter (as compared to beats in ssf function)
#gauss_amplitude = 2.0
#def get_snr(self, fsd, ssf, ssf_threshold, ssf_zxings):
# """Compute the Signal-to-Noise Ratio of beats, based on SSF function and detected beat locations."""
sigma = fsd * gauss_beat_template_sigma_sec
W = int(fsd * gauss_beat_template_win_sec)
gb = gauss(W, W//2, sigma)
# place gaussians on estimated beat locations
ssf_est = np.zeros(ssf.shape[0])
for i in ssf_zxings:
ssf_est += shift(ssf.shape[0], i, gb)
ssf_est /= gb[W//2] # normalize amplitude to 1.0
ssf_est = np.roll(ssf_est, int(sigma)) # shift to right (beat loc = gauss beginning, not center)
return ssf_est