Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
source
TensorFlow version
TensorFlow 2.22.0-dev20260429
Custom code
Yes
OS platform and distribution
No response
Mobile device
No response
Python version
No response
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
Summary
Under XLA (jit_compile=True), the gradient of tf.where(cond, f(x), g(x)) evaluates the gradient of both branches, including the "dead" branch not selected by the condition. When the dead branch has a NaN or inf gradient at the boundary (e.g., sqrt'(0) = inf, reciprocal'(0) = -inf), this leaks into the result.
Eager mode correctly computes gradient only through the selected branch.
Scope
Affects any tf.where pattern that guards a function with a singular gradient:
| Pattern |
x value |
Eager grad |
XLA grad |
where(x>0, sqrt(x), 0) |
0.0 |
0 |
NaN |
where(x>0, sqrt(x), 0) |
-1.0 |
0 |
NaN |
where(x>0, 1/x, 1) |
0.0 |
0 |
NaN |
where(x>0, log(x), -100) |
0.0 |
NaN |
NaN |
The log case is NOT affected because log'(0) = inf but log(0) = -inf, and the forward pass already produces -inf which makes the gradient NaN in both paths.
The sqrt case IS affected because sqrt(0) = 0 (forward is fine) but sqrt'(0) = 1/(2*sqrt(0)) = inf (gradient is singular). The where condition successfully guards the forward pass but not the backward pass under XLA.
For mixed arrays, only the dead-branch elements are wrong:
x = tf.constant([-1.0, 0.0, 1.0, 4.0])
# Eager gradient: [0, 0, 0.5, 0.25]
# XLA gradient: [nan, nan, 0.5, 0.25]
Practical Impact
The tf.where guarding pattern is the standard way to write safe gradient functions in TensorFlow:
# Safe Euclidean distance (avoids sqrt(0) gradient explosion)
def safe_distance(x, y):
d_sq = tf.reduce_sum((x - y) ** 2)
return tf.where(d_sq > 0, tf.sqrt(d_sq), 0.0)
x = tf.constant([1.0, 2.0, 3.0])
y = tf.constant([1.0, 2.0, 3.0]) # identical points
# Eager: gradient = [0, 0, 0] (correct)
# XLA: gradient = [nan, nan, nan] (WRONG)
This pattern appears in:
- Distance computations (contrastive loss, triplet loss)
- Normalization (safe L2 normalize)
- Any function with a removable singularity guarded by
tf.where
Code that works correctly in eager mode produces NaN gradients when compiled with XLA, with no warning.
Root Cause
XLA lowers tf.where to HLO Select. During backpropagation, the gradient of Select is:
grad_true_branch = upstream * cast(condition)
grad_false_branch = upstream * cast(~condition)
But XLA first computes grad_f(x) and grad_g(x) for both branches regardless of the condition, then multiplies by the condition mask. When grad_f(x) is NaN/inf at boundary points, NaN * 0 = NaN (IEEE 754), so the dead branch gradient leaks through.
TF eager mode handles this correctly by only computing the gradient through the selected branch.
Cross-framework Check
JAX (0.9.2) returns NaN in both eager and jit — JAX's jnp.where also lowers to HLO Select, so the same gradient issue appears. However, JAX is at least consistent (both modes return NaN). TF has an inconsistency: eager is correct, XLA is wrong.
Environment
- TensorFlow 2.22.0-dev20260429
- Affects both CPU and GPU XLA
- Forward pass: agrees (both correct)
- Gradient: diverges (eager=0, XLA=NaN)
- JAX 0.9.2: NaN in both eager and jit (consistent, but same root cause)
Recommendation
Fix in TF's gradient implementation for tf.where under XLA: use select(cond, grad_f, 0) instead of grad_f * cast(cond) to avoid NaN * 0 = NaN leakage. Alternatively, lower the gradient computation to HLO Conditional instead of Select so the dead branch gradient is never evaluated.
Standalone code to reproduce the issue
import tensorflow as tf
x = tf.constant([0.0])
# Eager: gradient = 0 (correct — x=0 selects the zero branch)
with tf.GradientTape() as t:
t.watch(x)
y = tf.where(x > 0, tf.sqrt(x), tf.zeros_like(x))
print(t.gradient(y, x).numpy()) # [0.]
# XLA: gradient = NaN (WRONG — sqrt branch gradient leaks)
@tf.function(jit_compile=True)
def grad_xla(x):
with tf.GradientTape() as t:
t.watch(x)
y = tf.where(x > 0, tf.sqrt(x), tf.zeros_like(x))
return t.gradient(y, x)
print(grad_xla(x).numpy()) # [nan]
Observed on TensorFlow 2.22.0-dev20260429.
Relevant log output
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
source
TensorFlow version
TensorFlow 2.22.0-dev20260429
Custom code
Yes
OS platform and distribution
No response
Mobile device
No response
Python version
No response
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
Summary
Under XLA (
jit_compile=True), the gradient oftf.where(cond, f(x), g(x))evaluates the gradient of both branches, including the "dead" branch not selected by the condition. When the dead branch has a NaN or inf gradient at the boundary (e.g.,sqrt'(0) = inf,reciprocal'(0) = -inf), this leaks into the result.Eager mode correctly computes gradient only through the selected branch.
Scope
Affects any
tf.wherepattern that guards a function with a singular gradient:where(x>0, sqrt(x), 0)where(x>0, sqrt(x), 0)where(x>0, 1/x, 1)where(x>0, log(x), -100)The
logcase is NOT affected becauselog'(0) = infbutlog(0) = -inf, and the forward pass already produces -inf which makes the gradient NaN in both paths.The
sqrtcase IS affected becausesqrt(0) = 0(forward is fine) butsqrt'(0) = 1/(2*sqrt(0)) = inf(gradient is singular). Thewherecondition successfully guards the forward pass but not the backward pass under XLA.For mixed arrays, only the dead-branch elements are wrong:
Practical Impact
The
tf.whereguarding pattern is the standard way to write safe gradient functions in TensorFlow:This pattern appears in:
tf.whereCode that works correctly in eager mode produces NaN gradients when compiled with XLA, with no warning.
Root Cause
XLA lowers
tf.whereto HLOSelect. During backpropagation, the gradient ofSelectis:But XLA first computes
grad_f(x)andgrad_g(x)for both branches regardless of the condition, then multiplies by the condition mask. Whengrad_f(x)is NaN/inf at boundary points,NaN * 0 = NaN(IEEE 754), so the dead branch gradient leaks through.TF eager mode handles this correctly by only computing the gradient through the selected branch.
Cross-framework Check
JAX (0.9.2) returns NaN in both eager and jit — JAX's
jnp.wherealso lowers to HLO Select, so the same gradient issue appears. However, JAX is at least consistent (both modes return NaN). TF has an inconsistency: eager is correct, XLA is wrong.Environment
Recommendation
Fix in TF's gradient implementation for
tf.whereunder XLA: useselect(cond, grad_f, 0)instead ofgrad_f * cast(cond)to avoidNaN * 0 = NaNleakage. Alternatively, lower the gradient computation to HLO Conditional instead of Select so the dead branch gradient is never evaluated.Standalone code to reproduce the issue
Relevant log output