|
1 | 1 | import * as ts from "typescript"; |
2 | | -import { LuaTarget } from "../../CompilerOptions"; |
3 | 2 | import * as lua from "../../LuaAST"; |
4 | 3 | import { FunctionVisitor } from "../context"; |
5 | | -import { unsupportedForTarget } from "../utils/diagnostics"; |
6 | 4 | import { performHoisting, popScope, pushScope, ScopeType } from "../utils/scope"; |
7 | 5 |
|
8 | | -export const transformSwitchStatement: FunctionVisitor<ts.SwitchStatement> = (statement, context) => { |
9 | | - if (context.luaTarget === LuaTarget.Universal || context.luaTarget === LuaTarget.Lua51) { |
10 | | - context.diagnostics.push(unsupportedForTarget(statement, "Switch statements", LuaTarget.Lua51)); |
| 6 | +const containsBreakStatement = (statements: ts.Node[]): boolean => { |
| 7 | + for (const s of statements) { |
| 8 | + if ( |
| 9 | + ts.isSwitchStatement(s) || |
| 10 | + ts.isWhileStatement(s) || |
| 11 | + ts.isDoStatement(s) || |
| 12 | + ts.isForStatement(s) || |
| 13 | + ts.isForInStatement(s) || |
| 14 | + ts.isForOfStatement(s) |
| 15 | + ) { |
| 16 | + // Ignore: Break statements are valid as children of these |
| 17 | + // statements without breaking the clause |
| 18 | + } else if (ts.isBreakStatement(s)) { |
| 19 | + return true; |
| 20 | + } else if (containsBreakStatement(s.getChildren())) { |
| 21 | + return true; |
| 22 | + } |
11 | 23 | } |
12 | 24 |
|
| 25 | + return false; |
| 26 | +}; |
| 27 | + |
| 28 | +export const transformSwitchStatement: FunctionVisitor<ts.SwitchStatement> = (statement, context) => { |
13 | 29 | const scope = pushScope(context, ScopeType.Switch); |
14 | 30 |
|
15 | 31 | // Give the switch a unique name to prevent nested switches from acting up. |
16 | 32 | const switchName = `____switch${scope.id}`; |
17 | 33 | const switchVariable = lua.createIdentifier(switchName); |
18 | 34 |
|
19 | | - let statements: lua.Statement[] = []; |
| 35 | + // Collect the fallthrough bodies for each case as defined by the switch. |
| 36 | + const caseBody: lua.Statement[][] = []; |
| 37 | + for (let i = 0; i < statement.caseBlock.clauses.length; i++) { |
| 38 | + const end = statement.caseBlock.clauses |
| 39 | + .slice(i) |
| 40 | + .findIndex(clause => containsBreakStatement([...clause.statements])); |
| 41 | + caseBody[i] = statement.caseBlock.clauses |
| 42 | + .slice(i, end >= 0 ? end + i + 1 : undefined) |
| 43 | + .reduce<lua.Statement[]>( |
| 44 | + (statements, clause) => [ |
| 45 | + ...statements, |
| 46 | + lua.createDoStatement(context.transformStatements(clause.statements)), |
| 47 | + ], |
| 48 | + [] |
| 49 | + ); |
| 50 | + } |
20 | 51 |
|
21 | 52 | // Starting from the back, concatenating ifs into one big if/elseif statement |
22 | | - const concatenatedIf = statement.caseBlock.clauses.reduceRight((previousCondition, clause, index) => { |
23 | | - if (ts.isDefaultClause(clause)) { |
24 | | - // Skip default clause here (needs to be included to ensure index lines up with index later) |
25 | | - return previousCondition; |
26 | | - } |
27 | | - |
28 | | - // If the clause condition holds, go to the correct label |
29 | | - const condition = lua.createBinaryExpression( |
30 | | - switchVariable, |
31 | | - context.transformExpression(clause.expression), |
32 | | - lua.SyntaxKind.EqualityOperator |
33 | | - ); |
| 53 | + const defaultIndex = statement.caseBlock.clauses.findIndex(c => ts.isDefaultClause(c)); |
| 54 | + const concatenatedIf = statement.caseBlock.clauses.reduceRight<lua.IfStatement | lua.Block | undefined>( |
| 55 | + (previousCondition, clause, index) => { |
| 56 | + if (ts.isDefaultClause(clause)) { |
| 57 | + // Skip default clause here (needs to be included to ensure index lines up with index later) |
| 58 | + return previousCondition; |
| 59 | + } |
34 | 60 |
|
35 | | - const goto = lua.createGotoStatement(`${switchName}_case_${index}`); |
36 | | - return lua.createIfStatement(condition, lua.createBlock([goto]), previousCondition); |
37 | | - }, undefined as lua.IfStatement | undefined); |
| 61 | + // If the clause condition holds, go to the correct label |
| 62 | + const condition = lua.createBinaryExpression( |
| 63 | + switchVariable, |
| 64 | + context.transformExpression(clause.expression), |
| 65 | + lua.SyntaxKind.EqualityOperator |
| 66 | + ); |
38 | 67 |
|
39 | | - if (concatenatedIf) { |
40 | | - statements.push(concatenatedIf); |
41 | | - } |
| 68 | + return lua.createIfStatement(condition, lua.createBlock(caseBody[index]), previousCondition); |
| 69 | + }, |
| 70 | + defaultIndex >= 0 ? lua.createBlock(caseBody[defaultIndex]) : undefined |
| 71 | + ); |
42 | 72 |
|
43 | | - const hasDefaultCase = statement.caseBlock.clauses.some(ts.isDefaultClause); |
44 | | - statements.push(lua.createGotoStatement(`${switchName}_${hasDefaultCase ? "case_default" : "end"}`)); |
| 73 | + let statements: lua.Statement[] = []; |
45 | 74 |
|
46 | | - for (const [index, clause] of statement.caseBlock.clauses.entries()) { |
47 | | - const labelName = `${switchName}_case_${ts.isCaseClause(clause) ? index : "default"}`; |
48 | | - statements.push(lua.createLabelStatement(labelName)); |
49 | | - statements.push(lua.createDoStatement(context.transformStatements(clause.statements))); |
| 75 | + if (concatenatedIf) { |
| 76 | + statements.push(concatenatedIf as unknown as lua.IfStatement); |
50 | 77 | } |
51 | 78 |
|
52 | | - statements.push(lua.createLabelStatement(`${switchName}_end`)); |
53 | | - |
54 | 79 | statements = performHoisting(context, statements); |
55 | 80 | popScope(context); |
56 | 81 |
|
|
0 commit comments