On Thu, Dec 11, 2025 at 2:30 PM Gary Guo gary@kernel.org wrote:
From: Gary Guo gary@garyguo.net
This allows significant cleanups.
Signed-off-by: Gary Guo gary@garyguo.net
rust/macros/kunit.rs | 274 +++++++++++++++++++------------------------ rust/macros/lib.rs | 6 +- 2 files changed, 123 insertions(+), 157 deletions(-)
diff --git a/rust/macros/kunit.rs b/rust/macros/kunit.rs index 7427c17ee5f5c..516219f5b1356 100644 --- a/rust/macros/kunit.rs +++ b/rust/macros/kunit.rs @@ -4,81 +4,50 @@ //! //! Copyright (c) 2023 José Expósito jose.exposito89@gmail.com
-use std::collections::HashMap; -use std::fmt::Write;
-use proc_macro2::{Delimiter, Group, TokenStream, TokenTree};
-pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
- let attr = attr.to_string();
- if attr.is_empty() {
panic!("Missing test name in `#[kunit_tests(test_name)]` macro")- }
- if attr.len() > 255 {
panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes")+use std::ffi::CString;
+use proc_macro2::TokenStream; +use quote::{
- format_ident,
- quote,
- ToTokens, //
+}; +use syn::{
- parse_quote,
- Error,
- Ident,
- Item,
- ItemMod,
- LitCStr,
- Result, //
+};
+pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> {
- if test_suite.to_string().len() > 255 {
return Err(Error::new_spanned(test_suite,"test suite names cannot exceed the maximum length of 255 bytes", }));
- let mut tokens: Vec<_> = ts.into_iter().collect();
- // Scan for the `mod` keyword.
- tokens
.iter().find_map(|token| match token {TokenTree::Ident(ident) => match ident.to_string().as_str() {"mod" => Some(true),_ => None,},_ => None,}).expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");- // Retrieve the main body. The main body should be the last token tree.
- let body = match tokens.pop() {
Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,_ => panic!("Cannot locate main body of module"),
- // We cannot handle modules that defer to another file (e.g. `mod foo;`).
- let Some((module_brace, module_items)) = module.content.take() else {
Err(Error::new_spanned(module,"`#[kunit_tests(test_name)]` attribute should only be applied to inline modules", };))?
- // Get the functions set as tests. Search for `[test]` -> `fn`.
- let mut body_it = body.stream().into_iter();
- let mut tests = Vec::new();
- let mut attributes: HashMap<String, TokenStream> = HashMap::new();
- while let Some(token) = body_it.next() {
match token {TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() {Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() {// Collect attributes because we need to find which are tests. We also// need to copy `cfg` attributes so tests can be conditionally enabled.attributes.entry(name.to_string()).or_default().extend([token, TokenTree::Group(g)]);}continue;}_ => (),},TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => {if let Some(TokenTree::Ident(test_name)) = body_it.next() {tests.push((test_name, attributes.remove("cfg").unwrap_or_default()))}}_ => (),}attributes.clear();- }
- // Make the entire module gated behind `CONFIG_KUNIT`.
- module
.attrs.insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")]));
Does this need to be the first attribute? I think it can just be pushed to the end.
- // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration.
- let config_kunit = "#[cfg(CONFIG_KUNIT="y")]".to_owned().parse().unwrap();
- tokens.insert(
0,TokenTree::Group(Group::new(Delimiter::None, config_kunit)),- );
let mut processed_items = Vec::new();
let mut test_cases = Vec::new();
// Generate the test KUnit test suite and a test case for each `#[test]`.
// // The code generated for the following test module: // // ```
@@ -110,98 +79,93 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { // // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); // ```
- let mut kunit_macros = "".to_owned();
- let mut test_cases = "".to_owned();
- let mut assert_macros = "".to_owned();
- let path = crate::helpers::file();
- let num_tests = tests.len();
- for (test, cfg_attr) in tests {
let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");// Append any `cfg` attributes the user might have written on their tests so we don't// attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce// the length of the assert message.let kunit_wrapper = format!(r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit){{(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;{cfg_attr} {{(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;use ::kernel::kunit::is_test_result_ok;assert!(is_test_result_ok({test}()));
- //
- // Non-function items (e.g. imports) are preserved.
- for item in module_items {
let Item::Fn(mut f) = item else {processed_items.push(item);continue;};// TODO: Replace below with `extract_if` when MSRV is bumped above 1.85.// Remove `#[test]` attributes applied on the function and count if any.
What does "count if any" mean here?
if !f.attrs.iter().any(|attr| attr.path().is_ident("test")) {processed_items.push(Item::Fn(f));continue;}f.attrs.retain(|attr| !attr.path().is_ident("test"));
Can this code be something like this:
let before = f.attrs.len(); f.attrs.retain(|attr| !attr.path().is_ident("test")); let after = f.attrs.len();
if after == before { processed_items.push(Item::Fn(f)); continue; }
let test = f.sig.ident.clone();// Retrieve `#[cfg]` applied on the function which needs to be present on derived items too.let cfg_attrs: Vec<_> = f.attrs.iter().filter(|attr| attr.path().is_ident("cfg")).cloned().collect();// Before the test, override usual `assert!` and `assert_eq!` macros with ones that call// KUnit instead.let test_str = test.to_string();let path = crate::helpers::file();processed_items.push(parse_quote! {#[allow(unused)]macro_rules! assert {($cond:expr $(,)?) => {{kernel::kunit_assert!(#test_str, #path, 0, $cond);}}}});processed_items.push(parse_quote! {#[allow(unused)]macro_rules! assert_eq {($left:expr, $right:expr $(,)?) => {{kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right); }}
}}"#,
}});
Am I reading this right that the macros will be repeatedly redefined before each test? Could we put them inside each test body instead?
// Add back the test item.processed_items.push(Item::Fn(f));let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}");let test_cstr = LitCStr::new(&CString::new(test_str.as_str()).expect("identifier cannot contain NUL"),test.span(), );
writeln!(kunit_macros, "{kunit_wrapper}").unwrap();writeln!(test_cases," ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name}),").unwrap();writeln!(assert_macros,r#"-/// Overrides the usual [`assert!`] macro with one that calls KUnit instead. -#[allow(unused)] -macro_rules! assert {{
- ($cond:expr $(,)?) => {{{{
kernel::kunit_assert!("{test}", "{path}", 0, $cond);- }}}}
-}}
-/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead. -#[allow(unused)] -macro_rules! assert_eq {{
- ($left:expr, $right:expr $(,)?) => {{{{
kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right);- }}}}
-}}
"#).unwrap();- }
processed_items.push(parse_quote! {unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) {(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
- writeln!(kunit_macros).unwrap();
- writeln!(
kunit_macros,"static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];",num_tests + 1- )
- .unwrap();
- writeln!(
kunit_macros,"::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"- )
- .unwrap();
- // Remove the `#[test]` macros.
- // We do this at a token level, in order to preserve span information.
- let mut new_body = vec![];
- let mut body_it = body.stream().into_iter();
- while let Some(token) = body_it.next() {
match token {TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),Some(next) => {new_body.extend([token, next]);}_ => {new_body.push(token);
// Append any `cfg` attributes the user might have written on their tests so we// don't attempt to call them when they are `cfg`'d out. An extra `use` is used// here to reduce the length of the assert message.#(#cfg_attrs)*{(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;use ::kernel::kunit::is_test_result_ok;assert!(is_test_result_ok(#test())); }
},_ => {new_body.push(token); }}- }
- let mut final_body = TokenStream::new();
- final_body.extend::<TokenStream>(assert_macros.parse().unwrap());
- final_body.extend(new_body);
- final_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
});
- tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body)));
test_cases.push(quote!(::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name)));- }
- tokens.into_iter().collect()
- let num_tests_plus_1 = test_cases.len() + 1;
- processed_items.push(parse_quote! {
static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [#(#test_cases,)*::kernel::kunit::kunit_case_null(),];- });
- processed_items.push(parse_quote! {
::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES);- });
- module.content = Some((module_brace, processed_items));
- Ok(module.to_token_stream())
} diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs index bb2dfd4a4dafc..9cfac9fce0d36 100644 --- a/rust/macros/lib.rs +++ b/rust/macros/lib.rs @@ -453,6 +453,8 @@ pub fn paste(input: TokenStream) -> TokenStream { /// } /// ``` #[proc_macro_attribute] -pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
- kunit::kunit_tests(attr.into(), ts.into()).into()
+pub fn kunit_tests(attr: TokenStream, input: TokenStream) -> TokenStream {
- kunit::kunit_tests(parse_macro_input!(attr), parse_macro_input!(input))
.unwrap_or_else(|e| e.into_compile_error()).into()}
2.51.2