Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion snmp-discovery/policy/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"log/slog"
"time"

"github.com/netboxlabs/diode-sdk-go/diode"
"github.com/netboxlabs/orb-discovery/snmp-discovery/config"
Expand Down Expand Up @@ -193,7 +194,12 @@ func (m *Manager) StartPolicy(name string, policy config.Policy) error {
m.logger.Info("Loaded device lookup extensions", "directory", policy.Config.LookupExtensionsDir)
}

r, err := NewRunner(m.ctx, m.logger, name, policy, m.client, snmp.NewClient, &m.mappingConfig, m.manufacturers, deviceLookup)
// Create logger-aware ClientFactory wrapper
clientFactory := func(host string, port uint16, retries int, timeout time.Duration, authentication *config.Authentication, logger *slog.Logger) (snmp.Walker, error) {
return snmp.NewClient(host, port, retries, timeout, authentication, logger)
}

r, err := NewRunner(m.ctx, m.logger, name, policy, m.client, clientFactory, &m.mappingConfig, m.manufacturers, deviceLookup)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion snmp-discovery/policy/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ func TestRunnerWalkError(t *testing.T) {
mockHost.On("Walk", mock.Anything, mock.Anything).Return(nil, errors.New("walk error"))

// Create a mock client factory that returns the mock host
mockClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) {
mockClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) {
return mockHost, nil
}

Expand Down
3 changes: 2 additions & 1 deletion snmp-discovery/snmp/mocks.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package snmp

