#define GLFW_DLL
// #define TINYGLTF_NOEXCEPTION // optional. disable exception handling.

#include <iostream>
#include <stdexcept>
#include <cstdio>
#include <cstdlib>
#include <string>
#include <cstring>
#include <vector>

#include <Eigen/Eigen>
#include "imgui.h"
#include "imgui_impl_glfw_gl3.h"
#include <glad/glad.h>
#include <GLFW/glfw3.h>
// Define these only in *one* .cpp file.
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "pyramids.hpp"
#include "utils.h"

using namespace Eigen;

static void glfw_error_callback(int error, const char *description)
{
    trace("Error " << error << " : " << description);
}

void keyCallback(GLFWwindow *window, int key, int scanCode, int action, int mods)
{
    ImGui_ImplGlfw_KeyCallback(window, key, scanCode, action, mods);
    if(key == GLFW_KEY_ESCAPE && action == GLFW_PRESS && !ImGui::GetIO().WantCaptureKeyboard)
        glfwSetWindowShouldClose(window, true);
}

void mouseButtonCallback(GLFWwindow *window, int button, int action, int mods)
{
    ImGui_ImplGlfw_MouseButtonCallback(window, button, action, mods);
}

typedef struct
{
    float x, y ,f;
} Spreadlet;

static std::vector<Vector3i> cubeFaces({ Vector3i(0, 1, 2), Vector3i(0, 2, 3),
    Vector3i(4, 6, 5), Vector3i(4, 7, 6),
    Vector3i(0, 3, 7), Vector3i(0, 7, 4),
    Vector3i(1, 6, 2), Vector3i(1, 5, 6),
    Vector3i(2, 6, 7), Vector3i(2, 7, 3),
    Vector3i(1, 0, 4), Vector3i(1, 4, 5)
});

static std::vector<Vector3f> cubeVertices({ Vector3f(-.5f, -.5f, .5f),
    Vector3f(.5f, -.5f, .5f),
    Vector3f(.5f, .5f, .5f),
    Vector3f(-.5f, .5f, .5f),
    Vector3f(-.5f, -.5f, -.5f),
    Vector3f(.5f, -.5f, -.5f),
    Vector3f(.5f, .5f, -.5f),
    Vector3f(-.5f, .5f, -.5f)
});

static std::vector<Vector2f> planeVertices({ Vector2f(-1.f, -1.f),
    Vector2f(1.f, -1.f),
    Vector2f(1.f, 1.f),
    Vector2f(-1.f, -1.f),
    Vector2f(1.f, 1.f),
    Vector2f(-1.f, 1.f)
});

// Expects an identity matrix as input
void perspective(Matrix4f &p, float fov, float ratio, float near, float far)
{
    float d = 1 / tan(fov * M_PI / 180 / 2);
    float ir = 1. / (near - far);
    
    p(0, 0) = d;
    p(1, 1) = -d * ratio;
    p(2, 2) = (near + far) * ir;
    p(3, 3) = 0;
    p(3, 2) = -1;
    p(2, 3) = 2 * near * far * ir;
}

extern "C" {
    /**
     * Tell the Nvidia driver to make itself useful.
     */
    
    #ifdef linux
    __attribute__((visibility("default")))
    #else
    __declspec(dllexport)
    #endif
     uint64_t NvOptimusEnablement = 0x00000001;
}

// Deferred rendering buffers and textures
static GLuint gBuffer, texSpeed, texColor, texDepth;

/**
 * Sets up internals (framebuffer and textures) for deferred rendering given
 * framebuffer dimensions.
 */
void setupDeferred(int w, int h)
{
    glBindFramebuffer(GL_FRAMEBUFFER, gBuffer);
    
    glBindTexture(GL_TEXTURE_2D, texSpeed);
    glTexImage2D(GL_TEXTURE_2D, 0, GL_RG16F, w, h, 0, GL_RG, GL_FLOAT, NULL);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
    glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, texSpeed, 0);
    
    glBindTexture(GL_TEXTURE_2D, texColor);
    glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA8, w, h, 0, GL_RGBA, GL_FLOAT, NULL);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
    glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT1, GL_TEXTURE_2D, texColor, 0);
    
    glBindTexture(GL_TEXTURE_2D, texDepth);
    glTexImage2D(GL_TEXTURE_2D, 0, GL_DEPTH_COMPONENT, w, h, 0, GL_DEPTH_COMPONENT, GL_FLOAT, NULL);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
    glFramebufferTexture2D(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, GL_TEXTURE_2D, texDepth, 0);
    
    const static unsigned int attachments[2] = { GL_COLOR_ATTACHMENT0, GL_COLOR_ATTACHMENT1 };
    glDrawBuffers(2, attachments);
    
    glBindFramebuffer(GL_FRAMEBUFFER, 0);
}

