Skip to content

Commit

Permalink
riscv64: Improve signed and zero extend codegen (bytecodealliance#5844)
Browse files Browse the repository at this point in the history
* riscv64: Remove unused code

* riscv64: Group extend rules

* riscv64: Remove more unused rules

* riscv64: Cleanup existing extension rules

* riscv64: Move the existing Extend rules to ISLE

* riscv64: Use `sext.w` when extending

* riscv64: Remove duplicate extend tests

* riscv64: Use `zbb` instructions when extending values

* riscv64: Use `zbkb` extensions when zero extending

* riscv64: Enable additional tests for extend i128

* riscv64: Fix formatting for `Inst::Extend`

* riscv64: Reverse register for pack

* riscv64: Misc Cleanups

* riscv64: Cleanup extend rules
  • Loading branch information
afonso360 authored Feb 22, 2023
1 parent 6e6a103 commit f6c6bc2
Show file tree
Hide file tree
Showing 68 changed files with 1,918 additions and 1,581 deletions.
203 changes: 107 additions & 96 deletions cranelift/codegen/src/isa/riscv64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,11 @@
(Clmul)
(Clmulh)
(Clmulr)

;; Zbkb: Bit-manipulation for Cryptography
(Pack)
(Packw)
(Packh)
))


Expand Down Expand Up @@ -858,22 +863,6 @@
(_ Unit (emit (MInst.AluRRImm12 op dst src (imm12_zero)))))
dst))

;; extend int if need.
(decl ext_int_if_need (bool ValueRegs Type) ValueRegs)
;;; for I8, I16, and I32 ...
(rule -1
(ext_int_if_need signed val ty)
(gen_extend val signed (ty_bits ty) 64))
;;; otherwise this is a I64 or I128
;;; no need to extend.
(rule
(ext_int_if_need _ r $I64)
r)
(rule
(ext_int_if_need _ r $I128)
r)


;; Helper for get negative of Imm12
(decl neg_imm12 (Imm12) Imm12)
(extern constructor neg_imm12 neg_imm12)
Expand Down Expand Up @@ -1031,50 +1020,116 @@
;; add low and high together.
(result Reg (alu_add high low)))
(value_regs result (load_u64_constant 0))))

;; Extends an integer if it is smaller than 64 bits.
(decl ext_int_if_need (bool ValueRegs Type) ValueRegs)
;;; For values smaller than 64 bits, we need to extend them to 64 bits
(rule 0 (ext_int_if_need $true val (fits_in_32 (ty_int ty)))
(sext val ty $I64))
(rule 0 (ext_int_if_need $false val (fits_in_32 (ty_int ty)))
(zext val ty $I64))
;; If the value is larger than one machine register, we don't need to do anything
(rule 1 (ext_int_if_need _ r $I64) r)
(rule 2 (ext_int_if_need _ r $I128) r)

(decl gen_extend (Reg bool u8 u8) Reg)
(rule
(gen_extend r is_signed from_bits to_bits)
(let
((tmp WritableReg (temp_writable_reg $I16))
(_ Unit (emit (MInst.Extend tmp r is_signed from_bits to_bits))))
tmp))

;; val is_signed from_bits to_bits
(decl lower_extend (Reg bool u8 u8) ValueRegs)
(rule -1
(lower_extend r is_signed from_bits to_bits)
(gen_extend r is_signed from_bits to_bits))
;; Performs a zero extension of the given value
(decl zext (ValueRegs Type Type) ValueRegs)
(rule (zext val from_ty to_ty) (extend val (ExtendOp.Zero) from_ty to_ty))

;;;; for I128 signed extend.
(rule 1
(lower_extend r $true 64 128)
(let
((tmp Reg (alu_rrr (AluOPRRR.Slt) r (zero_reg)))
(high Reg (gen_extend tmp $true 1 64)))
(value_regs (gen_move2 r $I64 $I64) high)))
;; Performs a signed extension of the given value
(decl sext (ValueRegs Type Type) ValueRegs)
(rule (sext val from_ty to_ty) (extend val (ExtendOp.Signed) from_ty to_ty))

