Skip to content

Commit 7b1c89f

Browse files
committed
added PrintTAFn flag for autodiff
Signed-off-by: Karan Janthe <karanjanthe@gmail.com>
1 parent 066ae4c commit 7b1c89f

File tree

6 files changed

+42
-4
lines changed

6 files changed

+42
-4
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ fn thin_lto(
587587
}
588588

589589
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
590-
for &val in ad {
590+
for val in ad {
591591
// We intentionally don't use a wildcard, to not forget handling anything new.
592592
match val {
593593
config::AutoDiff::PrintPerf => {
@@ -599,6 +599,10 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
599599
config::AutoDiff::PrintTA => {
600600
llvm::set_print_type(true);
601601
}
602+
config::AutoDiff::PrintTAFn(fun) => {
603+
llvm::set_print_type(true); // Enable general type printing
604+
llvm::set_print_type_fun(&fun); // Set specific function to analyze
605+
}
602606
config::AutoDiff::Inline => {
603607
llvm::set_inline(true);
604608
}

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,19 @@ pub(crate) use self::Enzyme_AD::*;
5757

5858
#[cfg(llvm_enzyme)]
5959
pub(crate) mod Enzyme_AD {
60+
use std::ffi::{CString, c_char};
61+
6062
use libc::c_void;
63+
6164
unsafe extern "C" {
6265
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
66+
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
6367
}
6468
unsafe extern "C" {
6569
static mut EnzymePrintPerf: c_void;
6670
static mut EnzymePrintActivity: c_void;
6771
static mut EnzymePrintType: c_void;
72+
static mut EnzymeFunctionToAnalyze: c_void;
6873
static mut EnzymePrint: c_void;
6974
static mut EnzymeStrictAliasing: c_void;
7075
static mut looseTypeAnalysis: c_void;
@@ -86,6 +91,15 @@ pub(crate) mod Enzyme_AD {
8691
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
8792
}
8893
}
94+
pub(crate) fn set_print_type_fun(fun_name: &str) {
95+
let c_fun_name = CString::new(fun_name).unwrap();
96+
unsafe {
97+
EnzymeSetCLString(
98+
std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze),
99+
c_fun_name.as_ptr() as *const c_char,
100+
);
101+
}
102+
}
89103
pub(crate) fn set_print(print: bool) {
90104
unsafe {
91105
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
@@ -132,6 +146,9 @@ pub(crate) mod Fallback_AD {
132146
pub(crate) fn set_print_type(print: bool) {
133147
unimplemented!()
134148
}
149+
pub(crate) fn set_print_type_fun(fun_name: &str) {
150+
unimplemented!()
151+
}
135152
pub(crate) fn set_print(print: bool) {
136153
unimplemented!()
137154
}

compiler/rustc_session/src/config.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,15 @@ pub enum CoverageLevel {
227227
}
228228

229229
/// The different settings that the `-Z autodiff` flag can have.
230-
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
230+
#[derive(Clone, PartialEq, Hash, Debug)]
231231
pub enum AutoDiff {
232232
/// Enable the autodiff opt pipeline
233233
Enable,
234234

235235
/// Print TypeAnalysis information
236236
PrintTA,
237+
/// Print TypeAnalysis information for a specific function
238+
PrintTAFn(String),
237239
/// Print ActivityAnalysis Information
238240
PrintAA,
239241
/// Print Performance Warnings from Enzyme

compiler/rustc_session/src/options.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ mod desc {
711711
pub(crate) const parse_list: &str = "a space-separated list of strings";
712712
pub(crate) const parse_list_with_polarity: &str =
713713
"a comma-separated list of strings, with elements beginning with + or -";
714-
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
714+
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
715715
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
716716
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
717717
pub(crate) const parse_number: &str = "a number";
@@ -1351,9 +1351,22 @@ pub mod parse {
13511351
let mut v: Vec<&str> = v.split(",").collect();
13521352
v.sort_unstable();
13531353
for &val in v.iter() {
1354-
let variant = match val {
1354+
// Split each entry on '=' if it has an argument
1355+
let (key, arg) = match val.split_once('=') {
1356+
Some((k, a)) => (k, Some(a)),
1357+
None => (val, None),
1358+
};
1359+
1360+
let variant = match key {
13551361
"Enable" => AutoDiff::Enable,
13561362
"PrintTA" => AutoDiff::PrintTA,
1363+
"PrintTAFn" => {
1364+
if let Some(fun) = arg {
1365+
AutoDiff::PrintTAFn(fun.to_string())
1366+
} else {
1367+
return false;
1368+
}
1369+
}
13571370
"PrintAA" => AutoDiff::PrintAA,
13581371
"PrintPerf" => AutoDiff::PrintPerf,
13591372
"PrintSteps" => AutoDiff::PrintSteps,

src/doc/rustc-dev-guide/src/autodiff/flags.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ To support you while debugging or profiling, we have added support for an experi
66

77
```text
88
PrintTA // Print TypeAnalysis information
9+
PrintTAFn // Print TypeAnalysis information for a specific function
910
PrintAA // Print ActivityAnalysis information
1011
Print // Print differentiated functions while they are being generated and optimized
1112
PrintPerf // Print AD related Performance warnings

src/doc/unstable-book/src/compiler-flags/autodiff.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Multiple options can be separated with a comma. Valid options are:
1010

1111
`Enable` - Required flag to enable autodiff
1212
`PrintTA` - print Type Analysis Information
13+
`PrintTAFn` - print Type Analysis Information for a specific function
1314
`PrintAA` - print Activity Analysis Information
1415
`PrintPerf` - print Performance Warnings from Enzyme
1516
`PrintSteps` - prints all intermediate transformations

0 commit comments

Comments
 (0)