Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/transformation/utils/diagnostics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,11 @@ export const unsupportedOptionalCompileMembersOnly = createErrorDiagnosticFactor
export const undefinedInArrayLiteral = createErrorDiagnosticFactory(
"Array literals may not contain undefined or null."
);

export const invalidMethodCallExtensionUse = createErrorDiagnosticFactory(
"This language extension must be called as a method."
);

export const invalidSpreadInCallExtension = createErrorDiagnosticFactory(
"Spread elements are not supported in call extensions."
);
77 changes: 77 additions & 0 deletions src/transformation/utils/language-extensions.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import * as ts from "typescript";
import { TransformationContext } from "../context";
import { invalidMethodCallExtensionUse, invalidSpreadInCallExtension } from "./diagnostics";

export enum ExtensionKind {
MultiFunction = "MultiFunction",
Expand Down Expand Up @@ -53,6 +54,7 @@ export enum ExtensionKind {
TableAddKeyType = "TableAddKey",
TableAddKeyMethodType = "TableAddKeyMethod",
}

const extensionValues: Set<string> = new Set(Object.values(ExtensionKind));

export function getExtensionKindForType(context: TransformationContext, type: ts.Type): ExtensionKind | undefined {
Expand Down Expand Up @@ -119,3 +121,78 @@ export function getIterableExtensionKindForNode(
const type = context.checker.getTypeAtLocation(node);
return getIterableExtensionTypeForType(context, type);
}

export const methodExtensionKinds: ReadonlySet<ExtensionKind> = new Set<ExtensionKind>([
ExtensionKind.AdditionOperatorMethodType,
ExtensionKind.SubtractionOperatorMethodType,
ExtensionKind.MultiplicationOperatorMethodType,
ExtensionKind.DivisionOperatorMethodType,
ExtensionKind.ModuloOperatorMethodType,
ExtensionKind.PowerOperatorMethodType,
ExtensionKind.FloorDivisionOperatorMethodType,
ExtensionKind.BitwiseAndOperatorMethodType,
ExtensionKind.BitwiseOrOperatorMethodType,
ExtensionKind.BitwiseExclusiveOrOperatorMethodType,
ExtensionKind.BitwiseLeftShiftOperatorMethodType,
ExtensionKind.BitwiseRightShiftOperatorMethodType,
ExtensionKind.ConcatOperatorMethodType,
ExtensionKind.LessThanOperatorMethodType,
ExtensionKind.GreaterThanOperatorMethodType,
ExtensionKind.NegationOperatorMethodType,
ExtensionKind.BitwiseNotOperatorMethodType,
ExtensionKind.LengthOperatorMethodType,
ExtensionKind.TableDeleteMethodType,
ExtensionKind.TableGetMethodType,
ExtensionKind.TableHasMethodType,
ExtensionKind.TableSetMethodType,
ExtensionKind.TableAddKeyMethodType,
]);

export function getNaryCallExtensionArgs(
context: TransformationContext,
node: ts.CallExpression,
kind: ExtensionKind,
numArgs: number
): readonly ts.Expression[] | undefined {
let expressions: readonly ts.Expression[];
if (node.arguments.some(ts.isSpreadElement)) {
context.diagnostics.push(invalidSpreadInCallExtension(node));
return undefined;
}
if (methodExtensionKinds.has(kind)) {
if (!(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))) {
context.diagnostics.push(invalidMethodCallExtensionUse(node));
return undefined;
}
if (node.arguments.length < numArgs - 1) {
// assumed to be TS error
return undefined;
}
expressions = [node.expression.expression, ...node.arguments];
} else {
if (node.arguments.length < numArgs) {
// assumed to be TS error
return undefined;
}
expressions = node.arguments;
}
return expressions;
}

export function getUnaryCallExtensionArg(
context: TransformationContext,
node: ts.CallExpression,
kind: ExtensionKind
): ts.Expression | undefined {
return getNaryCallExtensionArgs(context, node, kind, 1)?.[0];
}

export function getBinaryCallExtensionArgs(
context: TransformationContext,
node: ts.CallExpression,
kind: ExtensionKind
): readonly [ts.Expression, ts.Expression] | undefined {
const expressions = getNaryCallExtensionArgs(context, node, kind, 2);
if (expressions === undefined) return undefined;
return [expressions[0], expressions[1]];
}
31 changes: 9 additions & 22 deletions src/transformation/visitors/language-extensions/operators.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import { TransformationContext } from "../../context";
import { assert } from "../../../utils";
import { LuaTarget } from "../../../CompilerOptions";
import { unsupportedForTarget } from "../../utils/diagnostics";
import { ExtensionKind } from "../../utils/language-extensions";
import { ExtensionKind, getBinaryCallExtensionArgs, getUnaryCallExtensionArg } from "../../utils/language-extensions";
import { LanguageExtensionCallTransformerMap } from "./call-extension";
import { transformOrderedExpressions } from "../expression-list";

const binaryOperatorMappings = new Map<ExtensionKind, lua.BinaryOperator>([
[ExtensionKind.AdditionOperatorType, lua.SyntaxKind.AdditionOperator],
Expand Down Expand Up @@ -81,35 +82,21 @@ for (const kind of unaryOperatorMappings.keys()) {
function transformBinaryOperator(context: TransformationContext, node: ts.CallExpression, kind: ExtensionKind) {
if (requiresLua53.has(kind)) checkHasLua53(context, node, kind);

let args: readonly ts.Expression[] = node.arguments;
if (
args.length === 1 &&
(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))
) {
args = [node.expression.expression, ...args];
}
const args = getBinaryCallExtensionArgs(context, node, kind);
if (!args) return lua.createNilLiteral();

const [left, right] = transformOrderedExpressions(context, args);

const luaOperator = binaryOperatorMappings.get(kind);
assert(luaOperator);
return lua.createBinaryExpression(
context.transformExpression(args[0]),
context.transformExpression(args[1]),
luaOperator
);
return lua.createBinaryExpression(left, right, luaOperator);
}

function transformUnaryOperator(context: TransformationContext, node: ts.CallExpression, kind: ExtensionKind) {
if (requiresLua53.has(kind)) checkHasLua53(context, node, kind);

let arg: ts.Expression;
if (
node.arguments.length === 0 &&
(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))
) {
arg = node.expression.expression;
} else {
arg = node.arguments[0];
}
const arg = getUnaryCallExtensionArg(context, node, kind);
if (!arg) return lua.createNilLiteral();

const luaOperator = unaryOperatorMappings.get(kind);
assert(luaOperator);
Expand Down
98 changes: 42 additions & 56 deletions src/transformation/visitors/language-extensions/table.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import * as ts from "typescript";
import * as lua from "../../../LuaAST";
import { TransformationContext } from "../../context";
import { ExtensionKind, getExtensionKindForNode } from "../../utils/language-extensions";
import { transformExpressionList } from "../expression-list";
import { LanguageExtensionCallTransformer } from "./call-extension";
import {
ExtensionKind,
getBinaryCallExtensionArgs,
getExtensionKindForNode,
getNaryCallExtensionArgs,
} from "../../utils/language-extensions";
import { transformOrderedExpressions } from "../expression-list";
import { LanguageExtensionCallTransformerMap } from "./call-extension";

export function isTableNewCall(context: TransformationContext, node: ts.NewExpression) {
return getExtensionKindForNode(context, node.expression) === ExtensionKind.TableNewType;
}

export const tableNewExtensions = [ExtensionKind.TableNewType];

export const tableExtensionTransformers: { [P in ExtensionKind]?: LanguageExtensionCallTransformer } = {
export const tableExtensionTransformers: LanguageExtensionCallTransformerMap = {
[ExtensionKind.TableDeleteType]: transformTableDeleteExpression,
[ExtensionKind.TableDeleteMethodType]: transformTableDeleteExpression,
[ExtensionKind.TableGetType]: transformTableGetExpression,
Expand All @@ -19,72 +25,56 @@ export const tableExtensionTransformers: { [P in ExtensionKind]?: LanguageExtens
[ExtensionKind.TableHasMethodType]: transformTableHasExpression,
[ExtensionKind.TableSetType]: transformTableSetExpression,
[ExtensionKind.TableSetMethodType]: transformTableSetExpression,
[ExtensionKind.TableAddKeyType]: transformTableAddExpression,
[ExtensionKind.TableAddKeyMethodType]: transformTableAddExpression,
[ExtensionKind.TableAddKeyType]: transformTableAddKeyExpression,
[ExtensionKind.TableAddKeyMethodType]: transformTableAddKeyExpression,
};

function transformTableDeleteExpression(
context: TransformationContext,
node: ts.CallExpression,
extensionKind: ExtensionKind
): lua.Expression {
const args = node.arguments.slice();
if (
extensionKind === ExtensionKind.TableDeleteMethodType &&
(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))
) {
// In case of method (no table argument), push method owner to front of args list
args.unshift(node.expression.expression);
const args = getBinaryCallExtensionArgs(context, node, extensionKind);
if (!args) {
return lua.createNilLiteral();
}

const [table, accessExpression] = transformExpressionList(context, args);
const [table, key] = transformOrderedExpressions(context, args);
// arg0[arg1] = nil
context.addPrecedingStatements(
lua.createAssignmentStatement(
lua.createTableIndexExpression(table, accessExpression),
lua.createNilLiteral(),
node
)
lua.createAssignmentStatement(lua.createTableIndexExpression(table, key), lua.createNilLiteral(), node)
);
return lua.createBooleanLiteral(true);
}

function transformWithTableArgument(context: TransformationContext, node: ts.CallExpression): lua.Expression[] {
if (ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression)) {
return transformExpressionList(context, [node.expression.expression, ...node.arguments]);
}
// todo: report diagnostic?
return [lua.createNilLiteral(), ...transformExpressionList(context, node.arguments)];
}

function transformTableGetExpression(
context: TransformationContext,
node: ts.CallExpression,
extensionKind: ExtensionKind
): lua.Expression {
const args =
extensionKind === ExtensionKind.TableGetMethodType
? transformWithTableArgument(context, node)
: transformExpressionList(context, node.arguments);
const args = getBinaryCallExtensionArgs(context, node, extensionKind);
if (!args) {
return lua.createNilLiteral();
}

const [table, accessExpression] = args;
const [table, key] = transformOrderedExpressions(context, args);
// arg0[arg1]
return lua.createTableIndexExpression(table, accessExpression, node);
return lua.createTableIndexExpression(table, key, node);
}

function transformTableHasExpression(
context: TransformationContext,
node: ts.CallExpression,
extensionKind: ExtensionKind
): lua.Expression {
const args =
extensionKind === ExtensionKind.TableHasMethodType
? transformWithTableArgument(context, node)
: transformExpressionList(context, node.arguments);
const args = getBinaryCallExtensionArgs(context, node, extensionKind);
if (!args) {
return lua.createNilLiteral();
}

const [table, accessExpression] = args;
const [table, key] = transformOrderedExpressions(context, args);
// arg0[arg1]
const tableIndexExpression = lua.createTableIndexExpression(table, accessExpression);
const tableIndexExpression = lua.createTableIndexExpression(table, key);

// arg0[arg1] ~= nil
return lua.createBinaryExpression(
Expand All @@ -100,37 +90,33 @@ function transformTableSetExpression(
node: ts.CallExpression,
extensionKind: ExtensionKind
): lua.Expression {
const args =
extensionKind === ExtensionKind.TableSetMethodType
? transformWithTableArgument(context, node)
: transformExpressionList(context, node.arguments);
const args = getNaryCallExtensionArgs(context, node, extensionKind, 3);
if (!args) {
return lua.createNilLiteral();
}

const [table, accessExpression, value] = args;
const [table, key, value] = transformOrderedExpressions(context, args);
// arg0[arg1] = arg2
context.addPrecedingStatements(
lua.createAssignmentStatement(lua.createTableIndexExpression(table, accessExpression), value, node)
lua.createAssignmentStatement(lua.createTableIndexExpression(table, key), value, node)
);
return lua.createNilLiteral();
}

function transformTableAddExpression(
function transformTableAddKeyExpression(
context: TransformationContext,
node: ts.CallExpression,
extensionKind: ExtensionKind
): lua.Expression {
const args =
extensionKind === ExtensionKind.TableAddKeyMethodType
? transformWithTableArgument(context, node)
: transformExpressionList(context, node.arguments);
const args = getNaryCallExtensionArgs(context, node, extensionKind, 2);
if (!args) {
return lua.createNilLiteral();
}

const [table, value] = args;
const [table, key] = transformOrderedExpressions(context, args);
// arg0[arg1] = true
context.addPrecedingStatements(
lua.createAssignmentStatement(
lua.createTableIndexExpression(table, value),
lua.createBooleanLiteral(true),
node
)
lua.createAssignmentStatement(lua.createTableIndexExpression(table, key), lua.createBooleanLiteral(true), node)
);
return lua.createNilLiteral();
}
6 changes: 4 additions & 2 deletions test/unit/__snapshots__/optionalChaining.spec.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ exports[`Unsupported optional chains Compile members only: diagnostics 1`] = `"m
exports[`Unsupported optional chains Language extensions: code 1`] = `
"local ____opt_0 = ({}).has
if ____opt_0 ~= nil then
local ____ = nil[3] ~= nil
end"
`;

exports[`Unsupported optional chains Language extensions: diagnostics 1`] = `"main.ts(2,17): error TSTL: Optional calls are not supported for builtin or language extension functions."`;
exports[`Unsupported optional chains Language extensions: diagnostics 1`] = `
"main.ts(2,17): error TSTL: Optional calls are not supported for builtin or language extension functions.
main.ts(2,17): error TSTL: This language extension must be called as a method."
`;

exports[`long optional chain 1`] = `
"local ____exports = {}
Expand Down
12 changes: 12 additions & 0 deletions test/unit/language-extensions/__snapshots__/operators.spec.ts.snap
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP

exports[`does not crash on invalid operator use global function: code 1`] = `""`;

exports[`does not crash on invalid operator use global function: diagnostics 1`] = `"main.ts(3,13): error TS2554: Expected 2 arguments, but got 1."`;

exports[`does not crash on invalid operator use method: code 1`] = `"left = {}"`;

exports[`does not crash on invalid operator use method: diagnostics 1`] = `"main.ts(5,18): error TS2554: Expected 1 arguments, but got 0."`;

exports[`does not crash on invalid operator use unary operator: code 1`] = `"op(_G)"`;

exports[`does not crash on invalid operator use unary operator: diagnostics 1`] = `"main.ts(2,31): error TS2304: Cannot find name 'LuaUnaryMinus'."`;

exports[`operator mapping - invalid use (const foo = (op as any)(1, 2);): code 1`] = `"foo = op(_G, 1, 2)"`;

exports[`operator mapping - invalid use (const foo = (op as any)(1, 2);): diagnostics 1`] = `"main.ts(3,22): error TSTL: This function must be called directly and cannot be referred to."`;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,11 @@ __TS__ArrayMap({\\"a\\", \\"b\\", \\"c\\"}, ____table.has)"
`;

exports[`LuaTableHas extension invalid use method expression ("LuaTable<string, number>"): diagnostics 1`] = `"main.ts(3,37): error TSTL: This function must be called directly and cannot be referred to."`;

exports[`does not crash on invalid extension use global function: code 1`] = `""`;

exports[`does not crash on invalid extension use global function: diagnostics 1`] = `"main.ts(3,9): error TS2554: Expected 2 arguments, but got 1."`;

exports[`does not crash on invalid extension use method: code 1`] = `"left = {}"`;

exports[`does not crash on invalid extension use method: diagnostics 1`] = `"main.ts(5,14): error TS2554: Expected 2 arguments, but got 0."`;
Loading