Skip to content

Commit b85c27f

Browse files
authored
Unrolled build for #142809
Rollup merge of #142809 - KMJ-007:ad-type-analysis-flag, r=ZuseZ4 Add PrintTAFn flag for targeted type analysis printing ## Summary This PR adds a new `PrintTAFn` flag to the `-Z autodiff` option that allows printing type analysis information for a specific function, rather than all functions. ## Changes ### New Flag - Added `PrintTAFn=<function_name>` option to `-Z autodiff` - Usage: `-Z autodiff=Enable,PrintTAFn=my_function_name` ### Implementation Details - **Rust side**: Added `PrintTAFn(String)` variant to `AutoDiff` enum - **Parser**: Updated `parse_autodiff` to handle `PrintTAFn=<function_name>` syntax with proper error handling - **FFI**: Added `set_print_type_fun` function to interface with Enzyme's `FunctionToAnalyze` command line option - **Documentation**: Updated help text and documentation for the new flag ### Files Modified - `compiler/rustc_session/src/config.rs`: Added `PrintTAFn(String)` variant - `compiler/rustc_session/src/options.rs`: Updated parser and help text (now shows `PrintTAFn` in the list) - `compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs`: Added FFI function and static variable - `compiler/rustc_codegen_llvm/src/back/lto.rs`: Added handling for new flag - `src/doc/rustc-dev-guide/src/autodiff/flags.md`: Updated documentation - `src/doc/unstable-book/src/compiler-flags/autodiff.md`: Updated documentation ## Testing The flag can be tested with: ```bash rustc +enzyme -Z autodiff=Enable,PrintTAFn=square test.rs ``` This will print type analysis information only for the function named "square" instead of all functions. ## Error Handling The parser includes proper error handling: - Missing argument: `PrintTAFn` without `=<function_name>` will show an error - Unknown options: Invalid autodiff options will be reported r? ```@ZuseZ4```
2 parents 0fa4ec6 + 7b1c89f commit b85c27f

File tree

7 files changed

+43
-5
lines changed

7 files changed

+43
-5
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ fn thin_lto(
587587
}
588588

589589
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
590-
for &val in ad {
590+
for val in ad {
591591
// We intentionally don't use a wildcard, to not forget handling anything new.
592592
match val {
593593
config::AutoDiff::PrintPerf => {
@@ -599,6 +599,10 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
599599
config::AutoDiff::PrintTA => {
600600
llvm::set_print_type(true);
601601
}
602+
config::AutoDiff::PrintTAFn(fun) => {
603+
llvm::set_print_type(true); // Enable general type printing
604+
llvm::set_print_type_fun(&fun); // Set specific function to analyze
605+
}
602606
config::AutoDiff::Inline => {
603607
llvm::set_inline(true);
604608
}

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,19 @@ pub(crate) use self::Enzyme_AD::*;
5757

5858
#[cfg(llvm_enzyme)]
5959
pub(crate) mod Enzyme_AD {
60+
use std::ffi::{CString, c_char};
61+
6062
use libc::c_void;
63+
6164
unsafe extern "C" {
6265
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
66+
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
6367
}
6468
unsafe extern "C" {
6569
static mut EnzymePrintPerf: c_void;
6670
static mut EnzymePrintActivity: c_void;
6771
static mut EnzymePrintType: c_void;
72+
static mut EnzymeFunctionToAnalyze: c_void;
6873
static mut EnzymePrint: c_void;
6974
static mut EnzymeStrictAliasing: c_void;
7075
static mut looseTypeAnalysis: c_void;
@@ -86,6 +91,15 @@ pub(crate) mod Enzyme_AD {
8691
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
8792
}
8893
}
94+
pub(crate) fn set_print_type_fun(fun_name: &str) {
95+
let c_fun_name = CString::new(fun_name).unwrap();
96+
unsafe {
97+
EnzymeSetCLString(
98+
std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze),
99+
c_fun_name.as_ptr() as *const c_char,
100+
);
101+
}
102+
}
89103
pub(crate) fn set_print(print: bool) {
90104
unsafe {
91105
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
@@ -132,6 +146,9 @@ pub(crate) mod Fallback_AD {
132146
pub(crate) fn set_print_type(print: bool) {
133147
unimplemented!()
134148
}
149+
pub(crate) fn set_print_type_fun(fun_name: &str) {
150+
unimplemented!()
151+
}
135152
pub(crate) fn set_print(print: bool) {
136153
unimplemented!()
137154
}

compiler/rustc_session/src/config.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,15 @@ pub enum CoverageLevel {
227227
}
228228

229229
/// The different settings that the `-Z autodiff` flag can have.
230-
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
230+
#[derive(Clone, PartialEq, Hash, Debug)]
231231
pub enum AutoDiff {
232232
/// Enable the autodiff opt pipeline
233233
Enable,
234234

235235
/// Print TypeAnalysis information
236236
PrintTA,
237+
/// Print TypeAnalysis information for a specific function
238+
PrintTAFn(String),
237239
/// Print ActivityAnalysis Information
238240
PrintAA,
239241
/// Print Performance Warnings from Enzyme

compiler/rustc_session/src/options.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ mod desc {
725725
pub(crate) const parse_list: &str = "a space-separated list of strings";
726726
pub(crate) const parse_list_with_polarity: &str =
727727
"a comma-separated list of strings, with elements beginning with + or -";
728-
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
728+
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
729729
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
730730
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
731731
pub(crate) const parse_number: &str = "a number";
@@ -1365,9 +1365,22 @@ pub mod parse {
13651365
let mut v: Vec<&str> = v.split(",").collect();
13661366
v.sort_unstable();
13671367
for &val in v.iter() {
1368-
let variant = match val {
1368+
// Split each entry on '=' if it has an argument
1369+
let (key, arg) = match val.split_once('=') {
1370+
Some((k, a)) => (k, Some(a)),
1371+
None => (val, None),
1372+
};
1373+
1374+
let variant = match key {
13691375
"Enable" => AutoDiff::Enable,
13701376
"PrintTA" => AutoDiff::PrintTA,
1377+
"PrintTAFn" => {
1378+
if let Some(fun) = arg {
1379+
AutoDiff::PrintTAFn(fun.to_string())
1380+
} else {
1381+
return false;
1382+
}
1383+
}
13711384
"PrintAA" => AutoDiff::PrintAA,
13721385
"PrintPerf" => AutoDiff::PrintPerf,
13731386
"PrintSteps" => AutoDiff::PrintSteps,

src/doc/rustc-dev-guide/src/autodiff/flags.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ To support you while debugging or profiling, we have added support for an experi
66

77
```text
88
PrintTA // Print TypeAnalysis information
9+
PrintTAFn // Print TypeAnalysis information for a specific function
910
PrintAA // Print ActivityAnalysis information
1011
Print // Print differentiated functions while they are being generated and optimized
1112
PrintPerf // Print AD related Performance warnings

src/doc/unstable-book/src/compiler-flags/autodiff.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Multiple options can be separated with a comma. Valid options are:
1010

1111
`Enable` - Required flag to enable autodiff
1212
`PrintTA` - print Type Analysis Information
13+
`PrintTAFn` - print Type Analysis Information for a specific function
1314
`PrintAA` - print Activity Analysis Information
1415
`PrintPerf` - print Performance Warnings from Enzyme
1516
`PrintSteps` - prints all intermediate transformations

src/tools/enzyme

Submodule enzyme updated 113 files

0 commit comments

Comments
 (0)