From 664000f8487ac070aaf33d64588433f949947df5 Mon Sep 17 00:00:00 2001 From: Tim Overbeek <158390905+Bleachfuel@users.noreply.github.com> Date: Fri, 7 Mar 2025 03:01:23 +0100 Subject: [PATCH] Improve derive(Event) and simplify macro code (#18083) # Objective simplify some code and improve Event macro Closes https://github.com/bevyengine/bevy/issues/14336, # Showcase you can now write derive Events like so ```rust #[derive(event)] #[event(auto_propagate, traversal = MyType)] struct MyEvent; ``` --- crates/bevy_ecs/macros/Cargo.toml | 2 +- crates/bevy_ecs/macros/src/component.rs | 310 +++++++++++------------- crates/bevy_ecs/macros/src/lib.rs | 2 +- crates/bevy_ecs/src/event/base.rs | 16 +- crates/bevy_ecs/src/observer/mod.rs | 9 +- examples/ecs/observer_propagation.rs | 26 +- 6 files changed, 174 insertions(+), 191 deletions(-) diff --git a/crates/bevy_ecs/macros/Cargo.toml b/crates/bevy_ecs/macros/Cargo.toml index 3325a102de..28605a5d67 100644 --- a/crates/bevy_ecs/macros/Cargo.toml +++ b/crates/bevy_ecs/macros/Cargo.toml @@ -11,7 +11,7 @@ proc-macro = true [dependencies] bevy_macro_utils = { path = "../../bevy_macro_utils", version = "0.16.0-dev" } -syn = { version = "2.0", features = ["full"] } +syn = { version = "2.0.99", features = ["full", "extra-traits"] } quote = "1.0" proc-macro2 = "1.0" [lints] diff --git a/crates/bevy_ecs/macros/src/component.rs b/crates/bevy_ecs/macros/src/component.rs index 48a7715b85..17fed7fa44 100644 --- a/crates/bevy_ecs/macros/src/component.rs +++ b/crates/bevy_ecs/macros/src/component.rs @@ -9,12 +9,18 @@ use syn::{ punctuated::Punctuated, spanned::Spanned, token::{Comma, Paren}, - Data, DataStruct, DeriveInput, Expr, ExprCall, ExprClosure, ExprPath, Field, Fields, Ident, - Index, LitStr, Member, Path, Result, Token, Type, Visibility, + Data, DataEnum, DataStruct, DeriveInput, Expr, ExprCall, ExprClosure, ExprPath, Field, Fields, + Ident, LitStr, Member, Path, Result, Token, Type, Visibility, }; +pub const EVENT: &str = "event"; +pub const AUTO_PROPAGATE: &str = "auto_propagate"; +pub const TRAVERSAL: &str = "traversal"; + pub fn derive_event(input: TokenStream) -> TokenStream { let mut ast = parse_macro_input!(input as DeriveInput); + let mut auto_propagate = false; + let mut traversal: Type = parse_quote!(()); let bevy_ecs_path: Path = crate::bevy_ecs_path(); ast.generics @@ -22,13 +28,30 @@ pub fn derive_event(input: TokenStream) -> TokenStream { .predicates .push(parse_quote! { Self: Send + Sync + 'static }); + if let Some(attr) = ast.attrs.iter().find(|attr| attr.path().is_ident(EVENT)) { + if let Err(e) = attr.parse_nested_meta(|meta| match meta.path.get_ident() { + Some(ident) if ident == AUTO_PROPAGATE => { + auto_propagate = true; + Ok(()) + } + Some(ident) if ident == TRAVERSAL => { + traversal = meta.value()?.parse()?; + Ok(()) + } + Some(ident) => Err(meta.error(format!("unsupported attribute: {}", ident))), + None => Err(meta.error("expected identifier")), + }) { + return e.to_compile_error().into(); + } + } + let struct_name = &ast.ident; let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl(); TokenStream::from(quote! { impl #impl_generics #bevy_ecs_path::event::Event for #struct_name #type_generics #where_clause { - type Traversal = (); - const AUTO_PROPAGATE: bool = false; + type Traversal = #traversal; + const AUTO_PROPAGATE: bool = #auto_propagate; } }) } @@ -51,8 +74,6 @@ pub fn derive_resource(input: TokenStream) -> TokenStream { }) } -const ENTITIES_ATTR: &str = "entities"; - pub fn derive_component(input: TokenStream) -> TokenStream { let mut ast = parse_macro_input!(input as DeriveInput); let bevy_ecs_path: Path = crate::bevy_ecs_path(); @@ -283,6 +304,8 @@ pub fn derive_component(input: TokenStream) -> TokenStream { }) } +const ENTITIES: &str = "entities"; + fn visit_entities( data: &Data, bevy_ecs_path: &Path, @@ -291,152 +314,106 @@ fn visit_entities( ) -> TokenStream2 { match data { Data::Struct(DataStruct { fields, .. }) => { - let mut visited_fields = Vec::new(); - let mut visited_indices = Vec::new(); + let mut visit = Vec::with_capacity(fields.len()); + let mut visit_mut = Vec::with_capacity(fields.len()); - if is_relationship { - let field = match relationship_field(fields, "VisitEntities", fields.span()) { - Ok(f) => f, - Err(e) => return e.to_compile_error(), - }; - - match field.ident { - Some(ref ident) => visited_fields.push(ident.clone()), - None => visited_indices.push(Index::from(0)), - } - } - match fields { - Fields::Named(fields) => { - for field in &fields.named { - if field - .attrs - .iter() - .any(|a| a.meta.path().is_ident(ENTITIES_ATTR)) - { - if let Some(ident) = field.ident.clone() { - visited_fields.push(ident); - } - } - } - } - Fields::Unnamed(fields) => { - for (index, field) in fields.unnamed.iter().enumerate() { - if index == 0 && is_relationship_target { - visited_indices.push(Index::from(0)); - } else if field - .attrs - .iter() - .any(|a| a.meta.path().is_ident(ENTITIES_ATTR)) - { - visited_indices.push(Index::from(index)); - } - } - } - Fields::Unit => {} - } - if visited_fields.is_empty() && visited_indices.is_empty() { - TokenStream2::new() + let relationship = if is_relationship || is_relationship_target { + relationship_field(fields, "VisitEntities", fields.span()).ok() } else { - let visit = visited_fields - .iter() - .map(|field| quote!(this.#field.visit_entities(&mut func);)) - .chain( - visited_indices - .iter() - .map(|index| quote!(this.#index.visit_entities(&mut func);)), - ); - let visit_mut = visited_fields - .iter() - .map(|field| quote!(this.#field.visit_entities_mut(&mut func);)) - .chain( - visited_indices - .iter() - .map(|index| quote!(this.#index.visit_entities_mut(&mut func);)), - ); - quote!( - fn visit_entities(this: &Self, mut func: impl FnMut(Entity)) { - use #bevy_ecs_path::entity::VisitEntities; - #(#visit)* - } + None + }; + fields + .iter() + .enumerate() + .filter(|(_, field)| { + field.attrs.iter().any(|a| a.path().is_ident(ENTITIES)) + || relationship.is_some_and(|relationship| relationship == *field) + }) + .for_each(|(index, field)| { + let field_member = field + .ident + .clone() + .map_or(Member::from(index), Member::Named); - fn visit_entities_mut(this: &mut Self, mut func: impl FnMut(&mut Entity)) { - use #bevy_ecs_path::entity::VisitEntitiesMut; - #(#visit_mut)* - } - ) - } - } - Data::Enum(data_enum) => { - let mut has_visited_fields = false; - let mut visit_variants = Vec::with_capacity(data_enum.variants.len()); - let mut visit_variants_mut = Vec::with_capacity(data_enum.variants.len()); - for variant in &data_enum.variants { - let mut variant_fields = Vec::new(); - let mut variant_fields_mut = Vec::new(); - - let mut visit_variant_fields = Vec::new(); - let mut visit_variant_fields_mut = Vec::new(); - - for (index, field) in variant.fields.iter().enumerate() { - if field - .attrs - .iter() - .any(|a| a.meta.path().is_ident(ENTITIES_ATTR)) - { - has_visited_fields = true; - let field_member = ident_or_index(field.ident.as_ref(), index); - let field_ident = format_ident!("field_{}", field_member); - - variant_fields.push(quote!(#field_member: #field_ident)); - variant_fields_mut.push(quote!(#field_member: #field_ident)); - - visit_variant_fields.push(quote!(#field_ident.visit_entities(&mut func);)); - visit_variant_fields_mut - .push(quote!(#field_ident.visit_entities_mut(&mut func);)); - } + visit.push(quote!(this.#field_member.visit_entities(&mut func);)); + visit_mut.push(quote!(this.#field_member.visit_entities_mut(&mut func);)); + }); + if visit.is_empty() { + return quote!(); + }; + quote!( + fn visit_entities(this: &Self, mut func: impl FnMut(#bevy_ecs_path::entity::Entity)) { + use #bevy_ecs_path::entity::VisitEntities; + #(#visit)* } + fn visit_entities_mut(this: &mut Self, mut func: impl FnMut(&mut #bevy_ecs_path::entity::Entity)) { + use #bevy_ecs_path::entity::VisitEntitiesMut; + #(#visit_mut)* + } + ) + } + Data::Enum(DataEnum { variants, .. }) => { + let mut visit = Vec::with_capacity(variants.len()); + let mut visit_mut = Vec::with_capacity(variants.len()); + + for variant in variants.iter() { + let field_members = variant + .fields + .iter() + .enumerate() + .filter(|(_, field)| field.attrs.iter().any(|a| a.path().is_ident(ENTITIES))) + .map(|(index, field)| { + field + .ident + .clone() + .map_or(Member::from(index), Member::Named) + }) + .collect::>(); + let ident = &variant.ident; - visit_variants.push(quote!(Self::#ident {#(#variant_fields,)* ..} => { - #(#visit_variant_fields)* - })); - visit_variants_mut.push(quote!(Self::#ident {#(#variant_fields_mut,)* ..} => { - #(#visit_variant_fields_mut)* - })); - } - if has_visited_fields { - quote!( - fn visit_entities(this: &Self, mut func: impl FnMut(Entity)) { - use #bevy_ecs_path::entity::VisitEntities; - match this { - #(#visit_variants,)* - _ => {} - } - } + let field_idents = field_members + .iter() + .map(|member| format_ident!("__self_{}", member)) + .collect::>(); - fn visit_entities_mut(this: &mut Self, mut func: impl FnMut(&mut Entity)) { - use #bevy_ecs_path::entity::VisitEntitiesMut; - match this { - #(#visit_variants_mut,)* - _ => {} - } - } - ) - } else { - TokenStream2::new() + visit.push( + quote!(Self::#ident {#(#field_members: #field_idents,)* ..} => { + #(#field_idents.visit_entities(&mut func);)* + }), + ); + visit_mut.push( + quote!(Self::#ident {#(#field_members: #field_idents,)* ..} => { + #(#field_idents.visit_entities_mut(&mut func);)* + }), + ); } + + if visit.is_empty() { + return quote!(); + }; + quote!( + fn visit_entities(this: &Self, mut func: impl FnMut(#bevy_ecs_path::entity::Entity)) { + use #bevy_ecs_path::entity::VisitEntities; + match this { + #(#visit,)* + _ => {} + } + } + + fn visit_entities_mut(this: &mut Self, mut func: impl FnMut(&mut #bevy_ecs_path::entity::Entity)) { + use #bevy_ecs_path::entity::VisitEntitiesMut; + match this { + #(#visit_mut,)* + _ => {} + } + } + ) } - Data::Union(_) => TokenStream2::new(), + Data::Union(_) => quote!(), } } -pub(crate) fn ident_or_index(ident: Option<&Ident>, index: usize) -> Member { - ident.map_or_else( - || Member::Unnamed(index.into()), - |ident| Member::Named(ident.clone()), - ) -} - pub const COMPONENT: &str = "component"; pub const STORAGE: &str = "storage"; pub const REQUIRE: &str = "require"; @@ -664,10 +641,15 @@ fn hook_register_function_call( }) } +mod kw { + syn::custom_keyword!(relationship_target); + syn::custom_keyword!(relationship); + syn::custom_keyword!(linked_spawn); +} + impl Parse for Relationship { fn parse(input: syn::parse::ParseStream) -> Result { - syn::custom_keyword!(relationship_target); - input.parse::()?; + input.parse::()?; input.parse::()?; Ok(Relationship { relationship_target: input.parse::()?, @@ -677,34 +659,30 @@ impl Parse for Relationship { impl Parse for RelationshipTarget { fn parse(input: syn::parse::ParseStream) -> Result { - let mut relationship_type: Option = None; - let mut linked_spawn_exists = false; - syn::custom_keyword!(relationship); - syn::custom_keyword!(linked_spawn); - let mut done = false; - loop { - if input.peek(relationship) { - input.parse::()?; + let mut relationship: Option = None; + let mut linked_spawn: bool = false; + + while !input.is_empty() { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::linked_spawn) { + input.parse::()?; + linked_spawn = true; + } else if lookahead.peek(kw::relationship) { + input.parse::()?; input.parse::()?; - relationship_type = Some(input.parse()?); - } else if input.peek(linked_spawn) { - input.parse::()?; - linked_spawn_exists = true; + relationship = Some(input.parse()?); } else { - done = true; + return Err(lookahead.error()); } - if input.peek(Token![,]) { + if !input.is_empty() { input.parse::()?; } - if done { - break; - } } - - let relationship = relationship_type.ok_or_else(|| syn::Error::new(input.span(), "RelationshipTarget derive must specify a relationship via #[relationship_target(relationship = X)"))?; Ok(RelationshipTarget { - relationship, - linked_spawn: linked_spawn_exists, + relationship: relationship.ok_or_else(|| { + syn::Error::new(input.span(), "Missing `relationship = X` attribute") + })?, + linked_spawn, }) } } @@ -730,8 +708,7 @@ fn derive_relationship( }; let field = relationship_field(fields, "Relationship", struct_token.span())?; - let relationship_member: Member = field.ident.clone().map_or(Member::from(0), Member::Named); - + let relationship_member = field.ident.clone().map_or(Member::from(0), Member::Named); let members = fields .members() .filter(|member| member != &relationship_member); @@ -787,7 +764,6 @@ fn derive_relationship_target( return Err(syn::Error::new(field.span(), "The collection in RelationshipTarget must be private to prevent users from directly mutating it, which could invalidate the correctness of relationships.")); } let collection = &field.ty; - let relationship_member = field.ident.clone().map_or(Member::from(0), Member::Named); let members = fields @@ -838,7 +814,7 @@ fn relationship_field<'a>( field .attrs .iter() - .any(|attr| attr.path().is_ident("relationship")) + .any(|attr| attr.path().is_ident(RELATIONSHIP)) }).ok_or(syn::Error::new( span, format!("{derive} derive expected named structs with a single field or with a field annotated with #[relationship].") diff --git a/crates/bevy_ecs/macros/src/lib.rs b/crates/bevy_ecs/macros/src/lib.rs index 9887f1fabe..9ae02bfb3a 100644 --- a/crates/bevy_ecs/macros/src/lib.rs +++ b/crates/bevy_ecs/macros/src/lib.rs @@ -585,7 +585,7 @@ pub(crate) fn bevy_ecs_path() -> syn::Path { BevyManifest::shared().get_path("bevy_ecs") } -#[proc_macro_derive(Event)] +#[proc_macro_derive(Event, attributes(event))] pub fn derive_event(input: TokenStream) -> TokenStream { component::derive_event(input) } diff --git a/crates/bevy_ecs/src/event/base.rs b/crates/bevy_ecs/src/event/base.rs index 42443d38b2..5105c786ac 100644 --- a/crates/bevy_ecs/src/event/base.rs +++ b/crates/bevy_ecs/src/event/base.rs @@ -18,10 +18,22 @@ use core::{ /// /// Events can also be "triggered" on a [`World`], which will then cause any [`Observer`] of that trigger to run. /// -/// This trait can be derived. -/// /// Events must be thread-safe. /// +/// ## Derive +/// This trait can be derived. +/// Adding `auto_propagate` sets [`Self::AUTO_PROPAGATE`] to true. +/// Adding `traversal = "X"` sets [`Self::Traversal`] to be of type "X". +/// +/// ``` +/// use bevy_ecs::prelude::*; +/// +/// #[derive(Event)] +/// #[event(auto_propagate)] +/// struct MyEvent; +/// ``` +/// +/// /// [`World`]: crate::world::World /// [`ComponentId`]: crate::component::ComponentId /// [`Observer`]: crate::observer::Observer diff --git a/crates/bevy_ecs/src/observer/mod.rs b/crates/bevy_ecs/src/observer/mod.rs index 0c40a7f5e3..4bbd82c85b 100644 --- a/crates/bevy_ecs/src/observer/mod.rs +++ b/crates/bevy_ecs/src/observer/mod.rs @@ -894,15 +894,10 @@ mod tests { } } - #[derive(Component)] + #[derive(Component, Event)] + #[event(traversal = &'static ChildOf, auto_propagate)] struct EventPropagating; - impl Event for EventPropagating { - type Traversal = &'static ChildOf; - - const AUTO_PROPAGATE: bool = true; - } - #[test] fn observer_order_spawn_despawn() { let mut world = World::new(); diff --git a/examples/ecs/observer_propagation.rs b/examples/ecs/observer_propagation.rs index 1ec1a69eab..1acf5efa90 100644 --- a/examples/ecs/observer_propagation.rs +++ b/examples/ecs/observer_propagation.rs @@ -42,23 +42,23 @@ fn setup(mut commands: Commands) { } // This event represents an attack we want to "bubble" up from the armor to the goblin. -#[derive(Clone, Component)] +// +// We enable propagation by adding the event attribute and specifying two important pieces of information. +// +// - **traversal:** +// Which component we want to propagate along. In this case, we want to "bubble" (meaning propagate +// from child to parent) so we use the `ChildOf` component for propagation. The component supplied +// must implement the `Traversal` trait. +// +// - **auto_propagate:** +// We can also choose whether or not this event will propagate by default when triggered. If this is +// false, it will only propagate following a call to `Trigger::propagate(true)`. +#[derive(Clone, Component, Event)] +#[event(traversal = &'static ChildOf, auto_propagate)] struct Attack { damage: u16, } -// We enable propagation by implementing `Event` manually (rather than using a derive) and specifying -// two important pieces of information: -impl Event for Attack { - // 1. Which component we want to propagate along. In this case, we want to "bubble" (meaning propagate - // from child to parent) so we use the `ChildOf` component for propagation. The component supplied - // must implement the `Traversal` trait. - type Traversal = &'static ChildOf; - // 2. We can also choose whether or not this event will propagate by default when triggered. If this is - // false, it will only propagate following a call to `Trigger::propagate(true)`. - const AUTO_PROPAGATE: bool = true; -} - /// An entity that can take damage. #[derive(Component, Deref, DerefMut)] struct HitPoints(u16);