Skip to content

Commit 7c72dee

Browse files
committed
2.4.8-1
Fix ROCm#209: improve socket transport performance Split transfers over multiple sockets Launch multiple threads to drive sockets Detect AWS NICs and set nsockets/nthreads accordingly
1 parent 0ceaec9 commit 7c72dee

File tree

7 files changed

+425
-96
lines changed

7 files changed

+425
-96
lines changed

makefiles/version.mk

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
##### version
22
NCCL_MAJOR := 2
33
NCCL_MINOR := 4
4-
NCCL_PATCH := 7
4+
NCCL_PATCH := 8
55
NCCL_SUFFIX :=
66
PKG_REVISION := 1

src/bootstrap.cc

+130-22
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,145 @@
99
#include "utils.h"
1010
#include "bootstrap.h"
1111
#include "net.h"
12+
#include "socket.h"
1213
#include <unistd.h>
1314
#include <sys/types.h>
1415

1516
// Always use sockets for bootstrap
16-
ncclNet_t* ncclBootstrapNet = &ncclNetSocket;
17+
struct bootstrapNetHandle {
18+
union socketAddress connectAddr;
19+
};
20+
21+
struct bootstrapNetComm {
22+
int fd;
23+
};
24+
25+
/* Init functions */
26+
static char bootstrapNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
27+
static union socketAddress bootstrapNetIfAddrs[MAX_IFS];
28+
static int bootstrapNetIfs = -1;
29+
pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER;
30+
31+
ncclResult_t bootstrapNetInit() {
32+
if (bootstrapNetIfs == -1) {
33+
pthread_mutex_lock(&bootstrapNetLock);
34+
if (bootstrapNetIfs == -1) {
35+
bootstrapNetIfs = findInterfaces(bootstrapNetIfNames, bootstrapNetIfAddrs, MAX_IF_NAME_SIZE, MAX_IFS);
36+
if (bootstrapNetIfs <= 0) {
37+
WARN("Bootstrap : no socket interface found");
38+
return ncclInternalError;
39+
} else {
40+
char line[1024];
41+
char addrline[1024];
42+
line[0] = '\0';
43+
for (int i=0; i<bootstrapNetIfs; i++) {
44+
snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%s", i, bootstrapNetIfNames+i*MAX_IF_NAME_SIZE,
45+
socketToString(&bootstrapNetIfAddrs[i].sa, addrline));
46+
}
47+
line[1023] = '\0';
48+
INFO(NCCL_INIT, "Bootstrap : Using%s", line);
49+
}
50+
}
51+
pthread_mutex_unlock(&bootstrapNetLock);
52+
}
53+
return ncclSuccess;
54+
}
55+
56+
static ncclResult_t bootstrapNetNewComm(struct bootstrapNetComm** comm) {
57+
NCCLCHECK(ncclCalloc(comm, 1));
58+
(*comm)->fd = -1;
59+
return ncclSuccess;
60+
}
61+
62+
static ncclResult_t bootstrapNetGetSocketAddr(int dev, union socketAddress* addr) {
63+
if (dev >= bootstrapNetIfs) return ncclInternalError;
64+
memcpy(addr, bootstrapNetIfAddrs+dev, sizeof(*addr));
65+
return ncclSuccess;
66+
}
67+
68+
/* Socket Interface Selection type */
69+
enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
70+
71+
static ncclResult_t bootstrapNetListen(int dev, void* opaqueHandle, void** listenComm) {
72+
struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
73+
static_assert(sizeof(struct bootstrapNetHandle) < NCCL_NET_HANDLE_MAXSIZE, "bootstrapNetHandle size too large");
74+
// if dev >= 0, listen based on dev
75+
if (dev >= 0) {
76+
NCCLCHECK(bootstrapNetGetSocketAddr(dev, &(handle->connectAddr)));
77+
} else if (dev == findSubnetIf) {
78+
// handle stores a remote address
79+
// need to find a local addr that is in the same network as the remote addr
80+
union socketAddress localAddr;
81+
char ifName[MAX_IF_NAME_SIZE];
82+
if (findInterfaceMatchSubnet(ifName, &localAddr, handle->connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
83+
WARN("NET/Socket : No usable listening interface found");
84+
return ncclSystemError;
85+
}
86+
// pass the local address back
87+
memcpy(&handle->connectAddr, &localAddr, sizeof(handle->connectAddr));
88+
} // Otherwise, handle stores a local address
89+
struct bootstrapNetComm* comm;
90+
NCCLCHECK(bootstrapNetNewComm(&comm));
91+
NCCLCHECK(createListenSocket(&comm->fd, &handle->connectAddr));
92+
*listenComm = comm;
93+
return ncclSuccess;
94+
}
95+
96+
static ncclResult_t bootstrapNetConnect(int dev, void* opaqueHandle, void** sendComm) {
97+
struct bootstrapNetComm* comm;
98+
NCCLCHECK(bootstrapNetNewComm(&comm));
99+
struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
100+
NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr));
101+
*sendComm = comm;
102+
return ncclSuccess;
103+
}
104+
105+
static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) {
106+
struct bootstrapNetComm* lComm = (struct bootstrapNetComm*)listenComm;
107+
struct bootstrapNetComm* rComm;
108+
NCCLCHECK(bootstrapNetNewComm(&rComm));
109+
struct sockaddr_in sockaddr;
110+
socklen_t socklen = sizeof(struct sockaddr_in);
111+
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
112+
*recvComm = rComm;
113+
return ncclSuccess;
114+
}
17115

18-
static ncclResult_t bootstrapNetListen(int dev, void* handle, void** listenComm) { NCCLCHECK(ncclBootstrapNet->listen(dev, handle, listenComm)); return ncclSuccess; }
19-
static ncclResult_t bootstrapNetConnect(int dev, void* handle, void** sendComm) { NCCLCHECK(ncclBootstrapNet->connect(dev, handle, sendComm)); return ncclSuccess; }
20-
static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) { NCCLCHECK(ncclBootstrapNet->accept(listenComm, recvComm)); return ncclSuccess; }
21-
static ncclResult_t bootstrapNetTest(void* request, int* done, int* size) { NCCLCHECK(ncclBootstrapNet->test(request, done, size)); return ncclSuccess; }
22-
static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(ncclBootstrapNet->closeSend(sendComm)); return ncclSuccess; }
23-
static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(ncclBootstrapNet->closeRecv(recvComm)); return ncclSuccess; }
24-
static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(ncclBootstrapNet->closeListen(listenComm)); return ncclSuccess; }
116+
static ncclResult_t bootstrapNetClose(void* opaqueComm) {
117+
struct bootstrapNetComm* comm = (struct bootstrapNetComm*)opaqueComm;
118+
if (comm) {
119+
close(comm->fd);
120+
free(comm);
121+
}
122+
return ncclSuccess;
123+
}
25124

26-
// Additional sync functions based on async + test for bootstrap, using host ptrs.
125+
static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(bootstrapNetClose(sendComm)); return ncclSuccess; }
126+
static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(bootstrapNetClose(recvComm)); return ncclSuccess; }
127+
static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(bootstrapNetClose(listenComm)); return ncclSuccess; }
128+
129+
// Additional sync functions
27130
static ncclResult_t bootstrapNetSend(void* sendComm, void* data, int size) {
28-
void* request, *mhandle;
29-
NCCLCHECK(ncclBootstrapNet->regMr(sendComm, data, size, NCCL_PTR_HOST, &mhandle));
30-
NCCLCHECK(ncclBootstrapNet->isend(sendComm, data, size, mhandle, &request));
31-
NCCLCHECK(ncclBootstrapNet->deregMr(sendComm, mhandle));
32-
int done = 0;
33-
while (!done) NCCLCHECK(bootstrapNetTest(request, &done, NULL));
131+
struct bootstrapNetComm* comm = (struct bootstrapNetComm*)sendComm;
132+
NCCLCHECK(socketSend(comm->fd, &size, sizeof(int)));
133+
NCCLCHECK(socketSend(comm->fd, data, size));
34134
return ncclSuccess;
35135
}
36136
static ncclResult_t bootstrapNetRecv(void* recvComm, void* data, int size) {
37-
void* request, *mhandle;
38-
NCCLCHECK(ncclBootstrapNet->regMr(recvComm, data, size, NCCL_PTR_HOST, &mhandle));
39-
NCCLCHECK(ncclBootstrapNet->irecv(recvComm, data, size, mhandle, &request));
40-
NCCLCHECK(ncclBootstrapNet->deregMr(recvComm, mhandle));
41-
int done = 0;
42-
while (!done) NCCLCHECK(bootstrapNetTest(request, &done, NULL));
137+
struct bootstrapNetComm* comm = (struct bootstrapNetComm*)recvComm;
138+
int recvSize;
139+
NCCLCHECK(socketReceive(comm->fd, &recvSize, sizeof(int)));
140+
if (recvSize > size) {
141+
WARN("Message truncated : received %d bytes instead of %d\n", recvSize, size);
142+
return ncclInternalError;
143+
}
144+
NCCLCHECK(socketReceive(comm->fd, data, std::min(recvSize, size)));
145+
return ncclSuccess;
146+
}
147+
148+
ncclResult_t bootstrapNetCreateHandle(void* opaqueHandle, const char* str) {
149+
struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
150+
NCCLCHECK(GetSocketAddrFromString(&handle->connectAddr, str));
43151
return ncclSuccess;
44152
}
45153

