From 065172eee1508bc8f109fd9327224cef75750b04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Wed, 4 Jun 2025 08:10:51 +0000 Subject: [PATCH 1/4] Lower autodiff functions using instrinsics --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 3 +++ compiler/rustc_hir_analysis/src/check/intrinsic.rs | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index e8629aeebb95a..ca46314a94936 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -197,6 +197,9 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { Some(instance), ) } + _ if tcx.has_attr(def_id, sym::rustc_autodiff) => { + return Err(ty::Instance::new_raw(def_id, instance.args)); + } sym::is_val_statically_known => { let intrinsic_type = args[0].layout.immediate_llvm_type(self.cx); let kind = self.type_kind(intrinsic_type); diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 54bb3ac411304..ea02ca7fec52f 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -174,6 +174,8 @@ pub(crate) fn check_intrinsic_type( }; let name_str = intrinsic_name.as_str(); + let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff); + let bound_vars = tcx.mk_bound_variable_kinds(&[ ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), @@ -229,6 +231,17 @@ pub(crate) fn check_intrinsic_type( // // so: two type params, 0 lifetime param, 0 const params, two inputs, no return (2, 0, 0, vec![param(0), param(1)], param(1), hir::Safety::Safe) + } else if has_autodiff { + let sig = tcx.fn_sig(intrinsic_id.to_def_id()); + let sig = sig.skip_binder(); + let n_tps = generics.own_counts().types; + let n_lts = generics.own_counts().lifetimes; + let n_cts = generics.own_counts().consts; + + let inputs = sig.skip_binder().inputs().to_vec(); + let output = sig.skip_binder().output(); + + (n_tps, n_lts, n_cts, inputs, output, hir::Safety::Safe) } else { let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); let (n_tps, n_cts, inputs, output) = match intrinsic_name { From 3ce4d9c53e32aa04ec67cce30ef5e7e7fdc44a60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Thu, 5 Jun 2025 17:46:21 +0000 Subject: [PATCH 2/4] Macro expansion with `rustc_intrinsic` WARNING: ad function defined in traits are broken --- compiler/rustc_builtin_macros/src/autodiff.rs | 20 ++- tests/pretty/autodiff/autodiff_forward.pp | 136 +++++------------- tests/pretty/autodiff/autodiff_forward.rs | 1 + tests/pretty/autodiff/autodiff_reverse.pp | 43 ++---- tests/pretty/autodiff/autodiff_reverse.rs | 5 +- tests/pretty/autodiff/inherent_impl.pp | 3 +- tests/pretty/autodiff/inherent_impl.rs | 1 + 7 files changed, 67 insertions(+), 142 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index dc3bb8ab52a5d..2e97a125fc067 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -330,7 +330,9 @@ mod llvm_enzyme { .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) .count() as u32; let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - let d_body = gen_enzyme_body( + + // UNUSED + let _d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, &generics, ); @@ -342,7 +344,7 @@ mod llvm_enzyme { ident: first_ident(&meta_item_vec[0]), generics, contract: None, - body: Some(d_body), + body: None, define_opaque: None, }); let mut rustc_ad_attr = @@ -429,12 +431,18 @@ mod llvm_enzyme { tokens: ts, }); + let rustc_intrinsic_attr = + P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_intrinsic))); + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); + let intrinsic_attr = outer_normal_attr(&rustc_intrinsic_attr, new_id, span); + + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = P(ast::AssocItem { - attrs: thin_vec![d_attr, inline_never], + attrs: thin_vec![d_attr, intrinsic_attr], id: ast::DUMMY_NODE_ID, span, vis, @@ -444,13 +452,15 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = + ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = + ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Stmt(P(ast::Stmt { diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index a2525abc83207..787c2e517492c 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -36,78 +37,44 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] -#[inline(never)] -pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f1(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(<(f64, f64)>::default()) -} +#[rustc_intrinsic] +pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64); #[rustc_autodiff] #[inline(never)] pub fn f2(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f2(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f2(x, y)) -} +#[rustc_intrinsic] +pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f3(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f3(x, y)) -} +#[rustc_intrinsic] +pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f4() {} #[rustc_autodiff(Forward, 1, None)] -#[inline(never)] -pub fn df4() -> () { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f4()); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df4() -> (); #[rustc_autodiff] #[inline(never)] pub fn f5(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const, Dual, Const)] -#[inline(never)] -pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((by_0,)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64; #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; struct DoesNotImplDefault; #[rustc_autodiff] #[inline(never)] @@ -115,84 +82,47 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const)] -#[inline(never)] -pub fn df6() -> DoesNotImplDefault { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f6()); - ::core::hint::black_box(()); - ::core::hint::black_box(f6()) -} +#[rustc_intrinsic] +pub fn df6() -> DoesNotImplDefault; #[rustc_autodiff] #[inline(never)] pub fn f7(x: f32) -> () {} #[rustc_autodiff(Forward, 1, Const, None)] -#[inline(never)] -pub fn df7(x: f32) -> () { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f7(x)); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df7(x: f32) -> (); #[no_mangle] #[rustc_autodiff] #[inline(never)] fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 4, Dual, Dual)] -#[inline(never)] +#[rustc_intrinsic] fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) - -> [f32; 5usize] { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3)); - ::core::hint::black_box(<[f32; 5usize]>::default()) -} +-> [f32; 5usize]; #[rustc_autodiff(Forward, 4, Dual, DualOnly)] -#[inline(never)] +#[rustc_intrinsic] fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) - -> [f32; 4usize] { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3)); - ::core::hint::black_box(<[f32; 4usize]>::default()) -} +-> [f32; 4usize]; #[rustc_autodiff(Forward, 1, Dual, DualOnly)] -#[inline(never)] -fn f8_1(x: &f32, bx_0: &f32) -> f32 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(::default()) -} +#[rustc_intrinsic] +fn f8_1(x: &f32, bx_0: &f32) -> f32; pub fn f9() { #[rustc_autodiff] #[inline(never)] fn inner(x: f32) -> f32 { x * x } #[rustc_autodiff(Forward, 1, Dual, Dual)] - #[inline(never)] - fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(inner(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(<(f32, f32)>::default()) - } + #[rustc_intrinsic] + fn d_inner_2(x: f32, bx_0: f32) + -> (f32, f32); #[rustc_autodiff(Forward, 1, Dual, DualOnly)] - #[inline(never)] - fn d_inner_1(x: f32, bx_0: f32) -> f32 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(inner(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(::default()) - } + #[rustc_intrinsic] + fn d_inner_1(x: f32, bx_0: f32) + -> f32; } #[rustc_autodiff] #[inline(never)] pub fn f10 + Copy>(x: &T) -> T { *x * *x } #[rustc_autodiff(Reverse, 1, Duplicated, Active)] -#[inline(never)] +#[rustc_intrinsic] pub fn d_square + - Copy>(x: &T, dx_0: &mut T, dret: T) -> T { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f10::(x)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f10::(x)) -} +Copy>(x: &T, dx_0: &mut T, dret: T) -> T; fn main() {} diff --git a/tests/pretty/autodiff/autodiff_forward.rs b/tests/pretty/autodiff/autodiff_forward.rs index e23a1b3e241e9..b003d87dccfa7 100644 --- a/tests/pretty/autodiff/autodiff_forward.rs +++ b/tests/pretty/autodiff/autodiff_forward.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_forward.pp diff --git a/tests/pretty/autodiff/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp index e67c3443ddef1..6f368c74f1a26 100644 --- a/tests/pretty/autodiff/autodiff_reverse.pp +++ b/tests/pretty/autodiff/autodiff_reverse.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -29,58 +30,36 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f1(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f1(x, y)) -} +#[rustc_intrinsic] +pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f2() {} #[rustc_autodiff(Reverse, 1, None)] -#[inline(never)] -pub fn df2() { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f2()); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df2(); #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f3(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f3(x, y)) -} +#[rustc_intrinsic] +pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; enum Foo { Reverse, } use Foo::Reverse; #[rustc_autodiff] #[inline(never)] pub fn f4(x: f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Const, None)] -#[inline(never)] -pub fn df4(x: f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f4(x)); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df4(x: f32); #[rustc_autodiff] #[inline(never)] pub fn f5(x: *const f32, y: &f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)] -#[inline(never)] -pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((dx_0, dy_0)); -} +#[rustc_intrinsic] +pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32); fn main() {} diff --git a/tests/pretty/autodiff/autodiff_reverse.rs b/tests/pretty/autodiff/autodiff_reverse.rs index d37e5e3eb4cec..fc95ba2e5a63e 100644 --- a/tests/pretty/autodiff/autodiff_reverse.rs +++ b/tests/pretty/autodiff/autodiff_reverse.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_reverse.pp @@ -23,7 +24,9 @@ pub fn f3(x: &[f64], y: f64) -> f64 { unimplemented!() } -enum Foo { Reverse } +enum Foo { + Reverse, +} use Foo::Reverse; // What happens if we already have Reverse in type (enum variant decl) and value (enum variant // constructor) namespace? > It's expected to work normally. diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp index d18061b2dbdef..4bc8dac0dc758 100644 --- a/tests/pretty/autodiff/inherent_impl.pp +++ b/tests/pretty/autodiff/inherent_impl.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -31,7 +32,7 @@ self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) } #[rustc_autodiff(Reverse, 1, Const, Active, Active)] - #[inline(never)] + #[rustc_intrinsic] fn df(&self, x: f64, dret: f64) -> (f64, f64) { unsafe { asm!("NOP", options(pure, nomem)); }; ::core::hint::black_box(self.f(x)); diff --git a/tests/pretty/autodiff/inherent_impl.rs b/tests/pretty/autodiff/inherent_impl.rs index 11ff209f9d89e..9f00ff5eb02c1 100644 --- a/tests/pretty/autodiff/inherent_impl.rs +++ b/tests/pretty/autodiff/inherent_impl.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:inherent_impl.pp From fbd9b77f788ef8dfd4afd9dc76edb83a7934ce92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 17 Jun 2025 19:52:00 +0000 Subject: [PATCH 3/4] Lowering draft --- compiler/rustc_builtin_macros/src/autodiff.rs | 2 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 2e97a125fc067..ecd68e0015851 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -344,7 +344,7 @@ mod llvm_enzyme { ident: first_ident(&meta_item_vec[0]), generics, contract: None, - body: None, + body: None, // This leads to an error when the ad function is inside a traits define_opaque: None, }); let mut rustc_ad_attr = diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index ca46314a94936..5813bdf8435e2 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -198,6 +198,21 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { ) } _ if tcx.has_attr(def_id, sym::rustc_autodiff) => { + // NOTE(Sa4dUs): This is a hacky way to get the autodiff items + // so we can focus on the lowering of the intrinsic call + + // `diff_items` is empty even when autodiff is enabled, and if we're here, + // it's because some function was marked as intrinsic and had the `rustc_autodiff` attr + let diff_items = tcx.collect_and_partition_mono_items(()).autodiff_items; + + // this shouldn't happen? + if diff_items.is_empty() { + bug!("no autodiff items found for {def_id:?}"); + } + + // TODO(Sa4dUs): generate the enzyme call itself, based on the logic in `builder.rs` + + // Just gen the fallback body for now return Err(ty::Instance::new_raw(def_id, instance.args)); } sym::is_val_statically_known => { From 0f537d68a302be32317e16b1321462397795b006 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 23 Jun 2025 12:17:53 +0000 Subject: [PATCH 4/4] Naive impl of intrinsic codegen Note(Sa4dUs): Most tests are still broken due to `sret` and how funcs are searched in the current logic --- .../src/builder/autodiff.rs | 62 ++++++------------- compiler/rustc_codegen_llvm/src/intrinsic.rs | 60 ++++++++++++++---- .../rustc_hir_analysis/src/check/intrinsic.rs | 1 + tests/codegen/autodiff/scalar.rs | 1 + 4 files changed, 70 insertions(+), 54 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index c5c13ac097a27..9bfb631f631d9 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -4,13 +4,13 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::back::write::ModuleConfig; use rustc_codegen_ssa::common::TypeKind; -use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; +use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_errors::FatalError; use rustc_middle::bug; use tracing::{debug, trace}; use crate::back::write::llvm_err; -use crate::builder::{SBuilder, UNNAMED}; +use crate::builder::{Builder, OperandRef, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; @@ -19,7 +19,7 @@ use crate::llvm::{Metadata, True}; use crate::value::Value; use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; -fn get_params(fnc: &Value) -> Vec<&Value> { +fn _get_params(fnc: &Value) -> Vec<&Value> { let param_num = llvm::LLVMCountParams(fnc) as usize; let mut fnc_args: Vec<&Value> = vec![]; fnc_args.reserve(param_num); @@ -49,9 +49,9 @@ fn has_sret(fnc: &Value) -> bool { // need to match those. // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it // using iterators and peek()? -fn match_args_from_caller_to_enzyme<'ll>( +fn match_args_from_caller_to_enzyme<'ll, 'tcx>( cx: &SimpleCx<'ll>, - builder: &SBuilder<'ll, 'll>, + builder: &mut Builder<'_, 'll, 'tcx>, width: u32, args: &mut Vec<&'ll llvm::Value>, inputs: &[DiffActivity], @@ -289,11 +289,14 @@ fn compute_enzyme_fn_ty<'ll>( /// [^1]: // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. -fn generate_enzyme_call<'ll>( +pub(crate) fn generate_enzyme_call<'ll, 'tcx>( + builder: &mut Builder<'_, 'll, 'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, outer_fn: &'ll Value, + fn_args: &[OperandRef<'tcx, &'ll Value>], attrs: AutoDiffAttrs, + dest: PlaceRef<'tcx, &'ll Value>, ) { // We have to pick the name depending on whether we want forward or reverse mode autodiff. let mut ad_name: String = match attrs.mode { @@ -366,14 +369,6 @@ fn generate_enzyme_call<'ll>( let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker"); attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]); - // first, remove all calls from fnc - let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); - let br = llvm::LLVMRustGetTerminator(entry); - llvm::LLVMRustEraseInstFromParent(br); - - let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); - let mut builder = SBuilder::build(cx, entry); - let num_args = llvm::LLVMCountParams(&fn_to_diff); let mut args = Vec::with_capacity(num_args as usize + 1); args.push(fn_to_diff); @@ -389,10 +384,10 @@ fn generate_enzyme_call<'ll>( } let has_sret = has_sret(outer_fn); - let outer_args: Vec<&llvm::Value> = get_params(outer_fn); + let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect(); match_args_from_caller_to_enzyme( &cx, - &builder, + builder, attrs.width, &mut args, &attrs.input_activity, @@ -400,29 +395,9 @@ fn generate_enzyme_call<'ll>( has_sret, ); - let call = builder.call(enzyme_ty, ad_fn, &args, None); - - // This part is a bit iffy. LLVM requires that a call to an inlineable function has some - // metadata attached to it, but we just created this code oota. Given that the - // differentiated function already has partly confusing metadata, and given that this - // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the - // dummy code which we inserted at a higher level. - // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have, - // and how to best improve it for enzyme core and rust-enzyme. - let md_ty = cx.get_md_kind_id("dbg"); - if llvm::LLVMRustHasMetadata(last_inst, md_ty) { - let md = llvm::LLVMRustDIGetInstMetadata(last_inst) - .expect("failed to get instruction metadata"); - let md_todiff = cx.get_metadata_value(md); - llvm::LLVMSetMetadata(call, md_ty, md_todiff); - } else { - // We don't panic, since depending on whether we are in debug or release mode, we might - // have no debug info to copy, which would then be ok. - trace!("no dbg info"); - } + let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); - // Now that we copied the metadata, get rid of dummy code. - llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst); + builder.store_to_place(call, dest.val); if cx.val_ty(call) == cx.type_void() || has_sret { if has_sret { @@ -445,10 +420,10 @@ fn generate_enzyme_call<'ll>( llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr); } builder.ret_void(); - } else { - builder.ret(call); } + builder.store_to_place(call, dest.val); + // Let's crash in case that we messed something up above and generated invalid IR. llvm::LLVMRustVerifyFunction( outer_fn, @@ -463,6 +438,7 @@ pub(crate) fn differentiate<'ll>( diff_items: Vec, _config: &ModuleConfig, ) -> Result<(), FatalError> { + // TODO(Sa4dUs): delete all this logic for item in &diff_items { trace!("{}", item); } @@ -482,7 +458,7 @@ pub(crate) fn differentiate<'ll>( for item in diff_items.iter() { let name = item.source.clone(); let fn_def: Option<&llvm::Value> = cx.get_function(&name); - let Some(fn_def) = fn_def else { + let Some(_fn_def) = fn_def else { return Err(llvm_err( diag_handler.handle(), LlvmError::PrepareAutoDiff { @@ -494,7 +470,7 @@ pub(crate) fn differentiate<'ll>( }; debug!(?item.target); let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); - let Some(fn_target) = fn_target else { + let Some(_fn_target) = fn_target else { return Err(llvm_err( diag_handler.handle(), LlvmError::PrepareAutoDiff { @@ -505,7 +481,7 @@ pub(crate) fn differentiate<'ll>( )); }; - generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); + // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); } // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 5813bdf8435e2..575ebd3b8dcd8 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -9,17 +9,19 @@ use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue}; use rustc_codegen_ssa::traits::*; use rustc_hir as hir; +use rustc_hir::def_id::LOCAL_CRATE; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; -use rustc_middle::ty::{self, GenericArgsRef, Ty}; +use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; -use rustc_symbol_mangling::mangle_internal_symbol; +use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; use rustc_target::spec::{HasTargetSpec, PanicStrategy}; use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; +use crate::builder::autodiff::generate_enzyme_call; use crate::context::CodegenCx; use crate::llvm::{self, Metadata}; use crate::type_::Type; @@ -200,20 +202,56 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { _ if tcx.has_attr(def_id, sym::rustc_autodiff) => { // NOTE(Sa4dUs): This is a hacky way to get the autodiff items // so we can focus on the lowering of the intrinsic call + let mut source_id = None; + let mut diff_attrs = None; + let items: Vec<_> = tcx.hir_body_owners().map(|i| i.to_def_id()).collect(); + + // Hacky way of getting primal-diff pair, only works for code with 1 autodiff call + for target_id in &items { + let Some(target_attrs) = &tcx.codegen_fn_attrs(target_id).autodiff_item else { + continue; + }; - // `diff_items` is empty even when autodiff is enabled, and if we're here, - // it's because some function was marked as intrinsic and had the `rustc_autodiff` attr - let diff_items = tcx.collect_and_partition_mono_items(()).autodiff_items; + if target_attrs.is_source() { + source_id = Some(*target_id); + } else { + diff_attrs = Some(target_attrs); + } + } - // this shouldn't happen? - if diff_items.is_empty() { - bug!("no autodiff items found for {def_id:?}"); + if source_id.is_none() || diff_attrs.is_none() { + bug!("could not find source_id={source_id:?} or diff_attrs={diff_attrs:?}"); } - // TODO(Sa4dUs): generate the enzyme call itself, based on the logic in `builder.rs` + let diff_attrs = diff_attrs.unwrap().clone(); + + // Get source fn + let source_id = source_id.unwrap(); + let fn_source = Instance::mono(tcx, source_id); + let source_symbol = + symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); + let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); + let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; + + // Declare target fn + let target_symbol = + symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); + let fn_abi = self.cx.fn_abi_of_instance(instance, ty::List::empty()); + let outer_fn: &'ll Value = + self.cx.declare_fn(&target_symbol, fn_abi, Some(instance)); + + // Build body + generate_enzyme_call( + self, + self.cx, + fn_to_diff, + outer_fn, + args, // This argument was not in the original `generate_enzyme_call`, now it's included because `get_params` is not working anymore + diff_attrs.clone(), + result, + ); - // Just gen the fallback body for now - return Err(ty::Instance::new_raw(def_id, instance.args)); + return Ok(()); } sym::is_val_statically_known => { let intrinsic_type = args[0].layout.immediate_llvm_type(self.cx); diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index ea02ca7fec52f..f235344de8fc3 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -241,6 +241,7 @@ pub(crate) fn check_intrinsic_type( let inputs = sig.skip_binder().inputs().to_vec(); let output = sig.skip_binder().output(); + // TODO(Sa4dUs): We can also have unsafe ad functions (n_tps, n_lts, n_cts, inputs, output, hir::Safety::Safe) } else { let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); diff --git a/tests/codegen/autodiff/scalar.rs b/tests/codegen/autodiff/scalar.rs index 096b4209e84ad..c2bca7e9c81ef 100644 --- a/tests/codegen/autodiff/scalar.rs +++ b/tests/codegen/autodiff/scalar.rs @@ -2,6 +2,7 @@ //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_reverse;