diff --git a/test/test_download.py b/test/test_download.py
index 43b39c36b..fd7752cdd 100755
--- a/test/test_download.py
+++ b/test/test_download.py
@@ -10,10 +10,7 @@
import collections
import hashlib
-import http.client
import json
-import socket
-import urllib.error
from test.helper import (
assertGreaterEqual,
@@ -29,6 +26,7 @@
import yt_dlp.YoutubeDL # isort: split
from yt_dlp.extractor import get_info_extractor
+from yt_dlp.networking.exceptions import HTTPError, TransportError
from yt_dlp.utils import (
DownloadError,
ExtractorError,
@@ -162,8 +160,7 @@ def try_rm_tcs_files(tcs=None):
force_generic_extractor=params.get('force_generic_extractor', False))
except (DownloadError, ExtractorError) as err:
# Check if the exception is not a network related one
- if (err.exc_info[0] not in (urllib.error.URLError, socket.timeout, UnavailableVideoError, http.client.BadStatusLine)
- or (err.exc_info[0] == urllib.error.HTTPError and err.exc_info[1].code == 503)):
+ if not isinstance(err.exc_info[1], (TransportError, UnavailableVideoError)) or (isinstance(err.exc_info[1], HTTPError) and err.exc_info[1].code == 503):
err.msg = f'{getattr(err, "msg", err)} ({tname})'
raise
@@ -249,7 +246,7 @@ def try_rm_tcs_files(tcs=None):
# extractor returns full results even with extract_flat
res_tcs = [{'info_dict': e} for e in res_dict['entries']]
try_rm_tcs_files(res_tcs)
-
+ ydl.close()
return test_template
diff --git a/test/test_networking.py b/test/test_networking.py
index e4e66dce1..147a4ff49 100644
--- a/test/test_networking.py
+++ b/test/test_networking.py
@@ -3,32 +3,74 @@
# Allow direct execution
import os
import sys
-import unittest
+
+import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import functools
import gzip
+import http.client
import http.cookiejar
import http.server
+import inspect
import io
import pathlib
+import random
import ssl
import tempfile
import threading
+import time
import urllib.error
import urllib.request
+import warnings
import zlib
+from email.message import Message
+from http.cookiejar import CookieJar
-from test.helper import http_server_port
-from yt_dlp import YoutubeDL
+from test.helper import FakeYDL, http_server_port
from yt_dlp.dependencies import brotli
-from yt_dlp.utils import sanitized_Request, urlencode_postdata
-
-from .helper import FakeYDL
+from yt_dlp.networking import (
+ HEADRequest,
+ PUTRequest,
+ Request,
+ RequestDirector,
+ RequestHandler,
+ Response,
+)
+from yt_dlp.networking._urllib import UrllibRH
+from yt_dlp.networking.common import _REQUEST_HANDLERS
+from yt_dlp.networking.exceptions import (
+ CertificateVerifyError,
+ HTTPError,
+ IncompleteRead,
+ NoSupportingHandlers,
+ RequestError,
+ SSLError,
+ TransportError,
+ UnsupportedRequest,
+)
+from yt_dlp.utils._utils import _YDLLogger as FakeLogger
+from yt_dlp.utils.networking import HTTPHeaderDict
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
+def _build_proxy_handler(name):
+ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
+ proxy_name = name
+
+ def log_message(self, format, *args):
+ pass
+
+ def do_GET(self):
+ self.send_response(200)
+ self.send_header('Content-Type', 'text/plain; charset=utf-8')
+ self.end_headers()
+ self.wfile.write('{self.proxy_name}: {self.path}'.format(self=self).encode())
+ return HTTPTestRequestHandler
+
+
class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
@@ -36,7 +78,7 @@ def log_message(self, format, *args):
pass
def _headers(self):
- payload = str(self.headers).encode('utf-8')
+ payload = str(self.headers).encode()
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.send_header('Content-Length', str(len(payload)))
@@ -70,7 +112,7 @@ def _read_data(self):
return self.rfile.read(int(self.headers['Content-Length']))
def do_POST(self):
- data = self._read_data()
+ data = self._read_data() + str(self.headers).encode()
if self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
@@ -89,7 +131,7 @@ def do_HEAD(self):
self._status(404)
def do_PUT(self):
- data = self._read_data()
+ data = self._read_data() + str(self.headers).encode()
if self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
@@ -102,7 +144,7 @@ def do_GET(self):
payload = b''
self.send_response(200)
self.send_header('Content-Type', 'text/html; charset=utf-8')
- self.send_header('Content-Length', str(len(payload))) # required for persistent connections
+ self.send_header('Content-Length', str(len(payload)))
self.end_headers()
self.wfile.write(payload)
elif self.path == '/vid.mp4':
@@ -126,10 +168,15 @@ def do_GET(self):
self.send_header('Content-Length', str(len(payload)))
self.end_headers()
self.wfile.write(payload)
+ elif self.path.startswith('/redirect_loop'):
+ self.send_response(301)
+ self.send_header('Location', self.path)
+ self.send_header('Content-Length', '0')
+ self.end_headers()
elif self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
- self._method('GET')
+ self._method('GET', str(self.headers).encode())
elif self.path.startswith('/headers'):
self._headers()
elif self.path.startswith('/308-to-headers'):
@@ -179,7 +226,32 @@ def do_GET(self):
self.send_header('Content-Length', str(len(payload)))
self.end_headers()
self.wfile.write(payload)
-
+ elif self.path.startswith('/gen_'):
+ payload = b''
+ self.send_response(int(self.path[len('/gen_'):]))
+ self.send_header('Content-Type', 'text/html; charset=utf-8')
+ self.send_header('Content-Length', str(len(payload)))
+ self.end_headers()
+ self.wfile.write(payload)
+ elif self.path.startswith('/incompleteread'):
+ payload = b''
+ self.send_response(200)
+ self.send_header('Content-Type', 'text/html; charset=utf-8')
+ self.send_header('Content-Length', '234234')
+ self.end_headers()
+ self.wfile.write(payload)
+ self.finish()
+ elif self.path.startswith('/timeout_'):
+ time.sleep(int(self.path[len('/timeout_'):]))
+ self._headers()
+ elif self.path == '/source_address':
+ payload = str(self.client_address[0]).encode()
+ self.send_response(200)
+ self.send_header('Content-Type', 'text/html; charset=utf-8')
+ self.send_header('Content-Length', str(len(payload)))
+ self.end_headers()
+ self.wfile.write(payload)
+ self.finish()
else:
self._status(404)
@@ -198,334 +270,1099 @@ def send_header(self, keyword, value):
self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
-class FakeLogger:
- def debug(self, msg):
- pass
-
- def warning(self, msg):
- pass
-
- def error(self, msg):
- pass
+def validate_and_send(rh, req):
+ rh.validate(req)
+ return rh.send(req)
-class TestHTTP(unittest.TestCase):
- def setUp(self):
- # HTTP server
- self.http_httpd = http.server.ThreadingHTTPServer(
+class TestRequestHandlerBase:
+ @classmethod
+ def setup_class(cls):
+ cls.http_httpd = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler)
- self.http_port = http_server_port(self.http_httpd)
- self.http_server_thread = threading.Thread(target=self.http_httpd.serve_forever)
+ cls.http_port = http_server_port(cls.http_httpd)
+ cls.http_server_thread = threading.Thread(target=cls.http_httpd.serve_forever)
# FIXME: we should probably stop the http server thread after each test
# See: https://github.com/yt-dlp/yt-dlp/pull/7094#discussion_r1199746041
- self.http_server_thread.daemon = True
- self.http_server_thread.start()
+ cls.http_server_thread.daemon = True
+ cls.http_server_thread.start()
# HTTPS server
certfn = os.path.join(TEST_DIR, 'testcert.pem')
- self.https_httpd = http.server.ThreadingHTTPServer(
+ cls.https_httpd = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None)
- self.https_httpd.socket = sslctx.wrap_socket(self.https_httpd.socket, server_side=True)
- self.https_port = http_server_port(self.https_httpd)
- self.https_server_thread = threading.Thread(target=self.https_httpd.serve_forever)
- self.https_server_thread.daemon = True
- self.https_server_thread.start()
+ cls.https_httpd.socket = sslctx.wrap_socket(cls.https_httpd.socket, server_side=True)
+ cls.https_port = http_server_port(cls.https_httpd)
+ cls.https_server_thread = threading.Thread(target=cls.https_httpd.serve_forever)
+ cls.https_server_thread.daemon = True
+ cls.https_server_thread.start()
- def test_nocheckcertificate(self):
- with FakeYDL({'logger': FakeLogger()}) as ydl:
- with self.assertRaises(urllib.error.URLError):
- ydl.urlopen(sanitized_Request(f'https://127.0.0.1:{self.https_port}/headers'))
- with FakeYDL({'logger': FakeLogger(), 'nocheckcertificate': True}) as ydl:
- r = ydl.urlopen(sanitized_Request(f'https://127.0.0.1:{self.https_port}/headers'))
- self.assertEqual(r.status, 200)
+@pytest.fixture
+def handler(request):
+ RH_KEY = request.param
+ if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
+ handler = RH_KEY
+ elif RH_KEY in _REQUEST_HANDLERS:
+ handler = _REQUEST_HANDLERS[RH_KEY]
+ else:
+ pytest.skip(f'{RH_KEY} request handler is not available')
+
+ return functools.partial(handler, logger=FakeLogger)
+
+
+class TestHTTPRequestHandler(TestRequestHandlerBase):
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_verify_cert(self, handler):
+ with handler() as rh:
+ with pytest.raises(CertificateVerifyError):
+ validate_and_send(rh, Request(f'https://127.0.0.1:{self.https_port}/headers'))
+
+ with handler(verify=False) as rh:
+ r = validate_and_send(rh, Request(f'https://127.0.0.1:{self.https_port}/headers'))
+ assert r.status == 200
r.close()
- def test_percent_encode(self):
- with FakeYDL() as ydl:
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_ssl_error(self, handler):
+ # HTTPS server with too old TLS version
+ # XXX: is there a better way to test this than to create a new server?
+ https_httpd = http.server.ThreadingHTTPServer(
+ ('127.0.0.1', 0), HTTPTestRequestHandler)
+ sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ https_httpd.socket = sslctx.wrap_socket(https_httpd.socket, server_side=True)
+ https_port = http_server_port(https_httpd)
+ https_server_thread = threading.Thread(target=https_httpd.serve_forever)
+ https_server_thread.daemon = True
+ https_server_thread.start()
+
+ with handler(verify=False) as rh:
+ with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info:
+ validate_and_send(rh, Request(f'https://127.0.0.1:{https_port}/headers'))
+ assert not issubclass(exc_info.type, CertificateVerifyError)
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_percent_encode(self, handler):
+ with handler() as rh:
# Unicode characters should be encoded with uppercase percent-encoding
- res = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/中文.html'))
- self.assertEqual(res.status, 200)
+ res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/中文.html'))
+ assert res.status == 200
res.close()
# don't normalize existing percent encodings
- res = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/%c7%9f'))
- self.assertEqual(res.status, 200)
+ res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/%c7%9f'))
+ assert res.status == 200
res.close()
- def test_unicode_path_redirection(self):
- with FakeYDL() as ydl:
- r = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/302-non-ascii-redirect'))
- self.assertEqual(r.url, f'http://127.0.0.1:{self.http_port}/%E4%B8%AD%E6%96%87.html')
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_unicode_path_redirection(self, handler):
+ with handler() as rh:
+ r = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/302-non-ascii-redirect'))
+ assert r.url == f'http://127.0.0.1:{self.http_port}/%E4%B8%AD%E6%96%87.html'
r.close()
- def test_redirect(self):
- with FakeYDL() as ydl:
- def do_req(redirect_status, method):
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_raise_http_error(self, handler):
+ with handler() as rh:
+ for bad_status in (400, 500, 599, 302):
+ with pytest.raises(HTTPError):
+ validate_and_send(rh, Request('http://127.0.0.1:%d/gen_%d' % (self.http_port, bad_status)))
+
+ # Should not raise an error
+ validate_and_send(rh, Request('http://127.0.0.1:%d/gen_200' % self.http_port)).close()
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_response_url(self, handler):
+ with handler() as rh:
+ # Response url should be that of the last url in redirect chain
+ res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/redirect_301'))
+ assert res.url == f'http://127.0.0.1:{self.http_port}/method'
+ res.close()
+ res2 = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/gen_200'))
+ assert res2.url == f'http://127.0.0.1:{self.http_port}/gen_200'
+ res2.close()
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_redirect(self, handler):
+ with handler() as rh:
+ def do_req(redirect_status, method, assert_no_content=False):
data = b'testdata' if method in ('POST', 'PUT') else None
- res = ydl.urlopen(sanitized_Request(
- f'http://127.0.0.1:{self.http_port}/redirect_{redirect_status}', method=method, data=data))
- return res.read().decode('utf-8'), res.headers.get('method', '')
+ res = validate_and_send(
+ rh, Request(f'http://127.0.0.1:{self.http_port}/redirect_{redirect_status}', method=method, data=data))
+
+ headers = b''
+ data_sent = b''
+ if data is not None:
+ data_sent += res.read(len(data))
+ if data_sent != data:
+ headers += data_sent
+ data_sent = b''
+
+ headers += res.read()
+
+ if assert_no_content or data is None:
+ assert b'Content-Type' not in headers
+ assert b'Content-Length' not in headers
+ else:
+ assert b'Content-Type' in headers
+ assert b'Content-Length' in headers
+
+ return data_sent.decode(), res.headers.get('method', '')
# A 303 must either use GET or HEAD for subsequent request
- self.assertEqual(do_req(303, 'POST'), ('', 'GET'))
- self.assertEqual(do_req(303, 'HEAD'), ('', 'HEAD'))
+ assert do_req(303, 'POST', True) == ('', 'GET')
+ assert do_req(303, 'HEAD') == ('', 'HEAD')
- self.assertEqual(do_req(303, 'PUT'), ('', 'GET'))
+ assert do_req(303, 'PUT', True) == ('', 'GET')
# 301 and 302 turn POST only into a GET
- # XXX: we should also test if the Content-Type and Content-Length headers are removed
- self.assertEqual(do_req(301, 'POST'), ('', 'GET'))
- self.assertEqual(do_req(301, 'HEAD'), ('', 'HEAD'))
- self.assertEqual(do_req(302, 'POST'), ('', 'GET'))
- self.assertEqual(do_req(302, 'HEAD'), ('', 'HEAD'))
+ assert do_req(301, 'POST', True) == ('', 'GET')
+ assert do_req(301, 'HEAD') == ('', 'HEAD')
+ assert do_req(302, 'POST', True) == ('', 'GET')
+ assert do_req(302, 'HEAD') == ('', 'HEAD')
- self.assertEqual(do_req(301, 'PUT'), ('testdata', 'PUT'))
- self.assertEqual(do_req(302, 'PUT'), ('testdata', 'PUT'))
+ assert do_req(301, 'PUT') == ('testdata', 'PUT')
+ assert do_req(302, 'PUT') == ('testdata', 'PUT')
# 307 and 308 should not change method
for m in ('POST', 'PUT'):
- self.assertEqual(do_req(307, m), ('testdata', m))
- self.assertEqual(do_req(308, m), ('testdata', m))
+ assert do_req(307, m) == ('testdata', m)
+ assert do_req(308, m) == ('testdata', m)
- self.assertEqual(do_req(307, 'HEAD'), ('', 'HEAD'))
- self.assertEqual(do_req(308, 'HEAD'), ('', 'HEAD'))
+ assert do_req(307, 'HEAD') == ('', 'HEAD')
+ assert do_req(308, 'HEAD') == ('', 'HEAD')
# These should not redirect and instead raise an HTTPError
for code in (300, 304, 305, 306):
- with self.assertRaises(urllib.error.HTTPError):
+ with pytest.raises(HTTPError):
do_req(code, 'GET')
- def test_content_type(self):
- # https://github.com/yt-dlp/yt-dlp/commit/379a4f161d4ad3e40932dcf5aca6e6fb9715ab28
- with FakeYDL({'nocheckcertificate': True}) as ydl:
- # method should be auto-detected as POST
- r = sanitized_Request(f'https://localhost:{self.https_port}/headers', data=urlencode_postdata({'test': 'test'}))
-
- headers = ydl.urlopen(r).read().decode('utf-8')
- self.assertIn('Content-Type: application/x-www-form-urlencoded', headers)
-
- # test http
- r = sanitized_Request(f'http://localhost:{self.http_port}/headers', data=urlencode_postdata({'test': 'test'}))
- headers = ydl.urlopen(r).read().decode('utf-8')
- self.assertIn('Content-Type: application/x-www-form-urlencoded', headers)
-
- def test_cookiejar(self):
- with FakeYDL() as ydl:
- ydl.cookiejar.set_cookie(http.cookiejar.Cookie(
- 0, 'test', 'ytdlp', None, False, '127.0.0.1', True,
- False, '/headers', True, False, None, False, None, None, {}))
- data = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/headers')).read()
- self.assertIn(b'Cookie: test=ytdlp', data)
-
- def test_passed_cookie_header(self):
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_request_cookie_header(self, handler):
# We should accept a Cookie header being passed as in normal headers and handle it appropriately.
- with FakeYDL() as ydl:
+ with handler() as rh:
# Specified Cookie header should be used
- res = ydl.urlopen(
- sanitized_Request(f'http://127.0.0.1:{self.http_port}/headers',
- headers={'Cookie': 'test=test'})).read().decode('utf-8')
- self.assertIn('Cookie: test=test', res)
+ res = validate_and_send(
+ rh, Request(
+ f'http://127.0.0.1:{self.http_port}/headers',
+ headers={'Cookie': 'test=test'})).read().decode()
+ assert 'Cookie: test=test' in res
# Specified Cookie header should be removed on any redirect
- res = ydl.urlopen(
- sanitized_Request(f'http://127.0.0.1:{self.http_port}/308-to-headers', headers={'Cookie': 'test=test'})).read().decode('utf-8')
- self.assertNotIn('Cookie: test=test', res)
+ res = validate_and_send(
+ rh, Request(
+ f'http://127.0.0.1:{self.http_port}/308-to-headers',
+ headers={'Cookie': 'test=test'})).read().decode()
+ assert 'Cookie: test=test' not in res
- # Specified Cookie header should override global cookiejar for that request
- ydl.cookiejar.set_cookie(http.cookiejar.Cookie(
- version=0, name='test', value='ytdlp', port=None, port_specified=False,
- domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
- path_specified=True, secure=False, expires=None, discard=False, comment=None,
- comment_url=None, rest={}))
+ # Specified Cookie header should override global cookiejar for that request
+ cookiejar = http.cookiejar.CookieJar()
+ cookiejar.set_cookie(http.cookiejar.Cookie(
+ version=0, name='test', value='ytdlp', port=None, port_specified=False,
+ domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
+ path_specified=True, secure=False, expires=None, discard=False, comment=None,
+ comment_url=None, rest={}))
- data = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/headers', headers={'Cookie': 'test=test'})).read()
- self.assertNotIn(b'Cookie: test=ytdlp', data)
- self.assertIn(b'Cookie: test=test', data)
+ with handler(cookiejar=cookiejar) as rh:
+ data = validate_and_send(
+ rh, Request(f'http://127.0.0.1:{self.http_port}/headers', headers={'cookie': 'test=test'})).read()
+ assert b'Cookie: test=ytdlp' not in data
+ assert b'Cookie: test=test' in data
- def test_no_compression_compat_header(self):
- with FakeYDL() as ydl:
- data = ydl.urlopen(
- sanitized_Request(
- f'http://127.0.0.1:{self.http_port}/headers',
- headers={'Youtubedl-no-compression': True})).read()
- self.assertIn(b'Accept-Encoding: identity', data)
- self.assertNotIn(b'youtubedl-no-compression', data.lower())
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_redirect_loop(self, handler):
+ with handler() as rh:
+ with pytest.raises(HTTPError, match='redirect loop'):
+ validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/redirect_loop'))
- def test_gzip_trailing_garbage(self):
- # https://github.com/ytdl-org/youtube-dl/commit/aa3e950764337ef9800c936f4de89b31c00dfcf5
- # https://github.com/ytdl-org/youtube-dl/commit/6f2ec15cee79d35dba065677cad9da7491ec6e6f
- with FakeYDL() as ydl:
- data = ydl.urlopen(sanitized_Request(f'http://localhost:{self.http_port}/trailing_garbage')).read().decode('utf-8')
- self.assertEqual(data, '')
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_incompleteread(self, handler):
+ with handler(timeout=2) as rh:
+ with pytest.raises(IncompleteRead):
+ validate_and_send(rh, Request('http://127.0.0.1:%d/incompleteread' % self.http_port)).read()
- @unittest.skipUnless(brotli, 'brotli support is not installed')
- def test_brotli(self):
- with FakeYDL() as ydl:
- res = ydl.urlopen(
- sanitized_Request(
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_cookies(self, handler):
+ cookiejar = http.cookiejar.CookieJar()
+ cookiejar.set_cookie(http.cookiejar.Cookie(
+ 0, 'test', 'ytdlp', None, False, '127.0.0.1', True,
+ False, '/headers', True, False, None, False, None, None, {}))
+
+ with handler(cookiejar=cookiejar) as rh:
+ data = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/headers')).read()
+ assert b'Cookie: test=ytdlp' in data
+
+ # Per request
+ with handler() as rh:
+ data = validate_and_send(
+ rh, Request(f'http://127.0.0.1:{self.http_port}/headers', extensions={'cookiejar': cookiejar})).read()
+ assert b'Cookie: test=ytdlp' in data
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_headers(self, handler):
+
+ with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
+ # Global Headers
+ data = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/headers')).read()
+ assert b'Test1: test' in data
+
+ # Per request headers, merged with global
+ data = validate_and_send(rh, Request(
+ f'http://127.0.0.1:{self.http_port}/headers', headers={'test2': 'changed', 'test3': 'test3'})).read()
+ assert b'Test1: test' in data
+ assert b'Test2: changed' in data
+ assert b'Test2: test2' not in data
+ assert b'Test3: test3' in data
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_timeout(self, handler):
+ with handler() as rh:
+ # Default timeout is 20 seconds, so this should go through
+ validate_and_send(
+ rh, Request(f'http://127.0.0.1:{self.http_port}/timeout_3'))
+
+ with handler(timeout=0.5) as rh:
+ with pytest.raises(TransportError):
+ validate_and_send(
+ rh, Request(f'http://127.0.0.1:{self.http_port}/timeout_1'))
+
+ # Per request timeout, should override handler timeout
+ validate_and_send(
+ rh, Request(f'http://127.0.0.1:{self.http_port}/timeout_1', extensions={'timeout': 4}))
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_source_address(self, handler):
+ source_address = f'127.0.0.{random.randint(5, 255)}'
+ with handler(source_address=source_address) as rh:
+ data = validate_and_send(
+ rh, Request(f'http://127.0.0.1:{self.http_port}/source_address')).read().decode()
+ assert source_address == data
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_gzip_trailing_garbage(self, handler):
+ with handler() as rh:
+ data = validate_and_send(rh, Request(f'http://localhost:{self.http_port}/trailing_garbage')).read().decode()
+ assert data == ''
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ @pytest.mark.skipif(not brotli, reason='brotli support is not installed')
+ def test_brotli(self, handler):
+ with handler() as rh:
+ res = validate_and_send(
+ rh, Request(
f'http://127.0.0.1:{self.http_port}/content-encoding',
headers={'ytdl-encoding': 'br'}))
- self.assertEqual(res.headers.get('Content-Encoding'), 'br')
- self.assertEqual(res.read(), b'')
+ assert res.headers.get('Content-Encoding') == 'br'
+ assert res.read() == b''
- def test_deflate(self):
- with FakeYDL() as ydl:
- res = ydl.urlopen(
- sanitized_Request(
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_deflate(self, handler):
+ with handler() as rh:
+ res = validate_and_send(
+ rh, Request(
f'http://127.0.0.1:{self.http_port}/content-encoding',
headers={'ytdl-encoding': 'deflate'}))
- self.assertEqual(res.headers.get('Content-Encoding'), 'deflate')
- self.assertEqual(res.read(), b'')
+ assert res.headers.get('Content-Encoding') == 'deflate'
+ assert res.read() == b''
- def test_gzip(self):
- with FakeYDL() as ydl:
- res = ydl.urlopen(
- sanitized_Request(
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_gzip(self, handler):
+ with handler() as rh:
+ res = validate_and_send(
+ rh, Request(
f'http://127.0.0.1:{self.http_port}/content-encoding',
headers={'ytdl-encoding': 'gzip'}))
- self.assertEqual(res.headers.get('Content-Encoding'), 'gzip')
- self.assertEqual(res.read(), b'')
+ assert res.headers.get('Content-Encoding') == 'gzip'
+ assert res.read() == b''
- def test_multiple_encodings(self):
- # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.4
- with FakeYDL() as ydl:
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_multiple_encodings(self, handler):
+ with handler() as rh:
for pair in ('gzip,deflate', 'deflate, gzip', 'gzip, gzip', 'deflate, deflate'):
- res = ydl.urlopen(
- sanitized_Request(
+ res = validate_and_send(
+ rh, Request(
f'http://127.0.0.1:{self.http_port}/content-encoding',
headers={'ytdl-encoding': pair}))
- self.assertEqual(res.headers.get('Content-Encoding'), pair)
- self.assertEqual(res.read(), b'')
+ assert res.headers.get('Content-Encoding') == pair
+ assert res.read() == b''
- def test_unsupported_encoding(self):
- # it should return the raw content
- with FakeYDL() as ydl:
- res = ydl.urlopen(
- sanitized_Request(
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_unsupported_encoding(self, handler):
+ with handler() as rh:
+ res = validate_and_send(
+ rh, Request(
f'http://127.0.0.1:{self.http_port}/content-encoding',
headers={'ytdl-encoding': 'unsupported'}))
- self.assertEqual(res.headers.get('Content-Encoding'), 'unsupported')
- self.assertEqual(res.read(), b'raw')
+ assert res.headers.get('Content-Encoding') == 'unsupported'
+ assert res.read() == b'raw'
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_read(self, handler):
+ with handler() as rh:
+ res = validate_and_send(
+ rh, Request(f'http://127.0.0.1:{self.http_port}/headers'))
+ assert res.readable()
+ assert res.read(1) == b'H'
+ assert res.read(3) == b'ost'
-class TestClientCert(unittest.TestCase):
- def setUp(self):
+class TestHTTPProxy(TestRequestHandlerBase):
+ @classmethod
+ def setup_class(cls):
+ super().setup_class()
+ # HTTP Proxy server
+ cls.proxy = http.server.ThreadingHTTPServer(
+ ('127.0.0.1', 0), _build_proxy_handler('normal'))
+ cls.proxy_port = http_server_port(cls.proxy)
+ cls.proxy_thread = threading.Thread(target=cls.proxy.serve_forever)
+ cls.proxy_thread.daemon = True
+ cls.proxy_thread.start()
+
+ # Geo proxy server
+ cls.geo_proxy = http.server.ThreadingHTTPServer(
+ ('127.0.0.1', 0), _build_proxy_handler('geo'))
+ cls.geo_port = http_server_port(cls.geo_proxy)
+ cls.geo_proxy_thread = threading.Thread(target=cls.geo_proxy.serve_forever)
+ cls.geo_proxy_thread.daemon = True
+ cls.geo_proxy_thread.start()
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_http_proxy(self, handler):
+ http_proxy = f'http://127.0.0.1:{self.proxy_port}'
+ geo_proxy = f'http://127.0.0.1:{self.geo_port}'
+
+ # Test global http proxy
+ # Test per request http proxy
+ # Test per request http proxy disables proxy
+ url = 'http://foo.com/bar'
+
+ # Global HTTP proxy
+ with handler(proxies={'http': http_proxy}) as rh:
+ res = validate_and_send(rh, Request(url)).read().decode()
+ assert res == f'normal: {url}'
+
+ # Per request proxy overrides global
+ res = validate_and_send(rh, Request(url, proxies={'http': geo_proxy})).read().decode()
+ assert res == f'geo: {url}'
+
+ # and setting to None disables all proxies for that request
+ real_url = f'http://127.0.0.1:{self.http_port}/headers'
+ res = validate_and_send(
+ rh, Request(real_url, proxies={'http': None})).read().decode()
+ assert res != f'normal: {real_url}'
+ assert 'Accept' in res
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_noproxy(self, handler):
+ with handler(proxies={'proxy': f'http://127.0.0.1:{self.proxy_port}'}) as rh:
+ # NO_PROXY
+ for no_proxy in (f'127.0.0.1:{self.http_port}', '127.0.0.1', 'localhost'):
+ nop_response = validate_and_send(
+ rh, Request(f'http://127.0.0.1:{self.http_port}/headers', proxies={'no': no_proxy})).read().decode(
+ 'utf-8')
+ assert 'Accept' in nop_response
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_allproxy(self, handler):
+ url = 'http://foo.com/bar'
+ with handler() as rh:
+ response = validate_and_send(rh, Request(url, proxies={'all': f'http://127.0.0.1:{self.proxy_port}'})).read().decode(
+ 'utf-8')
+ assert response == f'normal: {url}'
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_http_proxy_with_idn(self, handler):
+ with handler(proxies={
+ 'http': f'http://127.0.0.1:{self.proxy_port}',
+ }) as rh:
+ url = 'http://中文.tw/'
+ response = rh.send(Request(url)).read().decode()
+ # b'xn--fiq228c' is '中文'.encode('idna')
+ assert response == 'normal: http://xn--fiq228c.tw/'
+
+
+class TestClientCertificate:
+
+ @classmethod
+ def setup_class(cls):
certfn = os.path.join(TEST_DIR, 'testcert.pem')
- self.certdir = os.path.join(TEST_DIR, 'testdata', 'certificate')
- cacertfn = os.path.join(self.certdir, 'ca.crt')
- self.httpd = http.server.HTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler)
+ cls.certdir = os.path.join(TEST_DIR, 'testdata', 'certificate')
+ cacertfn = os.path.join(cls.certdir, 'ca.crt')
+ cls.httpd = http.server.ThreadingHTTPServer(('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.verify_mode = ssl.CERT_REQUIRED
sslctx.load_verify_locations(cafile=cacertfn)
sslctx.load_cert_chain(certfn, None)
- self.httpd.socket = sslctx.wrap_socket(self.httpd.socket, server_side=True)
- self.port = http_server_port(self.httpd)
- self.server_thread = threading.Thread(target=self.httpd.serve_forever)
- self.server_thread.daemon = True
- self.server_thread.start()
+ cls.httpd.socket = sslctx.wrap_socket(cls.httpd.socket, server_side=True)
+ cls.port = http_server_port(cls.httpd)
+ cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)
+ cls.server_thread.daemon = True
+ cls.server_thread.start()
- def _run_test(self, **params):
- ydl = YoutubeDL({
- 'logger': FakeLogger(),
+ def _run_test(self, handler, **handler_kwargs):
+ with handler(
# Disable client-side validation of unacceptable self-signed testcert.pem
# The test is of a check on the server side, so unaffected
- 'nocheckcertificate': True,
- **params,
+ verify=False,
+ **handler_kwargs,
+ ) as rh:
+ validate_and_send(rh, Request(f'https://127.0.0.1:{self.port}/video.html')).read().decode()
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_certificate_combined_nopass(self, handler):
+ self._run_test(handler, client_cert={
+ 'client_certificate': os.path.join(self.certdir, 'clientwithkey.crt'),
})
- r = ydl.extract_info(f'https://127.0.0.1:{self.port}/video.html')
- self.assertEqual(r['url'], f'https://127.0.0.1:{self.port}/vid.mp4')
- def test_certificate_combined_nopass(self):
- self._run_test(client_certificate=os.path.join(self.certdir, 'clientwithkey.crt'))
-
- def test_certificate_nocombined_nopass(self):
- self._run_test(client_certificate=os.path.join(self.certdir, 'client.crt'),
- client_certificate_key=os.path.join(self.certdir, 'client.key'))
-
- def test_certificate_combined_pass(self):
- self._run_test(client_certificate=os.path.join(self.certdir, 'clientwithencryptedkey.crt'),
- client_certificate_password='foobar')
-
- def test_certificate_nocombined_pass(self):
- self._run_test(client_certificate=os.path.join(self.certdir, 'client.crt'),
- client_certificate_key=os.path.join(self.certdir, 'clientencrypted.key'),
- client_certificate_password='foobar')
-
-
-def _build_proxy_handler(name):
- class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
- proxy_name = name
-
- def log_message(self, format, *args):
- pass
-
- def do_GET(self):
- self.send_response(200)
- self.send_header('Content-Type', 'text/plain; charset=utf-8')
- self.end_headers()
- self.wfile.write(f'{self.proxy_name}: {self.path}'.encode())
- return HTTPTestRequestHandler
-
-
-class TestProxy(unittest.TestCase):
- def setUp(self):
- self.proxy = http.server.HTTPServer(
- ('127.0.0.1', 0), _build_proxy_handler('normal'))
- self.port = http_server_port(self.proxy)
- self.proxy_thread = threading.Thread(target=self.proxy.serve_forever)
- self.proxy_thread.daemon = True
- self.proxy_thread.start()
-
- self.geo_proxy = http.server.HTTPServer(
- ('127.0.0.1', 0), _build_proxy_handler('geo'))
- self.geo_port = http_server_port(self.geo_proxy)
- self.geo_proxy_thread = threading.Thread(target=self.geo_proxy.serve_forever)
- self.geo_proxy_thread.daemon = True
- self.geo_proxy_thread.start()
-
- def test_proxy(self):
- geo_proxy = f'127.0.0.1:{self.geo_port}'
- ydl = YoutubeDL({
- 'proxy': f'127.0.0.1:{self.port}',
- 'geo_verification_proxy': geo_proxy,
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_certificate_nocombined_nopass(self, handler):
+ self._run_test(handler, client_cert={
+ 'client_certificate': os.path.join(self.certdir, 'client.crt'),
+ 'client_certificate_key': os.path.join(self.certdir, 'client.key'),
})
- url = 'http://foo.com/bar'
- response = ydl.urlopen(url).read().decode()
- self.assertEqual(response, f'normal: {url}')
- req = urllib.request.Request(url)
- req.add_header('Ytdl-request-proxy', geo_proxy)
- response = ydl.urlopen(req).read().decode()
- self.assertEqual(response, f'geo: {url}')
-
- def test_proxy_with_idn(self):
- ydl = YoutubeDL({
- 'proxy': f'127.0.0.1:{self.port}',
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_certificate_combined_pass(self, handler):
+ self._run_test(handler, client_cert={
+ 'client_certificate': os.path.join(self.certdir, 'clientwithencryptedkey.crt'),
+ 'client_certificate_password': 'foobar',
+ })
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_certificate_nocombined_pass(self, handler):
+ self._run_test(handler, client_cert={
+ 'client_certificate': os.path.join(self.certdir, 'client.crt'),
+ 'client_certificate_key': os.path.join(self.certdir, 'clientencrypted.key'),
+ 'client_certificate_password': 'foobar',
})
- url = 'http://中文.tw/'
- response = ydl.urlopen(url).read().decode()
- # b'xn--fiq228c' is '中文'.encode('idna')
- self.assertEqual(response, 'normal: http://xn--fiq228c.tw/')
-class TestFileURL(unittest.TestCase):
- # See https://github.com/ytdl-org/youtube-dl/issues/8227
- def test_file_urls(self):
+class TestUrllibRequestHandler(TestRequestHandlerBase):
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_file_urls(self, handler):
+ # See https://github.com/ytdl-org/youtube-dl/issues/8227
tf = tempfile.NamedTemporaryFile(delete=False)
tf.write(b'foobar')
tf.close()
- url = pathlib.Path(tf.name).as_uri()
- with FakeYDL() as ydl:
- self.assertRaisesRegex(
- urllib.error.URLError, 'file:// URLs are explicitly disabled in yt-dlp for security reasons', ydl.urlopen, url)
- with FakeYDL({'enable_file_urls': True}) as ydl:
- res = ydl.urlopen(url)
- self.assertEqual(res.read(), b'foobar')
+ req = Request(pathlib.Path(tf.name).as_uri())
+ with handler() as rh:
+ with pytest.raises(UnsupportedRequest):
+ rh.validate(req)
+
+ # Test that urllib never loaded FileHandler
+ with pytest.raises(TransportError):
+ rh.send(req)
+
+ with handler(enable_file_urls=True) as rh:
+ res = validate_and_send(rh, req)
+ assert res.read() == b'foobar'
res.close()
+
os.unlink(tf.name)
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_http_error_returns_content(self, handler):
+ # urllib HTTPError will try close the underlying response if reference to the HTTPError object is lost
+ def get_response():
+ with handler() as rh:
+ # headers url
+ try:
+ validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/gen_404'))
+ except HTTPError as e:
+ return e.response
-if __name__ == '__main__':
- unittest.main()
+ assert get_response().read() == b''
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_verify_cert_error_text(self, handler):
+ # Check the output of the error message
+ with handler() as rh:
+ with pytest.raises(
+ CertificateVerifyError,
+ match=r'\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed: self.signed certificate'
+ ):
+ validate_and_send(rh, Request(f'https://127.0.0.1:{self.https_port}/headers'))
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_httplib_validation_errors(self, handler):
+ with handler() as rh:
+
+ # https://github.com/python/cpython/blob/987b712b4aeeece336eed24fcc87a950a756c3e2/Lib/http/client.py#L1256
+ with pytest.raises(RequestError, match='method can\'t contain control characters') as exc_info:
+ validate_and_send(rh, Request('http://127.0.0.1', method='GET\n'))
+ assert not isinstance(exc_info.value, TransportError)
+
+ # https://github.com/python/cpython/blob/987b712b4aeeece336eed24fcc87a950a756c3e2/Lib/http/client.py#L1265
+ with pytest.raises(RequestError, match='URL can\'t contain control characters') as exc_info:
+ validate_and_send(rh, Request('http://127.0.0. 1', method='GET\n'))
+ assert not isinstance(exc_info.value, TransportError)
+
+ # https://github.com/python/cpython/blob/987b712b4aeeece336eed24fcc87a950a756c3e2/Lib/http/client.py#L1288C31-L1288C50
+ with pytest.raises(RequestError, match='Invalid header name') as exc_info:
+ validate_and_send(rh, Request('http://127.0.0.1', headers={'foo\n': 'bar'}))
+ assert not isinstance(exc_info.value, TransportError)
+
+
+def run_validation(handler, fail, req, **handler_kwargs):
+ with handler(**handler_kwargs) as rh:
+ if fail:
+ with pytest.raises(UnsupportedRequest):
+ rh.validate(req)
+ else:
+ rh.validate(req)
+
+
+class TestRequestHandlerValidation:
+
+ class ValidationRH(RequestHandler):
+ def _send(self, request):
+ raise RequestError('test')
+
+ class NoCheckRH(ValidationRH):
+ _SUPPORTED_FEATURES = None
+ _SUPPORTED_PROXY_SCHEMES = None
+ _SUPPORTED_URL_SCHEMES = None
+
+ class HTTPSupportedRH(ValidationRH):
+ _SUPPORTED_URL_SCHEMES = ('http',)
+
+ URL_SCHEME_TESTS = [
+ # scheme, expected to fail, handler kwargs
+ ('Urllib', [
+ ('http', False, {}),
+ ('https', False, {}),
+ ('data', False, {}),
+ ('ftp', False, {}),
+ ('file', True, {}),
+ ('file', False, {'enable_file_urls': True}),
+ ]),
+ (NoCheckRH, [('http', False, {})]),
+ (ValidationRH, [('http', True, {})])
+ ]
+
+ PROXY_SCHEME_TESTS = [
+ # scheme, expected to fail
+ ('Urllib', [
+ ('http', False),
+ ('https', True),
+ ('socks4', False),
+ ('socks4a', False),
+ ('socks5', False),
+ ('socks5h', False),
+ ('socks', True),
+ ]),
+ (NoCheckRH, [('http', False)]),
+ (HTTPSupportedRH, [('http', True)]),
+ ]
+
+ PROXY_KEY_TESTS = [
+ # key, expected to fail
+ ('Urllib', [
+ ('all', False),
+ ('unrelated', False),
+ ]),
+ (NoCheckRH, [('all', False)]),
+ (HTTPSupportedRH, [('all', True)]),
+ (HTTPSupportedRH, [('no', True)]),
+ ]
+
+ @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
+ (handler_tests[0], scheme, fail, handler_kwargs)
+ for handler_tests in URL_SCHEME_TESTS
+ for scheme, fail, handler_kwargs in handler_tests[1]
+
+ ], indirect=['handler'])
+ def test_url_scheme(self, handler, scheme, fail, handler_kwargs):
+ run_validation(handler, fail, Request(f'{scheme}://'), **(handler_kwargs or {}))
+
+ @pytest.mark.parametrize('handler,fail', [('Urllib', False)], indirect=['handler'])
+ def test_no_proxy(self, handler, fail):
+ run_validation(handler, fail, Request('http://', proxies={'no': '127.0.0.1,github.com'}))
+ run_validation(handler, fail, Request('http://'), proxies={'no': '127.0.0.1,github.com'})
+
+ @pytest.mark.parametrize('handler,proxy_key,fail', [
+ (handler_tests[0], proxy_key, fail)
+ for handler_tests in PROXY_KEY_TESTS
+ for proxy_key, fail in handler_tests[1]
+ ], indirect=['handler'])
+ def test_proxy_key(self, handler, proxy_key, fail):
+ run_validation(handler, fail, Request('http://', proxies={proxy_key: 'http://example.com'}))
+ run_validation(handler, fail, Request('http://'), proxies={proxy_key: 'http://example.com'})
+
+ @pytest.mark.parametrize('handler,scheme,fail', [
+ (handler_tests[0], scheme, fail)
+ for handler_tests in PROXY_SCHEME_TESTS
+ for scheme, fail in handler_tests[1]
+ ], indirect=['handler'])
+ def test_proxy_scheme(self, handler, scheme, fail):
+ run_validation(handler, fail, Request('http://', proxies={'http': f'{scheme}://example.com'}))
+ run_validation(handler, fail, Request('http://'), proxies={'http': f'{scheme}://example.com'})
+
+ @pytest.mark.parametrize('handler', ['Urllib', HTTPSupportedRH], indirect=True)
+ def test_empty_proxy(self, handler):
+ run_validation(handler, False, Request('http://', proxies={'http': None}))
+ run_validation(handler, False, Request('http://'), proxies={'http': None})
+
+ @pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1'])
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_missing_proxy_scheme(self, handler, proxy_url):
+ run_validation(handler, True, Request('http://', proxies={'http': 'example.com'}))
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_cookiejar_extension(self, handler):
+ run_validation(handler, True, Request('http://', extensions={'cookiejar': 'notacookiejar'}))
+
+ @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
+ def test_timeout_extension(self, handler):
+ run_validation(handler, True, Request('http://', extensions={'timeout': 'notavalidtimeout'}))
+
+ def test_invalid_request_type(self):
+ rh = self.ValidationRH(logger=FakeLogger())
+ for method in (rh.validate, rh.send):
+ with pytest.raises(TypeError, match='Expected an instance of Request'):
+ method('not a request')
+
+
+class FakeResponse(Response):
+ def __init__(self, request):
+ # XXX: we could make request part of standard response interface
+ self.request = request
+ super().__init__(fp=io.BytesIO(b''), headers={}, url=request.url)
+
+
+class FakeRH(RequestHandler):
+
+ def _validate(self, request):
+ return
+
+ def _send(self, request: Request):
+ if request.url.startswith('ssl://'):
+ raise SSLError(request.url[len('ssl://'):])
+ return FakeResponse(request)
+
+
+class FakeRHYDL(FakeYDL):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._request_director = self.build_request_director([FakeRH])
+
+
+class TestRequestDirector:
+
+ def test_handler_operations(self):
+ director = RequestDirector(logger=FakeLogger())
+ handler = FakeRH(logger=FakeLogger())
+ director.add_handler(handler)
+ assert director.handlers.get(FakeRH.RH_KEY) is handler
+
+ # Handler should overwrite
+ handler2 = FakeRH(logger=FakeLogger())
+ director.add_handler(handler2)
+ assert director.handlers.get(FakeRH.RH_KEY) is not handler
+ assert director.handlers.get(FakeRH.RH_KEY) is handler2
+ assert len(director.handlers) == 1
+
+ class AnotherFakeRH(FakeRH):
+ pass
+ director.add_handler(AnotherFakeRH(logger=FakeLogger()))
+ assert len(director.handlers) == 2
+ assert director.handlers.get(AnotherFakeRH.RH_KEY).RH_KEY == AnotherFakeRH.RH_KEY
+
+ director.handlers.pop(FakeRH.RH_KEY, None)
+ assert director.handlers.get(FakeRH.RH_KEY) is None
+ assert len(director.handlers) == 1
+
+ # RequestErrors should passthrough
+ with pytest.raises(SSLError):
+ director.send(Request('ssl://something'))
+
+ def test_send(self):
+ director = RequestDirector(logger=FakeLogger())
+ with pytest.raises(RequestError):
+ director.send(Request('any://'))
+ director.add_handler(FakeRH(logger=FakeLogger()))
+ assert isinstance(director.send(Request('http://')), FakeResponse)
+
+ def test_unsupported_handlers(self):
+ director = RequestDirector(logger=FakeLogger())
+ director.add_handler(FakeRH(logger=FakeLogger()))
+
+ class SupportedRH(RequestHandler):
+ _SUPPORTED_URL_SCHEMES = ['http']
+
+ def _send(self, request: Request):
+ return Response(fp=io.BytesIO(b'supported'), headers={}, url=request.url)
+
+ # This handler should by default take preference over FakeRH
+ director.add_handler(SupportedRH(logger=FakeLogger()))
+ assert director.send(Request('http://')).read() == b'supported'
+ assert director.send(Request('any://')).read() == b''
+
+ director.handlers.pop(FakeRH.RH_KEY)
+ with pytest.raises(NoSupportingHandlers):
+ director.send(Request('any://'))
+
+ def test_unexpected_error(self):
+ director = RequestDirector(logger=FakeLogger())
+
+ class UnexpectedRH(FakeRH):
+ def _send(self, request: Request):
+ raise TypeError('something')
+
+ director.add_handler(UnexpectedRH(logger=FakeLogger))
+ with pytest.raises(NoSupportingHandlers, match=r'1 unexpected error'):
+ director.send(Request('any://'))
+
+ director.handlers.clear()
+ assert len(director.handlers) == 0
+
+ # Should not be fatal
+ director.add_handler(FakeRH(logger=FakeLogger()))
+ director.add_handler(UnexpectedRH(logger=FakeLogger))
+ assert director.send(Request('any://'))
+
+
+# XXX: do we want to move this to test_YoutubeDL.py?
+class TestYoutubeDLNetworking:
+
+ @staticmethod
+ def build_handler(ydl, handler: RequestHandler = FakeRH):
+ return ydl.build_request_director([handler]).handlers.get(handler.RH_KEY)
+
+ def test_compat_opener(self):
+ with FakeYDL() as ydl:
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', category=DeprecationWarning)
+ assert isinstance(ydl._opener, urllib.request.OpenerDirector)
+
+ @pytest.mark.parametrize('proxy,expected', [
+ ('http://127.0.0.1:8080', {'all': 'http://127.0.0.1:8080'}),
+ ('', {'all': '__noproxy__'}),
+ (None, {'http': 'http://127.0.0.1:8081', 'https': 'http://127.0.0.1:8081'}) # env, set https
+ ])
+ def test_proxy(self, proxy, expected):
+ old_http_proxy = os.environ.get('HTTP_PROXY')
+ try:
+ os.environ['HTTP_PROXY'] = 'http://127.0.0.1:8081' # ensure that provided proxies override env
+ with FakeYDL({'proxy': proxy}) as ydl:
+ assert ydl.proxies == expected
+ finally:
+ if old_http_proxy:
+ os.environ['HTTP_PROXY'] = old_http_proxy
+
+ def test_compat_request(self):
+ with FakeRHYDL() as ydl:
+ assert ydl.urlopen('test://')
+ urllib_req = urllib.request.Request('http://foo.bar', data=b'test', method='PUT', headers={'X-Test': '1'})
+ urllib_req.add_unredirected_header('Cookie', 'bob=bob')
+ urllib_req.timeout = 2
+
+ req = ydl.urlopen(urllib_req).request
+ assert req.url == urllib_req.get_full_url()
+ assert req.data == urllib_req.data
+ assert req.method == urllib_req.get_method()
+ assert 'X-Test' in req.headers
+ assert 'Cookie' in req.headers
+ assert req.extensions.get('timeout') == 2
+
+ with pytest.raises(AssertionError):
+ ydl.urlopen(None)
+
+ def test_extract_basic_auth(self):
+ with FakeRHYDL() as ydl:
+ res = ydl.urlopen(Request('http://user:pass@foo.bar'))
+ assert res.request.headers['Authorization'] == 'Basic dXNlcjpwYXNz'
+
+ def test_sanitize_url(self):
+ with FakeRHYDL() as ydl:
+ res = ydl.urlopen(Request('httpss://foo.bar'))
+ assert res.request.url == 'https://foo.bar'
+
+ def test_file_urls_error(self):
+ # use urllib handler
+ with FakeYDL() as ydl:
+ with pytest.raises(RequestError, match=r'file:// URLs are disabled by default'):
+ ydl.urlopen('file://')
+
+ def test_legacy_server_connect_error(self):
+ with FakeRHYDL() as ydl:
+ for error in ('UNSAFE_LEGACY_RENEGOTIATION_DISABLED', 'SSLV3_ALERT_HANDSHAKE_FAILURE'):
+ with pytest.raises(RequestError, match=r'Try using --legacy-server-connect'):
+ ydl.urlopen(f'ssl://{error}')
+
+ with pytest.raises(SSLError, match='testerror'):
+ ydl.urlopen('ssl://testerror')
+
+ @pytest.mark.parametrize('proxy_key,proxy_url,expected', [
+ ('http', '__noproxy__', None),
+ ('no', '127.0.0.1,foo.bar', '127.0.0.1,foo.bar'),
+ ('https', 'example.com', 'http://example.com'),
+ ('https', 'socks5://example.com', 'socks5h://example.com'),
+ ('http', 'socks://example.com', 'socks4://example.com'),
+ ('http', 'socks4://example.com', 'socks4://example.com'),
+ ])
+ def test_clean_proxy(self, proxy_key, proxy_url, expected):
+ # proxies should be cleaned in urlopen()
+ with FakeRHYDL() as ydl:
+ req = ydl.urlopen(Request('test://', proxies={proxy_key: proxy_url})).request
+ assert req.proxies[proxy_key] == expected
+
+ # and should also be cleaned when building the handler
+ env_key = f'{proxy_key.upper()}_PROXY'
+ old_env_proxy = os.environ.get(env_key)
+ try:
+ os.environ[env_key] = proxy_url # ensure that provided proxies override env
+ with FakeYDL() as ydl:
+ rh = self.build_handler(ydl)
+ assert rh.proxies[proxy_key] == expected
+ finally:
+ if old_env_proxy:
+ os.environ[env_key] = old_env_proxy
+
+ def test_clean_proxy_header(self):
+ with FakeRHYDL() as ydl:
+ req = ydl.urlopen(Request('test://', headers={'ytdl-request-proxy': '//foo.bar'})).request
+ assert 'ytdl-request-proxy' not in req.headers
+ assert req.proxies == {'all': 'http://foo.bar'}
+
+ with FakeYDL({'http_headers': {'ytdl-request-proxy': '//foo.bar'}}) as ydl:
+ rh = self.build_handler(ydl)
+ assert 'ytdl-request-proxy' not in rh.headers
+ assert rh.proxies == {'all': 'http://foo.bar'}
+
+ def test_clean_header(self):
+ with FakeRHYDL() as ydl:
+ res = ydl.urlopen(Request('test://', headers={'Youtubedl-no-compression': True}))
+ assert 'Youtubedl-no-compression' not in res.request.headers
+ assert res.request.headers.get('Accept-Encoding') == 'identity'
+
+ with FakeYDL({'http_headers': {'Youtubedl-no-compression': True}}) as ydl:
+ rh = self.build_handler(ydl)
+ assert 'Youtubedl-no-compression' not in rh.headers
+ assert rh.headers.get('Accept-Encoding') == 'identity'
+
+ def test_build_handler_params(self):
+ with FakeYDL({
+ 'http_headers': {'test': 'testtest'},
+ 'socket_timeout': 2,
+ 'proxy': 'http://127.0.0.1:8080',
+ 'source_address': '127.0.0.45',
+ 'debug_printtraffic': True,
+ 'compat_opts': ['no-certifi'],
+ 'nocheckcertificate': True,
+ 'legacy_server_connect': True,
+ }) as ydl:
+ rh = self.build_handler(ydl)
+ assert rh.headers.get('test') == 'testtest'
+ assert 'Accept' in rh.headers # ensure std_headers are still there
+ assert rh.timeout == 2
+ assert rh.proxies.get('all') == 'http://127.0.0.1:8080'
+ assert rh.source_address == '127.0.0.45'
+ assert rh.verbose is True
+ assert rh.prefer_system_certs is True
+ assert rh.verify is False
+ assert rh.legacy_ssl_support is True
+
+ @pytest.mark.parametrize('ydl_params', [
+ {'client_certificate': 'fakecert.crt'},
+ {'client_certificate': 'fakecert.crt', 'client_certificate_key': 'fakekey.key'},
+ {'client_certificate': 'fakecert.crt', 'client_certificate_key': 'fakekey.key', 'client_certificate_password': 'foobar'},
+ {'client_certificate_key': 'fakekey.key', 'client_certificate_password': 'foobar'},
+ ])
+ def test_client_certificate(self, ydl_params):
+ with FakeYDL(ydl_params) as ydl:
+ rh = self.build_handler(ydl)
+ assert rh._client_cert == ydl_params # XXX: Too bound to implementation
+
+ def test_urllib_file_urls(self):
+ with FakeYDL({'enable_file_urls': False}) as ydl:
+ rh = self.build_handler(ydl, UrllibRH)
+ assert rh.enable_file_urls is False
+
+ with FakeYDL({'enable_file_urls': True}) as ydl:
+ rh = self.build_handler(ydl, UrllibRH)
+ assert rh.enable_file_urls is True
+
+
+class TestRequest:
+
+ def test_query(self):
+ req = Request('http://example.com?q=something', query={'v': 'xyz'})
+ assert req.url == 'http://example.com?q=something&v=xyz'
+
+ req.update(query={'v': '123'})
+ assert req.url == 'http://example.com?q=something&v=123'
+ req.update(url='http://example.com', query={'v': 'xyz'})
+ assert req.url == 'http://example.com?v=xyz'
+
+ def test_method(self):
+ req = Request('http://example.com')
+ assert req.method == 'GET'
+ req.data = b'test'
+ assert req.method == 'POST'
+ req.data = None
+ assert req.method == 'GET'
+ req.data = b'test2'
+ req.method = 'PUT'
+ assert req.method == 'PUT'
+ req.data = None
+ assert req.method == 'PUT'
+ with pytest.raises(TypeError):
+ req.method = 1
+
+ def test_request_helpers(self):
+ assert HEADRequest('http://example.com').method == 'HEAD'
+ assert PUTRequest('http://example.com').method == 'PUT'
+
+ def test_headers(self):
+ req = Request('http://example.com', headers={'tesT': 'test'})
+ assert req.headers == HTTPHeaderDict({'test': 'test'})
+ req.update(headers={'teSt2': 'test2'})
+ assert req.headers == HTTPHeaderDict({'test': 'test', 'test2': 'test2'})
+
+ req.headers = new_headers = HTTPHeaderDict({'test': 'test'})
+ assert req.headers == HTTPHeaderDict({'test': 'test'})
+ assert req.headers is new_headers
+
+ # test converts dict to case insensitive dict
+ req.headers = new_headers = {'test2': 'test2'}
+ assert isinstance(req.headers, HTTPHeaderDict)
+ assert req.headers is not new_headers
+
+ with pytest.raises(TypeError):
+ req.headers = None
+
+ def test_data_type(self):
+ req = Request('http://example.com')
+ assert req.data is None
+ # test bytes is allowed
+ req.data = b'test'
+ assert req.data == b'test'
+ # test iterable of bytes is allowed
+ i = [b'test', b'test2']
+ req.data = i
+ assert req.data == i
+
+ # test file-like object is allowed
+ f = io.BytesIO(b'test')
+ req.data = f
+ assert req.data == f
+
+ # common mistake: test str not allowed
+ with pytest.raises(TypeError):
+ req.data = 'test'
+ assert req.data != 'test'
+
+ # common mistake: test dict is not allowed
+ with pytest.raises(TypeError):
+ req.data = {'test': 'test'}
+ assert req.data != {'test': 'test'}
+
+ def test_content_length_header(self):
+ req = Request('http://example.com', headers={'Content-Length': '0'}, data=b'')
+ assert req.headers.get('Content-Length') == '0'
+
+ req.data = b'test'
+ assert 'Content-Length' not in req.headers
+
+ req = Request('http://example.com', headers={'Content-Length': '10'})
+ assert 'Content-Length' not in req.headers
+
+ def test_content_type_header(self):
+ req = Request('http://example.com', headers={'Content-Type': 'test'}, data=b'test')
+ assert req.headers.get('Content-Type') == 'test'
+ req.data = b'test2'
+ assert req.headers.get('Content-Type') == 'test'
+ req.data = None
+ assert 'Content-Type' not in req.headers
+ req.data = b'test3'
+ assert req.headers.get('Content-Type') == 'application/x-www-form-urlencoded'
+
+ def test_proxies(self):
+ req = Request(url='http://example.com', proxies={'http': 'http://127.0.0.1:8080'})
+ assert req.proxies == {'http': 'http://127.0.0.1:8080'}
+
+ def test_extensions(self):
+ req = Request(url='http://example.com', extensions={'timeout': 2})
+ assert req.extensions == {'timeout': 2}
+
+ def test_copy(self):
+ req = Request(
+ url='http://example.com',
+ extensions={'cookiejar': CookieJar()},
+ headers={'Accept-Encoding': 'br'},
+ proxies={'http': 'http://127.0.0.1'},
+ data=[b'123']
+ )
+ req_copy = req.copy()
+ assert req_copy is not req
+ assert req_copy.url == req.url
+ assert req_copy.headers == req.headers
+ assert req_copy.headers is not req.headers
+ assert req_copy.proxies == req.proxies
+ assert req_copy.proxies is not req.proxies
+
+ # Data is not able to be copied
+ assert req_copy.data == req.data
+ assert req_copy.data is req.data
+
+ # Shallow copy extensions
+ assert req_copy.extensions is not req.extensions
+ assert req_copy.extensions['cookiejar'] == req.extensions['cookiejar']
+
+ # Subclasses are copied by default
+ class AnotherRequest(Request):
+ pass
+
+ req = AnotherRequest(url='http://127.0.0.1')
+ assert isinstance(req.copy(), AnotherRequest)
+
+ def test_url(self):
+ req = Request(url='https://фtest.example.com/ some spaceв?ä=c',)
+ assert req.url == 'https://xn--test-z6d.example.com/%20some%20space%D0%B2?%C3%A4=c'
+
+ assert Request(url='//example.com').url == 'http://example.com'
+
+ with pytest.raises(TypeError):
+ Request(url='https://').url = None
+
+
+class TestResponse:
+
+ @pytest.mark.parametrize('reason,status,expected', [
+ ('custom', 200, 'custom'),
+ (None, 404, 'Not Found'), # fallback status
+ ('', 403, 'Forbidden'),
+ (None, 999, None)
+ ])
+ def test_reason(self, reason, status, expected):
+ res = Response(io.BytesIO(b''), url='test://', headers={}, status=status, reason=reason)
+ assert res.reason == expected
+
+ def test_headers(self):
+ headers = Message()
+ headers.add_header('Test', 'test')
+ headers.add_header('Test', 'test2')
+ headers.add_header('content-encoding', 'br')
+ res = Response(io.BytesIO(b''), headers=headers, url='test://')
+ assert res.headers.get_all('test') == ['test', 'test2']
+ assert 'Content-Encoding' in res.headers
+
+ def test_get_header(self):
+ headers = Message()
+ headers.add_header('Set-Cookie', 'cookie1')
+ headers.add_header('Set-cookie', 'cookie2')
+ headers.add_header('Test', 'test')
+ headers.add_header('Test', 'test2')
+ res = Response(io.BytesIO(b''), headers=headers, url='test://')
+ assert res.get_header('test') == 'test, test2'
+ assert res.get_header('set-Cookie') == 'cookie1'
+ assert res.get_header('notexist', 'default') == 'default'
+
+ def test_compat(self):
+ res = Response(io.BytesIO(b''), url='test://', status=404, headers={'test': 'test'})
+ assert res.code == res.getcode() == res.status
+ assert res.geturl() == res.url
+ assert res.info() is res.headers
+ assert res.getheader('test') == res.get_header('test')
diff --git a/test/test_networking_utils.py b/test/test_networking_utils.py
new file mode 100644
index 000000000..f9f876af3
--- /dev/null
+++ b/test/test_networking_utils.py
@@ -0,0 +1,239 @@
+#!/usr/bin/env python3
+
+# Allow direct execution
+import os
+import sys
+
+import pytest
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import io
+import platform
+import random
+import ssl
+import urllib.error
+
+from yt_dlp.cookies import YoutubeDLCookieJar
+from yt_dlp.dependencies import certifi
+from yt_dlp.networking import Response
+from yt_dlp.networking._helper import (
+ InstanceStoreMixin,
+ add_accept_encoding_header,
+ get_redirect_method,
+ make_socks_proxy_opts,
+ select_proxy,
+ ssl_load_certs,
+)
+from yt_dlp.networking.exceptions import (
+ HTTPError,
+ IncompleteRead,
+ _CompatHTTPError,
+)
+from yt_dlp.socks import ProxyType
+from yt_dlp.utils.networking import HTTPHeaderDict
+
+TEST_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+class TestNetworkingUtils:
+
+ def test_select_proxy(self):
+ proxies = {
+ 'all': 'socks5://example.com',
+ 'http': 'http://example.com:1080',
+ 'no': 'bypass.example.com,yt-dl.org'
+ }
+
+ assert select_proxy('https://example.com', proxies) == proxies['all']
+ assert select_proxy('http://example.com', proxies) == proxies['http']
+ assert select_proxy('http://bypass.example.com', proxies) is None
+ assert select_proxy('https://yt-dl.org', proxies) is None
+
+ @pytest.mark.parametrize('socks_proxy,expected', [
+ ('socks5h://example.com', {
+ 'proxytype': ProxyType.SOCKS5,
+ 'addr': 'example.com',
+ 'port': 1080,
+ 'rdns': True,
+ 'username': None,
+ 'password': None
+ }),
+ ('socks5://user:@example.com:5555', {
+ 'proxytype': ProxyType.SOCKS5,
+ 'addr': 'example.com',
+ 'port': 5555,
+ 'rdns': False,
+ 'username': 'user',
+ 'password': ''
+ }),
+ ('socks4://u%40ser:pa%20ss@127.0.0.1:1080', {
+ 'proxytype': ProxyType.SOCKS4,
+ 'addr': '127.0.0.1',
+ 'port': 1080,
+ 'rdns': False,
+ 'username': 'u@ser',
+ 'password': 'pa ss'
+ }),
+ ('socks4a://:pa%20ss@127.0.0.1', {
+ 'proxytype': ProxyType.SOCKS4A,
+ 'addr': '127.0.0.1',
+ 'port': 1080,
+ 'rdns': True,
+ 'username': '',
+ 'password': 'pa ss'
+ })
+ ])
+ def test_make_socks_proxy_opts(self, socks_proxy, expected):
+ assert make_socks_proxy_opts(socks_proxy) == expected
+
+ def test_make_socks_proxy_unknown(self):
+ with pytest.raises(ValueError, match='Unknown SOCKS proxy version: socks'):
+ make_socks_proxy_opts('socks://127.0.0.1')
+
+ @pytest.mark.skipif(not certifi, reason='certifi is not installed')
+ def test_load_certifi(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ context2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ ssl_load_certs(context, use_certifi=True)
+ context2.load_verify_locations(cafile=certifi.where())
+ assert context.get_ca_certs() == context2.get_ca_certs()
+
+ # Test load normal certs
+ # XXX: could there be a case where system certs are the same as certifi?
+ context3 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ ssl_load_certs(context3, use_certifi=False)
+ assert context3.get_ca_certs() != context.get_ca_certs()
+
+ @pytest.mark.parametrize('method,status,expected', [
+ ('GET', 303, 'GET'),
+ ('HEAD', 303, 'HEAD'),
+ ('PUT', 303, 'GET'),
+ ('POST', 301, 'GET'),
+ ('HEAD', 301, 'HEAD'),
+ ('POST', 302, 'GET'),
+ ('HEAD', 302, 'HEAD'),
+ ('PUT', 302, 'PUT'),
+ ('POST', 308, 'POST'),
+ ('POST', 307, 'POST'),
+ ('HEAD', 308, 'HEAD'),
+ ('HEAD', 307, 'HEAD'),
+ ])
+ def test_get_redirect_method(self, method, status, expected):
+ assert get_redirect_method(method, status) == expected
+
+ @pytest.mark.parametrize('headers,supported_encodings,expected', [
+ ({'Accept-Encoding': 'br'}, ['gzip', 'br'], {'Accept-Encoding': 'br'}),
+ ({}, ['gzip', 'br'], {'Accept-Encoding': 'gzip, br'}),
+ ({'Content-type': 'application/json'}, [], {'Content-type': 'application/json', 'Accept-Encoding': 'identity'}),
+ ])
+ def test_add_accept_encoding_header(self, headers, supported_encodings, expected):
+ headers = HTTPHeaderDict(headers)
+ add_accept_encoding_header(headers, supported_encodings)
+ assert headers == HTTPHeaderDict(expected)
+
+
+class TestInstanceStoreMixin:
+
+ class FakeInstanceStoreMixin(InstanceStoreMixin):
+ def _create_instance(self, **kwargs):
+ return random.randint(0, 1000000)
+
+ def _close_instance(self, instance):
+ pass
+
+ def test_mixin(self):
+ mixin = self.FakeInstanceStoreMixin()
+ assert mixin._get_instance(d={'a': 1, 'b': 2, 'c': {'d', 4}}) == mixin._get_instance(d={'a': 1, 'b': 2, 'c': {'d', 4}})
+
+ assert mixin._get_instance(d={'a': 1, 'b': 2, 'c': {'e', 4}}) != mixin._get_instance(d={'a': 1, 'b': 2, 'c': {'d', 4}})
+
+ assert mixin._get_instance(d={'a': 1, 'b': 2, 'c': {'d', 4}} != mixin._get_instance(d={'a': 1, 'b': 2, 'g': {'d', 4}}))
+
+ assert mixin._get_instance(d={'a': 1}, e=[1, 2, 3]) == mixin._get_instance(d={'a': 1}, e=[1, 2, 3])
+
+ assert mixin._get_instance(d={'a': 1}, e=[1, 2, 3]) != mixin._get_instance(d={'a': 1}, e=[1, 2, 3, 4])
+
+ cookiejar = YoutubeDLCookieJar()
+ assert mixin._get_instance(b=[1, 2], c=cookiejar) == mixin._get_instance(b=[1, 2], c=cookiejar)
+
+ assert mixin._get_instance(b=[1, 2], c=cookiejar) != mixin._get_instance(b=[1, 2], c=YoutubeDLCookieJar())
+
+ # Different order
+ assert mixin._get_instance(c=cookiejar, b=[1, 2]) == mixin._get_instance(b=[1, 2], c=cookiejar)
+
+ m = mixin._get_instance(t=1234)
+ assert mixin._get_instance(t=1234) == m
+ mixin._clear_instances()
+ assert mixin._get_instance(t=1234) != m
+
+
+class TestNetworkingExceptions:
+
+ @staticmethod
+ def create_response(status):
+ return Response(fp=io.BytesIO(b'test'), url='http://example.com', headers={'tesT': 'test'}, status=status)
+
+ @pytest.mark.parametrize('http_error_class', [HTTPError, lambda r: _CompatHTTPError(HTTPError(r))])
+ def test_http_error(self, http_error_class):
+
+ response = self.create_response(403)
+ error = http_error_class(response)
+
+ assert error.status == 403
+ assert str(error) == error.msg == 'HTTP Error 403: Forbidden'
+ assert error.reason == response.reason
+ assert error.response is response
+
+ data = error.response.read()
+ assert data == b'test'
+ assert repr(error) == ''
+
+ @pytest.mark.parametrize('http_error_class', [HTTPError, lambda *args, **kwargs: _CompatHTTPError(HTTPError(*args, **kwargs))])
+ def test_redirect_http_error(self, http_error_class):
+ response = self.create_response(301)
+ error = http_error_class(response, redirect_loop=True)
+ assert str(error) == error.msg == 'HTTP Error 301: Moved Permanently (redirect loop detected)'
+ assert error.reason == 'Moved Permanently'
+
+ def test_compat_http_error(self):
+ response = self.create_response(403)
+ error = _CompatHTTPError(HTTPError(response))
+ assert isinstance(error, HTTPError)
+ assert isinstance(error, urllib.error.HTTPError)
+
+ assert error.code == 403
+ assert error.getcode() == 403
+ assert error.hdrs is error.response.headers
+ assert error.info() is error.response.headers
+ assert error.headers is error.response.headers
+ assert error.filename == error.response.url
+ assert error.url == error.response.url
+ assert error.geturl() == error.response.url
+
+ # Passthrough file operations
+ assert error.read() == b'test'
+ assert not error.closed
+ # Technically Response operations are also passed through, which should not be used.
+ assert error.get_header('test') == 'test'
+
+ @pytest.mark.skipif(
+ platform.python_implementation() == 'PyPy', reason='garbage collector works differently in pypy')
+ def test_compat_http_error_autoclose(self):
+ # Compat HTTPError should not autoclose response
+ response = self.create_response(403)
+ _CompatHTTPError(HTTPError(response))
+ assert not response.closed
+
+ def test_incomplete_read_error(self):
+ error = IncompleteRead(b'test', 3, cause='test')
+ assert isinstance(error, IncompleteRead)
+ assert repr(error) == ''
+ assert str(error) == error.msg == '4 bytes read, 3 more expected'
+ assert error.partial == b'test'
+ assert error.expected == 3
+ assert error.cause == 'test'
+
+ error = IncompleteRead(b'aaa')
+ assert repr(error) == ''
+ assert str(error) == '3 bytes read'
diff --git a/test/test_utils.py b/test/test_utils.py
index 862c7d0f7..768edfd0c 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -51,6 +51,7 @@
escape_url,
expand_path,
extract_attributes,
+ extract_basic_auth,
find_xpath_attr,
fix_xml_ampersands,
float_or_none,
@@ -103,7 +104,6 @@
sanitize_filename,
sanitize_path,
sanitize_url,
- sanitized_Request,
shell_quote,
smuggle_url,
str_or_none,
@@ -132,6 +132,7 @@
xpath_text,
xpath_with_ns,
)
+from yt_dlp.utils.networking import HTTPHeaderDict
class TestUtil(unittest.TestCase):
@@ -2315,14 +2316,43 @@ def test_traverse_obj(self):
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
msg='function on a `re.Match` should give group name as well')
+ def test_http_header_dict(self):
+ headers = HTTPHeaderDict()
+ headers['ytdl-test'] = 1
+ self.assertEqual(list(headers.items()), [('Ytdl-Test', '1')])
+ headers['Ytdl-test'] = '2'
+ self.assertEqual(list(headers.items()), [('Ytdl-Test', '2')])
+ self.assertTrue('ytDl-Test' in headers)
+ self.assertEqual(str(headers), str(dict(headers)))
+ self.assertEqual(repr(headers), str(dict(headers)))
+
+ headers.update({'X-dlp': 'data'})
+ self.assertEqual(set(headers.items()), {('Ytdl-Test', '2'), ('X-Dlp', 'data')})
+ self.assertEqual(dict(headers), {'Ytdl-Test': '2', 'X-Dlp': 'data'})
+ self.assertEqual(len(headers), 2)
+ self.assertEqual(headers.copy(), headers)
+ headers2 = HTTPHeaderDict({'X-dlp': 'data3'}, **headers, **{'X-dlp': 'data2'})
+ self.assertEqual(set(headers2.items()), {('Ytdl-Test', '2'), ('X-Dlp', 'data2')})
+ self.assertEqual(len(headers2), 2)
+ headers2.clear()
+ self.assertEqual(len(headers2), 0)
+
+ # ensure we prefer latter headers
+ headers3 = HTTPHeaderDict({'Ytdl-TeSt': 1}, {'Ytdl-test': 2})
+ self.assertEqual(set(headers3.items()), {('Ytdl-Test', '2')})
+ del headers3['ytdl-tesT']
+ self.assertEqual(dict(headers3), {})
+
+ headers4 = HTTPHeaderDict({'ytdl-test': 'data;'})
+ self.assertEqual(set(headers4.items()), {('Ytdl-Test', 'data;')})
+
def test_extract_basic_auth(self):
- auth_header = lambda url: sanitized_Request(url).get_header('Authorization')
- self.assertFalse(auth_header('http://foo.bar'))
- self.assertFalse(auth_header('http://:foo.bar'))
- self.assertEqual(auth_header('http://@foo.bar'), 'Basic Og==')
- self.assertEqual(auth_header('http://:pass@foo.bar'), 'Basic OnBhc3M=')
- self.assertEqual(auth_header('http://user:@foo.bar'), 'Basic dXNlcjo=')
- self.assertEqual(auth_header('http://user:pass@foo.bar'), 'Basic dXNlcjpwYXNz')
+ assert extract_basic_auth('http://:foo.bar') == ('http://:foo.bar', None)
+ assert extract_basic_auth('http://foo.bar') == ('http://foo.bar', None)
+ assert extract_basic_auth('http://@foo.bar') == ('http://foo.bar', 'Basic Og==')
+ assert extract_basic_auth('http://:pass@foo.bar') == ('http://foo.bar', 'Basic OnBhc3M=')
+ assert extract_basic_auth('http://user:@foo.bar') == ('http://foo.bar', 'Basic dXNlcjo=')
+ assert extract_basic_auth('http://user:pass@foo.bar') == ('http://foo.bar', 'Basic dXNlcjpwYXNz')
if __name__ == '__main__':
diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py
index 138646ebf..29a18aef0 100644
--- a/yt_dlp/YoutubeDL.py
+++ b/yt_dlp/YoutubeDL.py
@@ -4,7 +4,6 @@
import datetime
import errno
import fileinput
-import functools
import http.cookiejar
import io
import itertools
@@ -25,8 +24,8 @@
import unicodedata
from .cache import Cache
-from .compat import urllib # isort: split
-from .compat import compat_os_name, compat_shlex_quote
+from .compat import functools, urllib # isort: split
+from .compat import compat_os_name, compat_shlex_quote, urllib_req_to_req
from .cookies import LenientSimpleCookie, load_cookies
from .downloader import FFmpegFD, get_suitable_downloader, shorten_protocol_name
from .downloader.rtmp import rtmpdump_version
@@ -34,6 +33,15 @@
from .extractor.common import UnsupportedURLIE
from .extractor.openload import PhantomJSwrapper
from .minicurses import format_text
+from .networking import Request, RequestDirector
+from .networking.common import _REQUEST_HANDLERS
+from .networking.exceptions import (
+ HTTPError,
+ NoSupportingHandlers,
+ RequestError,
+ SSLError,
+ _CompatHTTPError,
+)
from .plugins import directories as plugin_directories
from .postprocessor import _PLUGIN_CLASSES as plugin_pps
from .postprocessor import (
@@ -78,7 +86,6 @@
MaxDownloadsReached,
Namespace,
PagedList,
- PerRequestProxyHandler,
PlaylistEntries,
Popen,
PostProcessingError,
@@ -87,9 +94,6 @@
SameFileError,
UnavailableVideoError,
UserNotLive,
- YoutubeDLCookieProcessor,
- YoutubeDLHandler,
- YoutubeDLRedirectHandler,
age_restricted,
args_to_str,
bug_reports_message,
@@ -102,6 +106,7 @@
error_to_compat_str,
escapeHTML,
expand_path,
+ extract_basic_auth,
filter_dict,
float_or_none,
format_bytes,
@@ -117,8 +122,6 @@
locked_file,
make_archive_id,
make_dir,
- make_HTTPS_handler,
- merge_headers,
network_exceptions,
number_of_digits,
orderedSet,
@@ -132,7 +135,6 @@
sanitize_filename,
sanitize_path,
sanitize_url,
- sanitized_Request,
std_headers,
str_or_none,
strftime_or_none,
@@ -151,7 +153,12 @@
write_json_file,
write_string,
)
-from .utils.networking import clean_headers
+from .utils._utils import _YDLLogger
+from .utils.networking import (
+ HTTPHeaderDict,
+ clean_headers,
+ clean_proxies,
+)
from .version import CHANNEL, RELEASE_GIT_HEAD, VARIANT, __version__
if compat_os_name == 'nt':
@@ -673,7 +680,9 @@ def process_color_policy(stream):
raise
self.params['compat_opts'] = set(self.params.get('compat_opts', ()))
- self.params['http_headers'] = merge_headers(std_headers, self.params.get('http_headers', {}))
+ self.params['http_headers'] = HTTPHeaderDict(std_headers, self.params.get('http_headers'))
+ self._request_director = self.build_request_director(
+ sorted(_REQUEST_HANDLERS.values(), key=lambda rh: rh.RH_NAME.lower()))
if auto_init and auto_init != 'no_verbose_header':
self.print_debug_header()
@@ -763,8 +772,6 @@ def check_deprecated(param, option, suggestion):
get_postprocessor(pp_def.pop('key'))(self, **pp_def),
when=when)
- self._setup_opener()
-
def preload_download_archive(fn):
"""Preload the archive, if any is specified"""
archive = set()
@@ -946,7 +953,11 @@ def save_cookies(self):
def __exit__(self, *args):
self.restore_console_title()
+ self.close()
+
+ def close(self):
self.save_cookies()
+ self._request_director.close()
def trouble(self, message=None, tb=None, is_error=True):
"""Determine action to take when a download problem appears.
@@ -2468,7 +2479,7 @@ def restore_last_token(self):
return _build_selector_function(parsed_selector)
def _calc_headers(self, info_dict):
- res = merge_headers(self.params['http_headers'], info_dict.get('http_headers') or {})
+ res = HTTPHeaderDict(self.params['http_headers'], info_dict.get('http_headers'))
clean_headers(res)
cookies = self.cookiejar.get_cookies_for_url(info_dict['url'])
if cookies:
@@ -3943,13 +3954,8 @@ def get_encoding(stream):
join_nonempty(*get_package_info(m)) for m in available_dependencies.values()
})) or 'none'))
- self._setup_opener()
- proxy_map = {}
- for handler in self._opener.handlers:
- if hasattr(handler, 'proxies'):
- proxy_map.update(handler.proxies)
- write_debug(f'Proxy map: {proxy_map}')
-
+ write_debug(f'Proxy map: {self.proxies}')
+ # write_debug(f'Request Handlers: {", ".join(rh.RH_NAME for rh in self._request_director.handlers)}')
for plugin_type, plugins in {'Extractor': plugin_ies, 'Post-Processor': plugin_pps}.items():
display_list = ['%s%s' % (
klass.__name__, '' if klass.__name__ == name else f' as {name}')
@@ -3977,53 +3983,21 @@ def get_encoding(stream):
'See https://yt-dl.org/update if you need help updating.' %
latest_version)
- def _setup_opener(self):
- if hasattr(self, '_opener'):
- return
- timeout_val = self.params.get('socket_timeout')
- self._socket_timeout = 20 if timeout_val is None else float(timeout_val)
+ @functools.cached_property
+ def proxies(self):
+ """Global proxy configuration"""
opts_proxy = self.params.get('proxy')
-
- cookie_processor = YoutubeDLCookieProcessor(self.cookiejar)
if opts_proxy is not None:
if opts_proxy == '':
- proxies = {}
- else:
- proxies = {'http': opts_proxy, 'https': opts_proxy}
+ opts_proxy = '__noproxy__'
+ proxies = {'all': opts_proxy}
else:
proxies = urllib.request.getproxies()
- # Set HTTPS proxy to HTTP one if given (https://github.com/ytdl-org/youtube-dl/issues/805)
+ # compat. Set HTTPS_PROXY to __noproxy__ to revert
if 'http' in proxies and 'https' not in proxies:
proxies['https'] = proxies['http']
- proxy_handler = PerRequestProxyHandler(proxies)
- debuglevel = 1 if self.params.get('debug_printtraffic') else 0
- https_handler = make_HTTPS_handler(self.params, debuglevel=debuglevel)
- ydlh = YoutubeDLHandler(self.params, debuglevel=debuglevel)
- redirect_handler = YoutubeDLRedirectHandler()
- data_handler = urllib.request.DataHandler()
-
- # When passing our own FileHandler instance, build_opener won't add the
- # default FileHandler and allows us to disable the file protocol, which
- # can be used for malicious purposes (see
- # https://github.com/ytdl-org/youtube-dl/issues/8227)
- file_handler = urllib.request.FileHandler()
-
- if not self.params.get('enable_file_urls'):
- def file_open(*args, **kwargs):
- raise urllib.error.URLError(
- 'file:// URLs are explicitly disabled in yt-dlp for security reasons. '
- 'Use --enable-file-urls to enable at your own risk.')
- file_handler.file_open = file_open
-
- opener = urllib.request.build_opener(
- proxy_handler, https_handler, cookie_processor, ydlh, redirect_handler, data_handler, file_handler)
-
- # Delete the default user-agent header, which would otherwise apply in
- # cases where our custom HTTP handler doesn't come into play
- # (See https://github.com/ytdl-org/youtube-dl/issues/1309 for details)
- opener.addheaders = []
- self._opener = opener
+ return proxies
@functools.cached_property
def cookiejar(self):
@@ -4031,11 +4005,84 @@ def cookiejar(self):
return load_cookies(
self.params.get('cookiefile'), self.params.get('cookiesfrombrowser'), self)
+ @property
+ def _opener(self):
+ """
+ Get a urllib OpenerDirector from the Urllib handler (deprecated).
+ """
+ self.deprecation_warning('YoutubeDL._opener() is deprecated, use YoutubeDL.urlopen()')
+ handler = self._request_director.handlers['Urllib']
+ return handler._get_instance(cookiejar=self.cookiejar, proxies=self.proxies)
+
def urlopen(self, req):
""" Start an HTTP download """
if isinstance(req, str):
- req = sanitized_Request(req)
- return self._opener.open(req, timeout=self._socket_timeout)
+ req = Request(req)
+ elif isinstance(req, urllib.request.Request):
+ req = urllib_req_to_req(req)
+ assert isinstance(req, Request)
+
+ # compat: Assume user:pass url params are basic auth
+ url, basic_auth_header = extract_basic_auth(req.url)
+ if basic_auth_header:
+ req.headers['Authorization'] = basic_auth_header
+ req.url = sanitize_url(url)
+
+ clean_proxies(proxies=req.proxies, headers=req.headers)
+ clean_headers(req.headers)
+
+ try:
+ return self._request_director.send(req)
+ except NoSupportingHandlers as e:
+ for ue in e.unsupported_errors:
+ if not (ue.handler and ue.msg):
+ continue
+ if ue.handler.RH_KEY == 'Urllib' and 'unsupported url scheme: "file"' in ue.msg.lower():
+ raise RequestError(
+ 'file:// URLs are disabled by default in yt-dlp for security reasons. '
+ 'Use --enable-file-urls to enable at your own risk.', cause=ue) from ue
+ raise
+ except SSLError as e:
+ if 'UNSAFE_LEGACY_RENEGOTIATION_DISABLED' in str(e):
+ raise RequestError('UNSAFE_LEGACY_RENEGOTIATION_DISABLED: Try using --legacy-server-connect', cause=e) from e
+ elif 'SSLV3_ALERT_HANDSHAKE_FAILURE' in str(e):
+ raise RequestError(
+ 'SSLV3_ALERT_HANDSHAKE_FAILURE: The server may not support the current cipher list. '
+ 'Try using --legacy-server-connect', cause=e) from e
+ raise
+ except HTTPError as e: # TODO: Remove in a future release
+ raise _CompatHTTPError(e) from e
+
+ def build_request_director(self, handlers):
+ logger = _YDLLogger(self)
+ headers = self.params.get('http_headers').copy()
+ proxies = self.proxies.copy()
+ clean_headers(headers)
+ clean_proxies(proxies, headers)
+
+ director = RequestDirector(logger=logger, verbose=self.params.get('debug_printtraffic'))
+ for handler in handlers:
+ director.add_handler(handler(
+ logger=logger,
+ headers=headers,
+ cookiejar=self.cookiejar,
+ proxies=proxies,
+ prefer_system_certs='no-certifi' in self.params['compat_opts'],
+ verify=not self.params.get('nocheckcertificate'),
+ **traverse_obj(self.params, {
+ 'verbose': 'debug_printtraffic',
+ 'source_address': 'source_address',
+ 'timeout': 'socket_timeout',
+ 'legacy_ssl_support': 'legacy_server_connect',
+ 'enable_file_urls': 'enable_file_urls',
+ 'client_cert': {
+ 'client_certificate': 'client_certificate',
+ 'client_certificate_key': 'client_certificate_key',
+ 'client_certificate_password': 'client_certificate_password',
+ },
+ }),
+ ))
+ return director
def encode(self, s):
if isinstance(s, bytes):
@@ -4188,7 +4235,7 @@ def _write_thumbnails(self, label, info_dict, filename, thumb_filename_base=None
else:
self.to_screen(f'[info] Downloading {thumb_display_id} ...')
try:
- uf = self.urlopen(sanitized_Request(t['url'], headers=t.get('http_headers', {})))
+ uf = self.urlopen(Request(t['url'], headers=t.get('http_headers', {})))
self.to_screen(f'[info] Writing {thumb_display_id} to: {thumb_filename}')
with open(encodeFilename(thumb_filename), 'wb') as thumbf:
shutil.copyfileobj(uf, thumbf)
diff --git a/yt_dlp/compat/__init__.py b/yt_dlp/compat/__init__.py
index c6c02541c..a41a80ebb 100644
--- a/yt_dlp/compat/__init__.py
+++ b/yt_dlp/compat/__init__.py
@@ -70,3 +70,13 @@ def compat_expanduser(path):
return userhome + path[i:]
else:
compat_expanduser = os.path.expanduser
+
+
+def urllib_req_to_req(urllib_request):
+ """Convert urllib Request to a networking Request"""
+ from ..networking import Request
+ from ..utils.networking import HTTPHeaderDict
+ return Request(
+ urllib_request.get_full_url(), data=urllib_request.data, method=urllib_request.get_method(),
+ headers=HTTPHeaderDict(urllib_request.headers, urllib_request.unredirected_hdrs),
+ extensions={'timeout': urllib_request.timeout} if hasattr(urllib_request, 'timeout') else None)
diff --git a/yt_dlp/downloader/http.py b/yt_dlp/downloader/http.py
index 7c5daea85..45d094721 100644
--- a/yt_dlp/downloader/http.py
+++ b/yt_dlp/downloader/http.py
@@ -1,12 +1,10 @@
-import http.client
import os
import random
-import socket
-import ssl
import time
import urllib.error
from .common import FileDownloader
+from ..networking.exceptions import CertificateVerifyError, TransportError
from ..utils import (
ContentTooShortError,
RetryManager,
@@ -21,14 +19,6 @@
write_xattr,
)
-RESPONSE_READ_EXCEPTIONS = (
- TimeoutError,
- socket.timeout, # compat: py < 3.10
- ConnectionError,
- ssl.SSLError,
- http.client.HTTPException
-)
-
class HttpFD(FileDownloader):
def real_download(self, filename, info_dict):
@@ -196,13 +186,9 @@ def establish_connection():
# Unexpected HTTP error
raise
raise RetryDownload(err)
- except urllib.error.URLError as err:
- if isinstance(err.reason, ssl.CertificateError):
- raise
- raise RetryDownload(err)
- # In urllib.request.AbstractHTTPHandler, the response is partially read on request.
- # Any errors that occur during this will not be wrapped by URLError
- except RESPONSE_READ_EXCEPTIONS as err:
+ except CertificateVerifyError:
+ raise
+ except TransportError as err:
raise RetryDownload(err)
def close_stream():
@@ -258,7 +244,7 @@ def retry(e):
try:
# Download and write
data_block = ctx.data.read(block_size if not is_test else min(block_size, data_len - byte_counter))
- except RESPONSE_READ_EXCEPTIONS as err:
+ except TransportError as err:
retry(err)
byte_counter += len(data_block)
diff --git a/yt_dlp/extractor/common.py b/yt_dlp/extractor/common.py
index fe08839aa..63156d3ac 100644
--- a/yt_dlp/extractor/common.py
+++ b/yt_dlp/extractor/common.py
@@ -17,16 +17,22 @@
import sys
import time
import types
-import urllib.error
import urllib.parse
import urllib.request
import xml.etree.ElementTree
from ..compat import functools # isort: split
-from ..compat import compat_etree_fromstring, compat_expanduser, compat_os_name
+from ..compat import (
+ compat_etree_fromstring,
+ compat_expanduser,
+ compat_os_name,
+ urllib_req_to_req,
+)
from ..cookies import LenientSimpleCookie
from ..downloader.f4m import get_base_url, remove_encrypted_media
from ..downloader.hls import HlsFD
+from ..networking.common import HEADRequest, Request
+from ..networking.exceptions import network_exceptions
from ..utils import (
IDENTITY,
JSON_LD_RE,
@@ -35,7 +41,6 @@
FormatSorter,
GeoRestrictedError,
GeoUtils,
- HEADRequest,
LenientJSONDecoder,
Popen,
RegexNotFoundError,
@@ -61,7 +66,6 @@
js_to_json,
mimetype2ext,
netrc_from_content,
- network_exceptions,
orderedSet,
parse_bitrate,
parse_codecs,
@@ -71,7 +75,6 @@
parse_resolution,
sanitize_filename,
sanitize_url,
- sanitized_Request,
smuggle_url,
str_or_none,
str_to_int,
@@ -83,8 +86,6 @@
unescapeHTML,
unified_strdate,
unified_timestamp,
- update_Request,
- update_url_query,
url_basename,
url_or_none,
urlhandle_detect_ext,
@@ -797,10 +798,12 @@ def __can_accept_status_code(err, expected_status):
def _create_request(self, url_or_request, data=None, headers=None, query=None):
if isinstance(url_or_request, urllib.request.Request):
- return update_Request(url_or_request, data=data, headers=headers, query=query)
- if query:
- url_or_request = update_url_query(url_or_request, query)
- return sanitized_Request(url_or_request, data, headers or {})
+ url_or_request = urllib_req_to_req(url_or_request)
+ elif not isinstance(url_or_request, Request):
+ url_or_request = Request(url_or_request)
+
+ url_or_request.update(data=data, headers=headers, query=query)
+ return url_or_request
def _request_webpage(self, url_or_request, video_id, note=None, errnote=None, fatal=True, data=None, headers=None, query=None, expected_status=None):
"""
@@ -838,12 +841,7 @@ def _request_webpage(self, url_or_request, video_id, note=None, errnote=None, fa
except network_exceptions as err:
if isinstance(err, urllib.error.HTTPError):
if self.__can_accept_status_code(err, expected_status):
- # Retain reference to error to prevent file object from
- # being closed before it can be read. Works around the
- # effects of
- # introduced in Python 3.4.1.
- err.fp._error = err
- return err.fp
+ return err.response
if errnote is False:
return False
diff --git a/yt_dlp/networking/__init__.py b/yt_dlp/networking/__init__.py
index e69de29bb..5e8876484 100644
--- a/yt_dlp/networking/__init__.py
+++ b/yt_dlp/networking/__init__.py
@@ -0,0 +1,13 @@
+# flake8: noqa: 401
+from .common import (
+ HEADRequest,
+ PUTRequest,
+ Request,
+ RequestDirector,
+ RequestHandler,
+ Response,
+)
+
+# isort: split
+# TODO: all request handlers should be safely imported
+from . import _urllib
diff --git a/yt_dlp/networking/_helper.py b/yt_dlp/networking/_helper.py
index 367f3f444..a43c57bb4 100644
--- a/yt_dlp/networking/_helper.py
+++ b/yt_dlp/networking/_helper.py
@@ -1,13 +1,22 @@
from __future__ import annotations
import contextlib
+import functools
import ssl
import sys
+import typing
import urllib.parse
+import urllib.request
+from .exceptions import RequestError, UnsupportedRequest
from ..dependencies import certifi
from ..socks import ProxyType
-from ..utils import YoutubeDLError
+from ..utils import format_field, traverse_obj
+
+if typing.TYPE_CHECKING:
+ from collections.abc import Iterable
+
+ from ..utils.networking import HTTPHeaderDict
def ssl_load_certs(context: ssl.SSLContext, use_certifi=True):
@@ -23,11 +32,11 @@ def ssl_load_certs(context: ssl.SSLContext, use_certifi=True):
# enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151
if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'):
for storename in ('CA', 'ROOT'):
- _ssl_load_windows_store_certs(context, storename)
+ ssl_load_windows_store_certs(context, storename)
context.set_default_verify_paths()
-def _ssl_load_windows_store_certs(ssl_context, storename):
+def ssl_load_windows_store_certs(ssl_context, storename):
# Code adapted from _load_windows_store_certs in https://github.com/python/cpython/blob/main/Lib/ssl.py
try:
certs = [cert for cert, encoding, trust in ssl.enum_certificates(storename)
@@ -44,10 +53,18 @@ def make_socks_proxy_opts(socks_proxy):
url_components = urllib.parse.urlparse(socks_proxy)
if url_components.scheme.lower() == 'socks5':
socks_type = ProxyType.SOCKS5
- elif url_components.scheme.lower() in ('socks', 'socks4'):
+ rdns = False
+ elif url_components.scheme.lower() == 'socks5h':
+ socks_type = ProxyType.SOCKS5
+ rdns = True
+ elif url_components.scheme.lower() == 'socks4':
socks_type = ProxyType.SOCKS4
+ rdns = False
elif url_components.scheme.lower() == 'socks4a':
socks_type = ProxyType.SOCKS4A
+ rdns = True
+ else:
+ raise ValueError(f'Unknown SOCKS proxy version: {url_components.scheme.lower()}')
def unquote_if_non_empty(s):
if not s:
@@ -57,12 +74,25 @@ def unquote_if_non_empty(s):
'proxytype': socks_type,
'addr': url_components.hostname,
'port': url_components.port or 1080,
- 'rdns': True,
+ 'rdns': rdns,
'username': unquote_if_non_empty(url_components.username),
'password': unquote_if_non_empty(url_components.password),
}
+def select_proxy(url, proxies):
+ """Unified proxy selector for all backends"""
+ url_components = urllib.parse.urlparse(url)
+ if 'no' in proxies:
+ hostport = url_components.hostname + format_field(url_components.port, None, ':%s')
+ if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}):
+ return
+ elif urllib.request.proxy_bypass(hostport): # check system settings
+ return
+
+ return traverse_obj(proxies, url_components.scheme or 'http', 'all')
+
+
def get_redirect_method(method, status):
"""Unified redirect method handling"""
@@ -126,14 +156,53 @@ def make_ssl_context(
client_certificate, keyfile=client_certificate_key,
password=client_certificate_password)
except ssl.SSLError:
- raise YoutubeDLError('Unable to load client certificate')
+ raise RequestError('Unable to load client certificate')
+ if getattr(context, 'post_handshake_auth', None) is not None:
+ context.post_handshake_auth = True
return context
-def add_accept_encoding_header(headers, supported_encodings):
- if supported_encodings and 'Accept-Encoding' not in headers:
- headers['Accept-Encoding'] = ', '.join(supported_encodings)
+class InstanceStoreMixin:
+ def __init__(self, **kwargs):
+ self.__instances = []
+ super().__init__(**kwargs) # So that both MRO works
- elif 'Accept-Encoding' not in headers:
- headers['Accept-Encoding'] = 'identity'
+ @staticmethod
+ def _create_instance(**kwargs):
+ raise NotImplementedError
+
+ def _get_instance(self, **kwargs):
+ for key, instance in self.__instances:
+ if key == kwargs:
+ return instance
+
+ instance = self._create_instance(**kwargs)
+ self.__instances.append((kwargs, instance))
+ return instance
+
+ def _close_instance(self, instance):
+ if callable(getattr(instance, 'close', None)):
+ instance.close()
+
+ def _clear_instances(self):
+ for _, instance in self.__instances:
+ self._close_instance(instance)
+ self.__instances.clear()
+
+
+def add_accept_encoding_header(headers: HTTPHeaderDict, supported_encodings: Iterable[str]):
+ if 'Accept-Encoding' not in headers:
+ headers['Accept-Encoding'] = ', '.join(supported_encodings) or 'identity'
+
+
+def wrap_request_errors(func):
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ try:
+ return func(self, *args, **kwargs)
+ except UnsupportedRequest as e:
+ if e.handler is None:
+ e.handler = self
+ raise
+ return wrapper
diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py
index 1f5871ae6..2c5f09872 100644
--- a/yt_dlp/networking/_urllib.py
+++ b/yt_dlp/networking/_urllib.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import functools
import gzip
import http.client
@@ -9,26 +11,48 @@
import urllib.request
import urllib.response
import zlib
+from urllib.request import (
+ DataHandler,
+ FileHandler,
+ FTPHandler,
+ HTTPCookieProcessor,
+ HTTPDefaultErrorHandler,
+ HTTPErrorProcessor,
+ UnknownHandler,
+)
from ._helper import (
+ InstanceStoreMixin,
add_accept_encoding_header,
get_redirect_method,
make_socks_proxy_opts,
+ select_proxy,
+)
+from .common import Features, RequestHandler, Response, register
+from .exceptions import (
+ CertificateVerifyError,
+ HTTPError,
+ IncompleteRead,
+ ProxyError,
+ RequestError,
+ SSLError,
+ TransportError,
)
from ..dependencies import brotli
+from ..socks import ProxyError as SocksProxyError
from ..socks import sockssocket
from ..utils import escape_url, update_url_query
-from ..utils.networking import clean_headers, std_headers
SUPPORTED_ENCODINGS = ['gzip', 'deflate']
+CONTENT_DECODE_ERRORS = [zlib.error, OSError]
if brotli:
SUPPORTED_ENCODINGS.append('br')
+ CONTENT_DECODE_ERRORS.append(brotli.error)
-def _create_http_connection(ydl_handler, http_class, is_https, *args, **kwargs):
+def _create_http_connection(http_class, source_address, *args, **kwargs):
hc = http_class(*args, **kwargs)
- source_address = ydl_handler._params.get('source_address')
if source_address is not None:
# This is to workaround _create_connection() from socket where it will try all
@@ -73,7 +97,7 @@ def _create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_a
return hc
-class HTTPHandler(urllib.request.HTTPHandler):
+class HTTPHandler(urllib.request.AbstractHTTPHandler):
"""Handler for HTTP requests and responses.
This class, when installed with an OpenerDirector, automatically adds
@@ -88,21 +112,30 @@ class HTTPHandler(urllib.request.HTTPHandler):
public domain.
"""
- def __init__(self, params, *args, **kwargs):
- urllib.request.HTTPHandler.__init__(self, *args, **kwargs)
- self._params = params
+ def __init__(self, context=None, source_address=None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._source_address = source_address
+ self._context = context
- def http_open(self, req):
- conn_class = http.client.HTTPConnection
-
- socks_proxy = req.headers.get('Ytdl-socks-proxy')
+ @staticmethod
+ def _make_conn_class(base, req):
+ conn_class = base
+ socks_proxy = req.headers.pop('Ytdl-socks-proxy', None)
if socks_proxy:
conn_class = make_socks_conn_class(conn_class, socks_proxy)
- del req.headers['Ytdl-socks-proxy']
+ return conn_class
+ def http_open(self, req):
+ conn_class = self._make_conn_class(http.client.HTTPConnection, req)
return self.do_open(functools.partial(
- _create_http_connection, self, conn_class, False),
- req)
+ _create_http_connection, conn_class, self._source_address), req)
+
+ def https_open(self, req):
+ conn_class = self._make_conn_class(http.client.HTTPSConnection, req)
+ return self.do_open(
+ functools.partial(
+ _create_http_connection, conn_class, self._source_address),
+ req, context=self._context)
@staticmethod
def deflate(data):
@@ -152,14 +185,6 @@ def http_request(self, req):
if url != url_escaped:
req = update_Request(req, url=url_escaped)
- for h, v in self._params.get('http_headers', std_headers).items():
- # Capitalize is needed because of Python bug 2275: http://bugs.python.org/issue2275
- # The dict keys are capitalized because of this bug by urllib
- if h.capitalize() not in req.headers:
- req.add_header(h, v)
-
- clean_headers(req.headers)
- add_accept_encoding_header(req.headers, SUPPORTED_ENCODINGS)
return super().do_request_(req)
def http_response(self, req, resp):
@@ -207,16 +232,12 @@ class SocksConnection(base_class):
def connect(self):
self.sock = sockssocket()
self.sock.setproxy(**proxy_args)
- if isinstance(self.timeout, (int, float)):
+ if type(self.timeout) in (int, float): # noqa: E721
self.sock.settimeout(self.timeout)
self.sock.connect((self.host, self.port))
if isinstance(self, http.client.HTTPSConnection):
- if hasattr(self, '_context'): # Python > 2.6
- self.sock = self._context.wrap_socket(
- self.sock, server_hostname=self.host)
- else:
- self.sock = ssl.wrap_socket(self.sock)
+ self.sock = self._context.wrap_socket(self.sock, server_hostname=self.host)
return SocksConnection
@@ -260,29 +281,25 @@ def redirect_request(self, req, fp, code, msg, headers, newurl):
unverifiable=True, method=new_method, data=new_data)
-class ProxyHandler(urllib.request.ProxyHandler):
+class ProxyHandler(urllib.request.BaseHandler):
+ handler_order = 100
+
def __init__(self, proxies=None):
+ self.proxies = proxies
# Set default handlers
- for type in ('http', 'https'):
- setattr(self, '%s_open' % type,
- lambda r, proxy='__noproxy__', type=type, meth=self.proxy_open:
- meth(r, proxy, type))
- urllib.request.ProxyHandler.__init__(self, proxies)
+ for type in ('http', 'https', 'ftp'):
+ setattr(self, '%s_open' % type, lambda r, meth=self.proxy_open: meth(r))
- def proxy_open(self, req, proxy, type):
- req_proxy = req.headers.get('Ytdl-request-proxy')
- if req_proxy is not None:
- proxy = req_proxy
- del req.headers['Ytdl-request-proxy']
-
- if proxy == '__noproxy__':
- return None # No Proxy
- if urllib.parse.urlparse(proxy).scheme.lower() in ('socks', 'socks4', 'socks4a', 'socks5'):
+ def proxy_open(self, req):
+ proxy = select_proxy(req.get_full_url(), self.proxies)
+ if proxy is None:
+ return
+ if urllib.parse.urlparse(proxy).scheme.lower() in ('socks4', 'socks4a', 'socks5', 'socks5h'):
req.add_header('Ytdl-socks-proxy', proxy)
# yt-dlp's http/https handlers do wrapping the socket with socks
return None
return urllib.request.ProxyHandler.proxy_open(
- self, req, proxy, type)
+ self, req, proxy, None)
class PUTRequest(urllib.request.Request):
@@ -313,3 +330,129 @@ def update_Request(req, url=None, data=None, headers=None, query=None):
if hasattr(req, 'timeout'):
new_req.timeout = req.timeout
return new_req
+
+
+class UrllibResponseAdapter(Response):
+ """
+ HTTP Response adapter class for urllib addinfourl and http.client.HTTPResponse
+ """
+
+ def __init__(self, res: http.client.HTTPResponse | urllib.response.addinfourl):
+ # addinfourl: In Python 3.9+, .status was introduced and .getcode() was deprecated [1]
+ # HTTPResponse: .getcode() was deprecated, .status always existed [2]
+ # 1. https://docs.python.org/3/library/urllib.request.html#urllib.response.addinfourl.getcode
+ # 2. https://docs.python.org/3.10/library/http.client.html#http.client.HTTPResponse.status
+ super().__init__(
+ fp=res, headers=res.headers, url=res.url,
+ status=getattr(res, 'status', None) or res.getcode(), reason=getattr(res, 'reason', None))
+
+ def read(self, amt=None):
+ try:
+ return self.fp.read(amt)
+ except Exception as e:
+ handle_response_read_exceptions(e)
+ raise e
+
+
+def handle_sslerror(e: ssl.SSLError):
+ if not isinstance(e, ssl.SSLError):
+ return
+ if isinstance(e, ssl.SSLCertVerificationError):
+ raise CertificateVerifyError(cause=e) from e
+ raise SSLError(cause=e) from e
+
+
+def handle_response_read_exceptions(e):
+ if isinstance(e, http.client.IncompleteRead):
+ raise IncompleteRead(partial=e.partial, cause=e, expected=e.expected) from e
+ elif isinstance(e, ssl.SSLError):
+ handle_sslerror(e)
+ elif isinstance(e, (OSError, EOFError, http.client.HTTPException, *CONTENT_DECODE_ERRORS)):
+ # OSErrors raised here should mostly be network related
+ raise TransportError(cause=e) from e
+
+
+@register
+class UrllibRH(RequestHandler, InstanceStoreMixin):
+ _SUPPORTED_URL_SCHEMES = ('http', 'https', 'data', 'ftp')
+ _SUPPORTED_PROXY_SCHEMES = ('http', 'socks4', 'socks4a', 'socks5', 'socks5h')
+ _SUPPORTED_FEATURES = (Features.NO_PROXY, Features.ALL_PROXY)
+ RH_NAME = 'urllib'
+
+ def __init__(self, *, enable_file_urls: bool = False, **kwargs):
+ super().__init__(**kwargs)
+ self.enable_file_urls = enable_file_urls
+ if self.enable_file_urls:
+ self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file')
+
+ def _create_instance(self, proxies, cookiejar):
+ opener = urllib.request.OpenerDirector()
+ handlers = [
+ ProxyHandler(proxies),
+ HTTPHandler(
+ debuglevel=int(bool(self.verbose)),
+ context=self._make_sslcontext(),
+ source_address=self.source_address),
+ HTTPCookieProcessor(cookiejar),
+ DataHandler(),
+ UnknownHandler(),
+ HTTPDefaultErrorHandler(),
+ FTPHandler(),
+ HTTPErrorProcessor(),
+ RedirectHandler(),
+ ]
+
+ if self.enable_file_urls:
+ handlers.append(FileHandler())
+
+ for handler in handlers:
+ opener.add_handler(handler)
+
+ # Delete the default user-agent header, which would otherwise apply in
+ # cases where our custom HTTP handler doesn't come into play
+ # (See https://github.com/ytdl-org/youtube-dl/issues/1309 for details)
+ opener.addheaders = []
+ return opener
+
+ def _send(self, request):
+ headers = self._merge_headers(request.headers)
+ add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
+ urllib_req = urllib.request.Request(
+ url=request.url,
+ data=request.data,
+ headers=dict(headers),
+ method=request.method
+ )
+
+ opener = self._get_instance(
+ proxies=request.proxies or self.proxies,
+ cookiejar=request.extensions.get('cookiejar') or self.cookiejar
+ )
+ try:
+ res = opener.open(urllib_req, timeout=float(request.extensions.get('timeout') or self.timeout))
+ except urllib.error.HTTPError as e:
+ if isinstance(e.fp, (http.client.HTTPResponse, urllib.response.addinfourl)):
+ # Prevent file object from being closed when urllib.error.HTTPError is destroyed.
+ e._closer.file = None
+ raise HTTPError(UrllibResponseAdapter(e.fp), redirect_loop='redirect error' in str(e)) from e
+ raise # unexpected
+ except urllib.error.URLError as e:
+ cause = e.reason # NOTE: cause may be a string
+
+ # proxy errors
+ if 'tunnel connection failed' in str(cause).lower() or isinstance(cause, SocksProxyError):
+ raise ProxyError(cause=e) from e
+
+ handle_response_read_exceptions(cause)
+ raise TransportError(cause=e) from e
+ except (http.client.InvalidURL, ValueError) as e:
+ # Validation errors
+ # http.client.HTTPConnection raises ValueError in some validation cases
+ # such as if request method contains illegal control characters [1]
+ # 1. https://github.com/python/cpython/blob/987b712b4aeeece336eed24fcc87a950a756c3e2/Lib/http/client.py#L1256
+ raise RequestError(cause=e) from e
+ except Exception as e:
+ handle_response_read_exceptions(e)
+ raise # unexpected
+
+ return UrllibResponseAdapter(res)
diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py
new file mode 100644
index 000000000..e4b362827
--- /dev/null
+++ b/yt_dlp/networking/common.py
@@ -0,0 +1,522 @@
+from __future__ import annotations
+
+import abc
+import copy
+import enum
+import functools
+import io
+import typing
+import urllib.parse
+import urllib.request
+import urllib.response
+from collections.abc import Iterable, Mapping
+from email.message import Message
+from http import HTTPStatus
+from http.cookiejar import CookieJar
+
+from ._helper import make_ssl_context, wrap_request_errors
+from .exceptions import (
+ NoSupportingHandlers,
+ RequestError,
+ TransportError,
+ UnsupportedRequest,
+)
+from ..utils import (
+ bug_reports_message,
+ classproperty,
+ error_to_str,
+ escape_url,
+ update_url_query,
+)
+from ..utils.networking import HTTPHeaderDict
+
+if typing.TYPE_CHECKING:
+ RequestData = bytes | Iterable[bytes] | typing.IO | None
+
+
+class RequestDirector:
+ """RequestDirector class
+
+ Helper class that, when given a request, forward it to a RequestHandler that supports it.
+
+ @param logger: Logger instance.
+ @param verbose: Print debug request information to stdout.
+ """
+
+ def __init__(self, logger, verbose=False):
+ self.handlers: dict[str, RequestHandler] = {}
+ self.logger = logger # TODO(Grub4k): default logger
+ self.verbose = verbose
+
+ def close(self):
+ for handler in self.handlers.values():
+ handler.close()
+
+ def add_handler(self, handler: RequestHandler):
+ """Add a handler. If a handler of the same RH_KEY exists, it will overwrite it"""
+ assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
+ self.handlers[handler.RH_KEY] = handler
+
+ def _print_verbose(self, msg):
+ if self.verbose:
+ self.logger.stdout(f'director: {msg}')
+
+ def send(self, request: Request) -> Response:
+ """
+ Passes a request onto a suitable RequestHandler
+ """
+ if not self.handlers:
+ raise RequestError('No request handlers configured')
+
+ assert isinstance(request, Request)
+
+ unexpected_errors = []
+ unsupported_errors = []
+ # TODO (future): add a per-request preference system
+ for handler in reversed(list(self.handlers.values())):
+ self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
+ try:
+ handler.validate(request)
+ except UnsupportedRequest as e:
+ self._print_verbose(
+ f'"{handler.RH_NAME}" cannot handle this request (reason: {error_to_str(e)})')
+ unsupported_errors.append(e)
+ continue
+
+ self._print_verbose(f'Sending request via "{handler.RH_NAME}"')
+ try:
+ response = handler.send(request)
+ except RequestError:
+ raise
+ except Exception as e:
+ self.logger.error(
+ f'[{handler.RH_NAME}] Unexpected error: {error_to_str(e)}{bug_reports_message()}',
+ is_error=False)
+ unexpected_errors.append(e)
+ continue
+
+ assert isinstance(response, Response)
+ return response
+
+ raise NoSupportingHandlers(unsupported_errors, unexpected_errors)
+
+
+_REQUEST_HANDLERS = {}
+
+
+def register(handler):
+ """Register a RequestHandler class"""
+ assert issubclass(handler, RequestHandler), f'{handler} must be a subclass of RequestHandler'
+ assert handler.RH_KEY not in _REQUEST_HANDLERS, f'RequestHandler {handler.RH_KEY} already registered'
+ _REQUEST_HANDLERS[handler.RH_KEY] = handler
+ return handler
+
+
+class Features(enum.Enum):
+ ALL_PROXY = enum.auto()
+ NO_PROXY = enum.auto()
+
+
+class RequestHandler(abc.ABC):
+
+ """Request Handler class
+
+ Request handlers are class that, given a Request,
+ process the request from start to finish and return a Response.
+
+ Concrete subclasses need to redefine the _send(request) method,
+ which handles the underlying request logic and returns a Response.
+
+ RH_NAME class variable may contain a display name for the RequestHandler.
+ By default, this is generated from the class name.
+
+ The concrete request handler MUST have "RH" as the suffix in the class name.
+
+ All exceptions raised by a RequestHandler should be an instance of RequestError.
+ Any other exception raised will be treated as a handler issue.
+
+ If a Request is not supported by the handler, an UnsupportedRequest
+ should be raised with a reason.
+
+ By default, some checks are done on the request in _validate() based on the following class variables:
+ - `_SUPPORTED_URL_SCHEMES`: a tuple of supported url schemes.
+ Any Request with an url scheme not in this list will raise an UnsupportedRequest.
+
+ - `_SUPPORTED_PROXY_SCHEMES`: a tuple of support proxy url schemes. Any Request that contains
+ a proxy url with an url scheme not in this list will raise an UnsupportedRequest.
+
+ - `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum.
+ The above may be set to None to disable the checks.
+
+ Parameters:
+ @param logger: logger instance
+ @param headers: HTTP Headers to include when sending requests.
+ @param cookiejar: Cookiejar to use for requests.
+ @param timeout: Socket timeout to use when sending requests.
+ @param proxies: Proxies to use for sending requests.
+ @param source_address: Client-side IP address to bind to for requests.
+ @param verbose: Print debug request and traffic information to stdout.
+ @param prefer_system_certs: Whether to prefer system certificates over other means (e.g. certifi).
+ @param client_cert: SSL client certificate configuration.
+ dict with {client_certificate, client_certificate_key, client_certificate_password}
+ @param verify: Verify SSL certificates
+ @param legacy_ssl_support: Enable legacy SSL options such as legacy server connect and older cipher support.
+
+ Some configuration options may be available for individual Requests too. In this case,
+ either the Request configuration option takes precedence or they are merged.
+
+ Requests may have additional optional parameters defined as extensions.
+ RequestHandler subclasses may choose to support custom extensions.
+
+ The following extensions are defined for RequestHandler:
+ - `cookiejar`: Cookiejar to use for this request
+ - `timeout`: socket timeout to use for this request
+
+ Apart from the url protocol, proxies dict may contain the following keys:
+ - `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol.
+ - `no`: comma seperated list of hostnames (optionally with port) to not use a proxy for.
+ Note: a RequestHandler may not support these, as defined in `_SUPPORTED_FEATURES`.
+
+ """
+
+ _SUPPORTED_URL_SCHEMES = ()
+ _SUPPORTED_PROXY_SCHEMES = ()
+ _SUPPORTED_FEATURES = ()
+
+ def __init__(
+ self, *,
+ logger, # TODO(Grub4k): default logger
+ headers: HTTPHeaderDict = None,
+ cookiejar: CookieJar = None,
+ timeout: float | int | None = None,
+ proxies: dict = None,
+ source_address: str = None,
+ verbose: bool = False,
+ prefer_system_certs: bool = False,
+ client_cert: dict[str, str | None] = None,
+ verify: bool = True,
+ legacy_ssl_support: bool = False,
+ **_,
+ ):
+
+ self._logger = logger
+ self.headers = headers or {}
+ self.cookiejar = cookiejar if cookiejar is not None else CookieJar()
+ self.timeout = float(timeout or 20)
+ self.proxies = proxies or {}
+ self.source_address = source_address
+ self.verbose = verbose
+ self.prefer_system_certs = prefer_system_certs
+ self._client_cert = client_cert or {}
+ self.verify = verify
+ self.legacy_ssl_support = legacy_ssl_support
+ super().__init__()
+
+ def _make_sslcontext(self):
+ return make_ssl_context(
+ verify=self.verify,
+ legacy_support=self.legacy_ssl_support,
+ use_certifi=not self.prefer_system_certs,
+ **self._client_cert,
+ )
+
+ def _merge_headers(self, request_headers):
+ return HTTPHeaderDict(self.headers, request_headers)
+
+ def _check_url_scheme(self, request: Request):
+ scheme = urllib.parse.urlparse(request.url).scheme.lower()
+ if self._SUPPORTED_URL_SCHEMES is not None and scheme not in self._SUPPORTED_URL_SCHEMES:
+ raise UnsupportedRequest(f'Unsupported url scheme: "{scheme}"')
+ return scheme # for further processing
+
+ def _check_proxies(self, proxies):
+ for proxy_key, proxy_url in proxies.items():
+ if proxy_url is None:
+ continue
+ if proxy_key == 'no':
+ if self._SUPPORTED_FEATURES is not None and Features.NO_PROXY not in self._SUPPORTED_FEATURES:
+ raise UnsupportedRequest('"no" proxy is not supported')
+ continue
+ if (
+ proxy_key == 'all'
+ and self._SUPPORTED_FEATURES is not None
+ and Features.ALL_PROXY not in self._SUPPORTED_FEATURES
+ ):
+ raise UnsupportedRequest('"all" proxy is not supported')
+
+ # Unlikely this handler will use this proxy, so ignore.
+ # This is to allow a case where a proxy may be set for a protocol
+ # for one handler in which such protocol (and proxy) is not supported by another handler.
+ if self._SUPPORTED_URL_SCHEMES is not None and proxy_key not in (*self._SUPPORTED_URL_SCHEMES, 'all'):
+ continue
+
+ if self._SUPPORTED_PROXY_SCHEMES is None:
+ # Skip proxy scheme checks
+ continue
+
+ # Scheme-less proxies are not supported
+ if urllib.request._parse_proxy(proxy_url)[0] is None:
+ raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme')
+
+ scheme = urllib.parse.urlparse(proxy_url).scheme.lower()
+ if scheme not in self._SUPPORTED_PROXY_SCHEMES:
+ raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"')
+
+ def _check_cookiejar_extension(self, extensions):
+ if not extensions.get('cookiejar'):
+ return
+ if not isinstance(extensions['cookiejar'], CookieJar):
+ raise UnsupportedRequest('cookiejar is not a CookieJar')
+
+ def _check_timeout_extension(self, extensions):
+ if extensions.get('timeout') is None:
+ return
+ if not isinstance(extensions['timeout'], (float, int)):
+ raise UnsupportedRequest('timeout is not a float or int')
+
+ def _check_extensions(self, extensions):
+ self._check_cookiejar_extension(extensions)
+ self._check_timeout_extension(extensions)
+
+ def _validate(self, request):
+ self._check_url_scheme(request)
+ self._check_proxies(request.proxies or self.proxies)
+ self._check_extensions(request.extensions)
+
+ @wrap_request_errors
+ def validate(self, request: Request):
+ if not isinstance(request, Request):
+ raise TypeError('Expected an instance of Request')
+ self._validate(request)
+
+ @wrap_request_errors
+ def send(self, request: Request) -> Response:
+ if not isinstance(request, Request):
+ raise TypeError('Expected an instance of Request')
+ return self._send(request)
+
+ @abc.abstractmethod
+ def _send(self, request: Request):
+ """Handle a request from start to finish. Redefine in subclasses."""
+
+ def close(self):
+ pass
+
+ @classproperty
+ def RH_NAME(cls):
+ return cls.__name__[:-2]
+
+ @classproperty
+ def RH_KEY(cls):
+ assert cls.__name__.endswith('RH'), 'RequestHandler class names must end with "RH"'
+ return cls.__name__[:-2]
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ self.close()
+
+
+class Request:
+ """
+ Represents a request to be made.
+ Partially backwards-compatible with urllib.request.Request.
+
+ @param url: url to send. Will be sanitized.
+ @param data: payload data to send. Must be bytes, iterable of bytes, a file-like object or None
+ @param headers: headers to send.
+ @param proxies: proxy dict mapping of proto:proxy to use for the request and any redirects.
+ @param query: URL query parameters to update the url with.
+ @param method: HTTP method to use. If no method specified, will use POST if payload data is present else GET
+ @param extensions: Dictionary of Request extensions to add, as supported by handlers.
+ """
+
+ def __init__(
+ self,
+ url: str,
+ data: RequestData = None,
+ headers: typing.Mapping = None,
+ proxies: dict = None,
+ query: dict = None,
+ method: str = None,
+ extensions: dict = None
+ ):
+
+ self._headers = HTTPHeaderDict()
+ self._data = None
+
+ if query:
+ url = update_url_query(url, query)
+
+ self.url = url
+ self.method = method
+ if headers:
+ self.headers = headers
+ self.data = data # note: must be done after setting headers
+ self.proxies = proxies or {}
+ self.extensions = extensions or {}
+
+ @property
+ def url(self):
+ return self._url
+
+ @url.setter
+ def url(self, url):
+ if not isinstance(url, str):
+ raise TypeError('url must be a string')
+ elif url.startswith('//'):
+ url = 'http:' + url
+ self._url = escape_url(url)
+
+ @property
+ def method(self):
+ return self._method or ('POST' if self.data is not None else 'GET')
+
+ @method.setter
+ def method(self, method):
+ if method is None:
+ self._method = None
+ elif isinstance(method, str):
+ self._method = method.upper()
+ else:
+ raise TypeError('method must be a string')
+
+ @property
+ def data(self):
+ return self._data
+
+ @data.setter
+ def data(self, data: RequestData):
+ # Try catch some common mistakes
+ if data is not None and (
+ not isinstance(data, (bytes, io.IOBase, Iterable)) or isinstance(data, (str, Mapping))
+ ):
+ raise TypeError('data must be bytes, iterable of bytes, or a file-like object')
+
+ if data == self._data and self._data is None:
+ self.headers.pop('Content-Length', None)
+
+ # https://docs.python.org/3/library/urllib.request.html#urllib.request.Request.data
+ if data != self._data:
+ if self._data is not None:
+ self.headers.pop('Content-Length', None)
+ self._data = data
+
+ if self._data is None:
+ self.headers.pop('Content-Type', None)
+
+ if 'Content-Type' not in self.headers and self._data is not None:
+ self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
+
+ @property
+ def headers(self) -> HTTPHeaderDict:
+ return self._headers
+
+ @headers.setter
+ def headers(self, new_headers: Mapping):
+ """Replaces headers of the request. If not a CaseInsensitiveDict, it will be converted to one."""
+ if isinstance(new_headers, HTTPHeaderDict):
+ self._headers = new_headers
+ elif isinstance(new_headers, Mapping):
+ self._headers = HTTPHeaderDict(new_headers)
+ else:
+ raise TypeError('headers must be a mapping')
+
+ def update(self, url=None, data=None, headers=None, query=None):
+ self.data = data or self.data
+ self.headers.update(headers or {})
+ self.url = update_url_query(url or self.url, query or {})
+
+ def copy(self):
+ return self.__class__(
+ url=self.url,
+ headers=copy.deepcopy(self.headers),
+ proxies=copy.deepcopy(self.proxies),
+ data=self._data,
+ extensions=copy.copy(self.extensions),
+ method=self._method,
+ )
+
+
+HEADRequest = functools.partial(Request, method='HEAD')
+PUTRequest = functools.partial(Request, method='PUT')
+
+
+class Response(io.IOBase):
+ """
+ Base class for HTTP response adapters.
+
+ By default, it provides a basic wrapper for a file-like response object.
+
+ Interface partially backwards-compatible with addinfourl and http.client.HTTPResponse.
+
+ @param fp: Original, file-like, response.
+ @param url: URL that this is a response of.
+ @param headers: response headers.
+ @param status: Response HTTP status code. Default is 200 OK.
+ @param reason: HTTP status reason. Will use built-in reasons based on status code if not provided.
+ """
+
+ def __init__(
+ self,
+ fp: typing.IO,
+ url: str,
+ headers: Mapping[str, str],
+ status: int = 200,
+ reason: str = None):
+
+ self.fp = fp
+ self.headers = Message()
+ for name, value in headers.items():
+ self.headers.add_header(name, value)
+ self.status = status
+ self.url = url
+ try:
+ self.reason = reason or HTTPStatus(status).phrase
+ except ValueError:
+ self.reason = None
+
+ def readable(self):
+ return self.fp.readable()
+
+ def read(self, amt: int = None) -> bytes:
+ # Expected errors raised here should be of type RequestError or subclasses.
+ # Subclasses should redefine this method with more precise error handling.
+ try:
+ return self.fp.read(amt)
+ except Exception as e:
+ raise TransportError(cause=e) from e
+
+ def close(self):
+ self.fp.close()
+ return super().close()
+
+ def get_header(self, name, default=None):
+ """Get header for name.
+ If there are multiple matching headers, return all seperated by comma."""
+ headers = self.headers.get_all(name)
+ if not headers:
+ return default
+ if name.title() == 'Set-Cookie':
+ # Special case, only get the first one
+ # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.3-4.1
+ return headers[0]
+ return ', '.join(headers)
+
+ # The following methods are for compatability reasons and are deprecated
+ @property
+ def code(self):
+ return self.status
+
+ def getcode(self):
+ return self.status
+
+ def geturl(self):
+ return self.url
+
+ def info(self):
+ return self.headers
+
+ def getheader(self, name, default=None):
+ return self.get_header(name, default)
diff --git a/yt_dlp/networking/exceptions.py b/yt_dlp/networking/exceptions.py
index 89b484a22..6fe8afb92 100644
--- a/yt_dlp/networking/exceptions.py
+++ b/yt_dlp/networking/exceptions.py
@@ -1,9 +1,197 @@
-import http.client
-import socket
-import ssl
+from __future__ import annotations
+
+import typing
import urllib.error
-network_exceptions = [urllib.error.URLError, http.client.HTTPException, socket.error]
-if hasattr(ssl, 'CertificateError'):
- network_exceptions.append(ssl.CertificateError)
-network_exceptions = tuple(network_exceptions)
+from ..utils import YoutubeDLError
+
+if typing.TYPE_CHECKING:
+ from .common import RequestHandler, Response
+
+
+class RequestError(YoutubeDLError):
+ def __init__(
+ self,
+ msg: str | None = None,
+ cause: Exception | str | None = None,
+ handler: RequestHandler = None
+ ):
+ self.handler = handler
+ self.cause = cause
+ if not msg and cause:
+ msg = str(cause)
+ super().__init__(msg)
+
+
+class UnsupportedRequest(RequestError):
+ """raised when a handler cannot handle a request"""
+ pass
+
+
+class NoSupportingHandlers(RequestError):
+ """raised when no handlers can support a request for various reasons"""
+
+ def __init__(self, unsupported_errors: list[UnsupportedRequest], unexpected_errors: list[Exception]):
+ self.unsupported_errors = unsupported_errors or []
+ self.unexpected_errors = unexpected_errors or []
+
+ # Print a quick summary of the errors
+ err_handler_map = {}
+ for err in unsupported_errors:
+ err_handler_map.setdefault(err.msg, []).append(err.handler.RH_NAME)
+
+ reason_str = ', '.join([f'{msg} ({", ".join(handlers)})' for msg, handlers in err_handler_map.items()])
+ if unexpected_errors:
+ reason_str = ' + '.join(filter(None, [reason_str, f'{len(unexpected_errors)} unexpected error(s)']))
+
+ err_str = 'Unable to handle request'
+ if reason_str:
+ err_str += f': {reason_str}'
+
+ super().__init__(msg=err_str)
+
+
+class TransportError(RequestError):
+ """Network related errors"""
+
+
+class HTTPError(RequestError):
+ def __init__(self, response: Response, redirect_loop=False):
+ self.response = response
+ self.status = response.status
+ self.reason = response.reason
+ self.redirect_loop = redirect_loop
+ msg = f'HTTP Error {response.status}: {response.reason}'
+ if redirect_loop:
+ msg += ' (redirect loop detected)'
+
+ super().__init__(msg=msg)
+
+ def close(self):
+ self.response.close()
+
+ def __repr__(self):
+ return f''
+
+
+class IncompleteRead(TransportError):
+ def __init__(self, partial, expected=None, **kwargs):
+ self.partial = partial
+ self.expected = expected
+ msg = f'{len(partial)} bytes read'
+ if expected is not None:
+ msg += f', {expected} more expected'
+
+ super().__init__(msg=msg, **kwargs)
+
+ def __repr__(self):
+ return f''
+
+
+class SSLError(TransportError):
+ pass
+
+
+class CertificateVerifyError(SSLError):
+ """Raised when certificate validated has failed"""
+ pass
+
+
+class ProxyError(TransportError):
+ pass
+
+
+class _CompatHTTPError(urllib.error.HTTPError, HTTPError):
+ """
+ Provides backwards compatibility with urllib.error.HTTPError.
+ Do not use this class directly, use HTTPError instead.
+ """
+
+ def __init__(self, http_error: HTTPError):
+ super().__init__(
+ url=http_error.response.url,
+ code=http_error.status,
+ msg=http_error.msg,
+ hdrs=http_error.response.headers,
+ fp=http_error.response
+ )
+ self._closer.file = None # Disable auto close
+ self._http_error = http_error
+ HTTPError.__init__(self, http_error.response, redirect_loop=http_error.redirect_loop)
+
+ @property
+ def status(self):
+ return self._http_error.status
+
+ @status.setter
+ def status(self, value):
+ return
+
+ @property
+ def reason(self):
+ return self._http_error.reason
+
+ @reason.setter
+ def reason(self, value):
+ return
+
+ @property
+ def headers(self):
+ return self._http_error.response.headers
+
+ @headers.setter
+ def headers(self, value):
+ return
+
+ def info(self):
+ return self.response.headers
+
+ def getcode(self):
+ return self.status
+
+ def geturl(self):
+ return self.response.url
+
+ @property
+ def code(self):
+ return self.status
+
+ @code.setter
+ def code(self, value):
+ return
+
+ @property
+ def url(self):
+ return self.response.url
+
+ @url.setter
+ def url(self, value):
+ return
+
+ @property
+ def hdrs(self):
+ return self.response.headers
+
+ @hdrs.setter
+ def hdrs(self, value):
+ return
+
+ @property
+ def filename(self):
+ return self.response.url
+
+ @filename.setter
+ def filename(self, value):
+ return
+
+ def __getattr__(self, name):
+ return super().__getattr__(name)
+
+ def __str__(self):
+ return str(self._http_error)
+
+ def __repr__(self):
+ return repr(self._http_error)
+
+
+network_exceptions = (HTTPError, TransportError)
diff --git a/yt_dlp/utils/_deprecated.py b/yt_dlp/utils/_deprecated.py
index ca0fb1614..e55d42354 100644
--- a/yt_dlp/utils/_deprecated.py
+++ b/yt_dlp/utils/_deprecated.py
@@ -10,16 +10,16 @@
from ._utils import preferredencoding
+from ..networking._urllib import HTTPHandler
# isort: split
+from .networking import random_user_agent, std_headers # noqa: F401
from ..networking._urllib import PUTRequest # noqa: F401
from ..networking._urllib import SUPPORTED_ENCODINGS, HEADRequest # noqa: F401
-from ..networking._urllib import HTTPHandler as YoutubeDLHandler # noqa: F401
from ..networking._urllib import ProxyHandler as PerRequestProxyHandler # noqa: F401
from ..networking._urllib import RedirectHandler as YoutubeDLRedirectHandler # noqa: F401
from ..networking._urllib import make_socks_conn_class, update_Request # noqa: F401
from ..networking.exceptions import network_exceptions # noqa: F401
-from .networking import random_user_agent, std_headers # noqa: F401
def encodeFilename(s, for_subprocess=False):
@@ -47,3 +47,12 @@ def decodeOption(optval):
def error_to_compat_str(err):
return str(err)
+
+
+class YoutubeDLHandler(HTTPHandler):
+ def __init__(self, params, *args, **kwargs):
+ self._params = params
+ super().__init__(*args, **kwargs)
+
+
+YoutubeDLHTTPSHandler = YoutubeDLHandler
diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py
index d5704cadc..d0e328716 100644
--- a/yt_dlp/utils/_utils.py
+++ b/yt_dlp/utils/_utils.py
@@ -15,8 +15,6 @@
import hmac
import html.entities
import html.parser
-import http.client
-import http.cookiejar
import inspect
import io
import itertools
@@ -897,6 +895,7 @@ def formatSeconds(secs, delim=':', msec=False):
def make_HTTPS_handler(params, **kwargs):
+ from ._deprecated import YoutubeDLHTTPSHandler
from ..networking._helper import make_ssl_context
return YoutubeDLHTTPSHandler(params, context=make_ssl_context(
verify=not params.get('nocheckcertificate'),
@@ -1140,38 +1139,6 @@ class XAttrUnavailableError(YoutubeDLError):
pass
-class YoutubeDLHTTPSHandler(urllib.request.HTTPSHandler):
- def __init__(self, params, https_conn_class=None, *args, **kwargs):
- urllib.request.HTTPSHandler.__init__(self, *args, **kwargs)
- self._https_conn_class = https_conn_class or http.client.HTTPSConnection
- self._params = params
-
- def https_open(self, req):
- kwargs = {}
- conn_class = self._https_conn_class
-
- if hasattr(self, '_context'): # python > 2.6
- kwargs['context'] = self._context
- if hasattr(self, '_check_hostname'): # python 3.x
- kwargs['check_hostname'] = self._check_hostname
-
- socks_proxy = req.headers.get('Ytdl-socks-proxy')
- if socks_proxy:
- from ..networking._urllib import make_socks_conn_class
- conn_class = make_socks_conn_class(conn_class, socks_proxy)
- del req.headers['Ytdl-socks-proxy']
-
- from ..networking._urllib import _create_http_connection
- try:
- return self.do_open(
- functools.partial(_create_http_connection, self, conn_class, True), req, **kwargs)
- except urllib.error.URLError as e:
- if (isinstance(e.reason, ssl.SSLError)
- and getattr(e.reason, 'reason', None) == 'SSLV3_ALERT_HANDSHAKE_FAILURE'):
- raise YoutubeDLError('SSLV3_ALERT_HANDSHAKE_FAILURE: Try using --legacy-server-connect')
- raise
-
-
def is_path_like(f):
return isinstance(f, (str, bytes, os.PathLike))
diff --git a/yt_dlp/utils/networking.py b/yt_dlp/utils/networking.py
index 95b54fabe..ac355ddc8 100644
--- a/yt_dlp/utils/networking.py
+++ b/yt_dlp/utils/networking.py
@@ -1,4 +1,9 @@
+import collections
import random
+import urllib.parse
+import urllib.request
+
+from ._utils import remove_start
def random_user_agent():
@@ -46,15 +51,67 @@ def random_user_agent():
return _USER_AGENT_TPL % random.choice(_CHROME_VERSIONS)
-std_headers = {
+class HTTPHeaderDict(collections.UserDict, dict):
+ """
+ Store and access keys case-insensitively.
+ The constructor can take multiple dicts, in which keys in the latter are prioritised.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ for dct in args:
+ if dct is not None:
+ self.update(dct)
+ self.update(kwargs)
+
+ def __setitem__(self, key, value):
+ super().__setitem__(key.title(), str(value))
+
+ def __getitem__(self, key):
+ return super().__getitem__(key.title())
+
+ def __delitem__(self, key):
+ super().__delitem__(key.title())
+
+ def __contains__(self, key):
+ return super().__contains__(key.title() if isinstance(key, str) else key)
+
+
+std_headers = HTTPHeaderDict({
'User-Agent': random_user_agent(),
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
'Accept-Language': 'en-us,en;q=0.5',
'Sec-Fetch-Mode': 'navigate',
-}
+})
-def clean_headers(headers):
- if 'Youtubedl-no-compression' in headers: # compat
- del headers['Youtubedl-no-compression']
+def clean_proxies(proxies: dict, headers: HTTPHeaderDict):
+ req_proxy = headers.pop('Ytdl-Request-Proxy', None)
+ if req_proxy:
+ proxies.clear() # XXX: compat: Ytdl-Request-Proxy takes preference over everything, including NO_PROXY
+ proxies['all'] = req_proxy
+ for proxy_key, proxy_url in proxies.items():
+ if proxy_url == '__noproxy__':
+ proxies[proxy_key] = None
+ continue
+ if proxy_key == 'no': # special case
+ continue
+ if proxy_url is not None:
+ # Ensure proxies without a scheme are http.
+ proxy_scheme = urllib.request._parse_proxy(proxy_url)[0]
+ if proxy_scheme is None:
+ proxies[proxy_key] = 'http://' + remove_start(proxy_url, '//')
+
+ replace_scheme = {
+ 'socks5': 'socks5h', # compat: socks5 was treated as socks5h
+ 'socks': 'socks4' # compat: non-standard
+ }
+ if proxy_scheme in replace_scheme:
+ proxies[proxy_key] = urllib.parse.urlunparse(
+ urllib.parse.urlparse(proxy_url)._replace(scheme=replace_scheme[proxy_scheme]))
+
+
+def clean_headers(headers: HTTPHeaderDict):
+ if 'Youtubedl-No-Compression' in headers: # compat
+ del headers['Youtubedl-No-Compression']
headers['Accept-Encoding'] = 'identity'