Skip to content

Commit b692c28

Browse files
committed
fix return, break and continue inside try in async functions (#1706)
1 parent 5176fcd commit b692c28

File tree

5 files changed

+292
-18
lines changed

5 files changed

+292
-18
lines changed

src/transformation/utils/scope.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ export interface Scope {
3838
importStatements?: lua.Statement[];
3939
loopContinued?: LoopContinued;
4040
functionReturned?: boolean;
41+
asyncTryHasReturn?: boolean;
42+
asyncTryHasBreak?: boolean;
43+
asyncTryHasContinue?: LoopContinued;
4144
}
4245

4346
export interface HoistingResult {
@@ -84,6 +87,23 @@ export function findScope(context: TransformationContext, scopeTypes: ScopeType)
8487
}
8588
}
8689

90+
export function findAsyncTryScopeInStack(context: TransformationContext): Scope | undefined {
91+
for (const scope of walkScopesUp(context)) {
92+
if (scope.type === ScopeType.Function) return undefined;
93+
if (scope.type === ScopeType.Try || scope.type === ScopeType.Catch) return scope;
94+
}
95+
return undefined;
96+
}
97+
98+
/** Like findAsyncTryScopeInStack, but also stops at Loop boundaries. */
99+
export function findAsyncTryScopeBeforeLoop(context: TransformationContext): Scope | undefined {
100+
for (const scope of walkScopesUp(context)) {
101+
if (scope.type === ScopeType.Function || scope.type === ScopeType.Loop) return undefined;
102+
if (scope.type === ScopeType.Try || scope.type === ScopeType.Catch) return scope;
103+
}
104+
return undefined;
105+
}
106+
87107
export function addScopeVariableDeclaration(scope: Scope, declaration: lua.VariableDeclarationStatement) {
88108
scope.variableDeclarations ??= [];
89109

src/transformation/visitors/break-continue.ts

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,22 @@ import * as ts from "typescript";
22
import { LuaTarget } from "../../CompilerOptions";
33
import * as lua from "../../LuaAST";
44
import { FunctionVisitor } from "../context";
5-
import { findScope, LoopContinued, ScopeType } from "../utils/scope";
5+
import { findAsyncTryScopeBeforeLoop, findScope, LoopContinued, ScopeType } from "../utils/scope";
6+
import { isInAsyncFunction } from "../utils/typescript";
67

78
export const transformBreakStatement: FunctionVisitor<ts.BreakStatement> = (breakStatement, context) => {
8-
void context;
9+
const tryScope = isInAsyncFunction(breakStatement) ? findAsyncTryScopeBeforeLoop(context) : undefined;
10+
if (tryScope) {
11+
tryScope.asyncTryHasBreak = true;
12+
return [
13+
lua.createAssignmentStatement(
14+
lua.createIdentifier("____hasBroken"),
15+
lua.createBooleanLiteral(true),
16+
breakStatement
17+
),
18+
lua.createReturnStatement([], breakStatement),
19+
];
20+
}
921
return lua.createBreakStatement(breakStatement);
1022
};
1123

@@ -28,6 +40,19 @@ export const transformContinueStatement: FunctionVisitor<ts.ContinueStatement> =
2840
scope.loopContinued = continuedWith;
2941
}
3042

43+
const tryScope = isInAsyncFunction(statement) ? findAsyncTryScopeBeforeLoop(context) : undefined;
44+
if (tryScope) {
45+
tryScope.asyncTryHasContinue = continuedWith;
46+
return [
47+
lua.createAssignmentStatement(
48+
lua.createIdentifier("____hasContinued"),
49+
lua.createBooleanLiteral(true),
50+
statement
51+
),
52+
lua.createReturnStatement([], statement),
53+
];
54+
}
55+
3156
const label = `__continue${scope?.id ?? ""}`;
3257

