#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <random>
#include <set>
#include <vector>

#include "mathutil.hpp"

#define JC_VORONOI_IMPLEMENTATION
#include "jc_voronoi.h"

#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "FImage.hpp"
#include "pyramids.hpp"

// Uncomment this line to have inter-step output images
// #define DEBUG_IMAGES

// Uncomment this line to have debug output
// #define DEBUG_OUTPUT

#ifdef DEBUG_OUTPUT
#define trace(x) std::cout << x
#define flush() std::cout.flush()
#else
#define trace(x) ;
#define flush() ;
#endif

// Laplacian remapping
double radiusMap(double x, double a, double b)
{
    return a * pow(x, b);
}

void forEachSite(jcv_diagram *diagram, const std::function<void(const jcv_site*)> &f)
{
    const jcv_site *sites = jcv_diagram_get_sites(diagram);
    for(int i = 0; i < diagram->numsites; i++)
        f(sites + i);
}

// Assumes the Poisson process has already been prepared
FImage reconstructPSF(std::vector<jcv_point> &p, std::vector<float> &f, int w, int h)
{
    FImage r(h, w);
    r.setZero();
    for(unsigned int i = 0; i < p.size(); i++)
        r(p[i].y, p[i].x) = f[i];
    return poissonIntegrate(r);
}

