diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 2f918faaf752b..0766d7ba09376 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -6,6 +6,8 @@ use std::fmt::{self, Display, Formatter}; use std::str::FromStr; +use rustc_span::Span; + use crate::expand::{Decodable, Encodable, HashStable_Generic}; use crate::ptr::P; use crate::{Ty, TyKind}; @@ -85,6 +87,7 @@ pub struct AutoDiffItem { /// The name of the function being generated pub target: String, pub attrs: AutoDiffAttrs, + pub span: Span, } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] @@ -276,8 +279,8 @@ impl AutoDiffAttrs { !matches!(self.mode, DiffMode::Error | DiffMode::Source) } - pub fn into_item(self, source: String, target: String) -> AutoDiffItem { - AutoDiffItem { source, target, attrs: self } + pub fn into_item(self, source: String, target: String, span: Span) -> AutoDiffItem { + AutoDiffItem { source, target, attrs: self, span } } } diff --git a/compiler/rustc_codegen_ssa/messages.ftl b/compiler/rustc_codegen_ssa/messages.ftl index 2bd8644e0d7fc..1ecb138f36bee 100644 --- a/compiler/rustc_codegen_ssa/messages.ftl +++ b/compiler/rustc_codegen_ssa/messages.ftl @@ -8,7 +8,9 @@ codegen_ssa_aix_strip_not_used = using host's `strip` binary to cross-compile to codegen_ssa_archive_build_failure = failed to build archive at `{$path}`: {$error} -codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto +codegen_ssa_autodiff_lib_unsupported = using the autodiff feature with library builds is not yet supported + +codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto. codegen_ssa_bare_instruction_set = `#[instruction_set]` requires an argument diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index c3bfe4c13cdf7..c2714676178e0 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -41,7 +41,7 @@ use tracing::debug; use super::link::{self, ensure_removed}; use super::lto::{self, SerializedModule}; use super::symbol_export::symbol_name_for_instance_in_crate; -use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir}; +use crate::errors::{AutodiffLibraryBuild, AutodiffWithoutLto, ErrorCreatingRemarkDir}; use crate::traits::*; use crate::{ CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind, @@ -419,7 +419,13 @@ fn generate_lto_work( } else { if !autodiff.is_empty() { let dcx = cgcx.create_dcx(); - dcx.handle().emit_fatal(AutodiffWithoutLto {}); + let span = autodiff[0].span; + if cgcx.crate_types.contains(&CrateType::Rlib) { + dcx.handle().span_fatal(AutodiffLibraryBuild { span }); + } + if cgcx.lto != Lto::Fat { + dcx.handle().emit_fatal(AutodiffWithoutLto { span }); + } } assert!(needs_fat_lto.is_empty()); let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) @@ -1456,6 +1462,7 @@ fn start_executing_work( if needs_fat_lto.is_empty() && needs_thin_lto.is_empty() && lto_import_only_modules.is_empty() + && autodiff_items.is_empty() { // Nothing more to do! break; @@ -1469,13 +1476,14 @@ fn start_executing_work( assert!(!started_lto); started_lto = true; + let autodiff_items = mem::take(&mut autodiff_items); let needs_fat_lto = mem::take(&mut needs_fat_lto); let needs_thin_lto = mem::take(&mut needs_thin_lto); let import_only_modules = mem::take(&mut lto_import_only_modules); for (work, cost) in generate_lto_work( &cgcx, - autodiff_items.clone(), + autodiff_items, needs_fat_lto, needs_thin_lto, import_only_modules, diff --git a/compiler/rustc_codegen_ssa/src/errors.rs b/compiler/rustc_codegen_ssa/src/errors.rs index 72e71b97a1743..80e074b1dcc1b 100644 --- a/compiler/rustc_codegen_ssa/src/errors.rs +++ b/compiler/rustc_codegen_ssa/src/errors.rs @@ -39,7 +39,17 @@ pub(crate) struct CguNotRecorded<'a> { #[derive(Diagnostic)] #[diag(codegen_ssa_autodiff_without_lto)] -pub struct AutodiffWithoutLto; +pub struct AutodiffWithoutLto { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(codegen_ssa_autodiff_lib_unsupported)] +pub struct AutodiffLibraryBuild { + #[primary_span] + pub span: Span, +} #[derive(Diagnostic)] #[diag(codegen_ssa_unknown_reuse_kind)] diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs index 22d593b80b895..245c50ba75595 100644 --- a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs +++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs @@ -120,7 +120,8 @@ pub(crate) fn find_autodiff_source_functions<'tcx>( None => continue, }; - debug!("source_id: {:?}", inst.def_id()); + let source_def_id = inst.def_id(); + debug!("source_id: {:?}", source_def_id); let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized()); assert!(fn_ty.is_fn()); adjust_activity_to_abi(tcx, fn_ty, &mut input_activities); @@ -128,7 +129,7 @@ pub(crate) fn find_autodiff_source_functions<'tcx>( let mut new_target_attrs = target_attrs.clone(); new_target_attrs.input_activity = input_activities; - let itm = new_target_attrs.into_item(symb, target_symbol); + let itm = new_target_attrs.into_item(symb, target_symbol, tcx.def_span(source_def_id)); autodiff_items.push(itm); } diff --git a/tests/ui/autodiff/autodiff_in_rlib.rs b/tests/ui/autodiff/autodiff_in_rlib.rs new file mode 100644 index 0000000000000..1a96a6ff1a77c --- /dev/null +++ b/tests/ui/autodiff/autodiff_in_rlib.rs @@ -0,0 +1,17 @@ +#![feature(autodiff)] +#![crate_type = "rlib"] +//@ needs-enzyme +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ build-fail + +// We test that we fail to compile if a user applied an autodiff_ macro in src/lib.rs, +// since autodiff doesn't work in libraries yet. In the past we used to just return zeros in the +// autodiffed functions, which is obviously confusing and wrong, so erroring is an improvement. + +use std::autodiff::autodiff_reverse; +//~? ERROR: using the autodiff feature with library builds is not yet supported + +#[autodiff_reverse(d_square, Duplicated, Active)] +pub fn square(x: &f64) -> f64 { + *x * *x +} diff --git a/tests/ui/autodiff/autodiff_in_rlib.stderr b/tests/ui/autodiff/autodiff_in_rlib.stderr new file mode 100644 index 0000000000000..a7902c252c2c1 --- /dev/null +++ b/tests/ui/autodiff/autodiff_in_rlib.stderr @@ -0,0 +1,4 @@ +error: using the autodiff feature with library builds is not yet supported + +error: aborting due to 1 previous error + diff --git a/tests/ui/autodiff/no_lto_flag.no_lto.stderr b/tests/ui/autodiff/no_lto_flag.no_lto.stderr new file mode 100644 index 0000000000000..17580f2cdfa25 --- /dev/null +++ b/tests/ui/autodiff/no_lto_flag.no_lto.stderr @@ -0,0 +1,4 @@ +error: using the autodiff feature requires using fat-lto. + +error: aborting due to 1 previous error + diff --git a/tests/ui/autodiff/no_lto_flag.rs b/tests/ui/autodiff/no_lto_flag.rs new file mode 100644 index 0000000000000..1cb1150aac518 --- /dev/null +++ b/tests/ui/autodiff/no_lto_flag.rs @@ -0,0 +1,31 @@ +//@ needs-enzyme +//@ no-prefer-dynamic +//@ revisions: with_lto no_lto +//@[with_lto] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@[no_lto] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=thin + +#![feature(autodiff)] +//@[no_lto] build-fail +//@[with_lto] build-pass + +// Autodiff requires users to enable lto=fat (for now). +// In the past, autodiff did not run if users forget to enable fat-lto, which caused functions to +// returning zero-derivatives. That's obviously wrong and confusing to users. We now added a check +// which will abort compilation instead. + +use std::autodiff::autodiff_reverse; +//[no_lto]~? ERROR using the autodiff feature requires using fat-lto. + +#[autodiff_reverse(d_square, Duplicated, Active)] +fn square(x: &f64) -> f64 { + *x * *x +} + +fn main() { + let xf64: f64 = std::hint::black_box(3.0); + + let mut df_dxf64: f64 = std::hint::black_box(0.0); + + let _output_f64 = d_square(&xf64, &mut df_dxf64, 1.0); + assert_eq!(6.0, df_dxf64); +}