3358
switch (continuedWith) {

src/transformation/visitors/errors.ts

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { FunctionVisitor, TransformationContext } from "../context";
55
import { unsupportedForTarget, unsupportedForTargetButOverrideAvailable } from "../utils/diagnostics";
66
import { createUnpackCall } from "../utils/lua-ast";
77
import { transformLuaLibFunction } from "../utils/lualib";
8-
import { Scope, ScopeType } from "../utils/scope";
8+
import { findScope, LoopContinued, Scope, ScopeType } from "../utils/scope";
99
import { isInAsyncFunction, isInGeneratorFunction } from "../utils/typescript";
1010
import { wrapInAsyncAwaiter } from "./async-await";
1111
import { transformScopeBlock } from "./block";
@@ -14,7 +14,7 @@ import { isInMultiReturnFunction } from "./language-extensions/multi";
1414
import { createReturnStatement } from "./return";
1515

1616
const transformAsyncTry: FunctionVisitor<ts.TryStatement> = (statement, context) => {
17-
const [tryBlock] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);
17+
const [tryBlock, tryScope] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);
1818

1919
if (
2020
(context.options.luaTarget === LuaTarget.Lua50 || context.options.luaTarget === LuaTarget.Lua51) &&
@@ -31,13 +31,14 @@ const transformAsyncTry: FunctionVisitor<ts.TryStatement> = (statement, context)
3131
return tryBlock.statements;
3232
}
3333

34-
// __TS__AsyncAwaiter(<catch block>)
34+
// __TS__AsyncAwaiter(<try block>)
3535
const awaiter = wrapInAsyncAwaiter(context, tryBlock.statements, false);
3636
const awaiterIdentifier = lua.createIdentifier("____try");
3737
const awaiterDefinition = lua.createVariableDeclarationStatement(awaiterIdentifier, awaiter);
3838

39-
// local ____try = __TS__AsyncAwaiter(<catch block>)
40-
const result: lua.Statement[] = [awaiterDefinition];
39+
// Transform catch/finally and collect scope info before building the result
40+
let catchScope: Scope | undefined;
41+
const chainCalls: lua.Statement[] = [];
4142

4243
if (statement.finallyBlock) {
4344
const awaiterFinally = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("finally"));
@@ -49,27 +50,88 @@ const transformAsyncTry: FunctionVisitor<ts.TryStatement> = (statement, context)
4950
[awaiterIdentifier, finallyFunction],
5051
statement.finallyBlock
5152
);
52-
// ____try.finally(<finally function>)
53-
result.push(lua.createExpressionStatement(finallyCall));
53+
chainCalls.push(lua.createExpressionStatement(finallyCall));
5454
}
5555

5656
if (statement.catchClause) {
57-
// ____try.catch(<catch function>)
58-
const [catchFunction] = transformCatchClause(context, statement.catchClause);
57+
const [catchFunction, cScope] = transformCatchClause(context, statement.catchClause);
58+
catchScope = cScope;
5959
if (catchFunction.params) {
6060
catchFunction.params.unshift(lua.createAnonymousIdentifier());
6161
}
6262

6363
const awaiterCatch = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("catch"));
6464
const catchCall = lua.createCallExpression(awaiterCatch, [awaiterIdentifier, catchFunction]);
65-
66-
// await ____try.catch(<catch function>)
6765
const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, catchCall);
68-
result.push(lua.createExpressionStatement(promiseAwait, statement));
66+
chainCalls.push(lua.createExpressionStatement(promiseAwait, statement));
6967
} else {
70-
// await ____try
7168
const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, awaiterIdentifier);
72-
result.push(lua.createExpressionStatement(promiseAwait, statement));
69+
chainCalls.push(lua.createExpressionStatement(promiseAwait, statement));
70+
}
71+
72+
const hasReturn = tryScope.asyncTryHasReturn ?? catchScope?.asyncTryHasReturn;
73+
const hasBreak = tryScope.asyncTryHasBreak ?? catchScope?.asyncTryHasBreak;
74+
const hasContinue = tryScope.asyncTryHasContinue ?? catchScope?.asyncTryHasContinue;
75+
76+
// Build result in output order: flag declarations, awaiter, chain calls, post-checks
77+
const result: lua.Statement[] = [];
78+
79+
if (hasReturn || hasBreak || hasContinue !== undefined) {
80+
const flagDecls: lua.Identifier[] = [];
81+
if (hasReturn) {
82+
flagDecls.push(lua.createIdentifier("____hasReturned"));
83+
flagDecls.push(lua.createIdentifier("____returnValue"));
84+
}
85+
if (hasBreak) {
86+
flagDecls.push(lua.createIdentifier("____hasBroken"));
87+
}
88+
if (hasContinue !== undefined) {
89+
flagDecls.push(lua.createIdentifier("____hasContinued"));
90+
}
91+
result.push(lua.createVariableDeclarationStatement(flagDecls));
92+
}
93+
94+
result.push(awaiterDefinition);
95+
result.push(...chainCalls);
96+
97+
if (hasReturn) {
98+
result.push(
99+
lua.createIfStatement(
100+
lua.createIdentifier("____hasReturned"),
101+
lua.createBlock([createReturnStatement(context, [lua.createIdentifier("____returnValue")], statement)])
102+
)
103+
);
104+
}
105+
106+
if (hasBreak) {
107+
result.push(
108+
lua.createIfStatement(lua.createIdentifier("____hasBroken"), lua.createBlock([lua.createBreakStatement()]))
109+
);
110+
}
111+
112+
if (hasContinue !== undefined) {
113+
const loopScope = findScope(context, ScopeType.Loop);
114+
const label = `__continue${loopScope?.id ?? ""}`;
115+
116+
const continueStatements: lua.Statement[] = [];
117+
switch (hasContinue) {
118+
case LoopContinued.WithGoto:
119+
continueStatements.push(lua.createGotoStatement(label));
120+
break;
121+
case LoopContinued.WithContinue:
122+
continueStatements.push(lua.createContinueStatement());
123+
break;
124+
case LoopContinued.WithRepeatBreak:
125+
continueStatements.push(
126+
lua.createAssignmentStatement(lua.createIdentifier(label), lua.createBooleanLiteral(true))
127+
);
128+
continueStatements.push(lua.createBreakStatement());
129+
break;
130+
}
131+
132+
result.push(
133+
lua.createIfStatement(lua.createIdentifier("____hasContinued"), lua.createBlock(continueStatements))
134+
);
73135
}
74136

