Skip to content

Commit 3016e67

Browse files
committed
bind: give same treatment as connect in #8544, dedup
It is good to propagate the underlying error so whether or not we use a process to deal with path length issues is not observable. Also, as these wrapper functions got more and more complex, the code duplication got worse and worse. The new `bindConnectProcHelper` function deduplicates them.
1 parent 1d89c7b commit 3016e67

File tree

1 file changed

+29
-47
lines changed

1 file changed

+29
-47
lines changed

src/libutil/unix-domain-socket.cc

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -38,52 +38,20 @@ AutoCloseFD createUnixDomainSocket(const Path & path, mode_t mode)
3838
return fdSocket;
3939
}
4040

41-
static struct sockaddr* safeSockAddrPointerCast(struct sockaddr_un *addr) {
42-
// Casting between types like these legacy C library interfaces require
43-
// is forbidden in C++.
44-
// To maintain backwards compatibility, the implementation of the
45-
// bind function contains some hints to the compiler that allow for this
46-
// special case.
47-
return reinterpret_cast<struct sockaddr *>(addr);
48-
}
4941

50-
void bind(int fd, const std::string & path)
42+
static void bindConnectProcHelper(
43+
std::string_view operationName, auto && operation,
44+
int fd, const std::string & path)
5145
{
52-
unlink(path.c_str());
53-
5446
struct sockaddr_un addr;
5547
addr.sun_family = AF_UNIX;
56-
auto psaddr {safeSockAddrPointerCast(&addr)};
57-
58-
if (path.size() + 1 >= sizeof(addr.sun_path)) {
59-
Pid pid = startProcess([&] {
60-
Path dir = dirOf(path);
61-
if (chdir(dir.c_str()) == -1)
62-
throw SysError("chdir to '%s' failed", dir);
63-
std::string base(baseNameOf(path));
64-
if (base.size() + 1 >= sizeof(addr.sun_path))
65-
throw Error("socket path '%s' is too long", base);
66-
memcpy(addr.sun_path, base.c_str(), base.size() + 1);
67-
if (bind(fd, psaddr, sizeof(addr)) == -1)
68-
throw SysError("cannot bind to socket '%s'", path);
69-
_exit(0);
70-
});
71-
int status = pid.wait();
72-
if (status != 0)
73-
throw Error("cannot bind to socket '%s'", path);
74-
} else {
75-
memcpy(addr.sun_path, path.c_str(), path.size() + 1);
76-
if (bind(fd, psaddr, sizeof(addr)) == -1)
77-
throw SysError("cannot bind to socket '%s'", path);
78-
}
79-
}
80-
8148

82-
void connect(int fd, const std::string & path)
83-
{
84-
struct sockaddr_un addr;
85-
addr.sun_family = AF_UNIX;
86-
auto psaddr {safeSockAddrPointerCast(&addr)};
49+
// Casting between types like these legacy C library interfaces
50+
// require is forbidden in C++. To maintain backwards
51+
// compatibility, the implementation of the bind/connect functions
52+
// contains some hints to the compiler that allow for this
53+
// special case.
54+
auto * psaddr = reinterpret_cast<struct sockaddr *>(&addr);
8755

8856
if (path.size() + 1 >= sizeof(addr.sun_path)) {
8957
Pipe pipe;
@@ -98,8 +66,8 @@ void connect(int fd, const std::string & path)
9866
if (base.size() + 1 >= sizeof(addr.sun_path))
9967
throw Error("socket path '%s' is too long", base);
10068
memcpy(addr.sun_path, base.c_str(), base.size() + 1);
101-
if (connect(fd, psaddr, sizeof(addr)) == -1)
102-
throw SysError("cannot connect to socket at '%s'", path);
69+
if (operation(fd, psaddr, sizeof(addr)) == -1)
70+
throw SysError("cannot %s to socket at '%s'", operationName, path);
10371
writeFull(pipe.writeSide.get(), "0\n");
10472
} catch (SysError & e) {
10573
writeFull(pipe.writeSide.get(), fmt("%d\n", e.errNo));
@@ -110,16 +78,30 @@ void connect(int fd, const std::string & path)
11078
pipe.writeSide.close();
11179
auto errNo = string2Int<int>(chomp(drainFD(pipe.readSide.get())));
11280
if (!errNo || *errNo == -1)
113-
throw Error("cannot connect to socket at '%s'", path);
81+
throw Error("cannot %s to socket at '%s'", operationName, path);
11482
else if (*errNo > 0) {
11583
errno = *errNo;
116-
throw SysError("cannot connect to socket at '%s'", path);
84+
throw SysError("cannot %s to socket at '%s'", operationName, path);
11785
}
11886
} else {
11987
memcpy(addr.sun_path, path.c_str(), path.size() + 1);
120-
if (connect(fd, psaddr, sizeof(addr)) == -1)
121-
throw SysError("cannot connect to socket at '%s'", path);
88+
if (operation(fd, psaddr, sizeof(addr)) == -1)
89+
throw SysError("cannot %s to socket at '%s'", operationName, path);
12290
}
12391
}
12492

93+
94+
void bind(int fd, const std::string & path)
95+
{
96+
unlink(path.c_str());
97+
98+
bindConnectProcHelper("bind", ::bind, fd, path);
99+
}
100+
101+
102+
void connect(int fd, const std::string & path)
103+
{
104+
bindConnectProcHelper("connect", ::connect, fd, path);
105+
}
106+
125107
}

0 commit comments

Comments
 (0)