#include "Halide.h"

#include <algorithm>

namespace {

using namespace Halide;

class LensBlur : public Halide::Generator<LensBlur> {
    Input<Buffer<uint8_t>> left_im{"left_im", 3};
    Input<Buffer<uint8_t>> right_im{"right_im", 3};
    // The number of displacements to consider
    Input<int> slices{"slices", 32, 1, 64};
    // The depth to focus on
    Input<int> focus_depth{"focus_depth", 13, 1, 32};
    // The increase in blur radius with misfocus depth
    Input<float> blur_radius_scale{"blur_radius_scale", 0.5f, 0.0f, 1.0f};
    // The number of samples of the aperture to use
    Input<int> aperture_samples{"aperture_samples", 32, 1, 64};

    Output<Buffer<float>> final{"final", 3};

    void generate() {
        /* THE ALGORITHM */
        Expr maximum_blur_radius =
            cast<int>(max(slices - focus_depth, focus_depth) * blur_radius_scale);

        Func left = BoundaryConditions::repeat_edge(left_im);
        Func right = BoundaryConditions::repeat_edge(right_im);

        Func diff;
        diff(x, y, z, c) = min(absd(left(x, y, c), right(x + 2 * z, y, c)),
                               absd(left(x, y, c), right(x + 2 * z + 1, y, c)));

        Func cost;
        cost(x, y, z) = (pow(cast<float>(diff(x, y, z, 0)), 2) +
                         pow(cast<float>(diff(x, y, z, 1)), 2) +
                         pow(cast<float>(diff(x, y, z, 2)), 2));

        // Compute confidence of cost estimate at each pixel by taking the
        // variance across the stack.
        Func cost_confidence;
            RDom r(0, slices);
            Expr a = sum(pow(cost(x, y, r), 2)) / slices;
            Expr b = pow(sum(cost(x, y, r) / slices), 2);
            cost_confidence(x, y) = a - b;

        // Do a push-pull thing to blur the cost volume with an
        // exponential-decay type thing to inpaint over regions with low
        // confidence.
        Func cost_pyramid_push[8];
        cost_pyramid_push[0](x, y, z, c) =
            mux(c, {cost(x, y, z) * cost_confidence(x, y), cost_confidence(x, y)});

        Expr w = left_im.dim(0).extent(), h = left_im.dim(1).extent();
        for (int i = 1; i < 8; i++) {
            cost_pyramid_push[i](x, y, z, c) = downsample(cost_pyramid_push[i - 1])(x, y, z, c);
            w /= 2;
            h /= 2;
            cost_pyramid_push[i] = BoundaryConditions::repeat_edge(cost_pyramid_push[i], {{0, w}, {0, h}});

        Func cost_pyramid_pull[8];
        cost_pyramid_pull[7](x, y, z, c) = cost_pyramid_push[7](x, y, z, c);
        for (int i = 6; i >= 0; i--) {
            cost_pyramid_pull[i](x, y, z, c) = lerp(upsample(cost_pyramid_pull[i + 1])(x, y, z, c),
                                                    cost_pyramid_push[i](x, y, z, c),

        Func filtered_cost;
        filtered_cost(x, y, z) = (cost_pyramid_pull[0](x, y, z, 0) /
                                  cost_pyramid_pull[0](x, y, z, 1));

        // Assume the minimum cost slice is the correct depth.
        Func depth;
            RDom r(0, slices);
            depth(x, y) = argmin(filtered_cost(x, y, r))[0];

        Func bokeh_radius;
        bokeh_radius(x, y) = abs(depth(x, y) - focus_depth) * blur_radius_scale;

        Func bokeh_radius_squared;
        bokeh_radius_squared(x, y) = pow(bokeh_radius(x, y), 2);

        // Take a max filter of the bokeh radius to determine the
        // worst-case bokeh radius to consider at each pixel. Makes the
        // sampling more efficient below.
        Func worst_case_bokeh_radius_y;
        Func worst_case_bokeh_radius;
            RDom r(-maximum_blur_radius, 2 * maximum_blur_radius + 1);
            worst_case_bokeh_radius_y(x, y) = maximum(bokeh_radius(x, y + r));
            worst_case_bokeh_radius(x, y) = maximum(worst_case_bokeh_radius_y(x + r, y));

        Func input_with_alpha;
        input_with_alpha(x, y, c) = mux(c, {cast<float>(left(x, y, 0)),
                                            cast<float>(left(x, y, 1)),
                                            cast<float>(left(x, y, 2)),

        // Render a blurred image
        Func output;
        output(x, y, c) = input_with_alpha(x, y, c);

        // The sample locations are a random function of x, y, and sample
        // number (not c).
        Expr worst_radius = worst_case_bokeh_radius(x, y);
        Expr sample_u = (random_float() - 0.5f) * 2 * worst_radius;
        Expr sample_v = (random_float() - 0.5f) * 2 * worst_radius;
        sample_u = clamp(cast<int>(sample_u), -maximum_blur_radius, maximum_blur_radius);
        sample_v = clamp(cast<int>(sample_v), -maximum_blur_radius, maximum_blur_radius);
        Func sample_locations;
        sample_locations(x, y, z) = {sample_u, sample_v};

        RDom s(0, aperture_samples);
        sample_u = sample_locations(x, y, z)[0];
        sample_v = sample_locations(x, y, z)[1];
        Expr sample_x = x + sample_u, sample_y = y + sample_v;
        Expr r_squared = sample_u * sample_u + sample_v * sample_v;

        // We use this sample if it's from a pixel whose bokeh influences
        // this output pixel. Here's a crude approximation that ignores
        // some subtleties of occlusion edges and inpaints behind objects.
        Expr sample_is_within_bokeh_of_this_pixel =
            r_squared < bokeh_radius_squared(x, y);

        Expr this_pixel_is_within_bokeh_of_sample =
            r_squared < bokeh_radius_squared(sample_x, sample_y);

        Expr sample_is_in_front_of_this_pixel =
            depth(sample_x, sample_y) < depth(x, y);

        Func sample_weight;
        sample_weight(x, y, z) =
            select((sample_is_within_bokeh_of_this_pixel ||
                    sample_is_in_front_of_this_pixel) &&
                   1.0f, 0.0f);

        sample_x = x + sample_locations(x, y, s)[0];
        sample_y = y + sample_locations(x, y, s)[1];
        output(x, y, c) += sample_weight(x, y, s) * input_with_alpha(sample_x, sample_y, c);

        // Normalize
        final(x, y, c) = output(x, y, c) / output(x, y, 3);

        /* ESTIMATES */
        // (This can be useful in conjunction with RunGen and benchmarks as well
        // as auto-schedule, so we do it in all cases.)
        // Provide estimates on the input image
        left_im.set_estimates({{0, 192}, {0, 320}, {0, 3}});
        right_im.set_estimates({{0, 192}, {0, 320}, {0, 3}});
        // Provide estimates on the parameters
        // Provide estimates on the pipeline output
        final.set_estimates({{0, 192}, {0, 320}, {0, 3}});

        /* THE SCHEDULE */
        if (auto_schedule) {
            // nothing
        } else if (get_target().has_gpu_feature()) {
            // Manual GPU schedule
            Var xi("xi"), yi("yi"), zi("zi");
                .reorder(c, z, x, y)
                .bound(c, 0, 2)
                .gpu_tile(x, y, xi, yi, 16, 16);
            cost.compute_at(cost_pyramid_push[0], xi);
            cost_confidence.compute_at(cost_pyramid_push[0], xi);

            for (int i = 1; i < 8; i++) {
                cost_pyramid_push[i].compute_root().gpu_tile(x, y, z, xi, yi, zi, 8, 8, 8);
                cost_pyramid_pull[i].compute_root().gpu_tile(x, y, z, xi, yi, zi, 8, 8, 8);

                .gpu_tile(x, y, xi, yi, 16, 16);
                .reorder(c, x, y)
                .gpu_tile(x, y, xi, yi, 16, 16);
                .gpu_tile(x, y, xi, yi, 16, 16);
                .gpu_tile(x, y, xi, yi, 16, 16);
                .reorder(c, x, y)
                .bound(c, 0, 3)
                .gpu_tile(x, y, xi, yi, 16, 16);

            output.compute_at(final, xi);
            output.update().reorder(c, x, s).unroll(c);
            sample_weight.compute_at(output, x);
            sample_locations.compute_at(output, x);
        } else {
            // Manual CPU schedule
                .reorder(c, z, x, y)
                .bound(c, 0, 2)
                .vectorize(x, 16)
                .parallel(y, 4);
            cost.compute_at(cost_pyramid_push[0], x)
            cost_confidence.compute_at(cost_pyramid_push[0], x)

            Var xi, yi, t;
            for (int i = 1; i < 8; i++) {
                cost_pyramid_push[i].compute_at(cost_pyramid_pull[1], t).vectorize(x, 8);
                if (i > 1) {
                        .compute_at(cost_pyramid_pull[1], t)
                        .tile(x, y, xi, yi, 8, 2)

                .fuse(z, c, t)
                .tile(x, y, xi, yi, 8, 2)
                .tile(x, y, xi, yi, 8, 2)
                .parallel(y, 8);
                .reorder(c, x, y)
                .vectorize(x, 8)
                .parallel(y, 8);
                .compute_at(final, y)
                .vectorize(x, 8);
                .reorder(c, x, y)
                .bound(c, 0, 3)
                .vectorize(x, 8)
                .compute_at(final, y)
                .vectorize(x, 8);
            output.compute_at(final, x)
                .reorder(c, x, s)
            sample_weight.compute_at(output, x).unroll(x);
            sample_locations.compute_at(output, x).vectorize(x);

    Var x, y, z, c;

    // Downsample with a 1 3 3 1 filter
    Func downsample(Func f) {
        using Halide::_;
        Func downx, downy;
        downx(x, y, _) = (f(2 * x - 1, y, _) + 3.0f * (f(2 * x, y, _) + f(2 * x + 1, y, _)) + f(2 * x + 2, y, _)) / 8.0f;
        downy(x, y, _) = (downx(x, 2 * y - 1, _) + 3.0f * (downx(x, 2 * y, _) + downx(x, 2 * y + 1, _)) + downx(x, 2 * y + 2, _)) / 8.0f;
        return downy;

    // Upsample using bilinear interpolation
    Func upsample(Func f) {
        using Halide::_;
        Func upx, upy;
        upx(x, y, _) = 0.25f * f((x / 2) - 1 + 2 * (x % 2), y, _) + 0.75f * f(x / 2, y, _);
        upy(x, y, _) = 0.25f * upx(x, (y / 2) - 1 + 2 * (y % 2), _) + 0.75f * upx(x, y / 2, _);
        return upy;

}  // namespace

