Skip to content

Commit

Permalink
Merge pull request #621 from Jakuje/large-scp
Browse files Browse the repository at this point in the history
  • Loading branch information
webknjaz committed Jun 27, 2024
2 parents c43aad2 + 985f44d commit 3888862
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/changelog-fragments/621.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Downloading files larger than 64kB over SCP no longer fails -- by :user:`Jakuje`.
24 changes: 16 additions & 8 deletions src/pylibsshext/scp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ from pylibsshext.errors cimport LibsshSCPException
from pylibsshext.session cimport get_libssh_session


SCP_MAX_CHUNK = 65536


cdef class SCP:
def __cinit__(self, session):
self.session = session
Expand Down Expand Up @@ -122,7 +125,9 @@ cdef class SCP:
size = libssh.ssh_scp_request_get_size(scp)
mode = libssh.ssh_scp_request_get_permissions(scp)

read_buffer = <char *>PyMem_Malloc(size)
# cap the buffer size to reasonable number -- libssh will not return the whole data at once anyway
read_buffer_size = min(size, SCP_MAX_CHUNK)
read_buffer = <char *>PyMem_Malloc(read_buffer_size)
if read_buffer is NULL:
raise LibsshSCPException("Memory allocation error")

Expand All @@ -131,14 +136,17 @@ cdef class SCP:
if rc == libssh.SSH_ERROR:
raise LibsshSCPException("Failed to start read request: %s" % self._get_ssh_error_str())

# Read the file
rc = libssh.ssh_scp_read(scp, read_buffer, size)
if rc == libssh.SSH_ERROR:
raise LibsshSCPException("Error receiving file data: %s" % self._get_ssh_error_str())

py_file_bytes = read_buffer[:size]
remaining_bytes_to_read = size
with open(local_file, "wb") as f:
f.write(py_file_bytes)
while remaining_bytes_to_read > 0:
requested_read_bytes = min(remaining_bytes_to_read, read_buffer_size)
read_bytes = libssh.ssh_scp_read(scp, read_buffer, requested_read_bytes)
if read_bytes == libssh.SSH_ERROR:
raise LibsshSCPException("Error receiving file data: %s" % self._get_ssh_error_str())

py_file_bytes = read_buffer[:read_bytes]
f.write(py_file_bytes)
remaining_bytes_to_read -= read_bytes
if mode >= 0:
os.chmod(local_file, mode)

Expand Down
45 changes: 45 additions & 0 deletions tests/unit/scp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""Tests suite for scp."""

import os
import random
import string
import uuid

import pytest
Expand Down Expand Up @@ -75,3 +77,46 @@ def test_copy_from_non_existent_remote_path(path_to_non_existent_src_file, ssh_s
error_msg = '^Error receiving information about file:'
with pytest.raises(LibsshSCPException, match=error_msg):
ssh_scp.get(str(path_to_non_existent_src_file), os.devnull)


@pytest.fixture
def pre_existing_file_path(tmp_path):
"""Return local path for a pre-populated file."""
path = tmp_path / 'pre-existing-file.txt'
path.write_bytes(b'whatever')
return path


def test_get_existing_local(pre_existing_file_path, src_path, ssh_scp, transmit_payload):
"""Check that SCP file download works and overwrites local file if it exists."""
ssh_scp.get(str(src_path), str(pre_existing_file_path))
assert pre_existing_file_path.read_bytes() == transmit_payload


@pytest.fixture
def large_payload():
"""Generate a large 65537 byte (64kB+1B) test payload."""
random_char_kilobyte = [ord(random.choice(string.printable)) for _ in range(1024)]
full_bytes_number = 64
a_64kB_chunk = bytes(random_char_kilobyte * full_bytes_number)
the_last_byte = random.choice(random_char_kilobyte).to_bytes(length=1, byteorder='big')
return a_64kB_chunk + the_last_byte


@pytest.fixture
def src_path_large(tmp_path, large_payload):
"""Return a remote path that to a 65537 byte-sized file.
Typical single-read chunk size is 64kB in ``libssh`` so
the test needs a file that would overflow that to trigger
the read loop.
"""
path = tmp_path / 'large.txt'
path.write_bytes(large_payload)
return path


def test_get_large(dst_path, src_path_large, ssh_scp, large_payload):
"""Check that SCP file download gets over 64kB of data."""
ssh_scp.get(str(src_path_large), str(dst_path))
assert dst_path.read_bytes() == large_payload

0 comments on commit 3888862

Please sign in to comment.