/*
*
* Supplemental code for "High-Performance Elliptical Cone Tracing"
* Authors  Umut Emre, Aryan Kanak, Shlomi Steinberg
*
*/

#include <immintrin.h>
#include <glm/glm.hpp>

#pragma once

using vec2 = glm::vec2;
using vec3 = glm::vec3;
using vec2u32_t = glm::uvec2;
using vec3b = glm::bvec3;

namespace util {

static constexpr auto inf = std::numeric_limits<float>::infinity();

class frame_t {
public:
    vec3 t,b,n;

    [[nodiscard]] inline vec3 to_local(const vec3& v) const noexcept {
        return {
            glm::dot(v, t),
            glm::dot(v, b),
            glm::dot(v, n)
        };
    }

    [[nodiscard]] inline vec3 to_world(const vec3& v) const noexcept {
        return t*v.x + b*v.y + n*v.z;
    }

    /**
     * @brief Vectorized 4x transform to local space. Input/output is assumed to be in metres.
     */
    inline void to_local(__m128& x,
                         __m128& y,
                         __m128& z) const noexcept {
        const auto tx = _mm_set1_ps(t.x);
        const auto ty = _mm_set1_ps(t.y);
        const auto tz = _mm_set1_ps(t.z);
        const auto bx = _mm_set1_ps(b.x);
        const auto by = _mm_set1_ps(b.y);
        const auto bz = _mm_set1_ps(b.z);
        const auto nx = _mm_set1_ps(n.x);
        const auto ny = _mm_set1_ps(n.y);
        const auto nz = _mm_set1_ps(n.z);

        const auto vtz = _mm_mul_ps(tz,z);
        const auto vbz = _mm_mul_ps(bz,z);
        const auto vnz = _mm_mul_ps(nz,z);
        const auto vtyz = _mm_fmadd_ps(ty, y, vtz);
        const auto vbyz = _mm_fmadd_ps(by, y, vbz);
        const auto vnyz = _mm_fmadd_ps(ny, y, vnz);
        const auto vt = _mm_fmadd_ps(tx, x, vtyz);
        const auto vb = _mm_fmadd_ps(bx, x, vbyz);
        const auto vn = _mm_fmadd_ps(nx, x, vnyz);

        x = vt;
        y = vb;
        z = vn;
    }

