#version 460 core

const uint LOCAL_SIZE = 128u;

layout(local_size_x = LOCAL_SIZE, local_size_y = 1, local_size_z = 1) in;

struct point_light {
    vec4  position;
    vec4  colour;
    float intensity;
    float radius;
};

const uint MAX_NUM_LIGHTS = 100u;

struct cluster {
    vec4 min_point;
    vec4 max_point;
    uint count;
    uint light_indices[MAX_NUM_LIGHTS];
};

layout(std430, binding = 1) restrict buffer cluster_buf {
    cluster clusters[];
};

layout(std430, binding = 2) restrict buffer lights_buf {
    point_light point_lights[];
};

uniform mat4 view_mat;

bool test_sphere_aabb(uint i, cluster c);

// each invocation of main() is a thread processing a cluster
void main() {
    uint    n_lights    = point_lights.length();
    uint    cluster_idx = gl_WorkGroupID.x * LOCAL_SIZE + gl_LocalInvocationID.x;
    cluster c           = clusters[cluster_idx];

    // we need to reset count because culling runs every frame.
    // otherwise it would accumulate.
    c.count = 0;

    for (uint i = 0; i < n_lights; ++i) {
        if (test_sphere_aabb(i, c) && c.count < 100) {
            c.light_indices[c.count] = i;
            c.count++;
        }
    }

    clusters[cluster_idx] = c;
}

bool sphere_aabb_intersection(vec3 center, float radius, vec3 aabb_min, vec3 aabb_max) {
    // closest point on the AABB to the sphere center
    vec3 closest_point = clamp(center, aabb_min, aabb_max);

    // squared distance between the sphere center and closest point
    float distance_squared = dot(closest_point - center, closest_point - center);
    return distance_squared <= radius * radius;
}

// this just unpacks data for sphere_aabb_intersection
bool test_sphere_aabb(uint i, cluster c) {
    vec3  center = vec3(view_mat * point_lights[i].position);
    float radius = point_lights[i].radius;

    vec3 aabb_min = c.min_point.xyz;
    vec3 aabb_max = c.max_point.xyz;

    return sphere_aabb_intersection(center, radius, aabb_min, aabb_max);
}