(rule
(lower_extend r $true from_bits 128)
(let
((tmp Reg (gen_extend r $true from_bits 64))
(tmp2 Reg (alu_rrr (AluOPRRR.Slt) tmp (zero_reg)))
(high Reg (gen_extend tmp2 $true 1 64)))
(value_regs (gen_move2 tmp $I64 $I64) high)))
(type ExtendOp
(enum
(Zero)
(Signed)))

;; Performs either a sign or zero extension of the given value
(decl extend (ValueRegs ExtendOp Type Type) ValueRegs)

;;; Generic Rules Extending to I64
(decl pure extend_shift_op (ExtendOp) AluOPRRI)
(rule (extend_shift_op (ExtendOp.Zero)) (AluOPRRI.Srli))
(rule (extend_shift_op (ExtendOp.Signed)) (AluOPRRI.Srai))

;; In the most generic case, we shift left and then shift right.
;; The type of right shift is determined by the extend op.
(rule 0 (extend val extend_op (fits_in_32 from_ty) (fits_in_64 to_ty))
(let ((val Reg (value_regs_get val 0))
(shift Imm12 (imm_from_bits (u64_sub 64 (ty_bits from_ty))))
(left Reg (alu_rr_imm12 (AluOPRRI.Slli) val shift))
(shift_op AluOPRRI (extend_shift_op extend_op))
(right Reg (alu_rr_imm12 shift_op left shift)))
right))

;; If we are zero extending a U8 we can use a `andi` instruction.
(rule 1 (extend val (ExtendOp.Zero) $I8 (fits_in_64 to_ty))
(let ((val Reg (value_regs_get val 0)))
(alu_rr_imm12 (AluOPRRI.Andi) val (imm12_const 255))))

;; When signed extending from 32 to 64 bits we can use a
;; `addiw val 0`. Also known as a `sext.w`
(rule 1 (extend val (ExtendOp.Signed) $I32 $I64)
(let ((val Reg (value_regs_get val 0)))
(alu_rr_imm12 (AluOPRRI.Addiw) val (imm12_const 0))))


;; No point in trying to use `packh` here to zero extend 8 bit values
;; since we can just use `andi` instead which is part of the base ISA.

;; If we have the `zbkb` extension `packw` can be used to zero extend 16 bit values
(rule 1 (extend val (ExtendOp.Zero) $I16 (fits_in_64 _))
(if-let $true (has_zbkb))
(let ((val Reg (value_regs_get val 0)))
(alu_rrr (AluOPRRR.Packw) val (zero_reg))))

;; If we have the `zbkb` extension `pack` can be used to zero extend 32 bit registers
(rule 1 (extend val (ExtendOp.Zero) $I32 $I64)
(if-let $true (has_zbkb))
(let ((val Reg (value_regs_get val 0)))
(alu_rrr (AluOPRRR.Pack) val (zero_reg))))

;;;; for I128 unsigned extend.
(rule 1
(lower_extend r $false 64 128)
(value_regs (gen_move2 r $I64 $I64) (load_u64_constant 0)))

(rule
(lower_extend r $false from_bits 128)
(value_regs (gen_extend r $false from_bits 64) (load_u64_constant 0)))
;; If we have the `zbb` extension we can use the dedicated `sext.b` instruction.
(rule 1 (extend val (ExtendOp.Signed) $I8 (fits_in_64 _))
(if-let $true (has_zbb))
(let ((val Reg (value_regs_get val 0)))
(alu_rr_imm12 (AluOPRRI.Sextb) val (imm12_const 0))))

