address feedback

This commit is contained in:
Christian Hughes 2025-07-17 20:12:35 -05:00
parent 807bad6b80
commit e133d94eeb
5 changed files with 108 additions and 145 deletions

View File

@ -13,24 +13,24 @@ use core::{
use indexmap::IndexMap;
use smallvec::SmallVec;
use crate::schedule::graph::node::{DirectedGraphNodeId, GraphNodeId, GraphNodeIdPair};
use crate::schedule::graph::node::GraphNodeId;
use Direction::{Incoming, Outgoing};
/// A `Graph` with undirected edges.
/// A `Graph` with undirected edges of some [`GraphNodeId`] `N`.
///
/// For example, an edge between *1* and *2* is equivalent to an edge between
/// *2* and *1*.
pub type UnGraph<N, S = FixedHasher> = Graph<false, N, S>;
/// A `Graph` with directed edges.
/// A `Graph` with directed edges of some [`GraphNodeId`] `N`.
///
/// For example, an edge from *1* to *2* is distinct from an edge from *2* to
/// *1*.
pub type DiGraph<N, S = FixedHasher> = Graph<true, N, S>;
/// `Graph<DIRECTED>` is a graph datastructure using an associative array
/// of its node weights `NodeId`.
/// of its node weights of some [`GraphNodeId`].
///
/// It uses a combined adjacency list and sparse adjacency matrix
/// representation, using **O(|N| + |E|)** space, and allows testing for edge
@ -40,6 +40,7 @@ pub type DiGraph<N, S = FixedHasher> = Graph<true, N, S>;
///
/// - Constant generic bool `DIRECTED` determines whether the graph edges are directed or
/// undirected.
/// - The `GraphNodeId` type `N`, which is used as the node weight.
/// - The `BuildHasher` `S`.
///
/// You can use the type aliases `UnGraph` and `DiGraph` for convenience.
@ -50,8 +51,8 @@ pub struct Graph<const DIRECTED: bool, N: GraphNodeId, S = FixedHasher>
where
S: BuildHasher,
{
nodes: IndexMap<N, Vec<N::Directed>, S>,
edges: HashSet<N::Pair, S>,
nodes: IndexMap<N, Vec<N::Adjacent>, S>,
edges: HashSet<N::Edge, S>,
}
impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> fmt::Debug for Graph<DIRECTED, N, S> {
@ -74,10 +75,10 @@ impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S>
/// Use their natural order to map the node pair (a, b) to a canonical edge id.
#[inline]
fn edge_key(a: N, b: N) -> N::Pair {
fn edge_key(a: N, b: N) -> N::Edge {
let (a, b) = if DIRECTED || a <= b { (a, b) } else { (b, a) };
N::Pair::new(a, b)
N::Edge::from((a, b))
}
/// Return the number of nodes in the graph.
@ -103,7 +104,7 @@ impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S>
return;
};
let links = links.into_iter().map(N::Directed::unwrap);
let links = links.into_iter().map(N::Adjacent::into);
for (succ, dir) in links {
let edge = if dir == Outgoing {
@ -133,18 +134,18 @@ impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S>
self.nodes
.entry(a)
.or_insert_with(|| Vec::with_capacity(1))
.push(N::Directed::new(b, Outgoing));
.push(N::Adjacent::from((b, Outgoing)));
if a != b {
// self loops don't have the Incoming entry
self.nodes
.entry(b)
.or_insert_with(|| Vec::with_capacity(1))
.push(N::Directed::new(a, Incoming));
.push(N::Adjacent::from((a, Incoming)));
}
}
}
/// Remove edge relation from a to b
/// Remove edge relation from a to b.
///
/// Return `true` if it did exist.
fn remove_single_edge(&mut self, a: N, b: N, dir: Direction) -> bool {
@ -155,7 +156,7 @@ impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S>
let Some(index) = sus
.iter()
.copied()
.map(N::Directed::unwrap)
.map(N::Adjacent::into)
.position(|elt| (DIRECTED && elt == (b, dir)) || (!DIRECTED && elt.0 == b))
else {
return false;
@ -198,7 +199,7 @@ impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S>
};
iter.copied()
.map(N::Directed::unwrap)
.map(N::Adjacent::into)
.filter_map(|(n, dir)| (!DIRECTED || dir == Outgoing).then_some(n))
}
@ -216,7 +217,7 @@ impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S>
};
iter.copied()
.map(N::Directed::unwrap)
.map(N::Adjacent::into)
.filter_map(move |(n, d)| (!DIRECTED || d == dir || n == a).then_some(n))
}
@ -249,7 +250,7 @@ impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S>
/// Return an iterator over all edges of the graph with their weight in arbitrary order.
pub fn all_edges(&self) -> impl ExactSizeIterator<Item = (N, N)> + '_ {
self.edges.iter().copied().map(N::Pair::unwrap)
self.edges.iter().copied().map(N::Edge::into)
}
pub(crate) fn to_index(&self, ix: N) -> usize {
@ -266,29 +267,38 @@ impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S>
where
S: Default,
{
// Converts the node key and every adjacency list entry from `N` to `T`.
fn try_convert_node<N: GraphNodeId, T: GraphNodeId + TryFrom<N>>(
(key, adj): (N, Vec<N::Adjacent>),
) -> Result<(T, Vec<T::Adjacent>), T::Error> {
let key = key.try_into()?;
let adj = adj
.into_iter()
.map(|node| {
let (id, dir) = node.into();
Ok(T::Adjacent::from((id.try_into()?, dir)))
})
.collect::<Result<_, T::Error>>()?;
Ok((key, adj))
}
// Unpacks the edge pair, converts the nodes from `N` to `T`, and repacks them.
fn try_convert_edge<N: GraphNodeId, T: GraphNodeId + TryFrom<N>>(
edge: N::Edge,
) -> Result<T::Edge, T::Error> {
let (a, b) = edge.into();
Ok(T::Edge::from((a.try_into()?, b.try_into()?)))
}
let nodes = self
.nodes
.into_iter()
.map(|(k, v)| {
Ok((
k.try_into()?,
v.into_iter()
.map(|v| {
let (id, dir) = v.unwrap();
Ok(T::Directed::new(id.try_into()?, dir))
})
.collect::<Result<Vec<T::Directed>, T::Error>>()?,
))
})
.collect::<Result<IndexMap<T, Vec<T::Directed>, S>, T::Error>>()?;
.map(try_convert_node::<N, T>)
.collect::<Result<_, T::Error>>()?;
let edges = self
.edges
.into_iter()
.map(|e| {
let (a, b) = e.unwrap();
Ok(T::Pair::new(a.try_into()?, b.try_into()?))
})
.collect::<Result<HashSet<T::Pair, S>, T::Error>>()?;
.map(try_convert_edge::<N, T>)
.collect::<Result<_, T::Error>>()?;
Ok(Graph { nodes, edges })
}
}

