Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion language-extensions/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ declare type LuaIterable<TValue, TState = undefined> = Iterable<TValue> &
LuaIterator<TValue, TState> &
LuaExtension<"__luaIterableBrand">;

/**
* Represents an object that can be iterated with pairs()
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
*
* @param TKey The type of the key returned each iteration.
* @param TValue The type of the value returned each iteration.
*/
declare type LuaPairsIterable<TKey extends AnyNotNil, TValue> = Iterable<[TKey, TValue]> &
LuaExtension<"__luaPairsIterableBrand">;

/**
* Calls to functions with this type are translated to `left + right`.
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
Expand Down Expand Up @@ -535,7 +545,7 @@ declare type LuaTableDeleteMethod<TKey extends AnyNotNil> = ((key: TKey) => bool
* @param TKey The type of the keys used to access the table.
* @param TValue The type of the values stored in the table.
*/
declare interface LuaTable<TKey extends AnyNotNil = AnyNotNil, TValue = any> {
declare interface LuaTable<TKey extends AnyNotNil = AnyNotNil, TValue = any> extends LuaPairsIterable<TKey, TValue> {
length: LuaLengthMethod<number>;
get: LuaTableGetMethod<TKey, TValue>;
set: LuaTableSetMethod<TKey, TValue>;
Expand Down
4 changes: 4 additions & 0 deletions src/transformation/utils/diagnostics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ export const invalidMultiIterableWithoutDestructuring = createErrorDiagnosticFac
"LuaIterable with a LuaMultiReturn return value type must be destructured."
);

export const invalidPairsIterableWithoutDestructuring = createErrorDiagnosticFactory(
"LuaPairsIterable type must be destructured in a for...of statement."
);

export const unsupportedAccessorInObjectLiteral = createErrorDiagnosticFactory(
"Accessors in object literal are not supported."
);
Expand Down
2 changes: 2 additions & 0 deletions src/transformation/utils/language-extensions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export enum ExtensionKind {
RangeFunction = "RangeFunction",
VarargConstant = "VarargConstant",
IterableType = "IterableType",
PairsIterableType = "PairsIterableType",
AdditionOperatorType = "AdditionOperatorType",
AdditionOperatorMethodType = "AdditionOperatorMethodType",
SubtractionOperatorType = "SubtractionOperatorType",
Expand Down Expand Up @@ -66,6 +67,7 @@ const extensionKindToTypeBrand: { [T in ExtensionKind]: string } = {
[ExtensionKind.RangeFunction]: "__luaRangeFunctionBrand",
[ExtensionKind.VarargConstant]: "__luaVarargConstantBrand",
[ExtensionKind.IterableType]: "__luaIterableBrand",
[ExtensionKind.PairsIterableType]: "__luaPairsIterableBrand",
[ExtensionKind.AdditionOperatorType]: "__luaAdditionBrand",
[ExtensionKind.AdditionOperatorMethodType]: "__luaAdditionMethodBrand",
[ExtensionKind.SubtractionOperatorType]: "__luaSubtractionBrand",
Expand Down
63 changes: 63 additions & 0 deletions src/transformation/visitors/language-extensions/pairsIterable.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import * as ts from "typescript";
import * as lua from "../../../LuaAST";
import { cast } from "../../../utils";
import { TransformationContext } from "../../context";
import { invalidPairsIterableWithoutDestructuring } from "../../utils/diagnostics";
import * as extensions from "../../utils/language-extensions";
import { getVariableDeclarationBinding } from "../loops/utils";
import { transformArrayBindingElement } from "../variable-declaration";

function isPairsIterableType(type: ts.Type): boolean {
return extensions.isExtensionType(type, extensions.ExtensionKind.PairsIterableType);
}

export function isPairsIterableExpression(context: TransformationContext, expression: ts.Expression): boolean {
const type = context.checker.getTypeAtLocation(expression);
return isPairsIterableType(type);
}

export function transformForOfPairsIterableStatement(
context: TransformationContext,
statement: ts.ForOfStatement,
block: lua.Block
): lua.Statement {
const pairsCall = lua.createCallExpression(lua.createIdentifier("pairs"), [
context.transformExpression(statement.expression),
]);

let identifiers: lua.Identifier[] = [];

if (ts.isVariableDeclarationList(statement.initializer)) {
// Variables declared in for loop
// for key, value in iterable do
const binding = getVariableDeclarationBinding(context, statement.initializer);
if (ts.isArrayBindingPattern(binding)) {
identifiers = binding.elements.map(e => transformArrayBindingElement(context, e));
} else {
context.diagnostics.push(invalidPairsIterableWithoutDestructuring(binding));
}
} else if (ts.isArrayLiteralExpression(statement.initializer)) {
// Variables NOT declared in for loop - catch iterator values in temps and assign
// for ____key, ____value in iterable do
// key, value = ____key, ____value
identifiers = statement.initializer.elements.map(e => context.createTempNameForNode(e));
if (identifiers.length > 0) {
block.statements.unshift(
lua.createAssignmentStatement(
statement.initializer.elements.map(e =>
cast(context.transformExpression(e), lua.isAssignmentLeftHandSideExpression)
),
identifiers
)
);
}
} else {
context.diagnostics.push(invalidPairsIterableWithoutDestructuring(statement.initializer));
}

if (identifiers.length === 0) {
identifiers.push(lua.createAnonymousIdentifier());
}

return lua.createForInStatement(block, identifiers, [pairsCall], statement);
}
3 changes: 3 additions & 0 deletions src/transformation/visitors/loops/for-of.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { annotationRemoved } from "../../utils/diagnostics";
import { LuaLibFeature, transformLuaLibFunction } from "../../utils/lualib";
import { isArrayType } from "../../utils/typescript";
import { isIterableExpression, transformForOfIterableStatement } from "../language-extensions/iterable";
import { isPairsIterableExpression, transformForOfPairsIterableStatement } from "../language-extensions/pairsIterable";
import { isRangeFunction, transformRangeStatement } from "../language-extensions/range";
import { transformForInitializer, transformLoopBody } from "./utils";

Expand Down Expand Up @@ -47,6 +48,8 @@ export const transformForOfStatement: FunctionVisitor<ts.ForOfStatement> = (node
context.diagnostics.push(annotationRemoved(node.expression, AnnotationKind.ForRange));
} else if (isIterableExpression(context, node.expression)) {
return transformForOfIterableStatement(context, node, body);
} else if (isPairsIterableExpression(context, node.expression)) {
return transformForOfPairsIterableStatement(context, node, body);
} else if (isLuaIteratorType(context, node.expression)) {
context.diagnostics.push(annotationRemoved(node.expression, AnnotationKind.LuaIterator));
} else if (isArrayType(context, context.checker.getTypeAtLocation(node.expression))) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP

exports[`invalid LuaPairsIterable without destructuring ("for (const s of testIterable) {}"): code 1`] = `
"local ____exports = {}
function ____exports.__main(self)
local testIterable = {a1 = \\"a2\\", b1 = \\"b2\\", c1 = \\"c2\\"}
for ____ in pairs(testIterable) do
end
end
return ____exports"
`;

exports[`invalid LuaPairsIterable without destructuring ("for (const s of testIterable) {}"): diagnostics 1`] = `"main.ts(5,20): error TSTL: LuaPairsIterable type must be destructured in a for...of statement."`;

exports[`invalid LuaPairsIterable without destructuring ("let s; for (s of testIterable) {}"): code 1`] = `
"local ____exports = {}
function ____exports.__main(self)
local testIterable = {a1 = \\"a2\\", b1 = \\"b2\\", c1 = \\"c2\\"}
local s
for ____ in pairs(testIterable) do
end
end
return ____exports"
`;

exports[`invalid LuaPairsIterable without destructuring ("let s; for (s of testIterable) {}"): diagnostics 1`] = `"main.ts(5,21): error TSTL: LuaPairsIterable type must be destructured in a for...of statement."`;
120 changes: 120 additions & 0 deletions test/unit/language-extensions/pairsIterable.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import * as util from "../../util";
import { invalidPairsIterableWithoutDestructuring } from "../../../src/transformation/utils/diagnostics";
import { LuaTarget } from "../../../src";

const testIterable = `
const testIterable = {a1: "a2", b1: "b2", c1: "c2"} as unknown as LuaPairsIterable<string, string>;
`;

const testResults = {
a1: "a2",
b1: "b2",
c1: "c2",
};

test("pairs iterable", () => {
util.testFunction`
${testIterable}
const results: Record<string, string> = {};
for (const [k, v] of testIterable) {
results[k] = v;
}
return results;
`
.withLanguageExtensions()
.expectToEqual(testResults);
});

test("pairs iterable with external control variable", () => {
util.testFunction`
${testIterable}
const results: Record<string, string> = {};
let k: string, v: string;
for ([k, v] of testIterable) {
results[k] = v;
}
return results;
`
.withLanguageExtensions()
.expectToEqual(testResults);
});

test("pairs iterable function forward", () => {
util.testFunction`
${testIterable}
function forward() { return testIterable; }
const results: Record<string, string> = {};
for (const [k, v] of forward()) {
results[k] = v;
}
return results;
`
.withLanguageExtensions()
.expectToEqual(testResults);
});

test("pairs iterable function indirect forward", () => {
util.testFunction`
${testIterable}
function forward() { const iter = testIterable; return iter; }
const results: Record<string, string> = {};
for (const [k, v] of forward()) {
results[k] = v;
}
return results;
`
.withLanguageExtensions()
.expectToEqual(testResults);
});

test("pairs iterable arrow function forward", () => {
util.testFunction`
${testIterable}
const forward = () => testIterable;
const results: Record<string, string> = {};
for (const [k, v] of forward()) {
results[k] = v;
}
return results;
`
.withLanguageExtensions()
.expectToEqual(testResults);
});

test("pairs iterable with __pairs metamethod", () => {
util.testFunction`
class PairsTest {
__pairs() {
const kvp = [ ["a1", "a2"], ["b1", "b2"], ["c1", "c2"] ];
let i = 0;
return () => {
if (i < kvp.length) {
const [k, v] = kvp[i++];
return $multi(k, v);
}
};
}
}
const tester = new PairsTest() as PairsTest & LuaPairsIterable<string, string>;
const results: Record<string, string> = {};
for (const [k, v] of tester) {
results[k] = v;
}
return results;
`
.withLanguageExtensions()
.setOptions({ luaTarget: LuaTarget.Lua53 })
.expectToEqual(testResults);
});

test.each(["for (const s of testIterable) {}", "let s; for (s of testIterable) {}"])(
"invalid LuaPairsIterable without destructuring (%p)",
statement => {
util.testFunction`
${testIterable}
${statement}
`
.withLanguageExtensions()
.expectDiagnosticsToMatchSnapshot([invalidPairsIterableWithoutDestructuring.code]);
}
);
16 changes: 16 additions & 0 deletions test/unit/language-extensions/table.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -325,4 +325,20 @@ describe("LuaTable extension interface", () => {
util.testFunction(statement).withLanguageExtensions().expectToHaveNoDiagnostics();
}
);

test("table pairs iterate", () => {
util.testFunction`
const tbl = new LuaTable<string, number>();
tbl.set("foo", 1);
tbl.set("bar", 3);
tbl.set("baz", 5);
const results: Record<string, number> = {};
for (const [k, v] of tbl) {
results[k] = v;
}
return results;
`
.withLanguageExtensions()
.expectToEqual({ foo: 1, bar: 3, baz: 5 });
});
});