Skip to content

Commit 3235c6c

Browse files
authored
Correctly handle awaited promises rejecting in a try/catch (#1144)
* Make await throw if awaited thing is a rejected promise * fix almost all tests * Also fix lambas in async * Fix bug in try/catch adding extra return twice * Fix prettier
1 parent a38a86b commit 3235c6c

File tree

5 files changed

+167
-61
lines changed

5 files changed

+167
-61
lines changed

src/lualib/Await.ts

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
// };
1515
//
1616

17+
type ErrorHandler = (this: void, error: unknown) => unknown;
18+
1719
// eslint-disable-next-line @typescript-eslint/promise-function-async
1820
function __TS__AsyncAwaiter(this: void, generator: (this: void) => void) {
1921
return new Promise((resolve, reject) => {
@@ -23,30 +25,47 @@ function __TS__AsyncAwaiter(this: void, generator: (this: void) => void) {
2325
function adopt(value: unknown) {
2426
return value instanceof __TS__Promise ? value : Promise.resolve(value);
2527
}
26-
function fulfilled(value) {
27-
const [success, resultOrError] = coroutine.resume(asyncCoroutine, value);
28+
function fulfilled(value: unknown) {
29+
const [success, errorOrErrorHandler, resultOrError] = coroutine.resume(asyncCoroutine, value);
2830
if (success) {
29-
step(resultOrError);
31+
step(resultOrError, errorOrErrorHandler);
3032
} else {
3133
reject(resultOrError);
3234
}
3335
}
34-
function step(result: unknown) {
36+
function rejected(handler: ErrorHandler | undefined) {
37+
if (handler) {
38+
return (value: unknown) => {
39+
const [success, valueOrError] = pcall(handler, value);
40+
if (success) {
41+
step(valueOrError, handler);
42+
} else {
43+
reject(valueOrError);
44+
}
45+
};
46+
} else {
47+
// If no catch clause, just reject
48+
return value => {
49+
reject(value);
50+
};
51+
}
52+
}
53+
function step(result: unknown, errorHandler: ErrorHandler | undefined) {
3554
if (coroutine.status(asyncCoroutine) === "dead") {
3655
resolve(result);
3756
} else {
38-
adopt(result).then(fulfilled, reason => reject(reason));
57+
adopt(result).then(fulfilled, rejected(errorHandler));
3958
}
4059
}
41-
const [success, resultOrError] = coroutine.resume(asyncCoroutine);
60+
const [success, errorOrErrorHandler, resultOrError] = coroutine.resume(asyncCoroutine);
4261
if (success) {
43-
step(resultOrError);
62+
step(resultOrError, errorOrErrorHandler);
4463
} else {
45-
reject(resultOrError);
64+
reject(errorOrErrorHandler);
4665
}
4766
});
4867
}
4968

50-
function __TS__Await(this: void, thing: unknown) {
51-
return coroutine.yield(thing);
69+
function __TS__Await(this: void, errorHandler: ErrorHandler, thing: unknown) {
70+
return coroutine.yield(errorHandler, thing);
5271
}

src/transformation/visitors/async-await.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ export const transformAwaitExpression: FunctionVisitor<ts.AwaitExpression> = (no
1616
}
1717

1818
const expression = context.transformExpression(node.expression);
19-
return transformLuaLibFunction(context, LuaLibFeature.Await, node, expression);
19+
const catchIdentifier = lua.createIdentifier("____catch");
20+
return transformLuaLibFunction(context, LuaLibFeature.Await, node, catchIdentifier, expression);
2021
};
2122

2223
export function isAsyncFunction(declaration: ts.FunctionLikeDeclaration): boolean {

src/transformation/visitors/errors.ts

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ import * as ts from "typescript";
22
import * as lua from "../../LuaAST";
33
import { FunctionVisitor } from "../context";
44
import { createUnpackCall } from "../utils/lua-ast";
5-
import { findScope, ScopeType } from "../utils/scope";
5+
import { ScopeType } from "../utils/scope";
66
import { transformScopeBlock } from "./block";
77
import { transformIdentifier } from "./identifier";
88
import { isInMultiReturnFunction } from "./language-extensions/multi";
9+
import { createReturnStatement } from "./return";
910

1011
export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statement, context) => {
1112
const [tryBlock, tryScope] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);
@@ -15,24 +16,24 @@ export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statemen
1516

1617
const result: lua.Statement[] = [];
1718

18-
let returnedIdentifier: lua.Identifier | undefined;
19+
const returnedIdentifier = lua.createIdentifier("____hasReturned");
1920
let returnCondition: lua.Expression | undefined;
2021

2122
const pCall = lua.createIdentifier("pcall");
2223
const tryCall = lua.createCallExpression(pCall, [lua.createFunctionExpression(tryBlock)]);
2324

2425
if (statement.catchClause && statement.catchClause.block.statements.length > 0) {
2526
// try with catch
26-
let [catchBlock, catchScope] = transformScopeBlock(context, statement.catchClause.block, ScopeType.Catch);
27-
if (statement.catchClause.variableDeclaration) {
28-
// Replace ____returned with catch variable
29-
returnedIdentifier = transformIdentifier(
30-
context,
31-
statement.catchClause.variableDeclaration.name as ts.Identifier
32-
);
33-
} else if (tryScope.functionReturned || catchScope.functionReturned) {
34-
returnedIdentifier = lua.createIdentifier("____returned");
35-
}
27+
const [catchBlock, catchScope] = transformScopeBlock(context, statement.catchClause.block, ScopeType.Catch);
28+
29+
const catchParameter = statement.catchClause.variableDeclaration
30+
? transformIdentifier(context, statement.catchClause.variableDeclaration.name as ts.Identifier)
31+
: undefined;
32+
const catchParameters = () => (catchParameter ? [lua.cloneIdentifier(catchParameter)] : []);
33+
34+
const catchIdentifier = lua.createIdentifier("____catch");
35+
const catchFunction = lua.createFunctionExpression(catchBlock, catchParameters());
36+
result.push(lua.createVariableDeclarationStatement(catchIdentifier, catchFunction));
3637

3738
const tryReturnIdentifiers = [tryResultIdentifier]; // ____try
3839
if (returnedIdentifier) {
@@ -44,20 +45,18 @@ export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statemen
4445
}
4546
result.push(lua.createVariableDeclarationStatement(tryReturnIdentifiers, tryCall));
4647

47-
if ((tryScope.functionReturned || catchScope.functionReturned) && returnedIdentifier) {
48-
// Wrap catch in function if try or catch has return
49-
const catchCall = lua.createCallExpression(lua.createFunctionExpression(catchBlock), []);
50-
const catchAssign = lua.createAssignmentStatement(
51-
[lua.cloneIdentifier(returnedIdentifier), lua.cloneIdentifier(returnValueIdentifier)],
52-
catchCall
53-
);
54-
catchBlock = lua.createBlock([catchAssign]);
55-
}
48+
// Wrap catch in function if try or catch has return
49+
const catchCall = lua.createCallExpression(catchIdentifier, [lua.cloneIdentifier(returnedIdentifier)]);
50+
const catchAssign = lua.createAssignmentStatement(
51+
[lua.cloneIdentifier(returnedIdentifier), lua.cloneIdentifier(returnValueIdentifier)],
52+
catchCall
53+
);
54+
5655
const notTryCondition = lua.createUnaryExpression(tryResultIdentifier, lua.SyntaxKind.NotOperator);
57-
result.push(lua.createIfStatement(notTryCondition, catchBlock));
56+
result.push(lua.createIfStatement(notTryCondition, lua.createBlock([catchAssign])));
5857
} else if (tryScope.functionReturned) {
5958
// try with return, but no catch
60-
returnedIdentifier = lua.createIdentifier("____returned");
59+
// returnedIdentifier = lua.createIdentifier("____returned");
6160
const returnedVariables = [tryResultIdentifier, returnedIdentifier, returnValueIdentifier];
6261
result.push(lua.createVariableDeclarationStatement(returnedVariables, tryCall));
6362

@@ -77,24 +76,15 @@ export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statemen
7776
}
7877

7978
if (returnCondition && returnedIdentifier) {
80-
// With catch clause:
81-
// if ____returned then return ____returnValue end
82-
// No catch clause:
83-
// if ____try and ____returned then return ____returnValue end
8479
const returnValues: lua.Expression[] = [];
85-
const parentTryCatch = findScope(context, ScopeType.Function | ScopeType.Try | ScopeType.Catch);
86-
if (parentTryCatch && parentTryCatch.type !== ScopeType.Function) {
87-
// Nested try/catch needs to prefix a 'true' return value
88-
returnValues.push(lua.createBooleanLiteral(true));
89-
}
9080

9181
if (isInMultiReturnFunction(context, statement)) {
9282
returnValues.push(createUnpackCall(context, lua.cloneIdentifier(returnValueIdentifier)));
9383
} else {
9484
returnValues.push(lua.cloneIdentifier(returnValueIdentifier));
9585
}
9686

97-
const returnStatement = lua.createReturnStatement(returnValues);
87+
const returnStatement = createReturnStatement(context, returnValues, statement);
9888
const ifReturnedStatement = lua.createIfStatement(returnCondition, lua.createBlock([returnStatement]));
9989
result.push(ifReturnedStatement);
10090
}

src/transformation/visitors/return.ts

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import {
1414
canBeMultiReturnType,
1515
} from "./language-extensions/multi";
1616
import { invalidMultiFunctionReturnType } from "../utils/diagnostics";
17+
import { findFirstNodeAbove } from "../utils/typescript";
1718

1819
function transformExpressionsInReturn(
1920
context: TransformationContext,
@@ -55,22 +56,10 @@ export function transformExpressionBodyToReturnStatement(
5556
node: ts.Expression
5657
): lua.Statement {
5758
const expressions = transformExpressionsInReturn(context, node, false);
58-
return lua.createReturnStatement(expressions, node);
59+
return createReturnStatement(context, expressions, node);
5960
}
6061

6162
export const transformReturnStatement: FunctionVisitor<ts.ReturnStatement> = (statement, context) => {
62-
// Bubble up explicit return flag and check if we're inside a try/catch block
63-
let insideTryCatch = false;
64-
for (const scope of walkScopesUp(context)) {
65-
scope.functionReturned = true;
66-
67-
if (scope.type === ScopeType.Function) {
68-
break;
69-
}
70-
71-
insideTryCatch = insideTryCatch || scope.type === ScopeType.Try || scope.type === ScopeType.Catch;
72-
}
73-
7463
let results: lua.Expression[];
7564

7665
if (statement.expression) {
@@ -80,15 +69,55 @@ export const transformReturnStatement: FunctionVisitor<ts.ReturnStatement> = (st
8069
validateAssignment(context, statement, expressionType, returnType);
8170
}
8271

83-
results = transformExpressionsInReturn(context, statement.expression, insideTryCatch);
72+
results = transformExpressionsInReturn(context, statement.expression, isInTryCatch(context));
8473
} else {
8574
// Empty return
8675
results = [];
8776
}
8877

89-
if (insideTryCatch) {
78+
return createReturnStatement(context, results, statement);
79+
};
80+
81+
export function createReturnStatement(
82+
context: TransformationContext,
83+
values: lua.Expression[],
84+
node: ts.Node
85+
): lua.ReturnStatement {
86+
const results = [...values];
87+
88+
if (isInTryCatch(context)) {
89+
// Bubble up explicit return flag and check if we're inside a try/catch block
9090
results.unshift(lua.createBooleanLiteral(true));
91+
} else if (isInAsyncFunction(node)) {
92+
// Add nil error handler in async function and not in try
93+
results.unshift(lua.createNilLiteral());
9194
}
9295

93-
return lua.createReturnStatement(results, statement);
94-
};
96+
return lua.createReturnStatement(results, node);
97+
}
98+
99+
function isInAsyncFunction(node: ts.Node): boolean {
100+
// Check if node is in function declaration with `async`
101+
const declaration = findFirstNodeAbove(node, ts.isFunctionLike);
102+
if (!declaration) {
103+
return false;
104+
}
105+
106+
return declaration.modifiers?.some(m => m.kind === ts.SyntaxKind.AsyncKeyword) ?? false;
107+
}
108+
109+
function isInTryCatch(context: TransformationContext): boolean {
110+
// Check if context is in a try or catch
111+
let insideTryCatch = false;
112+
for (const scope of walkScopesUp(context)) {
113+
scope.functionReturned = true;
114+
115+
if (scope.type === ScopeType.Function) {
116+
break;
117+
}
118+
119+
insideTryCatch = insideTryCatch || scope.type === ScopeType.Try || scope.type === ScopeType.Catch;
120+
}
121+
122+
return insideTryCatch;
123+
}

test/unit/builtins/async-await.spec.ts

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,19 @@ test.each(["async function abc() {", "const abc = async () => {"])(
161161
}
162162
);
163163

164+
test("can make inline async functions", () => {
165+
util.testFunction`
166+
const foo = async function() { return "foo"; };
167+
const bar = async function() { return await foo(); };
168+
169+
const { state, value } = bar() as any;
170+
return { state, value };
171+
`.expectToEqual({
172+
state: 1, // __TS__PromiseState.Fulfilled
173+
value: "foo",
174+
});
175+
});
176+
164177
test("can make async lambdas with expression body", () => {
165178
util.testFunction`
166179
const foo = async () => "foo";
@@ -369,3 +382,57 @@ test("async function can forward varargs", () => {
369382
.setTsHeader(promiseTestLib)
370383
.expectToEqual(["resolved", "A", "B", "C"]);
371384
});
385+
386+
// https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1105
387+
describe("try/catch in async function", () => {
388+
test("await inside try/catch returns inside async function", () => {
389+
util.testModule`
390+
export let result = 0;
391+
async function foo(): Promise<number> {
392+
try {
393+
return await new Promise(resolve => resolve(4));
394+
} catch {
395+
throw "an error occurred in the async function"
396+
}
397+
}
398+
foo().then(value => {
399+
result = value;
400+
});
401+
`.expectToEqual({ result: 4 });
402+
});
403+
404+
test("await inside try/catch throws inside async function", () => {
405+
util.testModule`
406+
export let reason = "";
407+
async function foo(): Promise<number> {
408+
try {
409+
return await new Promise((resolve, reject) => reject("test error"));
410+
} catch (e) {
411+
throw "an error occurred in the async function: " + e;
412+
}
413+
}
414+
foo().catch(e => {
415+
reason = e;
416+
});
417+
`.expectToEqual({ reason: "an error occurred in the async function: test error" });
418+
});
419+
420+
test("await inside try/catch deferred rejection uses catch clause", () => {
421+
util.testModule`
422+
export let reason = "";
423+
let reject: (reason: string) => void;
424+
425+
async function foo(): Promise<number> {
426+
try {
427+
return await new Promise((res, rej) => { reject = rej; });
428+
} catch (e) {
429+
throw "an error occurred in the async function: " + e;
430+
}
431+
}
432+
foo().catch(e => {
433+
reason = e;
434+
});
435+
reject("test error");
436+
`.expectToEqual({ reason: "an error occurred in the async function: test error" });
437+
});
438+
});

0 commit comments

Comments
 (0)