Skip to content
Snippets Groups Projects
Unverified Commit b81b47be authored by Thomas Baumann's avatar Thomas Baumann Committed by GitHub
Browse files

Added a few things to the NCCL communicator (#503)

parent bf940bd4
Branches
Tags
No related merge requests found
Pipeline #233471 passed
...@@ -27,7 +27,7 @@ class NCCLComm(object): ...@@ -27,7 +27,7 @@ class NCCLComm(object):
Args: Args:
Name (str): Name of the requested attribute Name (str): Name of the requested attribute
""" """
if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']: if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split', 'Create_cart', 'Is_inter', 'Get_topology']:
cp.cuda.get_current_stream().synchronize() cp.cuda.get_current_stream().synchronize()
return getattr(self.commMPI, name) return getattr(self.commMPI, name)
...@@ -71,6 +71,26 @@ class NCCLComm(object): ...@@ -71,6 +71,26 @@ class NCCLComm(object):
else: else:
raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!') raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!')
def reduce(self, sendobj, op=MPI.SUM, root=0):
sync = False
if hasattr(sendobj, 'data'):
if hasattr(sendobj.data, 'ptr'):
sync = True
if sync:
cp.cuda.Device().synchronize()
return self.commMPI.reduce(sendobj, op=op, root=root)
def allreduce(self, sendobj, op=MPI.SUM):
sync = False
if hasattr(sendobj, 'data'):
if hasattr(sendobj.data, 'ptr'):
sync = True
if sync:
cp.cuda.Device().synchronize()
return self.commMPI.allreduce(sendobj, op=op)
def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0): def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
if not hasattr(sendbuf.data, 'ptr'): if not hasattr(sendbuf.data, 'ptr'):
return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root) return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root)
...@@ -113,3 +133,7 @@ class NCCLComm(object): ...@@ -113,3 +133,7 @@ class NCCLComm(object):
stream = cp.cuda.get_current_stream() stream = cp.cuda.get_current_stream()
self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr) self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr)
def Barrier(self):
cp.cuda.get_current_stream().synchronize()
self.commMPI.Barrier()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment