Skip to content
Merged
6 changes: 6 additions & 0 deletions src/LuaAST.ts
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,12 @@ export function createIdentifier(
return expression;
}

export function createAnnonymousIdentifier(tsOriginal?: ts.Node, parent?: Node): Identifier {
const expression = createNode(SyntaxKind.Identifier, tsOriginal, parent) as Identifier;
expression.text = "____";
return expression;
}

export interface TableIndexExpression extends Expression {
kind: SyntaxKind.TableIndexExpression;
table: Expression;
Expand Down
168 changes: 159 additions & 9 deletions src/LuaTransformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import {TSTLErrors} from "./TSTLErrors";

export type StatementVisitResult = tstl.Statement | tstl.Statement[] | undefined;
export type ExpressionVisitResult = tstl.Expression | undefined;

export enum ScopeType {
File = 0x1,
Function = 0x2,
Expand Down Expand Up @@ -1031,6 +1030,144 @@ export class LuaTransformer {
});
}

private transformGeneratorFunction(
parameters: ts.NodeArray<ts.ParameterDeclaration>,
body: ts.Block,
transformedParameters: tstl.Identifier[],
dotsLiteral: tstl.DotsLiteral,
spreadIdentifier?: tstl.Identifier
): [tstl.Statement[], Scope]
{
this.importLuaLibFeature(LuaLibFeature.Symbol);
const [functionBody, functionScope] = this.transformFunctionBody(
parameters,
body,
spreadIdentifier
);

const coroutineIdentifier = tstl.createIdentifier("____co");
const valueIdentifier = tstl.createIdentifier("____value");
const errIdentifier = tstl.createIdentifier("____err");
const itIdentifier = tstl.createIdentifier("____it");

//local ____co = coroutine.create(originalFunction)
const coroutine =
tstl.createVariableDeclarationStatement(coroutineIdentifier,
tstl.createCallExpression(
tstl.createTableIndexExpression(tstl.createIdentifier("coroutine"),
tstl.createStringLiteral("create")
),
[tstl.createFunctionExpression(
tstl.createBlock(functionBody),
transformedParameters,
dotsLiteral,
spreadIdentifier),
]
)
);

const nextBody = [];
// coroutine.resume(__co, ...)
const resumeCall = tstl.createCallExpression(
tstl.createTableIndexExpression(
tstl.createIdentifier("coroutine"),
tstl.createStringLiteral("resume")
),
[coroutineIdentifier, tstl.createDotsLiteral()]
);

// ____err, ____value = coroutine.resume(____co, ...)
nextBody.push(tstl.createVariableDeclarationStatement(
[errIdentifier, valueIdentifier],
resumeCall)
);

//coroutine.status(____co) ~= "dead";
const coStatus = tstl.createCallExpression(
tstl.createTableIndexExpression(
tstl.createIdentifier("coroutine"),
tstl.createStringLiteral("status")
),
[coroutineIdentifier]
);
const status = tstl.createBinaryExpression(
coStatus,
tstl.createStringLiteral("dead"),
tstl.SyntaxKind.EqualityOperator
);
nextBody.push(status);
//if(not ____err){error(____value)}
const errorCheck = tstl.createIfStatement(
tstl.createUnaryExpression(
errIdentifier,
tstl.SyntaxKind.NotOperator
),
tstl.createBlock([
tstl.createExpressionStatement(
tstl.createCallExpression(
tstl.createIdentifier("error"),
[valueIdentifier]
)
),
])
);
nextBody.push(errorCheck);
//{done = coroutine.status(____co) ~= "dead"; value = ____value}
const iteratorResult = tstl.createTableExpression([
tstl.createTableFieldExpression(
status,
tstl.createStringLiteral("done")
),
tstl.createTableFieldExpression(
valueIdentifier,
tstl.createStringLiteral("value")
),
]);
nextBody.push(tstl.createReturnStatement([iteratorResult]));

//function(____, ...)
const nextFunctionDeclaration = tstl.createFunctionExpression(
tstl.createBlock(nextBody),
[tstl.createAnnonymousIdentifier()],
tstl.createDotsLiteral());

//____it = {next = function(____, ...)}
const iterator = tstl.createVariableDeclarationStatement(
itIdentifier,
tstl.createTableExpression([
tstl.createTableFieldExpression(
nextFunctionDeclaration,
tstl.createStringLiteral("next")
),
])
);

const symbolIterator = tstl.createTableIndexExpression(
tstl.createIdentifier("Symbol"),
tstl.createStringLiteral("iterator")
);

const block = [
coroutine,
iterator,
//____it[Symbol.iterator] = {return ____it}
tstl.createAssignmentStatement(
tstl.createTableIndexExpression(
itIdentifier,
symbolIterator
),
tstl.createFunctionExpression(
tstl.createBlock(
[tstl.createReturnStatement([itIdentifier])]
)
)
),
//return ____it
tstl.createReturnStatement([itIdentifier]),
];
return [block, functionScope];
}

