use rwlock in graph executor instead of cloning

This commit is contained in:
Carter Anderson 2020-05-05 13:33:47 -07:00
parent 2b8789dc8c
commit c388598996

View File

@ -4,7 +4,7 @@ use bevy_render::{
renderer::RenderResources, renderer::RenderResources,
}; };
use legion::prelude::{Resources, World}; use legion::prelude::{Resources, World};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::{RwLock, Arc}};
pub struct WgpuRenderGraphExecutor { pub struct WgpuRenderGraphExecutor {
pub max_thread_count: usize, pub max_thread_count: usize,
@ -24,7 +24,7 @@ impl WgpuRenderGraphExecutor {
.context .context
.downcast_mut::<WgpuRenderResourceContext>() .downcast_mut::<WgpuRenderResourceContext>()
.unwrap(); .unwrap();
let mut node_outputs: HashMap<NodeId, ResourceSlots> = HashMap::new(); let node_outputs: Arc<RwLock<HashMap<NodeId, ResourceSlots>>> = Default::default();
for stage in stages.iter_mut() { for stage in stages.iter_mut() {
// TODO: sort jobs and slice by "amount of work" / weights // TODO: sort jobs and slice by "amount of work" / weights
// stage.jobs.sort_by_key(|j| j.node_states.len()); // stage.jobs.sort_by_key(|j| j.node_states.len());
@ -33,17 +33,16 @@ impl WgpuRenderGraphExecutor {
let chunk_size = (stage.jobs.len() + self.max_thread_count - 1) / self.max_thread_count; // divide ints rounding remainder up let chunk_size = (stage.jobs.len() + self.max_thread_count - 1) / self.max_thread_count; // divide ints rounding remainder up
let mut actual_thread_count = 0; let mut actual_thread_count = 0;
crossbeam_utils::thread::scope(|s| { crossbeam_utils::thread::scope(|s| {
let node_outputs = &node_outputs;
for jobs_chunk in stage.jobs.chunks_mut(chunk_size) { for jobs_chunk in stage.jobs.chunks_mut(chunk_size) {
let sender = sender.clone(); let sender = sender.clone();
let world = &*world; let world = &*world;
actual_thread_count += 1; actual_thread_count += 1;
let device = device.clone(); let device = device.clone();
let wgpu_render_resources = wgpu_render_resources.clone(); let wgpu_render_resources = wgpu_render_resources.clone();
let node_outputs = node_outputs.clone();
s.spawn(move |_| { s.spawn(move |_| {
let mut render_context = let mut render_context =
WgpuRenderContext::new(device, wgpu_render_resources); WgpuRenderContext::new(device, wgpu_render_resources);
let mut local_node_outputs = HashMap::new();
for job in jobs_chunk.iter_mut() { for job in jobs_chunk.iter_mut() {
for node_state in job.node_states.iter_mut() { for node_state in job.node_states.iter_mut() {
// bind inputs from connected node outputs // bind inputs from connected node outputs
@ -56,13 +55,10 @@ impl WgpuRenderGraphExecutor {
.. ..
} = node_state.edges.get_input_slot_edge(i).unwrap() } = node_state.edges.get_input_slot_edge(i).unwrap()
{ {
let node_outputs = node_outputs.read().unwrap();
let outputs = let outputs =
if let Some(outputs) = node_outputs.get(output_node) { if let Some(outputs) = node_outputs.get(output_node) {
outputs outputs
} else if let Some(outputs) =
local_node_outputs.get(output_node)
{
outputs
} else { } else {
panic!("node inputs not set") panic!("node inputs not set")
}; };
@ -83,12 +79,12 @@ impl WgpuRenderGraphExecutor {
&mut node_state.output_slots, &mut node_state.output_slots,
); );
local_node_outputs node_outputs.write().unwrap()
.insert(node_state.id, node_state.output_slots.clone()); .insert(node_state.id, node_state.output_slots.clone());
} }
} }
sender sender
.send((render_context.finish(), local_node_outputs)) .send(render_context.finish())
.unwrap(); .unwrap();
}); });
} }
@ -97,12 +93,10 @@ impl WgpuRenderGraphExecutor {
let mut command_buffers = Vec::new(); let mut command_buffers = Vec::new();
for _i in 0..actual_thread_count { for _i in 0..actual_thread_count {
let (command_buffer, mut local_node_outputs) = receiver.recv().unwrap(); let command_buffer = receiver.recv().unwrap();
if let Some(command_buffer) = command_buffer { if let Some(command_buffer) = command_buffer {
command_buffers.push(command_buffer); command_buffers.push(command_buffer);
} }
node_outputs.extend(local_node_outputs.drain());
} }
queue.submit(command_buffers.drain(..)); queue.submit(command_buffers.drain(..));