Skip to content

Commit 0924b37

Browse files
authored
Fix async try (#1278)
* Initial prototype * this works * Kinda works but not really * Actually it does work but test was wrong * finalize solution
1 parent 43d31d0 commit 0924b37

File tree

6 files changed

+299
-72
lines changed

6 files changed

+299
-72
lines changed

src/lualib/Await.ts

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,62 +16,44 @@
1616

1717
import { __TS__Promise } from "./Promise";
1818

19-
type ErrorHandler = (this: void, error: unknown) => unknown;
20-
2119
// eslint-disable-next-line @typescript-eslint/promise-function-async
2220
export function __TS__AsyncAwaiter(this: void, generator: (this: void) => void) {
2321
return new Promise((resolve, reject) => {
22+
let resolved = false;
2423
const asyncCoroutine = coroutine.create(generator);
2524

2625
// eslint-disable-next-line @typescript-eslint/promise-function-async
2726
function adopt(value: unknown) {
2827
return value instanceof __TS__Promise ? value : Promise.resolve(value);
2928
}
3029
function fulfilled(value: unknown) {
31-
const [success, errorOrErrorHandler, resultOrError] = coroutine.resume(asyncCoroutine, value);
30+
const [success, resultOrError] = coroutine.resume(asyncCoroutine, value);
3231
if (success) {
33-
step(resultOrError, errorOrErrorHandler);
34-
} else {
35-
reject(errorOrErrorHandler);
36-
}
37-
}
38-
function rejected(handler: ErrorHandler | undefined) {
39-
if (handler) {
40-
return (value: unknown) => {
41-
const [success, hasReturnedOrError, returnedValue] = pcall(handler, value);
42-
if (success) {
43-
if (hasReturnedOrError) {
44-
resolve(returnedValue);
45-
} else {
46-
step(hasReturnedOrError, handler);
47-
}
48-
} else {
49-
reject(hasReturnedOrError);
50-
}
51-
};
32+
step(resultOrError);
5233
} else {
53-
// If no catch clause, just reject
54-
return value => {
55-
reject(value);
56-
};
34+
reject(resultOrError);
5735
}
5836
}
59-
function step(result: unknown, errorHandler: ErrorHandler | undefined) {
37+
function step(result: unknown) {
38+
if (resolved) return;
6039
if (coroutine.status(asyncCoroutine) === "dead") {
6140
resolve(result);
6241
} else {
63-
adopt(result).then(fulfilled, rejected(errorHandler));
42+
adopt(result).then(fulfilled, reject);
6443
}
6544
}
66-
const [success, errorOrErrorHandler, resultOrError] = coroutine.resume(asyncCoroutine);
45+
const [success, resultOrError] = coroutine.resume(asyncCoroutine, (v: unknown) => {
46+
resolved = true;
47+
adopt(v).then(resolve, reject);
48+
});
6749
if (success) {
68-
step(resultOrError, errorOrErrorHandler);
50+
step(resultOrError);
6951
} else {
70-
reject(errorOrErrorHandler);
52+
reject(resultOrError);
7153
}
7254
});
7355
}
7456

75-
export function __TS__Await(this: void, errorHandler: ErrorHandler, thing: unknown) {
76-
return coroutine.yield(errorHandler, thing);
57+
export function __TS__Await(this: void, thing: unknown) {
58+
return coroutine.yield(thing);
7759
}

src/transformation/visitors/async-await.ts

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,32 @@ import * as lua from "../../LuaAST";
33
import { FunctionVisitor, TransformationContext } from "../context";
44
import { awaitMustBeInAsyncFunction } from "../utils/diagnostics";
55
import { importLuaLibFeature, LuaLibFeature, transformLuaLibFunction } from "../utils/lualib";
6-
import { findFirstNodeAbove } from "../utils/typescript";
6+
import { isInAsyncFunction } from "../utils/typescript";
77

88
export const transformAwaitExpression: FunctionVisitor<ts.AwaitExpression> = (node, context) => {
99
// Check if await is inside an async function, it is not allowed at top level or in non-async functions
10-
const containingFunction = findFirstNodeAbove(node, ts.isFunctionLike);
11-
if (
12-
containingFunction === undefined ||
13-
!containingFunction.modifiers?.some(m => m.kind === ts.SyntaxKind.AsyncKeyword)
14-
) {
10+
if (!isInAsyncFunction(node)) {
1511
context.diagnostics.push(awaitMustBeInAsyncFunction(node));
1612
}
1713

1814
const expression = context.transformExpression(node.expression);
19-
const catchIdentifier = lua.createIdentifier("____catch");
20-
return transformLuaLibFunction(context, LuaLibFeature.Await, node, catchIdentifier, expression);
15+
return transformLuaLibFunction(context, LuaLibFeature.Await, node, expression);
2116
};
2217

2318
export function isAsyncFunction(declaration: ts.FunctionLikeDeclaration): boolean {
2419
return declaration.modifiers?.some(m => m.kind === ts.SyntaxKind.AsyncKeyword) ?? false;
2520
}
2621

27-
export function wrapInAsyncAwaiter(context: TransformationContext, statements: lua.Statement[]): lua.Statement[] {
22+
export function wrapInAsyncAwaiter(
23+
context: TransformationContext,
24+
statements: lua.Statement[],
25+
includeResolveParameter = true
26+
): lua.CallExpression {
2827
importLuaLibFeature(context, LuaLibFeature.Await);
2928

30-
return [
31-
lua.createReturnStatement([
32-
lua.createCallExpression(lua.createIdentifier("__TS__AsyncAwaiter"), [
33-
lua.createFunctionExpression(lua.createBlock(statements)),
34-
]),
35-
]),
36-
];
29+
const parameters = includeResolveParameter ? [lua.createIdentifier("____awaiter_resolve")] : [];
30+
31+
return lua.createCallExpression(lua.createIdentifier("__TS__AsyncAwaiter"), [
32+
lua.createFunctionExpression(lua.createBlock(statements), parameters),
33+
]);
3734
}

src/transformation/visitors/errors.ts

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
11
import * as ts from "typescript";
2-
import { LuaTarget } from "../..";
2+
import { LuaLibFeature, LuaTarget } from "../..";
33
import * as lua from "../../LuaAST";
4-
import { FunctionVisitor } from "../context";
4+
import { FunctionVisitor, TransformationContext } from "../context";
55
import { unsupportedForTarget, unsupportedForTargetButOverrideAvailable } from "../utils/diagnostics";
66
import { createUnpackCall } from "../utils/lua-ast";
7-
import { ScopeType } from "../utils/scope";
7+
import { transformLuaLibFunction } from "../utils/lualib";
8+
import { Scope, ScopeType } from "../utils/scope";
89
import { isInAsyncFunction, isInGeneratorFunction } from "../utils/typescript";
10+
import { wrapInAsyncAwaiter } from "./async-await";
911
import { transformScopeBlock } from "./block";
1012
import { transformIdentifier } from "./identifier";
1113
import { isInMultiReturnFunction } from "./language-extensions/multi";
1214
import { createReturnStatement } from "./return";
1315

14-
export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statement, context) => {
15-
const [tryBlock, tryScope] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);
16+
const transformAsyncTry: FunctionVisitor<ts.TryStatement> = (statement, context) => {
17+
const [tryBlock] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);
1618

17-
if (
18-
context.options.luaTarget === LuaTarget.Lua51 &&
19-
isInAsyncFunction(statement) &&
20-
!context.options.lua51AllowTryCatchInAsyncAwait
21-
) {
19+
if (context.options.luaTarget === LuaTarget.Lua51 && !context.options.lua51AllowTryCatchInAsyncAwait) {
2220
context.diagnostics.push(
2321
unsupportedForTargetButOverrideAvailable(
2422
statement,
@@ -30,6 +28,57 @@ export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statemen
3028
return tryBlock.statements;
3129
}
3230

31+
// __TS__AsyncAwaiter(<catch block>)
32+
const awaiter = wrapInAsyncAwaiter(context, tryBlock.statements, false);
33+
const awaiterIdentifier = lua.createIdentifier("____try");
34+
const awaiterDefinition = lua.createVariableDeclarationStatement(awaiterIdentifier, awaiter);
35+
36+
// local ____try = __TS__AsyncAwaiter(<catch block>)
37+
const result: lua.Statement[] = [awaiterDefinition];
38+
39+
if (statement.finallyBlock) {
40+
const awaiterFinally = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("finally"));
41+
const finallyFunction = lua.createFunctionExpression(
42+
lua.createBlock(context.transformStatements(statement.finallyBlock.statements))
43+
);
44+
const finallyCall = lua.createCallExpression(
45+
awaiterFinally,
46+
[awaiterIdentifier, finallyFunction],
47+
statement.finallyBlock
48+
);
49+
// ____try.finally(<finally function>)
50+
result.push(lua.createExpressionStatement(finallyCall));
51+
}
52+
53+
if (statement.catchClause) {
54+
// ____try.catch(<catch function>)
55+
const [catchFunction] = transformCatchClause(context, statement.catchClause);
56+
if (catchFunction.params) {
57+
catchFunction.params.unshift(lua.createAnonymousIdentifier());
58+
}
59+
60+
const awaiterCatch = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("catch"));
61+
const catchCall = lua.createCallExpression(awaiterCatch, [awaiterIdentifier, catchFunction]);
62+
63+
// await ____try.catch(<catch function>)
64+
const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, catchCall);
65+
result.push(lua.createExpressionStatement(promiseAwait, statement));
66+
} else {
67+
// await ____try
68+
const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, awaiterIdentifier);
69+
result.push(lua.createExpressionStatement(promiseAwait, statement));
70+
}
71+
72+
return result;
73+
};
74+
75+
export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statement, context) => {
76+
if (isInAsyncFunction(statement)) {
77+
return transformAsyncTry(statement, context);
78+
}
79+
80+
const [tryBlock, tryScope] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);
81+
3382
if (context.options.luaTarget === LuaTarget.Lua51 && isInGeneratorFunction(statement)) {
3483
context.diagnostics.push(
3584
unsupportedForTarget(statement, "try/catch inside generator functions", LuaTarget.Lua51)
@@ -50,15 +99,7 @@ export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statemen
5099

51100
if (statement.catchClause && statement.catchClause.block.statements.length > 0) {
52101
// try with catch
53-
const [catchBlock, catchScope] = transformScopeBlock(context, statement.catchClause.block, ScopeType.Catch);
54-
55-
const catchParameter = statement.catchClause.variableDeclaration
56-
? transformIdentifier(context, statement.catchClause.variableDeclaration.name as ts.Identifier)
57-
: undefined;
58-
const catchFunction = lua.createFunctionExpression(
59-
catchBlock,
60-
catchParameter ? [lua.cloneIdentifier(catchParameter)] : []
61-
);
102+
const [catchFunction, catchScope] = transformCatchClause(context, statement.catchClause);
62103
const catchIdentifier = lua.createIdentifier("____catch");
63104
result.push(lua.createVariableDeclarationStatement(catchIdentifier, catchFunction));
64105

@@ -138,3 +179,20 @@ export const transformThrowStatement: FunctionVisitor<ts.ThrowStatement> = (stat
138179
statement
139180
);
140181
};
182+
183+
function transformCatchClause(
184+
context: TransformationContext,
185+
catchClause: ts.CatchClause
186+
): [lua.FunctionExpression, Scope] {
187+
const [catchBlock, catchScope] = transformScopeBlock(context, catchClause.block, ScopeType.Catch);
188+
189+
const catchParameter = catchClause.variableDeclaration
190+
? transformIdentifier(context, catchClause.variableDeclaration.name as ts.Identifier)
191+
: undefined;
192+
const catchFunction = lua.createFunctionExpression(
193+
catchBlock,
194+
catchParameter ? [lua.cloneIdentifier(catchParameter)] : []
195+
);
196+
197+
return [catchFunction, catchScope];
198+
}

src/transformation/visitors/function.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ export function transformFunctionBody(
163163
scope.node = node;
164164
let bodyStatements = transformFunctionBodyContent(context, body);
165165
if (node && isAsyncFunction(node)) {
166-
bodyStatements = wrapInAsyncAwaiter(context, bodyStatements);
166+
bodyStatements = [lua.createReturnStatement([wrapInAsyncAwaiter(context, bodyStatements)])];
167167
}
168168
const headerStatements = transformFunctionBodyHeader(context, scope, parameters, spreadIdentifier);
169169
popScope(context);

src/transformation/visitors/return.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,15 @@ export function createReturnStatement(
8585
): lua.ReturnStatement {
8686
const results = [...values];
8787

88+
if (isInAsyncFunction(node)) {
89+
return lua.createReturnStatement([
90+
lua.createCallExpression(lua.createIdentifier("____awaiter_resolve"), [lua.createNilLiteral(), ...values]),
91+
]);
92+
}
93+
8894
if (isInTryCatch(context)) {
8995
// Bubble up explicit return flag and check if we're inside a try/catch block
9096
results.unshift(lua.createBooleanLiteral(true));
91-
} else if (isInAsyncFunction(node)) {
92-
// Add nil error handler in async function and not in try
93-
results.unshift(lua.createNilLiteral());
9497
}
9598

9699
return lua.createReturnStatement(results, node);

0 commit comments

Comments
 (0)