diff --git a/googletest_macro/src/lib.rs b/googletest_macro/src/lib.rs index 7569a737..52ae75b5 100644 --- a/googletest_macro/src/lib.rs +++ b/googletest_macro/src/lib.rs @@ -14,8 +14,9 @@ use quote::quote; use syn::{ - parse_macro_input, punctuated::Punctuated, spanned::Spanned, Attribute, DeriveInput, Expr, - ExprLit, FnArg, ItemFn, Lit, MetaNameValue, PatType, ReturnType, Signature, Type, + parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute, + DeriveInput, Expr, ExprLit, FnArg, ItemFn, Lit, MetaNameValue, PatType, ReturnType, Signature, + Type, }; /// Marks a test to be run by the Google Rust test runner. @@ -75,7 +76,7 @@ pub fn gtest( _args: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - let ItemFn { attrs, sig, block, .. } = parse_macro_input!(input as ItemFn); + let ItemFn { mut attrs, sig, block, .. } = parse_macro_input!(input as ItemFn); let test_case_hash: u64 = { use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; @@ -185,6 +186,13 @@ pub fn gtest( ) } }; + + if !attrs.iter().any(is_test_attribute) && !is_rstest_enabled { + let test_attr: Attribute = parse_quote! { + #[::core::prelude::v1::test] + }; + attrs.push(test_attr); + }; let function = quote! { #(#attrs)* #outer_sig -> #outer_return_type { @@ -200,17 +208,7 @@ pub fn gtest( #trailer } }; - - let output = if attrs.iter().any(is_test_attribute) || is_rstest_enabled { - function - } else { - quote! { - #[::core::prelude::v1::test] - #function - } - }; - - output.into() + function.into() } /// Extract the optional "expected" string literal from a `should_panic`