/*
*
* Supplemental code for "High-Performance Elliptical Cone Tracing"
* Authors  Umut Emre, Aryan Kanak, Shlomi Steinberg
*
*/

#pragma once

#include <optional>

#include <glm/glm.hpp>
#include <glm/gtc/constants.hpp>
#include <glm/gtc/reciprocal.hpp>
#include <glm/gtc/ulp.hpp>

#include "util.hpp"

namespace intersect {

using range_t = util::range_t;

struct ray_t {
    vec3 o;       // origin
    vec3 d;        // direction

    // precomputed auxiliary values
    vec3 invd;

    constexpr ray_t(const vec3& origin,
                    const vec3& direction)
        : o(origin),
            d(direction),
            invd(1.0f/direction)
    {}
};

class elliptic_cone_t {
public:
    ray_t r;                        // centre ray
    vec3 tangent;                   // tangent
    float initial_x_length = 0.0f;  // the initial cone footprint (w.r.t. the major axis) at dist=0 propagation.

    float one_over_e;               // sqrt(1-eccentricity²)    == minor/major
    float e;                        // 1/sqrt(1-eccentricity²)  == major/minor

    float tan_alpha;                // tan half opening angle

    float z_apex;                   // z position of apex point, can be -∞

    [[nodiscard]] const auto& ray() const noexcept { return r; }
    [[nodiscard]] const auto& o() const noexcept { return r.o; }
    [[nodiscard]] const auto& d() const noexcept { return r.d; }
    [[nodiscard]] const auto& x() const noexcept { return tangent; }
    [[nodiscard]] const auto y() const noexcept { return vec3{ glm::cross(r.d,x()) }; }

    [[nodiscard]] constexpr auto x0() const { return initial_x_length; }
    [[nodiscard]] constexpr auto get_tan_alpha() const { return tan_alpha; }
    [[nodiscard]] constexpr auto get_e() const { return e; }
    [[nodiscard]] constexpr auto get_one_over_e() const { return one_over_e; }
    [[nodiscard]] constexpr auto get_z_apex() const { return z_apex; }

    /**
     * @brief Returns true if the elliptic cone is a ray: \alpha=0 && x0==0
     */
    [[nodiscard]] inline bool is_ray() const noexcept {
        return tan_alpha==0 && initial_x_length==0.0f;
    }

    [[nodiscard]] inline util::frame_t frame() const noexcept {
        return { x(),y(),d() };
    }

    /**
     * @brief Returns the major and minor axes (x and y), in local frame, of the elliptic cone cross-section after propagation a distance of z.
     * @param z distance of propagation
     * @return axes length
     */
    [[nodiscard]] inline vec2 axes(const float z) const noexcept {
        const auto r = tan_alpha*z + initial_x_length;
        return r*vec2{ 1,one_over_e };
    }

    /**
     * @brief Vectorized contains_local. Input assumed to be in metres.
     *
     * @param range restrict z distance to this range
     */
    [[nodiscard]] inline auto contains_local(const __m128& x,
                                             const __m128& y,
                                             const __m128& z,
                                             const range_t& range = { 0, +util::inf }) const noexcept {
        const auto x2  = _mm_mul_ps(x,x);
        const auto ey  = _mm_mul_ps(y, _mm_set1_ps(e));
        const auto ztx = _mm_fmadd_ps(z, _mm_set1_ps(tan_alpha), _mm_set1_ps(initial_x_length));

        const auto ey2 = _mm_mul_ps(ey, ey);
        const auto ztx2 = _mm_mul_ps(ztx,ztx);

        const auto cond1 = _mm_cmp_ps(_mm_set1_ps(z_apex), z, _CMP_LE_OQ);
        const auto cond2 = _mm_cmp_ps(_mm_set1_ps(range.min), z, _CMP_LE_OQ);
        const auto cond3 = _mm_cmp_ps(z, _mm_set1_ps(range.max), _CMP_LE_OQ);
        const auto cond4 = _mm_cmp_ps(_mm_add_ps(x2, ey2), ztx2, _CMP_LE_OQ);

        return _mm_and_ps(_mm_and_ps(cond1,cond2), _mm_and_ps(cond3,cond4));
    }

    /**
     * @brief Vectorized contains_local. Input assumed to be in metres.
     *
     * @param range restrict z distance to this range
     */
    [[nodiscard]] inline auto contains_local(const __m256& x,
                                             const __m256& y,
                                             const __m256& z,
                                             const range_t& range = { 0, +util::inf }) const noexcept {
        const auto x2  = _mm256_mul_ps(x,x);
        const auto ey  = _mm256_mul_ps(y, _mm256_set1_ps(e));
        const auto ztx = _mm256_fmadd_ps(z, _mm256_set1_ps(tan_alpha), _mm256_set1_ps(initial_x_length));

        const auto ey2 = _mm256_mul_ps(ey, ey);
        const auto ztx2 = _mm256_mul_ps(ztx,ztx);

        const auto cond1 = _mm256_cmp_ps(_mm256_set1_ps(z_apex), z, _CMP_LE_OQ);
        const auto cond2 = _mm256_cmp_ps(_mm256_set1_ps(range.min), z, _CMP_LE_OQ);
        const auto cond3 = _mm256_cmp_ps(z, _mm256_set1_ps(range.max), _CMP_LE_OQ);
        const auto cond4 = _mm256_cmp_ps(_mm256_add_ps(x2, ey2), ztx2, _CMP_LE_OQ);

        return _mm256_and_ps(_mm256_and_ps(cond1,cond2), _mm256_and_ps(cond3,cond4));
    }

};

struct aabb_t {
    vec3 min;
    vec3 max;

