Skip to content

Commit ad0af85

Browse files
authored
Add atomic fadd for reverse mode (rust-lang#849)
* Add atomic fadd for reverse mode * Fix lower version * Fix for version * add maybealign
1 parent 0665cc8 commit ad0af85

4 files changed

Lines changed: 227 additions & 26 deletions

File tree

enzyme/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,14 @@ if (NOT DEFINED LLVM_EXTERNAL_LIT)
5858
message("found llvm match ${CMAKE_MATCH_1} dir ${LLVM_DIR}")
5959
if (EXISTS ${LLVM_DIR}/../../../bin/llvm-lit)
6060
set(LLVM_EXTERNAL_LIT ${LLVM_DIR}/../../../bin/llvm-lit)
61+
else()
62+
set(LLVM_EXTERNAL_LIT lit)
6163
endif()
6264
else()
6365
if (EXISTS ${LLVM_DIR}/bin/llvm-lit)
6466
set(LLVM_EXTERNAL_LIT ${LLVM_DIR}/bin/llvm-lit)
67+
else()
68+
set(LLVM_EXTERNAL_LIT lit)
6569
endif()
6670
endif()
6771
endif()

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 91 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -935,12 +935,25 @@ class AdjointGenerator
935935
}
936936

937937
void visitAtomicRMWInst(llvm::AtomicRMWInst &I) {
938-
if (Mode == DerivativeMode::ForwardMode) {
939-
IRBuilder<> BuilderZ(&I);
940-
getForwardBuilder(BuilderZ);
941-
switch (I.getOperation()) {
942-
case AtomicRMWInst::FAdd:
943-
case AtomicRMWInst::FSub: {
938+
939+
if (gutils->isConstantInstruction(&I) && gutils->isConstantValue(&I)) {
940+
if (Mode == DerivativeMode::ReverseModeGradient ||
941+
Mode == DerivativeMode::ForwardModeSplit) {
942+
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
943+
} else {
944+
eraseIfUnused(I);
945+
}
946+
return;
947+
}
948+
949+
switch (I.getOperation()) {
950+
case AtomicRMWInst::FAdd:
951+
case AtomicRMWInst::FSub: {
952+
953+
if (Mode == DerivativeMode::ForwardMode ||
954+
Mode == DerivativeMode::ForwardModeSplit) {
955+
IRBuilder<> BuilderZ(&I);
956+
getForwardBuilder(BuilderZ);
944957
auto rule = [&](Value *ptr, Value *dif) -> Value * {
945958
if (!gutils->isConstantInstruction(&I)) {
946959
assert(ptr);
@@ -981,32 +994,84 @@ class AdjointGenerator
981994
setDiffe(&I, diff, BuilderZ);
982995
return;
983996
}
984-
default:
985-
break;
997+
if (Mode == DerivativeMode::ReverseModePrimal) {
998+
eraseIfUnused(I);
999+
return;
9861000
}
987-
}
988-
if (!gutils->isConstantInstruction(&I) || !gutils->isConstantValue(&I)) {
989-
if (looseTypeAnalysis) {
990-
auto &DL = gutils->newFunc->getParent()->getDataLayout();
991-
auto valType = I.getValOperand()->getType();
992-
auto storeSize = DL.getTypeSizeInBits(valType) / 8;
993-
auto fp = TR.firstPointer(storeSize, I.getPointerOperand(),
994-
/*errifnotfound*/ false,
995-
/*pointerIntSame*/ true);
996-
if (!fp.isKnown() && valType->isIntOrIntVectorTy()) {
997-
goto noerror;
1001+
if ((Mode == DerivativeMode::ReverseModeCombined ||
1002+
Mode == DerivativeMode::ReverseModeGradient) &&
1003+
gutils->isConstantValue(&I)) {
1004+
if (!gutils->isConstantValue(I.getValOperand())) {
1005+
assert(!gutils->isConstantValue(I.getPointerOperand()));
1006+
IRBuilder<> Builder2(&I);
1007+
getReverseBuilder(Builder2);
1008+
Value *ip = gutils->invertPointerM(I.getPointerOperand(), Builder2);
1009+
auto order = I.getOrdering();
1010+
if (order == AtomicOrdering::Release)
1011+
order = AtomicOrdering::Monotonic;
1012+
else if (order == AtomicOrdering::AcquireRelease)
1013+
order = AtomicOrdering::Acquire;
1014+
1015+
auto rule = [&](Value *ip) -> Value * {
1016+
#if LLVM_VERSION_MAJOR > 7
1017+
LoadInst *dif1 =
1018+
Builder2.CreateLoad(I.getType(), ip, I.isVolatile());
1019+
#else
1020+
LoadInst *dif1 = Builder2.CreateLoad(ip, I.isVolatile());
1021+
#endif
1022+
1023+
#if LLVM_VERSION_MAJOR >= 11
1024+
dif1->setAlignment(I.getAlign());
1025+
#else
1026+
const DataLayout &DL = I.getModule()->getDataLayout();
1027+
auto tmpAlign = DL.getTypeStoreSize(I.getValOperand()->getType());
1028+
#if LLVM_VERSION_MAJOR >= 10
1029+
dif1->setAlignment(MaybeAlign(tmpAlign.getFixedSize()));
1030+
#else
1031+
dif1->setAlignment(tmpAlign);
1032+
#endif
1033+
#endif
1034+
dif1->setOrdering(order);
1035+
dif1->setSyncScopeID(I.getSyncScopeID());
1036+
return dif1;
1037+
};
1038+
Value *diff = applyChainRule(I.getType(), Builder2, rule, ip);
1039+
1040+
addToDiffe(I.getValOperand(), diff, Builder2,
1041+
I.getValOperand()->getType()->getScalarType());
9981042
}
1043+
if (Mode == DerivativeMode::ReverseModeGradient) {
1044+
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
1045+
} else
1046+
eraseIfUnused(I);
1047+
return;
9991048
}
1000-
TR.dump();
1001-
llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n";
1002-
llvm::errs() << "I: " << I << "\n";
1003-
assert(0 && "Active atomic inst not handled");
1049+
break;
1050+
}
1051+
default:
1052+
break;
10041053
}
1005-
noerror:;
10061054

1007-
if (Mode == DerivativeMode::ReverseModeGradient) {
1008-
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
1055+
if (looseTypeAnalysis) {
1056+
auto &DL = gutils->newFunc->getParent()->getDataLayout();
1057+
auto valType = I.getValOperand()->getType();
1058+
auto storeSize = DL.getTypeSizeInBits(valType) / 8;
1059+
auto fp = TR.firstPointer(storeSize, I.getPointerOperand(),
1060+
/*errifnotfound*/ false,
1061+
/*pointerIntSame*/ true);
1062+
if (!fp.isKnown() && valType->isIntOrIntVectorTy()) {
1063+
if (Mode == DerivativeMode::ReverseModeGradient ||
1064+
Mode == DerivativeMode::ReverseModeGradient) {
1065+
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
1066+
} else
1067+
eraseIfUnused(I);
1068+
return;
1069+
}
10091070
}
1071+
TR.dump();
1072+
llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n";
1073+
llvm::errs() << "I: " << I << "\n";
1074+
llvm_unreachable("Active atomic inst not yet handled");
10101075
}
10111076

10121077
void visitStoreInst(llvm::StoreInst &SI) {
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
; RUN: if [ %llvmver -ge 9 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -S | FileCheck %s; fi
2+
3+
; ModuleID = '<source>'
4+
source_filename = "<source>"
5+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
6+
target triple = "x86_64-unknown-linux-gnu"
7+
8+
define dso_local void @foo1(double* %p, double %v) {
9+
%a10 = atomicrmw volatile fadd double* %p, double %v monotonic
10+
ret void
11+
}
12+
define dso_local void @foo2(double* %p, double %v) {
13+
%a10 = atomicrmw volatile fadd double* %p, double %v acquire
14+
ret void
15+
}
16+
define dso_local void @foo3(double* %p, double %v) {
17+
%a10 = atomicrmw volatile fadd double* %p, double %v release
18+
ret void
19+
}
20+
define dso_local void @foo4(double* %p, double %v) {
21+
%a10 = atomicrmw volatile fadd double* %p, double %v acq_rel
22+
ret void
23+
}
24+
define dso_local void @foo5(double* %p, double %v) {
25+
%a10 = atomicrmw volatile fadd double* %p, double %v seq_cst
26+
ret void
27+
}
28+
define dso_local void @foo6(double* %p, double %v) {
29+
%a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
30+
ret void
31+
}
32+
33+
define void @caller(double* %a, double* %b, double %v) {
34+
%r1 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo1 to i8*), double* %a, double* %b, double %v)
35+
%r2 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo2 to i8*), double* %a, double* %b, double %v)
36+
%r3 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo3 to i8*), double* %a, double* %b, double %v)
37+
%r4 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo4 to i8*), double* %a, double* %b, double %v)
38+
%r5 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo5 to i8*), double* %a, double* %b, double %v)
39+
%r6 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo6 to i8*), double* %a, double* %b, double %v)
40+
ret void
41+
}
42+
43+
declare double @_Z17__enzyme_autodiffPviRdS0_(i8*, double*, double*, double)
44+
45+
46+
; CHECK: define internal { double } @diffefoo1(double* %p, double* %"p'", double %v)
47+
; CHECK-NEXT: invert:
48+
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v monotonic
49+
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" monotonic, align 8
50+
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
51+
; CHECK-NEXT: ret { double } %1
52+
; CHECK-NEXT: }
53+
54+
; CHECK: define internal { double } @diffefoo2(double* %p, double* %"p'", double %v)
55+
; CHECK-NEXT: invert:
56+
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v acquire
57+
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" acquire, align 8
58+
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
59+
; CHECK-NEXT: ret { double } %1
60+
; CHECK-NEXT: }
61+
62+
; CHECK: define internal { double } @diffefoo3(double* %p, double* %"p'", double %v)
63+
; CHECK-NEXT: invert:
64+
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v release
65+
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" monotonic, align 8
66+
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
67+
; CHECK-NEXT: ret { double } %1
68+
; CHECK-NEXT: }
69+
70+
; CHECK: define internal { double } @diffefoo4(double* %p, double* %"p'", double %v)
71+
; CHECK-NEXT: invert:
72+
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v acq_rel
73+
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" acquire, align 8
74+
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
75+
; CHECK-NEXT: ret { double } %1
76+
; CHECK-NEXT: }
77+
78+
; CHECK: define internal { double } @diffefoo5(double* %p, double* %"p'", double %v)
79+
; CHECK-NEXT: invert:
80+
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v seq_cst
81+
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" seq_cst, align 8
82+
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
83+
; CHECK-NEXT: ret { double } %1
84+
; CHECK-NEXT: }
85+
86+
; CHECK: define internal { double } @diffefoo6(double* %p, double* %"p'", double %v)
87+
; CHECK-NEXT: invert:
88+
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
89+
; CHECK-NEXT: ret { double } zeroinitializer
90+
; CHECK-NEXT: }
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
; RUN: if [ %llvmver -ge 9 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -S | FileCheck %s; fi
2+
3+
; ModuleID = '<source>'
4+
source_filename = "<source>"
5+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
6+
target triple = "x86_64-unknown-linux-gnu"
7+
8+
define dso_local void @foo1(double* %p, double %v) {
9+
%a10 = atomicrmw volatile fadd double* %p, double %v monotonic
10+
ret void
11+
}
12+
define dso_local void @foo6(double* %p, double %v) {
13+
%a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
14+
ret void
15+
}
16+
17+
define void @caller(double* %a, double* %b, double %v) {
18+
%r1 = call [2 x double] (...) @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo1 to i8*), metadata !"enzyme_width", i64 2, double* %a, double* %b, double* %b, double %v)
19+
%r6 = call [2 x double] (...) @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo6 to i8*), metadata !"enzyme_width", i64 2, double* %a, double* %b, double* %b, double %v)
20+
ret void
21+
}
22+
23+
declare [2 x double] @_Z17__enzyme_autodiffPviRdS0_(...)
24+
25+
; CHECK: define internal { [2 x double] } @diffe2foo1(double* %p, [2 x double*] %"p'", double %v)
26+
; CHECK-NEXT: invert:
27+
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v monotonic
28+
; CHECK-NEXT: %0 = extractvalue [2 x double*] %"p'", 0
29+
; CHECK-NEXT: %1 = load atomic volatile double, double* %0 monotonic, align 8
30+
; CHECK-NEXT: %2 = extractvalue [2 x double*] %"p'", 1
31+
; CHECK-NEXT: %3 = load atomic volatile double, double* %2 monotonic, align 8
32+
; CHECK-NEXT: %.fca.0.insert5 = insertvalue [2 x double] {{(undef|poison)}}, double %1, 0
33+
; CHECK-NEXT: %.fca.1.insert8 = insertvalue [2 x double] %.fca.0.insert5, double %3, 1
34+
; CHECK-NEXT: %4 = insertvalue { [2 x double] } undef, [2 x double] %.fca.1.insert8, 0
35+
; CHECK-NEXT: ret { [2 x double] } %4
36+
; CHECK-NEXT: }
37+
38+
; CHECK: define internal { [2 x double] } @diffe2foo6(double* %p, [2 x double*] %"p'", double %v)
39+
; CHECK-NEXT: invert:
40+
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
41+
; CHECK-NEXT: ret { [2 x double] } zeroinitializer
42+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)