diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index ad01a206c93fb3..8bbd5ab1818590 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14380,6 +14380,31 @@ static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) { return isOperandOfVmullHighP64(Op1) && isOperandOfVmullHighP64(Op2); } +static bool shouldSinkVectorOfPtrs(Value *Ptrs, SmallVectorImpl &Ops) { + // Restrict ourselves to the form CodeGenPrepare typically constructs. + auto *GEP = dyn_cast(Ptrs); + if (!GEP || GEP->getNumOperands() != 2) + return false; + + Value *Base = GEP->getOperand(0); + Value *Offsets = GEP->getOperand(1); + + // We only care about scalar_base+vector_offsets. + if (Base->getType()->isVectorTy() || !Offsets->getType()->isVectorTy()) + return false; + + // Sink extends that would allow us to use 32-bit offset vectors. + if (isa(Offsets) || isa(Offsets)) { + auto *OffsetsInst = cast(Offsets); + if (OffsetsInst->getType()->getScalarSizeInBits() > 32 && + OffsetsInst->getOperand(0)->getType()->getScalarSizeInBits() <= 32) + Ops.push_back(&GEP->getOperandUse(1)); + } + + // Sink the GEP. + return true; +} + /// Check if sinking \p I's operands to I's basic block is profitable, because /// the operands can be folded into a target instruction, e.g. /// shufflevectors extracts and/or sext/zext can be folded into (u,s)subl(2). @@ -14481,6 +14506,16 @@ bool AArch64TargetLowering::shouldSinkOperands( Ops.push_back(&II->getArgOperandUse(0)); Ops.push_back(&II->getArgOperandUse(1)); return true; + case Intrinsic::masked_gather: + if (!shouldSinkVectorOfPtrs(II->getArgOperand(0), Ops)) + return false; + Ops.push_back(&II->getArgOperandUse(0)); + return true; + case Intrinsic::masked_scatter: + if (!shouldSinkVectorOfPtrs(II->getArgOperand(1), Ops)) + return false; + Ops.push_back(&II->getArgOperandUse(1)); + return true; default: return false; } diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/sink-gather-scatter-addressing.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/sink-gather-scatter-addressing.ll new file mode 100644 index 00000000000000..73322836d1b84a --- /dev/null +++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/sink-gather-scatter-addressing.ll @@ -0,0 +1,231 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3 +; RUN: opt -S --codegenprepare < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +; Sink the GEP to make use of scalar+vector addressing modes. +define @gather_offsets_sink_gep(ptr %base, %indices, %mask, i1 %cond) { +; CHECK-LABEL: define @gather_offsets_sink_gep( +; CHECK-SAME: ptr [[BASE:%.*]], [[INDICES:%.*]], [[MASK:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]] +; CHECK: cond.block: +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr float, ptr [[BASE]], [[INDICES]] +; CHECK-NEXT: [[LOAD:%.*]] = tail call @llvm.masked.gather.nxv4f32.nxv4p0( [[TMP0]], i32 4, [[MASK]], poison) +; CHECK-NEXT: ret [[LOAD]] +; CHECK: exit: +; CHECK-NEXT: ret zeroinitializer +; +entry: + %ptrs = getelementptr float, ptr %base, %indices + br i1 %cond, label %cond.block, label %exit + +cond.block: + %load = tail call @llvm.masked.gather.nxv4f32( %ptrs, i32 4, %mask, poison) + br label %exit + +exit: + %ret = phi [ zeroinitializer, %entry ], [ %load, %cond.block ] + ret %ret +} + +; Sink sext to make use of scalar+sxtw(vector) addressing modes. +define @gather_offsets_sink_sext(ptr %base, %indices, %mask, i1 %cond) { +; CHECK-LABEL: define @gather_offsets_sink_sext( +; CHECK-SAME: ptr [[BASE:%.*]], [[INDICES:%.*]], [[MASK:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]] +; CHECK: cond.block: +; CHECK-NEXT: [[TMP0:%.*]] = sext [[INDICES]] to +; CHECK-NEXT: [[PTRS:%.*]] = getelementptr float, ptr [[BASE]], [[TMP0]] +; CHECK-NEXT: [[LOAD:%.*]] = tail call @llvm.masked.gather.nxv4f32.nxv4p0( [[PTRS]], i32 4, [[MASK]], poison) +; CHECK-NEXT: ret [[LOAD]] +; CHECK: exit: +; CHECK-NEXT: ret zeroinitializer +; +entry: + %indices.sext = sext %indices to + br i1 %cond, label %cond.block, label %exit + +cond.block: + %ptrs = getelementptr float, ptr %base, %indices.sext + %load = tail call @llvm.masked.gather.nxv4f32( %ptrs, i32 4, %mask, poison) + br label %exit + +exit: + %ret = phi [ zeroinitializer, %entry ], [ %load, %cond.block ] + ret %ret +} + +; As above but ensure both the GEP and sext is sunk. +define @gather_offsets_sink_sext_get(ptr %base, %indices, %mask, i1 %cond) { +; CHECK-LABEL: define @gather_offsets_sink_sext_get( +; CHECK-SAME: ptr [[BASE:%.*]], [[INDICES:%.*]], [[MASK:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]] +; CHECK: cond.block: +; CHECK-NEXT: [[TMP0:%.*]] = sext [[INDICES]] to +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr float, ptr [[BASE]], [[TMP0]] +; CHECK-NEXT: [[LOAD:%.*]] = tail call @llvm.masked.gather.nxv4f32.nxv4p0( [[TMP1]], i32 4, [[MASK]], poison) +; CHECK-NEXT: ret [[LOAD]] +; CHECK: exit: +; CHECK-NEXT: ret zeroinitializer +; +entry: + %indices.sext = sext %indices to + %ptrs = getelementptr float, ptr %base, %indices.sext + br i1 %cond, label %cond.block, label %exit + +cond.block: + %load = tail call @llvm.masked.gather.nxv4f32( %ptrs, i32 4, %mask, poison) + br label %exit + +exit: + %ret = phi [ zeroinitializer, %entry ], [ %load, %cond.block ] + ret %ret +} + +; Don't sink GEPs that cannot benefit from SVE's scalar+vector addressing modes. +define @gather_no_scalar_base( %bases, %indices, %mask, i1 %cond) { +; CHECK-LABEL: define @gather_no_scalar_base( +; CHECK-SAME: [[BASES:%.*]], [[INDICES:%.*]], [[MASK:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[PTRS:%.*]] = getelementptr float, [[BASES]], [[INDICES]] +; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]] +; CHECK: cond.block: +; CHECK-NEXT: [[LOAD:%.*]] = tail call @llvm.masked.gather.nxv4f32.nxv4p0( [[PTRS]], i32 4, [[MASK]], poison) +; CHECK-NEXT: ret [[LOAD]] +; CHECK: exit: +; CHECK-NEXT: ret zeroinitializer +; +entry: + %ptrs = getelementptr float, %bases, %indices + br i1 %cond, label %cond.block, label %exit + +cond.block: + %load = tail call @llvm.masked.gather.nxv4f32( %ptrs, i32 4, %mask, poison) + br label %exit + +exit: + %ret = phi [ zeroinitializer, %entry ], [ %load, %cond.block ] + ret %ret +} + +; Don't sink extends whose result type is already favourable for SVE's sxtw/uxtw addressing modes. +; NOTE: We still want to sink the GEP. +define @gather_offset_type_too_small(ptr %base, %indices, %mask, i1 %cond) { +; CHECK-LABEL: define @gather_offset_type_too_small( +; CHECK-SAME: ptr [[BASE:%.*]], [[INDICES:%.*]], [[MASK:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[INDICES_SEXT:%.*]] = sext [[INDICES]] to +; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]] +; CHECK: cond.block: +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr float, ptr [[BASE]], [[INDICES_SEXT]] +; CHECK-NEXT: [[LOAD:%.*]] = tail call @llvm.masked.gather.nxv4f32.nxv4p0( [[TMP0]], i32 4, [[MASK]], poison) +; CHECK-NEXT: ret [[LOAD]] +; CHECK: exit: +; CHECK-NEXT: ret zeroinitializer +; +entry: + %indices.sext = sext %indices to + %ptrs = getelementptr float, ptr %base, %indices.sext + br i1 %cond, label %cond.block, label %exit + +cond.block: + %load = tail call @llvm.masked.gather.nxv4f32( %ptrs, i32 4, %mask, poison) + br label %exit + +exit: + %ret = phi [ zeroinitializer, %entry ], [ %load, %cond.block ] + ret %ret +} + +; Don't sink extends that cannot benefit from SVE's sxtw/uxtw addressing modes. +; NOTE: We still want to sink the GEP. +define @gather_offset_type_too_big(ptr %base, %indices, %mask, i1 %cond) { +; CHECK-LABEL: define @gather_offset_type_too_big( +; CHECK-SAME: ptr [[BASE:%.*]], [[INDICES:%.*]], [[MASK:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[INDICES_SEXT:%.*]] = sext [[INDICES]] to +; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]] +; CHECK: cond.block: +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr float, ptr [[BASE]], [[INDICES_SEXT]] +; CHECK-NEXT: [[LOAD:%.*]] = tail call @llvm.masked.gather.nxv4f32.nxv4p0( [[TMP0]], i32 4, [[MASK]], poison) +; CHECK-NEXT: ret [[LOAD]] +; CHECK: exit: +; CHECK-NEXT: ret zeroinitializer +; +entry: + %indices.sext = sext %indices to + %ptrs = getelementptr float, ptr %base, %indices.sext + br i1 %cond, label %cond.block, label %exit + +cond.block: + %load = tail call @llvm.masked.gather.nxv4f32( %ptrs, i32 4, %mask, poison) + br label %exit + +exit: + %ret = phi [ zeroinitializer, %entry ], [ %load, %cond.block ] + ret %ret +} + +; Sink zext to make use of scalar+uxtw(vector) addressing modes. +; TODO: There's an argument here to split the extend into i8->i32 and i32->i64, +; which would be especially useful if the i8s are the result of a load because +; it would maintain the use of sign-extending loads. +define @gather_offset_sink_zext(ptr %base, %indices, %mask, i1 %cond) { +; CHECK-LABEL: define @gather_offset_sink_zext( +; CHECK-SAME: ptr [[BASE:%.*]], [[INDICES:%.*]], [[MASK:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]] +; CHECK: cond.block: +; CHECK-NEXT: [[TMP0:%.*]] = zext [[INDICES]] to +; CHECK-NEXT: [[PTRS:%.*]] = getelementptr float, ptr [[BASE]], [[TMP0]] +; CHECK-NEXT: [[LOAD:%.*]] = tail call @llvm.masked.gather.nxv4f32.nxv4p0( [[PTRS]], i32 4, [[MASK]], poison) +; CHECK-NEXT: ret [[LOAD]] +; CHECK: exit: +; CHECK-NEXT: ret zeroinitializer +; +entry: + %indices.zext = zext %indices to + br i1 %cond, label %cond.block, label %exit + +cond.block: + %ptrs = getelementptr float, ptr %base, %indices.zext + %load = tail call @llvm.masked.gather.nxv4f32( %ptrs, i32 4, %mask, poison) + br label %exit + +exit: + %ret = phi [ zeroinitializer, %entry ], [ %load, %cond.block ] + ret %ret +} + +; Ensure we support scatters as well as gathers. +define void @scatter_offsets_sink_sext_get( %data, ptr %base, %indices, %mask, i1 %cond) { +; CHECK-LABEL: define void @scatter_offsets_sink_sext_get( +; CHECK-SAME: [[DATA:%.*]], ptr [[BASE:%.*]], [[INDICES:%.*]], [[MASK:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]] +; CHECK: cond.block: +; CHECK-NEXT: [[TMP0:%.*]] = sext [[INDICES]] to +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr float, ptr [[BASE]], [[TMP0]] +; CHECK-NEXT: tail call void @llvm.masked.scatter.nxv4f32.nxv4p0( [[DATA]], [[TMP1]], i32 4, [[MASK]]) +; CHECK-NEXT: ret void +; CHECK: exit: +; CHECK-NEXT: ret void +; +entry: + %indices.sext = sext %indices to + %ptrs = getelementptr float, ptr %base, %indices.sext + br i1 %cond, label %cond.block, label %exit + +cond.block: + tail call void @llvm.masked.scatter.nxv4f32( %data, %ptrs, i32 4, %mask) + br label %exit + +exit: + ret void +} + +declare @llvm.masked.gather.nxv4f32(, i32, , ) +declare void @llvm.masked.scatter.nxv4f32(, , i32, )