import React, { useRef, useState, useMemo } from 'react'
import { Canvas, useFrame, extend, useThree, invalidate } from '@react-three/fiber'
import {
    TextureLoader, ImageLoader, ShaderMaterial, Texture, DataTexture, WebGLRenderTarget,
    WebGLRenderer
} from 'three'
import * as THREE from 'three'
import { ExtractMaskShader, ApplyLayerAdjustmentsShader, ExpandShader, StaticBillboardShader } from "./Shaders"
import { ShaderPass } from "three/examples/jsm/postprocessing/ShaderPass";
import { RenderPass } from 'three/examples/jsm/postprocessing/RenderPass';
import { AddPass } from './Effects/AddPass'
import { HorizontalBlurShader } from 'three/examples/jsm/shaders/HorizontalBlurShader'
import { VerticalBlurShader } from 'three/examples/jsm/shaders/VerticalBlurShader'
import { EffectComposer } from "three/examples/jsm/postprocessing/EffectComposer";
import { CustomSlider } from '../Components/CustomSlider'

extend({ ShaderPass, EffectComposer, RenderPass });

const labelsTexture = "/img/labels.png"
const colorTexture = "/img/groupshot.png"

function getMaterialFromShader(shaderClass) {
    return new ShaderMaterial({
        uniforms: { ...THREE.UniformsUtils.clone(shaderClass.uniforms) },
        vertexShader: shaderClass.vertexShader,
        fragmentShader: shaderClass.fragmentShader
    });
}