75137
return result;

src/transformation/visitors/return.ts

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import * as lua from "../../LuaAST";
33
import { FunctionVisitor, TransformationContext } from "../context";
44
import { validateAssignment } from "../utils/assignment-validation";
55
import { createUnpackCall, wrapInTable } from "../utils/lua-ast";
6-
import { ScopeType, walkScopesUp } from "../utils/scope";
6+
import { findAsyncTryScopeInStack, ScopeType, walkScopesUp } from "../utils/scope";
77
import { transformArguments } from "./call";
88
import {
99
returnsMultiType,
@@ -68,6 +68,8 @@ export function transformExpressionBodyToReturnStatement(
6868
}
6969

7070
export const transformReturnStatement: FunctionVisitor<ts.ReturnStatement> = (statement, context) => {
71+
const asyncTryScope = isInAsyncFunction(statement) ? findAsyncTryScopeInStack(context) : undefined;
72+
7173
let results: lua.Expression[];
7274

7375
if (statement.expression) {
@@ -77,12 +79,35 @@ export const transformReturnStatement: FunctionVisitor<ts.ReturnStatement> = (st
7779
validateAssignment(context, statement, expressionType, returnType);
7880
}
7981

80-
results = transformExpressionsInReturn(context, statement.expression, isInTryCatch(context));
82+
// In async try, we handle return propagation via flag variables (asyncTryHasReturn)
83+
// rather than pcall return values (functionReturned set by isInTryCatch), so we skip
84+
// isInTryCatch but still need insideTryCatch=true for multi-return wrapping.
85+
results = transformExpressionsInReturn(
86+
context,
87+
statement.expression,
88+
asyncTryScope ? true : isInTryCatch(context)
89+
);
8190
} else {
8291
// Empty return
8392
results = [];
8493
}
8594

95+
if (asyncTryScope) {
96+
asyncTryScope.asyncTryHasReturn = true;
97+
const stmts: lua.Statement[] = [
98+
lua.createAssignmentStatement(
99+
lua.createIdentifier("____hasReturned"),
100+
lua.createBooleanLiteral(true),
101+
statement
102+
),
103+
];
104+
if (results.length > 0) {
105+
stmts.push(lua.createAssignmentStatement(lua.createIdentifier("____returnValue"), results[0], statement));
106+
}
107+
stmts.push(lua.createReturnStatement([], statement));
108+
return stmts;
109+
}
110+
86111
return createReturnStatement(context, results, statement);
87112
};
88113

0 commit comments

Comments
 (0)