import * as THREE from 'three';
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls.js';
import { MeshSurfaceSampler } from 'three/addons/math/MeshSurfaceSampler.js';
import { GLTFLoader } from 'three/examples/jsm/loaders/GLTFLoader.js';
import { DRACOLoader } from 'three/examples/jsm/loaders/DRACOLoader.js';
import gsap from 'gsap'
import ScrollTrigger from 'gsap/ScrollTrigger'
import { sections_count, sections_details } from '../data/sections';

// console.log("shadertest", shadertest);
const isMobile = /iPhone|iPad|iPod|Android|BlackBerry|Windows Phone/i.test(navigator.userAgent);
let SIM_WIDTH = 512;
let SIM_HEIGHT = 512;
let scene, camera, renderer;
let textures;
let materials;
let particlesComputeProgram;
let pointsRenderProgram;
let controls;
let simStep = 0;
let isReady = false;
const raycaster = new THREE.Raycaster();
const pointer = new THREE.Vector2();
const clock = new THREE.Clock();
let currentMesh = undefined;
let primaryMesh = undefined;
let secondaryMesh = undefined;
const meshes = [primaryMesh, secondaryMesh];
const cursor3D = new THREE.Vector3();
const targetCursor3D = new THREE.Vector3();
const cursor3DVec4 = new THREE.Vector4();
const cursor3DVec3Sim = new THREE.Vector3();
const inverseMeshWorldMatrix = new THREE.Matrix4();
let dt = 1;

// For computational purposes.
const rayOrigin = new THREE.Vector3();
const rayDir = new THREE.Vector3();

let prevFrameTime = clock.getElapsedTime();

gsap.registerPlugin(ScrollTrigger)

let previous_section = null;

function animateModel(current_section, globalScrollProgress, localScrollProgress, dt) {
    const lerp = THREE.MathUtils.lerp;
    // console.log("SECTION", current_section, "GLOBAL_SCROLL", globalScrollProgress, "LOCAL_SCROLL", localScrollProgress)

    if (pointsRenderProgram) {
        const mesh = pointsRenderProgram.pointsMesh;
        const prevData = sections_details[current_section - 2] || sections_details[0];
        const curData = sections_details[current_section - 1];

        // Interpolating values between sections and then use gsap to smoothly set it.
        // Every position, rotation or scale change happens near instantaneous but not exactly.
        // In this case it takes roughly 4-6 frames, which is enough to make it feel fluid,
        // and have a sense of weight.

        // Positions
        const nX = lerp(mesh.position.x, lerp(prevData.position.x, curData.position.x, localScrollProgress), 7 * dt);
        const nY = lerp(mesh.position.y, lerp(prevData.position.y, curData.position.y, localScrollProgress), 7 * dt);
        const nZ = lerp(mesh.position.z, lerp(prevData.position.z, curData.position.z, localScrollProgress), 7 * dt);
        mesh.position.set(nX, nY, nZ);

        // Rotation
        const rotX = lerp( mesh.rotation.x, lerp(prevData.rotation.x, curData.rotation.x, localScrollProgress), 7 * dt);
        const rotY = lerp( mesh.rotation.y, lerp(prevData.rotation.y, curData.rotation.y, localScrollProgress), 7 * dt);
        const rotZ = lerp( mesh.rotation.z, lerp(prevData.rotation.z, curData.rotation.z, localScrollProgress), 7 * dt);
        mesh.rotation.set(rotX, rotY, rotZ);

        // Scale
        const scale = lerp(prevData.scale, curData.scale, localScrollProgress);
        mesh.scale.set(scale, scale, scale);

        // Noise scale.
        const noiseScale = lerp(prevData.noiseScale, curData.noiseScale, localScrollProgress);
        materials.simShaderMaterial.uniforms.uNoiseScale.value = noiseScale;

        // Spawn point mix
        const modelMix = lerp(prevData.mix, curData.mix, localScrollProgress);
        materials.simShaderMaterial.uniforms.uOriginPointMix.value = modelMix;

        currentMesh = meshes[Math.round(modelMix)] || meshes[0];

        // Setting uNoiseMagnitude instantaneously, without gsap on purpose. Looks better that way.
        materials.simShaderMaterial.uniforms.uNoiseMagnitude.value = Math.sin(modelMix * Math.PI) * .05 + curData.baseNoiseMagnitude;

        const pointerDisplacementMagnitude = lerp(prevData.pointerDisplacementMagnitude, curData.pointerDisplacementMagnitude, localScrollProgress);
        materials.simShaderMaterial.uniforms.uPointerDisplacementMagnitude.value = pointerDisplacementMagnitude;
    }
}

