Flesh out the implementation a bit

This commit is contained in:
Chris Cannam
2022-05-23 15:04:34 +01:00
parent af97c70e69
commit 5cc4833820
7 changed files with 408 additions and 48 deletions

View File

@@ -235,6 +235,9 @@ protected:
bool inRange(double f, const Guide::Range &r) {
return r.present && f >= r.f0 && f < r.f1;
}
GuidedPhaseAdvance(const GuidedPhaseAdvance &) =delete;
GuidedPhaseAdvance &operator=(const GuidedPhaseAdvance &) =delete;
};
}

View File

@@ -0,0 +1,277 @@
/* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
/*
Rubber Band Library
An audio time-stretching and pitch-shifting library.
Copyright 2007-2022 Particular Programs Ltd.
This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License as
published by the Free Software Foundation; either version 2 of the
License, or (at your option) any later version. See the file
COPYING included with this distribution for more information.
Alternatively, if you have a valid commercial licence for the
Rubber Band Library obtained by agreement with the copyright
holders, you may redistribute and/or modify it under the terms
described in that licence.
If you wish to distribute code using the Rubber Band Library
under terms other than those of the GNU General Public License,
you must obtain a valid commercial licence before doing so.
*/
#include "R3StretcherImpl.h"
#include <array>
namespace RubberBand {
void
R3StretcherImpl::setTimeRatio(double ratio)
{
m_timeRatio = ratio;
}
void
R3StretcherImpl::setPitchScale(double scale)
{
m_pitchScale = scale;
}
double
R3StretcherImpl::getTimeRatio() const
{
return m_timeRatio;
}
double
R3StretcherImpl::getPitchScale() const
{
return m_pitchScale;
}
size_t
R3StretcherImpl::getLatency() const
{
return 0; //!!!
}
size_t
R3StretcherImpl::getChannelCount() const
{
return m_parameters.channels;
}
void
R3StretcherImpl::reset()
{
//!!!
}
size_t
R3StretcherImpl::getSamplesRequired() const
{
int longest = m_guideConfiguration.longestFftSize;
size_t rs = m_channelData[0]->inbuf->getReadSpace();
if (rs < longest) {
return longest - rs;
} else {
return 0;
}
}
void
R3StretcherImpl::process(const float *const *input, size_t samples, bool final)
{
//!!! todo: final
bool allConsumed = false;
size_t ws = m_channelData[0]->inbuf->getWriteSpace();
if (samples > ws) {
//!!! check this
m_parameters.logger("R3StretcherImpl::process: WARNING: Forced to increase input buffer size. Either setMaxProcessSize was not properly called or process is being called repeatedly without retrieve.");
size_t newSize = m_channelData[0]->inbuf->getSize() - ws + samples;
for (int c = 0; c < m_parameters.channels; ++c) {
m_channelData[c]->inbuf =
std::unique_ptr<RingBuffer<float>>
(m_channelData[c]->inbuf->resized(newSize));
}
}
for (int c = 0; c < m_parameters.channels; ++c) {
m_channelData[c]->inbuf->write(input[c], samples);
}
consume();
}
int
R3StretcherImpl::available() const
{
return int(m_channelData[0]->outbuf->getReadSpace());
}
size_t
R3StretcherImpl::retrieve(float *const *output, size_t samples) const
{
size_t got = samples;
for (size_t c = 0; c < m_parameters.channels; ++c) {
size_t gotHere = m_channelData[c]->outbuf->read(output[c], got);
if (gotHere < got) {
if (c > 0) {
m_parameters.logger("R3StretcherImpl::retrieve: WARNING: channel imbalance detected");
}
got = gotHere;
}
}
return got;
}
void
R3StretcherImpl::consume()
{
int inhop = 171, outhop = 256; //!!!
double ratio = double(outhop) / double(inhop);
int longest = m_guideConfiguration.longestFftSize;
int classify = m_guideConfiguration.classificationFftSize;
while (m_channelData[0]->inbuf->getReadSpace() >= longest &&
m_channelData[0]->outbuf->getWriteSpace() >= outhop) {
m_parameters.logger("consume looping");
for (int c = 0; c < m_parameters.channels; ++c) {
auto cd = m_channelData[c];
auto longestScale = cd->scales.at(longest);
cd->inbuf->read(longestScale->timeDomainFrame.data(), longest);
for (auto it: cd->scales) {
int fftSize = it.first;
auto scale = it.second;
if (fftSize == longest) continue;
int offset = (longest - fftSize) / 2;
m_scaleData.at(fftSize)->analysisWindow.cut
(longestScale->timeDomainFrame.data() + offset,
scale->timeDomainFrame.data());
}
m_scaleData.at(longest)->analysisWindow.cut
(longestScale->timeDomainFrame.data());
}
for (int c = 0; c < m_parameters.channels; ++c) {
auto cd = m_channelData[c];
//!!! There are some aspects of scaling etc handled in bsq
//!!! that are not yet here
for (auto it: cd->scales) {
int fftSize = it.first;
auto scale = it.second;
m_scaleData.at(fftSize)->fft.forwardPolar
(scale->timeDomainFrame.data(),
scale->mag.data(),
scale->phase.data());
}
}
for (int c = 0; c < m_parameters.channels; ++c) {
auto cd = m_channelData[c];
auto classifyScale = cd->scales.at(classify);
cd->prevSegmentation = cd->segmentation;
cd->segmentation = cd->segmenter->segment(classifyScale->mag.data());
m_troughPicker.findNearestAndNextPeaks
(classifyScale->mag.data(), 3, nullptr,
classifyScale->nextTroughs.data());
m_guide.calculate(ratio, classifyScale->mag.data(),
classifyScale->nextTroughs.data(),
classifyScale->prevMag.data(),
cd->segmentation,
cd->prevSegmentation,
BinSegmenter::Segmentation(), //!!!
cd->guidance);
}
for (auto it : m_channelData[0]->scales) {
int fftSize = it.first;
for (int c = 0; c < m_parameters.channels; ++c) {
auto cd = m_channelData[c];
auto classifyScale = cd->scales.at(fftSize);
m_channelAssembly.mag[c] = classifyScale->mag.data();
m_channelAssembly.phase[c] = classifyScale->phase.data();
m_channelAssembly.guidance[c] = &cd->guidance;
m_channelAssembly.outPhase[c] = classifyScale->outPhase.data();
}
m_scaleData.at(fftSize)->guided.advance
(m_channelAssembly.outPhase.data(),
m_channelAssembly.mag.data(),
m_channelAssembly.phase.data(),
m_guideConfiguration,
m_channelAssembly.guidance.data(),
inhop,
outhop);
}
for (int c = 0; c < m_parameters.channels; ++c) {
for (auto it : m_channelData[c]->scales) {
auto scale = it.second;
int bufSize = scale->bufSize;
// copy to prevMag before filtering
v_copy(scale->prevMag.data(), scale->mag.data(), bufSize);
v_copy(scale->prevOutPhase.data(), scale->outPhase.data(), bufSize);
for (int i = 0; i < bufSize; ++i) {
scale->phase[i] = princarg(scale->outPhase[i]);
}
}
}
//!!! + filter here
for (int c = 0; c < m_parameters.channels; ++c) {
for (auto it : m_channelData[c]->scales) {
int fftSize = it.first;
auto scale = it.second;
auto scaleData = m_scaleData.at(fftSize);
int bufSize = scale->bufSize;
scaleData->fft.inversePolar(scale->mag.data(),
scale->phase.data(),
scale->timeDomainFrame.data());
int synthesisWindowSize = scaleData->synthesisWindow.getSize();
int fromOffset = (fftSize - synthesisWindowSize) / 2;
int toOffset = (longest - synthesisWindowSize) / 2;
//!!! not right - accumulator is of scale data size, not full longest size - we need offset when mixing into mixdown buffer below as well
scaleData->synthesisWindow.cutAndAdd
(scale->timeDomainFrame.data() + fromOffset,
scale->accumulator.data() + toOffset);
}
}
for (int c = 0; c < m_parameters.channels; ++c) {
auto cd = m_channelData[c];
v_zero(cd->mixdown.data(), outhop);
for (auto it : cd->scales) {
auto scale = it.second;
auto &acc = scale->accumulator;
v_add(cd->mixdown.data(), acc.data(), outhop);
int n = acc.size() - outhop;
v_move(acc.data(), acc.data() + outhop, n);
v_zero(acc.data() + n, outhop);
}
m_channelData[c]->outbuf->write(cd->mixdown.data(), outhop);
m_channelData[c]->inbuf->skip(inhop);
}
}
}
}

