Add readahead for segmenters

This commit is contained in:
Chris Cannam
2022-05-25 11:14:19 +01:00
parent 680393c5c6
commit 973a334f75
2 changed files with 129 additions and 54 deletions

View File

@@ -233,80 +233,134 @@ R3StretcherImpl::consume()
// We have a single unwindowed frame at the longest FFT
// size ("scale"). Populate the shorter FFT sizes from the
// centre of it, windowing as we copy
// centre of it, windowing as we copy. The classification
// scale is handled separately because it has readahead,
// so skip it here as well as the longest. (In practice
// this means we are probably only populating one scale)
for (auto &it: cd->scales) {
int fftSize = it.first;
auto &scale = it.second;
if (fftSize == longest) continue;
if (fftSize == classify || fftSize == longest) continue;
int offset = (longest - fftSize) / 2;
m_scaleData.at(fftSize)->analysisWindow.cut
(buf + offset, scale->timeDomain.data());
(buf + offset, it.second->timeDomain.data());
}
// Then window the longest one
// The classification scale has a one-hop readahead (note
// that inhop is fixed), so populate its current data from
// the readahead and the readahead from further down the
// long unwindowed frame.
auto &classifyScale = cd->scales.at(classify);
ClassificationReadaheadData &readahead = cd->readahead;
m_scaleData.at(classify)->analysisWindow.cut
(buf + (longest - classify) / 2 + m_inhop,
readahead.timeDomain.data());
// Finally window the longest scale
m_scaleData.at(longest)->analysisWindow.cut(buf);
// FFT shift, forward FFT, and carry out cartesian-polar
// conversion for each FFT size
// conversion for each FFT size.
// For the classification scale we need magnitudes for the
// full range (polar only in a subset) and we operate in
// the readahead, pulling current values from the existing
// readahead
v_fftshift(readahead.timeDomain.data(), classify);
v_copy(classifyScale->mag.data(),
readahead.mag.data(),
classifyScale->bufSize);
v_copy(classifyScale->phase.data(),
readahead.phase.data(),
classifyScale->bufSize);
m_scaleData.at(classify)->fft.forward(readahead.timeDomain.data(),
classifyScale->real.data(),
classifyScale->imag.data());
for (const auto &b : m_guideConfiguration.fftBandLimits) {
if (b.fftSize == classify) {
if (b.b0min > 0) {
v_cartesian_to_magnitudes(readahead.mag.data(),
classifyScale->real.data(),
classifyScale->imag.data(),
b.b0min);
}
v_cartesian_to_polar(readahead.mag.data() + b.b0min,
readahead.phase.data() + b.b0min,
classifyScale->real.data() + b.b0min,
classifyScale->imag.data() + b.b0min,
b.b1max - b.b0min);
if (b.b1max < classify/2 + 1) {
v_cartesian_to_magnitudes
(readahead.mag.data() + b.b1max,
classifyScale->real.data() + b.b1max,
classifyScale->imag.data() + b.b1max,
classify/2 + 1 - b.b1max);
}
v_scale(classifyScale->mag.data(),
1.0 / double(classify),
classifyScale->mag.size());
break;
}
}
// For the others we operate directly in the scale data
// and restrict the range for cartesian-polar conversion
for (auto &it: cd->scales) {
int fftSize = it.first;
if (fftSize == classify) continue;
auto &scale = it.second;
v_fftshift(scale->timeDomain.data(), fftSize);
if (fftSize == m_guideConfiguration.classificationFftSize) {
// For the classification scale we need the full range
m_scaleData.at(fftSize)->fft.forwardPolar
(scale->timeDomain.data(),
scale->mag.data(),
scale->phase.data());
} else {
// For other scales we only need do
// cartesian-polar conversion for the necessary
// frequency subset
m_scaleData.at(fftSize)->fft.forward
(scale->timeDomain.data(),
m_scaleData.at(fftSize)->fft.forward(scale->timeDomain.data(),
scale->real.data(),
scale->imag.data());
//!!! This should be a map
for (const auto &b : m_guideConfiguration.fftBandLimits) {
if (b.fftSize == fftSize) {
int offset = b.b0min;
v_cartesian_to_polar
(scale->mag.data() + offset,
scale->phase.data() + offset,
scale->real.data() + offset,
scale->imag.data() + offset,
b.b1max - offset);
v_cartesian_to_polar(scale->mag.data() + b.b0min,
scale->phase.data() + b.b0min,
scale->real.data() + b.b0min,
scale->imag.data() + b.b0min,
b.b1max - b.b0min);
v_scale(scale->mag.data() + b.b0min,
1.0 / double(fftSize),
b.b1max - b.b0min);
break;
}
}
}
v_scale(scale->mag.data(), 1.0 / double(fftSize),
scale->mag.size());
}
// Use the classification scale to get a bin segmentation
// and calculate the adaptive frequency guide for this
// channel
auto &classifyScale = cd->scales.at(classify);
cd->prevSegmentation = cd->segmentation;
cd->segmentation =
cd->segmenter->segment(classifyScale->mag.data());
cd->segmentation = cd->nextSegmentation;
cd->nextSegmentation = cd->segmenter->segment(readahead.mag.data());
m_troughPicker.findNearestAndNextPeaks
(classifyScale->mag.data(), 3, nullptr,
classifyScale->nextTroughs.data());
classifyScale->troughs.data());
m_guide.calculate(instantaneousRatio,
classifyScale->mag.data(),
classifyScale->nextTroughs.data(),
classifyScale->troughs.data(),
classifyScale->prevMag.data(),
cd->segmentation,
cd->prevSegmentation,
BinSegmenter::Segmentation(), //!!!
cd->nextSegmentation,
cd->guidance);
}
@@ -320,7 +374,7 @@ R3StretcherImpl::consume()
m_channelAssembly.mag[c] = classifyScale->mag.data();
m_channelAssembly.phase[c] = classifyScale->phase.data();
m_channelAssembly.guidance[c] = &cd->guidance;
m_channelAssembly.outPhase[c] = classifyScale->outPhase.data();
m_channelAssembly.outPhase[c] = classifyScale->advancedPhase.data();
}
m_scaleData.at(fftSize)->guided.advance
(m_channelAssembly.outPhase.data(),
@@ -342,8 +396,12 @@ R3StretcherImpl::consume()
auto &scale = it.second;
int bufSize = scale->bufSize;
// copy to prevMag before filtering
v_copy(scale->prevMag.data(), scale->mag.data(), bufSize);
v_copy(scale->prevOutPhase.data(), scale->outPhase.data(), bufSize);
v_copy(scale->prevMag.data(),
scale->mag.data(),
bufSize);
v_copy(scale->prevAdvancedPhase.data(),
scale->advancedPhase.data(),
bufSize);
}
for (const auto &band : cd->guidance.fftBands) {
@@ -398,7 +456,7 @@ R3StretcherImpl::consume()
(scale->real.data() + offset,
scale->imag.data() + offset,
scale->mag.data() + offset,
scale->outPhase.data() + offset,
scale->advancedPhase.data() + offset,
b.b1max - offset);
break;
}

