4343
4444from packaging import version
4545if version .parse (keras .__version__ .replace ("-tf" , "+tf" )) < version .parse ("2.11" ):
46- from keras . optimizers import Optimizer
46+ from keras import optimizers
4747else :
48- from keras .optimizers . legacy import Optimizer
48+ from keras .optimizers import legacy as optimizers
4949
5050_PRE_TF_2_2_0 = version .parse (tf .__version__ ) < version .parse ("2.2.0" )
5151
@@ -71,10 +71,7 @@ def __init__(self, *args, **kwargs):
7171
7272 def test_train_model_lr_schedule (self ):
7373 initial_lr = 0.1 * hvd .size ()
74- if version .parse (tf .keras .__version__ .replace ("-tf" , "+tf" )) < version .parse ("2.11" ):
75- opt = tf .keras .optimizers .Adam ()
76- else :
77- opt = tf .keras .optimizers .legacy .Adam ()
74+ opt = optimizers .Adam ()
7875 opt = hvd .DistributedOptimizer (opt )
7976
8077 def linear_multiplier (epoch ):
@@ -164,10 +161,7 @@ def test_sparse_as_dense(self):
164161
165162 def test_sparse_as_dense_with_grad_aggregation (self ):
166163 backward_passes_per_step = 2
167- if version .parse (keras .__version__ .replace ("-tf" , "+tf" )) < version .parse ("2.11" ):
168- opt = keras .optimizers .RMSprop (lr = 0.0001 )
169- else :
170- opt = keras .optimizers .legacy .RMSprop (lr = 0.0001 )
164+ opt = optimizers .RMSprop (lr = 0.0001 )
171165 opt = hvd .DistributedOptimizer (
172166 opt ,
173167 sparse_as_dense = True ,
@@ -193,10 +187,7 @@ def test_sparse_as_dense_with_grad_aggregation(self):
193187 def test_grad_aggregation_with_inf_grad (self ):
194188 backward_passes_per_step = 2
195189 step_count = tf .Variable (0 , trainable = False , dtype = tf .int32 )
196- if version .parse (tf .keras .__version__ .replace ("-tf" , "+tf" )) < version .parse ("2.11" ):
197- opt = tf .keras .optimizers .SGD ()
198- else :
199- opt = tf .keras .optimizers .legacy .SGD ()
190+ opt = optimizers .SGD ()
200191 opt = hvd .DistributedOptimizer (
201192 opt ,
202193 backward_passes_per_step = backward_passes_per_step ,
@@ -221,10 +212,7 @@ def loss():
221212 assert tf .math .is_finite (grads_and_vars [0 ][0 ])
222213
223214 def test_from_config (self ):
224- if version .parse (keras .__version__ .replace ("-tf" , "+tf" )) < version .parse ("2.11" ):
225- opt = keras .optimizers .Adam ()
226- else :
227- opt = keras .optimizers .legacy .Adam ()
215+ opt = optimizers .Adam ()
228216 hopt = hvd .DistributedOptimizer (opt )
229217 cfg = hopt .get_config ()
230218
@@ -252,7 +240,7 @@ def test_elastic_state(self):
252240 [np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = np .float32 ),
253241 np .array ([0.0 , 0.0 ], dtype = np .float32 )])
254242
255- optimizer = Optimizer .Adam (0.001 * hvd .size ())
243+ optimizer = optimizers .Adam (0.001 * hvd .size ())
256244
257245 state = hvd .elastic .KerasState (
258246 model1 ,
@@ -543,10 +531,7 @@ def test_partial_distributed_optimizer(self):
543531 model .add (tf .keras .layers .Dense (2 , input_shape = (3 ,), kernel_initializer = initializer , bias_initializer = initializer ))
544532 model .add (tf .keras .layers .RepeatVector (3 ))
545533 model .add (tf .keras .layers .TimeDistributed (tf .keras .layers .Dense (3 , kernel_initializer = initializer , bias_initializer = initializer )))
546- if version .parse (tf .keras .__version__ .replace ("-tf" , "+tf" )) < version .parse ("2.11" ):
547- opt = tf .keras .optimizers .Adam ()
548- else :
549- opt = tf .keras .optimizers .legacy .Adam ()
534+ opt = optimizers .legacy .Adam ()
550535 model .compile (loss = tf .keras .losses .MSE ,
551536 metrics = [tf .keras .metrics .categorical_accuracy ])
552537
0 commit comments