;; If we have the `zbb` extension we can use the dedicated `sext.h` instruction.
(rule 1 (extend val (ExtendOp.Signed) $I16 (fits_in_64 _))
(if-let $true (has_zbb))
(let ((val Reg (value_regs_get val 0)))
(alu_rr_imm12 (AluOPRRI.Sexth) val (imm12_const 0))))

;; If we have the `zbb` extension we can use the dedicated `zext.h` instruction.
(rule 2 (extend val (ExtendOp.Zero) $I16 (fits_in_64 _))
(if-let $true (has_zbb))
(let ((val Reg (value_regs_get val 0)))
(alu_rr_imm12 (AluOPRRI.Zexth) val (imm12_const 0))))

;;; Signed rules extending to I128
;; Extend the bottom part, and extract the sign bit from the bottom as the top
(rule 2 (extend val (ExtendOp.Signed) (fits_in_64 from_ty) $I128)
(let ((val Reg (value_regs_get val 0))
(low Reg (extend val (ExtendOp.Signed) from_ty $I64))
(high Reg (alu_rr_imm12 (AluOPRRI.Srai) low (imm12_const 63))))
(value_regs low high)))

;;; Unsigned rules extending to I128
;; Extend the bottom register to I64 and then just zero out the top half.
(rule 3 (extend val (ExtendOp.Zero) (fits_in_64 from_ty) $I128)
(let ((val Reg (value_regs_get val 0))
(low Reg (extend val (ExtendOp.Zero) from_ty $I64))
(high Reg (load_u64_constant 0)))
(value_regs low high)))

;; Catch all rule for ignoring extensions of the same type.
(rule 4 (extend val _ ty ty) val)


;; extract the sign bit of integer.
(decl ext_sign_bit (Type Reg) Reg)
(extern constructor ext_sign_bit ext_sign_bit)

