Refactor bundle derive (#19749)

# Objective

- Splitted off from  #19491
- Make adding generated code to the `Bundle` derive macro easier
- Fix a bug when multiple fields are `#[bundle(ignore)]`

## Solution

- Instead of accumulating the code for each method in a different `Vec`,
accumulate only the names of non-ignored fields and their types, then
use `quote` to generate the code for each of them in the method body.
- To fix the bug, change the code populating the `BundleFieldKind` to
push only one of them per-field (previously each `#[bundle(ignore)]`
resulted in pushing twice, once for the correct
`BundleFieldKind::Ignore` and then again unconditionally for
`BundleFieldKind::Component`)

## Testing

- Added a regression test for the bug that was fixed
This commit is contained in:
Giacomo Stevanato 2025-06-20 18:36:08 +02:00 committed by GitHub
parent 8e1d0051d2
commit 35166d9029
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 72 deletions

View File

@ -16,7 +16,7 @@ use crate::{
use bevy_macro_utils::{derive_label, ensure_no_collision, get_struct_fields, BevyManifest}; use bevy_macro_utils::{derive_label, ensure_no_collision, get_struct_fields, BevyManifest};
use proc_macro::TokenStream; use proc_macro::TokenStream;
use proc_macro2::{Ident, Span}; use proc_macro2::{Ident, Span};
use quote::{format_ident, quote}; use quote::{format_ident, quote, ToTokens};
use syn::{ use syn::{
parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma,
ConstParam, Data, DataStruct, DeriveInput, GenericParam, Index, TypeParam, ConstParam, Data, DataStruct, DeriveInput, GenericParam, Index, TypeParam,
@ -79,6 +79,8 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
let mut field_kind = Vec::with_capacity(named_fields.len()); let mut field_kind = Vec::with_capacity(named_fields.len());
for field in named_fields { for field in named_fields {
let mut kind = BundleFieldKind::Component;
for attr in field for attr in field
.attrs .attrs
.iter() .iter()
@ -86,7 +88,7 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
{ {
if let Err(error) = attr.parse_nested_meta(|meta| { if let Err(error) = attr.parse_nested_meta(|meta| {
if meta.path.is_ident(BUNDLE_ATTRIBUTE_IGNORE_NAME) { if meta.path.is_ident(BUNDLE_ATTRIBUTE_IGNORE_NAME) {
field_kind.push(BundleFieldKind::Ignore); kind = BundleFieldKind::Ignore;
Ok(()) Ok(())
} else { } else {
Err(meta.error(format!( Err(meta.error(format!(
@ -98,7 +100,7 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
} }
} }
field_kind.push(BundleFieldKind::Component); field_kind.push(kind);
} }
let field = named_fields let field = named_fields
@ -111,82 +113,33 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
.map(|field| &field.ty) .map(|field| &field.ty)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let mut field_component_ids = Vec::new(); let mut active_field_types = Vec::new();
let mut field_get_component_ids = Vec::new(); let mut active_field_tokens = Vec::new();
let mut field_get_components = Vec::new(); let mut inactive_field_tokens = Vec::new();
let mut field_from_components = Vec::new();
let mut field_required_components = Vec::new();
for (((i, field_type), field_kind), field) in field_type for (((i, field_type), field_kind), field) in field_type
.iter() .iter()
.enumerate() .enumerate()
.zip(field_kind.iter()) .zip(field_kind.iter())
.zip(field.iter()) .zip(field.iter())
{ {
let field_tokens = match field {
Some(field) => field.to_token_stream(),
None => Index::from(i).to_token_stream(),
};
match field_kind { match field_kind {
BundleFieldKind::Component => { BundleFieldKind::Component => {
field_component_ids.push(quote! { active_field_types.push(field_type);
<#field_type as #ecs_path::bundle::Bundle>::component_ids(components, &mut *ids); active_field_tokens.push(field_tokens);
});
field_required_components.push(quote! {
<#field_type as #ecs_path::bundle::Bundle>::register_required_components(components, required_components);
});
field_get_component_ids.push(quote! {
<#field_type as #ecs_path::bundle::Bundle>::get_component_ids(components, &mut *ids);
});
match field {
Some(field) => {
field_get_components.push(quote! {
self.#field.get_components(&mut *func);
});
field_from_components.push(quote! {
#field: <#field_type as #ecs_path::bundle::BundleFromComponents>::from_components(ctx, &mut *func),
});
}
None => {
let index = Index::from(i);
field_get_components.push(quote! {
self.#index.get_components(&mut *func);
});
field_from_components.push(quote! {
#index: <#field_type as #ecs_path::bundle::BundleFromComponents>::from_components(ctx, &mut *func),
});
}
}
} }
BundleFieldKind::Ignore => { BundleFieldKind::Ignore => inactive_field_tokens.push(field_tokens),
field_from_components.push(quote! {
#field: ::core::default::Default::default(),
});
}
} }
} }
let generics = ast.generics; let generics = ast.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let struct_name = &ast.ident; let struct_name = &ast.ident;
let from_components = attributes.impl_from_components.then(|| quote! { let bundle_impl = quote! {
// SAFETY:
// - ComponentId is returned in field-definition-order. [from_components] uses field-definition-order
#[allow(deprecated)]
unsafe impl #impl_generics #ecs_path::bundle::BundleFromComponents for #struct_name #ty_generics #where_clause {
#[allow(unused_variables, non_snake_case)]
unsafe fn from_components<__T, __F>(ctx: &mut __T, func: &mut __F) -> Self
where
__F: FnMut(&mut __T) -> #ecs_path::ptr::OwningPtr<'_>
{
Self{
#(#field_from_components)*
}
}
}
});
let attribute_errors = &errors;
TokenStream::from(quote! {
#(#attribute_errors)*
// SAFETY: // SAFETY:
// - ComponentId is returned in field-definition-order. [get_components] uses field-definition-order // - ComponentId is returned in field-definition-order. [get_components] uses field-definition-order
// - `Bundle::get_components` is exactly once for each member. Rely's on the Component -> Bundle implementation to properly pass // - `Bundle::get_components` is exactly once for each member. Rely's on the Component -> Bundle implementation to properly pass
@ -196,27 +149,27 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
fn component_ids( fn component_ids(
components: &mut #ecs_path::component::ComponentsRegistrator, components: &mut #ecs_path::component::ComponentsRegistrator,
ids: &mut impl FnMut(#ecs_path::component::ComponentId) ids: &mut impl FnMut(#ecs_path::component::ComponentId)
){ ) {
#(#field_component_ids)* #(<#active_field_types as #ecs_path::bundle::Bundle>::component_ids(components, ids);)*
} }
fn get_component_ids( fn get_component_ids(
components: &#ecs_path::component::Components, components: &#ecs_path::component::Components,
ids: &mut impl FnMut(Option<#ecs_path::component::ComponentId>) ids: &mut impl FnMut(Option<#ecs_path::component::ComponentId>)
){ ) {
#(#field_get_component_ids)* #(<#active_field_types as #ecs_path::bundle::Bundle>::get_component_ids(components, &mut *ids);)*
} }
fn register_required_components( fn register_required_components(
components: &mut #ecs_path::component::ComponentsRegistrator, components: &mut #ecs_path::component::ComponentsRegistrator,
required_components: &mut #ecs_path::component::RequiredComponents required_components: &mut #ecs_path::component::RequiredComponents
){ ) {
#(#field_required_components)* #(<#active_field_types as #ecs_path::bundle::Bundle>::register_required_components(components, required_components);)*
} }
} }
};
#from_components let dynamic_bundle_impl = quote! {
#[allow(deprecated)] #[allow(deprecated)]
impl #impl_generics #ecs_path::bundle::DynamicBundle for #struct_name #ty_generics #where_clause { impl #impl_generics #ecs_path::bundle::DynamicBundle for #struct_name #ty_generics #where_clause {
type Effect = (); type Effect = ();
@ -226,9 +179,36 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
self, self,
func: &mut impl FnMut(#ecs_path::component::StorageType, #ecs_path::ptr::OwningPtr<'_>) func: &mut impl FnMut(#ecs_path::component::StorageType, #ecs_path::ptr::OwningPtr<'_>)
) { ) {
#(#field_get_components)* #(<#active_field_types as #ecs_path::bundle::DynamicBundle>::get_components(self.#active_field_tokens, &mut *func);)*
} }
} }
};
let from_components_impl = attributes.impl_from_components.then(|| quote! {
// SAFETY:
// - ComponentId is returned in field-definition-order. [from_components] uses field-definition-order
#[allow(deprecated)]
unsafe impl #impl_generics #ecs_path::bundle::BundleFromComponents for #struct_name #ty_generics #where_clause {
#[allow(unused_variables, non_snake_case)]
unsafe fn from_components<__T, __F>(ctx: &mut __T, func: &mut __F) -> Self
where
__F: FnMut(&mut __T) -> #ecs_path::ptr::OwningPtr<'_>
{
Self {
#(#active_field_tokens: <#active_field_types as #ecs_path::bundle::BundleFromComponents>::from_components(ctx, &mut *func),)*
#(#inactive_field_tokens: ::core::default::Default::default(),)*
}
}
}
});
let attribute_errors = &errors;
TokenStream::from(quote! {
#(#attribute_errors)*
#bundle_impl
#from_components_impl
#dynamic_bundle_impl
}) })
} }

View File

@ -2397,4 +2397,13 @@ mod tests {
assert_eq!(world.resource::<Count>().0, 3); assert_eq!(world.resource::<Count>().0, 3);
} }
#[derive(Bundle)]
#[expect(unused, reason = "tests the output of the derive macro is valid")]
struct Ignore {
#[bundle(ignore)]
foo: i32,
#[bundle(ignore)]
bar: i32,
}
} }