Skip to content

Commit

Permalink
Reverse vector mode (rust-lang#611)
Browse files Browse the repository at this point in the history
* Implement reverse vector mode

* add tests
  • Loading branch information
tgymnich committed Apr 18, 2022
1 parent 5989d49 commit cf73e23
Show file tree
Hide file tree
Showing 12 changed files with 1,008 additions and 354 deletions.
831 changes: 576 additions & 255 deletions enzyme/Enzyme/AdjointGenerator.h

Large diffs are not rendered by default.

22 changes: 16 additions & 6 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,11 +970,19 @@ class Enzyme : public ModulePass {
if (differentialReturn) {
if (differet)
args.push_back(differet);
else if (cast<Function>(fn)->getReturnType()->isFPOrFPVectorTy())
args.push_back(
ConstantFP::get(cast<Function>(fn)->getReturnType(), 1.0));
else if (auto ST =
dyn_cast<StructType>(cast<Function>(fn)->getReturnType())) {
else if (cast<Function>(fn)->getReturnType()->isFPOrFPVectorTy()) {
Constant *seed =
ConstantFP::get(cast<Function>(fn)->getReturnType(), 1.0);
if (width == 1) {
args.push_back(seed);
} else {
ArrayType *arrayType =
ArrayType::get(cast<Function>(fn)->getReturnType(), width);
args.push_back(ConstantArray::get(
arrayType, SmallVector<Constant *, 3>(width, seed)));
}
} else if (auto ST = dyn_cast<StructType>(
cast<Function>(fn)->getReturnType())) {
SmallVector<Constant *, 2> csts;
for (auto e : ST->elements()) {
csts.push_back(ConstantFP::get(e, 1.0));
Expand Down Expand Up @@ -1090,7 +1098,9 @@ class Enzyme : public ModulePass {
// Adapt the returned vector type to the struct type expected by our calling
// convention.
if (width > 1 && !diffret->getType()->isEmptyTy() &&
!diffret->getType()->isVoidTy()) {
!diffret->getType()->isVoidTy() &&
(mode == DerivativeMode::ForwardMode ||
mode == DerivativeMode::ForwardModeSplit)) {

/// Actual return type (including struct return)
Type *returnType =
Expand Down
17 changes: 12 additions & 5 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2848,7 +2848,9 @@ void createInvertedTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
}

if (!handled) {
gutils->setDiffe(orig, Constant::getNullValue(orig->getType()), Builder);
gutils->setDiffe(
orig, Constant::getNullValue(gutils->getShadowType(orig->getType())),
Builder);

for (BasicBlock *opred : predecessors(oBB)) {
auto oval = orig->getIncomingValueForBlock(opred);
Expand Down Expand Up @@ -3515,11 +3517,16 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
if (key.additionalType)
endarg--;
differetval = endarg;
if (differetval->getType() != key.todiff->getReturnType()) {
llvm::errs() << *gutils->oldFunc << "\n";
llvm::errs() << *gutils->newFunc << "\n";

if (!key.todiff->getReturnType()->isVoidTy()) {
if (!(differetval->getType() ==
gutils->getShadowType(key.todiff->getReturnType()))) {
llvm::errs() << *gutils->oldFunc << "\n";
llvm::errs() << *gutils->newFunc << "\n";
}
assert(differetval->getType() ==
gutils->getShadowType(key.todiff->getReturnType()));
}
assert(differetval->getType() == key.todiff->getReturnType());
}

// Explicitly handle all returns first to ensure that return instructions know
Expand Down
5 changes: 3 additions & 2 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1798,14 +1798,15 @@ FunctionType *getFunctionTypeForClone(
constant_args[argno] == DIFFE_TYPE::DUP_NONEED) {
ArgTypes.push_back(GradientUtils::getShadowType(I, width));
} else if (constant_args[argno] == DIFFE_TYPE::OUT_DIFF) {
RetTypes.push_back(I);
RetTypes.push_back(GradientUtils::getShadowType(I, width));
}
++argno;
}

if (diffeReturnArg) {
assert(!FTy->getReturnType()->isVoidTy());
ArgTypes.push_back(FTy->getReturnType());
ArgTypes.push_back(
GradientUtils::getShadowType(FTy->getReturnType(), width));
}
if (additionalArg) {
ArgTypes.push_back(additionalArg);
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3382,8 +3382,6 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
case DerivativeMode::ForwardMode:
case DerivativeMode::ForwardModeSplit:
prefix = "fwddiffe";
if (width > 1)
prefix += std::to_string(width);
break;
case DerivativeMode::ReverseModeCombined:
case DerivativeMode::ReverseModeGradient:
Expand All @@ -3393,6 +3391,9 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
llvm_unreachable("invalid DerivativeMode: ReverseModePrimal\n");
}

if (width > 1)
prefix += std::to_string(width);

auto newFunc = Logic.PPC.CloneFunctionWithReturns(
mode, width, todiff, invertedPointers, constant_args, constant_values,
nonconstant_values, returnvals, returnValue, retType,
Expand Down Expand Up @@ -3751,7 +3752,6 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
auto rule = [&CD](ArrayRef<Constant *> Vals) {
return ConstantStruct::get(CD->getType(), Vals);
};

return applyChainRule(CD->getType(), Vals, BuilderM, rule);
} else if (auto CD = dyn_cast<ConstantVector>(oval)) {
SmallVector<Constant *, 1> Vals;
Expand Down
Loading

0 comments on commit cf73e23

Please sign in to comment.