/**
 * Loaders
 */
// Texture loader
// const textureLoader = new THREE.TextureLoader();

// Draco loader
const dracoLoader = new DRACOLoader();
dracoLoader.setDecoderPath('draco/');

// GLTF loader
const gltfLoader = new GLTFLoader();
gltfLoader.setDRACOLoader(dracoLoader);

async function loadShaders() {
    const simFragmentSrcPath = isMobile ? '../shaders/sim_fragment_mobile.glsl' : '../shaders/sim_fragment.glsl';

    let result = await fetch('../shaders/sim_vertex.glsl');
    const simVertex = await result.text();

    result = await fetch(simFragmentSrcPath);
    const simFragment = await result.text();

    result = await fetch('../shaders/points_vertex.glsl');
    const pointsVertex = await result.text();

    result = await fetch('../shaders/points_fragment.glsl');
    const pointsFragment = await result.text();

    return {
        simVertex,
        simFragment,
        pointsVertex,
        pointsFragment,
    };
}

function sampleMeshSurface(width, height, mesh) {
    if (!mesh) {
        console.error('Mesh is undefined!');
        return;
    }

    let i, l;

    // TODO: ensure this works for .glbs
    if (!(mesh.material instanceof THREE.MeshBasicMaterial)) {
        const material = new THREE.MeshBasicMaterial({ color: 0xff0000 });
        mesh.material = material;
    }

    // Create a sampler for a Mesh surface.
    const sampler = new MeshSurfaceSampler(mesh)
        .setWeightAttribute('color')
        .build();

    // Uncomment desired data samples.
    const position = new THREE.Vector3();
    const normal = new THREE.Vector3();
    // const color = new THREE.Vector3();
    // const uv = new THREE.Vector3();

    // Positions and life-time.
    const surfacePoints = new Float32Array(width * height * 4);
    const surfaceNormals = new Float32Array(width * height * 4);

    for (i = 0, l = width * height; i < l; i++) {
        const i4 = i * 4;
        sampler.sample(
            position,
            normal,
            // color,
            // uv
        );
        surfacePoints[i4] = position.x;
        surfacePoints[i4 + 1] = position.y;
        surfacePoints[i4 + 2] = position.z;

        // Initial life-time.
        surfacePoints[i4 + 3] = Math.random();

        surfaceNormals[i4] = normal.x;
        surfaceNormals[i4 + 1] = normal.y;
        surfaceNormals[i4 + 2] = normal.z;
    }

    return {
        surfacePoints,
        surfaceNormals
    };
}

function resampleToTorusKnot(width, height) {
    const torusGeometry = new THREE.TorusKnotGeometry(1, .25, 100, 16);
    const material = new THREE.MeshBasicMaterial({ color: 0xffffff });
    const mesh = new THREE.Mesh(torusGeometry, material);

    return sampleMeshSurface(width, height, mesh);
}

// eslint-disable-next-line no-unused-vars
function resampleToCone(width, height) {
    const coneGeometry = new THREE.ConeGeometry(.5, 1, 32);
    const material = new THREE.MeshBasicMaterial({ color: 0xffffff });
    const mesh = new THREE.Mesh(coneGeometry, material);

    return sampleMeshSurface(width, height, mesh);
}

function resampleToBox(width, height) {
    const boxGeometry = new THREE.BoxGeometry(1, 1, 1);
    const material = new THREE.MeshBasicMaterial({ color: 0xffffff });
    const mesh = new THREE.Mesh(boxGeometry, material);

    return sampleMeshSurface(width, height, mesh);
}

