Skip to content

Commit

Permalink
riscv64: Refactor and simplify some branches/fcmp (bytecodealliance#7142
Browse files Browse the repository at this point in the history
)

* riscv64: Add codegen tests for fcmp and fcmp branches

* riscv64: Refactor and simplify some branches/fcmp

This commit is aimed at simplifying the layers necessary to generate a
branch for a `brif` statement. Previously this involved a
`lower_cond_br` helper which bottomed out in emitting a `CondBr`
instruction. The intention here was to cut all that out and emit a
`CondBr` directly.

Along the way I've additionally taken the liberty of simplifying `fcmp`
as well. This moves the "prefer ordered compares" logic into `emit_fcmp`
so it can benefit the `fcmp` instruction as well. This additionally
trimmed some abstractions around branches which shouldn't be necessary
any longer.
  • Loading branch information
alexcrichton committed Oct 4, 2023
1 parent 5c60d64 commit 56f9381
Show file tree
Hide file tree
Showing 4 changed files with 813 additions and 134 deletions.
179 changes: 60 additions & 119 deletions cranelift/codegen/src/isa/riscv64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -2651,9 +2651,6 @@
(decl int_zero_reg (Type) ValueRegs)
(extern constructor int_zero_reg int_zero_reg)

(decl lower_cond_br (IntCC ValueRegs MachLabelSlice Type) Unit)
(extern constructor lower_cond_br lower_cond_br)

;; Convert a truthy value, possibly of more than one register (an I128), to
;; one register.
;;
Expand All @@ -2668,31 +2665,44 @@
(hi XReg (value_regs_get regs 1)))
(rv_or lo hi)))

;; Consume a CmpResult, producing a branch on its result.
(decl cond_br (IntegerCompare CondBrTarget CondBrTarget) SideEffectNoResult)
(rule (cond_br cmp then else)
(SideEffectNoResult.Inst
(MInst.CondBr then else cmp)))

;; Helper for emitting the `j` mnemonic, an unconditional jump to label.
(decl rv_j (MachLabel) SideEffectNoResult)
(rule (rv_j label)
(SideEffectNoResult.Inst (MInst.Jal label)))

;; Construct an IntegerCompare value.
(decl int_compare (IntCC XReg XReg) IntegerCompare)
(extern constructor int_compare int_compare)

(decl label_to_br_target (MachLabel) CondBrTarget)
(extern constructor label_to_br_target label_to_br_target)
(convert MachLabel CondBrTarget label_to_br_target)

(decl partial lower_branch (Inst MachLabelSlice) Unit)
(rule (lower_branch (jump _) (single_target label))
(emit_side_effect (SideEffectNoResult.Inst (MInst.Jal label))))
(emit_side_effect (rv_j label)))

;; Default behavior for branching based on an input value.
(rule (lower_branch (brif v @ (value_type (fits_in_64 ty)) _ _) targets)
(lower_cond_br (IntCC.NotEqual) (zext v) targets ty))
(rule 2 (lower_branch (brif v @ (value_type $I128)_ _) targets)
(lower_cond_br (IntCC.NotEqual) (truthy_to_reg v) targets $I64))
(rule (lower_branch (brif v @ (value_type (fits_in_64 ty)) _ _) (two_targets then else))
(emit_side_effect (cond_br (int_compare (IntCC.NotEqual) (zext v) (zero_reg)) then else)))
(rule 2 (lower_branch (brif v @ (value_type $I128)_ _) (two_targets then else))
(emit_side_effect (cond_br (int_compare (IntCC.NotEqual) (truthy_to_reg v) (zero_reg)) then else)))

;; Branching on the result of an fcmp
(rule 1
(lower_branch (brif (maybe_uextend (fcmp cc a @ (value_type ty) b)) _ _) (two_targets then else))
(if-let $true (floatcc_unordered cc))
(emit_side_effect (cond_br (emit_fcmp (floatcc_complement cc) ty a b) else then)))

