Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 264 additions & 0 deletions bindings/sftp/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package sftp

import (
"errors"
"fmt"
"os"
"sync"

sftpClient "github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)

type Client struct {
sshClient *ssh.Client
sftpClient *sftpClient.Client
address string
config *ssh.ClientConfig
lock sync.RWMutex
rLock sync.Mutex
Comment on lines +31 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need two locks?

Copy link
Contributor Author

@javier-aliaga javier-aliaga Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because:

  • lock

    • Controls that only one goroutin trying to reconnect and verifies with ping() if the connection is valid or not
  • rlock

    • We do want to prevent multiple goroutines trying to use the client while swapping the connection and ensure the consistency of the client. We cannot swap the client until we have the full lock

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps renaming the locks if this is the best path forwards would be helpful e.g.:

  • reconnectLock
  • operationsLock

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can we not have a single read write lock?

  • rlock: when using the client in the happy path
  • lock: when reconnecting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with that is we do not know is the connection is valid, so inside the reconnect we have a ping method that does a simple request to the server.

So imagine you have a connection broken and you pile up 1000 goroutins waiting to reconnect, once the first action heals the connection then the other 999 will just ping the server and exits (all with a lock). This 999 goroutins will not lock the readers. If you have a single lock for re-connections then every new goroutine will be blocked.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to achieve the right behaviour (and simplify the logic) with the following pseudo code

r.RLock()
client.Call
if err != nil { needsReconnect.Store(true) }
r.RUnlock()
if err != nil {
  r.Lock()
  if needsReconnect.Load() { RECONECT(); needsReconnect.Store(false) }
  r.UnLock()
  GOTO BEGIN
}

}

func newClient(address string, config *ssh.ClientConfig) (*Client, error) {
if address == "" || config == nil {
return nil, errors.New("sftp binding error: client not initialized")
}

sshClient, err := newSSHClient(address, config)
if err != nil {
return nil, err
}

newSftpClient, err := sftpClient.NewClient(sshClient)
if err != nil {
_ = sshClient.Close()
return nil, fmt.Errorf("sftp binding error: error create sftp client: %w", err)
}

return &Client{
sshClient: sshClient,
sftpClient: newSftpClient,
address: address,
config: config,
}, nil
}

func (c *Client) Close() error {
_ = c.sshClient.Close()
c.lock.Lock()
defer c.lock.Unlock()
return c.sftpClient.Close()
}

func (c *Client) list(path string) ([]os.FileInfo, error) {
var fi []os.FileInfo

fn := func() error {
var err error
c.lock.RLock()
defer c.lock.RUnlock()
fi, err = c.sftpClient.ReadDir(path)
return err
}

err := withReconnection(c, fn)
if err != nil {
return nil, err
}

return fi, nil
}

func (c *Client) create(path string) (*sftpClient.File, string, error) {
dir, fileName := sftpClient.Split(path)

var file *sftpClient.File

createFn := func() error {
c.lock.RLock()
defer c.lock.RUnlock()
cErr := c.sftpClient.MkdirAll(dir)
if cErr != nil {
return fmt.Errorf("sftp binding error: error create dir %s: %w", dir, cErr)
}

file, cErr = c.sftpClient.Create(path)
if cErr != nil {
return fmt.Errorf("sftp binding error: error create file %s: %w", path, cErr)
}

return nil
}

rErr := withReconnection(c, createFn)
if rErr != nil {
return nil, "", rErr
}

return file, fileName, nil
}

func (c *Client) get(path string) (*sftpClient.File, error) {
var f *sftpClient.File

fn := func() error {
var err error
c.lock.RLock()
defer c.lock.RUnlock()
f, err = c.sftpClient.Open(path)
return err
}

err := withReconnection(c, fn)
if err != nil {
return nil, err
}

return f, nil
}

func (c *Client) delete(path string) error {
fn := func() error {
var err error
c.lock.RLock()
defer c.lock.RUnlock()
err = c.sftpClient.Remove(path)
return err
}

err := withReconnection(c, fn)
if err != nil {
return err
}

return nil
}

func (c *Client) ping() error {
c.lock.RLock()
defer c.lock.RUnlock()
_, err := c.sftpClient.Getwd()
if err != nil {
return err
}
return nil
}

func withReconnection(c *Client, fn func() error) error {
err := fn()
if err == nil {
return nil
}

if !shouldReconnect(err) {
return err
}

rErr := doReconnect(c)
if rErr != nil {
return errors.Join(err, rErr)
}

err = fn()
if err != nil {
return err
}

return nil
}

// 1) c.rLock (sync.Mutex) — reconnect serialization:
// - Ensures only one goroutine performs the reconnect sequence at a time
// (ping/check, dial SSH, create SFTP client), preventing a thundering herd
// of concurrent reconnect attempts.
// - Does NOT protect day-to-day client usage; it only coordinates who
// is allowed to perform a reconnect.
//
// 2) c.lock (sync.RWMutex) — data-plane safety and atomic swap:
// - Guards reads/writes of the active client handles (sshClient, sftpClient).
// - Regular operations hold RLock while using the clients.
// - Reconnect performs a short critical section with Lock to atomically swap
// the client pointers; old clients are closed after unlocking to keep the
// critical section small and avoid blocking readers.
//
// Why not a single RWMutex?
// - If we used only c.lock and held it while dialing/handshaking, all I/O would
// be blocked for the entire network operation, increasing latency and risk of
// contention. Worse, reconnects triggered while a caller holds RLock could
// deadlock or starve the writer.
// - Separating concerns allows: (a) fast, minimal swap under c.lock, and
// (b) serialized reconnect work under c.rLock without blocking readers.
func doReconnect(c *Client) error {
c.rLock.Lock()
defer c.rLock.Unlock()
Comment on lines +205 to +206
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is hard to follow the logic when we have 2 locks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment to explain why is each lock used, let me know what do you think


err := c.ping()
if !shouldReconnect(err) {
return nil
}

sshClient, err := newSSHClient(c.address, c.config)
if err != nil {
return err
}

newSftpClient, err := sftpClient.NewClient(sshClient)
if err != nil {
_ = sshClient.Close()
return fmt.Errorf("sftp binding error: error create sftp client: %w", err)
}

// Swap under short lock; close old clients after unlocking.
c.lock.Lock()
oldSftp := c.sftpClient
oldSSH := c.sshClient
c.sftpClient = newSftpClient
c.sshClient = sshClient
c.lock.Unlock()

if oldSftp != nil {
_ = oldSftp.Close()
}
if oldSSH != nil {
_ = oldSSH.Close()
}

return nil
}

func newSSHClient(address string, config *ssh.ClientConfig) (*ssh.Client, error) {
sshClient, err := ssh.Dial("tcp", address, config)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error create ssh client: %w", err)
}
return sshClient, nil
}

