Skip to content

Commit 3d0e98d

Browse files
authored
Simplify optional chaining expressions where possible (#1381)
* Simplify optional chaining expressions where possible * Fix case when right expression with preceding statements modifies left expression * Use if statement when left side can be a boolean * use `expressionResultIsUsed` where it was extracted from * Fix duplicate call (reversion) in ArrayPush * Fix canBeFalsyWhenNotNull This shouldn't depend on strictNullChecks * Add snapshot tests to some optional chaining tests. This tests that the output statements are correct.
1 parent d315218 commit 3d0e98d

File tree

9 files changed

+428
-77
lines changed

9 files changed

+428
-77
lines changed

src/transformation/builtins/array.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { TransformationContext } from "../context";
55
import { unsupportedProperty } from "../utils/diagnostics";
66
import { LuaLibFeature, transformLuaLibFunction } from "../utils/lualib";
77
import { transformArguments, transformCallAndArguments } from "../visitors/call";
8-
import { findFirstNonOuterParent, typeAlwaysHasSomeOfFlags } from "../utils/typescript";
8+
import { expressionResultIsUsed, typeAlwaysHasSomeOfFlags } from "../utils/typescript";
99
import { moveToPrecedingTemp } from "../visitors/expression-list";
1010
import { isUnpackCall, wrapInTable } from "../utils/lua-ast";
1111

@@ -54,8 +54,6 @@ function transformSingleElementArrayPush(
5454
caller: lua.Expression,
5555
param: lua.Expression
5656
): lua.Expression {
57-
const expressionIsUsed = !ts.isExpressionStatement(findFirstNonOuterParent(node));
58-
5957
const arrayIdentifier = lua.isIdentifier(caller) ? caller : moveToPrecedingTemp(context, caller);
6058

6159
// #array + 1
@@ -65,6 +63,7 @@ function transformSingleElementArrayPush(
6563
lua.SyntaxKind.AdditionOperator
6664
);
6765

66+
const expressionIsUsed = expressionResultIsUsed(node);
6867
if (expressionIsUsed) {
6968
// store length in a temp
7069
lengthExpression = moveToPrecedingTemp(context, lengthExpression);

src/transformation/utils/typescript/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ export function findFirstNonOuterParent(node: ts.Node): ts.Node {
3232
return current;
3333
}
3434

35+
export function expressionResultIsUsed(node: ts.Expression): boolean {
36+
return !ts.isExpressionStatement(findFirstNonOuterParent(node));
37+
}
38+
3539
export function getFirstDeclarationInFile(symbol: ts.Symbol, sourceFile: ts.SourceFile): ts.Declaration | undefined {
3640
const originalSourceFile = ts.getParseTreeNode(sourceFile) ?? sourceFile;
3741
const declarations = (symbol.getDeclarations() ?? []).filter(d => d.getSourceFile() === originalSourceFile);

src/transformation/utils/typescript/types.ts

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@ export function canBeFalsy(context: TransformationContext, type: ts.Type): boole
115115
}
116116

117117
export function canBeFalsyWhenNotNull(context: TransformationContext, type: ts.Type): boolean {
118-
const strictNullChecks = context.options.strict === true || context.options.strictNullChecks === true;
119-
if (!strictNullChecks && !type.isLiteral()) return true;
120118
const falsyFlags =
121119
ts.TypeFlags.Boolean |
122120
ts.TypeFlags.BooleanLiteral |

src/transformation/visitors/expression-statement.ts

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import * as ts from "typescript";
22
import * as lua from "../../LuaAST";
3-
import { FunctionVisitor, tempSymbolId, TransformationContext } from "../context";
3+
import { FunctionVisitor, tempSymbolId } from "../context";
44
import { transformBinaryExpressionStatement } from "./binary-expression";
55
import { transformUnaryExpressionStatement } from "./unary-expression";
66

@@ -15,15 +15,10 @@ export const transformExpressionStatement: FunctionVisitor<ts.ExpressionStatemen
1515
return binaryExpressionResult;
1616
}
1717

18-
return transformExpressionToStatement(context, node.expression);
18+
return wrapInStatement(context.transformExpression(node.expression));
1919
};
2020

21-
export function transformExpressionToStatement(
22-
context: TransformationContext,
23-
expression: ts.Expression
24-
): lua.Statement | undefined {
25-
const result = context.transformExpression(expression);
26-
21+
export function wrapInStatement(result: lua.Expression): lua.Statement | undefined {
2722
const isTempVariable = lua.isIdentifier(result) && result.symbolId === tempSymbolId;
2823
if (isTempVariable) {
2924
return undefined;

src/transformation/visitors/identifier.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { invalidCallExtensionUse } from "../utils/diagnostics";
77
import { createExportedIdentifier, getSymbolExportScope } from "../utils/export";
88
import { createSafeName, hasUnsafeIdentifierName } from "../utils/safe-names";
99
import { getIdentifierSymbolId } from "../utils/symbols";
10-
import { isOptionalContinuation } from "./optional-chaining";
10+
import { getOptionalContinuationData, isOptionalContinuation } from "./optional-chaining";
1111
import { isStandardLibraryType } from "../utils/typescript";
1212
import { getExtensionKindForNode, getExtensionKindForSymbol } from "../utils/language-extensions";
1313
import { callExtensions } from "./language-extensions/call-extension";
@@ -16,13 +16,16 @@ import { isIdentifierExtensionValue, reportInvalidExtensionValue } from "./langu
1616
export function transformIdentifier(context: TransformationContext, identifier: ts.Identifier): lua.Identifier {
1717
return transformNonValueIdentifier(context, identifier, context.checker.getSymbolAtLocation(identifier));
1818
}
19+
1920
function transformNonValueIdentifier(
2021
context: TransformationContext,
2122
identifier: ts.Identifier,
2223
symbol: ts.Symbol | undefined
2324
) {
2425
if (isOptionalContinuation(identifier)) {
25-
return lua.createIdentifier(identifier.text, undefined, tempSymbolId);
26+
const result = lua.createIdentifier(identifier.text, undefined, tempSymbolId);
27+
getOptionalContinuationData(identifier)!.usedIdentifiers.push(result);
28+
return result;
2629
}
2730

2831
const extensionKind = symbol

src/transformation/visitors/optional-chaining.ts

Lines changed: 100 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import * as ts from "typescript";
22
import * as lua from "../../LuaAST";
3-
import { TransformationContext, tempSymbolId } from "../context";
3+
import { tempSymbolId, TransformationContext } from "../context";
44
import { assert, assertNever } from "../../utils";
55
import { transformInPrecedingStatementScope } from "../utils/preceding-statements";
6-
import { transformPropertyAccessExpressionWithCapture, transformElementAccessExpressionWithCapture } from "./access";
6+
import { transformElementAccessExpressionWithCapture, transformPropertyAccessExpressionWithCapture } from "./access";
77
import { shouldMoveToTemp } from "./expression-list";
8+
import { canBeFalsyWhenNotNull, expressionResultIsUsed } from "../utils/typescript";
9+
import { wrapInStatement } from "./expression-statement";
810

911
type NormalOptionalChain = ts.PropertyAccessChain | ts.ElementAccessChain | ts.CallChain;
1012

@@ -56,7 +58,7 @@ export function captureThisValue(
5658
thisValueCapture: lua.Identifier,
5759
tsOriginal: ts.Node
5860
): lua.Expression {
59-
if (!shouldMoveToTemp(context, expression, tsOriginal) && !isOptionalContinuation(tsOriginal)) {
61+
if (!shouldMoveToTemp(context, expression, tsOriginal)) {
6062
return expression;
6163
}
6264
const tempAssignment = lua.createAssignmentStatement(thisValueCapture, expression, tsOriginal);
@@ -66,6 +68,7 @@ export function captureThisValue(
6668

6769
export interface OptionalContinuation {
6870
contextualCall?: lua.CallExpression;
71+
usedIdentifiers: lua.Identifier[];
6972
}
7073

7174
const optionalContinuations = new WeakMap<ts.Identifier, OptionalContinuation>();
@@ -74,12 +77,16 @@ const optionalContinuations = new WeakMap<ts.Identifier, OptionalContinuation>()
7477
function createOptionalContinuationIdentifier(text: string, tsOriginal: ts.Expression): ts.Identifier {
7578
const identifier = ts.factory.createIdentifier(text);
7679
ts.setOriginalNode(identifier, tsOriginal);
77-
optionalContinuations.set(identifier, {});
80+
optionalContinuations.set(identifier, {
81+
usedIdentifiers: [],
82+
});
7883
return identifier;
7984
}
85+
8086
export function isOptionalContinuation(node: ts.Node): boolean {
8187
return ts.isIdentifier(node) && optionalContinuations.has(node);
8288
}
89+
8390
export function getOptionalContinuationData(identifier: ts.Identifier): OptionalContinuation | undefined {
8491
return optionalContinuations.get(identifier);
8592
}
@@ -90,16 +97,16 @@ export function transformOptionalChain(context: TransformationContext, node: ts.
9097

9198
export function transformOptionalChainWithCapture(
9299
context: TransformationContext,
93-
node: ts.OptionalChain,
100+
tsNode: ts.OptionalChain,
94101
thisValueCapture: lua.Identifier | undefined,
95102
isDelete?: ts.DeleteExpression
96103
): ExpressionWithThisValue {
97-
const luaTemp = context.createTempNameForNode(node);
104+
const luaTempName = context.createTempName("opt");
98105

99-
const { expression: tsLeftExpression, chain } = flattenChain(node);
106+
const { expression: tsLeftExpression, chain } = flattenChain(tsNode);
100107

101108
// build temp.b.c.d
102-
const tsTemp = createOptionalContinuationIdentifier(luaTemp.text, tsLeftExpression);
109+
const tsTemp = createOptionalContinuationIdentifier(luaTempName, tsLeftExpression);
103110
let tsRightExpression: ts.Expression = tsTemp;
104111
for (const link of chain) {
105112
if (ts.isPropertyAccessExpression(link)) {
@@ -121,26 +128,27 @@ export function transformOptionalChainWithCapture(
121128
// transform right expression first to check if thisValue capture is needed
122129
// capture and return thisValue if requested from outside
123130
let returnThisValue: lua.Expression | undefined;
124-
const [rightPrecedingStatements, rightAssignment] = transformInPrecedingStatementScope(context, () => {
125-
let result: lua.Expression;
126-
if (thisValueCapture) {
127-
({ expression: result, thisValue: returnThisValue } = transformExpressionWithThisValueCapture(
128-
context,
129-
tsRightExpression,
130-
thisValueCapture
131-
));
132-
} else {
133-
result = context.transformExpression(tsRightExpression);
131+
const [rightPrecedingStatements, rightExpression] = transformInPrecedingStatementScope(context, () => {
132+
if (!thisValueCapture) {
133+
return context.transformExpression(tsRightExpression);
134134
}
135-
return lua.createAssignmentStatement(luaTemp, result);
135+
136+
const { expression: result, thisValue } = transformExpressionWithThisValueCapture(
137+
context,
138+
tsRightExpression,
139+
thisValueCapture
140+
);
141+
returnThisValue = thisValue;
142+
return result;
136143
});
137144

138145
// transform left expression, handle thisValue if needed by rightExpression
139146
const thisValueCaptureName = context.createTempName("this");
140147
const leftThisValueTemp = lua.createIdentifier(thisValueCaptureName, undefined, tempSymbolId);
141148
let capturedThisValue: lua.Expression | undefined;
142149

143-
const rightContextualCall = getOptionalContinuationData(tsTemp)?.contextualCall;
150+
const optionalContinuationData = getOptionalContinuationData(tsTemp);
151+
const rightContextualCall = optionalContinuationData?.contextualCall;
144152
const [leftPrecedingStatements, leftExpression] = transformInPrecedingStatementScope(context, () => {
145153
let result: lua.Expression;
146154
if (rightContextualCall) {
@@ -177,26 +185,78 @@ export function transformOptionalChainWithCapture(
177185
}
178186
}
179187

180-
// <left preceding statements>
181-
// local temp = <left>
182-
// if temp ~= nil then
183-
// <right preceding statements>
184-
// temp = temp.b.c.d
185-
// end
186-
// return temp
187-
188-
context.addPrecedingStatements([
189-
...leftPrecedingStatements,
190-
lua.createVariableDeclarationStatement(luaTemp, leftExpression),
191-
lua.createIfStatement(
192-
lua.createBinaryExpression(luaTemp, lua.createNilLiteral(), lua.SyntaxKind.InequalityOperator),
193-
lua.createBlock([...rightPrecedingStatements, rightAssignment])
194-
),
195-
]);
196-
return {
197-
expression: luaTemp,
198-
thisValue: returnThisValue,
199-
};
188+
// evaluate optional chain
189+
context.addPrecedingStatements(leftPrecedingStatements);
190+
191+
// try use existing variable instead of creating new one, if possible
192+
let leftIdentifier: lua.Identifier | undefined;
193+
const usedLuaIdentifiers = optionalContinuationData?.usedIdentifiers;
194+
const reuseLeftIdentifier =
195+
usedLuaIdentifiers &&
196+
usedLuaIdentifiers.length > 0 &&
197+
lua.isIdentifier(leftExpression) &&
198+
(rightPrecedingStatements.length === 0 || !shouldMoveToTemp(context, leftExpression, tsLeftExpression));
199+
if (reuseLeftIdentifier) {
200+
leftIdentifier = leftExpression;
201+
for (const usedIdentifier of usedLuaIdentifiers) {
202+
usedIdentifier.text = leftIdentifier.text;
203+
}
204+
} else {
205+
leftIdentifier = lua.createIdentifier(luaTempName, undefined, tempSymbolId);
206+
context.addPrecedingStatements(lua.createVariableDeclarationStatement(leftIdentifier, leftExpression));
207+
}
208+
209+
if (!expressionResultIsUsed(tsNode) || isDelete) {
210+
// if left ~= nil then
211+
// <right preceding statements>
212+
// <right expression>
213+
// end
214+
215+
const innerExpression = wrapInStatement(rightExpression);
216+
const innerStatements = rightPrecedingStatements;
217+
if (innerExpression) innerStatements.push(innerExpression);
218+
219+
context.addPrecedingStatements(
220+
lua.createIfStatement(
221+
lua.createBinaryExpression(leftIdentifier, lua.createNilLiteral(), lua.SyntaxKind.InequalityOperator),
222+
lua.createBlock(innerStatements)
223+
)
224+
);
225+
return { expression: lua.createNilLiteral(), thisValue: returnThisValue };
226+
} else if (
227+
rightPrecedingStatements.length === 0 &&
228+
!canBeFalsyWhenNotNull(context, context.checker.getTypeAtLocation(tsLeftExpression))
229+
) {
230+
// return a && a.b
231+
return {
232+
expression: lua.createBinaryExpression(leftIdentifier, rightExpression, lua.SyntaxKind.AndOperator, tsNode),
233+
thisValue: returnThisValue,
234+
};
235+
} else {
236+
let resultIdentifier: lua.Identifier;
237+
if (!reuseLeftIdentifier) {
238+
// reuse temp variable for output
239+
resultIdentifier = leftIdentifier;
240+
} else {
241+
resultIdentifier = lua.createIdentifier(context.createTempName("opt_result"), undefined, tempSymbolId);
242+
context.addPrecedingStatements(lua.createVariableDeclarationStatement(resultIdentifier));
243+
}
244+
// if left ~= nil then
245+
// <right preceding statements>
246+
// result = <right expression>
247+
// end
248+
// return result
249+
context.addPrecedingStatements(
250+
lua.createIfStatement(
251+
lua.createBinaryExpression(leftIdentifier, lua.createNilLiteral(), lua.SyntaxKind.InequalityOperator),
252+
lua.createBlock([
253+
...rightPrecedingStatements,
254+
lua.createAssignmentStatement(resultIdentifier, rightExpression),
255+
])
256+
)
257+
);
258+
return { expression: resultIdentifier, thisValue: returnThisValue };
259+
}
200260
}
201261

202262
export function transformOptionalDeleteExpression(

src/transformation/visitors/void.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import * as ts from "typescript";
22
import * as lua from "../../LuaAST";
33
import { FunctionVisitor } from "../context";
4-
import { transformExpressionToStatement } from "./expression-statement";
4+
import { wrapInStatement } from "./expression-statement";
55

66
// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/void
77
export const transformVoidExpression: FunctionVisitor<ts.VoidExpression> = (node, context) => {
88
// If content is a literal it is safe to replace the entire expression with nil
99
if (!ts.isLiteralExpression(node.expression)) {
10-
const statements = transformExpressionToStatement(context, node.expression);
10+
const statements = wrapInStatement(context.transformExpression(node.expression));
1111
if (statements) context.addPrecedingStatements(statements);
1212
}
1313

0 commit comments

Comments
 (0)