struct VertexOutput {
    @builtin(position) vert_pos: vec4f,
    @location(0) pos: vec4f
}

@vertex
fn vs_main(@builtin(vertex_index) VertexIndex: u32) -> VertexOutput {
  var pos = array<vec2f, 6>(
    vec2(-1.0, 1.0),
    vec2(-1.0, -1.0),
    vec2(1.0, -1.0),

    vec2(1.0, -1.0),
    vec2(1.0, 1.0),
    vec2(-1.0, 1.0)
  );

  var p = vec4f(pos[VertexIndex], 0.0, 1.0);

  return VertexOutput(p, p);
}

@group(0) @binding(0) var<uniform> time: f32;
@group(0) @binding(1) var<uniform> aspect: f32;

@fragment
fn fs_main(input: VertexOutput) -> @location(0) vec4<f32> {
    spheres[0].center += time;

    var p = vec3<f32>(input.pos.xy, 0.0);
    p.x *= aspect;
    let camera_pos = vec3<f32>(0.0, 0.0, 1.0);
    let ray_dir = p - camera_pos;

    let r = Ray(camera_pos, ray_dir);
    return vec4<f32>(ray_color(r), 1.0);
}

struct Ray {
    origin: vec3<f32>,
    direction: vec3<f32>,
};

struct Sphere {
    center: vec3<f32>,
    radius: f32,
};

fn unit_vector(v: vec3<f32>) -> vec3<f32> {
    return v / length(v);
}

fn length_squared(v: vec3<f32>) -> f32 {
    let len = length(v);
    return len * len;
}

fn hit_sphere(r: Ray, s: Sphere) -> f32 {
    let oc = s.center - r.origin;
    let a = length_squared(r.direction);
    let h = dot(r.direction, oc);
    let c = length_squared(oc) - s.radius * s.radius;
    let discriminant = h * h - a * c;
    
    if discriminant > 0.0 {
        return (h - sqrt(discriminant)) / a;
    }

    return -1.0;
}

fn at(r: Ray, time: f32) -> vec3<f32> {
    return r.origin + time * r.direction;
}

fn get_normal_color(r: Ray, pos: vec3<f32>, s: Sphere) -> vec3<f32> {
    let normal = unit_vector(pos - s.center);
    return 0.5 * (normal + vec3<f32>(1.0, 1.0, 1.0));
}

struct HitInfo {
    pos: vec3<f32>,
    hit: bool,
    index: i32,
};

const sphere_count = 5;
var<private> spheres = array<Sphere, sphere_count>(
    Sphere(vec3<f32>(0.0, 0.0, -1.0), 0.5),
    Sphere(vec3<f32>(2.0, 0.0, 0.0), 0.25),
    Sphere(vec3<f32>(1.0, 1.0, -1.0), 0.25),
    Sphere(vec3<f32>(-1.0, 0.5, -0.8), 0.25),
    Sphere(vec3<f32>(0.0, -100.5, -1.0), 100.0)
);

fn hit_spheres(r: Ray) -> HitInfo {
    var t_min = 0.0;
    var index = 0;
    var hit = false;

    for (var i = 0; i < sphere_count; i = i + 1) {
        let s = spheres[i];
        let t = hit_sphere(r, s);
        if t > 0.0 && (!hit || t < t_min) {
            t_min = t;
            index = i;
            hit = true;
        }
    }

    return HitInfo(select(vec3<f32>(0.0, 0.0, 0.0), at(r, t_min), hit), hit, index);
}

fn ray_color(r: Ray) -> vec3<f32> {
    let info = hit_spheres(r);
    if info.hit {
        let sun_dir = vec3<f32>(-10.0, 10.0, 10.0) - info.pos;
        let sun_ray = Ray(info.pos, sun_dir);
        let sun_hit = hit_spheres(sun_ray);

        return get_normal_color(r, info.pos, 
            spheres[info.index]) - 
            (select(0.0, 0.4, sun_hit.hit && sun_hit.index != info.index));
    } 

    let unit_direction = unit_vector(r.direction);
    let a = 0.5 * (unit_direction.y + 1.0);
    return (1.0 - a) * vec3<f32>(1.0, 1.0, 1.0) + a * vec3<f32>(0.5, 0.7, 1.0);
}