Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

riscv64: Implement SIMD icmp #6609

Merged
merged 1 commit into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,8 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool {
"simd_f64x2_cmp",
"simd_f64x2_pmin_pmax",
"simd_f64x2_rounding",
"simd_i16x8_cmp",
"simd_i32x4_cmp",
"simd_i32x4_trunc_sat_f32x4",
"simd_i32x4_trunc_sat_f64x2",
"simd_i64x2_cmp",
"simd_i8x16_cmp",
"simd_load",
"simd_splat",
]
Expand Down
48 changes: 43 additions & 5 deletions cranelift/codegen/src/isa/riscv64/inst/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,14 @@ impl VecAluOpRRR {
VecAluOpRRR::VwaddWV | VecAluOpRRR::VwaddWX => 0b110101,
VecAluOpRRR::VwsubuWV | VecAluOpRRR::VwsubuWX => 0b110110,
VecAluOpRRR::VwsubWV | VecAluOpRRR::VwsubWX => 0b110111,
VecAluOpRRR::VmsltVX => 0b011011,
VecAluOpRRR::VmseqVV | VecAluOpRRR::VmseqVX => 0b011000,
VecAluOpRRR::VmsneVV | VecAluOpRRR::VmsneVX => 0b011001,
VecAluOpRRR::VmsltuVV | VecAluOpRRR::VmsltuVX => 0b011010,
VecAluOpRRR::VmsltVV | VecAluOpRRR::VmsltVX => 0b011011,
VecAluOpRRR::VmsleuVV | VecAluOpRRR::VmsleuVX => 0b011100,
VecAluOpRRR::VmsleVV | VecAluOpRRR::VmsleVX => 0b011101,
VecAluOpRRR::VmsgtuVX => 0b011110,
VecAluOpRRR::VmsgtVX => 0b011111,
}
}

Expand All @@ -381,7 +388,13 @@ impl VecAluOpRRR {
| VecAluOpRRR::VmaxuVV
| VecAluOpRRR::VmaxVV
| VecAluOpRRR::VmergeVVM
| VecAluOpRRR::VrgatherVV => VecOpCategory::OPIVV,
| VecAluOpRRR::VrgatherVV
| VecAluOpRRR::VmseqVV
| VecAluOpRRR::VmsneVV
| VecAluOpRRR::VmsltuVV
| VecAluOpRRR::VmsltVV
| VecAluOpRRR::VmsleuVV
| VecAluOpRRR::VmsleVV => VecOpCategory::OPIVV,
VecAluOpRRR::VwaddVV
| VecAluOpRRR::VwaddWV
| VecAluOpRRR::VwadduVV
Expand Down Expand Up @@ -427,8 +440,15 @@ impl VecAluOpRRR {
| VecAluOpRRR::VmaxVX
| VecAluOpRRR::VslidedownVX
| VecAluOpRRR::VmergeVXM
| VecAluOpRRR::VrgatherVX
| VecAluOpRRR::VmseqVX
| VecAluOpRRR::VmsneVX
| VecAluOpRRR::VmsltuVX
| VecAluOpRRR::VmsltVX
| VecAluOpRRR::VrgatherVX => VecOpCategory::OPIVX,
| VecAluOpRRR::VmsleuVX
| VecAluOpRRR::VmsleVX
| VecAluOpRRR::VmsgtuVX
| VecAluOpRRR::VmsgtVX => VecOpCategory::OPIVX,
VecAluOpRRR::VfaddVV
| VecAluOpRRR::VfsubVV
| VecAluOpRRR::VfmulVV
Expand Down Expand Up @@ -522,6 +542,12 @@ impl VecAluOpRRImm5 {
VecAluOpRRImm5::VsaddVI => 0b100001,
VecAluOpRRImm5::VrgatherVI => 0b001100,
VecAluOpRRImm5::VmvrV => 0b100111,
VecAluOpRRImm5::VmseqVI => 0b011000,
VecAluOpRRImm5::VmsneVI => 0b011001,
VecAluOpRRImm5::VmsleuVI => 0b011100,
VecAluOpRRImm5::VmsleVI => 0b011101,
VecAluOpRRImm5::VmsgtuVI => 0b011110,
VecAluOpRRImm5::VmsgtVI => 0b011111,
}
}

Expand All @@ -541,7 +567,13 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VsadduVI
| VecAluOpRRImm5::VsaddVI
| VecAluOpRRImm5::VrgatherVI
| VecAluOpRRImm5::VmvrV => VecOpCategory::OPIVI,
| VecAluOpRRImm5::VmvrV
| VecAluOpRRImm5::VmseqVI
| VecAluOpRRImm5::VmsneVI
| VecAluOpRRImm5::VmsleuVI
| VecAluOpRRImm5::VmsleVI
| VecAluOpRRImm5::VmsgtuVI
| VecAluOpRRImm5::VmsgtVI => VecOpCategory::OPIVI,
}
}