(decl lower_b128_binary (AluOPRRR ValueRegs ValueRegs) ValueRegs)
(rule
Expand Down Expand Up @@ -1795,50 +1850,6 @@
(rule (lower_icmp cc x y ty)
(gen_icmp cc (ext_int_if_need $false x ty) (ext_int_if_need $false y ty) ty))

(decl lower_icmp_over_flow (ValueRegs ValueRegs Type) Reg)

;;; for I8 I16 I32
(rule 1
(lower_icmp_over_flow x y ty)
(let
((tmp Reg (alu_sub (ext_int_if_need $true x ty) (ext_int_if_need $true y ty)))
(tmp2 WritableReg (temp_writable_reg $I64))
(_ Unit (emit (MInst.Extend tmp2 tmp $true (ty_bits ty) 64))))
(gen_icmp (IntCC.NotEqual) (writable_reg_to_reg tmp2) tmp $I64)))

;;; $I64
(rule 3
(lower_icmp_over_flow x y $I64)
(let
((y_sign Reg (alu_rrr (AluOPRRR.Sgt) y (zero_reg)))
(sub_result Reg (alu_sub x y))
(tmp Reg (alu_rrr (AluOPRRR.Slt) sub_result x)))
(gen_icmp (IntCC.NotEqual) y_sign tmp $I64)))

;;; $I128
(rule 2
(lower_icmp_over_flow x y $I128)
(let
( ;; x sign bit.
(xs Reg (alu_rr_imm12 (AluOPRRI.Srli) (value_regs_get x 1) (imm12_const 63)))
;; y sign bit.
(ys Reg (alu_rr_imm12 (AluOPRRI.Srli) (value_regs_get y 1) (imm12_const 63)))
;;
(sub_result ValueRegs (i128_sub x y))
;; result sign bit.
(rs Reg (alu_rr_imm12 (AluOPRRI.Srli) (value_regs_get sub_result 1) (imm12_const 63)))

;;; xs && !ys && !rs
;;; x is positive y is negtive and result is negative.
;;; must overflow
(tmp1 Reg (alu_and xs (alu_and (gen_bit_not ys) (gen_bit_not rs))))
;;; !xs && ys && rs
;;; x is negative y is positive and result is positive.
;;; overflow
(tmp2 Reg (alu_and (gen_bit_not xs) (alu_and ys rs)))
;;;tmp3
(tmp3 Reg (alu_rrr (AluOPRRR.Or) tmp1 tmp2)))
(gen_extend tmp3 $true 1 64)))

(decl i128_sub (ValueRegs ValueRegs) ValueRegs)
(rule
Expand Down
27 changes: 23 additions & 4 deletions cranelift/codegen/src/isa/riscv64/inst/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,9 @@ impl AluOPRRR {
Self::Sh3add => "sh3add",
Self::Sh3adduw => "sh3add.uw",
Self::Xnor => "xnor",
Self::Pack => "pack",
Self::Packw => "packw",
Self::Packh => "packh",
}
}

Expand Down Expand Up @@ -785,6 +788,7 @@ impl AluOPRRR {
AluOPRRR::Remw => 0b110,
AluOPRRR::Remuw => 0b111,

// Zbb
AluOPRRR::Adduw => 0b000,
AluOPRRR::Andn => 0b111,
AluOPRRR::Bclr => 0b001,
Expand All @@ -810,6 +814,11 @@ impl AluOPRRR {
AluOPRRR::Sh3add => 0b110,
AluOPRRR::Sh3adduw => 0b110,
AluOPRRR::Xnor => 0b100,

// Zbkb
AluOPRRR::Pack => 0b100,
AluOPRRR::Packw => 0b100,
AluOPRRR::Packh => 0b111,
}
}

Expand All @@ -826,11 +835,16 @@ impl AluOPRRR {
| AluOPRRR::Srl
| AluOPRRR::Sra
| AluOPRRR::Or
| AluOPRRR::And => 0b0110011,
| AluOPRRR::And
| AluOPRRR::Pack
| AluOPRRR::Packh => 0b0110011,

AluOPRRR::Addw | AluOPRRR::Subw | AluOPRRR::Sllw | AluOPRRR::Srlw | AluOPRRR::Sraw => {
0b0111011
}
AluOPRRR::Addw
| AluOPRRR::Subw
| AluOPRRR::Sllw
| AluOPRRR::Srlw
| AluOPRRR::Sraw
| AluOPRRR::Packw => 0b0111011,

AluOPRRR::Mul
| AluOPRRR::Mulh
Expand Down Expand Up @@ -937,6 +951,11 @@ impl AluOPRRR {
AluOPRRR::Sh3add => 0b0010000,
AluOPRRR::Sh3adduw => 0b0010000,
AluOPRRR::Xnor => 0b0100000,

// Zbkb
AluOPRRR::Pack => 0b0000100,
AluOPRRR::Packw => 0b0000100,
AluOPRRR::Packh => 0b0000100,
}
}

Expand Down
32 changes: 32 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst/emit_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,38 @@ fn test_riscv64_binemit() {
0x400545b3,
));

// Zbkb
insns.push(TestUnit::new(
Inst::AluRRR {
alu_op: AluOPRRR::Pack,
rd: writable_a1(),
rs1: a0(),
rs2: zero_reg(),
},
"pack a1,a0,zero",
0x080545b3,
));
insns.push(TestUnit::new(
Inst::AluRRR {
alu_op: AluOPRRR::Packw,
rd: writable_a1(),
rs1: a0(),
rs2: zero_reg(),
},
"packw a1,a0,zero",
0x080545bb,
));
insns.push(TestUnit::new(
Inst::AluRRR {
alu_op: AluOPRRR::Packh,
rd: writable_a1(),
rs1: a0(),
rs2: zero_reg(),
},
"packh a1,a0,zero",
0x080575b3,
));

//
insns.push(TestUnit::new(
Inst::AluRRR {
Expand Down
Loading

0 comments on commit f6c6bc2

Please sign in to comment.