Skip to content

Commit 73a4ed0

Browse files
authored
Hoisting switch fix (#1125)
* fixes and refactors for hoisting to fix switch statements and make logic more clear * applied hoisting fix to new switch implementation also refactored a few things for clarity and added tests * fixed issues with hoisting from default clause * added snapshots for a couple tests and a comment * fixed edge case with hoisting in a solo default clause * addressing review feedback Co-authored-by: Tom <tomblind@users.noreply.github.com>
1 parent a8325da commit 73a4ed0

File tree

5 files changed

+360
-55
lines changed

5 files changed

+360
-55
lines changed

src/transformation/utils/lua-ast.ts

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ export function createLocalOrExportedOrGlobalDeclaration(
175175
const isTopLevelVariable = scope.type === ScopeType.File;
176176

177177
if (context.isModule || !isTopLevelVariable) {
178-
if (scope.type === ScopeType.Switch || (!isFunctionDeclaration && hasMultipleReferences(scope, lhs))) {
178+
if (!isFunctionDeclaration && hasMultipleReferences(scope, lhs)) {
179179
// Split declaration and assignment of identifiers that reference themselves in their declaration
180180
declaration = lua.createVariableDeclarationStatement(lhs, undefined, tsOriginal);
181181
if (rhs) {
@@ -185,15 +185,13 @@ export function createLocalOrExportedOrGlobalDeclaration(
185185
declaration = lua.createVariableDeclarationStatement(lhs, rhs, tsOriginal);
186186
}
187187

188-
// Remember local variable declarations for hoisting later
189-
if (!scope.variableDeclarations) {
190-
scope.variableDeclarations = [];
191-
}
192-
193-
scope.variableDeclarations.push(declaration);
188+
if (!isFunctionDeclaration) {
189+
// Remember local variable declarations for hoisting later
190+
if (!scope.variableDeclarations) {
191+
scope.variableDeclarations = [];
192+
}
194193

195-
if (scope.type === ScopeType.Switch) {
196-
declaration = undefined;
194+
scope.variableDeclarations.push(declaration);
197195
}
198196
} else if (rhs) {
199197
// global

src/transformation/utils/scope.ts

Lines changed: 85 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ export interface Scope {
3333
functionReturned?: boolean;
3434
}
3535

36+
export interface HoistingResult {
37+
statements: lua.Statement[];
38+
hoistedStatements: lua.Statement[];
39+
hoistedIdentifiers: lua.Identifier[];
40+
}
41+
3642
const scopeStacks = new WeakMap<TransformationContext, Scope[]>();
3743
function getScopeStack(context: TransformationContext): Scope[] {
3844
return getOrUpdate(scopeStacks, context, () => []);
@@ -133,16 +139,47 @@ export function isFunctionScopeWithDefinition(scope: Scope): scope is Scope & {
133139
return scope.node !== undefined && ts.isFunctionLike(scope.node);
134140
}
135141

136-
export function performHoisting(context: TransformationContext, statements: lua.Statement[]): lua.Statement[] {
142+
export function separateHoistedStatements(context: TransformationContext, statements: lua.Statement[]): HoistingResult {
137143
const scope = peekScope(context);
138-
let result = statements;
139-
result = hoistFunctionDefinitions(context, scope, result);
140-
result = hoistVariableDeclarations(context, scope, result);
141-
result = hoistImportStatements(scope, result);
142-
return result;
144+
const allHoistedStatments: lua.Statement[] = [];
145+
const allHoistedIdentifiers: lua.Identifier[] = [];
146+
147+
let { unhoistedStatements, hoistedStatements, hoistedIdentifiers } = hoistFunctionDefinitions(
148+
context,
149+
scope,
150+
statements
151+
);
152+
allHoistedStatments.push(...hoistedStatements);
153+
allHoistedIdentifiers.push(...hoistedIdentifiers);
154+
155+
({ unhoistedStatements, hoistedIdentifiers } = hoistVariableDeclarations(context, scope, unhoistedStatements));
156+
allHoistedIdentifiers.push(...hoistedIdentifiers);
157+
158+
({ unhoistedStatements, hoistedStatements } = hoistImportStatements(scope, unhoistedStatements));
159+
allHoistedStatments.unshift(...hoistedStatements);
160+
161+
return {
162+
statements: unhoistedStatements,
163+
hoistedStatements: allHoistedStatments,
164+
hoistedIdentifiers: allHoistedIdentifiers,
165+
};
166+
}
167+
168+
export function performHoisting(context: TransformationContext, statements: lua.Statement[]): lua.Statement[] {
169+
const result = separateHoistedStatements(context, statements);
170+
const modifiedStatements = [...result.hoistedStatements, ...result.statements];
171+
if (result.hoistedIdentifiers.length > 0) {
172+
modifiedStatements.unshift(lua.createVariableDeclarationStatement(result.hoistedIdentifiers));
173+
}
174+
return modifiedStatements;
143175
}
144176

145177
function shouldHoistSymbol(context: TransformationContext, symbolId: lua.SymbolId, scope: Scope): boolean {
178+
// Always hoist in top-level of switch statements
179+
if (scope.type === ScopeType.Switch) {
180+
return true;
181+
}
182+
146183
const symbolInfo = getSymbolInfo(context, symbolId);
147184
if (!symbolInfo) {
148185
return false;
@@ -183,65 +220,80 @@ function hoistVariableDeclarations(
183220
context: TransformationContext,
184221
scope: Scope,
185222
statements: lua.Statement[]
186-
): lua.Statement[] {
223+
): { unhoistedStatements: lua.Statement[]; hoistedIdentifiers: lua.Identifier[] } {
187224
if (!scope.variableDeclarations) {
188-
return statements;
225+
return { unhoistedStatements: statements, hoistedIdentifiers: [] };
189226
}
190227

191-
const result = [...statements];
192-
const hoistedLocals: lua.Identifier[] = [];
228+
const unhoistedStatements = [...statements];
229+
const hoistedIdentifiers: lua.Identifier[] = [];
193230
for (const declaration of scope.variableDeclarations) {
194231
const symbols = declaration.left.map(i => i.symbolId).filter(isNonNull);
195232
if (symbols.some(s => shouldHoistSymbol(context, s, scope))) {
196-
const index = result.indexOf(declaration);
197-
assert(index > -1);
233+
const index = unhoistedStatements.indexOf(declaration);
234+
if (index < 0) {
235+
continue; // statements array may not contain all statements in the scope (switch-case)
236+
}
198237

199238
if (declaration.right) {
200239
const assignment = lua.createAssignmentStatement(declaration.left, declaration.right);
201240
lua.setNodePosition(assignment, declaration); // Preserve position info for sourcemap
202-
result.splice(index, 1, assignment);
241+
unhoistedStatements.splice(index, 1, assignment);
203242
} else {
204-
result.splice(index, 1);
243+
unhoistedStatements.splice(index, 1);
205244
}
206245

207-
hoistedLocals.push(...declaration.left);
208-
} else if (scope.type === ScopeType.Switch) {
209-
assert(!declaration.right);
210-
hoistedLocals.push(...declaration.left);
246+
hoistedIdentifiers.push(...declaration.left);
211247
}
212248
}
213249

214-
if (hoistedLocals.length > 0) {
215-
result.unshift(lua.createVariableDeclarationStatement(hoistedLocals));
216-
}
217-
218-
return result;
250+
return { unhoistedStatements, hoistedIdentifiers };
219251
}
220252

221253
function hoistFunctionDefinitions(
222254
context: TransformationContext,
223255
scope: Scope,
224256
statements: lua.Statement[]
225-
): lua.Statement[] {
257+
): { unhoistedStatements: lua.Statement[]; hoistedStatements: lua.Statement[]; hoistedIdentifiers: lua.Identifier[] } {
226258
if (!scope.functionDefinitions) {
227-
return statements;
259+
return { unhoistedStatements: statements, hoistedStatements: [], hoistedIdentifiers: [] };
228260
}
229261

230-
const result = [...statements];
231-
const hoistedFunctions: Array<lua.VariableDeclarationStatement | lua.AssignmentStatement> = [];
262+
const unhoistedStatements = [...statements];
263+
const hoistedStatements: lua.Statement[] = [];
264+
const hoistedIdentifiers: lua.Identifier[] = [];
232265
for (const [functionSymbolId, functionDefinition] of scope.functionDefinitions) {
233266
assert(functionDefinition.definition);
234267

235268
if (shouldHoistSymbol(context, functionSymbolId, scope)) {
236-
const index = result.indexOf(functionDefinition.definition);
237-
result.splice(index, 1);
238-
hoistedFunctions.push(functionDefinition.definition);
269+
const index = unhoistedStatements.indexOf(functionDefinition.definition);
270+
if (index < 0) {
271+
continue; // statements array may not contain all statements in the scope (switch-case)
272+
}
273+
unhoistedStatements.splice(index, 1);
274+
275+
if (lua.isVariableDeclarationStatement(functionDefinition.definition)) {
276+
// Separate function definition and variable declaration
277+
assert(functionDefinition.definition.right);
278+
hoistedIdentifiers.push(...functionDefinition.definition.left);
279+
hoistedStatements.push(
280+
lua.createAssignmentStatement(
281+
functionDefinition.definition.left,
282+
functionDefinition.definition.right
283+
)
284+
);
285+
} else {
286+
hoistedStatements.push(functionDefinition.definition);
287+
}
239288
}
240289
}
241290

242-
return [...hoistedFunctions, ...result];
291+
return { unhoistedStatements, hoistedStatements, hoistedIdentifiers };
243292
}
244293

245-
function hoistImportStatements(scope: Scope, statements: lua.Statement[]): lua.Statement[] {
246-
return scope.importStatements ? [...scope.importStatements, ...statements] : statements;
294+
function hoistImportStatements(
295+
scope: Scope,
296+
statements: lua.Statement[]
297+
): { unhoistedStatements: lua.Statement[]; hoistedStatements: lua.Statement[] } {
298+
return { unhoistedStatements: statements, hoistedStatements: scope.importStatements ?? [] };
247299
}

src/transformation/visitors/switch.ts

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import * as ts from "typescript";
22
import * as lua from "../../LuaAST";
33
import { FunctionVisitor, TransformationContext } from "../context";
4-
import { performHoisting, popScope, pushScope, ScopeType } from "../utils/scope";
4+
import { popScope, pushScope, ScopeType, separateHoistedStatements } from "../utils/scope";
55

66
const containsBreakOrReturn = (nodes: Iterable<ts.Node>): boolean => {
77
for (const s of nodes) {
@@ -55,15 +55,25 @@ export const transformSwitchStatement: FunctionVisitor<ts.SwitchStatement> = (st
5555

5656
// If the switch only has a default clause, wrap it in a single do.
5757
// Otherwise, we need to generate a set of if statements to emulate the switch.
58-
let statements: lua.Statement[] = [];
58+
const statements: lua.Statement[] = [];
59+
const hoistedStatements: lua.Statement[] = [];
60+
const hoistedIdentifiers: lua.Identifier[] = [];
5961
const clauses = statement.caseBlock.clauses;
6062
if (clauses.length === 1 && ts.isDefaultClause(clauses[0])) {
6163
const defaultClause = clauses[0].statements;
6264
if (defaultClause.length) {
63-
statements.push(lua.createDoStatement(context.transformStatements(defaultClause)));
65+
const {
66+
statements: defaultStatements,
67+
hoistedStatements: defaultHoistedStatements,
68+
hoistedIdentifiers: defaultHoistedIdentifiers,
69+
} = separateHoistedStatements(context, context.transformStatements(defaultClause));
70+
hoistedStatements.push(...defaultHoistedStatements);
71+
hoistedIdentifiers.push(...defaultHoistedIdentifiers);
72+
statements.push(lua.createDoStatement(defaultStatements));
6473
}
6574
} else {
6675
// Build up the condition for each if statement
76+
let defaultTransformed = false;
6777
let isInitialCondition = true;
6878
let condition: lua.Expression | undefined = undefined;
6979
for (let i = 0; i < clauses.length; i++) {
@@ -124,10 +134,21 @@ export const transformSwitchStatement: FunctionVisitor<ts.SwitchStatement> = (st
124134
}
125135

126136
// Transform the clause and append the final break statement if necessary
127-
const clauseStatements = context.transformStatements(clause.statements);
137+
const {
138+
statements: clauseStatements,
139+
hoistedStatements: clauseHoistedStatements,
140+
hoistedIdentifiers: clauseHoistedIdentifiers,
141+
} = separateHoistedStatements(context, context.transformStatements(clause.statements));
128142
if (i === clauses.length - 1 && !containsBreakOrReturn(clause.statements)) {
129143
clauseStatements.push(lua.createBreakStatement());
130144
}
145+
hoistedStatements.push(...clauseHoistedStatements);
146+
hoistedIdentifiers.push(...clauseHoistedIdentifiers);
147+
148+
// Remember that we transformed default clause so we don't duplicate hoisted statements later
149+
if (ts.isDefaultClause(clause)) {
150+
defaultTransformed = true;
151+
}
131152

132153
// Push if statement for case
133154
statements.push(lua.createIfStatement(conditionVariable, lua.createBlock(clauseStatements)));
@@ -145,11 +166,25 @@ export const transformSwitchStatement: FunctionVisitor<ts.SwitchStatement> = (st
145166
(clause, index) => index >= start && containsBreakOrReturn(clause.statements)
146167
);
147168

148-
// Combine the default and all fallthrough statements
149-
const defaultStatements: lua.Statement[] = [];
150-
clauses
151-
.slice(start, end >= 0 ? end + 1 : undefined)
152-
.forEach(c => defaultStatements.push(...context.transformStatements(c.statements)));
169+
const {
170+
statements: defaultStatements,
171+
hoistedStatements: defaultHoistedStatements,
172+
hoistedIdentifiers: defaultHoistedIdentifiers,
173+
} = separateHoistedStatements(context, context.transformStatements(clauses[start].statements));
174+
175+
// Only push hoisted statements if this is the first time we're transforming the default clause
176+
if (!defaultTransformed) {
177+
hoistedStatements.push(...defaultHoistedStatements);
178+
hoistedIdentifiers.push(...defaultHoistedIdentifiers);
179+
}
180+
181+
// Combine the fallthrough statements
182+
for (const clause of clauses.slice(start + 1, end >= 0 ? end + 1 : undefined)) {
183+
let statements = context.transformStatements(clause.statements);
184+
// Drop hoisted statements as they were already added when clauses were initially transformed above
185+
({ statements } = separateHoistedStatements(context, statements));
186+
defaultStatements.push(...statements);
187+
}
153188

154189
// Add the default clause if it has any statements
155190
// The switch will always break on the final clause and skip execution if valid to do so
@@ -160,7 +195,11 @@ export const transformSwitchStatement: FunctionVisitor<ts.SwitchStatement> = (st
160195
}
161196

162197
// Hoist the variable, function, and import statements to the top of the switch
163-
statements = performHoisting(context, statements);
198+
statements.unshift(...hoistedStatements);
199+
if (hoistedIdentifiers.length > 0) {
200+
statements.unshift(lua.createVariableDeclarationStatement(hoistedIdentifiers));
201+
}
202+
164203
popScope(context);
165204

166205
// Add the switch expression after hoisting

test/unit/__snapshots__/switch.spec.ts.snap

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,82 @@ end
3434
return ____exports"
3535
`;
3636

37+
exports[`switch hoisting hoisting from default clause is not duplicated when falling through 1`] = `
38+
"local ____exports = {}
39+
function ____exports.__main(self)
40+
local x = 1
41+
local result = \\"\\"
42+
repeat
43+
local ____switch3 = x
44+
local hoisted
45+
function hoisted(self)
46+
return \\"hoisted\\"
47+
end
48+
local ____cond3 = ____switch3 == 1
49+
if ____cond3 then
50+
result = hoisted(nil)
51+
break
52+
end
53+
____cond3 = ____cond3 or (____switch3 == 2)
54+
if ____cond3 then
55+
result = \\"2\\"
56+
end
57+
if ____cond3 then
58+
result = \\"default\\"
59+
end
60+
____cond3 = ____cond3 or (____switch3 == 3)
61+
if ____cond3 then
62+
result = \\"3\\"
63+
break
64+
end
65+
do
66+
result = \\"default\\"
67+
result = \\"3\\"
68+
end
69+
until true
70+
return result
71+
end
72+
return ____exports"
73+
`;
74+
75+
exports[`switch hoisting hoisting from fallthrough clause after default is not duplicated 1`] = `
76+
"local ____exports = {}
77+
function ____exports.__main(self)
78+
local x = 1
79+
local result = \\"\\"
80+
repeat
81+
local ____switch3 = x
82+
local hoisted
83+
function hoisted(self)
84+
return \\"hoisted\\"
85+
end
86+
local ____cond3 = ____switch3 == 1
87+
if ____cond3 then
88+
result = hoisted(nil)
89+
break
90+
end
91+
____cond3 = ____cond3 or (____switch3 == 2)
92+
if ____cond3 then
93+
result = \\"2\\"
94+
end
95+
if ____cond3 then
96+
result = \\"default\\"
97+
end
98+
____cond3 = ____cond3 or (____switch3 == 3)
99+
if ____cond3 then
100+
result = \\"3\\"
101+
break
102+
end
103+
do
104+
result = \\"default\\"
105+
result = \\"3\\"
106+
end
107+
until true
108+
return result
109+
end
110+
return ____exports"
111+
`;
112+
37113
exports[`switch produces optimal output 1`] = `
38114
"require(\\"lualib_bundle\\");
39115
local ____exports = {}

0 commit comments

Comments
 (0)