function glbToMeshSurfacePoints(glbModel) {
    const width = SIM_WIDTH;
    const height = SIM_HEIGHT;
    const mesh = glbModel.scene.children.find(child => child instanceof THREE.Mesh);
    const data = sampleMeshSurface(width, height, mesh);

    const originalPositionsDataTexture = new THREE.DataTexture(data.surfacePoints, width, height, THREE.RGBAFormat, THREE.FloatType);
    const originalNormalsDataTexture = new THREE.DataTexture(
        data.surfaceNormals, width, height, THREE.RGBAFormat, THREE.FloatType,
        undefined,
        undefined,
        undefined,
        THREE.LinearFilter,
        THREE.LinearFilter
    );
    originalPositionsDataTexture.needsUpdate = true;
    originalNormalsDataTexture.needsUpdate = true;

    return {
        mesh,
        originalPositionsDataTexture,
        originalNormalsDataTexture
    };
}

function setupTextureResources(params) {
    const { width, height } = params;

    // NOTE! Code below is for tests in absence of model assets. Uncomment if needed.

    // let data = params.data;
    // let altData = params.altData;

    // if (!data) {
    //     data = resampleToBox(width, height);
    //     // data = resampleToCone(width, height);
    // }

    // if (!altData) {
    //     altData = resampleToTorusKnot(width, height);
    // }

    // const originalPositionDataTexture = new THREE.DataTexture(data.surfacePoints, width, height, THREE.RGBAFormat, THREE.FloatType);
    // const originalNormalsDataTexture = new THREE.DataTexture(
    //     data.surfaceNormals, width, height, THREE.RGBAFormat, THREE.FloatType,
    //     undefined,
    //     undefined,
    //     undefined,
    //     THREE.LinearFilter,
    //     THREE.LinearMipmapLinearFilter
    // );
    // originalPositionDataTexture.needsUpdate = true;
    // originalNormalsDataTexture.needsUpdate = true;

    // const originalPositionDataTextureAlt = new THREE.DataTexture(altData.surfacePoints, width, height, THREE.RGBAFormat, THREE.FloatType);
    // const originalNormalsDataTextureAlt = new THREE.DataTexture(
    //     altData.surfaceNormals, width, height, THREE.RGBAFormat, THREE.FloatType,
    //     undefined,
    //     undefined,
    //     undefined,
    //     THREE.LinearFilter,
    //     THREE.LinearMipmapLinearFilter
    // );
    // originalPositionDataTextureAlt.needsUpdate = true;
    // originalNormalsDataTextureAlt.needsUpdate = true;

    // NOTE! type can be both THREE.FloatType and THREE.HalfFloatType for compute render targets.
    // HalfFloat uses 16-bit floating point textures which in some cases allows to achieve faster performance.
    const rtParams = {
        minFilter: THREE.NearestFilter,
        magFilter: THREE.NearestFilter,
        format: THREE.RGBAFormat,
        type: THREE.HalfFloatType
    };

    const computeRenderTarget0 = new THREE.WebGLRenderTarget(width, height, rtParams);
    const computeRenderTarget1 = new THREE.WebGLRenderTarget(width, height, rtParams);

    return {
        // originalPositionDataTexture,
        // originalNormalsDataTexture,
        // originalPositionDataTextureAlt,
        // originalNormalsDataTextureAlt,
        computeRenderTargets: [computeRenderTarget0, computeRenderTarget1]
    };
}