public transformFunctionDeclaration(functionDeclaration: ts.FunctionDeclaration): StatementVisitResult {
// Don't transform functions without body (overload declarations)
if (!functionDeclaration.body) {
Expand All @@ -1044,22 +1181,28 @@ export class LuaTransformer {
const [params, dotsLiteral, restParamName] = this.transformParameters(functionDeclaration.parameters, context);

const name = this.transformIdentifier(functionDeclaration.name);
const [body, functionScope] = this.transformFunctionBody(
functionDeclaration.parameters,
functionDeclaration.body,
restParamName
);
const [body, functionScope] = functionDeclaration.asteriskToken
? this.transformGeneratorFunction(
functionDeclaration.parameters,
functionDeclaration.body,
params,
dotsLiteral,
restParamName
)
: this.transformFunctionBody(
functionDeclaration.parameters,
functionDeclaration.body,
restParamName
);
const block = tstl.createBlock(body);
const functionExpression = tstl.createFunctionExpression(block, params, dotsLiteral, restParamName);

// Remember symbols referenced in this function for hoisting later
if (!this.options.noHoisting && name.symbolId !== undefined) {
const scope = this.peekScope();
if (!scope.functionDefinitions) { scope.functionDefinitions = new Map(); }
const functionInfo = {referencedSymbols: functionScope.referencedSymbols || new Set()};
scope.functionDefinitions.set(name.symbolId, functionInfo);
}

return this.createLocalOrExportedOrGlobalDeclaration(name, functionExpression, functionDeclaration);
}

Expand Down Expand Up @@ -1219,6 +1362,12 @@ export class LuaTransformer {
return tstl.createExpressionStatement(this.transformExpression(expression));
}

public transformYield(expression: ts.YieldExpression): tstl.Expression {
return tstl.createCallExpression(
tstl.createTableIndexExpression(tstl.createIdentifier("coroutine"), tstl.createStringLiteral("yield")),
expression.expression?[this.transformExpression(expression.expression)]:[], expression);
}

public transformReturn(statement: ts.ReturnStatement): tstl.Statement {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

if (statement.expression) {
const returnType = tsHelper.getContainingFunctionReturnType(statement, this.checker);
Expand Down Expand Up @@ -1731,6 +1880,8 @@ export class LuaTransformer {
return this.transformSpreadElement(expression as ts.SpreadElement);
case ts.SyntaxKind.NonNullExpression:
return this.transformExpression((expression as ts.NonNullExpression).expression);
case ts.SyntaxKind.YieldExpression:
return this.transformYield(expression as ts.YieldExpression);
case ts.SyntaxKind.EmptyStatement:
return undefined;
case ts.SyntaxKind.NotEmittedStatement:
Expand Down Expand Up @@ -1871,7 +2022,6 @@ export class LuaTransformer {
throw TSTLErrors.UnsupportedUnionAccessor(lhs);
}
}

return tstl.createAssignmentStatement(
this.transformExpression(lhs) as tstl.IdentifierOrTableIndexExpression,
right,
Expand Down
58 changes: 58 additions & 0 deletions test/unit/functions.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,64 @@ export class FunctionTests {
Expect(result).toBe("foobar");
}

@TestCase(1, 1)
@TestCase(2, 42)
@Test("Generator functions value")
public generatorFunctionValue(iterations: number, expectedResult: number): void {
const code = `function* seq(value: number) {
let a = yield value + 1;
return 42;
}
const gen = seq(0);
let ret: number;
for(let i = 0; i < ${iterations}; ++i)
{
ret = gen.next(i).value;
}
return ret;
`;
const result = util.transpileAndExecute(code);
Expect(result).toBe(expectedResult);
}

@TestCase(1, false)
@TestCase(2, true)
@Test("Generator functions done")
public generatorFunctionDone(iterations: number, expectedResult: boolean): void {
const code = `function* seq(value: number) {
let a = yield value + 1;
return 42;
}
const gen = seq(0);
let ret: boolean;
for(let i = 0; i < ${iterations}; ++i)
{
ret = gen.next(i).done;
}
return ret;
`;
const result = util.transpileAndExecute(code);
Expect(result).toBe(expectedResult);
}

@Test("Generator for..of")
public generatorFunctionForOf(): void {
const code = `function* seq() {
yield(1);
yield(2);
yield(3);
return 4;
}
let result = 0;
for(let i of seq())
{
result = result * 10 + i;
}
return result`;
const result = util.transpileAndExecute(code);
Expect(result).toBe(123);
}

@Test("Function local overriding export")
public functionLocalOverridingExport(): void {
const code =
Expand Down