diff --git a/crates/bevy_solari/src/realtime/restir_di.wgsl b/crates/bevy_solari/src/realtime/restir_di.wgsl index 6dae89d21b..7bc3d32fcf 100644 --- a/crates/bevy_solari/src/realtime/restir_di.wgsl +++ b/crates/bevy_solari/src/realtime/restir_di.wgsl @@ -48,7 +48,7 @@ fn initial_and_temporal(@builtin(global_invocation_id) global_id: vec3) { let diffuse_brdf = base_color / PI; let initial_reservoir = generate_initial_reservoir(world_position, world_normal, diffuse_brdf, &rng); - let temporal_reservoir = load_temporal_reservoir(global_id.xy, depth, world_position, world_normal); + let temporal_reservoir = load_temporal_reservoir(global_id.xy, depth, world_position, world_normal, &rng); let combined_reservoir = merge_reservoirs(initial_reservoir, temporal_reservoir, world_position, world_normal, diffuse_brdf, &rng); di_reservoirs_b[pixel_index] = combined_reservoir.merged_reservoir; @@ -119,14 +119,47 @@ fn generate_initial_reservoir(world_position: vec3, world_normal: vec3 return reservoir; } -fn load_temporal_reservoir(pixel_id: vec2, depth: f32, world_position: vec3, world_normal: vec3) -> Reservoir { +fn load_temporal_reservoir(pixel_id: vec2, depth: f32, world_position: vec3, world_normal: vec3, rng: ptr) -> Reservoir { let motion_vector = textureLoad(motion_vectors, pixel_id, 0).xy; - let temporal_pixel_id_float = round(vec2(pixel_id) - (motion_vector * view.viewport.zw)); - let temporal_pixel_id = vec2(temporal_pixel_id_float); + let temporal_pixel_id_float = vec2(pixel_id) - (motion_vector * view.viewport.zw); + + // Check if the current pixel was off screen during the previous frame (current pixel is newly visible), + // or if all temporal history should assumed to be invalid if any(temporal_pixel_id_float < vec2(0.0)) || any(temporal_pixel_id_float >= view.viewport.zw) || bool(constants.reset) { return empty_reservoir(); } + // https://en.wikipedia.org/wiki/Bilinear_interpolation#On_the_unit_square + let tl = vec2(temporal_pixel_id_float); + let tr = tl + vec2(1u, 0u); + let bl = tl + vec2(0u, 1u); + let br = tl + vec2(1u, 1u); + let f = fract(temporal_pixel_id_float); + let tl_w = (1.0 - f.x) * (1.0 - f.y); + let tr_w = f.x * (1.0 - f.y); + let bl_w = (1.0 - f.x) * f.y; + + // Choose a random pixel from the 2x2 quad, weighted by the bilinear weights + // This gives better results than always using the nearest pixel + var temporal_pixel_id = tl; + var weight_sum = tl_w; + let r = rand_f(rng); + if (r > weight_sum) { + temporal_pixel_id = tr; + weight_sum += tr_w; + } + if (r > weight_sum) { + temporal_pixel_id = bl; + weight_sum += bl_w; + } + if (r > weight_sum) { + temporal_pixel_id = br; + } + + // Clamp to view size, since 2x2 quad may go off screen + temporal_pixel_id = min(temporal_pixel_id, vec2(view.viewport.zw - 1.0)); + + // Check if the pixel features have changed heavily between the current and previous frame let temporal_depth = textureLoad(previous_depth_buffer, temporal_pixel_id, 0); let temporal_gpixel = textureLoad(previous_gbuffer, temporal_pixel_id, 0); let temporal_world_position = reconstruct_previous_world_position(temporal_pixel_id, temporal_depth); @@ -138,6 +171,7 @@ fn load_temporal_reservoir(pixel_id: vec2, depth: f32, world_position: vec3 let temporal_pixel_index = temporal_pixel_id.x + temporal_pixel_id.y * u32(view.viewport.z); var temporal_reservoir = di_reservoirs_a[temporal_pixel_index]; + // Check if the light selected in the previous frame no longer exists in the current frame (e.g. entity despawned) temporal_reservoir.sample.light_id.x = previous_frame_light_id_translations[temporal_reservoir.sample.light_id.x]; if temporal_reservoir.sample.light_id.x == LIGHT_NOT_PRESENT_THIS_FRAME { return empty_reservoir(); diff --git a/crates/bevy_solari/src/realtime/restir_gi.wgsl b/crates/bevy_solari/src/realtime/restir_gi.wgsl index 8e090ca29c..aae0db1942 100644 --- a/crates/bevy_solari/src/realtime/restir_gi.wgsl +++ b/crates/bevy_solari/src/realtime/restir_gi.wgsl @@ -43,7 +43,7 @@ fn initial_and_temporal(@builtin(global_invocation_id) global_id: vec3) { let base_color = pow(unpack4x8unorm(gpixel.r).rgb, vec3(2.2)); let diffuse_brdf = base_color / PI; - let temporal_reservoir = load_temporal_reservoir(global_id.xy, depth, world_position, world_normal); + let temporal_reservoir = load_temporal_reservoir(global_id.xy, depth, world_position, world_normal, &rng); let ray_direction = sample_uniform_hemisphere(world_normal, &rng); let ray_hit = trace_ray(world_position, ray_direction, RAY_T_MIN, RAY_T_MAX, RAY_FLAG_NONE); @@ -127,14 +127,47 @@ fn spatial_and_shade(@builtin(global_invocation_id) global_id: vec3) { textureStore(view_output, global_id.xy, pixel_color); } -fn load_temporal_reservoir(pixel_id: vec2, depth: f32, world_position: vec3, world_normal: vec3) -> Reservoir { +fn load_temporal_reservoir(pixel_id: vec2, depth: f32, world_position: vec3, world_normal: vec3, rng: ptr) -> Reservoir { let motion_vector = textureLoad(motion_vectors, pixel_id, 0).xy; - let temporal_pixel_id_float = round(vec2(pixel_id) - (motion_vector * view.viewport.zw)); - let temporal_pixel_id = vec2(temporal_pixel_id_float); + let temporal_pixel_id_float = vec2(pixel_id) - (motion_vector * view.viewport.zw); + + // Check if the current pixel was off screen during the previous frame (current pixel is newly visible), + // or if all temporal history should assumed to be invalid if any(temporal_pixel_id_float < vec2(0.0)) || any(temporal_pixel_id_float >= view.viewport.zw) || bool(constants.reset) { return empty_reservoir(); } + // https://en.wikipedia.org/wiki/Bilinear_interpolation#On_the_unit_square + let tl = vec2(temporal_pixel_id_float); + let tr = tl + vec2(1u, 0u); + let bl = tl + vec2(0u, 1u); + let br = tl + vec2(1u, 1u); + let f = fract(temporal_pixel_id_float); + let tl_w = (1.0 - f.x) * (1.0 - f.y); + let tr_w = f.x * (1.0 - f.y); + let bl_w = (1.0 - f.x) * f.y; + + // Choose a random pixel from the 2x2 quad, weighted by the bilinear weights + // This gives better results than always using the nearest pixel + var temporal_pixel_id = tl; + var weight_sum = tl_w; + let r = rand_f(rng); + if (r > weight_sum) { + temporal_pixel_id = tr; + weight_sum += tr_w; + } + if (r > weight_sum) { + temporal_pixel_id = bl; + weight_sum += bl_w; + } + if (r > weight_sum) { + temporal_pixel_id = br; + } + + // Clamp to view size, since 2x2 quad may go off screen + temporal_pixel_id = min(temporal_pixel_id, vec2(view.viewport.zw - 1.0)); + + // Check if the pixel features have changed heavily between the current and previous frame let temporal_depth = textureLoad(previous_depth_buffer, temporal_pixel_id, 0); let temporal_gpixel = textureLoad(previous_gbuffer, temporal_pixel_id, 0); let temporal_world_position = reconstruct_previous_world_position(temporal_pixel_id, temporal_depth);