function setupShaderMaterials(shaders, textures) {
    // This is analogy of compute shader which calculates positions of the particles
    // for the next simulation step, hence the name.
    const simShaderMaterial = new THREE.ShaderMaterial({
        vertexShader: shaders.simVertex,
        fragmentShader: shaders.simFragment,
        uniforms: {
            uTime: {
                value: 0
            },

            uPointerPos: {
                value: new THREE.Vector3(0)
            },

            uDt: {
                value: 0
            },

            uParticlesLifetime: {
                value: 1
            },

            uNoiseScale: {
                value: 1
            },

            uNoiseMagnitude: {
                value: 1
            },

            uPointerDisplacementMagnitude: {
                value: 1
            },

            uOriginPointMix: {
                value: 0
            },

            uParticlesOriginPosition: {
                type: 't',
                value: textures.originalPositionDataTexture
            },

            uParticlesOriginNormal: {
                type: 't',
                value: textures.originalNormalsDataTexture
            },

            uParticlesOriginPositionAlt: {
                type: 't',
                value: textures.originalPositionDataTextureAlt
            },

            uParticlesOriginNormalAlt: {
                type: 't',
                value: textures.originalNormalsDataTextureAlt
            },

            uParticlesPositions: {
                type: 't',
                value: textures.originalPositionDataTexture
            }
        }
    });

    // This one just takes positions calculated in the simulation and applies them
    // to vertices of THREE.Points mesh.
    const pointsRenderShaderMaterial = new THREE.ShaderMaterial({
        vertexShader: shaders.pointsVertex,
        fragmentShader: shaders.pointsFragment,
        uniforms: {
            uTime: { value: 0 },

            uPointerPos: {
                value: new THREE.Vector3(0)
            },

            uParticlesLifetime: {
                value: 1
            },

            uParticleStartColor: {
                value: new THREE.Color(0x8c2eff)
            },

            uParticleEndColor: {
                value: new THREE.Color(0x6bdef5)
            },

            uParticleTouchColor: {
                value: new THREE.Color(0xff0000)
            },

            uParticlesOutput: {
                type: 't',
                value: null
            },
        },
        blending: THREE.AdditiveBlending,
        transparent: true,
        depthWrite: false
    });

    return {
        simShaderMaterial,
        pointsRenderShaderMaterial
    };
}

function setupParticlesComputePorgram(pipelineParams = {}) {
    const { materials } = pipelineParams;
    const scene = new THREE.Scene();

    // TODO: why 2^53??
    const camera = new THREE.OrthographicCamera(
        -1,
        1,
        1,
        -1,
        1 / Math.pow(2, 53),
        1
    );
    // const camera = new THREE.OrthographicCamera( width / - 2, width / 2, height / 2, height / - 2, 1, 1000 );

    const quadVertices = new Float32Array([
        -1, -1, 0, 1, -1, 0, 1, 1, 0,

        1, 1, 0, -1, 1, 0, -1, -1, 0,
    ]);

    const quadUVs = new Float32Array([
        0, 0, 1, 0, 1, 1,

        1, 1, 0, 1, 0, 0,
    ]);

    const quadGeometry = new THREE.BufferGeometry();
    quadGeometry.setAttribute(
        'position',
        new THREE.BufferAttribute(quadVertices, 3)
    );
    quadGeometry.setAttribute('uv', new THREE.BufferAttribute(quadUVs, 2));
    const quadMesh = new THREE.Mesh(quadGeometry, materials.simShaderMaterial);

    scene.add(camera);
    scene.add(quadMesh);

    return {
        scene,
        camera,
    };
}

function setupPointsRenderProgram(pipelineParams = {}) {
    const { width, height, materials } = pipelineParams;
    const pointsGeometry = new THREE.BufferGeometry();
    const positions = new Float32Array(width * height * 3);

    for (let i = 0, l = width * height; i < l; i++) {
        const i3 = i * 3;
        positions[i3] = (i % width) / width;
        positions[i3 + 1] = i / width / height;
    }

    pointsGeometry.setAttribute(
        'position',
        new THREE.BufferAttribute(positions, 3)
    );

    const pointsMesh = new THREE.Points(
        pointsGeometry,
        materials.pointsRenderShaderMaterial
    );

    return {
        pointsMesh,
    };
}

function setParticlesMeshAndData(glbModel, meshIndex = 0)
{
    const {
        mesh,
        originalPositionsDataTexture,
        originalNormalsDataTexture
    } = glbToMeshSurfacePoints(glbModel);

    if (meshIndex === 0)
    {
        if (primaryMesh)
        {
            if (currentMesh === primaryMesh)
            {
                currentMesh = mesh;
            }

            scene.remove(primaryMesh);
            primaryMesh.material.dispose();
            primaryMesh.geometry.dispose();
        }

        primaryMesh = mesh;
        primaryMesh.visible = false;
        scene.add(primaryMesh);

        meshes[meshIndex] = primaryMesh;

        materials.simShaderMaterial.uniforms.uParticlesPositions.value = originalPositionsDataTexture;
        materials.simShaderMaterial.uniforms.uParticlesOriginPosition.value = originalPositionsDataTexture;
        materials.simShaderMaterial.uniforms.uParticlesOriginNormal.value = originalNormalsDataTexture;
    }
    else
    {
        if (secondaryMesh)
        {
            if (currentMesh === secondaryMesh)
            {
                currentMesh = mesh;
            }

            scene.remove(secondaryMesh);
            secondaryMesh.material.dispose();
            secondaryMesh.geometry.dispose();
        }

        secondaryMesh = mesh;
        secondaryMesh.visible = false;
        scene.add(secondaryMesh);

        meshes[meshIndex] = secondaryMesh;

        materials.simShaderMaterial.uniforms.uParticlesOriginPositionAlt.value = originalPositionsDataTexture;
        materials.simShaderMaterial.uniforms.uParticlesOriginNormalAlt.value = originalNormalsDataTexture;
    }
}