View File

@ -17,7 +17,7 @@ mod node;
mod tarjan_scc;
pub use graph_map::{DiGraph, Direction, UnGraph};
pub use node::{DirectedGraphNodeId, GraphNodeId, GraphNodeIdPair};
pub use node::GraphNodeId;
/// Specifies what kind of edge should be added to the dependency graph.
#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)]

View File

@ -7,56 +7,10 @@ use crate::schedule::graph::Direction;
/// [`DiGraph`]: crate::schedule::graph::DiGraph
/// [`UnGraph`]: crate::schedule::graph::UnGraph
pub trait GraphNodeId: Copy + Eq + Hash + Ord + Debug {
/// This [`GraphNodeId`] and a [`Direction`].
type Directed: DirectedGraphNodeId<Id = Self>;
/// Two of these [`GraphNodeId`]s.
type Pair: GraphNodeIdPair<Id = Self>;
}
/// Types that are a [`GraphNodeId`] with a [`Direction`].
pub trait DirectedGraphNodeId: Copy + Debug {
/// The type of [`GraphNodeId`] a [`Direction`] is paired with.
type Id: GraphNodeId;
/// Packs a [`GraphNodeId`] and a [`Direction`] into a single type.
fn new(id: Self::Id, direction: Direction) -> Self;
/// Unpacks a [`GraphNodeId`] and a [`Direction`] from this type.
fn unwrap(self) -> (Self::Id, Direction);
}
/// Types that are a pair of [`GraphNodeId`]s.
pub trait GraphNodeIdPair: Copy + Eq + Hash + Debug {
/// The type of [`GraphNodeId`] for each element of the pair.
type Id: GraphNodeId;
/// Packs two [`GraphNodeId`]s into a single type.
fn new(a: Self::Id, b: Self::Id) -> Self;
/// Unpacks two [`GraphNodeId`]s from this type.
fn unwrap(self) -> (Self::Id, Self::Id);
}
impl<N: GraphNodeId> DirectedGraphNodeId for (N, Direction) {
type Id = N;
fn new(id: N, direction: Direction) -> Self {
(id, direction)
}
fn unwrap(self) -> (N, Direction) {
self
}
}
impl<N: GraphNodeId> GraphNodeIdPair for (N, N) {
type Id = N;
fn new(a: N, b: N) -> Self {
(a, b)
}
fn unwrap(self) -> (N, N) {
self
}
/// The type that packs and unpacks this [`GraphNodeId`] with a [`Direction`].
/// This is used to save space in the graph's adjacency list.
type Adjacent: Copy + Debug + From<(Self, Direction)> + Into<(Self, Direction)>;
/// The type that packs and unpacks this [`GraphNodeId`] with another
/// [`GraphNodeId`]. This is used to save space in the graph's edge list.
type Edge: Copy + Eq + Hash + Debug + From<(Self, Self)> + Into<(Self, Self)>;
}