// shouldReconnect returns true if the error looks like a transport-level failure
func shouldReconnect(err error) bool {
if err == nil {
return false
}

// SFTP status errors that are logical, not connectivity (avoid reconnect)
if errors.Is(err, sftpClient.ErrSSHFxPermissionDenied) ||
errors.Is(err, sftpClient.ErrSSHFxNoSuchFile) ||
errors.Is(err, sftpClient.ErrSSHFxOpUnsupported) {
return false
}

return true
}
55 changes: 31 additions & 24 deletions bindings/sftp/sftp.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package sftp

import (
Expand Down Expand Up @@ -25,9 +38,9 @@ const (

// Sftp is a binding for file operations on sftp server.
type Sftp struct {
metadata *sftpMetadata
logger logger.Logger
sftpClient *sftpClient.Client
metadata *sftpMetadata
logger logger.Logger
c *Client
}

// sftpMetadata defines the sftp metadata.
Expand Down Expand Up @@ -115,19 +128,12 @@ func (sftp *Sftp) Init(_ context.Context, metadata bindings.Metadata) error {
HostKeyCallback: hostKeyCallback,
}

sshClient, err := ssh.Dial("tcp", m.Address, config)
if err != nil {
return fmt.Errorf("sftp binding error: error create ssh client: %w", err)
}

newSftpClient, err := sftpClient.NewClient(sshClient)
sftp.metadata = m
sftp.c, err = newClient(m.Address, config)
if err != nil {
return fmt.Errorf("sftp binding error: error create sftp client: %w", err)
return fmt.Errorf("sftp binding error: create sftp client error: %w", err)
}

sftp.metadata = m
sftp.sftpClient = newSftpClient

return nil
}

Expand Down Expand Up @@ -161,14 +167,9 @@ func (sftp *Sftp) create(_ context.Context, req *bindings.InvokeRequest) (*bindi
return nil, fmt.Errorf("sftp binding error: %w", err)
}

dir, fileName := sftpClient.Split(path)
c := sftp.c

err = sftp.sftpClient.MkdirAll(dir)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error create dir %s: %w", dir, err)
}

file, err := sftp.sftpClient.Create(path)
file, fileName, err := c.create(path)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error create file %s: %w", path, err)
}
Expand Down Expand Up @@ -211,7 +212,9 @@ func (sftp *Sftp) list(_ context.Context, req *bindings.InvokeRequest) (*binding
return nil, fmt.Errorf("sftp binding error: %w", err)
}

files, err := sftp.sftpClient.ReadDir(path)
c := sftp.c

files, err := c.list(path)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error read dir %s: %w", path, err)
}
Expand Down Expand Up @@ -246,7 +249,9 @@ func (sftp *Sftp) get(_ context.Context, req *bindings.InvokeRequest) (*bindings
return nil, fmt.Errorf("sftp binding error: %w", err)
}

file, err := sftp.sftpClient.Open(path)
c := sftp.c

file, err := c.get(path)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error open file %s: %w", path, err)
}
Expand All @@ -272,7 +277,9 @@ func (sftp *Sftp) delete(_ context.Context, req *bindings.InvokeRequest) (*bindi
return nil, fmt.Errorf("sftp binding error: %w", err)
}

err = sftp.sftpClient.Remove(path)
c := sftp.c

err = c.delete(path)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error remove file %s: %w", path, err)
}
Expand All @@ -296,7 +303,7 @@ func (sftp *Sftp) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
}

func (sftp *Sftp) Close() error {
return sftp.sftpClient.Close()
return sftp.c.Close()
}

func (metadata sftpMetadata) getPath(requestMetadata map[string]string) (path string, err error) {
Expand Down
Loading