function Effects(props) {
    //update: we need one effectscomposer per mask layer, render them only on demand
    const { nLabels, setNLabels, canvasDims, setCanvasDims, parameters,
        hoveredIndex, setHoveredIndex, pointerPos } = props;
    const { scene, gl, size, camera, viewport, invlidate } = useThree()
    const origImageTextureRef = useRef();
    const origAlphaTextureRef = useRef();
    const origAlphaTextureDataRef = useRef();
    const blurPassesRef = useRef();
    const expandPassesRef = useRef();

    const [layerComposers, finalComposer,
        extractMaskMaterial, applyAdjustmentsMaterial, addPass] = useMemo(() => {

        // create materials
        const extractMaskMaterial = getMaterialFromShader(ExtractMaskShader);

        const expandMaskMaterial = getMaterialFromShader(ExpandShader);

        const applyAdjustmentsMaterial = getMaterialFromShader(ApplyLayerAdjustmentsShader);

        // create composers for each layer
        const layerComposers = [];
        const blurPasses = [];
        const expandPasses = [];
        for (let i = 0; i < nLabels + 1; i++) {
            //create composer with offscreen target, for doing image processing with doesn't get displayed
            const offscreenTarget = new WebGLRenderTarget(size.width, size.height);
            const composer = new EffectComposer(gl, offscreenTarget);
            composer.renderToScreen = false;

            //create passes
            const extractMaskPass = new ShaderPass(extractMaskMaterial);
            const applyAdjustmentsPass = new ShaderPass(applyAdjustmentsMaterial);
            const hblurPass = new ShaderPass(HorizontalBlurShader);
            const vblurPass = new ShaderPass(VerticalBlurShader);
            const expandPass = new ShaderPass(expandMaskMaterial);

            //these arrays will be accessible through refs
            blurPasses.push({hBlur: hblurPass, vBlur: vblurPass});
            expandPasses.push(expandPass);

            //add passes to composer
            composer.addPass(extractMaskPass);
            composer.addPass(expandPass);
            composer.addPass(hblurPass);
            composer.addPass(vblurPass);
            composer.addPass(applyAdjustmentsPass);

            //composers need to be accessible through ref
            layerComposers.push(composer);
        }

        blurPassesRef.current = blurPasses;
        expandPassesRef.current = expandPasses;

        //create stuff for final pass, combining all layers
        const addPass = new AddPass(size);
        const finalComposer = new EffectComposer(gl);
        const finalMaterial = getMaterialFromShader(StaticBillboardShader);
        const finalPass = new ShaderPass(finalMaterial);

        // add passes to composer
        finalComposer.addPass(addPass);
        finalComposer.addPass(finalPass);

        return [layerComposers, finalComposer, extractMaskMaterial, applyAdjustmentsMaterial, addPass];
    }, [size.width, size.height, nLabels]);

    const dummy = useMemo(() => {
        // load color texture and assign to ref when done
        new TextureLoader().load(colorTexture, (texture) => {
            origImageTextureRef.current = texture;
        });

        // load label texture and extract layer information when done
        new ImageLoader().load(labelsTexture, (image) => {
            const canvas = document.createElement('canvas');
            canvas.width = image.width;
            canvas.height = image.height;
            const context = canvas.getContext('2d');
            context.drawImage(image, 0, 0);
            const w = image.width;
            const h = image.height;
            const { width, height } = canvasDims;
            const ratio = w / h;
            setCanvasDims({ width: width, height: width / ratio })

            const imageData = context.getImageData(0, 0, w, h).data;
            origAlphaTextureDataRef.current = { data: imageData, w: w, h: h };
            origAlphaTextureRef.current = new Texture(image);
            origAlphaTextureRef.current.needsUpdate = true;

            let maxVal = 0;

            for (let y = 0; y < h; y++) {
                for (let x = 0; x < w; x++) {
                    const ind = (y * w + x) * 4;
                    maxVal = Math.max(maxVal, imageData[ind]);
                }
            }

            setNLabels(maxVal);

            const textureData = new Uint8Array(new ArrayBuffer(w * h * (maxVal + 1)));
            for (let y = 0; y < h; y++) {
                for (let x = 0; x < w; x++) {
                    const imageDataInd = (y * w + x) * 4;
                    const textureInd = imageData[imageDataInd];
                    const dataInd = (y * w + x) + textureInd * (w * h);
                    textureData[dataInd] = 255;

                    maxVal = Math.max(maxVal, imageData[imageDataInd]);
                }
            }

        }, (err) => console.log(err));
    }, []);

    useFrame(() => {
        // compute which layer is hovered over
        if (origAlphaTextureDataRef.current) {
            const x_normd = Math.min(Math.max(pointerPos.x / size.width, 0), 1);
            const y_normd = Math.min(Math.max(pointerPos.y / size.height, 0), 1);
            const x = Math.floor(x_normd * (origAlphaTextureDataRef.current.w - 1));
            const y = Math.floor(y_normd * (origAlphaTextureDataRef.current.h - 1))

            let hoveredId = -1;

            // if pointer x and y are within image range, look up hovered label from label image data
            if(Math.abs(x_normd*2-1) < 1 && Math.abs(y_normd*2-1) < 1) {
                const dataIndex = (y * origAlphaTextureDataRef.current.w + x) * 4;
                hoveredId = origAlphaTextureDataRef.current.data[dataIndex];
            }

            // if different layer or no layer is hovered, update hovered index
            if (hoveredId !== hoveredIndex || hoveredId === -1) {
                setHoveredIndex(hoveredId);
            }
        }

        layerComposers.forEach((comp, i) => {
            if (comp) {
                if (origAlphaTextureRef.current) {
                    // set which mask to extract from original label input
                    extractMaskMaterial.uniforms.tMasks.value = origAlphaTextureRef.current;
                    extractMaskMaterial.uniforms.maskId.value = i;

                    // assign color image to layer adjustment material
                    applyAdjustmentsMaterial.uniforms.tOrigImage.value = origImageTextureRef.current;

                    if(parameters[i]) {

                        //update mask post processing values
                        const bluriness = parameters[i].feather;

                        blurPassesRef.current[i].hBlur.uniforms.h.value = bluriness / size.width;
                        blurPassesRef.current[i].vBlur.uniforms.v.value = bluriness / size.width;

                        expandPassesRef.current[i].uniforms.radius.value = Math.round(parameters[i].expansion)
                        expandPassesRef.current[i].uniforms.stepSize.value = 1.0 / size.width;

                        // update layer adjustment values
                        const br = parameters[i].br;
                        const hover = hoveredIndex == i;

                        applyAdjustmentsMaterial.uniforms.gamma.value = (hover ? 1.2 : 1) * (br<1 ? br : (br-1.0)*5+1);
                        applyAdjustmentsMaterial.uniforms.contrast.value = Math.pow(parameters[i].contrast, 2.0);
                        applyAdjustmentsMaterial.uniforms.hue.value = parameters[i].hue;
                        applyAdjustmentsMaterial.uniforms.sat.value = parameters[i].sat;
                    }
                }
                comp.render();
            }
        });
        // set per-layer-composer outputs as input for final add pass
        addPass.texturesToAdd = [{texture: origImageTextureRef.current, weight: 0.25}, ...layerComposers.map((composer) => {
            return {texture: composer.readBuffer.texture, weight: 1.0};
        })];
        finalComposer.render();
    }, 1);

    return null;
}

const getDefaultSliderValues = () => {
    return { br: 1.0, contrast: 1.0, hue: 0.0, sat: 1.0, feather: 2.0, expansion: 0.0 };
}

