#version 460 core

in vec2 pos;
out vec4 color;

layout(location = 0) uniform float aspect_ratio;
layout(location = 1) uniform float sphere_y;

layout(std430, binding = 0) buffer data
{
    int ray_count;
};

struct Ray {
    vec3 origin;
    vec3 direction;
};

struct Sphere {
    vec3 center;
    float radius;
};

vec3 unit_vector(vec3 v) {
    return v / length(v);
}

float length_squared(vec3 v) {
    float length = length(v);
    return length * length;
}

float hit_sphere(Ray r, Sphere s) {
    vec3 oc = s.center - r.origin;
    float a = length_squared(r.direction);
    float h = dot(r.direction, oc);
    float c = length_squared(oc) - s.radius * s.radius;
    float discriminant = h * h - a * c;
    
    if(discriminant > 0.0) {
        return (h - sqrt(discriminant)) / a;
    }

    return -1.0;
}

vec3 at(Ray r, float time) {
    return r.origin + time * r.direction;
}

vec3 get_normal_color(Ray r, vec3 pos, Sphere s) {
    vec3 normal = unit_vector(pos - s.center);
    return 0.5 * vec3(normal + vec3(1.0, 1.0, 1.0));
}

Sphere spheres[] = {
    Sphere(vec3(0, sphere_y, -1), 0.5),
    Sphere(vec3(2, 0, 0), 0.25),
    Sphere(vec3(0, -100.5, -1), 100.0)
};

struct HitInfo {
    vec3 pos;
    bool hit;
    int index;
};

HitInfo hit_spheres(Ray r) {
    float t_min = 0;
    int index;
    bool hit = false;

    for(int i = 0; i < spheres.length(); i++) {
        Sphere s = spheres[i];
        float t = hit_sphere(r, s);
        if(t > 0.0 && (!hit || t < t_min)) {
            t_min = t;
            index = i;
            hit = true;
        }
    }

    return HitInfo(hit ? at(r, t_min) : vec3(0.0), hit, index);
}

vec3 ray_color(Ray r) {
    HitInfo info = hit_spheres(r);
    if(info.hit) {
        vec3 sun_dir = vec3(10, 10, 10) - info.pos;
        Ray sun_ray = Ray(info.pos, sun_dir);
        HitInfo sun_hit = hit_spheres(sun_ray);

        return get_normal_color(r, info.pos, 
            spheres[info.index]) - 
            (sun_hit.hit && sun_hit.index != info.index ? 0.4 : 0.0);
    } 

    vec3 unit_direction = unit_vector(r.direction);
    float a = 0.5 * (unit_direction.y + 1.0);
    return (1.0 - a) * vec3(1.0, 1.0, 1.0) + a * vec3(0.5, 0.7, 1.0);
}

void main() {
    vec3 p = vec3(pos, 0.0);
    p.x *= aspect_ratio;

    vec3 camera_pos = vec3(0.0, 0.0, 0.0);
    camera_pos.z += 1.0;
    vec3 ray_dir = p - camera_pos;

    Ray r = Ray(camera_pos, ray_dir);
    color = vec4(ray_color(r), 1.0);
}