Expand All @@ -561,7 +593,13 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VxorVI
| VecAluOpRRImm5::VmergeVIM
| VecAluOpRRImm5::VsadduVI
| VecAluOpRRImm5::VsaddVI => false,
| VecAluOpRRImm5::VsaddVI
| VecAluOpRRImm5::VmseqVI
| VecAluOpRRImm5::VmsneVI
| VecAluOpRRImm5::VmsleuVI
| VecAluOpRRImm5::VmsleVI
| VecAluOpRRImm5::VmsgtuVI
| VecAluOpRRImm5::VmsgtVI => false,
}
}

Expand Down
254 changes: 253 additions & 1 deletion cranelift/codegen/src/isa/riscv64/inst_vector.isle
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@
(VredminuVS)
(VrgatherVV)
(VcompressVM)
(VmseqVV)
(VmsneVV)
(VmsltuVV)
(VmsltVV)
(VmsleuVV)
(VmsleVV)


;; Vector-Scalar Opcodes
(VaddVX)
Expand Down Expand Up @@ -169,7 +176,14 @@
(VmergeVXM)
(VfmergeVFM)
(VrgatherVX)
(VmseqVX)
(VmsneVX)
(VmsltuVX)
(VmsltVX)
(VmsleuVX)
(VmsleVX)
(VmsgtuVX)
(VmsgtVX)
))


Expand Down Expand Up @@ -199,6 +213,12 @@
;; This opcode represents multiple instructions `vmv1r`/`vmv2r`/`vmv4r`/etc...
;; The immediate field specifies how many registers should be copied.
(VmvrV)
(VmseqVI)
(VmsneVI)
(VmsleuVI)
(VmsleVI)
(VmsgtuVI)
(VmsgtVI)
))

;; Imm only ALU Ops
Expand Down Expand Up @@ -969,11 +989,126 @@
(rule (rv_vcompress_vm vs2 vs1 vstate)
(vec_alu_rrr (VecAluOpRRR.VcompressVM) vs2 vs1 (unmasked) vstate))

