diff --git a/beat.py b/beat.py index 61a9c82..7e1d73a 100644 --- a/beat.py +++ b/beat.py @@ -172,14 +172,15 @@ class RegularBeatFinder: # evaluate mean absolute errors for all frequencies freqs, freq_errs = self._get_opt_ibi_freq_2(fs, act_ibis, debug_i) # bias with f_hint - once we know the beat freq, make it more likely for it to be found everywhere - nf, f1, f2 = RegularBeatFinder.num_freqs, RegularBeatFinder.range_f2, RegularBeatFinder.range_f1 - bias = gauss( - nf, - (f_hint - f1) / (f2 - f1) * nf, - RegularBeatFinder.f_bias_width * nf - ) - freqs_bias = 1.0 / (np.max(bias)+bias) # make 'f_hint' at most 2x more likely -- (1+bias) if normalized - freq_errs *= freqs_bias + if f_hint is not None: + nf, f1, f2 = RegularBeatFinder.num_freqs, RegularBeatFinder.range_f2, RegularBeatFinder.range_f1 + bias = gauss( + nf, + (f_hint - f1) / (f2 - f1) * nf, + RegularBeatFinder.f_bias_width * nf + ) + freqs_bias = 1.0 / (np.max(bias)+bias) # make 'f_hint' at most 2x more likely -- (1+bias) if normalized + freq_errs *= freqs_bias # if debug_fe: plt.figure(figsize=(8,2)) diff --git a/song.py b/song.py index c8c4703..017ad78 100644 --- a/song.py +++ b/song.py @@ -33,8 +33,12 @@ class SongBeatDetector: seg_sl = int(SongBeatDetector.SEGMENT_SLICE_LEN_SEC * fsd) - self.zds = [] + self.zds = self._estimate_segments(debug_fe_idx=debug_fe_idx) + return self.zds + def _estimate_segments(self, f_hint=None, debug_fe_idx=None): + zds = [] + fsd = self.fsd # for each segment for i in np.arange(self.i_seg.shape[0]-1): i1, i2 = self.i_seg[i], self.i_seg[i+1] @@ -52,12 +56,12 @@ class SongBeatDetector: debug_fe = i1 <= debug_fe_sidx < i2 else: debug_fe = False - zdd = self._process_slice(j1, j2, m, seg_sl, sig_slice, debug_fe=debug_fe) - self.zds.append(zdd) + zdd = self._process_slice(j1, j2, m, seg_sl, sig_slice, f_hint=f_hint, debug_fe=debug_fe) + zds.append(zdd) - return self.zds + return zds - def _process_slice(self, j1, j2, m, seg_sl, sig_slice, debug_fe=False): + def _process_slice(self, j1, j2, m, seg_sl, sig_slice, f_hint=None, debug_fe=False): """ :param j1: lower index into 'sig_slice' :param j2: upper index into 'sig_slice' @@ -100,7 +104,7 @@ class SongBeatDetector: # nice-to: optimize phase, (maybe iteratively, optimize phase and freq each) bf = RegularBeatFinder() - fb, ne = bf.find_beat(fsd, ssf_zxings, debug_fe=debug_fe, debug_i=None) + fb, ne = bf.find_beat(fsd, ssf_zxings, f_hint=f_hint, debug_fe=debug_fe, debug_i=None) if debug_fe: plt.title(f'regular-beat placement error (mae), slice [{j1}:{j2}]') # mae is unnurmalized here, as returned from RegularBeatFinder._get_opt_ibi_freq_2() zdd.update({