Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/transformation/builtins/array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { TransformationContext } from "../context";
import { unsupportedProperty } from "../utils/diagnostics";
import { LuaLibFeature, transformLuaLibFunction } from "../utils/lualib";
import { transformArguments, transformCallAndArguments } from "../visitors/call";
import { findFirstNonOuterParent, typeAlwaysHasSomeOfFlags } from "../utils/typescript";
import { expressionResultIsUsed, typeAlwaysHasSomeOfFlags } from "../utils/typescript";
import { moveToPrecedingTemp } from "../visitors/expression-list";
import { isUnpackCall, wrapInTable } from "../utils/lua-ast";

Expand Down Expand Up @@ -54,8 +54,6 @@ function transformSingleElementArrayPush(
caller: lua.Expression,
param: lua.Expression
): lua.Expression {
const expressionIsUsed = !ts.isExpressionStatement(findFirstNonOuterParent(node));

const arrayIdentifier = lua.isIdentifier(caller) ? caller : moveToPrecedingTemp(context, caller);

// #array + 1
Expand All @@ -65,6 +63,7 @@ function transformSingleElementArrayPush(
lua.SyntaxKind.AdditionOperator
);

const expressionIsUsed = expressionResultIsUsed(node);
if (expressionIsUsed) {
// store length in a temp
lengthExpression = moveToPrecedingTemp(context, lengthExpression);
Expand Down
4 changes: 4 additions & 0 deletions src/transformation/utils/typescript/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ export function findFirstNonOuterParent(node: ts.Node): ts.Node {
return current;
}

export function expressionResultIsUsed(node: ts.Expression): boolean {
return !ts.isExpressionStatement(findFirstNonOuterParent(node));
}

export function getFirstDeclarationInFile(symbol: ts.Symbol, sourceFile: ts.SourceFile): ts.Declaration | undefined {
const originalSourceFile = ts.getParseTreeNode(sourceFile) ?? sourceFile;
const declarations = (symbol.getDeclarations() ?? []).filter(d => d.getSourceFile() === originalSourceFile);
Expand Down
2 changes: 0 additions & 2 deletions src/transformation/utils/typescript/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ export function canBeFalsy(context: TransformationContext, type: ts.Type): boole
}

export function canBeFalsyWhenNotNull(context: TransformationContext, type: ts.Type): boolean {
const strictNullChecks = context.options.strict === true || context.options.strictNullChecks === true;
if (!strictNullChecks && !type.isLiteral()) return true;
const falsyFlags =
ts.TypeFlags.Boolean |
ts.TypeFlags.BooleanLiteral |
Expand Down
11 changes: 3 additions & 8 deletions src/transformation/visitors/expression-statement.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as ts from "typescript";
import * as lua from "../../LuaAST";
import { FunctionVisitor, tempSymbolId, TransformationContext } from "../context";
import { FunctionVisitor, tempSymbolId } from "../context";
import { transformBinaryExpressionStatement } from "./binary-expression";
import { transformUnaryExpressionStatement } from "./unary-expression";

Expand All @@ -15,15 +15,10 @@ export const transformExpressionStatement: FunctionVisitor<ts.ExpressionStatemen
return binaryExpressionResult;
}

return transformExpressionToStatement(context, node.expression);
return wrapInStatement(context.transformExpression(node.expression));
};

export function transformExpressionToStatement(
context: TransformationContext,
expression: ts.Expression
): lua.Statement | undefined {
const result = context.transformExpression(expression);

export function wrapInStatement(result: lua.Expression): lua.Statement | undefined {
const isTempVariable = lua.isIdentifier(result) && result.symbolId === tempSymbolId;
if (isTempVariable) {
return undefined;
Expand Down
7 changes: 5 additions & 2 deletions src/transformation/visitors/identifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { invalidCallExtensionUse } from "../utils/diagnostics";
import { createExportedIdentifier, getSymbolExportScope } from "../utils/export";
import { createSafeName, hasUnsafeIdentifierName } from "../utils/safe-names";
import { getIdentifierSymbolId } from "../utils/symbols";
import { isOptionalContinuation } from "./optional-chaining";
import { getOptionalContinuationData, isOptionalContinuation } from "./optional-chaining";
import { isStandardLibraryType } from "../utils/typescript";
import { getExtensionKindForNode, getExtensionKindForSymbol } from "../utils/language-extensions";
import { callExtensions } from "./language-extensions/call-extension";
Expand All @@ -16,13 +16,16 @@ import { isIdentifierExtensionValue, reportInvalidExtensionValue } from "./langu
export function transformIdentifier(context: TransformationContext, identifier: ts.Identifier): lua.Identifier {
return transformNonValueIdentifier(context, identifier, context.checker.getSymbolAtLocation(identifier));
}

function transformNonValueIdentifier(
context: TransformationContext,
identifier: ts.Identifier,
symbol: ts.Symbol | undefined
) {
if (isOptionalContinuation(identifier)) {
return lua.createIdentifier(identifier.text, undefined, tempSymbolId);
const result = lua.createIdentifier(identifier.text, undefined, tempSymbolId);
getOptionalContinuationData(identifier)!.usedIdentifiers.push(result);
return result;
}

const extensionKind = symbol
Expand Down
140 changes: 100 additions & 40 deletions src/transformation/visitors/optional-chaining.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import * as ts from "typescript";
import * as lua from "../../LuaAST";
import { TransformationContext, tempSymbolId } from "../context";
import { tempSymbolId, TransformationContext } from "../context";
import { assert, assertNever } from "../../utils";
import { transformInPrecedingStatementScope } from "../utils/preceding-statements";
import { transformPropertyAccessExpressionWithCapture, transformElementAccessExpressionWithCapture } from "./access";
import { transformElementAccessExpressionWithCapture, transformPropertyAccessExpressionWithCapture } from "./access";
import { shouldMoveToTemp } from "./expression-list";
import { canBeFalsyWhenNotNull, expressionResultIsUsed } from "../utils/typescript";
import { wrapInStatement } from "./expression-statement";

type NormalOptionalChain = ts.PropertyAccessChain | ts.ElementAccessChain | ts.CallChain;

Expand Down Expand Up @@ -56,7 +58,7 @@ export function captureThisValue(
thisValueCapture: lua.Identifier,
tsOriginal: ts.Node
): lua.Expression {
if (!shouldMoveToTemp(context, expression, tsOriginal) && !isOptionalContinuation(tsOriginal)) {
if (!shouldMoveToTemp(context, expression, tsOriginal)) {
return expression;
}
const tempAssignment = lua.createAssignmentStatement(thisValueCapture, expression, tsOriginal);
Expand All @@ -66,6 +68,7 @@ export function captureThisValue(

export interface OptionalContinuation {
contextualCall?: lua.CallExpression;
usedIdentifiers: lua.Identifier[];
}

const optionalContinuations = new WeakMap<ts.Identifier, OptionalContinuation>();
Expand All @@ -74,12 +77,16 @@ const optionalContinuations = new WeakMap<ts.Identifier, OptionalContinuation>()
function createOptionalContinuationIdentifier(text: string, tsOriginal: ts.Expression): ts.Identifier {
const identifier = ts.factory.createIdentifier(text);
ts.setOriginalNode(identifier, tsOriginal);
optionalContinuations.set(identifier, {});
optionalContinuations.set(identifier, {
usedIdentifiers: [],
});
return identifier;
}

export function isOptionalContinuation(node: ts.Node): boolean {
return ts.isIdentifier(node) && optionalContinuations.has(node);
}

export function getOptionalContinuationData(identifier: ts.Identifier): OptionalContinuation | undefined {
return optionalContinuations.get(identifier);
}
Expand All @@ -90,16 +97,16 @@ export function transformOptionalChain(context: TransformationContext, node: ts.

export function transformOptionalChainWithCapture(
context: TransformationContext,
node: ts.OptionalChain,
tsNode: ts.OptionalChain,
thisValueCapture: lua.Identifier | undefined,
isDelete?: ts.DeleteExpression
): ExpressionWithThisValue {
const luaTemp = context.createTempNameForNode(node);
const luaTempName = context.createTempName("opt");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this from what it was?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"node" was too generic and so confusing; there are a lot of variables, TS and Lua, in the body.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't actually "node" literal though, it was already deducing a very specific name based on what the node is, see the changes in the snapshots where it changed ____table_has_result_0 to ___opt_0

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, earlier I thought the comment was addressing something else.

The reason for this is that long-ish optional chains, which are somewhat common, generate long names such as "____a_b_c_foo_result_c_1". "____opt" is shorter and still signifies what the original TS was.


const { expression: tsLeftExpression, chain } = flattenChain(node);
const { expression: tsLeftExpression, chain } = flattenChain(tsNode);

// build temp.b.c.d
const tsTemp = createOptionalContinuationIdentifier(luaTemp.text, tsLeftExpression);
const tsTemp = createOptionalContinuationIdentifier(luaTempName, tsLeftExpression);
let tsRightExpression: ts.Expression = tsTemp;
for (const link of chain) {
if (ts.isPropertyAccessExpression(link)) {
Expand All @@ -121,26 +128,27 @@ export function transformOptionalChainWithCapture(
// transform right expression first to check if thisValue capture is needed
// capture and return thisValue if requested from outside
let returnThisValue: lua.Expression | undefined;
const [rightPrecedingStatements, rightAssignment] = transformInPrecedingStatementScope(context, () => {
let result: lua.Expression;
if (thisValueCapture) {
({ expression: result, thisValue: returnThisValue } = transformExpressionWithThisValueCapture(
context,
tsRightExpression,
thisValueCapture
));
} else {
result = context.transformExpression(tsRightExpression);
const [rightPrecedingStatements, rightExpression] = transformInPrecedingStatementScope(context, () => {
if (!thisValueCapture) {
return context.transformExpression(tsRightExpression);
}
return lua.createAssignmentStatement(luaTemp, result);

const { expression: result, thisValue } = transformExpressionWithThisValueCapture(
context,
tsRightExpression,
thisValueCapture
);
returnThisValue = thisValue;
return result;
});

// transform left expression, handle thisValue if needed by rightExpression
const thisValueCaptureName = context.createTempName("this");
const leftThisValueTemp = lua.createIdentifier(thisValueCaptureName, undefined, tempSymbolId);
let capturedThisValue: lua.Expression | undefined;

const rightContextualCall = getOptionalContinuationData(tsTemp)?.contextualCall;
const optionalContinuationData = getOptionalContinuationData(tsTemp);
const rightContextualCall = optionalContinuationData?.contextualCall;
const [leftPrecedingStatements, leftExpression] = transformInPrecedingStatementScope(context, () => {
let result: lua.Expression;
if (rightContextualCall) {
Expand Down Expand Up @@ -177,26 +185,78 @@ export function transformOptionalChainWithCapture(
}
}

// <left preceding statements>
// local temp = <left>
// if temp ~= nil then
// <right preceding statements>
// temp = temp.b.c.d
// end
// return temp

context.addPrecedingStatements([
...leftPrecedingStatements,
lua.createVariableDeclarationStatement(luaTemp, leftExpression),
lua.createIfStatement(
lua.createBinaryExpression(luaTemp, lua.createNilLiteral(), lua.SyntaxKind.InequalityOperator),
lua.createBlock([...rightPrecedingStatements, rightAssignment])
),
]);
return {
expression: luaTemp,
thisValue: returnThisValue,
};
// evaluate optional chain
context.addPrecedingStatements(leftPrecedingStatements);

// try use existing variable instead of creating new one, if possible
let leftIdentifier: lua.Identifier | undefined;
const usedLuaIdentifiers = optionalContinuationData?.usedIdentifiers;
const reuseLeftIdentifier =
usedLuaIdentifiers &&
usedLuaIdentifiers.length > 0 &&
lua.isIdentifier(leftExpression) &&
(rightPrecedingStatements.length === 0 || !shouldMoveToTemp(context, leftExpression, tsLeftExpression));
if (reuseLeftIdentifier) {
leftIdentifier = leftExpression;
for (const usedIdentifier of usedLuaIdentifiers) {
usedIdentifier.text = leftIdentifier.text;
}
} else {
leftIdentifier = lua.createIdentifier(luaTempName, undefined, tempSymbolId);
context.addPrecedingStatements(lua.createVariableDeclarationStatement(leftIdentifier, leftExpression));
}

if (!expressionResultIsUsed(tsNode) || isDelete) {
// if left ~= nil then
// <right preceding statements>
// <right expression>
// end

const innerExpression = wrapInStatement(rightExpression);
const innerStatements = rightPrecedingStatements;
if (innerExpression) innerStatements.push(innerExpression);

context.addPrecedingStatements(
lua.createIfStatement(
lua.createBinaryExpression(leftIdentifier, lua.createNilLiteral(), lua.SyntaxKind.InequalityOperator),
lua.createBlock(innerStatements)
)
);
return { expression: lua.createNilLiteral(), thisValue: returnThisValue };
} else if (
rightPrecedingStatements.length === 0 &&
!canBeFalsyWhenNotNull(context, context.checker.getTypeAtLocation(tsLeftExpression))
) {
// return a && a.b
return {
expression: lua.createBinaryExpression(leftIdentifier, rightExpression, lua.SyntaxKind.AndOperator, tsNode),
thisValue: returnThisValue,
};
} else {
let resultIdentifier: lua.Identifier;
if (!reuseLeftIdentifier) {
// reuse temp variable for output
resultIdentifier = leftIdentifier;
} else {
resultIdentifier = lua.createIdentifier(context.createTempName("opt_result"), undefined, tempSymbolId);
context.addPrecedingStatements(lua.createVariableDeclarationStatement(resultIdentifier));
}
// if left ~= nil then
// <right preceding statements>
// result = <right expression>
// end
// return result
context.addPrecedingStatements(
lua.createIfStatement(
lua.createBinaryExpression(leftIdentifier, lua.createNilLiteral(), lua.SyntaxKind.InequalityOperator),
lua.createBlock([
...rightPrecedingStatements,
lua.createAssignmentStatement(resultIdentifier, rightExpression),
])
)
);
return { expression: resultIdentifier, thisValue: returnThisValue };
}
}

export function transformOptionalDeleteExpression(
Expand Down
4 changes: 2 additions & 2 deletions src/transformation/visitors/void.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import * as ts from "typescript";
import * as lua from "../../LuaAST";
import { FunctionVisitor } from "../context";
import { transformExpressionToStatement } from "./expression-statement";
import { wrapInStatement } from "./expression-statement";

// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/void
export const transformVoidExpression: FunctionVisitor<ts.VoidExpression> = (node, context) => {
// If content is a literal it is safe to replace the entire expression with nil
if (!ts.isLiteralExpression(node.expression)) {
const statements = transformExpressionToStatement(context, node.expression);
const statements = wrapInStatement(context.transformExpression(node.expression));
if (statements) context.addPrecedingStatements(statements);
}

Expand Down
Loading