diff --git a/README.rst b/README.rst index 90b6e21..8db4e83 100644 --- a/README.rst +++ b/README.rst @@ -54,7 +54,7 @@ forwarding :: support -And finally there is a sample which shows how to copy a file from local to +There is a sample which shows how to copy a file from local to remote machine. You can also define owner and mode of the target :: >>> fd = open('test.txt', 'w') @@ -67,3 +67,16 @@ remote machine. You can also define owner and mode of the target :: Hello world >>> print conn.run('ls -l /tmp/test.txt').stdout -rw-rw-rw- 1 nobody nogroup ... /tmp/test.txt + + +You can also pass file-like objects instead of filenames to scp method. Behind +the scenes the method creates temporary files for you, send them to remote +target and then removes everything which has been created:: + + >>> from StringIO import StringIO + >>> data = StringIO('test') + >>> from openssh_wrapper import SSHConnection + >>> conn = SSHConnection('localhost', login='root') + >>> conn.scp((data, ), target='/tmp/test.txt', mode='0644') + >>> print open('/tmp/test.txt').read() + test diff --git a/openssh_wrapper.py b/openssh_wrapper.py index 5894a8f..431f601 100644 --- a/openssh_wrapper.py +++ b/openssh_wrapper.py @@ -2,7 +2,7 @@ """ This is a wrapper around the openssh binaries ssh and scp. """ -import re, os, subprocess, signal, pipes +import re, os, subprocess, signal, pipes, tempfile, shutil __all__ = 'SSHConnection SSHResult SSHError'.split() @@ -129,7 +129,13 @@ def scp(self, files, target, mode=None, owner=None): understandable by chown). Makes sence only if you open your connection as root. """ - scp_command = self.scp_command(files, target) + filenames, tmpdir = self.convert_files_to_filenames(files) + + def cleanup_tmp_dir(): + if tmpdir: + shutil.rmtree(tmpdir, ignore_errors=True) + + scp_command = self.scp_command(filenames, target) pipe = subprocess.Popen(scp_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=self.get_env()) @@ -142,27 +148,55 @@ def scp(self, files, target, mode=None, owner=None): # pipe.terminate() # only in python 2.6 allowed os.kill(pipe.pid, signal.SIGTERM) signal.alarm(0) # disable alarm + cleanup_tmp_dir() raise SSHError(stderr=str(exc)) signal.alarm(0) # disable alarm returncode = pipe.returncode if returncode != 0: # ssh client error + cleanup_tmp_dir() raise SSHError(err.strip()) if mode or owner: - targets = self.get_scp_targets(files, target) # XXX: files VS filenames + targets = self.get_scp_targets(filenames, target) if mode: cmd_chunks = ['chmod', mode] + targets cmd = ' '.join([pipes.quote(chunk) for chunk in cmd_chunks]) result = self.run(cmd) if result.returncode: + cleanup_tmp_dir() raise SSHError(result.stderr.strip()) if owner: cmd_chunks = ['chown', owner] + targets cmd = ' '.join([pipes.quote(chunk) for chunk in cmd_chunks]) result = self.run(cmd) if result.returncode: + cleanup_tmp_dir() raise SSHError(result.stderr.strip()) + def convert_files_to_filenames(self, files): + """ + Check for every file in list and save it locally to send to remote side, if needed + """ + filenames = [] + tmpdir = None + for file_obj in files: + if isinstance(file_obj, basestring): + filenames.append(file_obj) + else: + if not tmpdir: + tmpdir = tempfile.mkdtemp() + if hasattr(file_obj, 'name'): + basename = os.path.basename(file_obj.name) + tmpname = os.path.join(tmpdir, basename) + fd = open(tmpname, 'w') + fd.write(file_obj.read()) + fd.close() + else: + tmpfd, tmpname = tempfile.mkstemp(dir=tmpdir) + os.write(tmpfd, file_obj.read()) + os.close(tmpfd) + filenames.append(tmpname) + return filenames, tmpdir def get_scp_targets(self, filenames, target): diff --git a/tests.py b/tests.py index 1a2cf03..1e774d6 100644 --- a/tests.py +++ b/tests.py @@ -57,7 +57,7 @@ class TestSCP(object): def setUp(self): self.c = SSHConnection('localhost', login='root') - self.c.run('rm -f /tmp/*.py') + self.c.run('rm -f /tmp/*.py /tmp/test*.txt') def test_scp(self): self.c.scp((test_file, ), target='/tmp') @@ -80,3 +80,15 @@ def test_owner(self): stat = os.stat('/tmp/tests.py') eq_(stat.st_uid, uid) eq_(stat.st_gid, gid) + + def test_file_descriptors(self): + from StringIO import StringIO + # name is set explicitly as target + fd1 = StringIO('test') + self.c.scp((fd1, ), target='/tmp/test1.txt', mode='0644') + eq_(open('/tmp/test1.txt').read(), 'test') + # name is set explicitly in the name option + fd2 = StringIO('test') + fd2.name = 'test2.txt' + self.c.scp((fd2, ), target='/tmp', mode='0644') + eq_(open('/tmp/test2.txt').read(), 'test')