int main(int argc, char *argv[])
{
    if(argc != 6)
    {
        std::cout << "Usage : ig3da_precompute <input image> <nb of color channels> <x center> <y center> <output file>" << std::endl;
        return 0;
    }
    
    srand(time(NULL));
    FImage base = loadImage(argv[1]);
    base /= base.maxCoeff();
    jcv_point center;
    center.x = std::stoi(argv[3]);
    center.y = std::stoi(argv[4]);
    
    /// Laplacian calculation + gaussian smoothing
    Array33f laplacianKernel; laplacianKernel << 0, -1, 0, -1, 4, -1, 0, -1, 0;
    Array<float, 5, 1> gaussianKernel; gaussianKernel << 1, 4, 6, 4, 1;
    Array<float, 1, 5> gaussianKernelT = gaussianKernel.transpose();
    
    FImage laplacian = convolve<3, 3>(base, laplacianKernel),
        smoothedLaplacian = abs(convolveSep<5, 5>(laplacian, gaussianKernelT, gaussianKernel) / 256),
        dartsImg(laplacian.rows(), laplacian.cols());
    
    auto weightIntegrator = [&](int x, int y)
    {
        if(x < 0 || x >= smoothedLaplacian.cols() ||
            y < 0 || y >= smoothedLaplacian.rows())
            trace(x << ", " << y << std::endl);
        return smoothedLaplacian(y, x);
    };
    
    uint8_t color;
    
    auto valueIntegrator = [&](int x, int y)
    {
#ifdef DEBUG_IMAGES
        dartsImg(y, x) = color;
#endif
        return laplacian(y, x);
    };
        
    double maxIntensity = smoothedLaplacian.maxCoeff(),
        b = 1,
        a = pow(maxIntensity, b);
    
    /// 2D variable-radius Poisson disk creation
    // Create a Poisson disk with dart-throwing, stop after 10k failed attemps
    std::vector<jcv_point> darts;
    
    int failedAttempts = 0;
    
    trace("Throwing darts for 10,000 fails" << std::endl);
    
    while(failedAttempts++ < 10000)
    {
        jcv_point dart;
        dart.x = (rand() % laplacian.cols()) + 0.5;
        dart.y = (rand() % laplacian.rows()) + 0.5;
        
        double minDist = laplacian.cols() * laplacian.cols() + laplacian.rows() * laplacian.rows();
        for(jcv_point &p : darts)
            minDist = min(minDist, distance(&dart, &p));
        
        if(minDist < radiusMap(1 / smoothedLaplacian(dart.y, dart.x), a, b))
            continue;
        
        trace(".");
        flush();
        failedAttempts = 0;
        darts.push_back(dart);
    }
    trace(std::endl << "Kept " << darts.size() << " dart(s)" << std::endl);
    
    // Output the darts at the initial step
#ifdef DEBUG_IMAGES
    for(jcv_point &p : darts)
        dartsImg(p.y, p.x) = 0xff;
    save(dartsImg, "darts_pre.png");
#endif
    
    std::vector<jcv_point> spreadlets_p;
    std::vector<float> spreadlets_f;
    
    // Try for fast-tracking, in case the actual image is so small that a stochastic
    // approximation is sure to do poorly
    for(int y = 0; y < laplacian.rows(); y++)
    {
        for(int x = 0; x < laplacian.cols(); x++)
        {
            float l = laplacian(y, x);
            if(l != 0)
            {
                jcv_point p; p.x = x + 0.5; p.y = y + 0.5;
                spreadlets_p.push_back(p);
                spreadlets_f.push_back(l);
            }
        }
    }
    
    if(spreadlets_p.size() > 200)
    {
        spreadlets_p.assign(darts.begin(), darts.end());
        spreadlets_f.resize(darts.size());
        
        /// 50 weighted Lloyd relaxations
        // Use the smoothed laplacian for weighting the sites
        jcv_rect aabb;
        aabb.min.x = aabb.min.y = 0;
        aabb.max.x = laplacian.cols() - 1;
        aabb.max.y = laplacian.rows() - 1;
        
        jcv_diagram diagram;
        memset(&diagram, 0, sizeof(jcv_diagram));
        
        trace("Running 50 weighted Lloyd relaxations" << std::endl);
        
        for(int k = 0; k < 50; k++)
        {
            trace(".");
            flush();
            
            jcv_diagram_generate(spreadlets_p.size(), &spreadlets_p[0], &aabb, &diagram);
            forEachSite(&diagram, [&](const jcv_site *site)
            {
                jcv_point sum;
                double weight = cellWeightedCentroid(site, weightIntegrator, sum);
                // Ignore the site if the laplacian integrates to 0
                if(weight > 0)
                {
                    spreadlets_p[site->index].x = sum.x;
                    spreadlets_p[site->index].y = sum.y;
                }
            });
        }
        
        jcv_diagram_generate(spreadlets_p.size(), &spreadlets_p[0], &aabb, &diagram);
        
        trace(std::endl << "Done" << std::endl);
        
        /// Collapse Voronoi cells into spreadlets
        // Also draw the diagram for fun
        {
            dartsImg.setZero();
            
            forEachSite(&diagram, [&](const jcv_site *site)
            {
                color = (rand() & 0x7f) + 64;
                spreadlets_f[site->index] = integrateCell(site, valueIntegrator);
                dartsImg(site->p.y, site->p.x) = 0xff;
            });
        }
        
#ifdef DEBUG_IMAGES
        save(dartsImg, "voronoi.png");
#endif
        
        trace("Running 400 steps of simulated annealing" << std::endl);
        
        poissonPrepare(base.cols(), base.rows());
        
        /// 400 steps of simulated annealing on the spreadlets
        {
            FImage reconstructed(base.rows(), base.cols());
            auto calcCost = [&](std::vector<jcv_point> &state_p, std::vector<float> &state_f)
            {
                reconstructed = reconstructPSF(state_p, state_f, base.cols(), base.rows());
                
                return abs(base - reconstructed).sum();
            };
            
            std::default_random_engine g;
            std::bernoulli_distribution half_bern(0.5);
            double cost = calcCost(spreadlets_p, spreadlets_f), bestCost = cost;
            trace("Initial cost : " << cost << std::endl);
            double T = 0, alpha = 0.89125; // try that I guess
            
            std::vector<jcv_point> old_p;
            std::vector<float> old_f;
            std::vector<int> done;
            done.reserve(spreadlets_p.size() / 100);
            
            for(int timer = 400; timer > 0; timer--)
            {
                trace(".");
                flush();
                
                old_p.assign(spreadlets_p.begin(), spreadlets_p.end());
                old_f.assign(spreadlets_f.begin(), spreadlets_f.end());
                
                // Change 1% of all positions
                done.clear();
                for(unsigned int i = 0; i < spreadlets_p.size() / 100; i++)
                {
                    int idx;
                    while(1)
                    {
                        idx = rand() % spreadlets_p.size();
                        for(int i : done)
                            if(i == idx)
                                continue;
                        done.push_back(idx);
                        break;
                    };
                    
                    if(half_bern(g))
                        spreadlets_p[idx].x += half_bern(g) * 2 - 1;
                    else
                        spreadlets_p[idx].y += half_bern(g) * 2 - 1;
                    // Check for duplicates
                    {
                        int sx = spreadlets_p[idx].x, sy = spreadlets_p[idx].y;
                        for(unsigned int k = 0; k < spreadlets_p.size(); k++)
                        {
                            if(k != idx && sx == (int)spreadlets_p[k].x && sy == (int)spreadlets_p[k].y)
                            {
                                spreadlets_p.erase(spreadlets_p.begin() + idx);
                                spreadlets_f.erase(spreadlets_f.begin() + idx);
                                done.pop_back();
                                done.push_back(k);
                                for(int &i : done)
                                    i -= i > idx;
                                break;
                            }
                        }
                    }
                }
                // Regenerate the diagram and recompute the spreadlets
                jcv_diagram_generate(spreadlets_p.size(), &spreadlets_p[0], &aabb, &diagram);
                forEachSite(&diagram, [&](const jcv_site *site)
                {
                    spreadlets_f[site->index] = integrateCell(site, [&](int x, int y)
                    {
                        return laplacian(y, x);
                    });
                });
                
                double delta = calcCost(spreadlets_p, spreadlets_f) - cost;
                
                double p_accept = T == 0 ? delta < 0 : exp(-delta / T);
                if(p_accept >= 1 || std::bernoulli_distribution(p_accept)(g))
                {
                    cost += delta;
                    
                    if(cost < bestCost)
                        bestCost = cost;
                }
                else
                {
                    spreadlets_p = old_p;
                    spreadlets_f = old_f;
                }

                T *= alpha;
            }
            
            trace("Final cost : " << calcCost(spreadlets_p, spreadlets_f) << std::endl);
        }
        
        trace(std::endl << "Done" << std::endl);
        
        // Collapse the cells once again
        {
            dartsImg.setZero();
            
            forEachSite(&diagram, [&](const jcv_site *site)
            {
                color = (rand() & 0x7f) + 64;
                spreadlets_f[site->index] = integrateCell(site, valueIntegrator);
                dartsImg(site->p.y, site->p.x) = 0xff;
            });
        }
        
        jcv_diagram_free(&diagram);
        
#ifdef DEBUG_IMAGES
        save(dartsImg, "annealing.png");
#endif
    }
    
    // Normalization step : make it so the reconstructed PSF has unit energy
    {
        poissonPrepare(base.cols(), base.rows());
        
        float factor = reconstructPSF(spreadlets_p, spreadlets_f, base.cols(), base.rows()).sum();
        
        trace("Approx. sums at " << factor << std::endl);
        
        for(unsigned int i = 0; i < spreadlets_f.size(); i++)
            spreadlets_f[i] /= factor;
    }
    
    // Output the spreadlets to a binary file
    std::string outpath(argv[5]);
    int count = 0;
    // Data : x, y, f as floats
    outpath += "_data.bin";
    {
        std::ofstream dataOut(outpath.c_str(), std::ios_base::app | std::ios_base::binary);
        for(unsigned int k = 0; k < spreadlets_p.size(); k++)
        {
            count++;
            float x = spreadlets_p[k].x - center.x, y = spreadlets_p[k].y - center.y;
            
            dataOut.write((char*)&x, sizeof(float));
            dataOut.write((char*)&y, sizeof(float));
            dataOut.write((char*)&spreadlets_f[k], sizeof(float));
        }
        dataOut.close();
    }
    
    outpath = argv[5];
    outpath += "_sizes.bin";
    // Sizes
    {
        std::ofstream sizeOut(outpath.c_str(), std::ios_base::app | std::ios_base::binary);
        sizeOut.write((char*)&count, sizeof(int));
        sizeOut.close();
    }
    
    return 0;
}
