Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 15 additions & 33 deletions src/lualib/Await.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,44 @@

import { __TS__Promise } from "./Promise";

type ErrorHandler = (this: void, error: unknown) => unknown;

// eslint-disable-next-line @typescript-eslint/promise-function-async
export function __TS__AsyncAwaiter(this: void, generator: (this: void) => void) {
return new Promise((resolve, reject) => {
let resolved = false;
const asyncCoroutine = coroutine.create(generator);

// eslint-disable-next-line @typescript-eslint/promise-function-async
function adopt(value: unknown) {
return value instanceof __TS__Promise ? value : Promise.resolve(value);
}
function fulfilled(value: unknown) {
const [success, errorOrErrorHandler, resultOrError] = coroutine.resume(asyncCoroutine, value);
const [success, resultOrError] = coroutine.resume(asyncCoroutine, value);
if (success) {
step(resultOrError, errorOrErrorHandler);
} else {
reject(errorOrErrorHandler);
}
}
function rejected(handler: ErrorHandler | undefined) {
if (handler) {
return (value: unknown) => {
const [success, hasReturnedOrError, returnedValue] = pcall(handler, value);
if (success) {
if (hasReturnedOrError) {
resolve(returnedValue);
} else {
step(hasReturnedOrError, handler);
}
} else {
reject(hasReturnedOrError);
}
};
step(resultOrError);
} else {
// If no catch clause, just reject
return value => {
reject(value);
};
reject(resultOrError);
}
}
function step(result: unknown, errorHandler: ErrorHandler | undefined) {
function step(result: unknown) {
if (resolved) return;
if (coroutine.status(asyncCoroutine) === "dead") {
resolve(result);
} else {
adopt(result).then(fulfilled, rejected(errorHandler));
adopt(result).then(fulfilled, reject);
}
}
const [success, errorOrErrorHandler, resultOrError] = coroutine.resume(asyncCoroutine);
const [success, resultOrError] = coroutine.resume(asyncCoroutine, (v: unknown) => {
resolved = true;
adopt(v).then(resolve, reject);
});
if (success) {
step(resultOrError, errorOrErrorHandler);
step(resultOrError);
} else {
reject(errorOrErrorHandler);
reject(resultOrError);
}
});
}

export function __TS__Await(this: void, errorHandler: ErrorHandler, thing: unknown) {
return coroutine.yield(errorHandler, thing);
export function __TS__Await(this: void, thing: unknown) {
return coroutine.yield(thing);
}
29 changes: 13 additions & 16 deletions src/transformation/visitors/async-await.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,32 @@ import * as lua from "../../LuaAST";
import { FunctionVisitor, TransformationContext } from "../context";
import { awaitMustBeInAsyncFunction } from "../utils/diagnostics";
import { importLuaLibFeature, LuaLibFeature, transformLuaLibFunction } from "../utils/lualib";
import { findFirstNodeAbove } from "../utils/typescript";
import { isInAsyncFunction } from "../utils/typescript";

export const transformAwaitExpression: FunctionVisitor<ts.AwaitExpression> = (node, context) => {
// Check if await is inside an async function, it is not allowed at top level or in non-async functions
const containingFunction = findFirstNodeAbove(node, ts.isFunctionLike);
if (
containingFunction === undefined ||
!containingFunction.modifiers?.some(m => m.kind === ts.SyntaxKind.AsyncKeyword)
) {
if (!isInAsyncFunction(node)) {
context.diagnostics.push(awaitMustBeInAsyncFunction(node));
}

const expression = context.transformExpression(node.expression);
const catchIdentifier = lua.createIdentifier("____catch");
return transformLuaLibFunction(context, LuaLibFeature.Await, node, catchIdentifier, expression);
return transformLuaLibFunction(context, LuaLibFeature.Await, node, expression);
};

export function isAsyncFunction(declaration: ts.FunctionLikeDeclaration): boolean {
return declaration.modifiers?.some(m => m.kind === ts.SyntaxKind.AsyncKeyword) ?? false;
}

export function wrapInAsyncAwaiter(context: TransformationContext, statements: lua.Statement[]): lua.Statement[] {
export function wrapInAsyncAwaiter(
context: TransformationContext,
statements: lua.Statement[],
includeResolveParameter = true
): lua.CallExpression {
importLuaLibFeature(context, LuaLibFeature.Await);

return [
lua.createReturnStatement([
lua.createCallExpression(lua.createIdentifier("__TS__AsyncAwaiter"), [
lua.createFunctionExpression(lua.createBlock(statements)),
]),
]),
];
const parameters = includeResolveParameter ? [lua.createIdentifier("____awaiter_resolve")] : [];

return lua.createCallExpression(lua.createIdentifier("__TS__AsyncAwaiter"), [
lua.createFunctionExpression(lua.createBlock(statements), parameters),
]);
}
96 changes: 77 additions & 19 deletions src/transformation/visitors/errors.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
import * as ts from "typescript";
import { LuaTarget } from "../..";
import { LuaLibFeature, LuaTarget } from "../..";
import * as lua from "../../LuaAST";
import { FunctionVisitor } from "../context";
import { FunctionVisitor, TransformationContext } from "../context";
import { unsupportedForTarget, unsupportedForTargetButOverrideAvailable } from "../utils/diagnostics";
import { createUnpackCall } from "../utils/lua-ast";
import { ScopeType } from "../utils/scope";
import { transformLuaLibFunction } from "../utils/lualib";
import { Scope, ScopeType } from "../utils/scope";
import { isInAsyncFunction, isInGeneratorFunction } from "../utils/typescript";
import { wrapInAsyncAwaiter } from "./async-await";
import { transformScopeBlock } from "./block";
import { transformIdentifier } from "./identifier";
import { isInMultiReturnFunction } from "./language-extensions/multi";
import { createReturnStatement } from "./return";

export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statement, context) => {
const [tryBlock, tryScope] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);
const transformAsyncTry: FunctionVisitor<ts.TryStatement> = (statement, context) => {
const [tryBlock] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);

if (
context.options.luaTarget === LuaTarget.Lua51 &&
isInAsyncFunction(statement) &&
!context.options.lua51AllowTryCatchInAsyncAwait
) {
if (context.options.luaTarget === LuaTarget.Lua51 && !context.options.lua51AllowTryCatchInAsyncAwait) {
context.diagnostics.push(
unsupportedForTargetButOverrideAvailable(
statement,
Expand All @@ -30,6 +28,57 @@ export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statemen
return tryBlock.statements;
}

// __TS__AsyncAwaiter(<catch block>)
const awaiter = wrapInAsyncAwaiter(context, tryBlock.statements, false);
const awaiterIdentifier = lua.createIdentifier("____try");
const awaiterDefinition = lua.createVariableDeclarationStatement(awaiterIdentifier, awaiter);

// local ____try = __TS__AsyncAwaiter(<catch block>)
const result: lua.Statement[] = [awaiterDefinition];

if (statement.finallyBlock) {
const awaiterFinally = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("finally"));
const finallyFunction = lua.createFunctionExpression(
lua.createBlock(context.transformStatements(statement.finallyBlock.statements))
);
const finallyCall = lua.createCallExpression(
awaiterFinally,
[awaiterIdentifier, finallyFunction],
statement.finallyBlock
);
// ____try.finally(<finally function>)
result.push(lua.createExpressionStatement(finallyCall));
}

if (statement.catchClause) {
// ____try.catch(<catch function>)
const [catchFunction] = transformCatchClause(context, statement.catchClause);
if (catchFunction.params) {
catchFunction.params.unshift(lua.createAnonymousIdentifier());
}

const awaiterCatch = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("catch"));
const catchCall = lua.createCallExpression(awaiterCatch, [awaiterIdentifier, catchFunction]);

// await ____try.catch(<catch function>)
const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, catchCall);
result.push(lua.createExpressionStatement(promiseAwait, statement));
} else {
// await ____try
const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, awaiterIdentifier);
result.push(lua.createExpressionStatement(promiseAwait, statement));
}

return result;
};

export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statement, context) => {
if (isInAsyncFunction(statement)) {
return transformAsyncTry(statement, context);
}

const [tryBlock, tryScope] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);

if (context.options.luaTarget === LuaTarget.Lua51 && isInGeneratorFunction(statement)) {
context.diagnostics.push(
unsupportedForTarget(statement, "try/catch inside generator functions", LuaTarget.Lua51)
Expand All @@ -50,15 +99,7 @@ export const transformTryStatement: FunctionVisitor<ts.TryStatement> = (statemen

if (statement.catchClause && statement.catchClause.block.statements.length > 0) {
// try with catch
const [catchBlock, catchScope] = transformScopeBlock(context, statement.catchClause.block, ScopeType.Catch);

const catchParameter = statement.catchClause.variableDeclaration
? transformIdentifier(context, statement.catchClause.variableDeclaration.name as ts.Identifier)
: undefined;
const catchFunction = lua.createFunctionExpression(
catchBlock,
catchParameter ? [lua.cloneIdentifier(catchParameter)] : []
);
const [catchFunction, catchScope] = transformCatchClause(context, statement.catchClause);
const catchIdentifier = lua.createIdentifier("____catch");
result.push(lua.createVariableDeclarationStatement(catchIdentifier, catchFunction));

Expand Down Expand Up @@ -138,3 +179,20 @@ export const transformThrowStatement: FunctionVisitor<ts.ThrowStatement> = (stat
statement
);
};

function transformCatchClause(
context: TransformationContext,
catchClause: ts.CatchClause
): [lua.FunctionExpression, Scope] {
const [catchBlock, catchScope] = transformScopeBlock(context, catchClause.block, ScopeType.Catch);

const catchParameter = catchClause.variableDeclaration
? transformIdentifier(context, catchClause.variableDeclaration.name as ts.Identifier)
: undefined;
const catchFunction = lua.createFunctionExpression(
catchBlock,
catchParameter ? [lua.cloneIdentifier(catchParameter)] : []
);

return [catchFunction, catchScope];
}
2 changes: 1 addition & 1 deletion src/transformation/visitors/function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ export function transformFunctionBody(
scope.node = node;
let bodyStatements = transformFunctionBodyContent(context, body);
if (node && isAsyncFunction(node)) {
bodyStatements = wrapInAsyncAwaiter(context, bodyStatements);
bodyStatements = [lua.createReturnStatement([wrapInAsyncAwaiter(context, bodyStatements)])];
}
const headerStatements = transformFunctionBodyHeader(context, scope, parameters, spreadIdentifier);
popScope(context);
Expand Down
9 changes: 6 additions & 3 deletions src/transformation/visitors/return.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,15 @@ export function createReturnStatement(
): lua.ReturnStatement {
const results = [...values];

if (isInAsyncFunction(node)) {
return lua.createReturnStatement([
lua.createCallExpression(lua.createIdentifier("____awaiter_resolve"), [lua.createNilLiteral(), ...values]),
]);
}

if (isInTryCatch(context)) {
// Bubble up explicit return flag and check if we're inside a try/catch block
results.unshift(lua.createBooleanLiteral(true));
} else if (isInAsyncFunction(node)) {
// Add nil error handler in async function and not in try
results.unshift(lua.createNilLiteral());
}

return lua.createReturnStatement(results, node);
Expand Down
Loading