@@ -148,7 +256,7 @@ ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out) {
148256

149257
char* env = getenv("NCCL_COMM_ID");
150258
if (env) {
151-
if (ncclSocketCreateHandle(&id->extHandleRoot, env) != 0) {
259+
if (bootstrapNetCreateHandle(&id->extHandleRoot, env) != 0) {
152260
WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
153261
return ncclInvalidArgument;
154262
}

src/include/bootstrap.h

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "nccl.h"
1111

12+
ncclResult_t bootstrapNetInit();
1213
ncclResult_t bootstrapCreateRoot(ncclUniqueId* commId, bool idFromEnv);
1314
ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out);
1415
ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commState);

src/include/net.h

-6
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@
1313
extern ncclNet_t* ncclNet;
1414
typedef char ncclNetHandle_t[NCCL_NET_HANDLE_MAXSIZE];
1515

16-
/* Socket Interface Selection type */
17-
typedef enum { findSubnetIf = -1,
18-
dontCareIf = -2
19-
} ncclSocketIfSl_t;
20-
2116
// Translation to external API
2217
static const char* ncclNetName() { return ncclNet->name; }
2318
static ncclResult_t ncclNetDevices(int* ndev) { NCCLCHECK(ncclNet->devices(ndev)); return ncclSuccess; }
@@ -36,7 +31,6 @@ static ncclResult_t ncclNetCloseSend(void* sendComm) { NCCLCHECK(ncclNet->closeS
3631
static ncclResult_t ncclNetCloseRecv(void* recvComm) { NCCLCHECK(ncclNet->closeRecv(recvComm)); return ncclSuccess; }
3732
static ncclResult_t ncclNetCloseListen(void* listenComm) { NCCLCHECK(ncclNet->closeListen(listenComm)); return ncclSuccess; }
3833

39-
extern ncclResult_t ncclSocketCreateHandle(void* opaqueHandle, const char* str);
4034
extern ncclNet_t ncclNetIb;
4135
extern ncclNet_t ncclNetSocket;
4236

src/include/socket.h

+14-7
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ static inline const char *socketToString(struct sockaddr *saddr, char *buf) {
4242
return buf;
4343
}
4444

45-
static inline short socketToPort(struct sockaddr *saddr) {
45+
static inline uint16_t socketToPort(struct sockaddr *saddr) {
4646
return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port);
4747
}
4848

@@ -161,7 +161,10 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) {
161161
}
162162

163163
static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress remoteAddr, int ifNameMaxSize, int maxIfs) {
164-
char line[1024], line_a[1024];
164+
#ifdef ENABLE_TRACE
165+
char line[1024];
166+
#endif
167+
char line_a[1024];
165168
int found = 0;
166169
struct ifaddrs *interfaces, *interface;
167170
getifaddrs(&interfaces);
@@ -185,7 +188,7 @@ static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAdd
185188
// Store the interface name
186189
strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
187190

188-
INFO(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a));
191+
TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a));
189192
found++;
190193
if (found == maxIfs) break;
191194
}
@@ -390,12 +393,12 @@ static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
390393

391394
#define NCCL_SOCKET_SEND 0
392395
#define NCCL_SOCKET_RECV 1
393-
static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
396+
static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) {
394397
int bytes = 0;
395398
char* data = (char*)ptr;
396399
do {
397-
if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), MSG_DONTWAIT);
398-
if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), MSG_DONTWAIT);
400+
if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
401+
if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
399402
if (op == NCCL_SOCKET_RECV && bytes == 0) {
400403
WARN("Net : Connection closed by remote peer");
401404
return ncclSystemError;
@@ -413,9 +416,13 @@ static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* off
413416
return ncclSuccess;
414417
}
415418

419+
static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
420+
return socketProgressOpt(op, fd, ptr, size, offset, 0);
421+
}
422+
416423
static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) {
417424
while (*offset < size)
418-
NCCLCHECK(socketProgress(op, fd, ptr, size, offset));
425+
NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1));
419426
return ncclSuccess;
420427
}
421428

src/init.cc

+3-2
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,15 @@ ncclResult_t initNetPlugin(ncclNet_t** net) {
124124
}
125125

126126
ncclResult_t initNet() {
127-
// Always initialize sockets as we use it for bootstrap
128-
NCCLCHECK(initNet(&ncclNetSocket));
127+
// Always initialize bootstrap network
128+
NCCLCHECK(bootstrapNetInit());
129129

130130
NCCLCHECK(initNetPlugin(&ncclNet));
131131
if (ncclNet != NULL) return ncclSuccess;
132132
if (initNet(&ncclNetIb) == ncclSuccess) {
133133
ncclNet = &ncclNetIb;
134134
} else {
135+
NCCLCHECK(initNet(&ncclNetSocket));
135136
ncclNet = &ncclNetSocket;
136137
}
137138
return ncclSuccess;

0 commit comments

Comments
 (0)