From 2abb6662d0c6200a6a977e53c60ed7ae9d15c9ee Mon Sep 17 00:00:00 2001 From: Kirill Batalin Date: Sat, 22 Feb 2025 23:18:45 +0000 Subject: [PATCH] =?UTF-8?q?fix(common):=20Handle=20downloads=20from=20sour?= =?UTF-8?q?ces=20that=20don=E2=80=99t=20provide=20a=20"Content-Length"=20h?= =?UTF-8?q?eader?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In some cases, sources do not return the Content-Length header. This can happen due to a proxy or simply because the source does not include it by design. Currently, you-get either: - Attempts to download the next range with an invalid Range header (out of bounds), resulting in an HTTP error, or - Gets stuck in an infinite loop of attempts. This PR fixes the issue by adjusting the exit condition: If the expected file size is infinite (i.e., undetermined) and the buffer has been fully read, the loop exits successfully. Additionally, this PR includes other fixes to ensure proper functionality: - Correctly processes URLs when provided as a list with a single element. - Prevents deleting the temporary file in this case, as it is actually the final result, not a temporary file. A test has been added to reproduce and verify the issue. --- src/you_get/common.py | 14 ++++---- tests/test_common.py | 76 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 7 deletions(-) diff --git a/src/you_get/common.py b/src/you_get/common.py index 0b307ddee8..385a67e2d5 100755 --- a/src/you_get/common.py +++ b/src/you_get/common.py @@ -672,6 +672,7 @@ def url_save( headers=None, timeout=None, **kwargs ): tmp_headers = headers.copy() if headers is not None else {} + url = url[0] if type(url) is list and len(url) == 1 else url # When a referer specified with param refer, # the key must be 'Referer' for the hack here if refer is not None: @@ -807,9 +808,9 @@ def numreturn(a): except socket.timeout: pass if not buffer: - if is_chunked and received_chunk == range_length: + if is_chunked and (received_chunk == range_length or range_length == float('inf')): break - elif not is_chunked and received == file_size: # Download finished + elif not is_chunked and (received == file_size or range_length == float('inf')): # Download finished break # Unexpected termination. Retry request tmp_headers['Range'] = 'bytes=' + str(received - chunk_start) + '-' @@ -827,10 +828,11 @@ def numreturn(a): received, os.path.getsize(temp_filepath), temp_filepath ) - if os.access(filepath, os.W_OK): - # on Windows rename could fail if destination filepath exists - os.remove(filepath) - os.rename(temp_filepath, filepath) + if temp_filepath != filepath: + if os.access(filepath, os.W_OK): + # on Windows rename could fail if destination filepath exists + os.remove(filepath) + os.rename(temp_filepath, filepath) class SimpleProgressBar: diff --git a/tests/test_common.py b/tests/test_common.py index f1ef92629b..7f9b23e1cc 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,11 +1,85 @@ #!/usr/bin/env python - +import http.server +import socketserver +import tempfile +import threading import unittest from you_get.common import * + class TestCommon(unittest.TestCase): def test_match1(self): self.assertEqual(match1('http://youtu.be/1234567890A', r'youtu.be/([^/]+)'), '1234567890A') self.assertEqual(match1('http://youtu.be/1234567890A', r'youtu.be/([^/]+)', r'youtu.(\w+)'), ['1234567890A', 'be']) + + +class TestDownloadUrlWithoutContentLength(unittest.TestCase): + def setUp(self): + self.server = ChunkedTestServer() + self.port = self.server.start() + + def tearDown(self): + self.server.stop() + + def test_server_response(self): + response = request.urlopen(f'http://localhost:{self.port}') + self.assertEqual(response.status, 200) + self.assertNotIn('Content-Length', response.headers) + + expected_data = b'First chunk of data\nSecond chunk of data\nLast chunk of data' + self.assertEqual(response.read(), expected_data) + + def test_url_save(self): + with tempfile.NamedTemporaryFile() as temp_file: + temp_path = temp_file.name + + try: + url_save([f'http://localhost:{self.port}'], temp_path, None) + + with open(temp_path, "r") as f: + expected_data = 'First chunk of data\nSecond chunk of data\nLast chunk of data' + self.assertEqual(f.read(), expected_data) + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + +class ChunkedHTTPRequestHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self): + self.send_response(200) + self.send_header('Transfer-Encoding', 'chunked') + self.end_headers() + + # Send data in chunks + chunks = [b"First chunk of data\n", + b"Second chunk of data\n", + b"Last chunk of data"] + + for chunk in chunks: + self.wfile.write(f"{len(chunk):x}\r\n".encode()) + self.wfile.write(chunk) + self.wfile.write(b"\r\n") + + # Write the final chunk (zero-length chunk to indicate the end) + self.wfile.write(b"0\r\n\r\n") + + +class ChunkedTestServer: + def __init__(self, port=0): + self.port = port + self.server = socketserver.TCPServer(('localhost', port), ChunkedHTTPRequestHandler) + self.server_thread = None + + def start(self): + self.server_thread = threading.Thread(target=self.server.serve_forever) + self.server_thread.daemon = True + self.server_thread.start() + self.port = self.server.server_address[1] + return self.port + + def stop(self): + self.server.shutdown() + self.server.server_close() + self.server_thread.join()