Skip to content

Commit 905bb2c

Browse files
committed
working dupv for fwd mode
1 parent d4f880f commit 905bb2c

File tree

5 files changed

+45
-10
lines changed

5 files changed

+45
-10
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ pub enum DiffActivity {
5050
/// with it.
5151
Dual,
5252
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
53+
/// with it. It expects the shadow argument to be `width` times larger than the original
54+
/// input/output.
55+
Dualv,
56+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
5357
/// with it. Drop the code which updates the original input/output for maximum performance.
5458
DualOnly,
5559
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
@@ -133,6 +137,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
133137
DiffMode::Source => false,
134138
DiffMode::Forward => {
135139
activity == DiffActivity::Dual
140+
|| activity == DiffActivity::Dualv
136141
|| activity == DiffActivity::DualOnly
137142
|| activity == DiffActivity::Const
138143
}
@@ -156,7 +161,7 @@ pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
156161
if matches!(activity, Const) {
157162
return true;
158163
}
159-
if matches!(activity, Dual | DualOnly) {
164+
if matches!(activity, Dual | DualOnly | Dualv) {
160165
return true;
161166
}
162167
// FIXME(ZuseZ4) We should make this more robust to also
@@ -173,7 +178,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
173178
DiffMode::Error => false,
174179
DiffMode::Source => false,
175180
DiffMode::Forward => {
176-
matches!(activity, Dual | DualOnly | Const)
181+
matches!(activity, Dual | DualOnly | Dualv | Const)
177182
}
178183
DiffMode::Reverse => {
179184
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
@@ -189,6 +194,7 @@ impl Display for DiffActivity {
189194
DiffActivity::Active => write!(f, "Active"),
190195
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
191196
DiffActivity::Dual => write!(f, "Dual"),
197+
DiffActivity::Dualv => write!(f, "Dualv"),
192198
DiffActivity::DualOnly => write!(f, "DualOnly"),
193199
DiffActivity::Duplicated => write!(f, "Duplicated"),
194200
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
@@ -220,6 +226,7 @@ impl FromStr for DiffActivity {
220226
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
221227
"Const" => Ok(DiffActivity::Const),
222228
"Dual" => Ok(DiffActivity::Dual),
229+
"Dualv" => Ok(DiffActivity::Dualv),
223230
"DualOnly" => Ok(DiffActivity::DualOnly),
224231
"Duplicated" => Ok(DiffActivity::Duplicated),
225232
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -799,8 +799,13 @@ mod llvm_enzyme {
799799
d_inputs.push(shadow_arg.clone());
800800
}
801801
}
802-
DiffActivity::Dual | DiffActivity::DualOnly => {
803-
for i in 0..x.width {
802+
DiffActivity::Dual | DiffActivity::DualOnly | DiffActivity::Dualv => {
803+
let iterations = if matches!(activity, DiffActivity::Dualv) {
804+
1
805+
} else {
806+
x.width
807+
};
808+
for i in 0..iterations {
804809
let mut shadow_arg = arg.clone();
805810
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
806811
ident.name
@@ -887,8 +892,8 @@ mod llvm_enzyme {
887892
}
888893
};
889894

890-
if let DiffActivity::Dual = x.ret_activity {
891-
let kind = if x.width == 1 {
895+
if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
896+
let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
892897
// Dual can only be used for f32/f64 ret.
893898
// In that case we return now a tuple with two floats.
894899
TyKind::Tup(thin_vec![ty.clone(), ty.clone()])

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
123123
/// Empty string, to be used where LLVM expects an instruction name, indicating
124124
/// that the instruction is to be left unnamed (i.e. numbered, in textual IR).
125125
// FIXME(eddyb) pass `&CStr` directly to FFI once it's a thin pointer.
126-
const UNNAMED: *const c_char = c"".as_ptr();
126+
pub(crate) const UNNAMED: *const c_char = c"".as_ptr();
127127

128128
impl<'ll, CX: Borrow<SCx<'ll>>> BackendTypes for GenericBuilder<'_, 'll, CX> {
129129
type Value = <GenericCx<'ll, CX> as BackendTypes>::Value;

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use rustc_middle::bug;
1010
use tracing::{debug, trace};
1111

1212
use crate::back::write::llvm_err;
13-
use crate::builder::SBuilder;
13+
use crate::builder::{SBuilder, UNNAMED};
1414
use crate::context::SimpleCx;
1515
use crate::declare::declare_simple_fn;
1616
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
@@ -51,6 +51,7 @@ fn has_sret(fnc: &Value) -> bool {
5151
// using iterators and peek()?
5252
fn match_args_from_caller_to_enzyme<'ll>(
5353
cx: &SimpleCx<'ll>,
54+
builder: &SBuilder<'ll,'ll>,
5455
width: u32,
5556
args: &mut Vec<&'ll llvm::Value>,
5657
inputs: &[DiffActivity],
@@ -78,6 +79,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
7879
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
7980
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
8081
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
82+
let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
8183
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
8284

8385
while activity_pos < inputs.len() {
@@ -90,13 +92,26 @@ fn match_args_from_caller_to_enzyme<'ll>(
9092
DiffActivity::Active => (enzyme_out, false),
9193
DiffActivity::ActiveOnly => (enzyme_out, false),
9294
DiffActivity::Dual => (enzyme_dup, true),
95+
DiffActivity::Dualv => (enzyme_dupv, true),
9396
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
9497
DiffActivity::Duplicated => (enzyme_dup, true),
9598
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
9699
DiffActivity::FakeActivitySize => (enzyme_const, false),
97100
};
98101
let outer_arg = outer_args[outer_pos];
99102
args.push(cx.get_metadata_value(activity));
103+
if matches!(diff_activity, DiffActivity::Dualv) {
104+
let next_outer_arg = outer_args[outer_pos + 1];
105+
// stride: sizeof(T) * n_elems.
106+
// T=f32 => 4 bytes
107+
// n_elems is the next integer.
108+
// Now we multiply `4 * next_outer_arg` to get the stride.
109+
//let mul = builder
110+
// .build_mul(cx.get_const_i64(4), next_outer_arg)
111+
// .unwrap();
112+
let mul = unsafe {llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED)};
113+
args.push(mul);
114+
}
100115
args.push(outer_arg);
101116
if duplicated {
102117
// We know that duplicated args by construction have a following argument,
@@ -125,7 +140,13 @@ fn match_args_from_caller_to_enzyme<'ll>(
125140
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
126141
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
127142

128-
for i in 0..(width as usize) {
143+
let iterations = if matches!(diff_activity, DiffActivity::Dualv) {
144+
1
145+
} else {
146+
width as usize
147+
};
148+
149+
for i in 0..iterations {
129150
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
130151
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
131152
assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
@@ -136,7 +157,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
136157
}
137158
args.push(cx.get_metadata_value(enzyme_const));
138159
args.push(next_outer_arg);
139-
outer_pos += 2 + 2 * width as usize;
160+
outer_pos += 2 + 2 * iterations;
140161
activity_pos += 2;
141162
} else {
142163
// A duplicated pointer will have the following two outer_fn arguments:
@@ -360,6 +381,7 @@ fn generate_enzyme_call<'ll>(
360381
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
361382
match_args_from_caller_to_enzyme(
362383
&cx,
384+
&builder,
363385
attrs.width,
364386
&mut args,
365387
&attrs.input_activity,

compiler/rustc_monomorphize/src/partitioning/autodiff.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
3131
let activity = match da[i] {
3232
DiffActivity::DualOnly
3333
| DiffActivity::Dual
34+
| DiffActivity::Dualv
3435
| DiffActivity::DuplicatedOnly
3536
| DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
3637
DiffActivity::Const => DiffActivity::Const,

0 commit comments

Comments
 (0)