Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions src/transformation/visitors/conditional.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,30 @@ import { performHoisting, popScope, pushScope, ScopeType } from "../utils/scope"
import { transformBlockOrStatement } from "./block";
import { canBeFalsy } from "../utils/typescript";

type EvaluatedExpression = [precedingStatemens: lua.Statement[], value: lua.Expression];

function transformProtectedConditionalExpression(
context: TransformationContext,
expression: ts.ConditionalExpression
expression: ts.ConditionalExpression,
condition: EvaluatedExpression,
whenTrue: EvaluatedExpression,
whenFalse: EvaluatedExpression
): lua.Expression {
const tempVar = context.createTempNameForNode(expression.condition);

const condition = context.transformExpression(expression.condition);

const [trueStatements, val1] = transformInPrecedingStatementScope(context, () =>
context.transformExpression(expression.whenTrue)
const trueStatements = whenTrue[0].concat(
lua.createAssignmentStatement(lua.cloneIdentifier(tempVar), whenTrue[1], expression.whenTrue)
);
trueStatements.push(lua.createAssignmentStatement(lua.cloneIdentifier(tempVar), val1, expression.whenTrue));

const [falseStatements, val2] = transformInPrecedingStatementScope(context, () =>
context.transformExpression(expression.whenFalse)
const falseStatements = whenFalse[0].concat(
lua.createAssignmentStatement(lua.cloneIdentifier(tempVar), whenFalse[1], expression.whenFalse)
);
falseStatements.push(lua.createAssignmentStatement(lua.cloneIdentifier(tempVar), val2, expression.whenFalse));

context.addPrecedingStatements([
lua.createVariableDeclarationStatement(tempVar, undefined, expression.condition),
...condition[0],
lua.createIfStatement(
condition,
condition[1],
lua.createBlock(trueStatements, expression.whenTrue),
lua.createBlock(falseStatements, expression.whenFalse),
expression
Expand All @@ -37,17 +39,27 @@ function transformProtectedConditionalExpression(
}

export const transformConditionalExpression: FunctionVisitor<ts.ConditionalExpression> = (expression, context) => {
if (canBeFalsy(context, context.checker.getTypeAtLocation(expression.whenTrue))) {
return transformProtectedConditionalExpression(context, expression);
const condition = transformInPrecedingStatementScope(context, () =>
context.transformExpression(expression.condition)
);
const whenTrue = transformInPrecedingStatementScope(context, () =>
context.transformExpression(expression.whenTrue)
);
const whenFalse = transformInPrecedingStatementScope(context, () =>
context.transformExpression(expression.whenFalse)
);
if (
whenTrue[0].length > 0 ||
whenFalse[0].length > 0 ||
canBeFalsy(context, context.checker.getTypeAtLocation(expression.whenTrue))
) {
return transformProtectedConditionalExpression(context, expression, condition, whenTrue, whenFalse);
}

const condition = context.transformExpression(expression.condition);
const val1 = context.transformExpression(expression.whenTrue);
const val2 = context.transformExpression(expression.whenFalse);

// condition and v1 or v2
const conditionAnd = lua.createBinaryExpression(condition, val1, lua.SyntaxKind.AndOperator);
return lua.createBinaryExpression(conditionAnd, val2, lua.SyntaxKind.OrOperator, expression);
context.addPrecedingStatements(condition[0]);
const conditionAnd = lua.createBinaryExpression(condition[1], whenTrue[1], lua.SyntaxKind.AndOperator);
return lua.createBinaryExpression(conditionAnd, whenFalse[1], lua.SyntaxKind.OrOperator, expression);
};

export function transformIfStatement(statement: ts.IfStatement, context: TransformationContext): lua.IfStatement {
Expand Down
26 changes: 26 additions & 0 deletions test/unit/conditionals.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,29 @@ test.each([false, true, null])("Ternary conditional with generic whenTrue branch
})
.expectToMatchJsResult();
});

test.each([false, true])("Ternary conditional with preceding statements in true branch (%p)", trueVal => {
// language=TypeScript
util.testFunction`
let i = 0;
const result = ${trueVal} ? i += 1 : i;
return { result, i };
`
.setOptions({
strictNullChecks: true,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this option required here, I don't see any undefineds/nulls in the tests?

Copy link
Copy Markdown
Contributor Author

@GlassBricks GlassBricks Mar 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without it, the test doesn't start as failing and detect the issue:
The type number with no strict null checks may be null and therefore falsy, so the if statement was used.

})
.expectToMatchJsResult();
});

test.each([false, true])("Ternary conditional with preceding statements in false branch (%p)", trueVal => {
// language=TypeScript
util.testFunction`
let i = 0;
const result = ${trueVal} ? i : i += 2;
return { result, i };
`
.setOptions({
strictNullChecks: true,
})
.expectToMatchJsResult();
});