Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions src/lualib/Await.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
// };
//

type ErrorHandler = (this: void, error: unknown) => unknown;

// eslint-disable-next-line @typescript-eslint/promise-function-async
function __TS__AsyncAwaiter(this: void, generator: (this: void) => void) {
return new Promise((resolve, reject) => {
Expand All @@ -23,30 +25,47 @@ function __TS__AsyncAwaiter(this: void, generator: (this: void) => void) {
function adopt(value: unknown) {
return value instanceof __TS__Promise ? value : Promise.resolve(value);
}
function fulfilled(value) {
const [success, resultOrError] = coroutine.resume(asyncCoroutine, value);
function fulfilled(value: unknown) {
const [success, errorOrErrorHandler, resultOrError] = coroutine.resume(asyncCoroutine, value);
if (success) {
step(resultOrError);
step(resultOrError, errorOrErrorHandler);
} else {
reject(resultOrError);
}
}
function step(result: unknown) {
function rejected(handler: ErrorHandler | undefined) {
if (handler) {
return (value: unknown) => {
const [success, valueOrError] = pcall(handler, value);
if (success) {
step(valueOrError, handler);
} else {
reject(valueOrError);
}
};
} else {
// If no catch clause, just reject
return value => {
reject(value);
};
}
}
function step(result: unknown, errorHandler: ErrorHandler | undefined) {
if (coroutine.status(asyncCoroutine) === "dead") {
resolve(result);
} else {
adopt(result).then(fulfilled, reason => reject(reason));
adopt(result).then(fulfilled, rejected(errorHandler));
}
}
const [success, resultOrError] = coroutine.resume(asyncCoroutine);
const [success, errorOrErrorHandler, resultOrError] = coroutine.resume(asyncCoroutine);
if (success) {
step(resultOrError);
step(resultOrError, errorOrErrorHandler);
} else {
reject(resultOrError);
reject(errorOrErrorHandler);
}
});
}

function __TS__Await(this: void, thing: unknown) {
return coroutine.yield(thing);
function __TS__Await(this: void, errorHandler: ErrorHandler, thing: unknown) {
return coroutine.yield(errorHandler, thing);
}
3 changes: 2 additions & 1 deletion src/transformation/visitors/async-await.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ export const transformAwaitExpression: FunctionVisitor<ts.AwaitExpression> = (no
}

const expression = context.transformExpression(node.expression);
return transformLuaLibFunction(context, LuaLibFeature.Await, node, expression);
const catchIdentifier = lua.createIdentifier("____catch");
return transformLuaLibFunction(context, LuaLibFeature.Await, node, catchIdentifier, expression);
};

export function isAsyncFunction(declaration: ts.FunctionLikeDeclaration): boolean {
Expand Down
56 changes: 23 additions & 33 deletions src/transformation/visitors/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ import * as ts from "typescript";
import * as lua from "../../LuaAST";
import { FunctionVisitor } from "../context";
import { createUnpackCall } from "../utils/lua-ast";
import { findScope, ScopeType } from "../utils/scope";
import { ScopeType } from "../utils/scope";
import { transformScopeBlock } from "./block";
import { transformIdentifier } from "./identifier";
import { isInMultiReturnFunction } from "./language-extensions/multi";
import { createReturnStatement } from "./return";

export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statement, context) => {
const [tryBlock, tryScope] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);
Expand All @@ -15,24 +16,24 @@ export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statemen

const result: lua.Statement[] = [];

let returnedIdentifier: lua.Identifier | undefined;
const returnedIdentifier = lua.createIdentifier("____hasReturned");
let returnCondition: lua.Expression | undefined;

const pCall = lua.createIdentifier("pcall");
const tryCall = lua.createCallExpression(pCall, [lua.createFunctionExpression(tryBlock)]);

if (statement.catchClause && statement.catchClause.block.statements.length > 0) {
// try with catch
let [catchBlock, catchScope] = transformScopeBlock(context, statement.catchClause.block, ScopeType.Catch);
if (statement.catchClause.variableDeclaration) {
// Replace ____returned with catch variable
returnedIdentifier = transformIdentifier(
context,
statement.catchClause.variableDeclaration.name as ts.Identifier
);
} else if (tryScope.functionReturned || catchScope.functionReturned) {
returnedIdentifier = lua.createIdentifier("____returned");
}
const [catchBlock, catchScope] = transformScopeBlock(context, statement.catchClause.block, ScopeType.Catch);

const catchParameter = statement.catchClause.variableDeclaration
? transformIdentifier(context, statement.catchClause.variableDeclaration.name as ts.Identifier)
: undefined;
const catchParameters = () => (catchParameter ? [lua.cloneIdentifier(catchParameter)] : []);

const catchIdentifier = lua.createIdentifier("____catch");
const catchFunction = lua.createFunctionExpression(catchBlock, catchParameters());
result.push(lua.createVariableDeclarationStatement(catchIdentifier, catchFunction));

const tryReturnIdentifiers = [tryResultIdentifier]; // ____try
if (returnedIdentifier) {
Expand All @@ -44,20 +45,18 @@ export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statemen
}
result.push(lua.createVariableDeclarationStatement(tryReturnIdentifiers, tryCall));

if ((tryScope.functionReturned || catchScope.functionReturned) && returnedIdentifier) {
// Wrap catch in function if try or catch has return
const catchCall = lua.createCallExpression(lua.createFunctionExpression(catchBlock), []);
const catchAssign = lua.createAssignmentStatement(
[lua.cloneIdentifier(returnedIdentifier), lua.cloneIdentifier(returnValueIdentifier)],
catchCall
);
catchBlock = lua.createBlock([catchAssign]);
}
// Wrap catch in function if try or catch has return
const catchCall = lua.createCallExpression(catchIdentifier, [lua.cloneIdentifier(returnedIdentifier)]);
const catchAssign = lua.createAssignmentStatement(
[lua.cloneIdentifier(returnedIdentifier), lua.cloneIdentifier(returnValueIdentifier)],
catchCall
);

const notTryCondition = lua.createUnaryExpression(tryResultIdentifier, lua.SyntaxKind.NotOperator);
result.push(lua.createIfStatement(notTryCondition, catchBlock));
result.push(lua.createIfStatement(notTryCondition, lua.createBlock([catchAssign])));
} else if (tryScope.functionReturned) {
// try with return, but no catch
returnedIdentifier = lua.createIdentifier("____returned");
// returnedIdentifier = lua.createIdentifier("____returned");
const returnedVariables = [tryResultIdentifier, returnedIdentifier, returnValueIdentifier];
result.push(lua.createVariableDeclarationStatement(returnedVariables, tryCall));

Expand All @@ -77,24 +76,15 @@ export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statemen
}

if (returnCondition && returnedIdentifier) {
// With catch clause:
// if ____returned then return ____returnValue end
// No catch clause:
// if ____try and ____returned then return ____returnValue end
const returnValues: lua.Expression[] = [];
const parentTryCatch = findScope(context, ScopeType.Function | ScopeType.Try | ScopeType.Catch);
if (parentTryCatch && parentTryCatch.type !== ScopeType.Function) {
// Nested try/catch needs to prefix a 'true' return value
returnValues.push(lua.createBooleanLiteral(true));
}

if (isInMultiReturnFunction(context, statement)) {
returnValues.push(createUnpackCall(context, lua.cloneIdentifier(returnValueIdentifier)));
} else {
returnValues.push(lua.cloneIdentifier(returnValueIdentifier));
}

const returnStatement = lua.createReturnStatement(returnValues);
const returnStatement = createReturnStatement(context, returnValues, statement);
const ifReturnedStatement = lua.createIfStatement(returnCondition, lua.createBlock([returnStatement]));
result.push(ifReturnedStatement);
}
Expand Down
63 changes: 46 additions & 17 deletions src/transformation/visitors/return.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
canBeMultiReturnType,
} from "./language-extensions/multi";
import { invalidMultiFunctionReturnType } from "../utils/diagnostics";
import { findFirstNodeAbove } from "../utils/typescript";

function transformExpressionsInReturn(
context: TransformationContext,
Expand Down Expand Up @@ -55,22 +56,10 @@ export function transformExpressionBodyToReturnStatement(
node: ts.Expression
): lua.Statement {
const expressions = transformExpressionsInReturn(context, node, false);
return lua.createReturnStatement(expressions, node);
return createReturnStatement(context, expressions, node);
}

export const transformReturnStatement: FunctionVisitor<ts.ReturnStatement> = (statement, context) => {
// Bubble up explicit return flag and check if we're inside a try/catch block
let insideTryCatch = false;
for (const scope of walkScopesUp(context)) {
scope.functionReturned = true;

if (scope.type === ScopeType.Function) {
break;
}

insideTryCatch = insideTryCatch || scope.type === ScopeType.Try || scope.type === ScopeType.Catch;
}

let results: lua.Expression[];

if (statement.expression) {
Expand All @@ -80,15 +69,55 @@ export const transformReturnStatement: FunctionVisitor<ts.ReturnStatement> = (st
validateAssignment(context, statement, expressionType, returnType);
}

results = transformExpressionsInReturn(context, statement.expression, insideTryCatch);
results = transformExpressionsInReturn(context, statement.expression, isInTryCatch(context));
} else {
// Empty return
results = [];
}

if (insideTryCatch) {
return createReturnStatement(context, results, statement);
};

export function createReturnStatement(
context: TransformationContext,
values: lua.Expression[],
node: ts.Node
): lua.ReturnStatement {
const results = [...values];

if (isInTryCatch(context)) {
// Bubble up explicit return flag and check if we're inside a try/catch block
results.unshift(lua.createBooleanLiteral(true));
} else if (isInAsyncFunction(node)) {
// Add nil error handler in async function and not in try
results.unshift(lua.createNilLiteral());
}

return lua.createReturnStatement(results, statement);
};
return lua.createReturnStatement(results, node);
}

function isInAsyncFunction(node: ts.Node): boolean {
// Check if node is in function declaration with `async`
const declaration = findFirstNodeAbove(node, ts.isFunctionLike);
if (!declaration) {
return false;
}

return declaration.modifiers?.some(m => m.kind === ts.SyntaxKind.AsyncKeyword) ?? false;
}

function isInTryCatch(context: TransformationContext): boolean {
// Check if context is in a try or catch
let insideTryCatch = false;
for (const scope of walkScopesUp(context)) {
scope.functionReturned = true;

if (scope.type === ScopeType.Function) {
break;
}

insideTryCatch = insideTryCatch || scope.type === ScopeType.Try || scope.type === ScopeType.Catch;
}

return insideTryCatch;
}
67 changes: 67 additions & 0 deletions test/unit/builtins/async-await.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,19 @@ test.each(["async function abc() {", "const abc = async () => {"])(
}
);

test("can make inline async functions", () => {
util.testFunction`
const foo = async function() { return "foo"; };
const bar = async function() { return await foo(); };

const { state, value } = bar() as any;
return { state, value };
`.expectToEqual({
state: 1, // __TS__PromiseState.Fulfilled
value: "foo",
});
});

test("can make async lambdas with expression body", () => {
util.testFunction`
const foo = async () => "foo";
Expand Down Expand Up @@ -369,3 +382,57 @@ test("async function can forward varargs", () => {
.setTsHeader(promiseTestLib)
.expectToEqual(["resolved", "A", "B", "C"]);
});

// https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1105
describe("try/catch in async function", () => {
test("await inside try/catch returns inside async function", () => {
util.testModule`
export let result = 0;
async function foo(): Promise<number> {
try {
return await new Promise(resolve => resolve(4));
} catch {
throw "an error occurred in the async function"
}
}
foo().then(value => {
result = value;
});
`.expectToEqual({ result: 4 });
});

test("await inside try/catch throws inside async function", () => {
util.testModule`
export let reason = "";
async function foo(): Promise<number> {
try {
return await new Promise((resolve, reject) => reject("test error"));
} catch (e) {
throw "an error occurred in the async function: " + e;
}
}
foo().catch(e => {
reason = e;
});
`.expectToEqual({ reason: "an error occurred in the async function: test error" });
});

test("await inside try/catch deferred rejection uses catch clause", () => {
util.testModule`
export let reason = "";
let reject: (reason: string) => void;

async function foo(): Promise<number> {
try {
return await new Promise((res, rej) => { reject = rej; });
} catch (e) {
throw "an error occurred in the async function: " + e;
}
}
foo().catch(e => {
reason = e;
});
reject("test error");
`.expectToEqual({ reason: "an error occurred in the async function: test error" });
});
});