Skip to content

Commit e827aa4

Browse files
authored
Merge pull request #4008 from norio-nomura/portfwd-support-host-socket
`portfwd`: support HostSocket
2 parents c85e72c + ef88218 commit e827aa4

File tree

7 files changed

+81
-19
lines changed

7 files changed

+81
-19
lines changed

hack/test-port-forwarding.pl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
use warnings;
1717

1818
use Config qw(%Config);
19+
use File::Spec::Functions qw(catfile);
1920
use IO::Handle qw();
2021
use Socket qw(inet_ntoa);
2122
use Sys::Hostname qw(hostname);
@@ -33,6 +34,11 @@
3334
chomp $ipv4;
3435
}
3536

37+
my $instDir = qx(limactl list "$instance" --yq .dir);
38+
chomp $instDir;
39+
# platform independent way to add trailing path separator
40+
my $sockDir = catfile($instDir, "sock", "");
41+
3642
# If $instance is a filename, add our portForwards to it to enable testing
3743
if (-f $instance) {
3844
open(my $fh, "+< $instance") or die "Can't open $instance for read/write: $!";
@@ -94,14 +100,17 @@
94100
s/sshLocalPort/$sshLocalPort/g;
95101
s/ipv4/$ipv4/g;
96102
s/ipv6/$ipv6/g;
103+
s/sockDir\//$sockDir/g;
97104
# forward: 127.0.0.1 899 → 127.0.0.1 799
98105
# ignore: 127.0.0.2 8888
99-
/^(forward|ignore):\s+([0-9.:]+)\s+(\d+)(?:\s+→)?(?:\s+([0-9.:]+)(?:\s+(\d+))?)?/;
106+
/^(forward|ignore):\s+([0-9.:]+)\s+(\d+)(?:\s+→)?(?:\s+(?:([0-9.:]+)(?:\s+(\d+))|(\S+))?)?/;
100107
die "Cannot parse test '$_'" unless $1;
101-
my %test; @test{qw(mode guest_ip guest_port host_ip host_port)} = ($1, $2, $3, $4, $5);
108+
my %test; @test{qw(mode guest_ip guest_port host_ip host_port host_socket)} = ($1, $2, $3, $4, $5, $6);
109+
102110
$test{host_ip} ||= "127.0.0.1";
103111
$test{host_port} ||= $test{guest_port};
104-
if ($test{mode} eq "forward" && $test{host_port} < 1024 && $Config{osname} ne "darwin") {
112+
$test{host_socket} ||= "";
113+
if ($test{mode} eq "forward" && $test{host_socket} eq "" && $test{host_port} < 1024 && $Config{osname} ne "darwin") {
105114
printf "🚧 Not supported on $Config{osname}: # $_\n";
106115
next;
107116
}
@@ -113,9 +122,13 @@
113122
printf "🚧 Not supported for $instanceType machines: # $_\n";
114123
next;
115124
}
125+
if ($test{host_socket} ne "" && $Config{osname} eq "cygwin") {
126+
printf "🚧 Not supported on $Config{osname}: # $_\n";
127+
next;
128+
}
116129

117130
my $remote = JoinHostPort($test{guest_ip},$test{guest_port});
118-
my $local = JoinHostPort($test{host_ip},$test{host_port});
131+
my $local = $test{host_socket} eq "" ? JoinHostPort($test{host_ip},$test{host_port}) : $test{host_socket};
119132
if ($test{mode} eq "ignore") {
120133
$test{log_msg} = "Not forwarding TCP $remote";
121134
}
@@ -163,7 +176,7 @@
163176
# Try to reach each listener from the host
164177
foreach my $test (@test) {
165178
next if $test->{host_port} == $sshLocalPort;
166-
my $nc = "nc -w 1 $test->{host_ip} $test->{host_port}";
179+
my $nc = $test->{host_socket} eq "" ? "nc -w 1 $test->{host_ip} $test->{host_port}" : "nc -w 1 -U $test->{host_socket}";
167180
open(my $netcat, "| $nc") or die "Can't run '$nc': $!";
168181
print $netcat "$test->{log_msg}\n";
169182
# Don't check for errors on close; macOS nc seems to return non-zero exit code even on success
@@ -175,7 +188,7 @@
175188
seek($log, $ha_log_size, 0) or die "Can't seek $ha_log to $ha_log_size: $!";
176189
my %seen;
177190
while (<$log>) {
178-
$seen{$1}++ if /(Forwarding TCP from .*? to (\d.*?|\[.*?\]):\d+)/;
191+
$seen{$1}++ if /(Forwarding TCP from .*? to ((\d.*?|\[.*?\]):\d+|\/[^"]+))/;
179192
$seen{$1}++ if /(Not forwarding TCP .*?:\d+)/;
180193
}
181194
close $log or die;
@@ -342,3 +355,8 @@ sub JoinHostPort {
342355
- guestIPMustBeZero: true
343356
guestPort: 8888
344357
hostIP: 0.0.0.0
358+
359+
- guestPort: 5000
360+
hostSocket: port5000.sock
361+
362+
# forward: 127.0.0.1 5000 → sockDir/port5000.sock

pkg/limayaml/defaults.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -902,14 +902,6 @@ func FillPortForwardDefaults(rule *limatype.PortForward, instDir string, user li
902902
rule.GuestPortRange[1] = rule.GuestPort
903903
}
904904
}
905-
if rule.HostPortRange[0] == 0 && rule.HostPortRange[1] == 0 {
906-
if rule.HostPort == 0 {
907-
rule.HostPortRange = rule.GuestPortRange
908-
} else {
909-
rule.HostPortRange[0] = rule.HostPort
910-
rule.HostPortRange[1] = rule.HostPort
911-
}
912-
}
913905
if rule.GuestSocket != "" {
914906
if out, err := executeGuestTemplate(rule.GuestSocket, instDir, user, param); err == nil {
915907
rule.GuestSocket = out.String()
@@ -926,6 +918,13 @@ func FillPortForwardDefaults(rule *limatype.PortForward, instDir string, user li
926918
if !filepath.IsAbs(rule.HostSocket) {
927919
rule.HostSocket = filepath.Join(instDir, filenames.SocketDir, rule.HostSocket)
928920
}
921+
} else if rule.HostPortRange[0] == 0 && rule.HostPortRange[1] == 0 {
922+
if rule.HostPort == 0 {
923+
rule.HostPortRange = rule.GuestPortRange
924+
} else {
925+
rule.HostPortRange[0] = rule.HostPort
926+
rule.HostPortRange[1] = rule.HostPort
927+
}
929928
}
930929
}
931930

pkg/limayaml/defaults_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ func TestFillDefault(t *testing.T) {
268268
expect.PortForwards[2].HostPort = 8888
269269
expect.PortForwards[2].HostPortRange = [2]int{8888, 8888}
270270

271+
expect.PortForwards[3].HostPortRange = [2]int{0, 0}
271272
expect.PortForwards[3].GuestSocket = fmt.Sprintf("%s | %s | %s | %s", user.HomeDir, user.Uid, user.Username, y.Param["ONE"])
272273
expect.PortForwards[3].HostSocket = fmt.Sprintf("%s | %s | %s | %s | %s | %s", hostHome, instDir, instName, currentUser.Uid, currentUser.Username, y.Param["ONE"])
273274

pkg/limayaml/validate.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,10 @@ func Validate(y *limatype.LimaYAML, warn bool) error {
314314
if err := validatePort(fmt.Sprintf("%s.guestPortRange[%d]", field, j), rule.GuestPortRange[j]); err != nil {
315315
errs = errors.Join(errs, err)
316316
}
317-
if err := validatePort(fmt.Sprintf("%s.hostPortRange[%d]", field, j), rule.HostPortRange[j]); err != nil {
318-
errs = errors.Join(errs, err)
317+
if rule.HostSocket == "" {
318+
if err := validatePort(fmt.Sprintf("%s.hostPortRange[%d]", field, j), rule.HostPortRange[j]); err != nil {
319+
errs = errors.Join(errs, err)
320+
}
319321
}
320322
}
321323
if rule.GuestPortRange[0] > rule.GuestPortRange[1] {
@@ -324,9 +326,6 @@ func Validate(y *limatype.LimaYAML, warn bool) error {
324326
if rule.HostPortRange[0] > rule.HostPortRange[1] {
325327
errs = errors.Join(errs, fmt.Errorf("field `%s.hostPortRange[1]` must be greater than or equal to field `%s.hostPortRange[0]`", field, field))
326328
}
327-
if rule.GuestPortRange[1]-rule.GuestPortRange[0] != rule.HostPortRange[1]-rule.HostPortRange[0] {
328-
errs = errors.Join(errs, fmt.Errorf("field `%s.hostPortRange` must specify the same number of ports as field `%s.guestPortRange`", field, field))
329-
}
330329
if rule.GuestSocket != "" {
331330
if !path.IsAbs(rule.GuestSocket) {
332331
errs = errors.Join(errs, fmt.Errorf("field `%s.guestSocket` must be an absolute path, but is %q", field, rule.GuestSocket))
@@ -343,7 +342,10 @@ func Validate(y *limatype.LimaYAML, warn bool) error {
343342
if rule.GuestSocket == "" && rule.GuestPortRange[1]-rule.GuestPortRange[0] > 0 {
344343
errs = errors.Join(errs, fmt.Errorf("field `%s.hostSocket` can only be mapped from a single port or socket. not a range", field))
345344
}
345+
} else if rule.GuestPortRange[1]-rule.GuestPortRange[0] != rule.HostPortRange[1]-rule.HostPortRange[0] {
346+
errs = errors.Join(errs, fmt.Errorf("field `%s.hostPortRange` must specify the same number of ports as field `%s.guestPortRange`", field, field))
346347
}
348+
347349
if len(rule.HostSocket) >= osutil.UnixPathMax {
348350
errs = errors.Join(errs, fmt.Errorf("field `%s.hostSocket` must be less than UNIX_PATH_MAX=%d characters, but is %d",
349351
field, osutil.UnixPathMax, len(rule.HostSocket)))

pkg/portfwd/listener.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"errors"
99
"fmt"
1010
"net"
11+
"os"
12+
"path/filepath"
1113
"strings"
1214
"sync"
1315

@@ -146,3 +148,13 @@ func (p *ClosableListeners) forwardUDP(ctx context.Context, client *guestagentcl
146148
func key(protocol, hostAddress, guestAddress string) string {
147149
return fmt.Sprintf("%s-%s-%s", protocol, hostAddress, guestAddress)
148150
}
151+
152+
func prepareUnixSocket(hostSocket string) error {
153+
if err := os.RemoveAll(hostSocket); err != nil {
154+
return fmt.Errorf("can't clean up %q: %w", hostSocket, err)
155+
}
156+
if err := os.MkdirAll(filepath.Dir(hostSocket), 0o755); err != nil {
157+
return fmt.Errorf("can't create directory for local socket %q: %w", hostSocket, err)
158+
}
159+
return nil
160+
}

pkg/portfwd/listener_darwin.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,26 @@ import (
77
"context"
88
"fmt"
99
"net"
10+
"path/filepath"
1011
"strconv"
1112

1213
"github.com/sirupsen/logrus"
1314
)
1415

1516
func Listen(ctx context.Context, listenConfig net.ListenConfig, hostAddress string) (net.Listener, error) {
17+
if filepath.IsAbs(hostAddress) {
18+
// Handle Unix domain sockets
19+
if err := prepareUnixSocket(hostAddress); err != nil {
20+
return nil, err
21+
}
22+
var lc net.ListenConfig
23+
unixLis, err := lc.Listen(ctx, "unix", hostAddress)
24+
if err != nil {
25+
logrus.WithError(err).Errorf("failed to listen unix: %v", hostAddress)
26+
return nil, err
27+
}
28+
return unixLis, nil
29+
}
1630
localIPStr, localPortStr, _ := net.SplitHostPort(hostAddress)
1731
localIP := net.ParseIP(localIPStr)
1832
localPort, _ := strconv.Atoi(localPortStr)

pkg/portfwd/listener_others.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,25 @@ package portfwd
88
import (
99
"context"
1010
"net"
11+
"path/filepath"
12+
13+
"github.com/sirupsen/logrus"
1114
)
1215

1316
func Listen(ctx context.Context, listenConfig net.ListenConfig, hostAddress string) (net.Listener, error) {
17+
if filepath.IsAbs(hostAddress) {
18+
// Handle Unix domain sockets
19+
if err := prepareUnixSocket(hostAddress); err != nil {
20+
return nil, err
21+
}
22+
var lc net.ListenConfig
23+
unixLis, err := lc.Listen(ctx, "unix", hostAddress)
24+
if err != nil {
25+
logrus.WithError(err).Errorf("failed to listen unix: %v", hostAddress)
26+
return nil, err
27+
}
28+
return unixLis, nil
29+
}
1430
return listenConfig.Listen(ctx, "tcp", hostAddress)
1531
}
1632

0 commit comments

Comments
 (0)