Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch statement improvements #2126

Merged
merged 9 commits into from
Dec 9, 2022
22 changes: 14 additions & 8 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1807,23 +1807,29 @@ impl<'a, W: Write> Writer<'a, W> {
for case in cases {
match case.value {
crate::SwitchValue::Integer(value) => {
writeln!(self.out, "{}case {}{}:", l2, value, type_postfix)?
write!(self.out, "{}case {}{}:", l2, value, type_postfix)?
}
crate::SwitchValue::Default => writeln!(self.out, "{}default:", l2)?,
crate::SwitchValue::Default => write!(self.out, "{}default:", l2)?,
}

let write_block_braces = !(case.fall_through && case.body.is_empty());
if write_block_braces {
writeln!(self.out, " {{")?;
} else {
writeln!(self.out)?;
}

for sta in case.body.iter() {
self.write_stmt(sta, ctx, l2.next())?;
}

// Write fallthrough comment if the case is fallthrough,
// otherwise write a break, if the case is not already
// broken out of at the end of its body.
if case.fall_through {
writeln!(self.out, "{}/* fallthrough */", l2.next())?;
jimblandy marked this conversation as resolved.
Show resolved Hide resolved
} else if case.body.last().map_or(true, |s| !s.is_terminator()) {
if !case.fall_through && case.body.last().map_or(true, |s| !s.is_terminator()) {
writeln!(self.out, "{}break;", l2.next())?;
}

if write_block_braces {
writeln!(self.out, "{}}}", l2)?;
}
}

writeln!(self.out, "{}}}", level)?
Expand Down
47 changes: 40 additions & 7 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1825,18 +1825,47 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

for (i, case) in cases.iter().enumerate() {
match case.value {
crate::SwitchValue::Integer(value) => writeln!(
crate::SwitchValue::Integer(value) => write!(
self.out,
"{}case {}{}: {{",
"{}case {}{}:",
indent_level_1, value, type_postfix
)?,
crate::SwitchValue::Default => {
writeln!(self.out, "{}default: {{", indent_level_1)?
write!(self.out, "{}default:", indent_level_1)?
}
}

// FXC doesn't support fallthrough so we duplicate the body of the following case blocks
if case.fall_through {
// The new block is not only stylistic, it plays a role here:
// We might end up having to write the same case body
// multiple times due to FXC not supporting fallthrough.
// Therefore, some `Expression`s written by `Statement::Emit`
// will end up having the same name (`_expr<handle_index>`).
// So we need to put each case in its own scope.
let write_block_braces = !(case.fall_through && case.body.is_empty());
if write_block_braces {
writeln!(self.out, " {{")?;
jimblandy marked this conversation as resolved.
Show resolved Hide resolved
} else {
writeln!(self.out)?;
}

// Although FXC does support a series of case clauses before
// a block[^yes], it does not support fallthrough from a
// non-empty case block to the next[^no]. If this case has a
// non-empty body with a fallthrough, emulate that by
// duplicating the bodies of all the cases it would fall
// into as extensions of this case's own body. This makes
// the HLSL output potentially quadratic in the size of the
// Naga IR.
//
// [^yes]: ```hlsl
// case 1:
// case 2: do_stuff()
// ```
// [^no]: ```hlsl
// case 1: do_this();
// case 2: do_that();
// ```
if case.fall_through && !case.body.is_empty() {
jimblandy marked this conversation as resolved.
Show resolved Hide resolved
let curr_len = i + 1;
let end_case_idx = curr_len
+ cases
Expand All @@ -1861,12 +1890,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
for sta in case.body.iter() {
self.write_stmt(module, sta, func_ctx, indent_level_2)?;
}
if case.body.last().map_or(true, |s| !s.is_terminator()) {
if !case.fall_through
&& case.body.last().map_or(true, |s| !s.is_terminator())
{
writeln!(self.out, "{}break;", indent_level_2)?;
}
}

writeln!(self.out, "{}}}", indent_level_1)?;
if write_block_braces {
writeln!(self.out, "{}}}", indent_level_1)?;
}
}

writeln!(self.out, "{}}}", level)?
Expand Down
17 changes: 14 additions & 3 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2546,19 +2546,30 @@ impl<W: Write> Writer<W> {
for case in cases.iter() {
match case.value {
crate::SwitchValue::Integer(value) => {
writeln!(self.out, "{}case {}{}: {{", lcase, value, type_postfix)?;
write!(self.out, "{}case {}{}:", lcase, value, type_postfix)?;
}
crate::SwitchValue::Default => {
writeln!(self.out, "{}default: {{", lcase)?;
write!(self.out, "{}default:", lcase)?;
}
}

let write_block_braces = !(case.fall_through && case.body.is_empty());
if write_block_braces {
writeln!(self.out, " {{")?;
} else {
writeln!(self.out)?;
}

self.put_block(lcase.next(), &case.body, context)?;
if !case.fall_through
&& case.body.last().map_or(true, |s| !s.is_terminator())
{
writeln!(self.out, "{}break;", lcase.next())?;
}
writeln!(self.out, "{}}}", lcase)?;

if write_block_braces {
writeln!(self.out, "{}}}", lcase)?;
}
}
writeln!(self.out, "{}}}", level)?;
}
Expand Down
51 changes: 25 additions & 26 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1680,32 +1680,37 @@ impl<'w> BlockContext<'w> {
spirv::SelectionControl::NONE,
));

let default_id = self.gen_id();
let mut default_id = None;
// id of previous empty fall-through case
let mut last_id = None;

let mut reached_default = false;
let mut raw_cases = Vec::with_capacity(cases.len());
let mut case_ids = Vec::with_capacity(cases.len());
for case in cases.iter() {
// take id of previous empty fall-through case or generate a new one
let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
jimblandy marked this conversation as resolved.
Show resolved Hide resolved

if case.fall_through && case.body.is_empty() {
last_id = Some(label_id);
}

case_ids.push(label_id);

match case.value {
crate::SwitchValue::Integer(value) => {
let label_id = self.gen_id();
// No cases should be added after the default case is encountered
// since the default case catches all
if !reached_default {
raw_cases.push(super::instructions::Case {
value: value as Word,
label_id,
});
}
case_ids.push(label_id);
raw_cases.push(super::instructions::Case {
value: value as Word,
label_id,
});
}
crate::SwitchValue::Default => {
case_ids.push(default_id);
reached_default = true;
default_id = Some(label_id);
}
}
}

let default_id = default_id.unwrap();

self.function.consume(
block,
Instruction::switch(selector_id, default_id, &raw_cases),
Expand All @@ -1716,7 +1721,12 @@ impl<'w> BlockContext<'w> {
..loop_context
};

for (i, (case, label_id)) in cases.iter().zip(case_ids.iter()).enumerate() {
for (i, (case, label_id)) in cases
.iter()
.zip(case_ids.iter())
.filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
.enumerate()
{
let case_finish_id = if case.fall_through {
case_ids[i + 1]
} else {
Expand All @@ -1732,17 +1742,6 @@ impl<'w> BlockContext<'w> {
)?;
}

// If no default was encountered write a empty block to satisfy the presence of
// a block the default label
jimblandy marked this conversation as resolved.
Show resolved Hide resolved
if !reached_default {
self.write_block(
default_id,
&[],
BlockExit::Branch { target: merge_id },
inner_context,
)?;
}

block = Block::new(merge_id);
}
crate::Statement::Loop {
Expand Down
48 changes: 35 additions & 13 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -880,25 +880,47 @@ impl<W: Write> Writer<W> {
};

let l2 = level.next();
if !cases.is_empty() {
for case in cases {
match case.value {
crate::SwitchValue::Integer(value) => {
writeln!(self.out, "{}case {}{}: {{", l2, value, type_postfix)?;
let mut new_case = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: It's a bummer that we're not removing fall_through, too; the complexity. From our sync conversation, you mentioned that it wasn't feasible to remove at this time, because IR needs to keep it for other front ends (IIRC?). Is there an issue to track this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That but also the fact that we still translate a switch with multiple cases with the same body as a sequence of switch cases with no body and fall_through = true ending in a case with the body and fall_through = false.
i.e.

switch(v) {
    case 0, 1: { p = 3; }
}

V

SwitchCase { value: 0, fall_through = true, body: {} },
SwitchCase { value: 1, fall_through = false, body: { p = 3 } }

I'll add some more docs on the Switch statement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idea(non-blocking): Maybe we could change SwitchCase model to remove fall_through and make body an Option? Concretely:

/// <snip>
struct SwitchCase {
    // …

    /// When [`None`], a subsequent case should contain the body associated with this case.
    body: Option<Block>, // replaces `fall_through` and current definition of `body`
}

A one-off enum might also serve well here, but I'm less certain about that.

Copy link
Member Author

@teoxoy teoxoy Nov 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other frontends still make use of fall_through with non empty bodies though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code that decides when to write "case" is fine, but it does seem like the IR is not making the writer's life easy. We might be able to improve the IR, but that should be a separate PR.

Suggestion: Would it work to do something like this?

// Gather cases into groups that all execute the same block.
// Only the last case in each group has a non-empty body.
for case_group in cases.split_inclusive(|case| !(case.fall_through && case.body.is_empty())) {
    write!("case ");
    for value in case_group.iter().map(|case| case.value) {
        write!("{}, ", value);
    }
    writeln!(":");
    write_block(case_group.last().body);
}

(I'm just making this up, please check my logic carefully!)

Granted, this would need to be changed when we implement fallthrough, but I think if we tradesplit_inclusive for .enumerate().filter() and then gather up all the consecutive bodies we need to concatenate, we could keep this as a loop whose iterations correspond to emitted "case" statements, not IR cases, which is what I think makes it clearer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code that decides when to write "case" is fine, but it does seem like the IR is not making the writer's life easy. We might be able to improve the IR, but that should be a separate PR.

I thought about this but I could not come up with a satisfactory change because we still need to support "real" fall-through for GLSL and SPV.

Granted, this would need to be changed when we implement fallthrough, but I think if we trade split_inclusive for .enumerate().filter() and then gather up all the consecutive bodies we need to concatenate, we could keep this as a loop whose iterations correspond to emitted "case" statements, not IR cases, which is what I think makes it clearer.

I don't see how .enumerate().filter() would work since we still need to emit the case values of the fall_through cases. As far as I can tell the only way to do this would be to allocate a new structure and have one pass restructuring the data then another to write it which I don't think is really worth it.

for case in cases {
if case.fall_through && !case.body.is_empty() {
// TODO: we could do the same workaround as we did for the HLSL backend
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: I know that in wgpu, cwfitzgerald requests that TODOs be either resolved or converted to issues. I see that there are lots of other TODOs in the repo, currently; does naga maintainership take a stance on "WIP"-ish comments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hasn't been brought up as far as I'm aware. It would be nice but I think we should do a round of converting all TODOs to issues first (which has been on my mind for a while but not a priority currently).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't introduce any new TODO comments to Naga. The issue tracker is a much better tool for keeping track of these sorts of plans and ideas, so please file an issue for this. (Instead of a "TODO" comment, a link to an issue is fine, if you think it would be helpful to future readers.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return Err(Error::Unimplemented(
"fall-through switch case block".into(),
));
}

match case.value {
crate::SwitchValue::Integer(value) => {
if new_case {
write!(self.out, "{}case ", l2)?;
}
crate::SwitchValue::Default => {
writeln!(self.out, "{}default: {{", l2)?;
write!(self.out, "{}{}", value, type_postfix)?;
}
crate::SwitchValue::Default => {
if new_case {
if case.fall_through {
write!(self.out, "{}case ", l2)?;
} else {
write!(self.out, "{}", l2)?;
}
}
write!(self.out, "default")?;
}
}

for sta in case.body.iter() {
self.write_stmt(module, sta, func_ctx, l2.next())?;
}
new_case = !case.fall_through;

if case.fall_through {
writeln!(self.out, "{}fallthrough;", l2.next())?;
}
if case.fall_through {
write!(self.out, ", ")?;
} else {
writeln!(self.out, ": {{")?;
}

for sta in case.body.iter() {
self.write_stmt(module, sta, func_ctx, l2.next())?;
}

if !case.fall_through {
writeln!(self.out, "{}}}", l2)?;
}
}
Expand Down
Loading