Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cranelift: Port ushr SIMD lowerings to ISLE on x64 #3688

Merged
merged 1 commit into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cranelift/codegen/src/isa/x64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,11 @@
(rule (psllq src1 src2)
(xmm_rmi_reg (SseOpcode.Psllq) src1 src2))

;; Helper for creating `psrlw` instructions.
(decl psrlw (Reg RegMemImm) Reg)
(rule (psrlw src1 src2)
(xmm_rmi_reg (SseOpcode.Psrlw) src1 src2))

;; Helper for creating `psrld` instructions.
(decl psrld (Reg RegMemImm) Reg)
(rule (psrld src1 src2)
Expand Down
73 changes: 62 additions & 11 deletions cranelift/codegen/src/isa/x64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -595,13 +595,17 @@

;; When the shift amount is known, we can statically (i.e. at compile time)
;; determine the mask to use and only emit that.
(decl ishl_i8x16_mask_for_const (u32) SyntheticAmode)
(extern constructor ishl_i8x16_mask_for_const ishl_i8x16_mask_for_const)
(rule (ishl_i8x16_mask (RegMemImm.Imm amt))
(ishl_i8x16_mask_for_const amt))

;; Otherwise, we must emit the entire mask table and dynamically (i.e. at run
;; time) find the correct mask offset in the table. We do this use `lea` to find
;; the base address of the mask table and then complex addressing to offset to
;; the right mask: `base_address + amt << 4`
;; time) find the correct mask offset in the table. We use `lea` to find the
;; base address of the mask table and then complex addressing to offset to the
;; right mask: `base_address + amt << 4`
(decl ishl_i8x16_mask_table () SyntheticAmode)
(extern constructor ishl_i8x16_mask_table ishl_i8x16_mask_table)
(rule (ishl_i8x16_mask (RegMemImm.Reg amt))
(let ((mask_table SyntheticAmode (ishl_i8x16_mask_table))
(base_mask_addr Reg (lea mask_table))
Expand All @@ -613,14 +617,6 @@
(rule (ishl_i8x16_mask (RegMemImm.Mem amt))
(ishl_i8x16_mask (RegMemImm.Reg (x64_load $I64 amt (ExtKind.None)))))

;; Get the address of the mask for a constant 8x16 shift amount.
(decl ishl_i8x16_mask_for_const (u32) SyntheticAmode)
(extern constructor ishl_i8x16_mask_for_const ishl_i8x16_mask_for_const)

;; Get the address of the mask table for a dynamic 8x16 shift amount.
(decl ishl_i8x16_mask_table () SyntheticAmode)
(extern constructor ishl_i8x16_mask_table ishl_i8x16_mask_table)

;; 16x8, 32x4, and 64x2 shifts can each use a single instruction.
(rule (lower (has_type $I16X8 (ishl src amt)))
(value_reg (psllw (put_in_reg src)
Expand Down Expand Up @@ -671,6 +667,61 @@
(let ((amt_ Reg (lo_reg amt)))
(shr_i128 (put_in_regs src) amt_)))

;; SSE.

;; There are no 8x16 shifts in x64. Do the same 16x8-shift-and-mask thing we do
;; with 8x16 `ishl`.
(rule (lower (has_type $I8X16 (ushr src amt)))
(let ((src_ Reg (put_in_reg src))
(amt_gpr RegMemImm (put_in_reg_mem_imm amt))
(amt_xmm RegMemImm (reg_mem_imm_to_xmm amt_gpr))
;; Shift `src` using 16x8. Unfortunately, a 16x8 shift will only be
;; correct for half of the lanes; the others must be fixed up with
;; the mask below.
(unmasked Reg (psrlw src_ amt_xmm))
(mask_addr SyntheticAmode (ushr_i8x16_mask amt_gpr))
(mask Reg (x64_load $I8X16 mask_addr (ExtKind.None))))
(value_reg (sse_and $I8X16 unmasked (RegMem.Reg mask)))))

;; Get the address of the mask to use when fixing up the lanes that weren't
;; correctly generated by the 16x8 shift.
(decl ushr_i8x16_mask (RegMemImm) SyntheticAmode)

;; When the shift amount is known, we can statically (i.e. at compile time)
;; determine the mask to use and only emit that.
(decl ushr_i8x16_mask_for_const (u32) SyntheticAmode)
(extern constructor ushr_i8x16_mask_for_const ushr_i8x16_mask_for_const)
(rule (ushr_i8x16_mask (RegMemImm.Imm amt))
(ushr_i8x16_mask_for_const amt))

;; Otherwise, we must emit the entire mask table and dynamically (i.e. at run
;; time) find the correct mask offset in the table. We use `lea` to find the
;; base address of the mask table and then complex addressing to offset to the
;; right mask: `base_address + amt << 4`
(decl ushr_i8x16_mask_table () SyntheticAmode)
(extern constructor ushr_i8x16_mask_table ushr_i8x16_mask_table)
(rule (ushr_i8x16_mask (RegMemImm.Reg amt))
(let ((mask_table SyntheticAmode (ushr_i8x16_mask_table))
(base_mask_addr Reg (lea mask_table))
(mask_offset Reg (shl $I64 amt (Imm8Reg.Imm8 4))))
(amode_to_synthetic_amode (amode_imm_reg_reg_shift 0
base_mask_addr
mask_offset
0))))
(rule (ushr_i8x16_mask (RegMemImm.Mem amt))
(ushr_i8x16_mask (RegMemImm.Reg (x64_load $I64 amt (ExtKind.None)))))

;; 16x8, 32x4, and 64x2 shifts can each use a single instruction.
(rule (lower (has_type $I16X8 (ushr src amt)))
(value_reg (psrlw (put_in_reg src)
(reg_mem_imm_to_xmm (put_in_reg_mem_imm amt)))))
(rule (lower (has_type $I32X4 (ushr src amt)))
(value_reg (psrld (put_in_reg src)
(reg_mem_imm_to_xmm (put_in_reg_mem_imm amt)))))
(rule (lower (has_type $I64X2 (ushr src amt)))
(value_reg (psrlq (put_in_reg src)
(reg_mem_imm_to_xmm (put_in_reg_mem_imm amt)))))

;;;; Rules for `sshr` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; `i64` and smaller.
Expand Down
167 changes: 4 additions & 163 deletions cranelift/codegen/src/isa/x64/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1539,10 +1539,11 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
| Opcode::Bnot
| Opcode::Bitselect
| Opcode::Vselect
| Opcode::Ushr
| Opcode::Sshr
| Opcode::Ishl => implemented_in_isle(ctx),

Opcode::Ushr | Opcode::Rotl | Opcode::Rotr => {
Opcode::Rotl | Opcode::Rotr => {
let dst_ty = ctx.output_ty(insn, 0);
debug_assert_eq!(ctx.input_ty(insn, 0), dst_ty);

Expand All @@ -1558,11 +1559,7 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
// This implementation uses the last two encoding methods.
let (size, lhs) = match dst_ty {
types::I8 | types::I16 => match op {
Opcode::Ushr => (
OperandSize::Size32,
extend_input_to_reg(ctx, inputs[0], ExtSpec::ZeroExtendTo32),
),
Opcode::Rotl | Opcode::Rotr => (
Opcode::Rotr => (
OperandSize::from_ty(dst_ty),
put_input_in_reg(ctx, inputs[0]),
),
Expand All @@ -1589,8 +1586,6 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();

let shift_kind = match op {
Opcode::Ushr => ShiftKind::ShiftRightLogical,
Opcode::Rotl => ShiftKind::RotateLeft,
Opcode::Rotr => ShiftKind::RotateRight,
_ => unreachable!(),
};
Expand All @@ -1607,9 +1602,6 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
let dst = get_output_reg(ctx, outputs[0]);

match op {
Opcode::Ushr | Opcode::Rotl => {
implemented_in_isle(ctx);
}
Opcode::Rotr => {
// (mov tmp, src)
// (ushr.i128 tmp, amt)
Expand Down Expand Up @@ -1642,159 +1634,8 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
}
_ => unreachable!(),
}
} else if dst_ty == types::I8X16 && op == Opcode::Ushr {
// Since the x86 instruction set does not have any 8x16 shift instructions (even in higher feature sets
// like AVX), we lower the `ishl.i8x16` and `ushr.i8x16` to a sequence of instructions. The basic idea,
// whether the `shift_by` amount is an immediate or not, is to use a 16x8 shift and then mask off the
// incorrect bits to 0s (see below for handling signs in `sshr.i8x16`).
let src = put_input_in_reg(ctx, inputs[0]);
let shift_by = input_to_reg_mem_imm(ctx, inputs[1]);
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();

// If necessary, move the shift index into the lowest bits of a vector register.
let shift_by_moved = match &shift_by {
RegMemImm::Imm { .. } => shift_by.clone(),
RegMemImm::Reg { reg } => {
let tmp_shift_by = ctx.alloc_tmp(dst_ty).only_reg().unwrap();
ctx.emit(Inst::gpr_to_xmm(
SseOpcode::Movd,
RegMem::reg(*reg),
OperandSize::Size32,
tmp_shift_by,
));
RegMemImm::reg(tmp_shift_by.to_reg())
}
RegMemImm::Mem { .. } => unimplemented!("load shift amount to XMM register"),
};

// Shift `src` using 16x8. Unfortunately, a 16x8 shift will only be correct for half of the lanes;
// the others must be fixed up with the mask below.
let shift_opcode = match op {
Opcode::Ushr => SseOpcode::Psrlw,
_ => unimplemented!("{} is not implemented for type {}", op, dst_ty),
};
ctx.emit(Inst::gen_move(dst, src, dst_ty));
ctx.emit(Inst::xmm_rmi_reg(shift_opcode, shift_by_moved, dst));

// Choose which mask to use to fixup the shifted lanes. Since we must use a 16x8 shift, we need to fix
// up the bits that migrate from one half of the lane to the other. Each 16-byte mask (which rustfmt
// forces to multiple lines) is indexed by the shift amount: e.g. if we shift right by 0 (no movement),
// we want to retain all the bits so we mask with `0xff`; if we shift right by 1, we want to retain all
// bits except the MSB so we mask with `0x7f`; etc.
const USHR_MASKS: [u8; 128] = [
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f,
0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x1f, 0x1f, 0x1f, 0x1f,
0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x0f,
0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f,
0x0f, 0x0f, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07,
0x07, 0x07, 0x07, 0x07, 0x07, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
];

let mask = match op {
Opcode::Ushr => &USHR_MASKS,
_ => unimplemented!("{} is not implemented for type {}", op, dst_ty),
};

// Figure out the address of the shift mask.
let mask_address = match shift_by {
RegMemImm::Imm { simm32 } => {
// When the shift amount is known, we can statically (i.e. at compile time) determine the mask to
// use and only emit that.
debug_assert!(simm32 < 8);
let mask_offset = simm32 as usize * 16;
let mask_constant = ctx.use_constant(VCodeConstantData::WellKnown(
&mask[mask_offset..mask_offset + 16],
));
SyntheticAmode::ConstantOffset(mask_constant)
}
RegMemImm::Reg { reg } => {
// Otherwise, we must emit the entire mask table and dynamically (i.e. at run time) find the correct
// mask offset in the table. We do this use LEA to find the base address of the mask table and then
// complex addressing to offset to the right mask: `base_address + shift_by * 4`
let base_mask_address = ctx.alloc_tmp(types::I64).only_reg().unwrap();
let mask_offset = ctx.alloc_tmp(types::I64).only_reg().unwrap();
let mask_constant = ctx.use_constant(VCodeConstantData::WellKnown(mask));
ctx.emit(Inst::lea(
SyntheticAmode::ConstantOffset(mask_constant),
base_mask_address,
));
ctx.emit(Inst::gen_move(mask_offset, reg, types::I64));
ctx.emit(Inst::shift_r(
OperandSize::Size64,
ShiftKind::ShiftLeft,
Some(4),
mask_offset,
));
Amode::imm_reg_reg_shift(
0,
base_mask_address.to_reg(),
mask_offset.to_reg(),
0,
)
.into()
}
RegMemImm::Mem { addr: _ } => unimplemented!("load mask address"),
};

// Load the mask into a temporary register, `mask_value`.
let mask_value = ctx.alloc_tmp(dst_ty).only_reg().unwrap();
ctx.emit(Inst::load(dst_ty, mask_address, mask_value, ExtKind::None));

// Remove the bits that would have disappeared in a true 8x16 shift. TODO in the future,
// this AND instruction could be coalesced with the load above.
let sse_op = match dst_ty {
types::F32X4 => SseOpcode::Andps,
types::F64X2 => SseOpcode::Andpd,
_ => SseOpcode::Pand,
};
ctx.emit(Inst::xmm_rm_r(sse_op, RegMem::from(mask_value), dst));
} else {
// For the remaining packed shifts not covered above, x86 has implementations that can either:
// - shift using an immediate
// - shift using a dynamic value given in the lower bits of another XMM register.
let src = put_input_in_reg(ctx, inputs[0]);
let shift_by = input_to_reg_mem_imm(ctx, inputs[1]);
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
let sse_op = match dst_ty {
types::I16X8 => match op {
Opcode::Ushr => SseOpcode::Psrlw,
_ => unimplemented!("{} is not implemented for type {}", op, dst_ty),
},
types::I32X4 => match op {
Opcode::Ushr => SseOpcode::Psrld,
_ => unimplemented!("{} is not implemented for type {}", op, dst_ty),
},
types::I64X2 => match op {
Opcode::Ushr => SseOpcode::Psrlq,
_ => unimplemented!("{} is not implemented for type {}", op, dst_ty),
},
_ => unreachable!(),
};

// If necessary, move the shift index into the lowest bits of a vector register.
let shift_by = match shift_by {
RegMemImm::Imm { .. } => shift_by,
RegMemImm::Reg { reg } => {
let tmp_shift_by = ctx.alloc_tmp(dst_ty).only_reg().unwrap();
ctx.emit(Inst::gpr_to_xmm(
SseOpcode::Movd,
RegMem::reg(reg),
OperandSize::Size32,
tmp_shift_by,
));
RegMemImm::reg(tmp_shift_by.to_reg())
}
RegMemImm::Mem { .. } => unimplemented!("load shift amount to XMM register"),
};

// Move the `src` to the same register as `dst`.
ctx.emit(Inst::gen_move(dst, src, dst_ty));

ctx.emit(Inst::xmm_rmi_reg(sse_op, shift_by, dst));
implemented_in_isle(ctx);
}
}

Expand Down
37 changes: 34 additions & 3 deletions cranelift/codegen/src/isa/x64/lower/isle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,33 @@ where
debug_assert!(amt < 8);
let mask_offset = amt as usize * 16;
let mask_constant = self.lower_ctx.use_constant(VCodeConstantData::WellKnown(
&I8X16_SHL_MASKS[mask_offset..mask_offset + 16],
&I8X16_ISHL_MASKS[mask_offset..mask_offset + 16],
));
SyntheticAmode::ConstantOffset(mask_constant)
}

fn ishl_i8x16_mask_table(&mut self) -> SyntheticAmode {
let mask_table = self
.lower_ctx
.use_constant(VCodeConstantData::WellKnown(&I8X16_SHL_MASKS));
.use_constant(VCodeConstantData::WellKnown(&I8X16_ISHL_MASKS));
SyntheticAmode::ConstantOffset(mask_table)
}

fn ushr_i8x16_mask_for_const(&mut self, amt: u32) -> SyntheticAmode {
// When the shift amount is known, we can statically (i.e. at compile
// time) determine the mask to use and only emit that.
debug_assert!(amt < 8);
let mask_offset = amt as usize * 16;
let mask_constant = self.lower_ctx.use_constant(VCodeConstantData::WellKnown(
&I8X16_USHR_MASKS[mask_offset..mask_offset + 16],
));
SyntheticAmode::ConstantOffset(mask_constant)
}

fn ushr_i8x16_mask_table(&mut self) -> SyntheticAmode {
let mask_table = self
.lower_ctx
.use_constant(VCodeConstantData::WellKnown(&I8X16_USHR_MASKS));
SyntheticAmode::ConstantOffset(mask_table)
}
}
Expand All @@ -289,8 +307,9 @@ where
// right by 0 (no movement), we want to retain all the bits so we mask with
// `0xff`; if we shift right by 1, we want to retain all bits except the MSB so
// we mask with `0x7f`; etc.

#[rustfmt::skip] // Preserve 16 bytes (i.e. one mask) per row.
const I8X16_SHL_MASKS: [u8; 128] = [
const I8X16_ISHL_MASKS: [u8; 128] = [
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe, 0xfe,
0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc, 0xfc,
Expand All @@ -301,6 +320,18 @@ const I8X16_SHL_MASKS: [u8; 128] = [
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
];

#[rustfmt::skip] // Preserve 16 bytes (i.e. one mask) per row.
const I8X16_USHR_MASKS: [u8; 128] = [
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f,
0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f,
0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f, 0x1f,
0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f,
0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07,
0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
];

#[inline]
fn to_simm32(constant: i64) -> Option<RegMemImm> {
if constant == ((constant << 32) >> 32) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
src/clif.isle f176ef3bba99365
src/prelude.isle 7b911d3b894ae17
src/isa/x64/inst.isle dbfa857f7f2c5d9f
src/isa/x64/lower.isle 5a737854091e1189
src/isa/x64/inst.isle 41304d8ef6f7d816
src/isa/x64/lower.isle 4689585f55f41438
Loading