    aabb_t() noexcept = default;

    constexpr aabb_t(const vec3& min, const vec3& max) noexcept : min(min),max(max) {}
    explicit constexpr aabb_t(const vec3 &v) : min(v), max(v) {}

    aabb_t(const aabb_t&) = default;
    aabb_t& operator=(const aabb_t&) = default;

    [[nodiscard]] constexpr inline bool contains(const vec3& p) const noexcept {
        return glm::all(glm::greaterThanEqual(p, min)) &&
           glm::all(glm::lessThan(p, max));
    }

    [[nodiscard]] constexpr inline auto centre() const noexcept {
        return (max+min)/2.0f;
    }

    [[nodiscard]] constexpr inline auto extent() const noexcept {
        return max-min;
    }

    [[nodiscard]] constexpr static inline aabb_t from_points(const std::convertible_to<vec3> auto& ...pts) {
        return ((aabb_t{ static_cast<vec3>(pts) }) | ...);
    }
};

// intersection return types

struct intersect_cone_edge_ret_t {
    vec3 p0,p1;
    range_t range;
    int pts;
};

struct intersect_edge_circle_ret_t {
    int points = 0;
    vec2 u1,u2;
    float t1={},t2={};
};

struct intersect_ray_tri_ret_t {
    vec3 p;
    float dist = std::numeric_limits<float>::infinity();

