Skip to content

Commit 1011713

Browse files
authored
LuaIterable language extension (#981)
* added LuaIterable and LuaMultiIterable language extensions * indentation fix * remove LuaMultiIterable in favor of LuaIterable<LuaMultiReturn> * switched from expectToHaveDiagnostics to expectDiagnosticsToMatchSnapshot * updates based on feedback and discussions - LuaIterable type reworked for better lua compatibility - language extensions now checked by brand instead of type alias name - fixed LuaMultiReturn indirect forward, which also affected LuaIterable - reorganized tests and added some for manual iterable usage * fixed issue with no-state vs state iterables, and updated tests to check both (also added test for property based iterables) * updated extension kind checking to reduce complexity * updated new multi tests to actually use multiple values * replaced raw brands with LuaExtension type Co-authored-by: Tom <tomblind@users.noreply.github.com>
1 parent d075097 commit 1011713

File tree

14 files changed

+913
-221
lines changed

14 files changed

+913
-221
lines changed

language-extensions/index.d.ts

Lines changed: 109 additions & 105 deletions
Large diffs are not rendered by default.

src/lualib/Iterator.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function __TS__IteratorStringStep(this: string, index: number): [number, string]
2727
function __TS__Iterator<T>(
2828
this: void,
2929
iterable: string | GeneratorIterator | Iterable<T> | readonly T[]
30-
): [(...args: any[]) => [any, any] | [], ...any[]] {
30+
): [(...args: any[]) => [any, any] | [], ...any[]] | LuaIterable<LuaMultiReturn<[number, T]>> {
3131
if (typeof iterable === "string") {
3232
return [__TS__IteratorStringStep, iterable, 0];
3333
} else if ("____coroutine" in iterable) {
@@ -36,6 +36,6 @@ function __TS__Iterator<T>(
3636
const iterator = iterable[Symbol.iterator]();
3737
return [__TS__IteratorIteratorStep, iterator];
3838
} else {
39-
return ipairs(iterable as readonly T[]) as any;
39+
return ipairs(iterable as readonly T[]);
4040
}
4141
}

src/lualib/declarations/global.d.ts

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,4 @@ declare function unpack<T>(list: T[], i?: number, j?: number): T[];
2525
declare function select<T>(index: number, ...args: T[]): T;
2626
declare function select<T>(index: "#", ...args: T[]): number;
2727

28-
/**
29-
* @luaIterator
30-
* @tupleReturn
31-
*/
32-
type LuaTupleIterator<T extends any[]> = Iterable<T> & { " LuaTupleIterator": never };
33-
34-
declare function ipairs<T>(t: Record<number, T>): LuaTupleIterator<[number, T]>;
28+
declare function ipairs<T>(t: Record<number, T>): LuaIterable<LuaMultiReturn<[number, T]>, Record<number, T>>;

src/transformation/utils/diagnostics.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ export const luaIteratorForbiddenUsage = createErrorDiagnosticFactory(
118118
"the '@tupleReturn' annotation."
119119
);
120120

121+
export const invalidMultiIterableWithoutDestructuring = createErrorDiagnosticFactory(
122+
"LuaIterable with a LuaMultiReturn return value type must be destructured."
123+
);
124+
121125
export const unsupportedAccessorInObjectLiteral = createErrorDiagnosticFactory(
122126
"Accessors in object literal are not supported."
123127
);
Lines changed: 58 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import * as ts from "typescript";
2-
import * as path from "path";
2+
import { TransformationContext } from "../context";
33

44
export enum ExtensionKind {
55
MultiFunction = "MultiFunction",
66
MultiType = "MultiType",
77
RangeFunction = "RangeFunction",
8+
IterableType = "IterableType",
89
AdditionOperatorType = "AdditionOperatorType",
910
AdditionOperatorMethodType = "AdditionOperatorMethodType",
1011
SubtractionOperatorType = "SubtractionOperatorType",
@@ -43,74 +44,66 @@ export enum ExtensionKind {
4344
LengthOperatorMethodType = "LengthOperatorMethodType",
4445
}
4546

46-
const functionNameToExtensionKind: { [name: string]: ExtensionKind } = {
47-
$multi: ExtensionKind.MultiFunction,
48-
$range: ExtensionKind.RangeFunction,
47+
const extensionKindToFunctionName: { [T in ExtensionKind]?: string } = {
48+
[ExtensionKind.MultiFunction]: "$multi",
49+
[ExtensionKind.RangeFunction]: "$range",
4950
};
5051

51-
const typeNameToExtensionKind: { [name: string]: ExtensionKind } = {
52-
LuaMultiReturn: ExtensionKind.MultiType,
53-
LuaAddition: ExtensionKind.AdditionOperatorType,
54-
LuaAdditionMethod: ExtensionKind.AdditionOperatorMethodType,
55-
LuaSubtraction: ExtensionKind.SubtractionOperatorType,
56-
LuaSubtractionMethod: ExtensionKind.SubtractionOperatorMethodType,
57-
LuaMultiplication: ExtensionKind.MultiplicationOperatorType,
58-
LuaMultiplicationMethod: ExtensionKind.MultiplicationOperatorMethodType,
59-
LuaDivision: ExtensionKind.DivisionOperatorType,
60-
LuaDivisionMethod: ExtensionKind.DivisionOperatorMethodType,
61-
LuaModulo: ExtensionKind.ModuloOperatorType,
62-
LuaModuloMethod: ExtensionKind.ModuloOperatorMethodType,
63-
LuaPower: ExtensionKind.PowerOperatorType,
64-
LuaPowerMethod: ExtensionKind.PowerOperatorMethodType,
65-
LuaFloorDivision: ExtensionKind.FloorDivisionOperatorType,
66-
LuaFloorDivisionMethod: ExtensionKind.FloorDivisionOperatorMethodType,
67-
LuaBitwiseAnd: ExtensionKind.BitwiseAndOperatorType,
68-
LuaBitwiseAndMethod: ExtensionKind.BitwiseAndOperatorMethodType,
69-
LuaBitwiseOr: ExtensionKind.BitwiseOrOperatorType,
70-
LuaBitwiseOrMethod: ExtensionKind.BitwiseOrOperatorMethodType,
71-
LuaBitwiseExclusiveOr: ExtensionKind.BitwiseExclusiveOrOperatorType,
72-
LuaBitwiseExclusiveOrMethod: ExtensionKind.BitwiseExclusiveOrOperatorMethodType,
73-
LuaBitwiseLeftShift: ExtensionKind.BitwiseLeftShiftOperatorType,
74-
LuaBitwiseLeftShiftMethod: ExtensionKind.BitwiseLeftShiftOperatorMethodType,
75-
LuaBitwiseRightShift: ExtensionKind.BitwiseRightShiftOperatorType,
76-
LuaBitwiseRightShiftMethod: ExtensionKind.BitwiseRightShiftOperatorMethodType,
77-
LuaConcat: ExtensionKind.ConcatOperatorType,
78-
LuaConcatMethod: ExtensionKind.ConcatOperatorMethodType,
79-
LuaLessThan: ExtensionKind.LessThanOperatorType,
80-
LuaLessThanMethod: ExtensionKind.LessThanOperatorMethodType,
81-
LuaGreaterThan: ExtensionKind.GreaterThanOperatorType,
82-
LuaGreaterThanMethod: ExtensionKind.GreaterThanOperatorMethodType,
83-
LuaNegation: ExtensionKind.NegationOperatorType,
84-
LuaNegationMethod: ExtensionKind.NegationOperatorMethodType,
85-
LuaBitwiseNot: ExtensionKind.BitwiseNotOperatorType,
86-
LuaBitwiseNotMethod: ExtensionKind.BitwiseNotOperatorMethodType,
87-
LuaLength: ExtensionKind.LengthOperatorType,
88-
LuaLengthMethod: ExtensionKind.LengthOperatorMethodType,
52+
const extensionKindToTypeBrand: { [T in ExtensionKind]: string } = {
53+
[ExtensionKind.MultiFunction]: "__luaMultiFunctionBrand",
54+
[ExtensionKind.MultiType]: "__luaMultiReturnBrand",
55+
[ExtensionKind.RangeFunction]: "__luaRangeFunctionBrand",
56+
[ExtensionKind.IterableType]: "__luaIterableBrand",
57+
[ExtensionKind.AdditionOperatorType]: "__luaAdditionBrand",
58+
[ExtensionKind.AdditionOperatorMethodType]: "__luaAdditionMethodBrand",
59+
[ExtensionKind.SubtractionOperatorType]: "__luaSubtractionBrand",
60+
[ExtensionKind.SubtractionOperatorMethodType]: "__luaSubtractionMethodBrand",
61+
[ExtensionKind.MultiplicationOperatorType]: "__luaMultiplicationBrand",
62+
[ExtensionKind.MultiplicationOperatorMethodType]: "__luaMultiplicationMethodBrand",
63+
[ExtensionKind.DivisionOperatorType]: "__luaDivisionBrand",
64+
[ExtensionKind.DivisionOperatorMethodType]: "__luaDivisionMethodBrand",
65+
[ExtensionKind.ModuloOperatorType]: "__luaModuloBrand",
66+
[ExtensionKind.ModuloOperatorMethodType]: "__luaModuloMethodBrand",
67+
[ExtensionKind.PowerOperatorType]: "__luaPowerBrand",
68+
[ExtensionKind.PowerOperatorMethodType]: "__luaPowerMethodBrand",
69+
[ExtensionKind.FloorDivisionOperatorType]: "__luaFloorDivisionBrand",
70+
[ExtensionKind.FloorDivisionOperatorMethodType]: "__luaFloorDivisionMethodBrand",
71+
[ExtensionKind.BitwiseAndOperatorType]: "__luaBitwiseAndBrand",
72+
[ExtensionKind.BitwiseAndOperatorMethodType]: "__luaBitwiseAndMethodBrand",
73+
[ExtensionKind.BitwiseOrOperatorType]: "__luaBitwiseOrBrand",
74+
[ExtensionKind.BitwiseOrOperatorMethodType]: "__luaBitwiseOrMethodBrand",
75+
[ExtensionKind.BitwiseExclusiveOrOperatorType]: "__luaBitwiseExclusiveOrBrand",
76+
[ExtensionKind.BitwiseExclusiveOrOperatorMethodType]: "__luaBitwiseExclusiveOrMethodBrand",
77+
[ExtensionKind.BitwiseLeftShiftOperatorType]: "__luaBitwiseLeftShiftBrand",
78+
[ExtensionKind.BitwiseLeftShiftOperatorMethodType]: "__luaBitwiseLeftShiftMethodBrand",
79+
[ExtensionKind.BitwiseRightShiftOperatorType]: "__luaBitwiseRightShiftBrand",
80+
[ExtensionKind.BitwiseRightShiftOperatorMethodType]: "__luaBitwiseRightShiftMethodBrand",
81+
[ExtensionKind.ConcatOperatorType]: "__luaConcatBrand",
82+
[ExtensionKind.ConcatOperatorMethodType]: "__luaConcatMethodBrand",
83+
[ExtensionKind.LessThanOperatorType]: "__luaLessThanBrand",
84+
[ExtensionKind.LessThanOperatorMethodType]: "__luaLessThanMethodBrand",
85+
[ExtensionKind.GreaterThanOperatorType]: "__luaGreaterThanBrand",
86+
[ExtensionKind.GreaterThanOperatorMethodType]: "__luaGreaterThanMethodBrand",
87+
[ExtensionKind.NegationOperatorType]: "__luaNegationBrand",
88+
[ExtensionKind.NegationOperatorMethodType]: "__luaNegationMethodBrand",
89+
[ExtensionKind.BitwiseNotOperatorType]: "__luaBitwiseNotBrand",
90+
[ExtensionKind.BitwiseNotOperatorMethodType]: "__luaBitwiseNotMethodBrand",
91+
[ExtensionKind.LengthOperatorType]: "__luaLengthBrand",
92+
[ExtensionKind.LengthOperatorMethodType]: "__luaLengthMethodBrand",
8993
};
9094

91-
function isSourceFileFromLanguageExtensions(sourceFile: ts.SourceFile): boolean {
92-
const extensionDirectory = path.resolve(__dirname, "../../../language-extensions");
93-
const sourceFileDirectory = path.dirname(path.normalize(sourceFile.fileName));
94-
return extensionDirectory === sourceFileDirectory;
95+
export function isExtensionType(type: ts.Type, extensionKind: ExtensionKind): boolean {
96+
const typeBrand = extensionKindToTypeBrand[extensionKind];
97+
return typeBrand !== undefined && type.getProperty(typeBrand) !== undefined;
9598
}
9699

97-
export function getExtensionKind(declaration: ts.Declaration): ExtensionKind | undefined {
98-
const sourceFile = declaration.getSourceFile();
99-
if (isSourceFileFromLanguageExtensions(sourceFile)) {
100-
if (ts.isFunctionDeclaration(declaration) && declaration.name?.text) {
101-
const extensionKind = functionNameToExtensionKind[declaration.name.text];
102-
if (extensionKind) {
103-
return extensionKind;
104-
}
105-
}
106-
107-
if (ts.isTypeAliasDeclaration(declaration)) {
108-
const extensionKind = typeNameToExtensionKind[declaration.name.text];
109-
if (extensionKind) {
110-
return extensionKind;
111-
}
112-
}
113-
114-
throw new Error("Unknown extension kind");
115-
}
100+
export function isExtensionFunction(
101+
context: TransformationContext,
102+
symbol: ts.Symbol,
103+
extensionKind: ExtensionKind
104+
): boolean {
105+
return (
106+
symbol.getName() === extensionKindToFunctionName[extensionKind] &&
107+
symbol.declarations.some(d => isExtensionType(context.checker.getTypeAtLocation(d), extensionKind))
108+
);
116109
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import * as ts from "typescript";
2+
import * as lua from "../../../LuaAST";
3+
import * as extensions from "../../utils/language-extensions";
4+
import { TransformationContext } from "../../context";
5+
import { getVariableDeclarationBinding, transformForInitializer } from "../loops/utils";
6+
import { transformArrayBindingElement } from "../variable-declaration";
7+
import { invalidMultiIterableWithoutDestructuring } from "../../utils/diagnostics";
8+
import { cast } from "../../../utils";
9+
import { isMultiReturnType } from "./multi";
10+
11+
export function isIterableType(type: ts.Type): boolean {
12+
return extensions.isExtensionType(type, extensions.ExtensionKind.IterableType);
13+
}
14+
15+
export function returnsIterableType(context: TransformationContext, node: ts.CallExpression): boolean {
16+
const signature = context.checker.getResolvedSignature(node);
17+
const type = signature?.getReturnType();
18+
return type ? isIterableType(type) : false;
19+
}
20+
21+
export function isIterableExpression(context: TransformationContext, expression: ts.Expression): boolean {
22+
const type = context.checker.getTypeAtLocation(expression);
23+
return isIterableType(type);
24+
}
25+
26+
function transformForOfMultiIterableStatement(
27+
context: TransformationContext,
28+
statement: ts.ForOfStatement,
29+
block: lua.Block
30+
): lua.Statement {
31+
const luaIterator = context.transformExpression(statement.expression);
32+
let identifiers: lua.Identifier[] = [];
33+
34+
if (ts.isVariableDeclarationList(statement.initializer)) {
35+
// Variables declared in for loop
36+
// for ${initializer} in ${iterable} do
37+
const binding = getVariableDeclarationBinding(context, statement.initializer);
38+
if (ts.isArrayBindingPattern(binding)) {
39+
identifiers = binding.elements.map(e => transformArrayBindingElement(context, e));
40+
} else {
41+
context.diagnostics.push(invalidMultiIterableWithoutDestructuring(binding));
42+
}
43+
} else if (ts.isArrayLiteralExpression(statement.initializer)) {
44+
// Variables NOT declared in for loop - catch iterator values in temps and assign
45+
// for ____value0 in ${iterable} do
46+
// ${initializer} = ____value0
47+
identifiers = statement.initializer.elements.map((_, i) => lua.createIdentifier(`____value${i}`));
48+
if (identifiers.length > 0) {
49+
block.statements.unshift(
50+
lua.createAssignmentStatement(
51+
statement.initializer.elements.map(e =>
52+
cast(context.transformExpression(e), lua.isAssignmentLeftHandSideExpression)
53+
),
54+
identifiers
55+
)
56+
);
57+
}
58+
} else {
59+
context.diagnostics.push(invalidMultiIterableWithoutDestructuring(statement.initializer));
60+
}
61+
62+
if (identifiers.length === 0) {
63+
identifiers.push(lua.createAnonymousIdentifier());
64+
}
65+
66+
return lua.createForInStatement(block, identifiers, [luaIterator], statement);
67+
}
68+
69+
export function transformForOfIterableStatement(
70+
context: TransformationContext,
71+
statement: ts.ForOfStatement,
72+
block: lua.Block
73+
): lua.Statement {
74+
const type = context.checker.getTypeAtLocation(statement.expression);
75+
if (type.aliasTypeArguments?.length === 2 && isMultiReturnType(type.aliasTypeArguments[0])) {
76+
return transformForOfMultiIterableStatement(context, statement, block);
77+
}
78+
79+
const luaIterator = context.transformExpression(statement.expression);
80+
const identifier = transformForInitializer(context, statement.initializer, block);
81+
return lua.createForInStatement(block, [identifier], [luaIterator], statement);
82+
}

src/transformation/visitors/language-extensions/multi.ts

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
import * as ts from "typescript";
22
import * as extensions from "../../utils/language-extensions";
33
import { TransformationContext } from "../../context";
4-
import { invalidMultiFunctionUse } from "../../utils/diagnostics";
54
import { findFirstNodeAbove } from "../../utils/typescript";
6-
7-
const isMultiFunctionDeclaration = (declaration: ts.Declaration): boolean =>
8-
extensions.getExtensionKind(declaration) === extensions.ExtensionKind.MultiFunction;
9-
10-
const isMultiTypeDeclaration = (declaration: ts.Declaration): boolean =>
11-
extensions.getExtensionKind(declaration) === extensions.ExtensionKind.MultiType;
5+
import { isIterableExpression } from "./iterable";
6+
import { invalidMultiFunctionUse } from "../../utils/diagnostics";
127

138
export function isMultiReturnType(type: ts.Type): boolean {
14-
return type.aliasSymbol?.declarations?.some(isMultiTypeDeclaration) ?? false;
9+
return extensions.isExtensionType(type, extensions.ExtensionKind.MultiType);
1510
}
1611

1712
export function isMultiFunctionCall(context: TransformationContext, expression: ts.CallExpression): boolean {
18-
const type = context.checker.getTypeAtLocation(expression.expression);
19-
return type.symbol?.declarations?.some(isMultiFunctionDeclaration) ?? false;
13+
return isMultiFunctionNode(context, expression.expression);
2014
}
2115

2216
export function returnsMultiType(context: TransformationContext, node: ts.CallExpression): boolean {
@@ -30,8 +24,8 @@ export function isMultiReturnCall(context: TransformationContext, expression: ts
3024
}
3125

3226
export function isMultiFunctionNode(context: TransformationContext, node: ts.Node): boolean {
33-
const type = context.checker.getTypeAtLocation(node);
34-
return type.symbol?.declarations?.some(isMultiFunctionDeclaration) ?? false;
27+
const symbol = context.checker.getSymbolAtLocation(node);
28+
return symbol ? extensions.isExtensionFunction(context, symbol, extensions.ExtensionKind.MultiFunction) : false;
3529
}
3630

3731
export function isInMultiReturnFunction(context: TransformationContext, node: ts.Node) {
@@ -86,6 +80,11 @@ export function shouldMultiReturnCallBeWrapped(context: TransformationContext, n
8680
return false;
8781
}
8882

83+
// LuaIterable in for...of
84+
if (ts.isForOfStatement(node.parent) && isIterableExpression(context, node)) {
85+
return false;
86+
}
87+
8988
return true;
9089
}
9190

@@ -99,8 +98,7 @@ export function findMultiAssignmentViolations(
9998
if (!ts.isShorthandPropertyAssignment(element)) continue;
10099
const valueSymbol = context.checker.getShorthandAssignmentValueSymbol(element);
101100
if (valueSymbol) {
102-
const declaration = valueSymbol.valueDeclaration;
103-
if (declaration && isMultiFunctionDeclaration(declaration)) {
101+
if (extensions.isExtensionFunction(context, valueSymbol, extensions.ExtensionKind.MultiFunction)) {
104102
context.diagnostics.push(invalidMultiFunctionUse(element));
105103
result.push(element);
106104
}

src/transformation/visitors/language-extensions/operators.ts

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,7 @@ const unaryOperatorMappings = new Map<extensions.ExtensionKind, lua.UnaryOperato
4949
[extensions.ExtensionKind.LengthOperatorMethodType, lua.SyntaxKind.LengthOperator],
5050
]);
5151

52-
const operatorMapExtensions = new Set<extensions.ExtensionKind>([
53-
...binaryOperatorMappings.keys(),
54-
...unaryOperatorMappings.keys(),
55-
]);
52+
const operatorMapExtensions = [...binaryOperatorMappings.keys(), ...unaryOperatorMappings.keys()];
5653

5754
const bitwiseOperatorMapExtensions = new Set<extensions.ExtensionKind>([
5855
extensions.ExtensionKind.BitwiseAndOperatorType,
@@ -84,25 +81,15 @@ function getOperatorMapExtensionKindForCall(context: TransformationContext, node
8481
if (!typeDeclaration) {
8582
return;
8683
}
87-
const mapping = extensions.getExtensionKind(typeDeclaration);
88-
if (mapping !== undefined && operatorMapExtensions.has(mapping)) {
89-
return mapping;
90-
}
91-
}
92-
93-
function isOperatorMapDeclaration(declaration: ts.Declaration) {
94-
const typeDeclaration = getTypeDeclaration(declaration);
95-
if (typeDeclaration) {
96-
const extensionKind = extensions.getExtensionKind(typeDeclaration);
97-
return extensionKind !== undefined ? operatorMapExtensions.has(extensionKind) : false;
98-
}
84+
const type = context.checker.getTypeFromTypeNode(typeDeclaration.type);
85+
return operatorMapExtensions.find(extensionKind => extensions.isExtensionType(type, extensionKind));
9986
}
10087

10188
function isOperatorMapType(context: TransformationContext, type: ts.Type): boolean {
10289
if (type.isUnionOrIntersection()) {
10390
return type.types.some(t => isOperatorMapType(context, t));
10491
} else {
105-
return type.symbol?.declarations?.some(isOperatorMapDeclaration);
92+
return operatorMapExtensions.some(extensionKind => extensions.isExtensionType(type, extensionKind));
10693
}
10794
}
10895

src/transformation/visitors/language-extensions/range.ts

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,13 @@ import { transformArguments } from "../call";
88
import { assert } from "../../../utils";
99
import { invalidRangeControlVariable } from "../../utils/diagnostics";
1010

11-
const isRangeFunctionDeclaration = (declaration: ts.Declaration): boolean =>
12-
extensions.getExtensionKind(declaration) === extensions.ExtensionKind.RangeFunction;
13-
1411
export function isRangeFunction(context: TransformationContext, expression: ts.CallExpression): boolean {
15-
const type = context.checker.getTypeAtLocation(expression.expression);
16-
return type.symbol?.declarations?.some(isRangeFunctionDeclaration) ?? false;
12+
return isRangeFunctionNode(context, expression.expression);
1713
}
1814

1915
export function isRangeFunctionNode(context: TransformationContext, node: ts.Node): boolean {
2016
const symbol = context.checker.getSymbolAtLocation(node);
21-
return symbol?.declarations?.some(isRangeFunctionDeclaration) ?? false;
17+
return symbol ? extensions.isExtensionFunction(context, symbol, extensions.ExtensionKind.RangeFunction) : false;
2218
}
2319

2420
function getControlVariable(context: TransformationContext, statement: ts.ForOfStatement) {

0 commit comments

Comments
 (0)