Skip to content
3 changes: 3 additions & 0 deletions src/Decorator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ export class Decorator {
return DecoratorKind.NoSelf;
case "noselfinfile":
return DecoratorKind.NoSelfInFile;
case "vararg":
return DecoratorKind.Vararg;
case "forrange":
return DecoratorKind.ForRange;
}
Expand Down Expand Up @@ -63,5 +65,6 @@ export enum DecoratorKind {
LuaTable = "LuaTable",
NoSelf = "NoSelf",
NoSelfInFile = "NoSelfInFile",
Vararg = "Vararg",
ForRange = "ForRange",
}
48 changes: 37 additions & 11 deletions src/LuaTransformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ interface SymbolInfo {
}

interface FunctionDefinitionInfo {
referencedSymbols: Set<tstl.SymbolId>;
referencedSymbols: Map<tstl.SymbolId, ts.Node[]>;
definition?: tstl.VariableDeclarationStatement | tstl.AssignmentStatement;
}

interface Scope {
type: ScopeType;
id: number;
referencedSymbols?: Set<tstl.SymbolId>;
referencedSymbols?: Map<tstl.SymbolId, ts.Node[]>;
variableDeclarations?: tstl.VariableDeclarationStatement[];
functionDefinitions?: Map<tstl.SymbolId, FunctionDefinitionInfo>;
importStatements?: tstl.Statement[];
Expand Down Expand Up @@ -1388,12 +1388,31 @@ export class LuaTransformer {
return [paramNames, dotsLiteral, restParamName];
}

protected isRestParameterReferenced(identifier: tstl.Identifier, scope: Scope): boolean {
if (!identifier.symbolId) {
return true;
}
if (scope.referencedSymbols === undefined) {
return false;
}
const references = scope.referencedSymbols.get(identifier.symbolId);
if (!references) {
return false;
}
// Ignore references to @vararg types in spread elements
return references.some(
r => !r.parent || !ts.isSpreadElement(r.parent) || !tsHelper.isVarArgType(r, this.checker)
);
}

protected transformFunctionBody(
parameters: ts.NodeArray<ts.ParameterDeclaration>,
body: ts.Block,
spreadIdentifier?: tstl.Identifier
): [tstl.Statement[], Scope] {
this.pushScope(ScopeType.Function);
const bodyStatements = this.performHoisting(this.transformStatements(body.statements));
const scope = this.popScope();

const headerStatements = [];

Expand Down Expand Up @@ -1426,18 +1445,14 @@ export class LuaTransformer {
}

// Push spread operator here
if (spreadIdentifier) {
if (spreadIdentifier && this.isRestParameterReferenced(spreadIdentifier, scope)) {
const spreadTable = this.wrapInTable(tstl.createDotsLiteral());
headerStatements.push(tstl.createVariableDeclarationStatement(spreadIdentifier, spreadTable));
}

// Binding pattern statements need to be after spread table is declared
headerStatements.push(...bindingPatternDeclarations);

const bodyStatements = this.performHoisting(this.transformStatements(body.statements));

const scope = this.popScope();

return [headerStatements.concat(bodyStatements), scope];
}

Expand Down Expand Up @@ -1844,7 +1859,7 @@ export class LuaTransformer {
if (!scope.functionDefinitions) {
scope.functionDefinitions = new Map();
}
const functionInfo = { referencedSymbols: functionScope.referencedSymbols || new Set() };
const functionInfo = { referencedSymbols: functionScope.referencedSymbols || new Map() };
scope.functionDefinitions.set(name.symbolId, functionInfo);
}
return this.createLocalOrExportedOrGlobalDeclaration(name, functionExpression, functionDeclaration);
Expand Down Expand Up @@ -4603,6 +4618,10 @@ export class LuaTransformer {
return innerExpression;
}

if (ts.isIdentifier(expression.expression) && tsHelper.isVarArgType(expression.expression, this.checker)) {
return tstl.createDotsLiteral(expression);
}

const type = this.checker.getTypeAtLocation(expression.expression);
if (tsHelper.isArrayType(type, this.checker, this.program)) {
return this.createUnpackCall(innerExpression, expression);
Expand Down Expand Up @@ -5278,13 +5297,20 @@ export class LuaTransformer {
if (declaration && identifier.pos < declaration.pos) {
throw TSTLErrors.ReferencedBeforeDeclaration(identifier);
}
} else if (symbolId !== undefined) {
}

if (symbolId !== undefined) {
//Mark symbol as seen in all current scopes
for (const scope of this.scopeStack) {
if (!scope.referencedSymbols) {
scope.referencedSymbols = new Set();
scope.referencedSymbols = new Map();
}
let references = scope.referencedSymbols.get(symbolId);
if (!references) {
references = [];
scope.referencedSymbols.set(symbolId, references);
}
scope.referencedSymbols.add(symbolId);
references.push(identifier);
}
}
}
Expand Down
17 changes: 17 additions & 0 deletions src/TSHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,23 @@ export class TSHelper {
return TSHelper.getCustomDecorators(type, checker).has(DecoratorKind.LuaIterator);
}

public static isRestParameter(node: ts.Node, checker: ts.TypeChecker): boolean {
const symbol = checker.getSymbolAtLocation(node);
if (!symbol) {
return false;
}
const declarations = symbol.getDeclarations();
if (!declarations) {
return false;
}
return declarations.some(d => ts.isParameter(d) && d.dotDotDotToken !== undefined);
}

public static isVarArgType(node: ts.Node, checker: ts.TypeChecker): boolean {
const type = checker.getTypeAtLocation(node);
return type !== undefined && TSHelper.getCustomDecorators(type, checker).has(DecoratorKind.Vararg);
}

public static isForRangeType(node: ts.Node, checker: ts.TypeChecker): boolean {
const type = checker.getTypeAtLocation(node);
return TSHelper.getCustomDecorators(type, checker).has(DecoratorKind.ForRange);
Expand Down
2 changes: 1 addition & 1 deletion test/translation/__snapshots__/transformation.spec.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ end"
exports[`Transformation (functionRestArguments) 1`] = `
"function varargsFunction(self, a, ...)
local b = ({...})
local c = b
end"
`;

Expand Down Expand Up @@ -319,7 +320,6 @@ end
function MyClass.prototype.____constructor(self)
end
function MyClass.prototype.varargsFunction(self, a, ...)
local b = ({...})
end"
`;

Expand Down
4 changes: 3 additions & 1 deletion test/translation/transformation/functionRestArguments.ts
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
function varargsFunction(a: string, ...b: string[]): void {}
function varargsFunction(a: string, ...b: string[]): void {
const c = b;
}
106 changes: 106 additions & 0 deletions test/unit/functions.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,109 @@ test("Function rest binding pattern", () => {

expect(result).toBe("defxyzabc");
});

test.each([{}, { noHoisting: true }])("Function rest parameter", compilerOptions => {
const code = `
function foo(a: unknown, ...b: string[]) {
return b.join("");
}
return foo("A", "B", "C", "D");
`;

expect(util.transpileAndExecute(code, compilerOptions)).toBe("BCD");
});

test.each([{}, { noHoisting: true }])("Function nested rest parameter", compilerOptions => {
const code = `
function foo(a: unknown, ...b: string[]) {
function bar() {
return b.join("");
}
return bar();
}
return foo("A", "B", "C", "D");
`;

expect(util.transpileAndExecute(code, compilerOptions)).toBe("BCD");
});

test.each([{}, { noHoisting: true }])("Function nested rest spread", compilerOptions => {
const code = `
function foo(a: unknown, ...b: string[]) {
function bar() {
const c = [...b];
return c.join("");
}
return bar();
}
return foo("A", "B", "C", "D");
`;

expect(util.transpileAndExecute(code, compilerOptions)).toBe("BCD");
});

test.each([{}, { noHoisting: true }])("Function rest parameter (unreferenced)", compilerOptions => {
const code = `
function foo(a: unknown, ...b: string[]) {
return "foobar";
}
return foo("A", "B", "C", "D");
`;

expect(util.transpileString(code, compilerOptions)).not.toMatch("b = ({...})");
expect(util.transpileAndExecute(code, compilerOptions)).toBe("foobar");
});

test.each([{}, { noHoisting: true }])("@vararg", compilerOptions => {
const code = `
/** @vararg */ type LuaVarArg<A extends unknown[]> = A & { __luaVarArg?: never };
function foo(a: unknown, ...b: LuaVarArg<unknown[]>) {
const c = [...b];
return c.join("");
}
function bar(a: unknown, ...b: LuaVarArg<unknown[]>) {
return foo(a, ...b);
}
return bar("A", "B", "C", "D");
`;

const lua = util.transpileString(code, compilerOptions);
expect(lua).not.toMatch("b = ({...})");
expect(lua).not.toMatch("unpack");
expect(util.transpileAndExecute(code, compilerOptions)).toBe("BCD");
});

test.each([{}, { noHoisting: true }])("@vararg array access", compilerOptions => {
const code = `
/** @vararg */ type LuaVarArg<A extends unknown[]> = A & { __luaVarArg?: never };
function foo(a: unknown, ...b: LuaVarArg<unknown[]>) {
const c = [...b];
return c.join("") + b[0];
}
return foo("A", "B", "C", "D");
`;

expect(util.transpileAndExecute(code, compilerOptions)).toBe("BCDB");
});

test.each([{}, { noHoisting: true }])("@vararg global", compilerOptions => {
const code = `
/** @vararg */ type LuaVarArg<A extends unknown[]> = A & { __luaVarArg?: never };
declare const arg: LuaVarArg<string[]>;
const arr = [...arg];
const result = arr.join("");
`;

const luaBody = util.transpileString(code, compilerOptions, false);
expect(luaBody).not.toMatch("unpack");

const lua = `
function test(...)
${luaBody}
return result
end
return test("A", "B", "C", "D")
`;

expect(util.executeLua(lua)).toBe("ABCD");
});