Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 29 additions & 47 deletions src/libutil/unix-domain-socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,52 +38,20 @@ AutoCloseFD createUnixDomainSocket(const Path & path, mode_t mode)
return fdSocket;
}

static struct sockaddr* safeSockAddrPointerCast(struct sockaddr_un *addr) {
// Casting between types like these legacy C library interfaces require
// is forbidden in C++.
// To maintain backwards compatibility, the implementation of the
// bind function contains some hints to the compiler that allow for this
// special case.
return reinterpret_cast<struct sockaddr *>(addr);
}

void bind(int fd, const std::string & path)
static void bindConnectProcHelper(
std::string_view operationName, auto && operation,
int fd, const std::string & path)
{
unlink(path.c_str());

struct sockaddr_un addr;
addr.sun_family = AF_UNIX;
auto psaddr {safeSockAddrPointerCast(&addr)};

if (path.size() + 1 >= sizeof(addr.sun_path)) {
Pid pid = startProcess([&] {
Path dir = dirOf(path);
if (chdir(dir.c_str()) == -1)
throw SysError("chdir to '%s' failed", dir);
std::string base(baseNameOf(path));
if (base.size() + 1 >= sizeof(addr.sun_path))
throw Error("socket path '%s' is too long", base);
memcpy(addr.sun_path, base.c_str(), base.size() + 1);
if (bind(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot bind to socket '%s'", path);
_exit(0);
});
int status = pid.wait();
if (status != 0)
throw Error("cannot bind to socket '%s'", path);
} else {
memcpy(addr.sun_path, path.c_str(), path.size() + 1);
if (bind(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot bind to socket '%s'", path);
}
}


void connect(int fd, const std::string & path)
{
struct sockaddr_un addr;
addr.sun_family = AF_UNIX;
auto psaddr {safeSockAddrPointerCast(&addr)};
// Casting between types like these legacy C library interfaces
// require is forbidden in C++. To maintain backwards
// compatibility, the implementation of the bind/connect functions
// contains some hints to the compiler that allow for this
// special case.
auto * psaddr = reinterpret_cast<struct sockaddr *>(&addr);

if (path.size() + 1 >= sizeof(addr.sun_path)) {
Pipe pipe;
Expand All @@ -98,8 +66,8 @@ void connect(int fd, const std::string & path)
if (base.size() + 1 >= sizeof(addr.sun_path))
throw Error("socket path '%s' is too long", base);
memcpy(addr.sun_path, base.c_str(), base.size() + 1);
if (connect(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot connect to socket at '%s'", path);
if (operation(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot %s to socket at '%s'", operationName, path);
writeFull(pipe.writeSide.get(), "0\n");
} catch (SysError & e) {
writeFull(pipe.writeSide.get(), fmt("%d\n", e.errNo));
Expand All @@ -110,16 +78,30 @@ void connect(int fd, const std::string & path)
pipe.writeSide.close();
auto errNo = string2Int<int>(chomp(drainFD(pipe.readSide.get())));
if (!errNo || *errNo == -1)
throw Error("cannot connect to socket at '%s'", path);
throw Error("cannot %s to socket at '%s'", operationName, path);
else if (*errNo > 0) {
errno = *errNo;
throw SysError("cannot connect to socket at '%s'", path);
throw SysError("cannot %s to socket at '%s'", operationName, path);
}
} else {
memcpy(addr.sun_path, path.c_str(), path.size() + 1);
if (connect(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot connect to socket at '%s'", path);
if (operation(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot %s to socket at '%s'", operationName, path);
}
}


void bind(int fd, const std::string & path)
{
unlink(path.c_str());

bindConnectProcHelper("bind", ::bind, fd, path);
}


void connect(int fd, const std::string & path)
{
bindConnectProcHelper("connect", ::connect, fd, path);
}

}