Skip to content

Commit

Permalink
cranelift: Port ushr SIMD lowerings to ISLE on x64
Browse files Browse the repository at this point in the history
  • Loading branch information
fitzgen committed Jan 13, 2022
1 parent 46ade3d commit 4e34dd8
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 282 deletions.
5 changes: 5 additions & 0 deletions cranelift/codegen/src/isa/x64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,11 @@
(rule (psllq src1 src2)
(xmm_rmi_reg (SseOpcode.Psllq) src1 src2))

;; Helper for creating `psrlw` instructions.
(decl psrlw (Reg RegMemImm) Reg)
(rule (psrlw src1 src2)
(xmm_rmi_reg (SseOpcode.Psrlw) src1 src2))

;; Helper for creating `psrld` instructions.
(decl psrld (Reg RegMemImm) Reg)
(rule (psrld src1 src2)
Expand Down
73 changes: 62 additions & 11 deletions cranelift/codegen/src/isa/x64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -595,13 +595,17 @@

;; When the shift amount is known, we can statically (i.e. at compile time)
;; determine the mask to use and only emit that.
(decl ishl_i8x16_mask_for_const (u32) SyntheticAmode)
(extern constructor ishl_i8x16_mask_for_const ishl_i8x16_mask_for_const)
(rule (ishl_i8x16_mask (RegMemImm.Imm amt))
(ishl_i8x16_mask_for_const amt))

;; Otherwise, we must emit the entire mask table and dynamically (i.e. at run
;; time) find the correct mask offset in the table. We do this use `lea` to find
;; the base address of the mask table and then complex addressing to offset to
;; the right mask: `base_address + amt << 4`
;; time) find the correct mask offset in the table. We use `lea` to find the
;; base address of the mask table and then complex addressing to offset to the
;; right mask: `base_address + amt << 4`
(decl ishl_i8x16_mask_table () SyntheticAmode)
(extern constructor ishl_i8x16_mask_table ishl_i8x16_mask_table)
(rule (ishl_i8x16_mask (RegMemImm.Reg amt))
(let ((mask_table SyntheticAmode (ishl_i8x16_mask_table))
(base_mask_addr Reg (lea mask_table))
Expand All @@ -613,14 +617,6 @@
(rule (ishl_i8x16_mask (RegMemImm.Mem amt))
(ishl_i8x16_mask (RegMemImm.Reg (x64_load $I64 amt (ExtKind.None)))))

;; Get the address of the mask for a constant 8x16 shift amount.
(decl ishl_i8x16_mask_for_const (u32) SyntheticAmode)
(extern constructor ishl_i8x16_mask_for_const ishl_i8x16_mask_for_const)

;; Get the address of the mask table for a dynamic 8x16 shift amount.
(decl ishl_i8x16_mask_table () SyntheticAmode)
(extern constructor ishl_i8x16_mask_table ishl_i8x16_mask_table)

;; 16x8, 32x4, and 64x2 shifts can each use a single instruction.
(rule (lower (has_type $I16X8 (ishl src amt)))
(value_reg (psllw (put_in_reg src)
Expand Down Expand Up @@ -671,6 +667,61 @@
(let ((amt_ Reg (lo_reg amt)))
(shr_i128 (put_in_regs src) amt_)))

;; SSE.

;; There are no 8x16 shifts in x64. Do the same 16x8-shift-and-mask thing we do
;; with 8x16 `ishl`.
(rule (lower (has_type $I8X16 (ushr src amt)))
(let ((src_ Reg (put_in_reg src))
(amt_gpr RegMemImm (put_in_reg_mem_imm amt))
(amt_xmm RegMemImm (reg_mem_imm_to_xmm amt_gpr))
;; Shift `src` using 16x8. Unfortunately, a 16x8 shift will only be
;; correct for half of the lanes; the others must be fixed up with
;; the mask below.
(unmasked Reg (psrlw src_ amt_xmm))
(mask_addr SyntheticAmode (ushr_i8x16_mask amt_gpr))
(mask Reg (x64_load $I8X16 mask_addr (ExtKind.None))))
(value_reg (sse_and $I8X16 unmasked (RegMem.Reg mask)))))

;; Get the address of the mask to use when fixing up the lanes that weren't
;; correctly generated by the 16x8 shift.
(decl ushr_i8x16_mask (RegMemImm) SyntheticAmode)

;; When the shift amount is known, we can statically (i.e. at compile time)
;; determine the mask to use and only emit that.
(decl ushr_i8x16_mask_for_const (u32) SyntheticAmode)
(extern constructor ushr_i8x16_mask_for_const ushr_i8x16_mask_for_const)
(rule (ushr_i8x16_mask (RegMemImm.Imm amt))
(ushr_i8x16_mask_for_const amt))

;; Otherwise, we must emit the entire mask table and dynamically (i.e. at run
;; time) find the correct mask offset in the table. We use `lea` to find the
;; base address of the mask table and then complex addressing to offset to the
;; right mask: `base_address + amt << 4`
(decl ushr_i8x16_mask_table () SyntheticAmode)
(extern constructor ushr_i8x16_mask_table ushr_i8x16_mask_table)
(rule (ushr_i8x16_mask (RegMemImm.Reg amt))
(let ((mask_table SyntheticAmode (ushr_i8x16_mask_table))
(base_mask_addr Reg (lea mask_table))
(mask_offset Reg (shl $I64 amt (Imm8Reg.Imm8 4))))
(amode_to_synthetic_amode (amode_imm_reg_reg_shift 0
base_mask_addr
mask_offset
0))))
(rule (ushr_i8x16_mask (RegMemImm.Mem amt))
(ushr_i8x16_mask (RegMemImm.Reg (x64_load $I64 amt (ExtKind.None)))))

;; 16x8, 32x4, and 64x2 shifts can each use a single instruction.
(rule (lower (has_type $I16X8 (ushr src amt)))
(value_reg (psrlw (put_in_reg src)
(reg_mem_imm_to_xmm (put_in_reg_mem_imm amt)))))
(rule (lower (has_type $I32X4 (ushr src amt)))
(value_reg (psrld (put_in_reg src)
(reg_mem_imm_to_xmm (put_in_reg_mem_imm amt)))))
(rule (lower (has_type $I64X2 (ushr src amt)))
(value_reg (psrlq (put_in_reg src)
(reg_mem_imm_to_xmm (put_in_reg_mem_imm amt)))))

;;;; Rules for `sshr` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; `i64` and smaller.
Expand Down
167 changes: 4 additions & 163 deletions cranelift/codegen/src/isa/x64/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1539,10 +1539,11 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
| Opcode::Bnot
| Opcode::Bitselect
| Opcode::Vselect
| Opcode::Ushr
| Opcode::Sshr
| Opcode::Ishl => implemented_in_isle(ctx),

Opcode::Ushr | Opcode::Rotl | Opcode::Rotr => {
Opcode::Rotl | Opcode::Rotr => {
let dst_ty = ctx.output_ty(insn, 0);
debug_assert_eq!(ctx.input_ty(insn, 0), dst_ty);

Expand All @@ -1558,11 +1559,7 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
// This implementation uses the last two encoding methods.
let (size, lhs) = match dst_ty {
types::I8 | types::I16 => match op {
Opcode::Ushr => (
OperandSize::Size32,
extend_input_to_reg(ctx, inputs[0], ExtSpec::ZeroExtendTo32),
),
Opcode::Rotl | Opcode::Rotr => (
Opcode::Rotr => (
OperandSize::from_ty(dst_ty),
put_input_in_reg(ctx, inputs[0]),
),
Expand All @@ -1589,8 +1586,6 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();

let shift_kind = match op {
Opcode::Ushr => ShiftKind::ShiftRightLogical,
Opcode::Rotl => ShiftKind::RotateLeft,
Opcode::Rotr => ShiftKind::RotateRight,
_ => unreachable!(),
};
Expand All @@ -1607,9 +1602,6 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
let dst = get_output_reg(ctx, outputs[0]);

match op {
Opcode::Ushr | Opcode::Rotl => {
implemented_in_isle(ctx);
}
Opcode::Rotr => {
// (mov tmp, src)
// (ushr.i128 tmp, amt)
Expand Down Expand Up @@ -1642,159 +1634,8 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
}
_ => unreachable!(),
}
} else if dst_ty == types::I8X16 && op == Opcode::Ushr {
// Since the x86 instruction set does not have any 8x16 shift instructions (even in higher feature sets
// like AVX), we lower the `ishl.i8x16` and `ushr.i8x16` to a sequence of instructions. The basic idea,
// whether the `shift_by` amount is an immediate or not, is to use a 16x8 shift and then mask off the
// incorrect bits to 0s (see below for handling signs in `sshr.i8x16`).
let src = put_input_in_reg(ctx, inputs[0]);
let shift_by = input_to_reg_mem_imm(ctx, inputs[1]);
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();

// If necessary, move the shift index into the lowest bits of a vector register.
let shift_by_moved = match &shift_by {
RegMemImm::Imm { .. } => shift_by.clone(),
RegMemImm::Reg { reg } => {
let tmp_shift_by = ctx.alloc_tmp(dst_ty).only_reg().unwrap();
ctx.emit(Inst::gpr_to_xmm(
SseOpcode::Movd,
RegMem::reg(*reg),
OperandSize::Size32,
tmp_shift_by,
));
RegMemImm::reg(tmp_shift_by.to_reg())
}
RegMemImm::Mem { .. } => unimplemented!("load shift amount to XMM register"),
};

// Shift `src` using 16x8. Unfortunately, a 16x8 shift will only be correct for half of the lanes;
// the others must be fixed up with the mask below.
let shift_opcode = match op {
Opcode::Ushr => SseOpcode::Psrlw,
_ => unimplemented!("{} is not implemented for type {}", op, dst_ty),
};
ctx.emit(Inst::gen_move(dst, src, dst_ty));
ctx.emit(Inst::xmm_rmi_reg(shift_opcode, shift_by_moved, dst));

// Choose which mask to use to fixup the shifted lanes. Since we must use a 16x8 shift, we need to fix
// up the bits that migrate from one half of the lane to the other. Each 16-byte mask (which rustfmt
// forces to multiple lines) is indexed by the shift amount: e.g. if we shift right by 0 (no movement),
// we want to retain all the bits so we mask with `0xff`; if we shift right by 1, we want to retain all
// bits except the MSB so we mask with `0x7f`; etc.
const USHR_MASKS: [u8; 128] = [
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f,
0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x1f, 0x1f, 0x1f, 0x1f,
0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x0f,
0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f,
0x0f, 0x0f, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07,
0x07, 0x07, 0x07, 0x07, 0x07, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
];

let mask = match op {
Opcode::Ushr => &USHR_MASKS,
_ => unimplemented!("{} is not implemented for type {}", op, dst_ty),
};

// Figure out the address of the shift mask.
let mask_address = match shift_by {
RegMemImm::Imm { simm32 } => {
// When the shift amount is known, we can statically (i.e. at compile time) determine the mask to
// use and only emit that.
debug_assert!(simm32 < 8);
let mask_offset = simm32 as usize * 16;
let mask_constant = ctx.use_constant(VCodeConstantData::WellKnown(
&mask[mask_offset..mask_offset + 16],
));
SyntheticAmode::ConstantOffset(mask_constant)
}
RegMemImm::Reg { reg } => {
// Otherwise, we must emit the entire mask table and dynamically (i.e. at run time) find the correct
// mask offset in the table. We do this use LEA to find the base address of the mask table and then
// complex addressing to offset to the right mask: `base_address + shift_by * 4`
let base_mask_address = ctx.alloc_tmp(types::I64).only_reg().unwrap();
let mask_offset = ctx.alloc_tmp(types::I64).only_reg().unwrap();
let mask_constant = ctx.use_constant(VCodeConstantData::WellKnown(mask));
ctx.emit(Inst::lea(
SyntheticAmode::ConstantOffset(mask_constant),
base_mask_address,
));
ctx.emit(Inst::gen_move(mask_offset, reg, types::I64));
ctx.emit(Inst::shift_r(
OperandSize::Size64,
ShiftKind::ShiftLeft,
Some(4),
mask_offset,
));
Amode::imm_reg_reg_shift(
0,
base_mask_address.to_reg(),
mask_offset.to_reg(),
0,
)
.into()
}
RegMemImm::Mem { addr: _ } => unimplemented!("load mask address"),
};

// Load the mask into a temporary register, `mask_value`.
let mask_value = ctx.alloc_tmp(dst_ty).only_reg().unwrap();
ctx.emit(Inst::load(dst_ty, mask_address, mask_value, ExtKind::None));

// Remove the bits that would have disappeared in a true 8x16 shift. TODO in the future,
// this AND instruction could be coalesced with the load above.
let sse_op = match dst_ty {
types::F32X4 => SseOpcode::Andps,
types::F64X2 => SseOpcode::Andpd,
_ => SseOpcode::Pand,
};
ctx.emit(Inst::xmm_rm_r(sse_op, RegMem::from(mask_value), dst));
} else {
// For the remaining packed shifts not covered above, x86 has implementations that can either:
// - shift using an immediate
// - shift using a dynamic value given in the lower bits of another XMM register.
let src = put_input_in_reg(ctx, inputs[0]);
let shift_by = input_to_reg_mem_imm(ctx, inputs[1]);
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
let sse_op = match dst_ty {
types::I16X8 => match op {
Opcode::Ushr => SseOpcode::Psrlw,
_ => unimplemented!("{} is not implemented for type {}", op, dst_ty),
},
types::I32X4 => match op {
Opcode::Ushr => SseOpcode::Psrld,
_ => unimplemented!("{} is not implemented for type {}", op, dst_ty),
},
types::I64X2 => match op {
Opcode::Ushr => SseOpcode::Psrlq,
_ => unimplemented!("{} is not implemented for type {}", op, dst_ty),
},
_ => unreachable!(),
};

// If necessary, move the shift index into the lowest bits of a vector register.
let shift_by = match shift_by {
RegMemImm::Imm { .. } => shift_by,
RegMemImm::Reg { reg } => {
let tmp_shift_by = ctx.alloc_tmp(dst_ty).only_reg().unwrap();
ctx.emit(Inst::gpr_to_xmm(
SseOpcode::Movd,
RegMem::reg(reg),
OperandSize::Size32,
tmp_shift_by,
));
RegMemImm::reg(tmp_shift_by.to_reg())
}
RegMemImm::Mem { .. } => unimplemented!("load shift amount to XMM register"),
};

// Move the `src` to the same register as `dst`.
ctx.emit(Inst::gen_move(dst, src, dst_ty));

ctx.emit(Inst::xmm_rmi_reg(sse_op, shift_by, dst));
implemented_in_isle(ctx);
}
}

Expand Down
37 changes: 34 additions & 3 deletions cranelift/codegen/src/isa/x64/lower/isle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,33 @@ where
debug_assert!(amt < 8);
let mask_offset = amt as usize * 16;
let mask_constant = self.lower_ctx.use_constant(VCodeConstantData::WellKnown(
&I8X16_SHL_MASKS[mask_offset..mask_offset + 16],
&I8X16_ISHL_MASKS[mask_offset..mask_offset + 16],
));
SyntheticAmode::ConstantOffset(mask_constant)
}

fn ishl_i8x16_mask_table(&mut self) -> SyntheticAmode {
let mask_table = self
.lower_ctx
.use_constant(VCodeConstantData::WellKnown(&I8X16_SHL_MASKS));
.use_constant(VCodeConstantData::WellKnown(&I8X16_ISHL_MASKS));
SyntheticAmode::ConstantOffset(mask_table)
}

fn ushr_i8x16_mask_for_const(&mut self, amt: u32) -> SyntheticAmode {
// When the shift amount is known, we can statically (i.e. at compile
// time) determine the mask to use and only emit that.
debug_assert!(amt < 8);
let mask_offset = amt as usize * 16;
let mask_constant = self.lower_ctx.use_constant(VCodeConstantData::WellKnown(
&I8X16_USHR_MASKS[mask_offset..mask_offset + 16],
));
SyntheticAmode::ConstantOffset(mask_constant)
}

fn ushr_i8x16_mask_table(&mut self) -> SyntheticAmode {
let mask_table = self
.lower_ctx
.use_constant(VCodeConstantData::WellKnown(&I8X16_USHR_MASKS));
SyntheticAmode::ConstantOffset(mask_table)
}
}
Expand All @@ -289,8 +307,9 @@ where
// right by 0 (no movement), we want to retain all the bits so we mask with
// `0xff`; if we shift right by 1, we want to retain all bits except the MSB so
// we mask with `0x7f`; etc.

#[rustfmt::skip] // Preserve 16 bytes (i.e. one mask) per row.
const I8X16_SHL_MASKS: [u8; 128] = [
const I8X16_ISHL_MASKS: [u8; 128] = [
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe,
0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc,
Expand All @@ -301,6 +320,18 @@ const I8X16_SHL_MASKS: [u8; 128] = [
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
];

#[rustfmt::skip] // Preserve 16 bytes (i.e. one mask) per row.
const I8X16_USHR_MASKS: [u8; 128] = [
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f,
0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f,
0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f,
0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07,
0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
];

#[inline]
fn to_simm32(constant: i64) -> Option<RegMemImm> {
if constant == ((constant << 32) >> 32) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
src/clif.isle f176ef3bba99365
src/prelude.isle 7b911d3b894ae17
src/isa/x64/inst.isle dbfa857f7f2c5d9f
src/isa/x64/lower.isle 5a737854091e1189
src/isa/x64/inst.isle 41304d8ef6f7d816
src/isa/x64/lower.isle 4689585f55f41438
Loading

0 comments on commit 4e34dd8

Please sign in to comment.