View File

@ -17,9 +17,9 @@ use smallvec::SmallVec;
/// Returns each strongly strongly connected component (scc).
/// The order of node ids within each scc is arbitrary, but the order of
/// the sccs is their postorder (reverse topological sort).
pub(crate) fn new_tarjan_scc<Id: GraphNodeId, S: BuildHasher>(
graph: &DiGraph<Id, S>,
) -> impl Iterator<Item = SmallVec<[Id; 4]>> + '_ {
pub(crate) fn new_tarjan_scc<N: GraphNodeId, S: BuildHasher>(
graph: &DiGraph<N, S>,
) -> impl Iterator<Item = SmallVec<[N; 4]>> + '_ {
// Create a list of all nodes we need to visit.
let unchecked_nodes = graph.nodes();
@ -47,9 +47,9 @@ pub(crate) fn new_tarjan_scc<Id: GraphNodeId, S: BuildHasher>(
}
}
struct NodeData<N: Iterator<Item: GraphNodeId>> {
struct NodeData<Neighbors: Iterator<Item: GraphNodeId>> {
root_index: Option<NonZeroUsize>,
neighbors: N,
neighbors: Neighbors,
}
/// A state for computing the *strongly connected components* using [Tarjan's algorithm][1].
@ -59,15 +59,15 @@ struct NodeData<N: Iterator<Item: GraphNodeId>> {
/// [1]: https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
/// [`petgraph`]: https://docs.rs/petgraph/0.6.5/petgraph/
/// [`TarjanScc`]: https://docs.rs/petgraph/0.6.5/petgraph/algo/struct.TarjanScc.html
struct TarjanScc<'graph, Id, Hasher, AllNodes, Neighbors>
struct TarjanScc<'graph, N, Hasher, AllNodes, Neighbors>
where
Id: GraphNodeId,
N: GraphNodeId,
Hasher: BuildHasher,
AllNodes: Iterator<Item = Id>,
Neighbors: Iterator<Item = Id>,
AllNodes: Iterator<Item = N>,
Neighbors: Iterator<Item = N>,
{
/// Source of truth [`DiGraph`]
graph: &'graph DiGraph<Id, Hasher>,
graph: &'graph DiGraph<N, Hasher>,
/// An [`Iterator`] of [`GraphNodeId`]s from the `graph` which may not have been visited yet.
unchecked_nodes: AllNodes,
/// The index of the next SCC
@ -78,17 +78,22 @@ where
/// [`Iterator`] of possibly unvisited neighbors.
nodes: Vec<NodeData<Neighbors>>,
/// A stack of [`GraphNodeId`]s where a SCC will be found starting at the top of the stack.
stack: Vec<Id>,
stack: Vec<N>,
/// A stack of [`GraphNodeId`]s which need to be visited to determine which SCC they belong to.
visitation_stack: Vec<(Id, bool)>,
visitation_stack: Vec<(N, bool)>,
/// An index into the `stack` indicating the starting point of a SCC.
start: Option<usize>,
/// An adjustment to the `index` which will be applied once the current SCC is found.
index_adjustment: Option<usize>,
}
impl<'graph, Id: GraphNodeId, S: BuildHasher, A: Iterator<Item = Id>, N: Iterator<Item = Id>>
TarjanScc<'graph, Id, S, A, N>
impl<
'graph,
N: GraphNodeId,
S: BuildHasher,
A: Iterator<Item = N>,
Neighbors: Iterator<Item = N>,
> TarjanScc<'graph, N, S, A, Neighbors>
{
/// Compute the next *strongly connected component* using Algorithm 3 in
/// [A Space-Efficient Algorithm for Finding Strongly Connected Components][1] by David J. Pierce,
@ -101,7 +106,7 @@ impl<'graph, Id: GraphNodeId, S: BuildHasher, A: Iterator<Item = Id>, N: Iterato
/// Returns `Some` for each strongly strongly connected component (scc).
/// The order of node ids within each scc is arbitrary, but the order of
/// the sccs is their postorder (reverse topological sort).
fn next_scc(&mut self) -> Option<&[Id]> {
fn next_scc(&mut self) -> Option<&[N]> {
// Cleanup from possible previous iteration
if let (Some(start), Some(index_adjustment)) =
(self.start.take(), self.index_adjustment.take())
@ -141,7 +146,7 @@ impl<'graph, Id: GraphNodeId, S: BuildHasher, A: Iterator<Item = Id>, N: Iterato
/// If a visitation is required, this will return `None` and mark the required neighbor and the
/// current node as in need of visitation again.
/// If no SCC can be found in the current visitation stack, returns `None`.
fn visit_once(&mut self, v: Id, mut v_is_local_root: bool) -> Option<usize> {
fn visit_once(&mut self, v: N, mut v_is_local_root: bool) -> Option<usize> {
let node_v = &mut self.nodes[self.graph.to_index(v)];
if node_v.root_index.is_none() {
@ -205,13 +210,18 @@ impl<'graph, Id: GraphNodeId, S: BuildHasher, A: Iterator<Item = Id>, N: Iterato
}
}
impl<'graph, Id: GraphNodeId, S: BuildHasher, A: Iterator<Item = Id>, N: Iterator<Item = Id>>
Iterator for TarjanScc<'graph, Id, S, A, N>
impl<
'graph,
N: GraphNodeId,
S: BuildHasher,
A: Iterator<Item = N>,
Neighbors: Iterator<Item = N>,
> Iterator for TarjanScc<'graph, N, S, A, Neighbors>
{
// It is expected that the `DiGraph` is sparse, and as such wont have many large SCCs.
// Returning a `SmallVec` allows this iterator to skip allocation in cases where that
// assumption holds.
type Item = SmallVec<[Id; 4]>;
type Item = SmallVec<[N; 4]>;
fn next(&mut self) -> Option<Self::Item> {
let next = SmallVec::from_slice(self.next_scc()?);

View File

@ -14,7 +14,7 @@ use crate::{
prelude::{SystemIn, SystemSet},
query::FilteredAccessSet,
schedule::{
graph::{DirectedGraphNodeId, Direction, GraphNodeId, GraphNodeIdPair},
graph::{Direction, GraphNodeId},
BoxedCondition, InternedSystemSet,
},
system::{
@ -256,8 +256,8 @@ new_key_type! {
}
impl GraphNodeId for SystemKey {
type Directed = (SystemKey, Direction);
type Pair = (SystemKey, SystemKey);
type Adjacent = (SystemKey, Direction);
type Edge = (SystemKey, SystemKey);
}
impl TryFrom<NodeId> for SystemKey {
@ -322,8 +322,8 @@ impl NodeId {
}
impl GraphNodeId for NodeId {
type Directed = CompactNodeIdAndDirection;
type Pair = CompactNodeIdPair;
type Adjacent = CompactNodeIdAndDirection;
type Edge = CompactNodeIdPair;
}
impl From<SystemKey> for NodeId {
@ -348,14 +348,13 @@ pub struct CompactNodeIdAndDirection {
impl Debug for CompactNodeIdAndDirection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.unwrap().fmt(f)
let tuple: (_, _) = (*self).into();
tuple.fmt(f)
}
}
impl DirectedGraphNodeId for CompactNodeIdAndDirection {
type Id = NodeId;
fn new(id: NodeId, direction: Direction) -> Self {
impl From<(NodeId, Direction)> for CompactNodeIdAndDirection {
fn from((id, direction): (NodeId, Direction)) -> Self {
let key = match id {
NodeId::System(key) => key.data(),
NodeId::Set(key) => key.data(),
@ -368,20 +367,16 @@ impl DirectedGraphNodeId for CompactNodeIdAndDirection {
direction,
}
}
}
fn unwrap(self) -> (NodeId, Direction) {
let Self {
key,
is_system,
direction,
} = self;
let node = match is_system {
true => NodeId::System(key.into()),
false => NodeId::Set(key.into()),
impl From<CompactNodeIdAndDirection> for (NodeId, Direction) {
fn from(value: CompactNodeIdAndDirection) -> Self {
let node = match value.is_system {
true => NodeId::System(value.key.into()),
false => NodeId::Set(value.key.into()),
};
(node, direction)
(node, value.direction)
}
}
@ -396,14 +391,13 @@ pub struct CompactNodeIdPair {
impl Debug for CompactNodeIdPair {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.unwrap().fmt(f)
let tuple: (_, _) = (*self).into();
tuple.fmt(f)
}
}
impl GraphNodeIdPair for CompactNodeIdPair {
type Id = NodeId;
fn new(a: NodeId, b: NodeId) -> Self {
impl From<(NodeId, NodeId)> for CompactNodeIdPair {
fn from((a, b): (NodeId, NodeId)) -> Self {
let key_a = match a {
NodeId::System(index) => index.data(),
NodeId::Set(index) => index.data(),
@ -423,23 +417,18 @@ impl GraphNodeIdPair for CompactNodeIdPair {
is_system_b,
}
}
}
fn unwrap(self) -> (NodeId, NodeId) {
let Self {
key_a,
key_b,
is_system_a,
is_system_b,
} = self;
let a = match is_system_a {
true => NodeId::System(key_a.into()),
false => NodeId::Set(key_a.into()),
impl From<CompactNodeIdPair> for (NodeId, NodeId) {
fn from(value: CompactNodeIdPair) -> Self {
let a = match value.is_system_a {
true => NodeId::System(value.key_a.into()),
false => NodeId::Set(value.key_a.into()),
};
let b = match is_system_b {
true => NodeId::System(key_b.into()),
false => NodeId::Set(key_b.into()),
let b = match value.is_system_b {
true => NodeId::System(value.key_b.into()),
false => NodeId::Set(value.key_b.into()),
};
(a, b)