Improve par_iter and Parallel (#12904)
# Objective - bevy usually use `Parallel::scope` to collect items from `par_iter`, but `scope` will be called with every satifified items. it will cause a lot of unnecessary lookup. ## Solution - similar to Rayon ,we introduce `for_each_init` for `par_iter` which only be invoked when spawn a task for a group of items. --- ## Changelog - added `for_each_init` ## Performance `check_visibility ` in `many_foxes `  ~40% performance gain in `check_visibility`. --------- Co-authored-by: James Liu <contact@jamessliu.com>
This commit is contained in:
parent
b1ab036329
commit
0f27500e46
@ -41,54 +41,6 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment
|
|
||||||
/// from a table.
|
|
||||||
///
|
|
||||||
/// # Safety
|
|
||||||
/// - all `rows` must be in `[0, table.entity_count)`.
|
|
||||||
/// - `table` must match D and F
|
|
||||||
/// - Both `D::IS_DENSE` and `F::IS_DENSE` must be true.
|
|
||||||
#[inline]
|
|
||||||
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
|
|
||||||
pub(super) unsafe fn for_each_in_table_range<Func>(
|
|
||||||
&mut self,
|
|
||||||
func: &mut Func,
|
|
||||||
table: &'w Table,
|
|
||||||
rows: Range<usize>,
|
|
||||||
) where
|
|
||||||
Func: FnMut(D::Item<'w>),
|
|
||||||
{
|
|
||||||
// SAFETY: Caller assures that D::IS_DENSE and F::IS_DENSE are true, that table matches D and F
|
|
||||||
// and all indices in rows are in range.
|
|
||||||
unsafe {
|
|
||||||
self.fold_over_table_range((), &mut |_, item| func(item), table, rows);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment
|
|
||||||
/// from an archetype.
|
|
||||||
///
|
|
||||||
/// # Safety
|
|
||||||
/// - all `indices` must be in `[0, archetype.len())`.
|
|
||||||
/// - `archetype` must match D and F
|
|
||||||
/// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false.
|
|
||||||
#[inline]
|
|
||||||
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
|
|
||||||
pub(super) unsafe fn for_each_in_archetype_range<Func>(
|
|
||||||
&mut self,
|
|
||||||
func: &mut Func,
|
|
||||||
archetype: &'w Archetype,
|
|
||||||
rows: Range<usize>,
|
|
||||||
) where
|
|
||||||
Func: FnMut(D::Item<'w>),
|
|
||||||
{
|
|
||||||
// SAFETY: Caller assures that either D::IS_DENSE or F::IS_DENSE are false, that archetype matches D and F
|
|
||||||
// and all indices in rows are in range.
|
|
||||||
unsafe {
|
|
||||||
self.fold_over_archetype_range((), &mut |_, item| func(item), archetype, rows);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Executes the equivalent of [`Iterator::fold`] over a contiguous segment
|
/// Executes the equivalent of [`Iterator::fold`] over a contiguous segment
|
||||||
/// from an table.
|
/// from an table.
|
||||||
///
|
///
|
||||||
@ -752,7 +704,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
|
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
|
||||||
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::par_for_each_unchecked_manual
|
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::par_fold_init_unchecked_manual
|
||||||
/// # Safety
|
/// # Safety
|
||||||
/// `tables` and `archetypes` must belong to the same world that the [`QueryIterationCursor`]
|
/// `tables` and `archetypes` must belong to the same world that the [`QueryIterationCursor`]
|
||||||
/// was initialized for.
|
/// was initialized for.
|
||||||
|
|||||||
@ -35,8 +35,52 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
|
|||||||
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
|
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn for_each<FN: Fn(QueryItem<'w, D>) + Send + Sync + Clone>(self, func: FN) {
|
pub fn for_each<FN: Fn(QueryItem<'w, D>) + Send + Sync + Clone>(self, func: FN) {
|
||||||
|
self.for_each_init(|| {}, |_, item| func(item));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs `func` on each query result in parallel on a value returned by `init`.
|
||||||
|
///
|
||||||
|
/// `init` may be called multiple times per thread, and the values returned may be discarded between tasks on any given thread.
|
||||||
|
/// Callers should avoid using this function as if it were a a parallel version
|
||||||
|
/// of [`Iterator::fold`].
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use bevy_utils::Parallel;
|
||||||
|
/// use crate::{bevy_ecs::prelude::Component, bevy_ecs::system::Query};
|
||||||
|
/// #[derive(Component)]
|
||||||
|
/// struct T;
|
||||||
|
/// fn system(query: Query<&T>){
|
||||||
|
/// let mut queue: Parallel<usize> = Parallel::default();
|
||||||
|
/// // queue.borrow_local_mut() will get or create a thread_local queue for each task/thread;
|
||||||
|
/// query.par_iter().for_each_init(|| queue.borrow_local_mut(),|local_queue,item| {
|
||||||
|
/// **local_queue += 1;
|
||||||
|
/// });
|
||||||
|
///
|
||||||
|
/// // collect value from every thread
|
||||||
|
/// let entity_count: usize = queue.iter_mut().map(|v| *v).sum();
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being
|
||||||
|
/// initialized and run from the ECS scheduler, this should never panic.
|
||||||
|
///
|
||||||
|
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
|
||||||
|
#[inline]
|
||||||
|
pub fn for_each_init<FN, INIT, T>(self, init: INIT, func: FN)
|
||||||
|
where
|
||||||
|
FN: Fn(&mut T, QueryItem<'w, D>) + Send + Sync + Clone,
|
||||||
|
INIT: Fn() -> T + Sync + Send + Clone,
|
||||||
|
{
|
||||||
|
let func = |mut init, item| {
|
||||||
|
func(&mut init, item);
|
||||||
|
init
|
||||||
|
};
|
||||||
#[cfg(any(target_arch = "wasm32", not(feature = "multi-threaded")))]
|
#[cfg(any(target_arch = "wasm32", not(feature = "multi-threaded")))]
|
||||||
{
|
{
|
||||||
|
let init = init();
|
||||||
// SAFETY:
|
// SAFETY:
|
||||||
// This method can only be called once per instance of QueryParIter,
|
// This method can only be called once per instance of QueryParIter,
|
||||||
// which ensures that mutable queries cannot be executed multiple times at once.
|
// which ensures that mutable queries cannot be executed multiple times at once.
|
||||||
@ -46,25 +90,27 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
|
|||||||
unsafe {
|
unsafe {
|
||||||
self.state
|
self.state
|
||||||
.iter_unchecked_manual(self.world, self.last_run, self.this_run)
|
.iter_unchecked_manual(self.world, self.last_run, self.this_run)
|
||||||
.for_each(func);
|
.fold(init, func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
|
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
|
||||||
{
|
{
|
||||||
let thread_count = bevy_tasks::ComputeTaskPool::get().thread_num();
|
let thread_count = bevy_tasks::ComputeTaskPool::get().thread_num();
|
||||||
if thread_count <= 1 {
|
if thread_count <= 1 {
|
||||||
|
let init = init();
|
||||||
// SAFETY: See the safety comment above.
|
// SAFETY: See the safety comment above.
|
||||||
unsafe {
|
unsafe {
|
||||||
self.state
|
self.state
|
||||||
.iter_unchecked_manual(self.world, self.last_run, self.this_run)
|
.iter_unchecked_manual(self.world, self.last_run, self.this_run)
|
||||||
.for_each(func);
|
.fold(init, func);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Need a batch size of at least 1.
|
// Need a batch size of at least 1.
|
||||||
let batch_size = self.get_batch_size(thread_count).max(1);
|
let batch_size = self.get_batch_size(thread_count).max(1);
|
||||||
// SAFETY: See the safety comment above.
|
// SAFETY: See the safety comment above.
|
||||||
unsafe {
|
unsafe {
|
||||||
self.state.par_for_each_unchecked_manual(
|
self.state.par_fold_init_unchecked_manual(
|
||||||
|
init,
|
||||||
self.world,
|
self.world,
|
||||||
batch_size,
|
batch_size,
|
||||||
func,
|
func,
|
||||||
|
|||||||
@ -1394,19 +1394,20 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
|
|||||||
///
|
///
|
||||||
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
|
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
|
||||||
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
|
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
|
||||||
pub(crate) unsafe fn par_for_each_unchecked_manual<
|
pub(crate) unsafe fn par_fold_init_unchecked_manual<'w, T, FN, INIT>(
|
||||||
'w,
|
|
||||||
FN: Fn(D::Item<'w>) + Send + Sync + Clone,
|
|
||||||
>(
|
|
||||||
&self,
|
&self,
|
||||||
|
init_accum: INIT,
|
||||||
world: UnsafeWorldCell<'w>,
|
world: UnsafeWorldCell<'w>,
|
||||||
batch_size: usize,
|
batch_size: usize,
|
||||||
func: FN,
|
func: FN,
|
||||||
last_run: Tick,
|
last_run: Tick,
|
||||||
this_run: Tick,
|
this_run: Tick,
|
||||||
) {
|
) where
|
||||||
|
FN: Fn(T, D::Item<'w>) -> T + Send + Sync + Clone,
|
||||||
|
INIT: Fn() -> T + Sync + Send + Clone,
|
||||||
|
{
|
||||||
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
|
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
|
||||||
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
|
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter,QueryState::par_fold_init_unchecked_manual
|
||||||
use arrayvec::ArrayVec;
|
use arrayvec::ArrayVec;
|
||||||
|
|
||||||
bevy_tasks::ComputeTaskPool::get().scope(|scope| {
|
bevy_tasks::ComputeTaskPool::get().scope(|scope| {
|
||||||
@ -1423,19 +1424,27 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
|
|||||||
}
|
}
|
||||||
let queue = std::mem::take(queue);
|
let queue = std::mem::take(queue);
|
||||||
let mut func = func.clone();
|
let mut func = func.clone();
|
||||||
|
let init_accum = init_accum.clone();
|
||||||
scope.spawn(async move {
|
scope.spawn(async move {
|
||||||
#[cfg(feature = "trace")]
|
#[cfg(feature = "trace")]
|
||||||
let _span = self.par_iter_span.enter();
|
let _span = self.par_iter_span.enter();
|
||||||
let mut iter = self.iter_unchecked_manual(world, last_run, this_run);
|
let mut iter = self.iter_unchecked_manual(world, last_run, this_run);
|
||||||
|
let mut accum = init_accum();
|
||||||
for storage_id in queue {
|
for storage_id in queue {
|
||||||
if D::IS_DENSE && F::IS_DENSE {
|
if D::IS_DENSE && F::IS_DENSE {
|
||||||
let id = storage_id.table_id;
|
let id = storage_id.table_id;
|
||||||
let table = &world.storages().tables.get(id).debug_checked_unwrap();
|
let table = &world.storages().tables.get(id).debug_checked_unwrap();
|
||||||
iter.for_each_in_table_range(&mut func, table, 0..table.entity_count());
|
accum = iter.fold_over_table_range(
|
||||||
|
accum,
|
||||||
|
&mut func,
|
||||||
|
table,
|
||||||
|
0..table.entity_count(),
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
let id = storage_id.archetype_id;
|
let id = storage_id.archetype_id;
|
||||||
let archetype = world.archetypes().get(id).debug_checked_unwrap();
|
let archetype = world.archetypes().get(id).debug_checked_unwrap();
|
||||||
iter.for_each_in_archetype_range(
|
accum = iter.fold_over_archetype_range(
|
||||||
|
accum,
|
||||||
&mut func,
|
&mut func,
|
||||||
archetype,
|
archetype,
|
||||||
0..archetype.len(),
|
0..archetype.len(),
|
||||||
@ -1449,21 +1458,23 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
|
|||||||
let submit_single = |count, storage_id: StorageId| {
|
let submit_single = |count, storage_id: StorageId| {
|
||||||
for offset in (0..count).step_by(batch_size) {
|
for offset in (0..count).step_by(batch_size) {
|
||||||
let mut func = func.clone();
|
let mut func = func.clone();
|
||||||
|
let init_accum = init_accum.clone();
|
||||||
let len = batch_size.min(count - offset);
|
let len = batch_size.min(count - offset);
|
||||||
let batch = offset..offset + len;
|
let batch = offset..offset + len;
|
||||||
scope.spawn(async move {
|
scope.spawn(async move {
|
||||||
#[cfg(feature = "trace")]
|
#[cfg(feature = "trace")]
|
||||||
let _span = self.par_iter_span.enter();
|
let _span = self.par_iter_span.enter();
|
||||||
|
let accum = init_accum();
|
||||||
if D::IS_DENSE && F::IS_DENSE {
|
if D::IS_DENSE && F::IS_DENSE {
|
||||||
let id = storage_id.table_id;
|
let id = storage_id.table_id;
|
||||||
let table = world.storages().tables.get(id).debug_checked_unwrap();
|
let table = world.storages().tables.get(id).debug_checked_unwrap();
|
||||||
self.iter_unchecked_manual(world, last_run, this_run)
|
self.iter_unchecked_manual(world, last_run, this_run)
|
||||||
.for_each_in_table_range(&mut func, table, batch);
|
.fold_over_table_range(accum, &mut func, table, batch);
|
||||||
} else {
|
} else {
|
||||||
let id = storage_id.archetype_id;
|
let id = storage_id.archetype_id;
|
||||||
let archetype = world.archetypes().get(id).debug_checked_unwrap();
|
let archetype = world.archetypes().get(id).debug_checked_unwrap();
|
||||||
self.iter_unchecked_manual(world, last_run, this_run)
|
self.iter_unchecked_manual(world, last_run, this_run)
|
||||||
.for_each_in_archetype_range(&mut func, archetype, batch);
|
.fold_over_archetype_range(accum, &mut func, archetype, batch);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -595,8 +595,10 @@ pub fn extract_meshes_for_cpu_building(
|
|||||||
)>,
|
)>,
|
||||||
>,
|
>,
|
||||||
) {
|
) {
|
||||||
meshes_query.par_iter().for_each(
|
meshes_query.par_iter().for_each_init(
|
||||||
|(
|
|| render_mesh_instance_queues.borrow_local_mut(),
|
||||||
|
|queue,
|
||||||
|
(
|
||||||
entity,
|
entity,
|
||||||
view_visibility,
|
view_visibility,
|
||||||
transform,
|
transform,
|
||||||
@ -621,23 +623,19 @@ pub fn extract_meshes_for_cpu_building(
|
|||||||
no_automatic_batching,
|
no_automatic_batching,
|
||||||
);
|
);
|
||||||
|
|
||||||
render_mesh_instance_queues.scope(|queue| {
|
let transform = transform.affine();
|
||||||
let transform = transform.affine();
|
queue.push((
|
||||||
queue.push((
|
entity,
|
||||||
entity,
|
RenderMeshInstanceCpu {
|
||||||
RenderMeshInstanceCpu {
|
transforms: MeshTransforms {
|
||||||
transforms: MeshTransforms {
|
transform: (&transform).into(),
|
||||||
transform: (&transform).into(),
|
previous_transform: (&previous_transform.map(|t| t.0).unwrap_or(transform))
|
||||||
previous_transform: (&previous_transform
|
.into(),
|
||||||
.map(|t| t.0)
|
flags: mesh_flags.bits(),
|
||||||
.unwrap_or(transform))
|
|
||||||
.into(),
|
|
||||||
flags: mesh_flags.bits(),
|
|
||||||
},
|
|
||||||
shared,
|
|
||||||
},
|
},
|
||||||
));
|
shared,
|
||||||
});
|
},
|
||||||
|
));
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -683,8 +681,10 @@ pub fn extract_meshes_for_gpu_building(
|
|||||||
)>,
|
)>,
|
||||||
>,
|
>,
|
||||||
) {
|
) {
|
||||||
meshes_query.par_iter().for_each(
|
meshes_query.par_iter().for_each_init(
|
||||||
|(
|
|| render_mesh_instance_queues.borrow_local_mut(),
|
||||||
|
|queue,
|
||||||
|
(
|
||||||
entity,
|
entity,
|
||||||
view_visibility,
|
view_visibility,
|
||||||
transform,
|
transform,
|
||||||
@ -713,17 +713,15 @@ pub fn extract_meshes_for_gpu_building(
|
|||||||
let lightmap_uv_rect =
|
let lightmap_uv_rect =
|
||||||
lightmap::pack_lightmap_uv_rect(lightmap.map(|lightmap| lightmap.uv_rect));
|
lightmap::pack_lightmap_uv_rect(lightmap.map(|lightmap| lightmap.uv_rect));
|
||||||
|
|
||||||
render_mesh_instance_queues.scope(|queue| {
|
queue.push((
|
||||||
queue.push((
|
entity,
|
||||||
entity,
|
RenderMeshInstanceGpuBuilder {
|
||||||
RenderMeshInstanceGpuBuilder {
|
shared,
|
||||||
shared,
|
transform: (&transform.affine()).into(),
|
||||||
transform: (&transform.affine()).into(),
|
lightmap_uv_rect,
|
||||||
lightmap_uv_rect,
|
mesh_flags,
|
||||||
mesh_flags,
|
},
|
||||||
},
|
));
|
||||||
));
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@ -453,52 +453,53 @@ pub fn check_visibility<QF>(
|
|||||||
|
|
||||||
let view_mask = maybe_view_mask.copied().unwrap_or_default();
|
let view_mask = maybe_view_mask.copied().unwrap_or_default();
|
||||||
|
|
||||||
visible_aabb_query.par_iter_mut().for_each(|query_item| {
|
visible_aabb_query.par_iter_mut().for_each_init(
|
||||||
let (
|
|| thread_queues.borrow_local_mut(),
|
||||||
entity,
|
|queue, query_item| {
|
||||||
inherited_visibility,
|
let (
|
||||||
mut view_visibility,
|
entity,
|
||||||
maybe_entity_mask,
|
inherited_visibility,
|
||||||
maybe_model_aabb,
|
mut view_visibility,
|
||||||
transform,
|
maybe_entity_mask,
|
||||||
no_frustum_culling,
|
maybe_model_aabb,
|
||||||
) = query_item;
|
transform,
|
||||||
|
no_frustum_culling,
|
||||||
|
) = query_item;
|
||||||
|
|
||||||
// Skip computing visibility for entities that are configured to be hidden.
|
// Skip computing visibility for entities that are configured to be hidden.
|
||||||
// ViewVisibility has already been reset in `reset_view_visibility`.
|
// ViewVisibility has already been reset in `reset_view_visibility`.
|
||||||
if !inherited_visibility.get() {
|
if !inherited_visibility.get() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let entity_mask = maybe_entity_mask.copied().unwrap_or_default();
|
let entity_mask = maybe_entity_mask.copied().unwrap_or_default();
|
||||||
if !view_mask.intersects(&entity_mask) {
|
if !view_mask.intersects(&entity_mask) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have an aabb, do frustum culling
|
// If we have an aabb, do frustum culling
|
||||||
if !no_frustum_culling {
|
if !no_frustum_culling {
|
||||||
if let Some(model_aabb) = maybe_model_aabb {
|
if let Some(model_aabb) = maybe_model_aabb {
|
||||||
let model = transform.affine();
|
let model = transform.affine();
|
||||||
let model_sphere = Sphere {
|
let model_sphere = Sphere {
|
||||||
center: model.transform_point3a(model_aabb.center),
|
center: model.transform_point3a(model_aabb.center),
|
||||||
radius: transform.radius_vec3a(model_aabb.half_extents),
|
radius: transform.radius_vec3a(model_aabb.half_extents),
|
||||||
};
|
};
|
||||||
// Do quick sphere-based frustum culling
|
// Do quick sphere-based frustum culling
|
||||||
if !frustum.intersects_sphere(&model_sphere, false) {
|
if !frustum.intersects_sphere(&model_sphere, false) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Do aabb-based frustum culling
|
// Do aabb-based frustum culling
|
||||||
if !frustum.intersects_obb(model_aabb, &model, true, false) {
|
if !frustum.intersects_obb(model_aabb, &model, true, false) {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
view_visibility.set();
|
view_visibility.set();
|
||||||
thread_queues.scope(|queue| {
|
|
||||||
queue.push(entity);
|
queue.push(entity);
|
||||||
});
|
},
|
||||||
});
|
);
|
||||||
|
|
||||||
visible_entities.clear::<QF>();
|
visible_entities.clear::<QF>();
|
||||||
thread_queues.drain_into(visible_entities.get_mut::<QF>());
|
thread_queues.drain_into(visible_entities.get_mut::<QF>());
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
use core::cell::Cell;
|
use std::{cell::RefCell, ops::DerefMut};
|
||||||
use thread_local::ThreadLocal;
|
use thread_local::ThreadLocal;
|
||||||
|
|
||||||
/// A cohesive set of thread-local values of a given type.
|
/// A cohesive set of thread-local values of a given type.
|
||||||
@ -6,9 +6,10 @@ use thread_local::ThreadLocal;
|
|||||||
/// Mutable references can be fetched if `T: Default` via [`Parallel::scope`].
|
/// Mutable references can be fetched if `T: Default` via [`Parallel::scope`].
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct Parallel<T: Send> {
|
pub struct Parallel<T: Send> {
|
||||||
locals: ThreadLocal<Cell<T>>,
|
locals: ThreadLocal<RefCell<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A scope guard of a `Parallel`, when this struct is dropped ,the value will writeback to its `Parallel`
|
||||||
impl<T: Send> Parallel<T> {
|
impl<T: Send> Parallel<T> {
|
||||||
/// Gets a mutable iterator over all of the per-thread queues.
|
/// Gets a mutable iterator over all of the per-thread queues.
|
||||||
pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
|
pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
|
||||||
@ -26,12 +27,17 @@ impl<T: Default + Send> Parallel<T> {
|
|||||||
///
|
///
|
||||||
/// If there is no thread-local value, it will be initialized to its default.
|
/// If there is no thread-local value, it will be initialized to its default.
|
||||||
pub fn scope<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
|
pub fn scope<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
|
||||||
let cell = self.locals.get_or_default();
|
let mut cell = self.locals.get_or_default().borrow_mut();
|
||||||
let mut value = cell.take();
|
let ret = f(cell.deref_mut());
|
||||||
let ret = f(&mut value);
|
|
||||||
cell.set(value);
|
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Mutably borrows the thread-local value.
|
||||||
|
///
|
||||||
|
/// If there is no thread-local value, it will be initialized to it's default.
|
||||||
|
pub fn borrow_local_mut(&self) -> impl DerefMut<Target = T> + '_ {
|
||||||
|
self.locals.get_or_default().borrow_mut()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, I> Parallel<I>
|
impl<T, I> Parallel<I>
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user