From cb2a1fdddb11cf69f90c0da55262324efd08eeee Mon Sep 17 00:00:00 2001 From: JMS55 <47158642+JMS55@users.noreply.github.com> Date: Sun, 6 Jul 2025 11:20:31 -0400 Subject: [PATCH] Cleanup GI --- .../bevy_solari/src/realtime/restir_gi.wgsl | 123 ++++++++++-------- 1 file changed, 70 insertions(+), 53 deletions(-) diff --git a/crates/bevy_solari/src/realtime/restir_gi.wgsl b/crates/bevy_solari/src/realtime/restir_gi.wgsl index aae0db1942..189121cb9a 100644 --- a/crates/bevy_solari/src/realtime/restir_gi.wgsl +++ b/crates/bevy_solari/src/realtime/restir_gi.wgsl @@ -43,61 +43,13 @@ fn initial_and_temporal(@builtin(global_invocation_id) global_id: vec3) { let base_color = pow(unpack4x8unorm(gpixel.r).rgb, vec3(2.2)); let diffuse_brdf = base_color / PI; - let temporal_reservoir = load_temporal_reservoir(global_id.xy, depth, world_position, world_normal, &rng); - - let ray_direction = sample_uniform_hemisphere(world_normal, &rng); - let ray_hit = trace_ray(world_position, ray_direction, RAY_T_MIN, RAY_T_MAX, RAY_FLAG_NONE); - if ray_hit.kind == RAY_QUERY_INTERSECTION_NONE { - gi_reservoirs_b[pixel_index] = temporal_reservoir; - return; - } - let sample_point = resolve_ray_hit_full(ray_hit); - if all(sample_point.material.emissive != vec3(0.0)) { - gi_reservoirs_b[pixel_index] = temporal_reservoir; - return; - } - let sample_point_diffuse_brdf = sample_point.material.base_color / PI; - let direct_lighting = sample_random_light(sample_point.world_position, sample_point.world_normal, &rng); - let sample_point_radiance = direct_lighting.radiance * sample_point_diffuse_brdf; - - let cos_theta = dot(ray_direction, world_normal); - let inverse_uniform_hemisphere_pdf = PI_2; - - var combined_reservoir = empty_reservoir(); - combined_reservoir.confidence_weight = 1.0 + temporal_reservoir.confidence_weight; - - let mis_weight_denominator = 1.0 / combined_reservoir.confidence_weight; - - let new_mis_weight = mis_weight_denominator; - let new_target_function = luminance(sample_point_radiance * diffuse_brdf * cos_theta); - let new_inverse_pdf = direct_lighting.inverse_pdf * inverse_uniform_hemisphere_pdf; - let new_resampling_weight = new_mis_weight * (new_target_function * new_inverse_pdf); - - let temporal_mis_weight = temporal_reservoir.confidence_weight * mis_weight_denominator; - let temporal_cos_theta = dot(normalize(temporal_reservoir.sample_point_world_position - world_position), world_normal); - let temporal_target_function = luminance(temporal_reservoir.radiance * diffuse_brdf * temporal_cos_theta); - let temporal_resampling_weight = temporal_mis_weight * (temporal_target_function * temporal_reservoir.unbiased_contribution_weight); - - combined_reservoir.weight_sum = new_resampling_weight + temporal_resampling_weight; - - if rand_f(&rng) < temporal_resampling_weight / combined_reservoir.weight_sum { - combined_reservoir.sample_point_world_position = temporal_reservoir.sample_point_world_position; - combined_reservoir.radiance = temporal_reservoir.radiance; - - let inverse_target_function = select(0.0, 1.0 / temporal_target_function, temporal_target_function > 0.0); - combined_reservoir.unbiased_contribution_weight = combined_reservoir.weight_sum * inverse_target_function; - } else { - combined_reservoir.sample_point_world_position = sample_point.world_position; - combined_reservoir.radiance = sample_point_radiance; - - let inverse_target_function = select(0.0, 1.0 / new_target_function, new_target_function > 0.0); - combined_reservoir.unbiased_contribution_weight = combined_reservoir.weight_sum * inverse_target_function; - } + let initial_reservoir = generate_initial_reservoir(world_position, world_normal, diffuse_brdf, &rng); + let temporal_reservoir = load_temporal_reservoir(global_id.xy, depth, world_position, world_normal, diffuse_brdf, &rng); + let combined_reservoir = merge_reservoirs(initial_reservoir, temporal_reservoir, &rng); gi_reservoirs_b[pixel_index] = combined_reservoir; } - @compute @workgroup_size(8, 8, 1) fn spatial_and_shade(@builtin(global_invocation_id) global_id: vec3) { if any(global_id.xy >= vec2u(view.viewport.zw)) { return; } @@ -127,7 +79,39 @@ fn spatial_and_shade(@builtin(global_invocation_id) global_id: vec3) { textureStore(view_output, global_id.xy, pixel_color); } -fn load_temporal_reservoir(pixel_id: vec2, depth: f32, world_position: vec3, world_normal: vec3, rng: ptr) -> Reservoir { +fn generate_initial_reservoir(world_position: vec3, world_normal: vec3, diffuse_brdf: vec3, rng: ptr) -> Reservoir{ + var reservoir = empty_reservoir(); + + let ray_direction = sample_uniform_hemisphere(world_normal, rng); + let ray_hit = trace_ray(world_position, ray_direction, RAY_T_MIN, RAY_T_MAX, RAY_FLAG_NONE); + + if ray_hit.kind == RAY_QUERY_INTERSECTION_NONE { + return reservoir; + } + + let sample_point = resolve_ray_hit_full(ray_hit); + + if all(sample_point.material.emissive != vec3(0.0)) { + return reservoir; + } + + reservoir.sample_point_world_position = sample_point.world_position; + reservoir.confidence_weight = 1.0; + + let sample_point_diffuse_brdf = sample_point.material.base_color / PI; + let direct_lighting = sample_random_light(sample_point.world_position, sample_point.world_normal, rng); + reservoir.radiance = direct_lighting.radiance * sample_point_diffuse_brdf; + + let inverse_uniform_hemisphere_pdf = PI_2; + reservoir.unbiased_contribution_weight = direct_lighting.inverse_pdf * inverse_uniform_hemisphere_pdf; + + let cos_theta = dot(ray_direction, world_normal); + reservoir.target_function = luminance(reservoir.radiance * diffuse_brdf * cos_theta); + + return reservoir; +} + +fn load_temporal_reservoir(pixel_id: vec2, depth: f32, world_position: vec3, world_normal: vec3, diffuse_brdf: vec3, rng: ptr) -> Reservoir { let motion_vector = textureLoad(motion_vectors, pixel_id, 0).xy; let temporal_pixel_id_float = vec2(pixel_id) - (motion_vector * view.viewport.zw); @@ -181,6 +165,9 @@ fn load_temporal_reservoir(pixel_id: vec2, depth: f32, world_position: vec3 temporal_reservoir.confidence_weight = min(temporal_reservoir.confidence_weight, CONFIDENCE_WEIGHT_CAP); + let temporal_cos_theta = dot(normalize(temporal_reservoir.sample_point_world_position - world_position), world_normal); + temporal_reservoir.target_function = luminance(temporal_reservoir.radiance * diffuse_brdf * temporal_cos_theta); + return temporal_reservoir; } @@ -224,7 +211,7 @@ struct Reservoir { radiance: vec3, confidence_weight: f32, unbiased_contribution_weight: f32, - padding1: f32, + target_function: f32, padding2: f32, padding3: f32, } @@ -241,3 +228,33 @@ fn empty_reservoir() -> Reservoir { 0.0, ); } + +fn merge_reservoirs(canonical_reservoir: Reservoir,other_reservoir: Reservoir,rng: ptr) -> Reservoir { + // TODO: Balance heuristic MIS weights + let mis_weight_denominator = 1.0 / (canonical_reservoir.confidence_weight + other_reservoir.confidence_weight); + + let canonical_mis_weight = canonical_reservoir.confidence_weight * mis_weight_denominator; + let canonical_resampling_weight = canonical_mis_weight * (canonical_reservoir.target_function * canonical_reservoir.unbiased_contribution_weight); + + let other_mis_weight = other_reservoir.confidence_weight * mis_weight_denominator; + let other_resampling_weight = other_mis_weight * (other_reservoir.target_function * other_reservoir.unbiased_contribution_weight); + + var combined_reservoir = empty_reservoir(); + combined_reservoir.weight_sum = canonical_resampling_weight + other_resampling_weight; + combined_reservoir.confidence_weight = canonical_reservoir.confidence_weight + other_reservoir.confidence_weight; + + if rand_f(rng) < other_resampling_weight / combined_reservoir.weight_sum { + combined_reservoir.sample_point_world_position = other_reservoir.sample_point_world_position; + combined_reservoir.radiance = other_reservoir.radiance; + combined_reservoir.target_function = other_reservoir.target_function; + } else { + combined_reservoir.sample_point_world_position = canonical_reservoir.sample_point_world_position; + combined_reservoir.radiance = canonical_reservoir.radiance; + combined_reservoir.target_function = canonical_reservoir.target_function; + } + + let inverse_target_function = select(0.0, 1.0 / combined_reservoir.target_function, combined_reservoir.target_function > 0.0); + combined_reservoir.unbiased_contribution_weight = combined_reservoir.weight_sum * inverse_target_function; + + return combined_reservoir; +}