Skip to content

Commit 3eae0a2

Browse files
authored
added LuaPairsIterable extension (#1176)
Co-authored-by: Tom <tomblind@users.noreply.github.com>
1 parent c955a63 commit 3eae0a2

File tree

8 files changed

+245
-1
lines changed

8 files changed

+245
-1
lines changed

language-extensions/index.d.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,16 @@ declare type LuaIterable<TValue, TState = undefined> = Iterable<TValue> &
7878
LuaIterator<TValue, TState> &
7979
LuaExtension<"__luaIterableBrand">;
8080

81+
/**
82+
* Represents an object that can be iterated with pairs()
83+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
84+
*
85+
* @param TKey The type of the key returned each iteration.
86+
* @param TValue The type of the value returned each iteration.
87+
*/
88+
declare type LuaPairsIterable<TKey extends AnyNotNil, TValue> = Iterable<[TKey, TValue]> &
89+
LuaExtension<"__luaPairsIterableBrand">;
90+
8191
/**
8292
* Calls to functions with this type are translated to `left + right`.
8393
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
@@ -535,7 +545,7 @@ declare type LuaTableDeleteMethod<TKey extends AnyNotNil> = ((key: TKey) => bool
535545
* @param TKey The type of the keys used to access the table.
536546
* @param TValue The type of the values stored in the table.
537547
*/
538-
declare interface LuaTable<TKey extends AnyNotNil = AnyNotNil, TValue = any> {
548+
declare interface LuaTable<TKey extends AnyNotNil = AnyNotNil, TValue = any> extends LuaPairsIterable<TKey, TValue> {
539549
length: LuaLengthMethod<number>;
540550
get: LuaTableGetMethod<TKey, TValue>;
541551
set: LuaTableSetMethod<TKey, TValue>;

src/transformation/utils/diagnostics.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ export const invalidMultiIterableWithoutDestructuring = createErrorDiagnosticFac
7272
"LuaIterable with a LuaMultiReturn return value type must be destructured."
7373
);
7474

75+
export const invalidPairsIterableWithoutDestructuring = createErrorDiagnosticFactory(
76+
"LuaPairsIterable type must be destructured in a for...of statement."
77+
);
78+
7579
export const unsupportedAccessorInObjectLiteral = createErrorDiagnosticFactory(
7680
"Accessors in object literal are not supported."
7781
);

src/transformation/utils/language-extensions.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export enum ExtensionKind {
77
RangeFunction = "RangeFunction",
88
VarargConstant = "VarargConstant",
99
IterableType = "IterableType",
10+
PairsIterableType = "PairsIterableType",
1011
AdditionOperatorType = "AdditionOperatorType",
1112
AdditionOperatorMethodType = "AdditionOperatorMethodType",
1213
SubtractionOperatorType = "SubtractionOperatorType",
@@ -66,6 +67,7 @@ const extensionKindToTypeBrand: { [T in ExtensionKind]: string } = {
6667
[ExtensionKind.RangeFunction]: "__luaRangeFunctionBrand",
6768
[ExtensionKind.VarargConstant]: "__luaVarargConstantBrand",
6869
[ExtensionKind.IterableType]: "__luaIterableBrand",
70+
[ExtensionKind.PairsIterableType]: "__luaPairsIterableBrand",
6971
[ExtensionKind.AdditionOperatorType]: "__luaAdditionBrand",
7072
[ExtensionKind.AdditionOperatorMethodType]: "__luaAdditionMethodBrand",
7173
[ExtensionKind.SubtractionOperatorType]: "__luaSubtractionBrand",
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import * as ts from "typescript";
2+
import * as lua from "../../../LuaAST";
3+
import { cast } from "../../../utils";
4+
import { TransformationContext } from "../../context";
5+
import { invalidPairsIterableWithoutDestructuring } from "../../utils/diagnostics";
6+
import * as extensions from "../../utils/language-extensions";
7+
import { getVariableDeclarationBinding } from "../loops/utils";
8+
import { transformArrayBindingElement } from "../variable-declaration";
9+
10+
function isPairsIterableType(type: ts.Type): boolean {
11+
return extensions.isExtensionType(type, extensions.ExtensionKind.PairsIterableType);
12+
}
13+
14+
export function isPairsIterableExpression(context: TransformationContext, expression: ts.Expression): boolean {
15+
const type = context.checker.getTypeAtLocation(expression);
16+
return isPairsIterableType(type);
17+
}
18+
19+
export function transformForOfPairsIterableStatement(
20+
context: TransformationContext,
21+
statement: ts.ForOfStatement,
22+
block: lua.Block
23+
): lua.Statement {
24+
const pairsCall = lua.createCallExpression(lua.createIdentifier("pairs"), [
25+
context.transformExpression(statement.expression),
26+
]);
27+
28+
let identifiers: lua.Identifier[] = [];
29+
30+
if (ts.isVariableDeclarationList(statement.initializer)) {
31+
// Variables declared in for loop
32+
// for key, value in iterable do
33+
const binding = getVariableDeclarationBinding(context, statement.initializer);
34+
if (ts.isArrayBindingPattern(binding)) {
35+
identifiers = binding.elements.map(e => transformArrayBindingElement(context, e));
36+
} else {
37+
context.diagnostics.push(invalidPairsIterableWithoutDestructuring(binding));
38+
}
39+
} else if (ts.isArrayLiteralExpression(statement.initializer)) {
40+
// Variables NOT declared in for loop - catch iterator values in temps and assign
41+
// for ____key, ____value in iterable do
42+
// key, value = ____key, ____value
43+
identifiers = statement.initializer.elements.map(e => context.createTempNameForNode(e));
44+
if (identifiers.length > 0) {
45+
block.statements.unshift(
46+
lua.createAssignmentStatement(
47+
statement.initializer.elements.map(e =>
48+
cast(context.transformExpression(e), lua.isAssignmentLeftHandSideExpression)
49+
),
50+
identifiers
51+
)
52+
);
53+
}
54+
} else {
55+
context.diagnostics.push(invalidPairsIterableWithoutDestructuring(statement.initializer));
56+
}
57+
58+
if (identifiers.length === 0) {
59+
identifiers.push(lua.createAnonymousIdentifier());
60+
}
61+
62+
return lua.createForInStatement(block, identifiers, [pairsCall], statement);
63+
}

src/transformation/visitors/loops/for-of.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { annotationRemoved } from "../../utils/diagnostics";
66
import { LuaLibFeature, transformLuaLibFunction } from "../../utils/lualib";
77
import { isArrayType } from "../../utils/typescript";
88
import { isIterableExpression, transformForOfIterableStatement } from "../language-extensions/iterable";
9+
import { isPairsIterableExpression, transformForOfPairsIterableStatement } from "../language-extensions/pairsIterable";
910
import { isRangeFunction, transformRangeStatement } from "../language-extensions/range";
1011
import { transformForInitializer, transformLoopBody } from "./utils";
1112

@@ -47,6 +48,8 @@ export const transformForOfStatement: FunctionVisitor<ts.ForOfStatement> = (node
4748
context.diagnostics.push(annotationRemoved(node.expression, AnnotationKind.ForRange));
4849
} else if (isIterableExpression(context, node.expression)) {
4950
return transformForOfIterableStatement(context, node, body);
51+
} else if (isPairsIterableExpression(context, node.expression)) {
52+
return transformForOfPairsIterableStatement(context, node, body);
5053
} else if (isLuaIteratorType(context, node.expression)) {
5154
context.diagnostics.push(annotationRemoved(node.expression, AnnotationKind.LuaIterator));
5255
} else if (isArrayType(context, context.checker.getTypeAtLocation(node.expression))) {
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Jest Snapshot v1, https://goo.gl/fbAQLP
2+
3+
exports[`invalid LuaPairsIterable without destructuring ("for (const s of testIterable) {}"): code 1`] = `
4+
"local ____exports = {}
5+
function ____exports.__main(self)
6+
local testIterable = {a1 = \\"a2\\", b1 = \\"b2\\", c1 = \\"c2\\"}
7+
for ____ in pairs(testIterable) do
8+
end
9+
end
10+
return ____exports"
11+
`;
12+
13+
exports[`invalid LuaPairsIterable without destructuring ("for (const s of testIterable) {}"): diagnostics 1`] = `"main.ts(5,20): error TSTL: LuaPairsIterable type must be destructured in a for...of statement."`;
14+
15+
exports[`invalid LuaPairsIterable without destructuring ("let s; for (s of testIterable) {}"): code 1`] = `
16+
"local ____exports = {}
17+
function ____exports.__main(self)
18+
local testIterable = {a1 = \\"a2\\", b1 = \\"b2\\", c1 = \\"c2\\"}
19+
local s
20+
for ____ in pairs(testIterable) do
21+
end
22+
end
23+
return ____exports"
24+
`;
25+
26+
exports[`invalid LuaPairsIterable without destructuring ("let s; for (s of testIterable) {}"): diagnostics 1`] = `"main.ts(5,21): error TSTL: LuaPairsIterable type must be destructured in a for...of statement."`;
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import * as util from "../../util";
2+
import { invalidPairsIterableWithoutDestructuring } from "../../../src/transformation/utils/diagnostics";
3+
import { LuaTarget } from "../../../src";
4+
5+
const testIterable = `
6+
const testIterable = {a1: "a2", b1: "b2", c1: "c2"} as unknown as LuaPairsIterable<string, string>;
7+
`;
8+
9+
const testResults = {
10+
a1: "a2",
11+
b1: "b2",
12+
c1: "c2",
13+
};
14+
15+
test("pairs iterable", () => {
16+
util.testFunction`
17+
${testIterable}
18+
const results: Record<string, string> = {};
19+
for (const [k, v] of testIterable) {
20+
results[k] = v;
21+
}
22+
return results;
23+
`
24+
.withLanguageExtensions()
25+
.expectToEqual(testResults);
26+
});
27+
28+
test("pairs iterable with external control variable", () => {
29+
util.testFunction`
30+
${testIterable}
31+
const results: Record<string, string> = {};
32+
let k: string, v: string;
33+
for ([k, v] of testIterable) {
34+
results[k] = v;
35+
}
36+
return results;
37+
`
38+
.withLanguageExtensions()
39+
.expectToEqual(testResults);
40+
});
41+
42+
test("pairs iterable function forward", () => {
43+
util.testFunction`
44+
${testIterable}
45+
function forward() { return testIterable; }
46+
const results: Record<string, string> = {};
47+
for (const [k, v] of forward()) {
48+
results[k] = v;
49+
}
50+
return results;
51+
`
52+
.withLanguageExtensions()
53+
.expectToEqual(testResults);
54+
});
55+
56+
test("pairs iterable function indirect forward", () => {
57+
util.testFunction`
58+
${testIterable}
59+
function forward() { const iter = testIterable; return iter; }
60+
const results: Record<string, string> = {};
61+
for (const [k, v] of forward()) {
62+
results[k] = v;
63+
}
64+
return results;
65+
`
66+
.withLanguageExtensions()
67+
.expectToEqual(testResults);
68+
});
69+
70+
test("pairs iterable arrow function forward", () => {
71+
util.testFunction`
72+
${testIterable}
73+
const forward = () => testIterable;
74+
const results: Record<string, string> = {};
75+
for (const [k, v] of forward()) {
76+
results[k] = v;
77+
}
78+
return results;
79+
`
80+
.withLanguageExtensions()
81+
.expectToEqual(testResults);
82+
});
83+
84+
test("pairs iterable with __pairs metamethod", () => {
85+
util.testFunction`
86+
class PairsTest {
87+
__pairs() {
88+
const kvp = [ ["a1", "a2"], ["b1", "b2"], ["c1", "c2"] ];
89+
let i = 0;
90+
return () => {
91+
if (i < kvp.length) {
92+
const [k, v] = kvp[i++];
93+
return $multi(k, v);
94+
}
95+
};
96+
}
97+
}
98+
const tester = new PairsTest() as PairsTest & LuaPairsIterable<string, string>;
99+
const results: Record<string, string> = {};
100+
for (const [k, v] of tester) {
101+
results[k] = v;
102+
}
103+
return results;
104+
`
105+
.withLanguageExtensions()
106+
.setOptions({ luaTarget: LuaTarget.Lua53 })
107+
.expectToEqual(testResults);
108+
});
109+
110+
test.each(["for (const s of testIterable) {}", "let s; for (s of testIterable) {}"])(
111+
"invalid LuaPairsIterable without destructuring (%p)",
112+
statement => {
113+
util.testFunction`
114+
${testIterable}
115+
${statement}
116+
`
117+
.withLanguageExtensions()
118+
.expectDiagnosticsToMatchSnapshot([invalidPairsIterableWithoutDestructuring.code]);
119+
}
120+
);

test/unit/language-extensions/table.spec.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,4 +325,20 @@ describe("LuaTable extension interface", () => {
325325
util.testFunction(statement).withLanguageExtensions().expectToHaveNoDiagnostics();
326326
}
327327
);
328+
329+
test("table pairs iterate", () => {
330+
util.testFunction`
331+
const tbl = new LuaTable<string, number>();
332+
tbl.set("foo", 1);
333+
tbl.set("bar", 3);
334+
tbl.set("baz", 5);
335+
const results: Record<string, number> = {};
336+
for (const [k, v] of tbl) {
337+
results[k] = v;
338+
}
339+
return results;
340+
`
341+
.withLanguageExtensions()
342+
.expectToEqual({ foo: 1, bar: 3, baz: 5 });
343+
});
328344
});

0 commit comments

Comments
 (0)