180 lines
8.0 KiB
Python
180 lines
8.0 KiB
Python
import numpy as np
|
|
|
|
from rhythm import BassAnalyzer, GuitarAnalyzer
|
|
from segmenter import Segmenter
|
|
from beat import SsfZxing, RegularBeatFinder
|
|
from sqi import gauss, shift
|
|
|
|
class SongBeatDetector:
|
|
SEGMENT_SLICE_LEN_SEC = 8.0 #: slice length for processing (long enough to contain bar structure; short enough for a constant freq. beat placement)
|
|
SSF_REL_THRES = 1.5 #: optimize for slope of error (mae) function over beat frequency
|
|
NE_THRES = 30.0 #: normalized error threshold for 'good' slices
|
|
def __init__(self): pass
|
|
def detect(self, fs, sig, use_f_hint=True, debug_fe_idx=None):
|
|
self.fs = fs
|
|
#self.sig = sig
|
|
|
|
self.ba = BassAnalyzer(fs, sig)
|
|
self.bass, times = self.ba.viterbi_wavelet_scalogram_amplitudes(dbg_time=True)
|
|
# times: durations of different stages
|
|
|
|
self.ga = GuitarAnalyzer(fs, sig)
|
|
self.guitar = self.ga.spectrogram_power_amplitudes()
|
|
|
|
fsd = fs / self.ga.D # <- guitar ('ga')
|
|
self.D = self.ga.D # <- guitar ('ga')
|
|
|
|
# self.bass, self.guitar: functions on windowed spectrum 0.008 sec apart (125 Hz)
|
|
self.sg = Segmenter()
|
|
self.i_seg = self.sg.get_segments(fsd, self.guitar) # <- guitar
|
|
self.t_seg = self.i_seg / fsd
|
|
self.fsd = fsd # reciprocal window step size
|
|
|
|
# we segment on 'guitar' info, but process 'bass' later
|
|
|
|
if use_f_hint:
|
|
# initial estimate (without 'f_hint')
|
|
zds_initial = self._estimate_segments(debug_fe_idx=None)
|
|
self.zds_initial = zds_initial
|
|
ifbs_good = np.array([zdd['ne'] < SongBeatDetector.NE_THRES for zdd in zds_initial])
|
|
fbs = np.array([zdd['fb'] for zdd in zds_initial])[np.where(ifbs_good)[0]]
|
|
bins, hfreq = np.histogram(fbs)
|
|
ih = np.argmax(bins)
|
|
self.f_hint = np.mean((hfreq[ih], hfreq[ih+1])) # center freq of bin
|
|
else:
|
|
self.f_hint = None
|
|
|
|
# actual estimate (using 'f_hint' to bias each segment)
|
|
self.zds = self._estimate_segments(f_hint=self.f_hint, debug_fe_idx=debug_fe_idx)
|
|
|
|
return self.zds
|
|
|
|
def _estimate_segments(self, f_hint=None, debug_fe_idx=None):
|
|
zds = []
|
|
fsd = self.fsd
|
|
seg_sl = int(SongBeatDetector.SEGMENT_SLICE_LEN_SEC * fsd) # segment slice length in 1/fsd units
|
|
# for each segment
|
|
for i in np.arange(self.i_seg.shape[0]-1):
|
|
i1, i2 = self.i_seg[i], self.i_seg[i+1]
|
|
t1, t2 = i1 / fsd, i2 / fsd
|
|
# split segment into slices
|
|
if i2-i1 < seg_sl: continue
|
|
num_sl = (i2-i1) // seg_sl
|
|
for m in np.arange(num_sl):
|
|
j1, j2 = i1+m*seg_sl, i1+(m+1)*seg_sl
|
|
sig_slice = self.bass[slice(j1, j2)] # <- bass
|
|
|
|
if debug_fe_idx is not None:
|
|
# there will be many (upto 50) different slices - do not debug-plot them all
|
|
debug_fe_sidx = debug_fe_idx / fs * fsd
|
|
debug_fe = i1 <= debug_fe_sidx < i2
|
|
else:
|
|
debug_fe = False
|
|
zdd = self._process_slice(j1, j2, m, sig_slice, f_hint=f_hint, debug_fe=debug_fe)
|
|
zds.append(zdd)
|
|
|
|
return zds
|
|
|
|
def _process_slice(self, j1, j2, m, sig_slice, f_hint=None, debug_fe=False):
|
|
"""
|
|
:param j1: lower index into 'sig_slice'
|
|
:param j2: upper index into 'sig_slice'
|
|
:param m: slice number (used to check if debugging)
|
|
:param debug_fe: show plots for SSF and raw/reg beat placement
|
|
"""
|
|
# TODO: C++ impl of SsfZxing._ssf_det_zxings() has diverged.
|
|
# - refractory period changes
|
|
# - ssf_th filter with 6-points
|
|
# - ?? others ??
|
|
# NOTE: SsfZxing here is always getting short 8-sec slices only (nb. for 'ssf_th' comput.)
|
|
|
|
fsd = self.fsd # reciprocal window step size
|
|
seg_sl = int(SongBeatDetector.SEGMENT_SLICE_LEN_SEC * fsd) # segment slice length in 1/fsd units
|
|
|
|
SsfZxing.ssf_rel_thres = SongBeatDetector.SSF_REL_THRES
|
|
zd = SsfZxing()
|
|
ssf, ssf_th = zd._ssf_function(fsd, sig_slice)
|
|
ssf_zxings = zd._ssf_det_zxings(fsd, ssf, ssf_th)
|
|
|
|
zdd = {
|
|
'i1': j1 * self.D, 'i2': j2 * self.D,
|
|
# ssf_zxings: raw beats (relative to slice)
|
|
'zd': zd, 'ssf': ssf, 'ssf_zxings': ssf_zxings,
|
|
'sig_slice': sig_slice, 'sig_source': 'bass',
|
|
'ssf_th': np.ones(ssf.shape[0]) * ssf_th
|
|
}
|
|
|
|
# (only plot first slice of a wider segment)
|
|
#if num_sl > 2 and m == 0:
|
|
if debug_fe:
|
|
#
|
|
# scalogram image, with viterbi path
|
|
self.ba.debug_plot(j1, j2) # TODO: adapt 'bass'
|
|
plt.title(f'scalogram & viterbi path, slice [{j1}:{j2}]')
|
|
|
|
# SSF function and detected raw beats
|
|
zd.debug_plot(0, seg_sl)
|
|
plt.title(f'raw beats, slice [{j1}:{j2}]')
|
|
|
|
# nice-to: optimize phase, (maybe iteratively, optimize phase and freq each)
|
|
bf = RegularBeatFinder()
|
|
fb, ne = bf.find_beat(fsd, ssf_zxings, f_hint=f_hint, debug_fe=debug_fe, debug_i=None)
|
|
if debug_fe: plt.title(f'regular-beat placement error (mae), slice [{j1}:{j2}]')
|
|
# mae is unnurmalized here, as returned from RegularBeatFinder._get_opt_ibi_freq_2()
|
|
zdd.update({
|
|
# bf: beat finder
|
|
# fb: beat frequency, in Hz
|
|
# ne: normalized mae error
|
|
'bf': bf, 'fb': fb, 'ne': ne
|
|
})
|
|
# TODO: ne > 30 is suspiciously bad - filter those "detections" out eventually
|
|
# TODO: # catch basic errors: ne == 0, or len(est_zxings) == 0, means slice is bad
|
|
# NOTE: since 2x the zero-crossings, we get twice the frequency here.
|
|
# NOTE: this means 0.5 lower freq bound of RegularBeatFinder will find at most 60 bpm in the song.
|
|
|
|
# TODO: RegularBeatFinder currently not using 'phase' info, but should be optimized
|
|
# TODO: (currently we start the pattern at the first detected beat, may or may not be good)
|
|
est_zxings = np.cumsum(np.pad(bf.freq_to_est_ibis(fsd, fb, j2-j1), (1,0))) # rel. to slice
|
|
if ssf_zxings.shape[0] > 0:
|
|
est_zxings += ssf_zxings[0] # add phase = currently we just start at first detected beat
|
|
# nice-to: median-filter the freq, etc.pp.
|
|
# nice-to: avoid adding len(est_zxings)=0 entries later
|
|
|
|
# trim back to max. index
|
|
est_zxings = est_zxings[np.where(est_zxings < ssf.shape[0])[0]]
|
|
|
|
zdd.update({
|
|
# est_zxings: regular beats (relative to slice)
|
|
'est_zxings': est_zxings
|
|
})
|
|
|
|
if debug_fe:
|
|
plt.figure(figsize=(8,2))
|
|
plt.plot(ssf)
|
|
plt.plot(np.arange(ssf.shape[0]), np.ones(ssf.shape[0]) * ssf_th); None
|
|
plt.scatter(ssf_zxings, np.ones(ssf_zxings.shape[0]) * ssf_th, c='r')
|
|
plt.scatter(est_zxings, np.ones(est_zxings.shape[0]) * ssf_th, c='g')
|
|
plt.title(f'ssf, ssf_th, raw beats (r), reg beats (g), slice [{j1}:{j2}]')
|
|
|
|
return zdd
|
|
|
|
# _debug_fmt_est_zxings
|
|
def _place_fmt_zxings(self, fsd, ssf, ssf_zxings):
|
|
gauss_beat_template_win_sec = 0.25542 #: gauss window width (as compared to beats in ssf function)
|
|
gauss_beat_template_sigma_sec = 0.027 #: gauss bump half-width parameter (as compared to beats in ssf function)
|
|
#gauss_amplitude = 2.0
|
|
|
|
#def get_snr(self, fsd, ssf, ssf_threshold, ssf_zxings):
|
|
# """Compute the Signal-to-Noise Ratio of beats, based on SSF function and detected beat locations."""
|
|
sigma = fsd * gauss_beat_template_sigma_sec
|
|
W = int(fsd * gauss_beat_template_win_sec)
|
|
gb = gauss(W, W//2, sigma)
|
|
# place gaussians on estimated beat locations
|
|
ssf_est = np.zeros(ssf.shape[0])
|
|
for i in ssf_zxings:
|
|
ssf_est += shift(ssf.shape[0], i, gb)
|
|
ssf_est /= gb[W//2] # normalize amplitude to 1.0
|
|
ssf_est = np.roll(ssf_est, int(sigma)) # shift to right (beat loc = gauss beginning, not center)
|
|
return ssf_est
|
|
|