static void framebufferSizeCallback(GLFWwindow*, int width, int height)
{
    setupDeferred(width, height);
}

// Constants for VBO usage
enum
{
    SCENE_VBO = 0,
    QUAD_VBO,
    PSF_VBO,
    PSF_SIZES_VBO,
    NB_VBO
};

int _main(int argc, char *argv[])
{
    if(argc != 2)
    {
        std::cout << "Usage : ig3da_compute <psf prefix>" << std::endl;
        return 0;
    }
    
    srand(123456789);
    
    setwd(argv);
    
    // Setup window
    glfwSetErrorCallback(glfw_error_callback);
    if (!glfwInit())
    {
        std::cout << "Couldn't initialize GLFW" << std::endl;
        return 1;
    }
    glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 4);
    glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 4);
    glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);
    glfwWindowHint(GLFW_RESIZABLE, GLFW_FALSE);
    GLFWwindow *window = glfwCreateWindow(1280, 720, "GLFW Window", NULL, NULL);
    glfwMakeContextCurrent(window);
    gladLoadGLLoader((GLADloadproc)glfwGetProcAddress);
    glfwSwapInterval(1); // Enable vsync
    
    // Setup ImGui binding
    ImGui::CreateContext();
    ImGuiIO &io = ImGui::GetIO();
    io.ConfigFlags |= ImGuiConfigFlags_NavEnableKeyboard;  // Enable Keyboard Controls
    //io.ConfigFlags |= ImGuiConfigFlags_NavEnableGamepad;   // Enable Gamepad Controls
    ImGui_ImplGlfwGL3_Init(window, true);
    glfwSetKeyCallback(window, keyCallback);
    glfwSetMouseButtonCallback(window, mouseButtonCallback);
    
    // Setup style
    ImGui::StyleColorsDark();
    // ImGui::StyleColorsClassic();
    
    // Open spreadlets files and compute a vector of cumulative sums
    std::vector<Spreadlet> spreadlets;
    std::vector<int> cumsum;
    cumsum.push_back(0);
    {
        std::string path = argv[1];
        path += "_sizes.bin";
        FILE *in = fopen(path.c_str(), "rb");
        int size;
        while(fread(&size, sizeof(int), 1, in) == 1)
            cumsum.push_back(size + cumsum.back());
        fclose(in);
        
        path = argv[1];
        path += "_data.bin";
        in = fopen(path.c_str(), "rb");
        spreadlets.resize(cumsum.back());
        fread(&spreadlets[0], sizeof(Spreadlet), cumsum.back(), in);
        fclose(in);
    }
    
    trace("Successfully read " << spreadlets.size() * sizeof(Spreadlet) << " bytes from file");
    
    int display_w, display_h;
    glfwGetFramebufferSize(window, &display_w, &display_h);
    glViewport(0, 0, display_w, display_h);
    
    // Generate VAOs and VBOs
    GLuint vao[3], vbo[NB_VBO];
    glGenVertexArrays(3, vao);
    glGenBuffers(NB_VBO, vbo);
    
    checkGLerror();
    
    // Setup deferred rendering
    glGenFramebuffers(1, &gBuffer);
    GLuint outtexs[3];
    glGenTextures(3, outtexs);
    texSpeed = outtexs[0]; texColor = outtexs[1]; texDepth = outtexs[2];
    setupDeferred(display_w, display_h);
    glfwSetFramebufferSizeCallback(window, framebufferSizeCallback);
    
    checkGLerror();
    
    glBindVertexArray(vao[1]);
    glBindBuffer(GL_ARRAY_BUFFER, vbo[QUAD_VBO]);
    glBufferData(GL_ARRAY_BUFFER, planeVertices.size() * sizeof(Vector2f), &planeVertices[0](0), GL_STATIC_DRAW);
    
    checkGLerror();
    
    GLuint deferredProgram = glCreateProgram(), deferredVertex = createShaderFromSource(GL_VERTEX_SHADER, "quadVertex.glsl"),
        deferredFragment = createShaderFromSource(GL_FRAGMENT_SHADER, "quadFragment.glsl");
    printShaderLog(deferredVertex);
    printShaderLog(deferredFragment);
    glAttachShader(deferredProgram, deferredVertex);
    glAttachShader(deferredProgram, deferredFragment);
    glLinkProgram(deferredProgram);
    glUseProgram(deferredProgram);
    
    checkGLerror();
    
    glEnableVertexAttribArray(0);
    glVertexAttribPointer(0, 2, GL_FLOAT, GL_FALSE, 0, 0);
    glActiveTexture(GL_TEXTURE0);
    glBindTexture(GL_TEXTURE_2D, texColor);
    glUniform1i(0, 0);
    glUniform1i(1, 0);
    
    checkGLerror();
    
    // Precompute cube unindexed vertices + normals
    std::vector<Vector3f> cubeUVertices;
    for(Vector3i &t : cubeFaces)
    {
        cubeUVertices.push_back(cubeVertices[t(0)]);
        cubeUVertices.push_back(cubeVertices[t(1)]);
        cubeUVertices.push_back(cubeVertices[t(2)]);
    }
    
    long int unindexedSize = cubeUVertices.size() * sizeof(Vector3f);
    
    std::vector<Vector3f> cubeUNormals;
    for(Vector3i &t : cubeFaces)
    {
        Vector3f &v1 = cubeVertices[t(0)], &v2 = cubeVertices[t(1)],
            &v3 = cubeVertices[t(2)];
        Vector3f e1 = v2 - v1, e2 = v3 - v1, n = e1.cross(e2);
        
        cubeUNormals.push_back(n);
        cubeUNormals.push_back(n);
        cubeUNormals.push_back(n);
    }
    
    // Setup scene rendering
    glBindVertexArray(vao[0]);
    glBindBuffer(GL_ARRAY_BUFFER, vbo[SCENE_VBO]);
    GLuint program = glCreateProgram(), vertex = createShaderFromSource(GL_VERTEX_SHADER, "vertex.glsl"),
        fragment = createShaderFromSource(GL_FRAGMENT_SHADER, "fragment.glsl");
    printShaderLog(vertex);
    printShaderLog(fragment);
    glAttachShader(program, vertex);
    glAttachShader(program, fragment);
    glLinkProgram(program);
    glUseProgram(program);
    
    GLuint dtLocation = glGetUniformLocation(program, "dt");
    
    checkGLerror();
    
    Matrix4f p = Matrix4f::Identity();
    perspective(p, 90, (float)display_w / display_h, 0.01, 50);
    glUniformMatrix4fv(0, 1, GL_FALSE, &p.data()[0]);
    glUniform3f(2, 1, 1, 1);
    glUniform1f(3, 1);
    
    const int instances = 1000;
    std::vector<Vector3f> centers, colors;
    std::vector<Vector2f> angularVelocities;
    const int floatSize = instances * sizeof(float);
    
    for(int k = 0; k < instances; k++)
    {
        // centers.push_back(Vector3f(cos(k * 2 * 3.14159259536 / instances) * 5,
        //     sin(k * 2 * 3.14159259536 / instances) * 5, k * 4 + 10));
        centers.push_back(Vector3f::Random() * 20 + Vector3f(0, 0, 25));
        angularVelocities.push_back(Vector2f::Random() + Vector2f(1.05, 1.05));
        colors.push_back(Vector3f::Random() * 0.5 + Vector3f(0.5, 0.5, 0.5));
    }
    
    glBufferData(GL_ARRAY_BUFFER, unindexedSize * 2 + floatSize * 8, NULL, GL_STATIC_DRAW);
    glBufferSubData(GL_ARRAY_BUFFER, 0, unindexedSize, &cubeUVertices[0]);
    glBufferSubData(GL_ARRAY_BUFFER, unindexedSize, unindexedSize, &cubeUNormals[0]);
    glBufferSubData(GL_ARRAY_BUFFER, unindexedSize * 2, floatSize * 3, &centers[0]);
    glBufferSubData(GL_ARRAY_BUFFER, unindexedSize * 2 + floatSize * 3, floatSize * 2, &angularVelocities[0]);
    glBufferSubData(GL_ARRAY_BUFFER, unindexedSize * 2 + floatSize * 5, floatSize * 3, &colors[0]);
    
    // Position + normals
    glEnableVertexAttribArray(0);
    glVertexAttribPointer(0, 3, GL_FLOAT, false, 0, 0);
    glEnableVertexAttribArray(1);
    glVertexAttribPointer(1, 3, GL_FLOAT, false, 0, reinterpret_cast<void*>(unindexedSize));
    // Instanced attributes
    glEnableVertexAttribArray(2);
    glVertexAttribPointer(2, 3, GL_FLOAT, false, 0, reinterpret_cast<void*>(unindexedSize * 2));
    glVertexAttribDivisor(2, 1);
    glEnableVertexAttribArray(3);
    glVertexAttribPointer(3, 2, GL_FLOAT, false, 0, reinterpret_cast<void*>(unindexedSize * 2 + floatSize * 3));
    glVertexAttribDivisor(3, 1);
    glEnableVertexAttribArray(4);
    glVertexAttribPointer(4, 3, GL_FLOAT, false, 0, reinterpret_cast<void*>(unindexedSize * 2 + floatSize * 5));
    glVertexAttribDivisor(4, 1);
    
    checkGLerror();
    
    // Setup PSF splatting on the GPU
    glBindVertexArray(vao[2]);
    GLuint splattingProgram = glCreateProgram(), splattingCompute =
        createShaderFromSource(GL_COMPUTE_SHADER, "splattingCompute.glsl");
    printShaderLog(splattingCompute);
    glAttachShader(splattingProgram, splattingCompute);
    glLinkProgram(splattingProgram);
    glUseProgram(splattingProgram);
    
    int texOutPadX = 153, texOutPadY = 155;
    glUniform1i(0, 0);
    glUniform1i(1, 1);
    glUniform1i(2, 2);
    glUniform1i(3, cumsum.size() - 1);
    glUniform2i(4, texOutPadX, texOutPadY);
    
    // Create and bind RGB output
    GLuint texOut[4];
    glGenTextures(4, texOut);
    for(int i = 0; i < 4; i++)
    {
        glActiveTexture(GL_TEXTURE0);
        glBindTexture(GL_TEXTURE_2D, texOut[i]);
        glTexImage2D(GL_TEXTURE_2D, 0, GL_R32F, display_w + texOutPadX * 2, display_h + texOutPadY * 2, 0, GL_RED, GL_FLOAT, NULL);
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
        glClearTexImage(texOut[i], 0, GL_RED, GL_FLOAT, NULL);
        glBindImageTexture(i, texOut[i], 0, GL_FALSE, 0, GL_READ_WRITE, GL_R32F);
    }
    poissonPrepare(display_w + texOutPadX * 2, display_h + texOutPadY * 2);
    
    // Expose PSF data
    glBindBuffer(GL_SHADER_STORAGE_BUFFER, vbo[PSF_VBO]);
    glBufferData(GL_SHADER_STORAGE_BUFFER, spreadlets.size() * sizeof(Spreadlet), &spreadlets[0].x, GL_STATIC_DRAW);
    glBindBuffer(GL_SHADER_STORAGE_BUFFER, vbo[PSF_SIZES_VBO]);
    glBufferData(GL_SHADER_STORAGE_BUFFER, cumsum.size() * sizeof(int), &cumsum[0], GL_STATIC_DRAW);
    glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, vbo[PSF_VBO]);
    glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, vbo[PSF_SIZES_VBO]);
    
    glClearColor(0, 0, 0, 0);
    glClearDepth(1.);
    
    glEnable(GL_DEPTH_TEST);
    
    trace("Entering drawing loop");
    
    double timeBase = glfwGetTime();
    
    bool freeze = false;
    
    GLuint texToDisplay = texColor;
    
    while (!glfwWindowShouldClose(window))
    {
        ImGui_ImplGlfwGL3_NewFrame();
        
        double now = glfwGetTime(), dt = now - timeBase;
        timeBase = now;
        now = 1874531;
        
        if(!freeze)
        {
            glBindFramebuffer(GL_FRAMEBUFFER, gBuffer);
            glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
            
            glBindVertexArray(vao[0]);
            glUseProgram(program);
            glUniform1f(1, now);
            glUniform1f(dtLocation, dt);
            
            glDrawArraysInstanced(GL_TRIANGLES, 0, cubeUVertices.size(), instances);
        }
        
        glBindFramebuffer(GL_FRAMEBUFFER, 0);
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
        
        glBindVertexArray(vao[1]);
        glUseProgram(deferredProgram);
        glActiveTexture(GL_TEXTURE0);
        glBindTexture(GL_TEXTURE_2D, texToDisplay);
        glUniform1i(0, 0);
        
        glDrawArrays(GL_TRIANGLES, 0, planeVertices.size());
        
        ImGui::Begin("Debug", NULL, ImGuiWindowFlags_AlwaysAutoResize);
            ImGui::Text("Rendering %d cubes at %.1f FPS", instances, ImGui::GetIO().Framerate);
            
            if(!freeze && ImGui::Button("Post-process"))
            {
                // Splat the PSFs onto the image
                freeze = true;
                glBindVertexArray(vao[2]);
                glUseProgram(splattingProgram);
                glActiveTexture(GL_TEXTURE0);
                glBindTexture(GL_TEXTURE_2D, texSpeed);
                glActiveTexture(GL_TEXTURE1);
                glBindTexture(GL_TEXTURE_2D, texColor);
                glActiveTexture(GL_TEXTURE2);
                glBindTexture(GL_TEXTURE_2D, texDepth);
                
                glDispatchCompute(display_w, display_h, 1);
                
                FImage r(display_h + texOutPadY * 2, display_w + texOutPadX * 2),
                    g(display_h + texOutPadY * 2, display_w + texOutPadX * 2),
                    b(display_h + texOutPadY * 2, display_w + texOutPadX * 2),
                    a(display_h + texOutPadY * 2, display_w + texOutPadX * 2);
                glMemoryBarrier(GL_TEXTURE_UPDATE_BARRIER_BIT);
                // Fetch red, green, blue and alpha textures
                glActiveTexture(GL_TEXTURE3);
                glBindTexture(GL_TEXTURE_2D, texOut[0]);
                glGetTexImage(GL_TEXTURE_2D, 0, GL_RED, GL_FLOAT, r.data());
                glBindTexture(GL_TEXTURE_2D, texOut[1]);
                glGetTexImage(GL_TEXTURE_2D, 0, GL_RED, GL_FLOAT, g.data());
                glBindTexture(GL_TEXTURE_2D, texOut[2]);
                glGetTexImage(GL_TEXTURE_2D, 0, GL_RED, GL_FLOAT, b.data());
                glBindTexture(GL_TEXTURE_2D, texOut[3]);
                glGetTexImage(GL_TEXTURE_2D, 0, GL_RED, GL_FLOAT, a.data());
                trace("Running integration ...");
                unsigned long bleh = glfwGetTime();
                r = poissonIntegrate(r);
                g = poissonIntegrate(g);
                b = poissonIntegrate(b);
                trace("That took " << glfwGetTime() - bleh << " s");
                float *fusion = new float[display_w * display_h * 3];
                for(int y = 0; y < display_h; y++)
                {
                    for(int x = 0; x < display_w; x++)
                    {
                        int i = y * display_w + x, off = (y + texOutPadY) * (display_w + texOutPadX * 2) + x + texOutPadX;
                        fusion[i * 3] = r.data()[off];
                        fusion[i * 3 + 1] = g.data()[off];
                        fusion[i * 3 + 2] = b.data()[off];
                    }
                }
                glBindTexture(GL_TEXTURE_2D, texColor);
                glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, display_w, display_h, 0, GL_RGB, GL_FLOAT, fusion);
                delete[] fusion;
            }
            if(freeze)
            {
                if(ImGui::Button("Show frozen frame"))
                {
                    texToDisplay = texColor;
                    glUniform1i(1, 0);
                }
                if(ImGui::Button("Show red texture"))
                {
                    texToDisplay = texOut[0];
                    glUniform1i(1, 1);
                }
                if(ImGui::Button("Show green texture"))
                {
                    texToDisplay = texOut[1];
                    glUniform1i(1, 2);
                }
                if(ImGui::Button("Show blue texture"))
                {
                    texToDisplay = texOut[2];
                    glUniform1i(1, 3);
                }
            }
        ImGui::End();
        
        ImGui::Render();
        ImGui_ImplGlfwGL3_RenderDrawData(ImGui::GetDrawData());
        glfwSwapBuffers(window);
        glfwPollEvents();
    }
    
    poissonClean();
    
    glDeleteTextures(4, texOut);
    
    glDeleteTextures(3, outtexs);
    glDeleteFramebuffers(1, &gBuffer);
    
    glDeleteBuffers(NB_VBO, vbo);
    glDeleteVertexArrays(3, vao);
    
    trace("Exiting drawing loop");
    
    // Cleanup
    ImGui_ImplGlfwGL3_Shutdown();
    ImGui::DestroyContext();
    glfwTerminate();
    
    // trace("Remaining images : " << imageCount);
    
    return 0;
}

int main(int argc, char *argv[])
{
    try
    {
        return _main(argc, argv);
    }
    catch(std::exception &e)
    {
        std::cerr << e.what();
        throw;
    }
}
