Skip to content

XLA gradient of tf.where evaluates dead branch, leaking NaN/inf #118674

@wuyii8941

Description

@wuyii8941

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

source

TensorFlow version

TensorFlow 2.22.0-dev20260429

Custom code

Yes

OS platform and distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

Summary

Under XLA (jit_compile=True), the gradient of tf.where(cond, f(x), g(x)) evaluates the gradient of both branches, including the "dead" branch not selected by the condition. When the dead branch has a NaN or inf gradient at the boundary (e.g., sqrt'(0) = inf, reciprocal'(0) = -inf), this leaks into the result.

Eager mode correctly computes gradient only through the selected branch.

Scope

Affects any tf.where pattern that guards a function with a singular gradient:

Pattern x value Eager grad XLA grad
where(x>0, sqrt(x), 0) 0.0 0 NaN
where(x>0, sqrt(x), 0) -1.0 0 NaN
where(x>0, 1/x, 1) 0.0 0 NaN
where(x>0, log(x), -100) 0.0 NaN NaN

The log case is NOT affected because log'(0) = inf but log(0) = -inf, and the forward pass already produces -inf which makes the gradient NaN in both paths.

The sqrt case IS affected because sqrt(0) = 0 (forward is fine) but sqrt'(0) = 1/(2*sqrt(0)) = inf (gradient is singular). The where condition successfully guards the forward pass but not the backward pass under XLA.

For mixed arrays, only the dead-branch elements are wrong:

x = tf.constant([-1.0, 0.0, 1.0, 4.0])
# Eager gradient: [0, 0, 0.5, 0.25]
# XLA gradient:   [nan, nan, 0.5, 0.25]

Practical Impact

The tf.where guarding pattern is the standard way to write safe gradient functions in TensorFlow:

# Safe Euclidean distance (avoids sqrt(0) gradient explosion)
def safe_distance(x, y):
    d_sq = tf.reduce_sum((x - y) ** 2)
    return tf.where(d_sq > 0, tf.sqrt(d_sq), 0.0)

x = tf.constant([1.0, 2.0, 3.0])
y = tf.constant([1.0, 2.0, 3.0])  # identical points

# Eager: gradient = [0, 0, 0] (correct)
# XLA:   gradient = [nan, nan, nan] (WRONG)

This pattern appears in:

  • Distance computations (contrastive loss, triplet loss)
  • Normalization (safe L2 normalize)
  • Any function with a removable singularity guarded by tf.where

Code that works correctly in eager mode produces NaN gradients when compiled with XLA, with no warning.

Root Cause

XLA lowers tf.where to HLO Select. During backpropagation, the gradient of Select is:

grad_true_branch = upstream * cast(condition)
grad_false_branch = upstream * cast(~condition)

But XLA first computes grad_f(x) and grad_g(x) for both branches regardless of the condition, then multiplies by the condition mask. When grad_f(x) is NaN/inf at boundary points, NaN * 0 = NaN (IEEE 754), so the dead branch gradient leaks through.

TF eager mode handles this correctly by only computing the gradient through the selected branch.

Cross-framework Check

JAX (0.9.2) returns NaN in both eager and jit — JAX's jnp.where also lowers to HLO Select, so the same gradient issue appears. However, JAX is at least consistent (both modes return NaN). TF has an inconsistency: eager is correct, XLA is wrong.

Environment

  • TensorFlow 2.22.0-dev20260429
  • Affects both CPU and GPU XLA
  • Forward pass: agrees (both correct)
  • Gradient: diverges (eager=0, XLA=NaN)
  • JAX 0.9.2: NaN in both eager and jit (consistent, but same root cause)

Recommendation

Fix in TF's gradient implementation for tf.where under XLA: use select(cond, grad_f, 0) instead of grad_f * cast(cond) to avoid NaN * 0 = NaN leakage. Alternatively, lower the gradient computation to HLO Conditional instead of Select so the dead branch gradient is never evaluated.

Standalone code to reproduce the issue

import tensorflow as tf

x = tf.constant([0.0])

# Eager: gradient = 0 (correct — x=0 selects the zero branch)
with tf.GradientTape() as t:
    t.watch(x)
    y = tf.where(x > 0, tf.sqrt(x), tf.zeros_like(x))
print(t.gradient(y, x).numpy())  # [0.]

# XLA: gradient = NaN (WRONG — sqrt branch gradient leaks)
@tf.function(jit_compile=True)
def grad_xla(x):
    with tf.GradientTape() as t:
        t.watch(x)
        y = tf.where(x > 0, tf.sqrt(x), tf.zeros_like(x))
    return t.gradient(y, x)
print(grad_xla(x).numpy())  # [nan]


Observed on TensorFlow 2.22.0-dev20260429.

Relevant log output

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions