11import * as ts from "typescript" ;
22import * as lua from "../../LuaAST" ;
3- import { TransformationContext , tempSymbolId } from "../context" ;
3+ import { tempSymbolId , TransformationContext } from "../context" ;
44import { assert , assertNever } from "../../utils" ;
55import { transformInPrecedingStatementScope } from "../utils/preceding-statements" ;
6- import { transformPropertyAccessExpressionWithCapture , transformElementAccessExpressionWithCapture } from "./access" ;
6+ import { transformElementAccessExpressionWithCapture , transformPropertyAccessExpressionWithCapture } from "./access" ;
77import { shouldMoveToTemp } from "./expression-list" ;
8+ import { canBeFalsyWhenNotNull , expressionResultIsUsed } from "../utils/typescript" ;
9+ import { wrapInStatement } from "./expression-statement" ;
810
911type NormalOptionalChain = ts . PropertyAccessChain | ts . ElementAccessChain | ts . CallChain ;
1012
@@ -56,7 +58,7 @@ export function captureThisValue(
5658 thisValueCapture : lua . Identifier ,
5759 tsOriginal : ts . Node
5860) : lua . Expression {
59- if ( ! shouldMoveToTemp ( context , expression , tsOriginal ) && ! isOptionalContinuation ( tsOriginal ) ) {
61+ if ( ! shouldMoveToTemp ( context , expression , tsOriginal ) ) {
6062 return expression ;
6163 }
6264 const tempAssignment = lua . createAssignmentStatement ( thisValueCapture , expression , tsOriginal ) ;
@@ -66,6 +68,7 @@ export function captureThisValue(
6668
6769export interface OptionalContinuation {
6870 contextualCall ?: lua . CallExpression ;
71+ usedIdentifiers : lua . Identifier [ ] ;
6972}
7073
7174const optionalContinuations = new WeakMap < ts . Identifier , OptionalContinuation > ( ) ;
@@ -74,12 +77,16 @@ const optionalContinuations = new WeakMap<ts.Identifier, OptionalContinuation>()
7477function createOptionalContinuationIdentifier ( text : string , tsOriginal : ts . Expression ) : ts . Identifier {
7578 const identifier = ts . factory . createIdentifier ( text ) ;
7679 ts . setOriginalNode ( identifier , tsOriginal ) ;
77- optionalContinuations . set ( identifier , { } ) ;
80+ optionalContinuations . set ( identifier , {
81+ usedIdentifiers : [ ] ,
82+ } ) ;
7883 return identifier ;
7984}
85+
8086export function isOptionalContinuation ( node : ts . Node ) : boolean {
8187 return ts . isIdentifier ( node ) && optionalContinuations . has ( node ) ;
8288}
89+
8390export function getOptionalContinuationData ( identifier : ts . Identifier ) : OptionalContinuation | undefined {
8491 return optionalContinuations . get ( identifier ) ;
8592}
@@ -90,16 +97,16 @@ export function transformOptionalChain(context: TransformationContext, node: ts.
9097
9198export function transformOptionalChainWithCapture (
9299 context : TransformationContext ,
93- node : ts . OptionalChain ,
100+ tsNode : ts . OptionalChain ,
94101 thisValueCapture : lua . Identifier | undefined ,
95102 isDelete ?: ts . DeleteExpression
96103) : ExpressionWithThisValue {
97- const luaTemp = context . createTempNameForNode ( node ) ;
104+ const luaTempName = context . createTempName ( "opt" ) ;
98105
99- const { expression : tsLeftExpression , chain } = flattenChain ( node ) ;
106+ const { expression : tsLeftExpression , chain } = flattenChain ( tsNode ) ;
100107
101108 // build temp.b.c.d
102- const tsTemp = createOptionalContinuationIdentifier ( luaTemp . text , tsLeftExpression ) ;
109+ const tsTemp = createOptionalContinuationIdentifier ( luaTempName , tsLeftExpression ) ;
103110 let tsRightExpression : ts . Expression = tsTemp ;
104111 for ( const link of chain ) {
105112 if ( ts . isPropertyAccessExpression ( link ) ) {
@@ -121,26 +128,27 @@ export function transformOptionalChainWithCapture(
121128 // transform right expression first to check if thisValue capture is needed
122129 // capture and return thisValue if requested from outside
123130 let returnThisValue : lua . Expression | undefined ;
124- const [ rightPrecedingStatements , rightAssignment ] = transformInPrecedingStatementScope ( context , ( ) => {
125- let result : lua . Expression ;
126- if ( thisValueCapture ) {
127- ( { expression : result , thisValue : returnThisValue } = transformExpressionWithThisValueCapture (
128- context ,
129- tsRightExpression ,
130- thisValueCapture
131- ) ) ;
132- } else {
133- result = context . transformExpression ( tsRightExpression ) ;
131+ const [ rightPrecedingStatements , rightExpression ] = transformInPrecedingStatementScope ( context , ( ) => {
132+ if ( ! thisValueCapture ) {
133+ return context . transformExpression ( tsRightExpression ) ;
134134 }
135- return lua . createAssignmentStatement ( luaTemp , result ) ;
135+
136+ const { expression : result , thisValue } = transformExpressionWithThisValueCapture (
137+ context ,
138+ tsRightExpression ,
139+ thisValueCapture
140+ ) ;
141+ returnThisValue = thisValue ;
142+ return result ;
136143 } ) ;
137144
138145 // transform left expression, handle thisValue if needed by rightExpression
139146 const thisValueCaptureName = context . createTempName ( "this" ) ;
140147 const leftThisValueTemp = lua . createIdentifier ( thisValueCaptureName , undefined , tempSymbolId ) ;
141148 let capturedThisValue : lua . Expression | undefined ;
142149
143- const rightContextualCall = getOptionalContinuationData ( tsTemp ) ?. contextualCall ;
150+ const optionalContinuationData = getOptionalContinuationData ( tsTemp ) ;
151+ const rightContextualCall = optionalContinuationData ?. contextualCall ;
144152 const [ leftPrecedingStatements , leftExpression ] = transformInPrecedingStatementScope ( context , ( ) => {
145153 let result : lua . Expression ;
146154 if ( rightContextualCall ) {
@@ -177,26 +185,78 @@ export function transformOptionalChainWithCapture(
177185 }
178186 }
179187
180- // <left preceding statements>
181- // local temp = <left>
182- // if temp ~= nil then
183- // <right preceding statements>
184- // temp = temp.b.c.d
185- // end
186- // return temp
187-
188- context . addPrecedingStatements ( [
189- ...leftPrecedingStatements ,
190- lua . createVariableDeclarationStatement ( luaTemp , leftExpression ) ,
191- lua . createIfStatement (
192- lua . createBinaryExpression ( luaTemp , lua . createNilLiteral ( ) , lua . SyntaxKind . InequalityOperator ) ,
193- lua . createBlock ( [ ...rightPrecedingStatements , rightAssignment ] )
194- ) ,
195- ] ) ;
196- return {
197- expression : luaTemp ,
198- thisValue : returnThisValue ,
199- } ;
188+ // evaluate optional chain
189+ context . addPrecedingStatements ( leftPrecedingStatements ) ;
190+
191+ // try use existing variable instead of creating new one, if possible
192+ let leftIdentifier : lua . Identifier | undefined ;
193+ const usedLuaIdentifiers = optionalContinuationData ?. usedIdentifiers ;
194+ const reuseLeftIdentifier =
195+ usedLuaIdentifiers &&
196+ usedLuaIdentifiers . length > 0 &&
197+ lua . isIdentifier ( leftExpression ) &&
198+ ( rightPrecedingStatements . length === 0 || ! shouldMoveToTemp ( context , leftExpression , tsLeftExpression ) ) ;
199+ if ( reuseLeftIdentifier ) {
200+ leftIdentifier = leftExpression ;
201+ for ( const usedIdentifier of usedLuaIdentifiers ) {
202+ usedIdentifier . text = leftIdentifier . text ;
203+ }
204+ } else {
205+ leftIdentifier = lua . createIdentifier ( luaTempName , undefined , tempSymbolId ) ;
206+ context . addPrecedingStatements ( lua . createVariableDeclarationStatement ( leftIdentifier , leftExpression ) ) ;
207+ }
208+
209+ if ( ! expressionResultIsUsed ( tsNode ) || isDelete ) {
210+ // if left ~= nil then
211+ // <right preceding statements>
212+ // <right expression>
213+ // end
214+
215+ const innerExpression = wrapInStatement ( rightExpression ) ;
216+ const innerStatements = rightPrecedingStatements ;
217+ if ( innerExpression ) innerStatements . push ( innerExpression ) ;
218+
219+ context . addPrecedingStatements (
220+ lua . createIfStatement (
221+ lua . createBinaryExpression ( leftIdentifier , lua . createNilLiteral ( ) , lua . SyntaxKind . InequalityOperator ) ,
222+ lua . createBlock ( innerStatements )
223+ )
224+ ) ;
225+ return { expression : lua . createNilLiteral ( ) , thisValue : returnThisValue } ;
226+ } else if (
227+ rightPrecedingStatements . length === 0 &&
228+ ! canBeFalsyWhenNotNull ( context , context . checker . getTypeAtLocation ( tsLeftExpression ) )
229+ ) {
230+ // return a && a.b
231+ return {
232+ expression : lua . createBinaryExpression ( leftIdentifier , rightExpression , lua . SyntaxKind . AndOperator , tsNode ) ,
233+ thisValue : returnThisValue ,
234+ } ;
235+ } else {
236+ let resultIdentifier : lua . Identifier ;
237+ if ( ! reuseLeftIdentifier ) {
238+ // reuse temp variable for output
239+ resultIdentifier = leftIdentifier ;
240+ } else {
241+ resultIdentifier = lua . createIdentifier ( context . createTempName ( "opt_result" ) , undefined , tempSymbolId ) ;
242+ context . addPrecedingStatements ( lua . createVariableDeclarationStatement ( resultIdentifier ) ) ;
243+ }
244+ // if left ~= nil then
245+ // <right preceding statements>
246+ // result = <right expression>
247+ // end
248+ // return result
249+ context . addPrecedingStatements (
250+ lua . createIfStatement (
251+ lua . createBinaryExpression ( leftIdentifier , lua . createNilLiteral ( ) , lua . SyntaxKind . InequalityOperator ) ,
252+ lua . createBlock ( [
253+ ...rightPrecedingStatements ,
254+ lua . createAssignmentStatement ( resultIdentifier , rightExpression ) ,
255+ ] )
256+ )
257+ ) ;
258+ return { expression : resultIdentifier , thisValue : returnThisValue } ;
259+ }
200260}
201261
202262export function transformOptionalDeleteExpression (
0 commit comments