diff --git a/snmp-discovery/policy/manager.go b/snmp-discovery/policy/manager.go index 134c340..f640bac 100644 --- a/snmp-discovery/policy/manager.go +++ b/snmp-discovery/policy/manager.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log/slog" + "time" "github.com/netboxlabs/diode-sdk-go/diode" "github.com/netboxlabs/orb-discovery/snmp-discovery/config" @@ -193,7 +194,12 @@ func (m *Manager) StartPolicy(name string, policy config.Policy) error { m.logger.Info("Loaded device lookup extensions", "directory", policy.Config.LookupExtensionsDir) } - r, err := NewRunner(m.ctx, m.logger, name, policy, m.client, snmp.NewClient, &m.mappingConfig, m.manufacturers, deviceLookup) + // Create logger-aware ClientFactory wrapper + clientFactory := func(host string, port uint16, retries int, timeout time.Duration, authentication *config.Authentication, logger *slog.Logger) (snmp.Walker, error) { + return snmp.NewClient(host, port, retries, timeout, authentication, logger) + } + + r, err := NewRunner(m.ctx, m.logger, name, policy, m.client, clientFactory, &m.mappingConfig, m.manufacturers, deviceLookup) if err != nil { return err } diff --git a/snmp-discovery/policy/runner_test.go b/snmp-discovery/policy/runner_test.go index d47fc2b..dd6c46e 100644 --- a/snmp-discovery/policy/runner_test.go +++ b/snmp-discovery/policy/runner_test.go @@ -353,7 +353,7 @@ func TestRunnerWalkError(t *testing.T) { mockHost.On("Walk", mock.Anything, mock.Anything).Return(nil, errors.New("walk error")) // Create a mock client factory that returns the mock host - mockClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) { + mockClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) { return mockHost, nil } diff --git a/snmp-discovery/snmp/mocks.go b/snmp-discovery/snmp/mocks.go index 25c5503..5152987 100644 --- a/snmp-discovery/snmp/mocks.go +++ b/snmp-discovery/snmp/mocks.go @@ -1,6 +1,7 @@ package snmp import ( + "log/slog" "time" "github.com/gosnmp/gosnmp" @@ -51,6 +52,6 @@ func (n *FakeSNMPWalker) Walk(oid string, _ int) (map[string]PDU, error) { } // NewFakeSNMPWalker creates a new FakeSNMPWalker -func NewFakeSNMPWalker(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (Walker, error) { +func NewFakeSNMPWalker(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (Walker, error) { return &FakeSNMPWalker{}, nil } diff --git a/snmp-discovery/snmp/snmp.go b/snmp-discovery/snmp/snmp.go index efe8db0..b58ec62 100644 --- a/snmp-discovery/snmp/snmp.go +++ b/snmp-discovery/snmp/snmp.go @@ -1,6 +1,7 @@ package snmp import ( + "context" "fmt" "log/slog" "reflect" @@ -11,6 +12,21 @@ import ( "github.com/netboxlabs/orb-discovery/snmp-discovery/mapping" ) +// SlogAdapter adapts slog.Logger to implement gosnmp.LoggerInterface +type SlogAdapter struct { + logger *slog.Logger +} + +// Print implements gosnmp.LoggerInterface by logging at Debug level +func (s *SlogAdapter) Print(v ...interface{}) { + s.logger.Debug(fmt.Sprint(v...)) +} + +// Printf implements gosnmp.LoggerInterface by logging at Debug level +func (s *SlogAdapter) Printf(format string, v ...interface{}) { + s.logger.Debug(fmt.Sprintf(format, v...)) +} + // Host is a struct that represents an SNMP host type Host struct { address string @@ -39,7 +55,7 @@ func NewHost(host string, port uint16, retries int, timeout time.Duration, authe func (s *Host) Walk(objectIDs map[string]int) (mapping.ObjectIDValueMap, error) { s.logger.Info("Scanning", "host", s.address) - snmpClient, err := s.ClientFactory(s.address, s.port, s.retries, s.timeout, s.authentication) + snmpClient, err := s.ClientFactory(s.address, s.port, s.retries, s.timeout, s.authentication, s.logger) if err != nil { return nil, err } @@ -167,10 +183,16 @@ const ( ) // ClientFactory is a function that creates a new SNMPClient -type ClientFactory func(host string, port uint16, retries int, timeout time.Duration, authentication *config.Authentication) (Walker, error) +type ClientFactory func(host string, port uint16, retries int, timeout time.Duration, authentication *config.Authentication, logger *slog.Logger) (Walker, error) // NewClient creates a new SNMPClient for the given target host -func NewClient(host string, port uint16, retries int, timeout time.Duration, authentication *config.Authentication) (Walker, error) { +func NewClient(host string, port uint16, retries int, timeout time.Duration, authentication *config.Authentication, logger *slog.Logger) (Walker, error) { + // Check if debug logging is enabled + var gosnmpLogger gosnmp.Logger + if logger.Enabled(context.Background(), slog.LevelDebug) { + gosnmpLogger = gosnmp.NewLogger(&SlogAdapter{logger}) + } + switch authentication.ProtocolVersion { case ProtocolVersion1: return &Client{ @@ -181,6 +203,7 @@ func NewClient(host string, port uint16, retries int, timeout time.Duration, aut Version: gosnmp.Version1, Timeout: timeout, Retries: retries, + Logger: gosnmpLogger, }, }, nil case ProtocolVersion2c: @@ -192,6 +215,7 @@ func NewClient(host string, port uint16, retries int, timeout time.Duration, aut Version: gosnmp.Version2c, Timeout: timeout, Retries: retries, + Logger: gosnmpLogger, }, }, nil case ProtocolVersion3: @@ -211,6 +235,7 @@ func NewClient(host string, port uint16, retries int, timeout time.Duration, aut Timeout: timeout, Retries: retries, SecurityModel: gosnmp.UserSecurityModel, + Logger: gosnmpLogger, SecurityParameters: &gosnmp.UsmSecurityParameters{ UserName: authentication.Username, AuthenticationProtocol: authProtocol, diff --git a/snmp-discovery/snmp/snmp_test.go b/snmp-discovery/snmp/snmp_test.go index 8306b5d..ba022aa 100644 --- a/snmp-discovery/snmp/snmp_test.go +++ b/snmp-discovery/snmp/snmp_test.go @@ -78,8 +78,8 @@ func TestSNMPHost(t *testing.T) { t.Run("Successfully walks a host", func(t *testing.T) { // Setup - snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) { - fakeWalker, _ := snmp.NewFakeSNMPWalker("192.168.1.1", 161, 3, 1*time.Second, nil) + snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, logger *slog.Logger) (snmp.Walker, error) { + fakeWalker, _ := snmp.NewFakeSNMPWalker("192.168.1.1", 161, 3, 1*time.Second, nil, logger) return fakeWalker, nil } host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory) @@ -110,7 +110,7 @@ func TestSNMPHost(t *testing.T) { interfaceSpeedOID + ".1": {Value: 1000000, Type: gosnmp.Integer, IdentifierSize: 1}, }, nil) - snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) { + snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) { return mockWalker, nil } host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory) @@ -136,7 +136,7 @@ func TestSNMPHost(t *testing.T) { ipAddressObjectID: {Value: "192.168.1.1", Type: gosnmp.IPAddress, IdentifierSize: 4}, }, nil) - snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) { + snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) { return mockWalker, nil } host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory) @@ -158,7 +158,7 @@ func TestSNMPHost(t *testing.T) { mockWalker := &MockSNMP{} mockWalker.On("Connect").Return(assert.AnError) mockWalker.On("Close").Return(nil) - snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) { + snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) { return mockWalker, nil } host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory) @@ -178,7 +178,7 @@ func TestSNMPHost(t *testing.T) { mockWalker.On("Connect").Return(nil) mockWalker.On("Close").Return(nil) mockWalker.On("Walk", mock.Anything, mock.Anything).Return(make(map[string]snmp.PDU), assert.AnError) - snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) { + snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) { return mockWalker, nil } host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory) @@ -207,7 +207,7 @@ func TestSNMPHost(t *testing.T) { interfaceSpeedOID + ".1": {Value: "invalid", Type: gosnmp.Asn1BER(255), IdentifierSize: 1}, // Invalid type }, nil) - snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) { + snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) { return mockWalker, nil } host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory) @@ -232,7 +232,7 @@ func TestSNMPHost(t *testing.T) { mockWalker.On("Connect").Return(nil) mockWalker.On("Close").Return(nil) mockWalker.On("Walk", mock.Anything, mock.Anything).Return(nil, assert.AnError) - snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) { + snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) { return nil, fmt.Errorf("error creating client") } host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory) @@ -434,7 +434,8 @@ func TestNewClient(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - client, err := snmp.NewClient("192.168.1.1", 161, 3, 1*time.Second, tc.auth) + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + client, err := snmp.NewClient("192.168.1.1", 161, 3, 1*time.Second, tc.auth, logger) if tc.expectError { assert.Error(t, err)