Skip to content
Snippets Groups Projects
Commit c98551cd authored by Michael Langguth's avatar Michael Langguth
Browse files

Ensure all-reduce statement for total loss after application of gradients during optimization.

parent cc3d91e5
Branches
Tags
No related merge requests found
Pipeline #88254 failed
...@@ -153,11 +153,12 @@ class VanillaConvLstmVideoPredictionModel(object): ...@@ -153,11 +153,12 @@ class VanillaConvLstmVideoPredictionModel(object):
# If monitoredTrainingSession is used, one must make use of hvd.BroadcastGlobalVariableHook(0), # If monitoredTrainingSession is used, one must make use of hvd.BroadcastGlobalVariableHook(0),
# If not, hvd.broadcast_global_variables after initialization of gloabl variables is suifficient # If not, hvd.broadcast_global_variables after initialization of gloabl variables is suifficient
# self.hooks = [hvd.BroadcastGlobalVariablesHook(0)] # self.hooks = [hvd.BroadcastGlobalVariablesHook(0)]
with tf.control_dependencies([grads]):
self.total_loss = hvd.allreduce(self.total_loss)
self.bcast_op = hvd.broadcast_global_variables(0) self.bcast_op = hvd.broadcast_global_variables(0)
# Apply gradient during training iteration # Apply gradient during training iteration
self.train_op = opt.apply_gradients(grads_and_vars=grads, global_step=self.global_step) self.train_op = opt.apply_gradients(grads_and_vars=grads, global_step=self.global_step)
# average loss
with tf.control_dependencies([self.train_op]):
self.total_loss = hvd.allreduce(self.total_loss)
self.outputs = {} self.outputs = {}
self.outputs["gen_images"] = self.x_hat self.outputs["gen_images"] = self.x_hat
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment