package ch.tachyon.sonics.gui.file.view.spectrogram;

import java.awt.*;
import java.awt.image.*;
import java.io.*;
import java.lang.ref.*;
import java.util.*;
import org.corebounce.common.audio.*;
import org.corebounce.common.dsp.*;
import org.corebounce.common.gui.*;
import org.corebounce.common.math.*;

import ch.tachyon.sonics.data.audio.*;
import ch.tachyon.sonics.data.stats.*;
import ch.tachyon.sonics.gui.file.view.*;
import ch.tachyon.tunnel.utils.*;


// TODO: refresh slightly more than the modified range after an operation
public class SpectrogramOfflineBuffer implements IChannelOfflineDataBuffer {

    private final static float DB_BOOST_PER_OCTAVE = 3.0f;
    private final static boolean ENHANCE_TOWARD_ZERO = false;
    private final static int BOOST_DB = 10;
    private final static int SCALING_QUALITY = 1;
    private final static int SPECTRUM_QUALITY = 1;
    protected final static float SHIFT_DB = 40.0f;

    private final static float MAX_FREQ = 22050.0f;
    private final static float MID_FREQ = 2000.0f;
    
    // Computation
    protected final boolean logarithmic;
    protected final int width;
    protected final int fftSize;
    protected final int nbBins;
    protected final float correction;
    protected final BooFFT fft;
    protected final float[] window;
    protected final float[] buffer;
    protected final Cmplx[] spectrum;
    // Stored as dB + SHIFT_DB (where dB is rounded, usually negative dB value)
    protected final byte[][] spectrogram;

    private final float[] lookupCache;
    private final float[] dbBoost;

    // Rendering
    private final Color backgroundColor;
    private final Color busyColor;
    private final int[] gradient; // Colors from 0dB to -80dB
    private GradientType gradientType;
    private SoftReference<BufferedImage> unscaledImageRef = new SoftReference<BufferedImage>(null);
    private SoftReference<VolatileImage> imageRef = new SoftReference<VolatileImage>(null);
    private final DirtyArea dirtySpectrogram = new DirtyArea();
    private final DirtyArea dirtyUnscaled = new DirtyArea();
    private final DirtyArea dirtyImage = new DirtyArea();
    private final int scalingQuality;
    private final int spectrumQuality;
    
    private Boolean accelerated;
    
    /**
     * To allow fast scrolling, the following are based on a circular buffer whose offset is given by this field:
     * <ul>
     * <li>{@link #spectrogram}
     * <li>{@link #unscaledImageRef} (the underlying image)
     * <li>{@link #imageRef} (the underlying image)
     * </ul>
     * On the other hand, {@link #dirtyImage}, {@link #dirtySpectrogram} and {@link #dirtyUnscaled} are based on actual
     * coordinates and not on circular buffer's coordinates
     */
    private int offlineOffset;

    private final Object paintLock = new Object();
    private final Object dataLock = new Object();


    public SpectrogramOfflineBuffer(int width, int fftSize, boolean logarithmic, GradientType gradientType,
            Color backgroundColor, Color busyColor) {
        this.width = width;
        this.fftSize = fftSize;
        this.logarithmic = logarithmic;
        this.buffer = new float[fftSize];
        this.window = new float[fftSize];
        this.fft = BooFFT.getInstance(fftSize / 2);
        this.nbBins = fftSize / 2 + 1;
        this.correction = fftSize;
        this.spectrum = Cmplx.newArray(nbBins);
        this.spectrogram = new byte[nbBins][width];
        this.lookupCache = new float[nbBins + 1];
        Arrays.fill(lookupCache, Float.NaN);
        this.dbBoost = new float[nbBins];
        initBoost();
        this.backgroundColor = backgroundColor;
        this.busyColor = busyColor;
        this.gradient = new int[nbDb()];
        this.offlineOffset = 0;
        this.scalingQuality = SCALING_QUALITY;
        this.spectrumQuality = SPECTRUM_QUALITY;
        Windows.fillWindow(window, Windows.HannCoefs);
        setGradientType(gradientType);
        markDirty(0, width);
    }

    protected int dynamicRange() {
        return 80;
    }

    protected final int nbDb() {
        return dynamicRange() + BOOST_DB;
    }

    protected final float minDb() {
        return -nbDb();
    }

    private void initBoost() {
        if (!logarithmic)
            return;

        for (int y = 0; y < nbBins; y++) {
            // Boost
            int y1 = Math.max(y, 1);
            float freq = (float) y1 * MAX_FREQ / nbBins;
            float octaves = (float) (Math.log(freq / MID_FREQ) / Math.log(2.0));
            dbBoost[y] = octaves * DB_BOOST_PER_OCTAVE;
        }
    }

    public GradientType getGradientType() {
        return gradientType;
    }

    public void setGradientType(GradientType gradientType) {
        if (this.gradientType != gradientType) {
            this.gradientType = gradientType;
            gradientType.initGradient(gradient);
        }
    }

    public int getWidth() {
        return this.width;
    }

    private int rotate(int x, boolean hiBound) {
        if (hiBound) {
            // if result matches width, keep it instead of returning 0
            while (x < 0)
                x += width;
            x += offlineOffset;
            if (x > width)
                x = x % width;
            return x;
        } else {
            // Should also work with x == -1, hence addition of width before modulo
            x = (x + offlineOffset + width) % width;
            return x;
        }
    }
    
    private int rotate(int x) {
        return rotate(x, false);
    }

    private float getValue(AudioChannelDataRange view, int dx, int y) {
        float y0 = scale(view, y);
        float y1 = scale(view, y + 1);
        if (spectrumQuality < 1) {
            // Nearest neighbour
            int sy = (int) ((y0 + y1) / 2.0f + 0.5f);
            return getBoostedValue0(dx, sy);
        } else {
            // Average (scaling down) and cubic interpolation (scaling up)
            // Check if out of bounds
            if (y1 <= 0.0f || y0 >= nbBins)
                return Float.NEGATIVE_INFINITY; // Out of file
            // Get value
            if (y1 - y0 > 1.0f) {
                // Max
                int sy = (int) y0;
                int ey = (int) y1;
                assert ey > sy;
                float result = Float.NEGATIVE_INFINITY;
                for (int i = sy; i < ey; i++)
                    result = FastMath.max(result, getBoostedValue0(dx, i));
                return result;
            } else {
                // Interpolation
                float index = (y0 + y1) / 2.0f - 0.5f;
                int index1 = (int) index;
                int index0 = index1 - 1;
                int index2 = index1 + 1;
                int index3 = index1 + 2;
                double mu = (index - index1);
                double ya = getBoostedValue0(dx, index0);
                double yb = getBoostedValue0(dx, index1);
                double yc = getBoostedValue0(dx, index2);
                double yd = getBoostedValue0(dx, index3);
                double result = CubicInterpolator.splineInterpolateY(ya, yb, yc, yd, mu);
                return (float) result;
            }
        }
    }

    /**
     * Convert y coordinate (0..nbBins) according to linear/logarithmic and to view ceil and floor
     */
    private float scale(AudioChannelDataRange view, int y) {
        float result = lookupCache[y];
        if (Float.isNaN(result)) {
            result = scaleVertical(view, y);
            if (logarithmic)
                result = logLookup(result);
            lookupCache[y] = result;
        }
        return result;
    }

    private float logLookup(float y) {
        final float base = (MAX_FREQ / MID_FREQ) * (MAX_FREQ / MID_FREQ);

        double norm = (double) y / (nbBins - 1); // [0..1]
        double logNorm = (Math.pow(base, norm) - 1.0) / (base - 1.0); // 0 - 1
        float index = (float) (logNorm * (nbBins - 1));
        return index;
    }

    private float scaleVertical(AudioChannelDataRange view, float value) {
        float floor = view.getVerticalFloor();
        float ceil = view.getVerticalCeil();
        if (floor == -1.0f && ceil == 1.0f)
            return value;
        float vertSpan = (ceil - floor) / 2.0f;
        // Normalize value to the [-1 .. 1] range
        float norm = value * 2.0f / (float) nbBins - 1.0f;
        // Apply vertical "zoom"
        float scaledNorm = norm * vertSpan;
        // Apply vertical "translate"
        float shifted = scaledNorm + swapShiftVert((floor + ceil) / 2.0f);
        // Denormalize back to [0 .. nbBins]
        return (shifted + 1.0f) * nbBins / 2.0f;
    }

    private float getBoostedValue0(int dx, int y) {
        if (y < 0)
            y = 0;
        else if (y >= nbBins)
            y = nbBins - 1;
        return spectrogram[y][dx] + dbBoost[y] - SHIFT_DB;
    }

    protected int swapCoordVert(int y) {
        return nbBins - y - 1;
    }

    protected float swapShiftVert(float v) {
        return v;
    }

    public void markDirty(int startX, int stopX) {
        synchronized (paintLock) {
            synchronized (dataLock) {
                dirtySpectrogram.markDirty(startX, stopX);
            }
            dirtyUnscaled.markDirty(startX, stopX);
            dirtyImage.markDirty(startX, stopX);
        }
    }

    public boolean scroll(int deltaX) {
        if (Math.abs(deltaX) >= width - 1)
            return false; // Too large
        synchronized (paintLock) {
            synchronized (dataLock) {
                // Scroll dirty regions
                dirtyUnscaled.scroll(deltaX, width);
                dirtyImage.scroll(deltaX, width);
                dirtySpectrogram.scroll(deltaX, width);
                // Scroll spectrogram, unscaled image and scaled image by modifying the offline offset
                offlineOffset = (offlineOffset + width - deltaX) % width;
                assert offlineOffset >= 0 && offlineOffset < width;
            }
            return true;
        }
    }

    public void refresh(AudioChannel data, AudioChannelDataRange view, int startX, int stopX, long fileStartPos, double reduction,
            int height, IRefreshObserver observer) throws IOException {
        if (startX < 0)
            startX = 0;
        if (stopX > width)
            stopX = width;
        int[] rgb = new int[nbBins];
        for (int x = startX; x < stopX; x++) {
            long midPos = fileStartPos + (long) ((double) x * reduction + 0.5);
            long bucketStart = fileStartPos + (long) (((double) x - 0.5) * reduction + 0.5);
            if (bucketStart < 0)
                bucketStart = 0;
            long bucketStop = fileStartPos + (long) (((double) x + 0.5) * reduction + 0.5);
            if (bucketStop > data.getLength())
                bucketStop = data.getLength();
            long bucketRange = bucketStop - bucketStart;
            double sampling = (double) fftSize / (double) bucketRange;
            int dx = rotate(x);
            long pos;
            if (sampling < 1.5 && bucketRange > 0) {
                pos = data.getExtremumSample(bucketStart, bucketStop, data, new LoudestExtremum());
            } else {
                pos = midPos;
            }
            computeSpectrumAt(data, pos, dx);
            synchronized (dataLock) {
                dirtySpectrogram.markClean(x);
            }
            // Optimization (leaves less work to the EDT):
            synchronized (paintLock) {
                // Rebuild corresponding column of unscaled image if available and dirty from the same point
                BufferedImage unscaled = unscaledImageRef.get();
                if (unscaled != null && dirtyUnscaled.startX == x) {
                    rebuildRgbRange(view, unscaled, x, x + 1, dirtySpectrogram, 1, rgb);
                    dirtyUnscaled.markClean(x);
                }
            }
            boolean cont = observer.refreshed(this, x, midPos, false);
            if (!cont) {
                break;
            }
        }
    }

    private void computeSpectrumAt(AudioChannel data, long pos, int dx) throws IOException {
        long startPos = pos - fftSize / 2;
        long stopPos = startPos + fftSize;
        long readStart = Math.max(startPos, 0);
        long readStop = Math.min(stopPos, data.getLength());
        // Quick check if entirely outside of file
        if (readStop <= readStart) {
            // Fill with zeros (-Infinity dB)
            for (int y = 0; y < nbBins; y++)
                spectrogram[y][dx] = Byte.MIN_VALUE;
            return;
        }
        // Read data
        Arrays.fill(buffer, 0.0f);
        if (readStop > readStart)
            data.read(readStart, buffer, (int) (readStart - startPos), (int) (readStop - readStart));
        if (readStart > startPos)
            Arrays.fill(buffer, 0, (int) (readStart - startPos), 0.0f);
        if (stopPos > readStop)
            Arrays.fill(buffer, (int) (readStop - startPos), buffer.length, 0.0f);
        // Compute spectrum
        computeSpectrum(dx);
    }

    protected void computeSpectrum(int dx) {
        for (int i = 0; i < fftSize; i++)
            buffer[i] *= window[i];
        fft.forwR2C(buffer, spectrum);
        // Get magnitude as dB
        float minDb = minDb();
        for (int y = 0; y < nbBins; y++) {
            float pMag = spectrum[y].powerMag() * correction;
            float db = (float) AudioMath.powerLevelToDb(pMag);
            if (db < minDb)
                db = minDb;
            spectrogram[y][dx] = (byte) (db + SHIFT_DB - 0.5f);
        }
    }

    private BufferedImage lookupUnscaledImage(AudioChannelDataRange view) {
        BufferedImage image = unscaledImageRef.get();
        if (image == null) {
            // When using setRGB or setDataElements (ImageUtils.setPixels), it seems
            // that using a manually created image is faster than a compatible image
            // (especially on old Java versions and VESA)
            image = new BufferedImage(width, nbBins, BufferedImage.TYPE_INT_RGB);
            unscaledImageRef = new SoftReference<BufferedImage>(image);
            dirtyUnscaled.markDirty(0, width);
        }
        DirtyArea dirtySpectrogram;
        synchronized (this.dirtySpectrogram) {
            // Copy current state of dirty spectrogram as it may change concurrently
            dirtySpectrogram = this.dirtySpectrogram.clone();
        }
        if (dirtyUnscaled.isDirty() && !dirtyUnscaled.isInside(dirtySpectrogram)) {
            // Recompute available region
            int startX = dirtyUnscaled.startX;
            int stopX = dirtyUnscaled.stopX;
            int cutX = width - offlineOffset;
            if (cutX > startX && cutX < stopX) {
                rebuildRgbRange(view, image, startX, cutX, dirtySpectrogram);
                rebuildRgbRange(view, image, cutX, stopX, dirtySpectrogram);
            } else if (stopX > startX) {
                rebuildRgbRange(view, image, startX, stopX, dirtySpectrogram);
            }
            dirtyUnscaled.copyFrom(dirtySpectrogram);
        }
        return image;
    }

    private void rebuildRgbRange(AudioChannelDataRange view, BufferedImage image, int startX, int stopX, DirtyArea dirtySpectrogram) {
        int dirtyWidth = stopX - startX;
        int[] rgb = new int[dirtyWidth * nbBins];
        rebuildRgbRange(view, image, startX, stopX, dirtySpectrogram, dirtyWidth, rgb);
    }

    private void rebuildRgbRange(AudioChannelDataRange view, BufferedImage image, int startX, int stopX, DirtyArea dirtySpectrogram,
            int dirtyWidth, int[] rgb) {
        for (int y = 0; y < nbBins; y++) {
            int yInv = swapCoordVert(y);
            for (int x = startX; x < stopX; x++) {
                int dx = rotate(x);
                int rgbOff = y * dirtyWidth + x - startX;
                if (!dirtySpectrogram.contains(x)) {
                    float value = getValue(view, dx, yInv);
                    if (value != Float.NEGATIVE_INFINITY) {
                        value = -value;
                        if (ENHANCE_TOWARD_ZERO)
                            value = (float) Math.sqrt(value / nbDb()) * nbDb();
                        value += BOOST_DB;
                        int index = (int) (value + 0.5f);
                        if (index < 0)
                            index = 0;
                        else if (index >= gradient.length)
                            index = gradient.length - 1;
                        rgb[rgbOff] = gradient[index];
                    } else
                        rgb[rgbOff] = backgroundColor.getRGB(); // Outside of file
                } else
                    rgb[rgbOff] = busyColor.getRGB(); // Not computed yet
            }
        }
        assert rotate(stopX, true) > rotate(startX, false) : "" + rotate(startX, false) + " - " + rotate(stopX, true);
        ImageUtils.setPixels(image, rotate(startX), 0, stopX - startX, nbBins, rgb);
    }

    private VolatileImage lookupImage(AudioChannelDataRange view, GraphicsConfiguration gc, int height) {
        VolatileImage image;
        boolean contentsLost;
        do {
            // Get image from cache
            image = imageRef.get();
            int imageStatus = VolatileImage.IMAGE_OK;
            if (image != null)
                imageStatus = image.validate(gc);
            if (image == null || image.getWidth() != width || image.getHeight() != height
                    || imageStatus == VolatileImage.IMAGE_INCOMPATIBLE) {
                // Rebuild image
                if (image != null)
                    image.flush();
                // This image is used for on-screen blitting. Use a compatible image
                image = gc.createCompatibleVolatileImage(width, height);
                imageRef = new SoftReference<VolatileImage>(image);
                boolean accel = image.getCapabilities().isAccelerated();
                if (!Utils.eq(accelerated, accel)) {
                    Debug.info("Image acceleration: {0}", accel);
                    accelerated = accel;
                }
                // Force full repaint
                dirtyImage.markDirty(0, width);
            } else if (imageStatus == VolatileImage.IMAGE_RESTORED) {
                // Force full repaint
                dirtyImage.markDirty(0, width);
            }
            
            // Repaint dirty regions
            if (dirtyImage.isDirty()) {
                // Repaint unpainted region
                int minX = dirtyImage.startX;
                int maxX = dirtyImage.stopX;
                if (maxX > minX) {
                    // Get unscaled image
                    BufferedImage unscaled = lookupUnscaledImage(view);
                    int cutX = width - offlineOffset;
                    Graphics2D g = image.createGraphics();
                    if (cutX > minX && cutX < maxX) {
                        scaleRange(unscaled, gc, height, g, rotate(minX, false), rotate(cutX, true));
                        scaleRange(unscaled, gc, height, g, rotate(cutX, false), rotate(maxX, true));
                    } else {
                        scaleRange(unscaled, gc, height, g, rotate(minX, false), rotate(maxX, true));
                    }
                    g.dispose();
    
                    // Mark as painted
                    dirtyImage.copyFrom(dirtyUnscaled);
                }
            }
            contentsLost = image.contentsLost();
            if (contentsLost)
                dirtyImage.markDirty(0, width); // Force full repaint
        } while (contentsLost);
        return image;
    }

    private void scaleRange(BufferedImage unscaled, GraphicsConfiguration gc, int height, Graphics2D g, int minDX, int maxDX) {
        assert minDX < maxDX;
        // Crop to dirty region
        BufferedImage cropped = unscaled.getSubimage(minDX, 0, maxDX - minDX, nbBins);
        // Create scaled image
        Image scaled = ImageUtils.getScaledInstance(cropped, maxDX - minDX, height, gc, scalingQuality);
        cropped.flush();
        // Copy
        g.drawImage(scaled, minDX, 0, null);
        scaled.flush();
    }
    
    public void paint(AudioChannelDataRange view, Graphics2D g, int startX, int stopX, int height) {
        final GraphicsConfiguration gc = g.getDeviceConfiguration();
        synchronized (paintLock) {
            boolean contentsLost;
            do {
                VolatileImage image = lookupImage(view, gc, height);
                if (dirtySpectrogram.isDirty()) {
                    // Paint dirty area
                    int minX = Math.max(dirtySpectrogram.startX, startX);
                    int maxX = Math.min(dirtySpectrogram.stopX, stopX);
                    if (maxX > minX) {
                        g.setColor(busyColor);
                        g.fillRect(minX, 0, maxX - minX, height);
                    }
                    // Image, left of dirty area
                    if (minX > startX)
                        drawCircularImage(g, image, startX, 0, minX, height);
                    // Image, right of dirty area
                    if (stopX > maxX)
                        drawCircularImage(g, image, maxX, 0, stopX, height);
                } else {
                    // Straight copy
                    drawCircularImage(g, image, startX, 0, stopX, height);
                }
                contentsLost = image.contentsLost();
                if (contentsLost)
                    dirtyImage.markDirty(0, width); // Force full repaint
            } while (contentsLost);
        }
    }

    private void drawCircularImage(Graphics2D g, Image image, int sx, int sy, int ex, int ey) {
        int cx = width - offlineOffset;
        if (cx > sx && cx < ex) {
            g.drawImage(image, sx, sy, cx, ey, rotate(sx, false), sy, rotate(cx, true), ey, null);
            g.drawImage(image, cx, sy, ex, ey, rotate(cx, false), sy, rotate(ex, true), ey, null);
        } else {
            g.drawImage(image, sx, sy, ex, ey, rotate(sx, false), sy, rotate(ex, true), ey, null);
        }
    }

    public boolean verticalChange(boolean heightOnly) {
        synchronized (paintLock) {
            synchronized (dataLock) {
                // Force rebuilding the image next time it is needed
                VolatileImage image = imageRef.get();
                if (image != null)
                    image.flush();
                imageRef.clear();
                if (!heightOnly) {
                    BufferedImage unscaled = unscaledImageRef.get();
                    if (unscaled != null)
                        unscaled.flush();
                    unscaledImageRef.clear();
                    
                }
                Arrays.fill(lookupCache, Float.NaN);
            }
        }
        return false;
    }

    public void dispose() {
        BufferedImage unscaled = unscaledImageRef.get();
        if (unscaled != null)
            unscaled.flush();
        VolatileImage image = imageRef.get();
        if (image != null)
            image.flush();
    }

    @Override
    protected void finalize() throws Throwable {
        dispose();
    }
    
}
