@@ -935,12 +935,25 @@ class AdjointGenerator
935935 }
936936
937937 void visitAtomicRMWInst (llvm::AtomicRMWInst &I) {
938- if (Mode == DerivativeMode::ForwardMode) {
939- IRBuilder<> BuilderZ (&I);
940- getForwardBuilder (BuilderZ);
941- switch (I.getOperation ()) {
942- case AtomicRMWInst::FAdd:
943- case AtomicRMWInst::FSub: {
938+
939+ if (gutils->isConstantInstruction (&I) && gutils->isConstantValue (&I)) {
940+ if (Mode == DerivativeMode::ReverseModeGradient ||
941+ Mode == DerivativeMode::ForwardModeSplit) {
942+ eraseIfUnused (I, /* erase*/ true , /* check*/ false );
943+ } else {
944+ eraseIfUnused (I);
945+ }
946+ return ;
947+ }
948+
949+ switch (I.getOperation ()) {
950+ case AtomicRMWInst::FAdd:
951+ case AtomicRMWInst::FSub: {
952+
953+ if (Mode == DerivativeMode::ForwardMode ||
954+ Mode == DerivativeMode::ForwardModeSplit) {
955+ IRBuilder<> BuilderZ (&I);
956+ getForwardBuilder (BuilderZ);
944957 auto rule = [&](Value *ptr, Value *dif) -> Value * {
945958 if (!gutils->isConstantInstruction (&I)) {
946959 assert (ptr);
@@ -981,32 +994,84 @@ class AdjointGenerator
981994 setDiffe (&I, diff, BuilderZ);
982995 return ;
983996 }
984- default :
985- break ;
997+ if (Mode == DerivativeMode::ReverseModePrimal) {
998+ eraseIfUnused (I);
999+ return ;
9861000 }
987- }
988- if (!gutils->isConstantInstruction (&I) || !gutils->isConstantValue (&I)) {
989- if (looseTypeAnalysis) {
990- auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
991- auto valType = I.getValOperand ()->getType ();
992- auto storeSize = DL.getTypeSizeInBits (valType) / 8 ;
993- auto fp = TR.firstPointer (storeSize, I.getPointerOperand (),
994- /* errifnotfound*/ false ,
995- /* pointerIntSame*/ true );
996- if (!fp.isKnown () && valType->isIntOrIntVectorTy ()) {
997- goto noerror;
1001+ if ((Mode == DerivativeMode::ReverseModeCombined ||
1002+ Mode == DerivativeMode::ReverseModeGradient) &&
1003+ gutils->isConstantValue (&I)) {
1004+ if (!gutils->isConstantValue (I.getValOperand ())) {
1005+ assert (!gutils->isConstantValue (I.getPointerOperand ()));
1006+ IRBuilder<> Builder2 (&I);
1007+ getReverseBuilder (Builder2);
1008+ Value *ip = gutils->invertPointerM (I.getPointerOperand (), Builder2);
1009+ auto order = I.getOrdering ();
1010+ if (order == AtomicOrdering::Release)
1011+ order = AtomicOrdering::Monotonic;
1012+ else if (order == AtomicOrdering::AcquireRelease)
1013+ order = AtomicOrdering::Acquire;
1014+
1015+ auto rule = [&](Value *ip) -> Value * {
1016+ #if LLVM_VERSION_MAJOR > 7
1017+ LoadInst *dif1 =
1018+ Builder2.CreateLoad (I.getType (), ip, I.isVolatile ());
1019+ #else
1020+ LoadInst *dif1 = Builder2.CreateLoad (ip, I.isVolatile ());
1021+ #endif
1022+
1023+ #if LLVM_VERSION_MAJOR >= 11
1024+ dif1->setAlignment (I.getAlign ());
1025+ #else
1026+ const DataLayout &DL = I.getModule ()->getDataLayout ();
1027+ auto tmpAlign = DL.getTypeStoreSize (I.getValOperand ()->getType ());
1028+ #if LLVM_VERSION_MAJOR >= 10
1029+ dif1->setAlignment (MaybeAlign (tmpAlign.getFixedSize ()));
1030+ #else
1031+ dif1->setAlignment (tmpAlign);
1032+ #endif
1033+ #endif
1034+ dif1->setOrdering (order);
1035+ dif1->setSyncScopeID (I.getSyncScopeID ());
1036+ return dif1;
1037+ };
1038+ Value *diff = applyChainRule (I.getType (), Builder2, rule, ip);
1039+
1040+ addToDiffe (I.getValOperand (), diff, Builder2,
1041+ I.getValOperand ()->getType ()->getScalarType ());
9981042 }
1043+ if (Mode == DerivativeMode::ReverseModeGradient) {
1044+ eraseIfUnused (I, /* erase*/ true , /* check*/ false );
1045+ } else
1046+ eraseIfUnused (I);
1047+ return ;
9991048 }
1000- TR. dump () ;
1001- llvm::errs () << " oldFunc: " << *gutils-> newFunc << " \n " ;
1002- llvm::errs () << " I: " << I << " \n " ;
1003- assert ( 0 && " Active atomic inst not handled " ) ;
1049+ break ;
1050+ }
1051+ default :
1052+ break ;
10041053 }
1005- noerror:;
10061054
1007- if (Mode == DerivativeMode::ReverseModeGradient) {
1008- eraseIfUnused (I, /* erase*/ true , /* check*/ false );
1055+ if (looseTypeAnalysis) {
1056+ auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
1057+ auto valType = I.getValOperand ()->getType ();
1058+ auto storeSize = DL.getTypeSizeInBits (valType) / 8 ;
1059+ auto fp = TR.firstPointer (storeSize, I.getPointerOperand (),
1060+ /* errifnotfound*/ false ,
1061+ /* pointerIntSame*/ true );
1062+ if (!fp.isKnown () && valType->isIntOrIntVectorTy ()) {
1063+ if (Mode == DerivativeMode::ReverseModeGradient ||
1064+ Mode == DerivativeMode::ReverseModeGradient) {
1065+ eraseIfUnused (I, /* erase*/ true , /* check*/ false );
1066+ } else
1067+ eraseIfUnused (I);
1068+ return ;
1069+ }
10091070 }
1071+ TR.dump ();
1072+ llvm::errs () << " oldFunc: " << *gutils->newFunc << " \n " ;
1073+ llvm::errs () << " I: " << I << " \n " ;
1074+ llvm_unreachable (" Active atomic inst not yet handled" );
10101075 }
10111076
10121077 void visitStoreInst (llvm::StoreInst &SI) {
0 commit comments