export function setParticlesMesh(params)
{
    let loadingPromiseResolve = () => {};
    const loadPromise = new Promise((res, rej) => { loadingPromiseResolve = res; });

    if (params.modelPath)
    {
        gltfLoader.load(params.modelPath, (glbModel) => {
            setParticlesMeshAndData(glbModel, params.meshIndex);
            loadingPromiseResolve();
        });
    }
    else if (params.glb)
    {
        setParticlesMeshAndData(params.glb, params.meshIndex);
        loadingPromiseResolve();
    }
    else
    {
        loadingPromiseResolve();
    }

    return loadPromise;
}

export async function init(threeParams = {}, simParams = {}) {
    const shaders = await loadShaders();

    // TODO: move this outside.
    ScrollTrigger.create({
        trigger: "#main",
        start: "top bottom",
        end: "bottom bottom",
        markers: false,
        onUpdate: self => {
            // Check if we are in a new section
            const current_section = Math.max(Math.ceil(self.progress * sections_count), 1);
            let localScrollProgress = (self.progress * sections_count) % 1;
            // Do not remove this, it avoid to go back to 0 at the end of the page
            if (self.progress === 1)
            {
                localScrollProgress = 1;
            }

            animateModel(current_section, self.progress, localScrollProgress, dt);

            if (current_section !== previous_section) {
                previous_section = current_section;
            }
        }
    });

    addEventListeners();

    scene = threeParams.scene;
    camera = threeParams.camera;
    renderer = threeParams.renderer;

    const { width, height } = simParams;
    SIM_WIDTH = width;
    SIM_HEIGHT = height;

    const canvas = renderer.domElement;
    controls = new OrbitControls(camera, canvas);

    textures = setupTextureResources({ width, height });
    materials = setupShaderMaterials(shaders, textures);
    particlesComputeProgram = setupParticlesComputePorgram({
        width,
        height,
        materials,
    });
    pointsRenderProgram = setupPointsRenderProgram({ width, height, materials });

    if (simParams.glbModels)
    {
        // TODO: this should be able to handle more than two models.
        setParticlesMesh({ glb: simParams.glbModels[0], meshIndex: 0 });
        setParticlesMesh({ glb: simParams.glbModels[1], meshIndex: 1 });
    }

    animateModel(1, 0, 0, dt);

    isReady = true;

    return {
        pointsMesh: pointsRenderProgram.pointsMesh,
        particlesComputeProgram,
        pointsRenderProgram,
        materials
    };
}

function updateRaycaster() {
    if (currentMesh) {
        currentMesh.position.copy(pointsRenderProgram.pointsMesh.position);
        currentMesh.rotation.copy(pointsRenderProgram.pointsMesh.rotation);
        currentMesh.scale.copy(pointsRenderProgram.pointsMesh.scale);
        raycaster.setFromCamera(pointer, camera);
        const intersects = raycaster.intersectObject(currentMesh);

        if (intersects.length > 0) {
            targetCursor3D.copy(intersects[0].point);
        } else {
            const currentIntersection = materials.pointsRenderShaderMaterial.uniforms.uPointerPos.value;
            rayOrigin.copy(raycaster.ray.origin);
            rayDir.copy(raycaster.ray.direction);
            const dist = camera.position.distanceTo(currentIntersection);
            rayDir.multiplyScalar(dist);
            rayOrigin.add(rayDir);
            targetCursor3D.set(rayOrigin.x, rayOrigin.y, rayOrigin.z);
        }

        materials.pointsRenderShaderMaterial.uniforms.uPointerPos.value = cursor3D;

        cursor3DVec4.set(cursor3D.x, cursor3D.y, cursor3D.z, 1.0);
        inverseMeshWorldMatrix.copy(currentMesh.matrixWorld);
        inverseMeshWorldMatrix.invert();
        cursor3DVec4.applyMatrix4(inverseMeshWorldMatrix);
        cursor3DVec3Sim.set(cursor3DVec4.x, cursor3DVec4.y, cursor3DVec4.z);
        materials.simShaderMaterial.uniforms.uPointerPos.value = cursor3DVec3Sim;
    }
}