export function LayeringLab(props) {
    const [canvasDims, setCanvasDims] = useState({ width: 700, height: 700 });
    const [nLabels, setNLabels] = useState(-1);
    const [parameters, setParameters] = useState({});
    const [hoveredIndex, setHoveredIndex] = useState(-1);
    const [selectedIndex, setSelectedIndex] = useState(-1);
    const [pointerPos, setPointerPos] = useState({x:-1, y:-1});

    // initialize parameters for each layer
    const dummy = useMemo(() => {
        for (let i = 0; i < nLabels; i++) {
            if (!(i in parameters)) {
                parameters[i] = getDefaultSliderValues();
            }
        }
    }, [nLabels]);

    // helper for disabling parameter sliders when no layer is selected
    const pvalid = selectedIndex !== -1 && parameters[selectedIndex];

    return <div style={{ display: 'flex', flexDirection: "row" }}>
        <div style={{ width: "200px", padding: "10px", display: 'flex', flexDirection: "column" }}>
            <CustomSlider labelText="Brightness"
                onChange={(ev, val) => {
                    if(selectedIndex != -1) {
                        const newParams = {...parameters};
                        newParams[selectedIndex].br = val;
                        setParameters(newParams);
                    }
                }}
                step={0.01}
                min={0.1}
                value={pvalid ? parameters[selectedIndex].br : 0.0}
                disabled={pvalid ? false : true}
                max={1.9}
            />
            <CustomSlider labelText="Contrast"
                onChange={(ev, val) => {
                    if(selectedIndex != -1) {
                        const newParams = {...parameters};
                        newParams[selectedIndex].contrast = val;
                        setParameters(newParams);
                    }
                }}
                step={0.01}
                min={0.5}
                value={pvalid ? parameters[selectedIndex].contrast : 0.0}
                disabled={pvalid ? false : true}
                max={1.5}
            />
            <CustomSlider labelText="Hue"
                onChange={(ev, val) => {
                    if(selectedIndex != -1) {
                        const newParams = {...parameters};
                        newParams[selectedIndex].hue = val;
                        setParameters(newParams);
                    }
                }}
                step={0.01}
                min={-0.5}
                value={pvalid ? parameters[selectedIndex].hue : 0.0}
                disabled={pvalid ? false : true}
                max={0.5} />
            <CustomSlider labelText="Saturation"
                onChange={(ev, val) => {
                    if(selectedIndex != -1) {
                        const newParams = {...parameters};
                        newParams[selectedIndex].sat = val;
                        setParameters(newParams);
                    }
                }}
                step={0.01}
                min={0}
                value={pvalid ? parameters[selectedIndex].sat : 0.0}
                disabled={pvalid ? false : true}
                max={2.0}
            />
            <CustomSlider labelText="Mask expansion"
                onChange={(ev, val) => {
                    if(selectedIndex != -1) {
                        const newParams = {...parameters};
                        newParams[selectedIndex].expansion = val;
                        setParameters(newParams);
                    }
                }}
                step={1}
                min={-10}
                value={pvalid ? parameters[selectedIndex].expansion : 0.0}
                disabled={pvalid ? false : true}
                max={10}
            />
            <CustomSlider labelText="Mask feather"
                onChange={(ev, val) => {
                    if(selectedIndex != -1) {
                        const newParams = {...parameters};
                        newParams[selectedIndex].feather = val;
                        setParameters(newParams);
                    }
                }}
                step={0.01}
                min={0}
                value={pvalid ? parameters[selectedIndex].feather : 0.0}
                disabled={pvalid ? false : true}
                max={5.0}
            />
        </div>
        <div onClick={(e) => {
                if(selectedIndex == hoveredIndex) {
                    setSelectedIndex(-1);
                } else {
                    setSelectedIndex(hoveredIndex);
                }
            }}
            onPointerMove={(e) => {
                setPointerPos({x: e.nativeEvent.offsetX, y:e.nativeEvent.offsetY});
            }}
            onPointerOut={(e) => {
                setPointerPos({x: -1, y:-1});
                if (!hoveredIndex !== -1) {
                    setHoveredIndex(-1);
                }
            }}>
        <Canvas style={canvasDims} frameloop={'demand'}>
            <mesh />
            <Effects
                nLabels={nLabels}
                setNLabels={setNLabels}
                canvasDims={canvasDims}
                setCanvasDims={setCanvasDims}
                selectedIndex={selectedIndex}
                hoveredIndex={hoveredIndex}
                setHoveredIndex={setHoveredIndex}
                parameters={parameters}
                pointerPos={pointerPos}/>
        </Canvas>
        </div>
    </div>
        ;
}