Skip to content

Commit 7ff29f0

Browse files
authored
Add LuaMap and LuaSet language extensions (#1303)
* Add LuaTableAdd language extension * Add definitions for LuaSet and LuaMap * Add LuaPairsKeyIterable language extension * Update test * Fix typos * Rename LuaTableAdd to LuaTableAddKey * Add more tests for LuaTableAddKey
1 parent 8733992 commit 7ff29f0

File tree

7 files changed

+330
-2
lines changed

7 files changed

+330
-2
lines changed

language-extensions/index.d.ts

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ declare type LuaIterable<TValue, TState = undefined> = Iterable<TValue> &
8888
declare type LuaPairsIterable<TKey extends AnyNotNil, TValue> = Iterable<[TKey, TValue]> &
8989
LuaExtension<"__luaPairsIterableBrand">;
9090

91+
/**
92+
* Represents an object that can be iterated with pairs(), where only the key value is used.
93+
*
94+
* @param TKey The type of the key returned each iteration.
95+
*/
96+
declare type LuaPairsKeyIterable<TKey extends AnyNotNil> = Iterable<TKey> & LuaExtension<"__luaPairsKeyIterableBrand">;
97+
9198
/**
9299
* Calls to functions with this type are translated to `left + right`.
93100
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
@@ -500,6 +507,24 @@ declare type LuaTableSet<TTable extends AnyTable, TKey extends AnyNotNil, TValue
500507
declare type LuaTableSetMethod<TKey extends AnyNotNil, TValue> = ((key: TKey, value: TValue) => void) &
501508
LuaExtension<"__luaTableSetMethodBrand">;
502509

510+
/**
511+
* Calls to functions with this type are translated to `table[key] = true`.
512+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
513+
*
514+
* @param TTable The type to access as a Lua table.
515+
* @param TKey The type of the key to use to access the table.
516+
*/
517+
declare type LuaTableAddKey<TTable extends AnyTable, TKey extends AnyNotNil> = ((table: TTable, key: TKey) => void) &
518+
LuaExtension<"__luaTableAddKeyBrand">;
519+
520+
/**
521+
* Calls to methods with this type are translated to `table[key] = true`, where `table` is the object with the method.
522+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
523+
* @param TKey The type of the key to use to access the table.
524+
*/
525+
declare type LuaTableAddKeyMethod<TKey extends AnyNotNil> = ((key: TKey) => void) &
526+
LuaExtension<"__luaTableAddKeyMethodBrand">;
527+
503528
/**
504529
* Calls to functions with this type are translated to `table[key] ~= nil`.
505530
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
@@ -574,3 +599,69 @@ declare type LuaTableConstructor = (new <TKey extends AnyNotNil = AnyNotNil, TVa
574599
* @param TValue The type of the values stored in the table.
575600
*/
576601
declare const LuaTable: LuaTableConstructor;
602+
603+
/**
604+
* A convenience type for working directly with a Lua table, used as a map.
605+
*
606+
* This differs from LuaTable in that the `get` method may return `nil`.
607+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
608+
* @param K The type of the keys used to access the table.
609+
* @param V The type of the values stored in the table.
610+
*/
611+
declare interface LuaMap<K extends AnyNotNil = AnyNotNil, V = any> extends LuaPairsIterable<K, V> {
612+
get: LuaTableGetMethod<K, V | undefined>;
613+
set: LuaTableSetMethod<K, V>;
614+
has: LuaTableHasMethod<K>;
615+
delete: LuaTableDeleteMethod<K>;
616+
}
617+
618+
/**
619+
* A convenience type for working directly with a Lua table, used as a map.
620+
*
621+
* This differs from LuaTable in that the `get` method may return `nil`.
622+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
623+
* @param K The type of the keys used to access the table.
624+
* @param V The type of the values stored in the table.
625+
*/
626+
declare const LuaMap: (new <K extends AnyNotNil = AnyNotNil, V = any>() => LuaMap<K, V>) &
627+
LuaExtension<"__luaTableNewBrand">;
628+
629+
/**
630+
* Readonly version of {@link LuaMap}.
631+
*
632+
* @param K The type of the keys used to access the table.
633+
* @param V The type of the values stored in the table.
634+
*/
635+
declare interface LuaReadonlyMap<K extends AnyNotNil = AnyNotNil, V = any> extends LuaPairsIterable<K, V> {
636+
get: LuaTableGetMethod<K, V>;
637+
has: LuaTableHasMethod<K>;
638+
}
639+
640+
/**
641+
* A convenience type for working directly with a Lua table, used as a set.
642+
*
643+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
644+
* @param T The type of the keys used to access the table.
645+
*/
646+
declare interface LuaSet<T extends AnyNotNil = AnyNotNil> extends LuaPairsKeyIterable<T> {
647+
add: LuaTableAddKeyMethod<T>;
648+
has: LuaTableHasMethod<T>;
649+
delete: LuaTableDeleteMethod<T>;
650+
}
651+
652+
/**
653+
* A convenience type for working directly with a Lua table, used as a set.
654+
*
655+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
656+
* @param T The type of the keys used to access the table.
657+
*/
658+
declare const LuaSet: (new <T extends AnyNotNil = AnyNotNil>() => LuaSet<T>) & LuaExtension<"__luaTableNewBrand">;
659+
660+
/**
661+
* Readonly version of {@link LuaSet}.
662+
*
663+
* @param T The type of the keys used to access the table.
664+
*/
665+
declare interface LuaReadonlySet<T extends AnyNotNil = AnyNotNil> extends LuaPairsKeyIterable<T> {
666+
has: LuaTableHasMethod<T>;
667+
}

src/transformation/utils/language-extensions.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ export enum ExtensionKind {
88
VarargConstant = "VarargConstant",
99
IterableType = "IterableType",
1010
PairsIterableType = "PairsIterableType",
11+
PairsKeyIterableType = "PairsKeyIterableType",
1112
AdditionOperatorType = "AdditionOperatorType",
1213
AdditionOperatorMethodType = "AdditionOperatorMethodType",
1314
SubtractionOperatorType = "SubtractionOperatorType",
@@ -53,6 +54,8 @@ export enum ExtensionKind {
5354
TableHasMethodType = "TableHasMethodType",
5455
TableSetType = "TableSetType",
5556
TableSetMethodType = "TableSetMethodType",
57+
TableAddType = "TableAddType",
58+
TableAddMethodType = "TableAddMethodType",
5659
}
5760

5861
const extensionKindToValueName: { [T in ExtensionKind]?: string } = {
@@ -68,6 +71,7 @@ const extensionKindToTypeBrand: { [T in ExtensionKind]: string } = {
6871
[ExtensionKind.VarargConstant]: "__luaVarargConstantBrand",
6972
[ExtensionKind.IterableType]: "__luaIterableBrand",
7073
[ExtensionKind.PairsIterableType]: "__luaPairsIterableBrand",
74+
[ExtensionKind.PairsKeyIterableType]: "__luaPairsKeyIterableBrand",
7175
[ExtensionKind.AdditionOperatorType]: "__luaAdditionBrand",
7276
[ExtensionKind.AdditionOperatorMethodType]: "__luaAdditionMethodBrand",
7377
[ExtensionKind.SubtractionOperatorType]: "__luaSubtractionBrand",
@@ -113,6 +117,8 @@ const extensionKindToTypeBrand: { [T in ExtensionKind]: string } = {
113117
[ExtensionKind.TableHasMethodType]: "__luaTableHasMethodBrand",
114118
[ExtensionKind.TableSetType]: "__luaTableSetBrand",
115119
[ExtensionKind.TableSetMethodType]: "__luaTableSetMethodBrand",
120+
[ExtensionKind.TableAddType]: "__luaTableAddKeyBrand",
121+
[ExtensionKind.TableAddMethodType]: "__luaTableAddKeyMethodBrand",
116122
};
117123

118124
export function isExtensionType(type: ts.Type, extensionKind: ExtensionKind): boolean {

src/transformation/visitors/language-extensions/pairsIterable.ts

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { cast } from "../../../utils";
44
import { TransformationContext } from "../../context";
55
import { invalidPairsIterableWithoutDestructuring } from "../../utils/diagnostics";
66
import * as extensions from "../../utils/language-extensions";
7-
import { getVariableDeclarationBinding } from "../loops/utils";
7+
import { getVariableDeclarationBinding, transformForInitializer } from "../loops/utils";
88
import { transformArrayBindingElement } from "../variable-declaration";
99

1010
function isPairsIterableType(type: ts.Type): boolean {
@@ -16,6 +16,15 @@ export function isPairsIterableExpression(context: TransformationContext, expres
1616
return isPairsIterableType(type);
1717
}
1818

19+
function isPairsKeyIterableType(type: ts.Type): boolean {
20+
return extensions.isExtensionType(type, extensions.ExtensionKind.PairsKeyIterableType);
21+
}
22+
23+
export function isPairsKeyIterableExpression(context: TransformationContext, expression: ts.Expression): boolean {
24+
const type = context.checker.getTypeAtLocation(expression);
25+
return isPairsKeyIterableType(type);
26+
}
27+
1928
export function transformForOfPairsIterableStatement(
2029
context: TransformationContext,
2130
statement: ts.ForOfStatement,
@@ -61,3 +70,15 @@ export function transformForOfPairsIterableStatement(
6170

6271
return lua.createForInStatement(block, identifiers, [pairsCall], statement);
6372
}
73+
74+
export function transformForOfPairsKeyIterableStatement(
75+
context: TransformationContext,
76+
statement: ts.ForOfStatement,
77+
block: lua.Block
78+
): lua.Statement {
79+
const pairsCall = lua.createCallExpression(lua.createIdentifier("pairs"), [
80+
context.transformExpression(statement.expression),
81+
]);
82+
const identifier = transformForInitializer(context, statement.initializer, block);
83+
return lua.createForInStatement(block, [identifier], [pairsCall], statement);
84+
}

src/transformation/visitors/language-extensions/table.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ const tableCallExtensions = [
1515
extensions.ExtensionKind.TableHasMethodType,
1616
extensions.ExtensionKind.TableSetType,
1717
extensions.ExtensionKind.TableSetMethodType,
18+
extensions.ExtensionKind.TableAddType,
19+
extensions.ExtensionKind.TableAddMethodType,
1820
];
1921

2022
const tableExtensions = [extensions.ExtensionKind.TableNewType, ...tableCallExtensions];
@@ -77,6 +79,13 @@ export function transformTableExtensionCall(
7779
) {
7880
return transformTableSetExpression(context, node, extensionType);
7981
}
82+
83+
if (
84+
extensionType === extensions.ExtensionKind.TableAddType ||
85+
extensionType === extensions.ExtensionKind.TableAddMethodType
86+
) {
87+
return transformTableAddExpression(context, node, extensionType);
88+
}
8089
}
8190

8291
function transformTableDeleteExpression(
@@ -172,3 +181,29 @@ function transformTableSetExpression(
172181
);
173182
return lua.createNilLiteral();
174183
}
184+
185+
function transformTableAddExpression(
186+
context: TransformationContext,
187+
node: ts.CallExpression,
188+
extensionKind: extensions.ExtensionKind
189+
): lua.Expression {
190+
const args = node.arguments.slice();
191+
if (
192+
extensionKind === extensions.ExtensionKind.TableAddMethodType &&
193+
(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))
194+
) {
195+
// In case of method (no table argument), push method owner to front of args list
196+
args.unshift(node.expression.expression);
197+
}
198+
199+
// arg0[arg1] = true
200+
const [table, value] = transformExpressionList(context, args);
201+
context.addPrecedingStatements(
202+
lua.createAssignmentStatement(
203+
lua.createTableIndexExpression(table, value),
204+
lua.createBooleanLiteral(true),
205+
node
206+
)
207+
);
208+
return lua.createNilLiteral();
209+
}

src/transformation/visitors/loops/for-of.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ import { annotationRemoved } from "../../utils/diagnostics";
66
import { LuaLibFeature, transformLuaLibFunction } from "../../utils/lualib";
77
import { isArrayType } from "../../utils/typescript";
88
import { isIterableExpression, transformForOfIterableStatement } from "../language-extensions/iterable";
9-
import { isPairsIterableExpression, transformForOfPairsIterableStatement } from "../language-extensions/pairsIterable";
9+
import {
10+
isPairsIterableExpression,
11+
transformForOfPairsIterableStatement,
12+
isPairsKeyIterableExpression,
13+
transformForOfPairsKeyIterableStatement,
14+
} from "../language-extensions/pairsIterable";
1015
import { isRangeFunction, transformRangeStatement } from "../language-extensions/range";
1116
import { transformForInitializer, transformLoopBody } from "./utils";
1217

@@ -50,6 +55,8 @@ export const transformForOfStatement: FunctionVisitor<ts.ForOfStatement> = (node
5055
return transformForOfIterableStatement(context, node, body);
5156
} else if (isPairsIterableExpression(context, node.expression)) {
5257
return transformForOfPairsIterableStatement(context, node, body);
58+
} else if (isPairsKeyIterableExpression(context, node.expression)) {
59+
return transformForOfPairsKeyIterableStatement(context, node, body);
5360
} else if (isLuaIteratorType(context, node.expression)) {
5461
context.diagnostics.push(annotationRemoved(node.expression, AnnotationKind.LuaIterator));
5562
} else if (isArrayType(context, context.checker.getTypeAtLocation(node.expression))) {

test/unit/language-extensions/pairsIterable.spec.ts

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,102 @@ test.each(["for (const s of testIterable) {}", "let s; for (s of testIterable) {
118118
.expectDiagnosticsToMatchSnapshot([invalidPairsIterableWithoutDestructuring.code]);
119119
}
120120
);
121+
122+
const testKeyIterable = `
123+
const testKeyIterable = {a1: true, b1: true, c1: true} as unknown as LuaPairsKeyIterable<string>;
124+
`;
125+
126+
test("pairs key iterable", () => {
127+
util.testFunction`
128+
${testKeyIterable}
129+
const results: Record<string, boolean> = {};
130+
for (const k of testKeyIterable) {
131+
results[k] = true;
132+
}
133+
return results;
134+
`
135+
.withLanguageExtensions()
136+
.expectToEqual({ a1: true, b1: true, c1: true });
137+
});
138+
139+
test("pairs key iterable with external control variable", () => {
140+
util.testFunction`
141+
${testKeyIterable}
142+
const results: Record<string, boolean> = {};
143+
let k: string;
144+
for (k of testKeyIterable) {
145+
results[k] = true;
146+
}
147+
return results;
148+
`
149+
.withLanguageExtensions()
150+
.expectToEqual({ a1: true, b1: true, c1: true });
151+
});
152+
153+
test("pairs key iterable function forward", () => {
154+
util.testFunction`
155+
${testKeyIterable}
156+
function forward() { return testKeyIterable; }
157+
const results: Record<string, boolean> = {};
158+
for (const k of forward()) {
159+
results[k] = true;
160+
}
161+
return results;
162+
`
163+
.withLanguageExtensions()
164+
.expectToEqual({ a1: true, b1: true, c1: true });
165+
});
166+
167+
test("pairs key iterable function indirect forward", () => {
168+
util.testFunction`
169+
${testKeyIterable}
170+
function forward() { const iter = testKeyIterable; return iter; }
171+
const results: Record<string, boolean> = {};
172+
for (const k of forward()) {
173+
results[k] = true;
174+
}
175+
return results;
176+
`
177+
.withLanguageExtensions()
178+
.expectToEqual({ a1: true, b1: true, c1: true });
179+
});
180+
181+
test("pairs key iterable arrow function forward", () => {
182+
util.testFunction`
183+
${testKeyIterable}
184+
const forward = () => testKeyIterable;
185+
const results: Record<string, boolean> = {};
186+
for (const k of forward()) {
187+
results[k] = true;
188+
}
189+
return results;
190+
`
191+
.withLanguageExtensions()
192+
.expectToEqual({ a1: true, b1: true, c1: true });
193+
});
194+
195+
test("pairs key iterable with __pairs metamethod", () => {
196+
util.testFunction`
197+
class PairsTest {
198+
__pairs() {
199+
const kvp = [ ["a1", true], ["b1", true], ["c1", true] ];
200+
let i = 0;
201+
return () => {
202+
if (i < kvp.length) {
203+
const [k, v] = kvp[i++];
204+
return $multi(k, v);
205+
}
206+
};
207+
}
208+
}
209+
const tester = new PairsTest() as PairsTest & LuaPairsKeyIterable<string>;
210+
const results: Record<string, boolean> = {};
211+
for (const k of tester) {
212+
results[k] = true;
213+
}
214+
return results;
215+
`
216+
.withLanguageExtensions()
217+
.setOptions({ luaTarget: LuaTarget.Lua53 })
218+
.expectToEqual({ a1: true, b1: true, c1: true });
219+
});

0 commit comments

Comments
 (0)