Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions language-extensions/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ declare type LuaIterable<TValue, TState = undefined> = Iterable<TValue> &
declare type LuaPairsIterable<TKey extends AnyNotNil, TValue> = Iterable<[TKey, TValue]> &
LuaExtension<"__luaPairsIterableBrand">;

/**
* Represents an object that can be iterated with pairs(), where only the key value is used.
*
* @param TKey The type of the key returned each iteration.
*/
declare type LuaPairsKeyIterable<TKey extends AnyNotNil> = Iterable<TKey> & LuaExtension<"__luaPairsKeyIterableBrand">;

/**
* Calls to functions with this type are translated to `left + right`.
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
Expand Down Expand Up @@ -500,6 +507,24 @@ declare type LuaTableSet<TTable extends AnyTable, TKey extends AnyNotNil, TValue
declare type LuaTableSetMethod<TKey extends AnyNotNil, TValue> = ((key: TKey, value: TValue) => void) &
LuaExtension<"__luaTableSetMethodBrand">;

/**
* Calls to functions with this type are translated to `table[key] = true`.
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
*
* @param TTable The type to access as a Lua table.
* @param TKey The type of the key to use to access the table.
*/
declare type LuaTableAddKey<TTable extends AnyTable, TKey extends AnyNotNil> = ((table: TTable, key: TKey) => void) &
LuaExtension<"__luaTableAddKeyBrand">;

/**
* Calls to methods with this type are translated to `table[key] = true`, where `table` is the object with the method.
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
* @param TKey The type of the key to use to access the table.
*/
declare type LuaTableAddKeyMethod<TKey extends AnyNotNil> = ((key: TKey) => void) &
LuaExtension<"__luaTableAddKeyMethodBrand">;

/**
* Calls to functions with this type are translated to `table[key] ~= nil`.
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
Expand Down Expand Up @@ -574,3 +599,69 @@ declare type LuaTableConstructor = (new <TKey extends AnyNotNil = AnyNotNil, TVa
* @param TValue The type of the values stored in the table.
*/
declare const LuaTable: LuaTableConstructor;

/**
* A convenience type for working directly with a Lua table, used as a map.
*
* This differs from LuaTable in that the `get` method may return `nil`.
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
* @param K The type of the keys used to access the table.
* @param V The type of the values stored in the table.
*/
declare interface LuaMap<K extends AnyNotNil = AnyNotNil, V = any> extends LuaPairsIterable<K, V> {
get: LuaTableGetMethod<K, V | undefined>;
set: LuaTableSetMethod<K, V>;
has: LuaTableHasMethod<K>;
delete: LuaTableDeleteMethod<K>;
}

/**
* A convenience type for working directly with a Lua table, used as a map.
*
* This differs from LuaTable in that the `get` method may return `nil`.
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
* @param K The type of the keys used to access the table.
* @param V The type of the values stored in the table.
*/
declare const LuaMap: (new <K extends AnyNotNil = AnyNotNil, V = any>() => LuaMap<K, V>) &
LuaExtension<"__luaTableNewBrand">;

/**
* Readonly version of {@link LuaMap}.
*
* @param K The type of the keys used to access the table.
* @param V The type of the values stored in the table.
*/
declare interface LuaReadonlyMap<K extends AnyNotNil = AnyNotNil, V = any> extends LuaPairsIterable<K, V> {
get: LuaTableGetMethod<K, V>;
has: LuaTableHasMethod<K>;
}

/**
* A convenience type for working directly with a Lua table, used as a set.
*
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
* @param T The type of the keys used to access the table.
*/
declare interface LuaSet<T extends AnyNotNil = AnyNotNil> extends LuaPairsKeyIterable<T> {
add: LuaTableAddKeyMethod<T>;
has: LuaTableHasMethod<T>;
delete: LuaTableDeleteMethod<T>;
}

/**
* A convenience type for working directly with a Lua table, used as a set.
*
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
* @param T The type of the keys used to access the table.
*/
declare const LuaSet: (new <T extends AnyNotNil = AnyNotNil>() => LuaSet<T>) & LuaExtension<"__luaTableNewBrand">;

/**
* Readonly version of {@link LuaSet}.
*
* @param T The type of the keys used to access the table.
*/
declare interface LuaReadonlySet<T extends AnyNotNil = AnyNotNil> extends LuaPairsKeyIterable<T> {
has: LuaTableHasMethod<T>;
}
6 changes: 6 additions & 0 deletions src/transformation/utils/language-extensions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export enum ExtensionKind {
VarargConstant = "VarargConstant",
IterableType = "IterableType",
PairsIterableType = "PairsIterableType",
PairsKeyIterableType = "PairsKeyIterableType",
AdditionOperatorType = "AdditionOperatorType",
AdditionOperatorMethodType = "AdditionOperatorMethodType",
SubtractionOperatorType = "SubtractionOperatorType",
Expand Down Expand Up @@ -53,6 +54,8 @@ export enum ExtensionKind {
TableHasMethodType = "TableHasMethodType",
TableSetType = "TableSetType",
TableSetMethodType = "TableSetMethodType",
TableAddType = "TableAddType",
TableAddMethodType = "TableAddMethodType",
}

const extensionKindToValueName: { [T in ExtensionKind]?: string } = {
Expand All @@ -68,6 +71,7 @@ const extensionKindToTypeBrand: { [T in ExtensionKind]: string } = {
[ExtensionKind.VarargConstant]: "__luaVarargConstantBrand",
[ExtensionKind.IterableType]: "__luaIterableBrand",
[ExtensionKind.PairsIterableType]: "__luaPairsIterableBrand",
[ExtensionKind.PairsKeyIterableType]: "__luaPairsKeyIterableBrand",
[ExtensionKind.AdditionOperatorType]: "__luaAdditionBrand",
[ExtensionKind.AdditionOperatorMethodType]: "__luaAdditionMethodBrand",
[ExtensionKind.SubtractionOperatorType]: "__luaSubtractionBrand",
Expand Down Expand Up @@ -113,6 +117,8 @@ const extensionKindToTypeBrand: { [T in ExtensionKind]: string } = {
[ExtensionKind.TableHasMethodType]: "__luaTableHasMethodBrand",
[ExtensionKind.TableSetType]: "__luaTableSetBrand",
[ExtensionKind.TableSetMethodType]: "__luaTableSetMethodBrand",
[ExtensionKind.TableAddType]: "__luaTableAddKeyBrand",
[ExtensionKind.TableAddMethodType]: "__luaTableAddKeyMethodBrand",
};

export function isExtensionType(type: ts.Type, extensionKind: ExtensionKind): boolean {
Expand Down
23 changes: 22 additions & 1 deletion src/transformation/visitors/language-extensions/pairsIterable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { cast } from "../../../utils";
import { TransformationContext } from "../../context";
import { invalidPairsIterableWithoutDestructuring } from "../../utils/diagnostics";
import * as extensions from "../../utils/language-extensions";
import { getVariableDeclarationBinding } from "../loops/utils";
import { getVariableDeclarationBinding, transformForInitializer } from "../loops/utils";
import { transformArrayBindingElement } from "../variable-declaration";

function isPairsIterableType(type: ts.Type): boolean {
Expand All @@ -16,6 +16,15 @@ export function isPairsIterableExpression(context: TransformationContext, expres
return isPairsIterableType(type);
}

function isPairsKeyIterableType(type: ts.Type): boolean {
return extensions.isExtensionType(type, extensions.ExtensionKind.PairsKeyIterableType);
}

export function isPairsKeyIterableExpression(context: TransformationContext, expression: ts.Expression): boolean {
const type = context.checker.getTypeAtLocation(expression);
return isPairsKeyIterableType(type);
}

export function transformForOfPairsIterableStatement(
context: TransformationContext,
statement: ts.ForOfStatement,
Expand Down Expand Up @@ -61,3 +70,15 @@ export function transformForOfPairsIterableStatement(

return lua.createForInStatement(block, identifiers, [pairsCall], statement);
}

export function transformForOfPairsKeyIterableStatement(
context: TransformationContext,
statement: ts.ForOfStatement,
block: lua.Block
): lua.Statement {
const pairsCall = lua.createCallExpression(lua.createIdentifier("pairs"), [
context.transformExpression(statement.expression),
]);
const identifier = transformForInitializer(context, statement.initializer, block);
return lua.createForInStatement(block, [identifier], [pairsCall], statement);
}
35 changes: 35 additions & 0 deletions src/transformation/visitors/language-extensions/table.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ const tableCallExtensions = [
extensions.ExtensionKind.TableHasMethodType,
extensions.ExtensionKind.TableSetType,
extensions.ExtensionKind.TableSetMethodType,
extensions.ExtensionKind.TableAddType,
extensions.ExtensionKind.TableAddMethodType,
];

const tableExtensions = [extensions.ExtensionKind.TableNewType, ...tableCallExtensions];
Expand Down Expand Up @@ -77,6 +79,13 @@ export function transformTableExtensionCall(
) {
return transformTableSetExpression(context, node, extensionType);
}

if (
extensionType === extensions.ExtensionKind.TableAddType ||
extensionType === extensions.ExtensionKind.TableAddMethodType
) {
return transformTableAddExpression(context, node, extensionType);
}
}

function transformTableDeleteExpression(
Expand Down Expand Up @@ -172,3 +181,29 @@ function transformTableSetExpression(
);
return lua.createNilLiteral();
}

function transformTableAddExpression(
context: TransformationContext,
node: ts.CallExpression,
extensionKind: extensions.ExtensionKind
): lua.Expression {
const args = node.arguments.slice();
if (
extensionKind === extensions.ExtensionKind.TableAddMethodType &&
(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))
) {
// In case of method (no table argument), push method owner to front of args list
args.unshift(node.expression.expression);
}

// arg0[arg1] = true
const [table, value] = transformExpressionList(context, args);
context.addPrecedingStatements(
lua.createAssignmentStatement(
lua.createTableIndexExpression(table, value),
lua.createBooleanLiteral(true),
node
)
);
return lua.createNilLiteral();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any thoughts on returning table instead? This is how sets work in TS:

const myset = new Set<number>().add(1).add(2).add(3);

Though I guess that could get us in an annoying spot where we have to worry about side effects evaluating table, so I'd be fine leaving this as it is.

Copy link
Copy Markdown
Contributor Author

@GlassBricks GlassBricks Jul 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is possible to implement, however it would be inconsistent with existing lua table extension functions.
Perhaps this could be done in another PR.

}
9 changes: 8 additions & 1 deletion src/transformation/visitors/loops/for-of.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ import { annotationRemoved } from "../../utils/diagnostics";
import { LuaLibFeature, transformLuaLibFunction } from "../../utils/lualib";
import { isArrayType } from "../../utils/typescript";
import { isIterableExpression, transformForOfIterableStatement } from "../language-extensions/iterable";
import { isPairsIterableExpression, transformForOfPairsIterableStatement } from "../language-extensions/pairsIterable";
import {
isPairsIterableExpression,
transformForOfPairsIterableStatement,
isPairsKeyIterableExpression,
transformForOfPairsKeyIterableStatement,
} from "../language-extensions/pairsIterable";
import { isRangeFunction, transformRangeStatement } from "../language-extensions/range";
import { transformForInitializer, transformLoopBody } from "./utils";

Expand Down Expand Up @@ -50,6 +55,8 @@ export const transformForOfStatement: FunctionVisitor<ts.ForOfStatement> = (node
return transformForOfIterableStatement(context, node, body);
} else if (isPairsIterableExpression(context, node.expression)) {
return transformForOfPairsIterableStatement(context, node, body);
} else if (isPairsKeyIterableExpression(context, node.expression)) {
return transformForOfPairsKeyIterableStatement(context, node, body);
} else if (isLuaIteratorType(context, node.expression)) {
context.diagnostics.push(annotationRemoved(node.expression, AnnotationKind.LuaIterator));
} else if (isArrayType(context, context.checker.getTypeAtLocation(node.expression))) {
Expand Down
99 changes: 99 additions & 0 deletions test/unit/language-extensions/pairsIterable.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,102 @@ test.each(["for (const s of testIterable) {}", "let s; for (s of testIterable) {
.expectDiagnosticsToMatchSnapshot([invalidPairsIterableWithoutDestructuring.code]);
}
);

const testKeyIterable = `
const testKeyIterable = {a1: true, b1: true, c1: true} as unknown as LuaPairsKeyIterable<string>;
`;

test("pairs key iterable", () => {
util.testFunction`
${testKeyIterable}
const results: Record<string, boolean> = {};
for (const k of testKeyIterable) {
results[k] = true;
}
return results;
`
.withLanguageExtensions()
.expectToEqual({ a1: true, b1: true, c1: true });
});

test("pairs key iterable with external control variable", () => {
util.testFunction`
${testKeyIterable}
const results: Record<string, boolean> = {};
let k: string;
for (k of testKeyIterable) {
results[k] = true;
}
return results;
`
.withLanguageExtensions()
.expectToEqual({ a1: true, b1: true, c1: true });
});

test("pairs key iterable function forward", () => {
util.testFunction`
${testKeyIterable}
function forward() { return testKeyIterable; }
const results: Record<string, boolean> = {};
for (const k of forward()) {
results[k] = true;
}
return results;
`
.withLanguageExtensions()
.expectToEqual({ a1: true, b1: true, c1: true });
});

test("pairs key iterable function indirect forward", () => {
util.testFunction`
${testKeyIterable}
function forward() { const iter = testKeyIterable; return iter; }
const results: Record<string, boolean> = {};
for (const k of forward()) {
results[k] = true;
}
return results;
`
.withLanguageExtensions()
.expectToEqual({ a1: true, b1: true, c1: true });
});

test("pairs key iterable arrow function forward", () => {
util.testFunction`
${testKeyIterable}
const forward = () => testKeyIterable;
const results: Record<string, boolean> = {};
for (const k of forward()) {
results[k] = true;
}
return results;
`
.withLanguageExtensions()
.expectToEqual({ a1: true, b1: true, c1: true });
});

test("pairs key iterable with __pairs metamethod", () => {
util.testFunction`
class PairsTest {
__pairs() {
const kvp = [ ["a1", true], ["b1", true], ["c1", true] ];
let i = 0;
return () => {
if (i < kvp.length) {
const [k, v] = kvp[i++];
return $multi(k, v);
}
};
}
}
const tester = new PairsTest() as PairsTest & LuaPairsKeyIterable<string>;
const results: Record<string, boolean> = {};
for (const k of tester) {
results[k] = true;
}
return results;
`
.withLanguageExtensions()
.setOptions({ luaTarget: LuaTarget.Lua53 })
.expectToEqual({ a1: true, b1: true, c1: true });
});
Loading