View File

@@ -56,7 +56,9 @@ public:
R3StretcherImpl(Parameters parameters) :
m_parameters(parameters),
m_guide(Guide::Parameters(m_parameters.sampleRate)),
m_guideConfiguration(m_guide.getConfiguration())
m_guideConfiguration(m_guide.getConfiguration()),
m_channelAssembly(m_parameters.channels),
m_troughPicker(m_guideConfiguration.classificationFftSize / 2 + 1)
{
BinSegmenter::Parameters segmenterParameters
(m_guideConfiguration.classificationFftSize,
@@ -64,7 +66,9 @@ public:
BinClassifier::Parameters classifierParameters
(m_guideConfiguration.classificationFftSize / 2 + 1,
9, 1, 10, 2.0, 2.0, 1.0e-7);
int ringBufferSize = m_guideConfiguration.longestFftSize * 2;
for (int c = 0; c < m_parameters.channels; ++c) {
m_channelData.push_back(std::make_shared<ChannelData>
(segmenterParameters,
@@ -72,11 +76,18 @@ public:
ringBufferSize));
for (auto band: m_guideConfiguration.fftBandLimits) {
int fftSize = band.fftSize;
m_ffts[fftSize] = std::make_shared<FFT>(fftSize);
m_channelData[c]->scales[fftSize] =
std::make_shared<ChannelScaleData>(fftSize);
}
}
for (auto band: m_guideConfiguration.fftBandLimits) {
int fftSize = band.fftSize;
GuidedPhaseAdvance::Parameters guidedParameters
(fftSize, m_parameters.sampleRate, m_parameters.channels,
m_parameters.logger);
m_scaleData[fftSize] = std::make_shared<ScaleData>(guidedParameters);
}
}
~R3StretcherImpl() { }
@@ -89,29 +100,39 @@ public:
double getTimeRatio() const;
double getPitchScale() const;
size_t getSamplesRequired() const;
void process(const float *const *input, size_t samples, bool final);
int available() const;
size_t retrieve(float *const *output, size_t samples) const;
size_t getLatency() const;
size_t getChannelCount() const;
protected:
struct ChannelScaleData {
int fftSize;
int bufSize; // size of every freq-domain array here: fftSize/2 + 1
//!!! review later which of these we are actually using!
FixedVector<float> timeDomainFrame;
FixedVector<float> mag;
FixedVector<float> phase;
FixedVector<int> nearestPeaks;
FixedVector<int> nearestTroughs;
FixedVector<float> prevOutMag;
FixedVector<double> outPhase; //!!! "advanced"?
FixedVector<int> nextTroughs; //!!! not used in every scale
FixedVector<float> prevMag; //!!! not used in every scale
FixedVector<double> prevOutPhase;
FixedVector<int> prevNearestPeaks;
FixedVector<float> timeDomainFrame;
Window<float> analysisWindow;
Window<float> synthesisWindow;
FixedVector<float> accumulator;
ChannelScaleData(int _fftSize) :
fftSize(_fftSize), bufSize(fftSize/2 + 1),
mag(bufSize, 0.f), phase(bufSize, 0.f),
nearestPeaks(bufSize, 0), nearestTroughs(bufSize, 0),
prevOutMag(bufSize, 0.f), prevOutPhase(bufSize, 0.f),
prevNearestPeaks(bufSize, 0), timeDomainFrame(fftSize, 0.f),
analysisWindow(HannWindow, fftSize),
synthesisWindow(HannWindow, fftSize/2)
fftSize(_fftSize),
bufSize(fftSize/2 + 1),
timeDomainFrame(fftSize, 0.f),
mag(bufSize, 0.f),
phase(bufSize, 0.f),
outPhase(bufSize, 0.f),
nextTroughs(bufSize, 0),
prevMag(bufSize, 0.f),
prevOutPhase(bufSize, 0.f),
accumulator(fftSize, 0.f)
{ }
private:
@@ -126,8 +147,9 @@ protected:
BinSegmenter::Segmentation prevSegmentation;
BinSegmenter::Segmentation nextSegmentation;
Guide::Guidance guidance;
RingBuffer<float> inbuf;
RingBuffer<float> outbuf;
FixedVector<float> mixdown;
std::unique_ptr<RingBuffer<float>> inbuf;
std::unique_ptr<RingBuffer<float>> outbuf;
ChannelData(BinSegmenter::Parameters segmenterParameters,
BinClassifier::Parameters classifierParameters,
int ringBufferSize) :
@@ -135,19 +157,49 @@ protected:
segmenter(new BinSegmenter(segmenterParameters,
classifierParameters)),
segmentation(), prevSegmentation(), nextSegmentation(),
inbuf(ringBufferSize), outbuf(ringBufferSize) { }
mixdown(ringBufferSize, 0.f), //!!! could be much shorter (bound is the max outhop)
inbuf(new RingBuffer<float>(ringBufferSize)),
outbuf(new RingBuffer<float>(ringBufferSize)) { }
};
struct ChannelAssembly {
// Vectors of bare pointers, used to package container data
// from different channels into arguments for PhaseAdvance
FixedVector<float *> mag;
FixedVector<float *> phase;
FixedVector<Guide::Guidance *> guidance;
FixedVector<double *> outPhase;
ChannelAssembly(int channels) :
mag(channels, nullptr), phase(channels, nullptr),
guidance(channels, nullptr), outPhase(channels, nullptr) { }
};
struct ScaleData {
FFT fft;
Window<float> analysisWindow;
Window<float> synthesisWindow;
GuidedPhaseAdvance guided;
ScaleData(GuidedPhaseAdvance::Parameters guidedParameters) :
fft(guidedParameters.fftSize),
analysisWindow(HannWindow, guidedParameters.fftSize),
synthesisWindow(HannWindow, guidedParameters.fftSize/2),
guided(guidedParameters) { }
};
Parameters m_parameters;
double m_timeRatio;
double m_pitchScale;
std::vector<std::shared_ptr<ChannelData>> m_channelData;
std::map<int, std::shared_ptr<FFT>> m_ffts;
std::map<int, std::shared_ptr<ScaleData>> m_scaleData;
Guide m_guide;
Guide::Configuration m_guideConfiguration;
ChannelAssembly m_channelAssembly;
Peak<float, std::less<float>> m_troughPicker;
void consume();
static void logCerr(const std::string &message) {
std::cerr << "RubberBandStretcher: " << message << std::endl;
}