import (
"log/slog"
"time"

"github.com/gosnmp/gosnmp"
Expand Down Expand Up @@ -51,6 +52,6 @@ func (n *FakeSNMPWalker) Walk(oid string, _ int) (map[string]PDU, error) {
}

// NewFakeSNMPWalker creates a new FakeSNMPWalker
func NewFakeSNMPWalker(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (Walker, error) {
func NewFakeSNMPWalker(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (Walker, error) {
return &FakeSNMPWalker{}, nil
}
31 changes: 28 additions & 3 deletions snmp-discovery/snmp/snmp.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package snmp

import (
"context"
"fmt"
"log/slog"
"reflect"
Expand All @@ -11,6 +12,21 @@ import (
"github.com/netboxlabs/orb-discovery/snmp-discovery/mapping"
)

// SlogAdapter adapts slog.Logger to implement gosnmp.LoggerInterface
type SlogAdapter struct {
logger *slog.Logger
}

// Print implements gosnmp.LoggerInterface by logging at Debug level
func (s *SlogAdapter) Print(v ...interface{}) {
s.logger.Debug(fmt.Sprint(v...))
}

// Printf implements gosnmp.LoggerInterface by logging at Debug level
func (s *SlogAdapter) Printf(format string, v ...interface{}) {
s.logger.Debug(fmt.Sprintf(format, v...))
}

// Host is a struct that represents an SNMP host
type Host struct {
address string
Expand Down Expand Up @@ -39,7 +55,7 @@ func NewHost(host string, port uint16, retries int, timeout time.Duration, authe
func (s *Host) Walk(objectIDs map[string]int) (mapping.ObjectIDValueMap, error) {
s.logger.Info("Scanning", "host", s.address)

snmpClient, err := s.ClientFactory(s.address, s.port, s.retries, s.timeout, s.authentication)
snmpClient, err := s.ClientFactory(s.address, s.port, s.retries, s.timeout, s.authentication, s.logger)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -167,10 +183,16 @@ const (
)

// ClientFactory is a function that creates a new SNMPClient
type ClientFactory func(host string, port uint16, retries int, timeout time.Duration, authentication *config.Authentication) (Walker, error)
type ClientFactory func(host string, port uint16, retries int, timeout time.Duration, authentication *config.Authentication, logger *slog.Logger) (Walker, error)

// NewClient creates a new SNMPClient for the given target host
func NewClient(host string, port uint16, retries int, timeout time.Duration, authentication *config.Authentication) (Walker, error) {
func NewClient(host string, port uint16, retries int, timeout time.Duration, authentication *config.Authentication, logger *slog.Logger) (Walker, error) {
// Check if debug logging is enabled
var gosnmpLogger gosnmp.Logger
if logger.Enabled(context.Background(), slog.LevelDebug) {
gosnmpLogger = gosnmp.NewLogger(&SlogAdapter{logger})
}

switch authentication.ProtocolVersion {
case ProtocolVersion1:
return &Client{
Expand All @@ -181,6 +203,7 @@ func NewClient(host string, port uint16, retries int, timeout time.Duration, aut
Version: gosnmp.Version1,
Timeout: timeout,
Retries: retries,
Logger: gosnmpLogger,
},
}, nil
case ProtocolVersion2c:
Expand All @@ -192,6 +215,7 @@ func NewClient(host string, port uint16, retries int, timeout time.Duration, aut
Version: gosnmp.Version2c,
Timeout: timeout,
Retries: retries,
Logger: gosnmpLogger,
},
}, nil
case ProtocolVersion3:
Expand All @@ -211,6 +235,7 @@ func NewClient(host string, port uint16, retries int, timeout time.Duration, aut
Timeout: timeout,
Retries: retries,
SecurityModel: gosnmp.UserSecurityModel,
Logger: gosnmpLogger,
SecurityParameters: &gosnmp.UsmSecurityParameters{
UserName: authentication.Username,
AuthenticationProtocol: authProtocol,
Expand Down
19 changes: 10 additions & 9 deletions snmp-discovery/snmp/snmp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ func TestSNMPHost(t *testing.T) {

t.Run("Successfully walks a host", func(t *testing.T) {
// Setup
snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) {
fakeWalker, _ := snmp.NewFakeSNMPWalker("192.168.1.1", 161, 3, 1*time.Second, nil)
snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, logger *slog.Logger) (snmp.Walker, error) {
fakeWalker, _ := snmp.NewFakeSNMPWalker("192.168.1.1", 161, 3, 1*time.Second, nil, logger)
return fakeWalker, nil
}
host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory)
Expand Down Expand Up @@ -110,7 +110,7 @@ func TestSNMPHost(t *testing.T) {
interfaceSpeedOID + ".1": {Value: 1000000, Type: gosnmp.Integer, IdentifierSize: 1},
}, nil)

snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) {
snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) {
return mockWalker, nil
}
host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory)
Expand All @@ -136,7 +136,7 @@ func TestSNMPHost(t *testing.T) {
ipAddressObjectID: {Value: "192.168.1.1", Type: gosnmp.IPAddress, IdentifierSize: 4},
}, nil)

snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) {
snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) {
return mockWalker, nil
}
host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory)
Expand All @@ -158,7 +158,7 @@ func TestSNMPHost(t *testing.T) {
mockWalker := &MockSNMP{}
mockWalker.On("Connect").Return(assert.AnError)
mockWalker.On("Close").Return(nil)
snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) {
snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) {
return mockWalker, nil
}
host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory)
Expand All @@ -178,7 +178,7 @@ func TestSNMPHost(t *testing.T) {
mockWalker.On("Connect").Return(nil)
mockWalker.On("Close").Return(nil)
mockWalker.On("Walk", mock.Anything, mock.Anything).Return(make(map[string]snmp.PDU), assert.AnError)
snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) {
snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) {
return mockWalker, nil
}
host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory)
Expand Down Expand Up @@ -207,7 +207,7 @@ func TestSNMPHost(t *testing.T) {
interfaceSpeedOID + ".1": {Value: "invalid", Type: gosnmp.Asn1BER(255), IdentifierSize: 1}, // Invalid type
}, nil)

snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) {
snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) {
return mockWalker, nil
}
host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory)
Expand All @@ -232,7 +232,7 @@ func TestSNMPHost(t *testing.T) {
mockWalker.On("Connect").Return(nil)
mockWalker.On("Close").Return(nil)
mockWalker.On("Walk", mock.Anything, mock.Anything).Return(nil, assert.AnError)
snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication) (snmp.Walker, error) {
snmpClientFactory := func(_ string, _ uint16, _ int, _ time.Duration, _ *config.Authentication, _ *slog.Logger) (snmp.Walker, error) {
return nil, fmt.Errorf("error creating client")
}
host := snmp.NewHost("192.168.1.1", 161, 3, 1*time.Second, nil, logger, snmpClientFactory)
Expand Down Expand Up @@ -434,7 +434,8 @@ func TestNewClient(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
client, err := snmp.NewClient("192.168.1.1", 161, 3, 1*time.Second, tc.auth)
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
client, err := snmp.NewClient("192.168.1.1", 161, 3, 1*time.Second, tc.auth, logger)

if tc.expectError {
assert.Error(t, err)
Expand Down
Loading