;; Helper for emitting the `vmslt.vx` (Vector Mask Set Less Than) instruction.
;; Helper for emitting the `vmseq.vv` (Vector Mask Set If Equal) instruction.
(decl rv_vmseq_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmseq_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmseqVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmseq.vx` (Vector Mask Set If Equal) instruction.
(decl rv_vmseq_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmseq_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmseqVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmseq.vi` (Vector Mask Set If Equal) instruction.
(decl rv_vmseq_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmseq_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmseqVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsne.vv` (Vector Mask Set If Not Equal) instruction.
(decl rv_vmsne_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsne_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsneVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsne.vx` (Vector Mask Set If Not Equal) instruction.
(decl rv_vmsne_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsne_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsneVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsne.vi` (Vector Mask Set If Not Equal) instruction.
(decl rv_vmsne_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmsne_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmsneVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsltu.vv` (Vector Mask Set If Less Than, Unsigned) instruction.
(decl rv_vmsltu_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsltu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsltuVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsltu.vx` (Vector Mask Set If Less Than, Unsigned) instruction.
(decl rv_vmsltu_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsltu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsltuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmslt.vv` (Vector Mask Set If Less Than) instruction.
(decl rv_vmslt_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmslt_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsltVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmslt.vx` (Vector Mask Set If Less Than) instruction.
(decl rv_vmslt_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmslt_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsltVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsleu.vv` (Vector Mask Set If Less Than or Equal, Unsigned) instruction.
(decl rv_vmsleu_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsleu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsleuVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsleu.vx` (Vector Mask Set If Less Than or Equal, Unsigned) instruction.
(decl rv_vmsleu_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsleu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsleuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsleu.vi` (Vector Mask Set If Less Than or Equal, Unsigned) instruction.
(decl rv_vmsleu_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmsleu_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmsleuVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsle.vv` (Vector Mask Set If Less Than or Equal) instruction.
(decl rv_vmsle_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsle_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsleVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsle.vx` (Vector Mask Set If Less Than or Equal) instruction.
(decl rv_vmsle_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsle_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsleVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsle.vi` (Vector Mask Set If Less Than or Equal) instruction.
(decl rv_vmsle_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmsle_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmsleVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsgt.vv` (Vector Mask Set If Greater Than, Unsigned) instruction.
;; This is an alias for `vmsltu.vv` with the operands inverted.
(decl rv_vmsgtu_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsgtu_vv vs2 vs1 mask vstate) (rv_vmsltu_vv vs1 vs2 mask vstate))

;; Helper for emitting the `vmsgtu.vx` (Vector Mask Set If Greater Than, Unsigned) instruction.
(decl rv_vmsgtu_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsgtu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsgtuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsgtu.vi` (Vector Mask Set If Greater Than, Unsigned) instruction.
(decl rv_vmsgtu_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmsgtu_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmsgtuVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsgt.vv` (Vector Mask Set If Greater Than) instruction.
;; This is an alias for `vmslt.vv` with the operands inverted.
(decl rv_vmsgt_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsgt_vv vs2 vs1 mask vstate) (rv_vmslt_vv vs1 vs2 mask vstate))

;; Helper for emitting the `vmsgt.vx` (Vector Mask Set If Greater Than) instruction.
(decl rv_vmsgt_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmsgt_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmsgtVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmsgt.vi` (Vector Mask Set If Greater Than) instruction.
(decl rv_vmsgt_vi (VReg Imm5 VecOpMasking VState) VReg)
(rule (rv_vmsgt_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VmsgtVI) vs2 imm mask vstate))

;; Helper for emitting the `vmsgeu.vv` (Vector Mask Set If Greater Than or Equal, Unsigned) instruction.
;; This is an alias for `vmsleu.vv` with the operands inverted.
(decl rv_vmsgeu_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsgeu_vv vs2 vs1 mask vstate) (rv_vmsleu_vv vs1 vs2 mask vstate))

;; Helper for emitting the `vmsge.vv` (Vector Mask Set If Greater Than or Equal) instruction.
;; This is an alias for `vmsle.vv` with the operands inverted.
(decl rv_vmsge_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vmsge_vv vs2 vs1 mask vstate) (rv_vmsle_vv vs1 vs2 mask vstate))

;; Helper for emitting the `vzext.vf2` instruction.
;; Zero-extend SEW/2 source to SEW destination
(decl rv_vzext_vf2 (VReg VecOpMasking VState) VReg)
Expand Down Expand Up @@ -1078,3 +1213,120 @@
(rule 0 (gen_slidedown_half (ty_vec_fits_in_register ty) src)
(if-let amt (u64_udiv (ty_lane_count ty) 2))
(rv_vslidedown_vx src (imm $I64 amt) (unmasked) ty))


;; Expands a mask into SEW wide lanes. Enabled lanes are set to all ones, disabled
;; lanes are set to all zeros.
(decl gen_expand_mask (Type VReg) VReg)
(rule (gen_expand_mask ty mask)
(if-let zero (imm5_from_i8 0))
(if-let neg1 (imm5_from_i8 -1))
(rv_vmerge_vim (rv_vmv_vi zero ty) neg1 mask ty))


;; Builds a vector mask corresponding to the IntCC operation.
(decl gen_icmp_mask (Type IntCC Value Value) VReg)

;; IntCC.Equal

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.Equal) x y)
(rv_vmseq_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.Equal) x (splat y))
(rv_vmseq_vx x y (unmasked) ty))

(rule 2 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.Equal) (splat x) y)
(rv_vmseq_vx y x (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.Equal) x (replicated_imm5 y))
(rv_vmseq_vi x y (unmasked) ty))

(rule 4 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.Equal) (replicated_imm5 x) y)
(rv_vmseq_vi y x (unmasked) ty))

;; IntCC.NotEqual

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.NotEqual) x y)
(rv_vmsne_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.NotEqual) x (splat y))
(rv_vmsne_vx x y (unmasked) ty))

(rule 2 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.NotEqual) (splat x) y)
(rv_vmsne_vx y x (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.NotEqual) x (replicated_imm5 y))
(rv_vmsne_vi x y (unmasked) ty))

(rule 4 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.NotEqual) (replicated_imm5 x) y)
(rv_vmsne_vi y x (unmasked) ty))

;; IntCC.UnsignedLessThan

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedLessThan) x y)
(rv_vmsltu_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedLessThan) x (splat y))
(rv_vmsltu_vx x y (unmasked) ty))

;; IntCC.SignedLessThan

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedLessThan) x y)
(rv_vmslt_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedLessThan) x (splat y))
(rv_vmslt_vx x y (unmasked) ty))

;; IntCC.UnsignedLessThanOrEqual

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedLessThanOrEqual) x y)
(rv_vmsleu_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedLessThanOrEqual) x (splat y))
(rv_vmsleu_vx x y (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedLessThanOrEqual) x (replicated_imm5 y))
(rv_vmsleu_vi x y (unmasked) ty))

;; IntCC.SignedLessThanOrEqual

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedLessThanOrEqual) x y)
(rv_vmsle_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedLessThanOrEqual) x (splat y))
(rv_vmsle_vx x y (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedLessThanOrEqual) x (replicated_imm5 y))
(rv_vmsle_vi x y (unmasked) ty))

;; IntCC.UnsignedGreaterThan

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedGreaterThan) x y)
(rv_vmsgtu_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedGreaterThan) x (splat y))
(rv_vmsgtu_vx x y (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedGreaterThan) x (replicated_imm5 y))
(rv_vmsgtu_vi x y (unmasked) ty))

;; IntCC.SignedGreaterThan

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedGreaterThan) x y)
(rv_vmsgt_vv x y (unmasked) ty))

(rule 1 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedGreaterThan) x (splat y))
(rv_vmsgt_vx x y (unmasked) ty))

(rule 3 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedGreaterThan) x (replicated_imm5 y))
(rv_vmsgt_vi x y (unmasked) ty))

;; IntCC.UnsignedGreaterThanOrEqual

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.UnsignedGreaterThanOrEqual) x y)
(rv_vmsgeu_vv x y (unmasked) ty))

;; IntCC.SignedGreaterThanOrEqual

(rule 0 (gen_icmp_mask (ty_vec_fits_in_register ty) (IntCC.SignedGreaterThanOrEqual) x y)
(rv_vmsge_vv x y (unmasked) ty))
Loading