(rule 1
(lower_branch (brif (maybe_uextend (fcmp cc a @ (value_type ty) b)) _ _) (two_targets then else))
(if-let $false (floatcc_unordered cc))
;; Branching on the result of an fcmp.
(rule 1 (lower_branch (brif (maybe_uextend (fcmp cc a @ (value_type ty) b)) _ _) (two_targets then else))
(emit_side_effect (cond_br (emit_fcmp cc ty a b) then else)))

(decl fcmp_to_compare (FCmp) IntegerCompare)
(rule (fcmp_to_compare (FCmp.One r)) (int_compare (IntCC.NotEqual) r (zero_reg)))
(rule (fcmp_to_compare (FCmp.Zero r)) (int_compare (IntCC.Equal) r (zero_reg)))
(convert FCmp IntegerCompare fcmp_to_compare)


(decl lower_br_table (Reg MachLabelSlice) Unit)
(extern constructor lower_br_table lower_br_table)
Expand Down Expand Up @@ -2857,131 +2867,62 @@

;;;; Helpers for floating point comparisons ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(decl not (XReg) XReg)
(rule (not x) (rv_xori x (imm_from_bits 1)))

(decl is_not_nan (Type FReg) XReg)
(rule (is_not_nan ty a) (rv_feq ty a a))

(decl ordered (Type FReg FReg) XReg)
(rule (ordered ty a b) (rv_and (is_not_nan ty a) (is_not_nan ty b)))

(type CmpResult (enum
(Result
(result XReg)
(invert bool))))

;; Wrapper for the common case when constructing comparison results. It assumes
;; that the result isn't negated.
(decl cmp_result (XReg) CmpResult)
(rule (cmp_result result) (CmpResult.Result result $false))

;; Wrapper for the case where it's more convenient to construct the negated
;; version of the comparison.
(decl cmp_result_invert (XReg) CmpResult)
(rule (cmp_result_invert result) (CmpResult.Result result $true))

;; Consume a CmpResult, producing a branch on its result.
(decl cond_br (CmpResult CondBrTarget CondBrTarget) SideEffectNoResult)
(rule (cond_br cmp then else)
(SideEffectNoResult.Inst
(MInst.CondBr then else (cmp_integer_compare cmp))))

;; Construct an IntegerCompare value.
(decl int_compare (IntCC XReg XReg) IntegerCompare)
(extern constructor int_compare int_compare)

;; Convert a comparison into a branch test.
(decl cmp_integer_compare (CmpResult) IntegerCompare)

(rule
(cmp_integer_compare (CmpResult.Result res $false))
(int_compare (IntCC.NotEqual) res (zero_reg)))

(rule
(cmp_integer_compare (CmpResult.Result res $true))
(int_compare (IntCC.Equal) res (zero_reg)))
(type FCmp (enum
;; The comparison succeeded if `r` is one
(One (r XReg))
;; The comparison succeeded if `r` is zero
(Zero (r XReg))
))

;; Convert a comparison into a boolean value.
(decl cmp_value (CmpResult) XReg)
(rule (cmp_value (CmpResult.Result res $false)) res)
(rule (cmp_value (CmpResult.Result res $true)) (not res))
(decl fcmp_invert (FCmp) FCmp)
(rule (fcmp_invert (FCmp.One r)) (FCmp.Zero r))
(rule (fcmp_invert (FCmp.Zero r)) (FCmp.One r))

;; Compare two floating point numbers and return a zero/non-zero result.
(decl emit_fcmp (FloatCC Type FReg FReg) CmpResult)
(decl emit_fcmp (FloatCC Type FReg FReg) FCmp)

;; a is not nan && b is not nan
(rule
(emit_fcmp (FloatCC.Ordered) ty a b)
(cmp_result (ordered ty a b)))
;; Direct codegen for unordered comparisons is not that efficient, so invert
;; the comparison to get an ordered comparison and generate that. Then invert
;; the result to produce the final fcmp result.
(rule 0 (emit_fcmp cc ty a b)
(if-let $true (floatcc_unordered cc))
(fcmp_invert (emit_fcmp (floatcc_complement cc) ty a b)))

;; a is nan || b is nan
;; == !(a is not nan && b is not nan)
(rule
(emit_fcmp (FloatCC.Unordered) ty a b)
(cmp_result_invert (ordered ty a b)))
;; a is not nan && b is not nan
(rule 1 (emit_fcmp (FloatCC.Ordered) ty a b)
(FCmp.One (ordered ty a b)))

;; a == b
(rule
(emit_fcmp (FloatCC.Equal) ty a b)
(cmp_result (rv_feq ty a b)))
(rule 1 (emit_fcmp (FloatCC.Equal) ty a b)
(FCmp.One (rv_feq ty a b)))

;; a != b
;; == !(a == b)
(rule
(emit_fcmp (FloatCC.NotEqual) ty a b)
(cmp_result_invert (rv_feq ty a b)))
(rule 1 (emit_fcmp (FloatCC.NotEqual) ty a b)
(FCmp.Zero (rv_feq ty a b)))

;; a < b || a > b
(rule
(emit_fcmp (FloatCC.OrderedNotEqual) ty a b)
(cmp_result (rv_or (rv_flt ty a b) (rv_fgt ty a b))))

;; !(ordered a b) || a == b
(rule
(emit_fcmp (FloatCC.UnorderedOrEqual) ty a b)
(cmp_result (rv_or (not (ordered ty a b)) (rv_feq ty a b))))
(rule 1 (emit_fcmp (FloatCC.OrderedNotEqual) ty a b)
(FCmp.One (rv_or (rv_flt ty a b) (rv_fgt ty a b))))

;; a < b
(rule
(emit_fcmp (FloatCC.LessThan) ty a b)
(cmp_result (rv_flt ty a b)))
(rule 1 (emit_fcmp (FloatCC.LessThan) ty a b)
(FCmp.One (rv_flt ty a b)))

;; a <= b
(rule
(emit_fcmp (FloatCC.LessThanOrEqual) ty a b)
(cmp_result (rv_fle ty a b)))
(rule 1 (emit_fcmp (FloatCC.LessThanOrEqual) ty a b)
(FCmp.One (rv_fle ty a b)))

;; a > b
(rule
(emit_fcmp (FloatCC.GreaterThan) ty a b)
(cmp_result (rv_fgt ty a b)))
(rule 1 (emit_fcmp (FloatCC.GreaterThan) ty a b)
(FCmp.One (rv_fgt ty a b)))

;; a >= b
(rule
(emit_fcmp (FloatCC.GreaterThanOrEqual) ty a b)
(cmp_result (rv_fge ty a b)))

;; !(ordered a b) || a < b
;; == !(ordered a b && a >= b)
(rule
(emit_fcmp (FloatCC.UnorderedOrLessThan) ty a b)
(cmp_result_invert (rv_and (ordered ty a b) (rv_fge ty a b))))

;; !(ordered a b) || a <= b
;; == !(ordered a b && a > b)
(rule
(emit_fcmp (FloatCC.UnorderedOrLessThanOrEqual) ty a b)
(cmp_result_invert (rv_and (ordered ty a b) (rv_fgt ty a b))))

;; !(ordered a b) || a > b
;; == !(ordered a b && a <= b)
(rule
(emit_fcmp (FloatCC.UnorderedOrGreaterThan) ty a b)
(cmp_result_invert (rv_and (ordered ty a b) (rv_fle ty a b))))

;; !(ordered a b) || a >= b
;; == !(ordered a b && a < b)
(rule
(emit_fcmp (FloatCC.UnorderedOrGreaterThanOrEqual) ty a b)
(cmp_result_invert (rv_and (ordered ty a b) (rv_flt ty a b))))
(rule 1 (emit_fcmp (FloatCC.GreaterThanOrEqual) ty a b)
(FCmp.One (rv_fge ty a b)))
6 changes: 5 additions & 1 deletion cranelift/codegen/src/isa/riscv64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1855,7 +1855,11 @@

;;;;; Rules for `fcmp`;;;;;;;;;
(rule 0 (lower (fcmp cc x @ (value_type (ty_scalar_float ty)) y))
(cmp_value (emit_fcmp cc ty x y)))
(lower_fcmp (emit_fcmp cc ty x y)))

(decl lower_fcmp (FCmp) XReg)
(rule (lower_fcmp (FCmp.One r)) r)
(rule (lower_fcmp (FCmp.Zero r)) (rv_seqz r))

(rule 1 (lower (fcmp cc x @ (value_type (ty_vec_fits_in_register ty)) y))
(gen_expand_mask ty (gen_fcmp_mask ty cc x y)))
Expand Down
12 changes: 0 additions & 12 deletions cranelift/codegen/src/isa/riscv64/lower/isle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,6 @@ impl generated_code::Context for RV64IsleContext<'_, '_, MInst, Riscv64Backend>
}
}

fn lower_cond_br(&mut self, cc: &IntCC, a: ValueRegs, targets: &[MachLabel], ty: Type) -> Unit {
MInst::lower_br_icmp(
*cc,
a,
self.int_zero_reg(ty),
CondBrTarget::Label(targets[0]),
CondBrTarget::Label(targets[1]),
ty,
)
.iter()
.for_each(|i| self.emit(i));
}
fn load_ra(&mut self) -> Reg {
if self.backend.flags.preserve_frame_pointers() {
let tmp = self.temp_writable_reg(I64);
Expand Down
Loading

0 comments on commit 56f9381

Please sign in to comment.