diff --git a/beat.py b/beat.py index ca2fbc1..b10b007 100644 --- a/beat.py +++ b/beat.py @@ -13,6 +13,11 @@ class SsfZxing: ssf_rel_thres = 3 #: magic number from Zong 2003, threshold from mean SSF amplitude ssf_rel_rise = 0.8 #: minimum rise of SSF edge (from foot to peak) relative to 'ssf_th' + # TODO: C++ impl has diverged. + # - refractory period changes + # - ssf_th filter with 6-points + # - ?? others ?? + def __init__(self): pass def _ssf_det_zxings(self, fs, ssf, ssf_th): @@ -37,6 +42,7 @@ class SsfZxing: ssf_z[:i_range] = 0 # force-zero the bounds where we cannot check the amplitude rise ssf_zxings = np.where(ssf_z)[0] # only integer-index resolution (no interpolation) + self.ssf_zxings = ssf_zxings return ssf_zxings def _ssf_function(self, fs, y): @@ -50,8 +56,17 @@ class SsfZxing: # compute threshold # TODO: check if we need lowpass instead of mean for 'ssf_th' ssf_th = self.ssf_rel_thres * np.mean(ssf) + self.ssf, self.ssf_th = ssf, ssf_th return ssf, ssf_th + def debug_plot(self, i1, i2): + ssf, ssf_th, ssf_zxings = self.ssf, self.ssf_th, self.ssf_zxings + ssf_slice = ssf[i1:i2] + ssf_th_slice = ssf_th[i1:i2] if isinstance(ssf_th, np.ndarray) else ssf_th + plt.figure(figsize=(8, 2)) + plt.plot(ssf_slice) + plt.plot(np.arange(ssf_slice.shape[0]), np.ones(ssf_slice.shape[0]) * ssf_th_slice) + plt.scatter(ssf_zxings[i1:i2], np.ones(ssf_zxings[i1:i2].shape[0]) * ssf_th_slice, c='r') def get_mae_dist(ibis): """make triangular wave between beats, representing absolute beat placement error."""