    /**
     * @brief Vectorized 8x transform to local space. Input/output is assumed to be in metres.
     */
    inline void to_local(__m256& x,
                         __m256& y,
                         __m256& z) const noexcept {
        const auto tx = _mm256_set1_ps(t.x);
        const auto ty = _mm256_set1_ps(t.y);
        const auto tz = _mm256_set1_ps(t.z);
        const auto bx = _mm256_set1_ps(b.x);
        const auto by = _mm256_set1_ps(b.y);
        const auto bz = _mm256_set1_ps(b.z);
        const auto nx = _mm256_set1_ps(n.x);
        const auto ny = _mm256_set1_ps(n.y);
        const auto nz = _mm256_set1_ps(n.z);

        const auto vtz = _mm256_mul_ps(tz,z);
        const auto vbz = _mm256_mul_ps(bz,z);
        const auto vnz = _mm256_mul_ps(nz,z);
        const auto vtyz = _mm256_fmadd_ps(ty, y, vtz);
        const auto vbyz = _mm256_fmadd_ps(by, y, vbz);
        const auto vnyz = _mm256_fmadd_ps(ny, y, vnz);
        const auto vt = _mm256_fmadd_ps(tx, x, vtyz);
        const auto vb = _mm256_fmadd_ps(bx, x, vbyz);
        const auto vn = _mm256_fmadd_ps(nx, x, vnyz);

        x = vt;
        y = vb;
        z = vn;
    }
};

template <typename T>
constexpr inline auto sqr(T t) noexcept { return t*t; }  

constexpr inline auto max_element(const vec3& v) noexcept {
    return std::max(v.x,std::max(v.y,v.z));
}

constexpr inline auto min_element(const vec3& v) noexcept {
    return std::min(v.x,std::min(v.y,v.z));
}

template <typename T>
constexpr inline auto max(const T& q1, const T& q2, const T& q3) noexcept {
    return glm::max(glm::max(q1,q2),q3);
}

template <typename T>
constexpr inline auto min(const T& q1, const T& q2, const T& q3) noexcept {
    return glm::min(glm::min(q1,q2),q3);
}

constexpr inline auto fma(float q1, float q2, float q3) noexcept {
    return q1*q2+q3;
}

constexpr inline auto difference_of_products(float a, float b, float c, float d) noexcept {
    const auto cd = c*d;
    const auto ret = fma(a,b,-cd);
    return ret + fma(-c,d,cd);
}

/**
 * @brief Returns TRUE if the point 'p' lies within the triangle abc.
          'p' is assumed to lie in the triangle plane, and points 'a','b','c' are assumed to NOT be co-linear.
 */
[[nodiscard]] inline bool is_point_in_triangle(const vec3& p, 
                                               const vec3& a, const vec3& b, const vec3& c) noexcept {
    // compute barycentrics
    const auto v0 = b-a,
               v1 = c-a;
    const auto u = p-a;
    const auto d00 = glm::dot(v0, v0);
    const auto d01 = glm::dot(v0, v1);
    const auto d11 = glm::dot(v1, v1);
    const auto d20 = glm::dot(u, v0);
    const auto d21 = glm::dot(u, v1);
    const auto rcpd = 1 / difference_of_products(d00,d11, d01,d01);

    const auto alpha = difference_of_products(d11,d20, d01,d21) * rcpd;
    const auto beta  = difference_of_products(d00,d21, d01,d20) * rcpd;

    return alpha>=0 && beta>=0 && alpha+beta<=1;
}

constexpr inline auto all_elts_lte(const vec3& o1, const vec3& o2) noexcept {
    vec3b ret;
    for (auto i=0ul;i<3;++i)
        ret[i] = o1[i]<=o2[i];
    return ret;
}

constexpr inline auto all_elts_gte(const vec3& o1, const vec3& o2) noexcept {
    vec3b ret;
    for (auto i=0ul;i<3;++i)
        ret[i] = o1[i]>=o2[i];
    return ret;
}

struct range_t {
    float min, max;

    [[nodiscard]] constexpr inline bool contains(float pt) const noexcept {
        if (pt<=max && min<=pt) return true;
        return false;
    }

    
    [[nodiscard]] inline bool overlaps(const range_t& r) const noexcept {
        return min<=r.max && r.min<=max;
    }

    // unions
    constexpr inline range_t& operator|=(const range_t& o) noexcept {
        *this = range_t{ glm::min(min,o.min), glm::max(max,o.max) };
        return *this;
    }

    // intersections
    constexpr inline range_t operator&(const range_t& o) const noexcept {
        return range_t{ glm::max(min,o.min), glm::min(max,o.max) };
    }
    
    constexpr inline range_t& operator&=(const range_t& o) noexcept {
        *this = *this&o;
        return *this;
    }

    [[nodiscard]] constexpr inline bool empty() const noexcept {
        // empty when min==max==±∞
        if (min==max && !std::isfinite(min))
            return true;

        return  min>max;
    }

    [[nodiscard]] constexpr inline auto length() const noexcept {
        return max-min;
    }

    [[nodiscard]] constexpr inline auto grow(const float extent) const noexcept {
        return range_t{ min-extent, max+extent };
    }

    [[nodiscard]] constexpr static inline auto range(const float min, const float max) noexcept {
        return range_t{ .min=min, .max=max, };
    }
    [[nodiscard]] constexpr static inline auto range(const float pt) noexcept {
        return range_t{ .min=pt, .max=pt, };
    }
    [[nodiscard]] constexpr static inline auto positive() noexcept {
        return range_t{
            .min = float{},
            .max = +util::inf
        };
    }
    [[nodiscard]] constexpr static inline auto all() noexcept {
        return range_t{
            .min = -util::inf,
            .max = +util::inf
        };
    }
    [[nodiscard]] constexpr static inline auto null() noexcept {
        return range_t{
            .min = +util::inf,
            .max = -util::inf
        };
    }

};

constexpr inline auto operator*(float f, const range_t& range) noexcept {
    return range_t{ .min=range.min*f, .max=range.max*f };
}

}  // namespace util
