add: aligner.py
This commit is contained in:
257
aligner.py
Normal file
257
aligner.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def load_track_annot(fn):
|
||||
with open(fn) as fin:
|
||||
annot = json.load(fin)
|
||||
return annot
|
||||
|
||||
fns_all = [['beat_annotations/%02d/Rhythmology_%03d.json'%(j,i) for i in range(1,10)] for j in range(1,4)]
|
||||
fns_all
|
||||
|
||||
"""
|
||||
implement an algorithm to align two series of time intervals in Python. series will be upto 1000 entries long. use dynamic programming if needed.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from math import inf
|
||||
from typing import Callable, Iterable, Optional, Sequence, Tuple, List, Any
|
||||
|
||||
|
||||
Interval = Tuple[float, float]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AlignmentStep:
|
||||
a_index: Optional[int]
|
||||
b_index: Optional[int]
|
||||
a_interval: Optional[Interval]
|
||||
b_interval: Optional[Interval]
|
||||
score: float
|
||||
|
||||
|
||||
def overlap_length(x: Interval, y: Interval) -> float:
|
||||
return max(0.0, min(x[1], y[1]) - max(x[0], y[0]))
|
||||
|
||||
|
||||
def union_length(x: Interval, y: Interval) -> float:
|
||||
return max(x[1], y[1]) - min(x[0], y[0])
|
||||
|
||||
|
||||
def default_match_score(a: Interval, b: Interval) -> float:
|
||||
"""
|
||||
Scores interval similarity.
|
||||
|
||||
Returns:
|
||||
+1.0 for perfect overlap
|
||||
0.0 to 1.0 for partial overlap
|
||||
-1.0 for no overlap
|
||||
|
||||
You can replace this with a domain-specific function.
|
||||
"""
|
||||
ov = overlap_length(a, b)
|
||||
if ov <= 0:
|
||||
return -1.0
|
||||
|
||||
u = union_length(a, b)
|
||||
if u <= 0:
|
||||
return 1.0
|
||||
|
||||
return ov / u
|
||||
|
||||
|
||||
def align_intervals(
|
||||
series_a: Sequence[Interval],
|
||||
series_b: Sequence[Interval],
|
||||
*,
|
||||
match_score: Callable[[Interval, Interval], float] = default_match_score,
|
||||
gap_penalty: float = -0.5,
|
||||
) -> Tuple[float, List[AlignmentStep]]:
|
||||
"""
|
||||
Globally align two ordered interval series using dynamic programming.
|
||||
|
||||
Each step is one of:
|
||||
- interval from A matched to interval from B
|
||||
- interval from A matched to a gap
|
||||
- gap matched to interval from B
|
||||
|
||||
Complexity:
|
||||
Time: O(len(series_a) * len(series_b))
|
||||
Memory: O(len(series_a) * len(series_b)) for traceback,
|
||||
O(len(series_b)) for DP scores.
|
||||
|
||||
With up to 1000 intervals per series, this is usually fine.
|
||||
"""
|
||||
|
||||
a = list(series_a)
|
||||
b = list(series_b)
|
||||
|
||||
n = len(a)
|
||||
m = len(b)
|
||||
|
||||
# Validate intervals.
|
||||
for idx, iv in enumerate(a):
|
||||
if iv[0] > iv[1]:
|
||||
raise ValueError(f"Invalid interval in series_a at index {idx}: {iv}")
|
||||
|
||||
for idx, iv in enumerate(b):
|
||||
if iv[0] > iv[1]:
|
||||
raise ValueError(f"Invalid interval in series_b at index {idx}: {iv}")
|
||||
|
||||
# Traceback directions:
|
||||
# 1 = diagonal: match A[i - 1] with B[j - 1]
|
||||
# 2 = up: A[i - 1] with gap
|
||||
# 3 = left: gap with B[j - 1]
|
||||
trace = [bytearray(m + 1) for _ in range(n + 1)]
|
||||
|
||||
prev = [0.0] * (m + 1)
|
||||
|
||||
# Initialize first row: gaps in A.
|
||||
for j in range(1, m + 1):
|
||||
prev[j] = prev[j - 1] + gap_penalty
|
||||
trace[0][j] = 3
|
||||
|
||||
# Fill DP table row by row.
|
||||
for i in range(1, n + 1):
|
||||
curr = [0.0] * (m + 1)
|
||||
|
||||
# First column: gaps in B.
|
||||
curr[0] = prev[0] + gap_penalty
|
||||
trace[i][0] = 2
|
||||
|
||||
for j in range(1, m + 1):
|
||||
s_match = prev[j - 1] + match_score(a[i - 1], b[j - 1])
|
||||
s_gap_b = prev[j] + gap_penalty # A[i - 1] aligned to gap
|
||||
s_gap_a = curr[j - 1] + gap_penalty # gap aligned to B[j - 1]
|
||||
|
||||
# Deterministic tie-breaking:
|
||||
# prefer match, then gap in B, then gap in A.
|
||||
best = s_match
|
||||
direction = 1
|
||||
|
||||
if s_gap_b > best:
|
||||
best = s_gap_b
|
||||
direction = 2
|
||||
|
||||
if s_gap_a > best:
|
||||
best = s_gap_a
|
||||
direction = 3
|
||||
|
||||
curr[j] = best
|
||||
trace[i][j] = direction
|
||||
|
||||
prev = curr
|
||||
|
||||
total_score = prev[m]
|
||||
|
||||
# Reconstruct alignment.
|
||||
alignment: List[AlignmentStep] = []
|
||||
i, j = n, m
|
||||
|
||||
while i > 0 or j > 0:
|
||||
direction = trace[i][j]
|
||||
|
||||
if direction == 1:
|
||||
s = match_score(a[i - 1], b[j - 1])
|
||||
alignment.append(
|
||||
AlignmentStep(
|
||||
a_index=i - 1,
|
||||
b_index=j - 1,
|
||||
a_interval=a[i - 1],
|
||||
b_interval=b[j - 1],
|
||||
score=s,
|
||||
)
|
||||
)
|
||||
i -= 1
|
||||
j -= 1
|
||||
|
||||
elif direction == 2:
|
||||
alignment.append(
|
||||
AlignmentStep(
|
||||
a_index=i - 1,
|
||||
b_index=None,
|
||||
a_interval=a[i - 1],
|
||||
b_interval=None,
|
||||
score=gap_penalty,
|
||||
)
|
||||
)
|
||||
i -= 1
|
||||
|
||||
elif direction == 3:
|
||||
alignment.append(
|
||||
AlignmentStep(
|
||||
a_index=None,
|
||||
b_index=j - 1,
|
||||
a_interval=None,
|
||||
b_interval=b[j - 1],
|
||||
score=gap_penalty,
|
||||
)
|
||||
)
|
||||
j -= 1
|
||||
|
||||
else:
|
||||
# Only possible at trace[0][0], but the loop should not enter there.
|
||||
raise RuntimeError("Invalid traceback state")
|
||||
|
||||
alignment.reverse()
|
||||
return total_score, alignment
|
||||
|
||||
###
|
||||
|
||||
def make_intv_pairs(annot):
|
||||
#ts = np.array(list(sorted(annot['beatTimesSec'])))
|
||||
#ts_pdiff = np.diff(np.pad(ts, (1,0)))
|
||||
#a = [(ts_pdiff[i], ts_pdiff[i+1]) for i in range(ts_pdiff.shape[0]-1)]
|
||||
tsp = np.pad(annot['beatTimesSec'], (1,0))
|
||||
a = [(tsp[i], tsp[i+1]) for i in range(tsp.shape[0]-1)]
|
||||
return a
|
||||
|
||||
def interp_missing_1(ts):
|
||||
"""interpolate missing beats"""
|
||||
ib_th = 0.2 #: sec (inter-beat interval threshold for interpolation)
|
||||
mn = np.mean(np.diff(ts))
|
||||
idxs_interp = np.where(np.abs(np.diff(ts) - mn) > ib_th)[0]
|
||||
tsl = list(ts)
|
||||
for k, i in enumerate(idxs_interp):
|
||||
tsl.insert(i+k+1, ts[i] + (ts[i+1] - ts[i]) / 2)
|
||||
if len(idxs_interp) > 0:
|
||||
print('interp_missing: added %d beats' % len(idxs_interp))
|
||||
return np.array(tsl)
|
||||
|
||||
|
||||
def interp_missing(ts_a, ts_b):
|
||||
"""interpolate missing beats"""
|
||||
assert len(ts_a) == len(ts_b)
|
||||
ib_th = 0.2 #: sec (inter-beat interval threshold for interpolation)
|
||||
mn_a, mn_b = np.mean(np.diff(ts_a)), np.mean(np.diff(ts_b))
|
||||
idxs_interp_a = np.where(np.abs(np.diff(ts_a) - mn_a) > ib_th)[0]
|
||||
idxs_interp_b = np.where(np.abs(np.diff(ts_b) - mn_b) > ib_th)[0]
|
||||
# join the two sets, and apply interpolation to both timeseries
|
||||
# (ensures they will have the same length afterwards)
|
||||
idxs_interp = np.array(list(sorted(list(set(idxs_interp_b).union(set(idxs_interp_b))))))
|
||||
tsl_a, tsl_b = list(ts_a), list(ts_b)
|
||||
for k, i in enumerate(idxs_interp):
|
||||
tsl_a.insert(i+k+1, ts_a[i] + (ts_a[i+1] - ts_a[i]) / 2)
|
||||
tsl_b.insert(i+k+1, ts_b[i] + (ts_b[i+1] - ts_b[i]) / 2)
|
||||
if len(idxs_interp) > 0:
|
||||
print('interp_missing: added %d beats' % len(idxs_interp))
|
||||
return np.array(tsl_a), np.array(tsl_b)
|
||||
|
||||
|
||||
def align(annot_a, annot_b):
|
||||
score, alignment = align_intervals(make_intv_pairs(annot_a), make_intv_pairs(annot_b))
|
||||
def check_intv(al):
|
||||
# - neither interval is None (we are aligned)
|
||||
not_none = al.a_interval and al.b_interval
|
||||
# - cut first few bad intervals (too lazy to interpolate them & they might be noisy by the user anyways)
|
||||
init_bad = al.a_interval and al.b_interval and al.score < 0.5 and (al.a_index < 5 or al.b_index < 5)
|
||||
return not_none and not init_bad
|
||||
|
||||
idxs_a = np.array([al.a_index for al in alignment if check_intv(al)])
|
||||
idxs_b = np.array([al.b_index for al in alignment if check_intv(al)])
|
||||
ts_a, ts_b = np.array(annot_a['beatTimesSec']), np.array(annot_b['beatTimesSec'])
|
||||
tsa_a, tsa_b = ts_a[idxs_a], ts_b[idxs_b]
|
||||
tsa_a, tsa_b = interp_missing(tsa_a, tsa_b)
|
||||
|
||||
return ts_a, ts_b, tsa_a, tsa_b
|
||||
Reference in New Issue
Block a user