View File

@@ -134,6 +134,21 @@ public:
size_t getChannelCount() const;
protected:
struct ClassificationReadaheadData {
FixedVector<double> timeDomain;
FixedVector<double> mag;
FixedVector<double> phase;
ClassificationReadaheadData(int _fftSize) :
timeDomain(_fftSize, 0.f),
mag(_fftSize/2 + 1, 0.f),
phase(_fftSize/2 + 1, 0.f)
{ }
private:
ClassificationReadaheadData(const ClassificationReadaheadData &) =delete;
ClassificationReadaheadData &operator=(const ClassificationReadaheadData &) =delete;
};
struct ChannelScaleData {
int fftSize;
int bufSize; // size of every freq-domain array here: fftSize/2 + 1
@@ -143,10 +158,10 @@ protected:
FixedVector<double> imag;
FixedVector<double> mag;
FixedVector<double> phase;
FixedVector<double> outPhase; //!!! "advanced"?
FixedVector<int> nextTroughs; //!!! not used in every scale
FixedVector<double> advancedPhase;
FixedVector<int> troughs; //!!! not used in every scale
FixedVector<double> prevMag; //!!! not used in every scale
FixedVector<double> prevOutPhase;
FixedVector<double> prevAdvancedPhase;
FixedVector<double> accumulator;
ChannelScaleData(int _fftSize, int _longestFftSize) :
@@ -157,10 +172,10 @@ protected:
imag(bufSize, 0.f),
mag(bufSize, 0.f),
phase(bufSize, 0.f),
outPhase(bufSize, 0.f),
nextTroughs(bufSize, 0),
advancedPhase(bufSize, 0.f),
troughs(bufSize, 0),
prevMag(bufSize, 0.f),
prevOutPhase(bufSize, 0.f),
prevAdvancedPhase(bufSize, 0.f),
accumulator(_longestFftSize, 0.f)
{ }
@@ -171,6 +186,7 @@ protected:
struct ChannelData {
std::map<int, std::shared_ptr<ChannelScaleData>> scales;
ClassificationReadaheadData readahead;
std::unique_ptr<BinSegmenter> segmenter;
BinSegmenter::Segmentation segmentation;
BinSegmenter::Segmentation prevSegmentation;
@@ -183,6 +199,7 @@ protected:
BinClassifier::Parameters classifierParameters,
int ringBufferSize) :
scales(),
readahead(segmenterParameters.fftSize),
segmenter(new BinSegmenter(segmenterParameters,
classifierParameters)),
segmentation(), prevSegmentation(), nextSegmentation(),