export function update() {
    const elapsedTime = clock.getElapsedTime();
    dt = elapsedTime - prevFrameTime;
    // TODO: this is temporary, until loading shaders is resolved.
    if (isReady) {
        controls.update();
        updateRaycaster();
        materials.simShaderMaterial.uniforms.uTime.value = elapsedTime;
        materials.simShaderMaterial.uniforms.uDt.value = dt;
        renderer.setRenderTarget(textures.computeRenderTargets[simStep]);
        renderer.render(
            particlesComputeProgram.scene,
            particlesComputeProgram.camera
        );
        // renderer.render(particlesComputeProgram.scene, camera);

        // materials.pointsRenderShaderMaterial.uniforms.uTime.value = elapsedTime;
        renderer.setRenderTarget(null);
        materials.pointsRenderShaderMaterial.uniforms.uParticlesOutput.value =
            textures.computeRenderTargets[simStep].texture;

        materials.simShaderMaterial.uniforms.uParticlesPositions.value =
            textures.computeRenderTargets[simStep].texture;
        simStep = (simStep + 1) % 2;
        prevFrameTime = elapsedTime;

        animate3DCursor(dt);

        // Constantly rotate the model
        pointsRenderProgram.pointsMesh.rotation.y += 0.15 * dt;
    }
}

function animate3DCursor(dt) {
    const lerp = THREE.MathUtils.lerp;
    const lerp_speed = 7 * dt;

    const nX = lerp(cursor3D.x, targetCursor3D.x, lerp_speed);
    const nY = lerp(cursor3D.y, targetCursor3D.y, lerp_speed);
    const nZ = lerp(cursor3D.z, targetCursor3D.z, lerp_speed);

    cursor3D.set(nX, nY, nZ);
}

function handleFileDrop(e) {
    e.preventDefault();
    e.stopPropagation();

    if (!(e.dataTransfer.files[0] instanceof Blob)) {
        return;
    }

    const r = new FileReader();
    r.onload = function (readRes) {
        console.log(readRes);
        gltfLoader.parse(readRes.target.result, null, (result) => {
            console.log(result);

            const {
                mesh,
                originalPositionsDataTexture,
                originalNormalsDataTexture
            } = glbToMeshSurfacePoints(result);

            currentMesh = mesh;
            mesh.visible = false;
            scene.add(mesh);

            materials.simShaderMaterial.uniforms.uParticlesOriginPosition.value = originalPositionsDataTexture;
            materials.simShaderMaterial.uniforms.uParticlesOriginNormal.value = originalNormalsDataTexture;
        });
    };
    r.readAsArrayBuffer(e.dataTransfer.files[0]);
}

function handlePointermove(e) {
    pointer.x = (e.clientX / window.innerWidth) * 2 - 1;
    pointer.y = - (e.clientY / window.innerHeight) * 2 + 1;
}

function handleTouchmove(e) {
    pointer.x = (e.touches[0].clientX / window.innerWidth) * 2 - 1;
    pointer.y = - (e.touches[0].clientY / window.innerHeight) * 2 + 1;
}

function addEventListeners() {
    window.addEventListener('drop', handleFileDrop);
    window.addEventListener('dragover', (e) => e.preventDefault());

    if (isMobile)
    {
        // Using touchmove, due to pointermove being cancelled on scroll for mobiles.
        // Better ideas welcomed!
        window.addEventListener('touchmove', handleTouchmove);
    }
    else
    {
        window.addEventListener('pointermove', handlePointermove);
    }
}
