diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index ee46b49a094c6..9c62244f3c9ff 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -587,7 +587,7 @@ fn thin_lto( } fn enable_autodiff_settings(ad: &[config::AutoDiff]) { - for &val in ad { + for val in ad { // We intentionally don't use a wildcard, to not forget handling anything new. match val { config::AutoDiff::PrintPerf => { @@ -599,6 +599,10 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) { config::AutoDiff::PrintTA => { llvm::set_print_type(true); } + config::AutoDiff::PrintTAFn(fun) => { + llvm::set_print_type(true); // Enable general type printing + llvm::set_print_type_fun(&fun); // Set specific function to analyze + } config::AutoDiff::Inline => { llvm::set_inline(true); } diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 2ad39fc853819..b94716b89d611 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -57,14 +57,19 @@ pub(crate) use self::Enzyme_AD::*; #[cfg(llvm_enzyme)] pub(crate) mod Enzyme_AD { + use std::ffi::{CString, c_char}; + use libc::c_void; + unsafe extern "C" { pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); + pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char); } unsafe extern "C" { static mut EnzymePrintPerf: c_void; static mut EnzymePrintActivity: c_void; static mut EnzymePrintType: c_void; + static mut EnzymeFunctionToAnalyze: c_void; static mut EnzymePrint: c_void; static mut EnzymeStrictAliasing: c_void; static mut looseTypeAnalysis: c_void; @@ -86,6 +91,15 @@ pub(crate) mod Enzyme_AD { EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8); } } + pub(crate) fn set_print_type_fun(fun_name: &str) { + let c_fun_name = CString::new(fun_name).unwrap(); + unsafe { + EnzymeSetCLString( + std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze), + c_fun_name.as_ptr() as *const c_char, + ); + } + } pub(crate) fn set_print(print: bool) { unsafe { EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8); @@ -132,6 +146,9 @@ pub(crate) mod Fallback_AD { pub(crate) fn set_print_type(print: bool) { unimplemented!() } + pub(crate) fn set_print_type_fun(fun_name: &str) { + unimplemented!() + } pub(crate) fn set_print(print: bool) { unimplemented!() } diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 60e1b465ba96d..0f3e53d03d653 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -227,13 +227,15 @@ pub enum CoverageLevel { } /// The different settings that the `-Z autodiff` flag can have. -#[derive(Clone, Copy, PartialEq, Hash, Debug)] +#[derive(Clone, PartialEq, Hash, Debug)] pub enum AutoDiff { /// Enable the autodiff opt pipeline Enable, /// Print TypeAnalysis information PrintTA, + /// Print TypeAnalysis information for a specific function + PrintTAFn(String), /// Print ActivityAnalysis Information PrintAA, /// Print Performance Warnings from Enzyme diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 12fa05118caf5..513a006e7905b 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -711,7 +711,7 @@ mod desc { pub(crate) const parse_list: &str = "a space-separated list of strings"; pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; - pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`"; + pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; pub(crate) const parse_number: &str = "a number"; @@ -1351,9 +1351,22 @@ pub mod parse { let mut v: Vec<&str> = v.split(",").collect(); v.sort_unstable(); for &val in v.iter() { - let variant = match val { + // Split each entry on '=' if it has an argument + let (key, arg) = match val.split_once('=') { + Some((k, a)) => (k, Some(a)), + None => (val, None), + }; + + let variant = match key { "Enable" => AutoDiff::Enable, "PrintTA" => AutoDiff::PrintTA, + "PrintTAFn" => { + if let Some(fun) = arg { + AutoDiff::PrintTAFn(fun.to_string()) + } else { + return false; + } + } "PrintAA" => AutoDiff::PrintAA, "PrintPerf" => AutoDiff::PrintPerf, "PrintSteps" => AutoDiff::PrintSteps, diff --git a/src/doc/rustc-dev-guide/src/autodiff/flags.md b/src/doc/rustc-dev-guide/src/autodiff/flags.md index 65287d9ba4c19..efbb9ea3497cb 100644 --- a/src/doc/rustc-dev-guide/src/autodiff/flags.md +++ b/src/doc/rustc-dev-guide/src/autodiff/flags.md @@ -6,6 +6,7 @@ To support you while debugging or profiling, we have added support for an experi ```text PrintTA // Print TypeAnalysis information +PrintTAFn // Print TypeAnalysis information for a specific function PrintAA // Print ActivityAnalysis information Print // Print differentiated functions while they are being generated and optimized PrintPerf // Print AD related Performance warnings diff --git a/src/doc/unstable-book/src/compiler-flags/autodiff.md b/src/doc/unstable-book/src/compiler-flags/autodiff.md index 95c188d1f3b29..28d2ece1468f7 100644 --- a/src/doc/unstable-book/src/compiler-flags/autodiff.md +++ b/src/doc/unstable-book/src/compiler-flags/autodiff.md @@ -10,6 +10,7 @@ Multiple options can be separated with a comma. Valid options are: `Enable` - Required flag to enable autodiff `PrintTA` - print Type Analysis Information +`PrintTAFn` - print Type Analysis Information for a specific function `PrintAA` - print Activity Analysis Information `PrintPerf` - print Performance Warnings from Enzyme `PrintSteps` - prints all intermediate transformations diff --git a/src/tools/enzyme b/src/tools/enzyme index a35f4f773118c..b5098d515d5e1 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit a35f4f773118ccfbd8d05102eb12a34097b1ee55 +Subproject commit b5098d515d5e1bd0f5470553bc0d18da9794ca8b