package ch.tachyon.sonics.effect.pitchtime;

import java.util.*;

import org.corebounce.common.audio.*;
import org.corebounce.common.dsp.fft.*;
import org.corebounce.common.math.*;

import ch.tachyon.sonics.effect.base.stft.*;
import ch.tachyon.tunnel.common.*;
import ch.tachyon.tunnel.plugin.opt.doc.*;
import ch.tachyon.tunnel.plugin.opt.thread.*;
import ch.tachyon.tunnel.plugin.param.*;


/**
 * General Time-Stretching implementation 1:
 *   VocoderPhd.pdf
 *   VocoderPhasiness.pdf
 *
 * General Time-Stretching implementation 2:
 *   icmc00-bonada.pdf
 * @author Nicolas
 */
// TODO: add frequency resolution slider
@Category("Tempo")
@Name("Half Tempo")
@Description("Slow down the tempo by 2 without affecting the pitch")
@MultiThreading
public class HalfTempo extends StftPushEffectBase {

    private final static int BLOCK_SIZE_0 = 8192;
    private final static int BLOCK_SIZE_1 = BLOCK_SIZE_0 * 2;
    private final static int OVERLAP_OUT_0 = 2;
    private final static int OVERLAP_OUT_1 = 8;
    private final static int OVERLAP_OUT_2 = 64;
    private final static int MAX_QUALITY = 2;

    private int quality = 1;
    private int shrinkLatency;

    private Random rnd = new Random();


    @Order(1)
    @Range(minValue=0, maxValue=MAX_QUALITY, defaultValue=1)
    @Name("Accuracy")
    @Description("Accuracy of the result\n" +
                "Higher values mean better quality but slower processing")
    public int getQuality() {
        return quality;
    }

    public void setQuality(int quality) {
        this.quality = quality;
    }

    @Order(2)
    @Name("NUMA Mode")
    @Description("Experimental - Optimize for NUMA architectures")
    @Override
    public boolean isAccuMode() {
        return super.isAccuMode();
    }

    @Override
    public void setAccuMode(boolean accuMode) {
        super.setAccuMode(accuMode);
    }

    // Processing

    @Override
    public boolean canWriteFasterThanRead() {
        return true;
    }

    @Override
    public void startProcessing(IProcessingInfo info) {
        int blockSize = (quality == 0 ? BLOCK_SIZE_0 : BLOCK_SIZE_1);
        float[] analysisWindow;
        float[] synthesisWindow;
        if (quality == 0) {
            this.shrinkLatency = 0;
            analysisWindow = WindowsFactory.getHannWindow(blockSize);
            synthesisWindow = WindowsFactory.getHannWindow(blockSize);
            float[] convolved = new float[blockSize * 2];
            // After the time-stretch, analysis window is convolved with itself:
            final float[] HannSelfConvolvedCoefs = new float[] {0.5f / 2.0f, -0.36f / 2.0f, 0.125f / 2.0f, -0.01438f / 2.0f}; // Empirical
            Windows.fillWindow(convolved, HannSelfConvolvedCoefs);
            // Because no zero-padding is used, the result must be time folded
            float[] modifiedWindow = new float[blockSize];
            for (int i = 0; i < blockSize; i++)
                modifiedWindow[i] = convolved[i + blockSize / 2] + convolved[(i + blockSize * 3 / 2) % convolved.length];
            super.setModifiedWindow(modifiedWindow);
        } else {
            this.shrinkLatency = blockSize / 4;
            analysisWindow = WindowsFactory.getHannWindow(blockSize, 2);
            synthesisWindow = WindowsFactory.getHannWindow(blockSize, 2);
            /*
             * Here analysis window is convolved with itself (without time folding).
             * However, overlap is sufficient to yield a flat curve after the process.
             * Hence we do not specify modified window...
             */

            float[] modifiedWindow = new float[blockSize];
            Arrays.fill(modifiedWindow, 0.5f);
            super.setModifiedWindow(modifiedWindow);
        }

        int overlapOut = (quality == 0 ? OVERLAP_OUT_0 : (quality == 1 ? OVERLAP_OUT_1 : OVERLAP_OUT_2));
        super.setBlockSizeLog(AudioMath.log2(blockSize));
        super.setOverlapInLog(AudioMath.log2(overlapOut * 2));
        super.setOverlapOutLog(AudioMath.log2(overlapOut));
        super.setAnalysisWindow(analysisWindow);
        super.setSynthesisWindow(synthesisWindow);
        super.startProcessing(info);
    }

    @Override
    public int getLatency(IoDirection ioDirection) {
        if (ioDirection == IoDirection.INPUT)
            return super.getLatency(ioDirection) / 2 + shrinkLatency;
        else
            return super.getLatency(ioDirection) + shrinkLatency * 2;
    }

    @Override
    protected void processSpectrum(Cmplx[] spectrum) {
        spectrum[0].mul(-1.0f);
        spectrum[spectrum.length - 1].mul(-1.0f);
        for (int i = 1; i < spectrum.length - 1; i++) {
            Cmplx value = spectrum[i];
            float mag = value.magApprox();
            if (mag > 0.0f) {
                value.mul(value); // Double phase and squares magnitude
                value.mul(1.0f / mag); // Restore original magnitude
            }
            value.toPolar();
            value.im = (float) ((rnd.nextFloat() - 0.5f) * Math.PI * 2.0);
            value.toCartesian();
            value.mul(1.4141f);
        }
    }

    // Post processing

    @Override
    protected AbstractStftEngine newAccuStftEngine() {
        return new AccuStftEngine() {

            @Override
            protected void createEngine(int blockSize, int inHopSize, int outHopSize, int numHops, float[] analysisWindow,
                    float[] modifiedWindow, float[] synthesisWindow) {
                analyzer = new StftAnalyzerAccu(blockSize, inHopSize, numHops, analysisWindow);
                synthesizer = new StftSynthesizerAccu(blockSize, outHopSize, numHops, modifiedWindow, synthesisWindow) {
                    @Override
                    protected void backwardFFT(Cmplx[] spectrum, float[] output) {
                        super.backwardFFT(spectrum, output);
                        // Rotate buffer by blockSize / 2
                        final int middle = blockSize / 2;
                        for (int i = 0; i < middle; i++) {
                            float temp = output[i];
                            output[i] = -output[i + middle];
                            output[i + middle] = -temp;
                        }
                    }
                };
            }

        };
    }

    @Override
    protected AbstractStftEngine newParallelStftEngine() {
        return new ParallelStftEngine() {

            @Override
            protected void createEngine(int blockSize, int inHopSize, int outHopSize, int numHops, float[] analysisWindow,
                    float[] modifiedWindow, float[] synthesisWindow) {
                analyzer = new StftAnalyzer(blockSize, inHopSize, numHops, analysisWindow, false);
                synthesizer = new StftSynthesizer(blockSize, outHopSize, numHops, modifiedWindow, synthesisWindow) {
                    @Override
                    protected void backwardFFT(Cmplx[] spectrum, float[] output) {
                        super.backwardFFT(spectrum, output);
                        // Rotate buffer by blockSize / 2
                        final int middle = getBlockSize() / 2;
                        for (int i = 0; i < middle; i++) {
                            float temp = output[i];
                            output[i] = -output[i + middle];
                            output[i + middle] = -temp;
                        }
                    }
                };
                this.output = new float[synthesizer.getOutputSize()];
            }

        };
    }

}