    // TODO uemre: maybe util bary
    vec2 bary{ -1,-1 };;
};

struct intersect_cone_plane_ret_t {
    range_t range;
    vec3 near, far;
};

struct intersect_cone_tri_ret_t {
    range_t range;
    vec3 near, far;
};

/**
 * @brief Edge-ellipse (axis-aligned) intersection test
 * @tparam line If TRUE perform line-ellipse test (line that passes through points p0,p1)
 */
inline intersect_edge_circle_ret_t intersect_edge_ellipse(
        vec2 point0,
        vec2 point1,
        const float rx, const float ry) noexcept {
    const auto scale = vec2{ rx,ry };
    const auto recp_scale = 1.0f / scale;
    const auto p0 = point0 * recp_scale;
    const auto p1 = point1 * recp_scale;
    
    const auto d = p1-p0;

    const auto a = glm::dot(d,d);
    const auto b = 2*glm::dot(p0,d);
    const auto c = glm::dot(p0,p0) - 1;

    const auto det2 = b*b - 4*a*c;
    if (det2<=0 || a==0) return {};

    const auto recp_a = 1/a;
    const auto det = glm::sqrt(det2);
    auto t1 = 0.5f*(-b - glm::sign(b)*det)*recp_a;
    auto t2 = t1==0 ? -b*recp_a : c*recp_a/t1;
    if (t1>t2)
        std::swap(t1,t2);

    const bool u1valid = (t1>=0&&1>=t1);
    const bool u2valid = (t2>=0&&1>=t2);

    intersect_edge_circle_ret_t ret;
    ret.t1=t1;
    ret.t2=t2;

    if (!u1valid && !u2valid) {
        ret.points = 0;
        return ret;
    }

    if (u1valid && u2valid) {
        ret.points = 2;
        ret.u1 = (p0+t1*d) * scale;
        ret.u2 = (p0+t2*d) * scale;
        return ret;
    }
    
    ret.points = 1;
    ret.u1 = (u1valid ? p0+t1*d : p0+t2*d) * scale;
    ret.t1 = u1valid ? t1 : t2;
    ret.t2 = u1valid ? t2 : t1;
    return ret;
}

/**
 * @brief Returns the intersection point between a plane and an edge.
 */
inline std::optional<vec3> intersect_edge_plane(
        const vec3& p0,
        const vec3& p1,
        const vec3& pp,
        const vec3& n) noexcept {
    const auto d0 = glm::dot(pp-p0,n);
    const auto d1 = glm::dot(pp-p1,n);
    const auto E = p1-p0;
    const auto E_dot_N = glm::dot(E,n);

    if (glm::sign(d0)==glm::sign(d1) || E_dot_N==0)
        return std::nullopt;

    const auto d = d0 / E_dot_N;
    if (d>=0&&1>=d)
        return p0 + d*E;
    return std::nullopt;
}

/**
 * @brief line-plane intersection test.
 */
inline std::optional<float> intersect_line_plane(
        const vec3& p0,
        const vec3& p1,
        const vec3& pp,
        const vec3& n) noexcept {

    const auto dn = glm::dot(p1-p0,n);
    if (dn==0.0f)
        return std::nullopt;
    return (float)(glm::dot(pp-p0,n) / dn);
}

/**
 * @brief Ray-triangle intersection test.
 *        Möller–Trumbore ray-triangle intersection, 1997, 10.1080/10867651.1997.10487468
 */
inline bool test_ray_tri(const ray_t& r, 
                         const vec3& a,
                         const vec3& b,
                         const vec3& c,
                         const range_t& range = range_t::positive(),
                         const float tol = 0) noexcept {
    const auto ray = r.o-a;
    const auto e1 = b-a;
    const auto e2 = c-a;
    const auto crs = glm::cross(r.d, e2);
    auto det = glm::dot(e1, crs);
    if (det==0.0f)
        return false;

    const auto recp_det = 1.0f/det;

    const auto q = glm::cross(ray, e1);
    const auto qe2 = glm::dot(q, e2);

    const auto bx = glm::dot(ray, crs);
    const auto by = glm::dot(r.d, q);

    return bx*recp_det>=-tol && by*recp_det>=-tol && 
           (bx+by)*recp_det<=1+tol && 
           range.contains(qe2*recp_det);
}

/**
 * @brief Ray-triangle intersection.
 *        Möller–Trumbore ray-triangle intersection, 1997, 10.1080/10867651.1997.10487468
 */
inline std::optional<intersect_ray_tri_ret_t> intersect_ray_tri(
        const ray_t& r, 
        const vec3& a,
        const vec3& b,
        const vec3& c,
        const range_t& range = range_t::positive()) noexcept {
    const auto ray = r.o-a;
    const auto e1 = b-a;
    const auto e2 = c-a;
    const auto crs = glm::cross(r.d, e2);
    auto det = glm::dot(e1, crs);
    if (det==0.0f)
        return std::nullopt;

    const float sdet = det>=0.0f?1:-1;
    det *= sdet;

    const auto q = glm::cross(ray, e1);
    const auto qe2 = sdet*glm::dot(q, e2);

    auto bx = sdet*glm::dot(ray, crs);
    auto by = sdet*glm::dot(r.d, q);
    if (bx>=0.0f && by>=0.0f && bx+by<=det && (det*range).contains(qe2)) {
        const auto recp_det = 1.0f/det;
        const auto dist = qe2 * recp_det;

        const auto bux = bx * recp_det;
        const auto buy = by * recp_det;

        return intersect_ray_tri_ret_t{ r.o+dist*r.d, dist, vec2{ 1-bux-buy,bux } };
    }

    return std::nullopt;
}

/**
 * @brief Ray-AABB intersection test. Returns intersection range. If range is empty, no intersection occurs.
 */
inline range_t intersect_ray_aabb(
        const ray_t& r,
        aabb_t aabb) noexcept {
    const auto C = aabb.centre();
    const auto E = aabb.extent()/2.0f;

    const auto contains = util::all_elts_lte(r.o, C+E) && util::all_elts_gte(r.o, C-E);
    const auto a = r.invd*(r.o-C);
    const auto b = glm::abs(r.invd)*E;

    // avoid NaNs
    const vec3b d0  = glm::equal(vec3{ r.d }, vec3{ 0 });
    const auto inf = glm::mix(
            vec3(-util::inf), 
            vec3(+util::inf), 
            contains
        );

    const auto t1 = glm::mix(-b-a, -inf, d0);
    const auto t2 = glm::mix( b-a, +inf, d0);

    return { util::max_element(t1), util::min_element(t2) };
}

/**
 * @brief Ray-AABB intersection test. 
 */
inline bool test_ray_aabb(const ray_t& r,
                          aabb_t aabb,
                          const range_t& range = range_t::positive()) noexcept {
    const auto ret = intersect_ray_aabb(r,aabb);
    return !ret.empty() && ret.overlaps(range);
}

/**
 * @brief Edge-cone intersection test.
 * @tparam in_local TRUE if p0 and p1 are given in cone's local frame
 */
template <bool in_local=false>
inline std::optional<intersect_cone_edge_ret_t> intersect_cone_edge(
        const elliptic_cone_t& cone,
        const vec3& p0,
        const vec3& p1,
        const range_t& range = range_t::positive()) noexcept {
    vec3 localp0, localp1;
    const auto& frame = cone.frame();
    if constexpr (!in_local) {
        localp0 = frame.to_local(p0-cone.o());
        localp1 = frame.to_local(p1-cone.o());
    } else {
        localp0 = p0; localp1 = p1;
    }

    const bool p0closer = localp1.z>localp0.z;
    if (!p0closer) std::swap(localp0,localp1);

    const auto p = localp0, l = localp1-localp0;
    const auto x0 = cone.x0();
    const auto ta = cone.get_tan_alpha();
    const auto e  = cone.get_e();

    const auto c = util::sqr(p.x) + util::sqr(e*p.y) - util::sqr(p.z*ta + x0);
    const auto b = 2 * p.x*l.x + util::sqr(e)*p.y*l.y - l.z*ta*(p.z*ta + x0);
    const auto a = util::sqr(l.x) + util::sqr(e*l.y) - util::sqr(l.z*ta);

    const auto D = b*b - 4*a*c;
    if (D<0.0f)
        return std::nullopt;

    const auto sqrtD = glm::sqrt(D);
    auto t1 = (float)(b>=0.0f ? (-b-sqrtD)/(2*a) : (-b+sqrtD)/(2*a));
    auto t2 = (float)(-b/a)-t1;

    const auto zapex = cone.get_z_apex();
    if (p.z + t1*l.z<=zapex)
        t1 = util::inf;
    if (p.z + t2*l.z<zapex)
        t2 = util::inf;

    if (t2<t1) std::swap(t1,t2);
    auto z1 = t1<util::inf ? p.z + t1*l.z : -util::inf;
    auto z2 = t2<util::inf ? p.z + t2*l.z : util::inf;
    assert(z2>=z1);

    if (z1>range.max || z2<range.min || (!std::isfinite(z1) && !std::isfinite(z2)))
        return std::nullopt;

    if (range.min>zapex && z1<range.min) {
        if (const auto tmin = intersect_line_plane(p, p+l, vec3{ 0.0f ,0.0f ,range.min }, vec3{ 0,0,1 }); tmin) {
            t1 = *tmin;
            z1 = range.min;
        }
    } else {
        assert(std::isfinite(t1));
    }

    if (z2>range.max) {
        if (const auto tmax = intersect_line_plane(p, p+l, vec3{ 0,0,range.max }, vec3{ 0,0,1 }); tmax) {
            t2 = *tmax;
            z2 = range.max;
        }
    }

    std::optional<vec3> v1,v2;
    if ((t1>=0 && (1>=t1))) v1 = (p0closer ? p0 : p1) + t1 * (p0closer ? p1-p0 : p0-p1);
    else z1=z2;
    if ((t2>=0 && (1>=t2))) v2 = (p0closer ? p0 : p1) + t2 * (p0closer ? p1-p0 : p0-p1);
    else z2=z1;
    if (!v1 && !v2)
        return std::nullopt;

    intersect_cone_edge_ret_t ret;
    ret.range = { z1,z2 };
    ret.pts = v1&&v2 ? 2 : 1;
    ret.p0 = v1 ? *v1 : *v2;
    if (v1&&v2)
        ret.p1 = *v2;

    return ret;
}

/**
 * @brief Ray-AABB intersection test.
 * @tparam test_points_in_cone if FALSE, assume the caller knows a priori that ALL points are outside the cone.
 */
template <bool in_local=false>
inline bool test_cone_edge(const elliptic_cone_t& cone,
                        const vec3& p0,
                        const vec3& p1,
                        const range_t& range = range_t::positive()) noexcept {
    return !!intersect_cone_edge<in_local>(cone, p0, p1, range);
}

/**
 * @brief Cone-plane intersection test.
 *        Returns intersection range. If range is empty, no intersection occurs. Returned intersection points are always on the cone boundary.
 * @tparam in_local TRUE if n and d are given in cone's local frame
 */
template <bool in_local=false>
inline intersect_cone_plane_ret_t intersect_cone_plane(
        const elliptic_cone_t& cone,
        vec3 n, float d,
        const range_t& range = { 0.0f, util::inf }) noexcept {
    const auto& frame = cone.frame();
    if constexpr (!in_local) {
        d -= glm::dot(cone.o(),n);
        n = frame.to_local(n);
    }

    const auto x0 = cone.x0();
    const auto e = cone.get_one_over_e();
    
    // cross sectional cone position where intersection occurs
    const auto v_denom2 = util::sqr(n.x)+util::sqr(e*n.y);
    const auto v = v_denom2>0 ?
        vec2{ n.x, e*n.y } / glm::sqrt(v_denom2) :
        vec2{ 0,0 };
    const auto u =  v * vec2{1,e};
    const auto nu = glm::dot(n,vec3{ u,0 });

    const auto zapex = cone.get_z_apex();
    auto z01 = (d - x0*nu) / (n.z + cone.get_tan_alpha()*nu);
    auto z02 = (d + x0*nu) / (n.z - cone.get_tan_alpha()*nu);

    // classify and order
    const auto has_z01 = z01>=zapex && !std::isnan(z01);
    const auto has_z02 = z02>=zapex && !std::isnan(z02);
    if (!has_z01) z01 = util::inf;
    if (!has_z02) z02 = util::inf;
    // positions at intersection candidates
    auto p1 = has_z01 ? vec3{ (z01*cone.get_tan_alpha() + x0) *   u , z01 } : vec3(util::inf);
    auto p2 = has_z02 ? vec3{ (z02*cone.get_tan_alpha() + x0) * (-u), z02 } : vec3(util::inf);
    // reorder if needed
    if (z01>z02) {
        std::swap(z01,z02);
        std::swap(p1,p2);
    }

    auto rng = range_t{ z01,z02 };
    const bool empty = (!has_z01 && !has_z02) || (rng & range).empty();
    if (empty)
        return { .range=range_t::null(), };

    // utility function to compute the point closest to the cone's mean on the plane-plane intersection
    static constexpr auto closest_point_plane_plane_intersection = 
        [](const auto z, const auto u, const auto n, const auto d){
        float x0,y0;
        if (glm::abs(n.y)>glm::abs(n.x)) {
            y0 = (d-n.z*z)/n.y;
            x0 = n.x!=0.0f ? (d-n.z*z-n.y*y0)/n.x : 0.0f;
        } else {
            x0 = (d-n.z*z)/n.x;
            y0 = n.y!=0.0f ? (d-n.z*z-n.x*x0)/n.y : 0.0f;
        }

        return vec3{ (x0*u.x+y0*u.y)*u, z };
    };

    // clamp and transform to world, if needed
    if (std::isfinite(rng.min)) {
        if (rng.min<range.min) {
            // compute point on plane at z=range.min closest to x=0,y=0
            p1 = closest_point_plane_plane_intersection(range.min,v,n,d);
            rng.min = range.min;
        }
        if constexpr (!in_local)
            p1 = cone.o() + frame.to_world(p1);
    }
    if (std::isfinite(rng.max)) {
        if (rng.max>range.max) {
            // compute point on plane at z=range.max closest to x=0,y=0
            p2 = closest_point_plane_plane_intersection(range.max,v,n,d);
            rng.max = range.max;
        }
        if constexpr (!in_local)
            p2 = cone.o() + frame.to_world(p2);
    }

    return {
        .range = rng,
        .near  = p1,
        .far   = p2,
    };
}

/**
 * @brief Cone-plane intersection test.
 * @tparam in_local TRUE if n and d are given in cone's local frame
 */
template <bool in_local=false>
inline bool test_cone_plane(
        const elliptic_cone_t& cone,
        vec3 n, float d,
        const range_t& range = { 0.0f, util::inf }) noexcept {
    return !intersect_cone_plane<in_local>(cone, n,d, range).range.empty();
}

inline bool fast_check_if_intersection_possible_cone_aabb(
                const elliptic_cone_t& cone,
                const aabb_t& aabb,
                const range_t& range = range_t::positive()) noexcept {
    const auto& f = cone.frame();

    const auto c = aabb.centre()-cone.o();
    const auto e = aabb.extent()/2.0f;
    const auto rz = glm::abs(glm::dot(e,glm::abs(vec3{ f.n })));
    const auto rx = glm::abs(glm::dot(e,glm::abs(vec3{ f.t })));
    const auto ry = glm::abs(glm::dot(e,glm::abs(vec3{ f.b })));

    const auto minz = glm::dot(c,f.n) - rz;
    const auto maxz = glm::dot(c,f.n) + rz;
    const auto axes = cone.axes(maxz);

    const auto x = glm::dot(c,f.t);
    const auto y = glm::dot(c,f.b);
    return range_t{ minz,maxz }.overlaps(range) &&
           range_t{ x-rx,x+rx }.overlaps(range_t{ -axes.x,+axes.x }) &&
           range_t{ y-ry,y+ry }.overlaps(range_t{ -axes.y,+axes.y });
}

/**
 * @brief Cone-AABB intersection test. 
 *        Fast, conservative approximation
 */
inline bool test_cone_aabb(const elliptic_cone_t& cone,
                           const aabb_t& aabb,
                           const range_t& range_input = range_t::positive()) noexcept {
    // testing faces is slow. Instead, grow range by the AABB size: conservative approximation that avoids face checks.
    const auto grow = glm::abs(glm::dot(aabb.extent(),cone.d()));
    const auto range = range_input.grow(grow);

    // Various fast accepts
    if (aabb.contains(cone.o() + range.min * cone.d()) ||
        aabb.contains(cone.o() + range.max * cone.d()))
        return true;
    if (test_ray_aabb(cone.ray(), aabb, range))
        return true;
    if (cone.is_ray())
        return false;

    // Fast reject: intersect per-axis cone AABB
    if (!fast_check_if_intersection_possible_cone_aabb(cone, aabb, range_input))
        return false;

    const auto& frame = cone.frame();
    const auto& o = cone.o();

    const auto ox = _mm256_set1_ps(o.x);
    const auto oy = _mm256_set1_ps(o.y);
    const auto oz = _mm256_set1_ps(o.z);

    const auto aabb_min_x = _mm256_set1_ps(aabb.min.x);
    const auto aabb_min_y = _mm256_set1_ps(aabb.min.y);
    const auto aabb_min_z = _mm256_set1_ps(aabb.min.z);
    const auto aabb_max_x = _mm256_set1_ps(aabb.max.x);
    const auto aabb_max_y = _mm256_set1_ps(aabb.max.y);
    const auto aabb_max_z = _mm256_set1_ps(aabb.max.z);

    auto vs_x = _mm256_blend_ps(aabb_min_x,aabb_max_x, 0xAA);
    auto vs_y = _mm256_blend_ps(aabb_min_y,aabb_max_y, 0xCC);
    auto vs_z = _mm256_blend_ps(aabb_min_z,aabb_max_z, 0xF0);

    vs_x = _mm256_sub_ps(vs_x, ox);
    vs_y = _mm256_sub_ps(vs_y, oy);
    vs_z = _mm256_sub_ps(vs_z, oz);

    frame.to_local(vs_x,vs_y,vs_z);

    const auto contains = cone.contains_local(vs_x,vs_y,vs_z,range);

    for (int i=0;i<8;++i)
        if (!(contains[i]==0)) return true;

    // test all edges
    // TODO: vectorize
    static constexpr vec2u32_t edges[12] = {
        { 0,1 },{ 1,3 },{ 2,3 },{ 0,2 },
        { 4,5 },{ 5,7 },{ 6,7 },{ 4,6 },
        { 0,4 },{ 1,5 },{ 3,7 },{ 2,6 },
    };
    for (const auto& edge : edges) {
        const auto& p0 = vec3{ vs_x[edge.x], vs_y[edge.x], vs_z[edge.x] };
        const auto& p1 = vec3{ vs_x[edge.y], vs_y[edge.y], vs_z[edge.y] };
        if (test_cone_edge<true>(cone,p0,p1, range))
            return true;
    }

    return false;
}

/**
 * @brief Cone-AABB intersection test. Returns intersection range. If range is empty, no intersection occurs.
 */
inline range_t intersect_cone_aabb(
        const elliptic_cone_t& cone,
        const aabb_t& aabb,
        const range_t& range = range_t::positive()) noexcept {
    if (cone.is_ray())
        return intersect_ray_aabb(cone.ray(), aabb) & range;

    // Fast reject: intersect per-axis cone AABB
    if (!fast_check_if_intersection_possible_cone_aabb(cone, aabb, range))
        return range_t::null();

    const auto& frame = cone.frame();
    const auto& o = cone.o();

    // transform to local
    const auto ox = _mm256_set1_ps(o.x);
    const auto oy = _mm256_set1_ps(o.y);
    const auto oz = _mm256_set1_ps(o.z);

    const auto aabb_min_x = _mm256_set1_ps(aabb.min.x);
    const auto aabb_min_y = _mm256_set1_ps(aabb.min.y);
    const auto aabb_min_z = _mm256_set1_ps(aabb.min.z);
    const auto aabb_max_x = _mm256_set1_ps(aabb.max.x);
    const auto aabb_max_y = _mm256_set1_ps(aabb.max.y);
    const auto aabb_max_z = _mm256_set1_ps(aabb.max.z);

    auto vs_x = _mm256_blend_ps(aabb_min_x,aabb_max_x, 0xAA);
    auto vs_y = _mm256_blend_ps(aabb_min_y,aabb_max_y, 0xCC);
    auto vs_z = _mm256_blend_ps(aabb_min_z,aabb_max_z, 0xF0);

    vs_x = _mm256_sub_ps(vs_x, ox);
    vs_y = _mm256_sub_ps(vs_y, oy);
    vs_z = _mm256_sub_ps(vs_z, oz);
    frame.to_local(vs_x,vs_y,vs_z);

    // find points in cone
    const auto contains = cone.contains_local(vs_x,vs_y,vs_z,range);

    // min/max z
    const auto pinf = _mm256_set1_ps(+util::inf);
    const auto minf = _mm256_set1_ps(-util::inf);
    const auto z_or_pinf = _mm256_blendv_ps(pinf,vs_z, contains);
    const auto z_or_minf = _mm256_blendv_ps(minf,vs_z, contains);

    const auto maxz_h = _mm256_permute2f128_ps(vs_z, z_or_minf, 0x21);
    const auto maxz_l = _mm256_permute2f128_ps(vs_z, z_or_minf, 0x30);
    const auto minz_h = _mm256_permute2f128_ps(vs_z, z_or_pinf, 0x21);
    const auto minz_l = _mm256_permute2f128_ps(vs_z, z_or_pinf, 0x30);
    const auto maxz_lh = _mm256_max_ps(maxz_l, maxz_h);
    const auto minz_lh = _mm256_min_ps(minz_l, minz_h);

    // z range of vertices
    auto possible_range = range_t{ 
        .min = glm::min(glm::min(minz_lh[0], minz_lh[1]), glm::min(minz_lh[2], minz_lh[3])),
        .max = glm::max(glm::max(maxz_lh[0], maxz_lh[1]), glm::max(maxz_lh[2], maxz_lh[3]))
    };
    // z range of vertices in cone
    auto ret = range_t{ 
        .min = glm::min(glm::min(minz_lh[4], minz_lh[5]), glm::min(minz_lh[6], minz_lh[7])),
        .max = glm::max(glm::max(maxz_lh[4], maxz_lh[5]), glm::max(maxz_lh[6], maxz_lh[7]))
    };

    possible_range &= range;
    if (possible_range.empty())
        return range_t::null();

    if (aabb.contains(cone.o() + range.min * cone.d()))
        ret |= range_t::range(range.min);
    if (range.max<util::inf && aabb.contains(cone.o() + range.max * cone.d()))
        ret |= range_t::range(range.max);

    // test all edges
    static constexpr vec2u32_t edges[12] = {
        { 0,1 },{ 1,3 },{ 2,3 },{ 0,2 },
        { 4,5 },{ 5,7 },{ 6,7 },{ 4,6 },
        { 0,4 },{ 1,5 },{ 3,7 },{ 2,6 },
    };
    for (const auto& edge : edges) {
        if (!(contains[edge.x]==0) && !(contains[edge.y]==0))
            continue;
        const auto& p0 = vec3{ vs_x[edge.x], vs_y[edge.x], vs_z[edge.x] };
        const auto& p1 = vec3{ vs_x[edge.y], vs_y[edge.y], vs_z[edge.y] };
        const auto ice = intersect_cone_edge<true>(cone,p0,p1, range_t::all());
        if (ice)
            ret |= range_t{ ice->p0.z, ice->pts==1 ? ice->p0.z: ice->p1.z };
    }

    // test faces
    static constexpr int facesv0[6] = { 0,4,0,2,0,1 };
    static std::array<vec3,6> normals = {
        vec3{ 0,0,-1 },
        vec3{ 0,0,+1 },
        vec3{ 0,-1,0 },
        vec3{ 0,+1,0 },
        vec3{ -1,0,0 },
        vec3{ +1,0,0 },
    };

    for (int i=0;i<6;++i) {
        const auto& a = vec3{ vs_x[facesv0[i]], vs_y[facesv0[i]], vs_z[facesv0[i]] };
        const auto n = frame.to_local(normals[i]);
        const auto d = glm::dot(a,n);

        const auto icp = intersect_cone_plane<true>(cone, n, d, range);
        if (icp.range.empty()) continue;

        constexpr auto test_point_in_aabb = [](const auto& aabb, const auto& wp, const auto& ln) {  // wp - world point, ln - local normal
            for (int i=0;i<3;++i) {
                if (ln[i]!=0) continue;
                if (aabb.min[i]>wp[i] || aabb.max[i]<wp[i])
                    return false;
            }
            return true;
        };

        // check if points are in AABB
        if (test_point_in_aabb(aabb, frame.to_world(icp.near)+o, normals[i]))
            ret |= range_t::range(icp.range.min);
        if (icp.range.length()>0.0f && test_point_in_aabb(aabb, frame.to_world(icp.far)+o, normals[i]))
            ret |= range_t::range(icp.range.max);
    }

    return ret & possible_range;
}

/**
 * @brief Cone-triangle boolean intersection test.
 */
inline bool test_cone_tri(const elliptic_cone_t& cone,
                          const vec3& a,
                          const vec3& b,
                          const vec3& c,
                          const range_t& range = range_t::positive()) noexcept {
    if (test_ray_tri(cone.ray(), a,b,c, range))
        return true;

    const auto& frame = cone.frame();
    const auto& o = cone.o();

    // transform to local
    const auto ox = _mm_set1_ps(o.x);
    const auto oy = _mm_set1_ps(o.y);
    const auto oz = _mm_set1_ps(o.z);

    auto vs_x = __m128{ a.x,b.x,c.x };
    auto vs_y = __m128{ a.y,b.y,c.y };
    auto vs_z = __m128{ a.z,b.z,c.z };
    vs_x = _mm_sub_ps(vs_x, ox);
    vs_y = _mm_sub_ps(vs_y, oy);
    vs_z = _mm_sub_ps(vs_z, oz);
    frame.to_local(vs_x,vs_y,vs_z);

    // find points in cone
    const auto contains = cone.contains_local(vs_x,vs_y,vs_z,range);

    if (util::max(vs_z[0],vs_z[1],vs_z[2]) < range.min ||
        util::min(vs_z[0],vs_z[1],vs_z[2]) > range.max)
        return false;

    if (!(contains[0]==0) ||
        !(contains[1]==0) ||
        !(contains[2]==0) ||
        test_cone_edge<true>(cone, vec3{ vs_x[0],vs_y[0],vs_z[0] },  vec3{ vs_x[1],vs_y[1],vs_z[1] }, range) ||
        test_cone_edge<true>(cone, vec3{ vs_x[0],vs_y[0],vs_z[0] },  vec3{ vs_x[2],vs_y[2],vs_z[2] }, range) ||
        test_cone_edge<true>(cone, vec3{ vs_x[1],vs_y[1],vs_z[1] },  vec3{ vs_x[2],vs_y[2],vs_z[2] }, range))
        return true;

    if (range.min<=0.0f)
        return false;
    
    // finally, does the triangle intersect the near/far clip planes?
    // this is faster than testing a cone--plane intersection
    vec2 Ns[2], Fs[2];
    int ns=0,fs=0;
    for (int i=0;i<3;++i) {
        const auto j = (i+1)%3;
        const auto vsi = vec3{ vs_x[i],vs_y[i],vs_z[i] };
        const auto vsj = vec3{ vs_x[j],vs_y[j],vs_z[j] };

        const auto np = range.min>0.0f ?
            intersect_edge_plane(vsi,vsj, vec3{ 0,0,range.min }, vec3{ 0,0,1 }) : std::nullopt;
        const auto fp = range.max<util::inf ? 
            intersect_edge_plane(vsi,vsj, vec3{ 0,0,range.max }, vec3{ 0,0,1 }) : std::nullopt;
        if (np && ns<2) Ns[ns++] = vec2{ *np };
        if (fp && fs<2) Fs[fs++] = vec2{ *fp };
    }
    if (ns==2) {
        const auto& axes = cone.axes(range.min);
        if (intersect_edge_ellipse(Ns[0], Ns[1], axes.x, axes.y).points>0)
            return true;
    }
    if (fs==2) {
        const auto& axes = cone.axes(range.max);
        if (intersect_edge_ellipse(Fs[0], Fs[1], axes.x, axes.y).points>0)
            return true;
    }

    return false;
}

/**
 * @brief Cone-triangle intersection test. Returns minimal distance to intersection, if any.
 */
inline std::optional<intersect_cone_tri_ret_t> intersect_cone_tri(const elliptic_cone_t& cone,
                                                                  const vec3& a,
                                                                  const vec3& b,
                                                                  const vec3& c,
                                                                  const vec3& n,
                                                                  const range_t& range = range_t::positive()) noexcept {
    if (cone.is_ray()) {
        // degenerate case: ray-triangle intersection
        const auto cr = intersect_ray_tri(cone.ray(), a,b,c, range);
        if (cr)
            return intersect_cone_tri_ret_t{ .range = range_t::range(cr->dist), .near = cr->p, .far = cr->p };
        return std::nullopt;
    }

    const auto& frame = cone.frame();
    const auto& o = cone.o();

    // transform to local
    const auto ox = _mm_set1_ps(o.x);
    const auto oy = _mm_set1_ps(o.y);
    const auto oz = _mm_set1_ps(o.z);

    auto vs_x = __m128{ a.x,b.x,c.x };
    auto vs_y = __m128{ a.y,b.y,c.y };
    auto vs_z = __m128{ a.z,b.z,c.z };
    vs_x = _mm_sub_ps(vs_x, ox);
    vs_y = _mm_sub_ps(vs_y, oy);
    vs_z = _mm_sub_ps(vs_z, oz);
    frame.to_local(vs_x,vs_y,vs_z);

    // find points in cone
    const auto contains = cone.contains_local(vs_x,vs_y,vs_z,range);

    // fast reject: all points before near clip or beyond far clip
    const auto closest_z  = util::min(vs_z[0],vs_z[1],vs_z[2]);
    const auto farthest_z = util::max(vs_z[0],vs_z[1],vs_z[2]);
    if (farthest_z < range.min ||
        closest_z  > range.max)
        return std::nullopt;

    vec3 pn, pf;
    bool has_near = false, has_far = false;

    for (int i=0;i<3;++i) {
        if (!(contains[i]==0)) {
            if (float(vs_z[i])==closest_z) {
                has_near = true;
                pn = vec3{ vs_x[i], vs_y[i], vs_z[i] };
            }
            if (float(vs_z[i])==farthest_z) {
                has_far = true;
                pf = vec3{ vs_x[i], vs_y[i], vs_z[i] };
            }
        }
    }
    if (has_near && has_far)
        return intersect_cone_tri_ret_t{ .range = range_t::range(pn.z,pf.z),
                                         .near  = frame.to_world(pn)+o,
                                         .far   = frame.to_world(pf)+o };

    // find closest point on cone-plane intersection conic section
    const auto ln = frame.to_local(n);
    const auto vs0 = vec3{ vs_x[0], vs_y[0], vs_z[0] };
    const auto vs1 = vec3{ vs_x[1], vs_y[1], vs_z[1] };
    const auto vs2 = vec3{ vs_x[2], vs_y[2], vs_z[2] };
    const auto icp = intersect_cone_plane<true>(cone, ln, glm::dot(vs0,ln), range);
    if (!icp.range.empty()) {
        if (!has_near &&
            util::is_point_in_triangle(icp.near, vs0, vs1, vs2)) {
            has_near = true;
            pn = icp.near;
        }
        if (!has_far &&
            std::isfinite(icp.range.max) &&
            util::is_point_in_triangle(icp.far, vs0, vs1, vs2)) {
            has_far = true;
            pf = icp.far;
        }
    }
    if (has_near && has_far)
        return intersect_cone_tri_ret_t{ .range = range_t::range(pn.z,pf.z),
                                         .near  = frame.to_world(pn)+o,
                                         .far   = frame.to_world(pf)+o };

    // test all edges
    for (int i=0;i<3;++i) {
        const auto j = (i+1)%3;
        if (!(contains[i]==0) && !(contains[j]==0)) continue;
        const auto& a = vec3{ vs_x[i], vs_y[i], vs_z[i] };
        const auto& b = vec3{ vs_x[j], vs_y[j], vs_z[j] };

        const auto cp = intersect_cone_edge<true>(cone, a,b, range);
        if (!cp) continue;
        if (!has_near || pn.z>cp->p0.z) {
            pn = cp->p0;
            has_near = true;
        }
        if (!has_far || pf.z<cp->p0.z) {
            pf = cp->p0;
            has_far = true;
        }
        if (!has_far || pf.z<cp->p1.z) {
            pf = cp->p1;
            has_far = true;
        }
    }
    
    if (!has_near || !has_far)
        return std::nullopt;

    // can happen due to numerics
    if (!has_near) pn=pf;
    if (!has_far)  pf=pn;

    return intersect_cone_tri_ret_t{ .range = range_t::range(pn.z,pf.z),
                                     .near  = frame.to_world(pn)+o,
                                     .far   = frame.to_world(pf)+o };
}

}  // namespace intersection_release