diff --git a/README.md b/README.md index a87b52832a..9f542844e0 100644 --- a/README.md +++ b/README.md @@ -1795,6 +1795,7 @@ The following extractors use this feature: * `player_client`: Clients to extract video data from. The currently available clients are `web`, `web_safari`, `web_embedded`, `web_music`, `web_creator`, `mweb`, `ios`, `android`, `android_vr`, `tv` and `tv_embedded`. By default, `tv,ios,web` is used, or `tv,web` is used when authenticating with cookies. The `web_music` client is added for `music.youtube.com` URLs when logged-in cookies are used. The `web_embedded` client is added for age-restricted videos but only works if the video is embeddable. The `tv_embedded` and `web_creator` clients are added for age-restricted videos if account age-verification is required. Some clients, such as `web` and `web_music`, require a `po_token` for their formats to be downloadable. Some clients, such as `web_creator`, will only work with authentication. Not all clients support authentication via cookies. You can use `default` for the default clients, or you can use `all` for all clients (not recommended). You can prefix a client with `-` to exclude it, e.g. `youtube:player_client=default,-ios` * `player_skip`: Skip some network requests that are generally needed for robust extraction. One or more of `configs` (skip client configs), `webpage` (skip initial webpage), `js` (skip js player), `initial_data` (skip initial data/next ep request). While these options can help reduce the number of requests needed or avoid some rate-limiting, they could cause issues such as missing formats or metadata. See [#860](https://github.com/yt-dlp/yt-dlp/pull/860) and [#12826](https://github.com/yt-dlp/yt-dlp/issues/12826) for more details * `player_params`: YouTube player parameters to use for player requests. Will overwrite any default ones set by yt-dlp. +* `player_js_variant`: The player javascript variant to use for signature and nsig deciphering. The known variants are: `main`, `tce`, `tv`, `tv_es6`, `phone`, `tablet`. Only `main` is recommended as a possible workaround; the others are for debugging purposes. The default is to use what is prescribed by the site, and can be selected with `actual` * `comment_sort`: `top` or `new` (default) - choose comment sorting mode (on YouTube's side) * `max_comments`: Limit the amount of comments to gather. Comma-separated list of integers representing `max-comments,max-parents,max-replies,max-replies-per-thread`. Default is `all,all,all,all` * E.g. `all,all,1000,10` will get a maximum of 1000 replies total, with up to 10 replies per thread. `1000,all,100` will get a maximum of 1000 comments, with a maximum of 100 replies total @@ -1805,7 +1806,11 @@ The following extractors use this feature: * `data_sync_id`: Overrides the account Data Sync ID used in Innertube API requests. This may be needed if you are using an account with `youtube:player_skip=webpage,configs` or `youtubetab:skip=webpage` * `visitor_data`: Overrides the Visitor Data used in Innertube API requests. This should be used with `player_skip=webpage,configs` and without cookies. Note: this may have adverse effects if used improperly. If a session from a browser is wanted, you should pass cookies instead (which contain the Visitor ID) * `po_token`: Proof of Origin (PO) Token(s) to use. Comma seperated list of PO Tokens in the format `CLIENT.CONTEXT+PO_TOKEN`, e.g. `youtube:po_token=web.gvs+XXX,web.player=XXX,web_safari.gvs+YYY`. Context can be either `gvs` (Google Video Server URLs) or `player` (Innertube player request) -* `player_js_variant`: The player javascript variant to use for signature and nsig deciphering. The known variants are: `main`, `tce`, `tv`, `tv_es6`, `phone`, `tablet`. Only `main` is recommended as a possible workaround; the others are for debugging purposes. The default is to use what is prescribed by the site, and can be selected with `actual` +* `pot_trace`: Enable debug logging for PO Token fetching. Either `true` or `false` (default) +* `fetch_pot`: Policy to use for fetching a PO Token from providers. One of `always` (always try fetch a PO Token regardless if the client requires one for the given context), `never` (never fetch a PO Token), or `auto` (default; only fetch a PO Token if the client requires one for the given context) + +#### youtubepot-webpo +* `bind_to_visitor_id`: Whether to use the Visitor ID instead of Visitor Data for caching WebPO tokens. Either `true` (default) or `false` #### youtubetab (YouTube playlists, channels, feeds, etc.) * `skip`: One or more of `webpage` (skip initial webpage download), `authcheck` (allow the download of playlists requiring authentication when no initial webpage is downloaded. This may cause unwanted behavior, see [#1122](https://github.com/yt-dlp/yt-dlp/pull/1122) for more details) diff --git a/test/test_YoutubeDL.py b/test/test_YoutubeDL.py index 708a04f92d..91312e4e5f 100644 --- a/test/test_YoutubeDL.py +++ b/test/test_YoutubeDL.py @@ -1435,6 +1435,27 @@ class TestYoutubeDL(unittest.TestCase): FakeYDL().close() assert all_plugins_loaded.value + def test_close_hooks(self): + # Should call all registered close hooks on close + close_hook_called = False + close_hook_two_called = False + + def close_hook(): + nonlocal close_hook_called + close_hook_called = True + + def close_hook_two(): + nonlocal close_hook_two_called + close_hook_two_called = True + + ydl = FakeYDL() + ydl.add_close_hook(close_hook) + ydl.add_close_hook(close_hook_two) + + ydl.close() + self.assertTrue(close_hook_called, 'Close hook was not called') + self.assertTrue(close_hook_two_called, 'Close hook two was not called') + if __name__ == '__main__': unittest.main() diff --git a/test/test_networking_utils.py b/test/test_networking_utils.py index 204fe87bda..a2feacba71 100644 --- a/test/test_networking_utils.py +++ b/test/test_networking_utils.py @@ -20,7 +20,6 @@ from yt_dlp.networking._helper import ( add_accept_encoding_header, get_redirect_method, make_socks_proxy_opts, - select_proxy, ssl_load_certs, ) from yt_dlp.networking.exceptions import ( @@ -28,7 +27,7 @@ from yt_dlp.networking.exceptions import ( IncompleteRead, ) from yt_dlp.socks import ProxyType -from yt_dlp.utils.networking import HTTPHeaderDict +from yt_dlp.utils.networking import HTTPHeaderDict, select_proxy TEST_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/test_pot/conftest.py b/test/test_pot/conftest.py new file mode 100644 index 0000000000..ff0667e928 --- /dev/null +++ b/test/test_pot/conftest.py @@ -0,0 +1,71 @@ +import collections + +import pytest + +from yt_dlp import YoutubeDL +from yt_dlp.cookies import YoutubeDLCookieJar +from yt_dlp.extractor.common import InfoExtractor +from yt_dlp.extractor.youtube.pot._provider import IEContentProviderLogger +from yt_dlp.extractor.youtube.pot.provider import PoTokenRequest, PoTokenContext +from yt_dlp.utils.networking import HTTPHeaderDict + + +class MockLogger(IEContentProviderLogger): + + log_level = IEContentProviderLogger.LogLevel.TRACE + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.messages = collections.defaultdict(list) + + def trace(self, message: str): + self.messages['trace'].append(message) + + def debug(self, message: str): + self.messages['debug'].append(message) + + def info(self, message: str): + self.messages['info'].append(message) + + def warning(self, message: str, *, once=False): + self.messages['warning'].append(message) + + def error(self, message: str): + self.messages['error'].append(message) + + +@pytest.fixture +def ie() -> InfoExtractor: + ydl = YoutubeDL() + return ydl.get_info_extractor('Youtube') + + +@pytest.fixture +def logger() -> MockLogger: + return MockLogger() + + +@pytest.fixture() +def pot_request() -> PoTokenRequest: + return PoTokenRequest( + context=PoTokenContext.GVS, + innertube_context={'client': {'clientName': 'WEB'}}, + innertube_host='youtube.com', + session_index=None, + player_url=None, + is_authenticated=False, + video_webpage=None, + + visitor_data='example-visitor-data', + data_sync_id='example-data-sync-id', + video_id='example-video-id', + + request_cookiejar=YoutubeDLCookieJar(), + request_proxy=None, + request_headers=HTTPHeaderDict(), + request_timeout=None, + request_source_address=None, + request_verify_tls=True, + + bypass_cache=False, + ) diff --git a/test/test_pot/test_pot_builtin_memorycache.py b/test/test_pot/test_pot_builtin_memorycache.py new file mode 100644 index 0000000000..ea19fbe29f --- /dev/null +++ b/test/test_pot/test_pot_builtin_memorycache.py @@ -0,0 +1,117 @@ +import threading +import time +from collections import OrderedDict +import pytest +from yt_dlp.extractor.youtube.pot._provider import IEContentProvider, BuiltinIEContentProvider +from yt_dlp.utils import bug_reports_message +from yt_dlp.extractor.youtube.pot._builtin.memory_cache import MemoryLRUPCP, memorylru_preference, initialize_global_cache +from yt_dlp.version import __version__ +from yt_dlp.extractor.youtube.pot._registry import _pot_cache_providers, _pot_memory_cache + + +class TestMemoryLRUPCS: + + def test_base_type(self): + assert issubclass(MemoryLRUPCP, IEContentProvider) + assert issubclass(MemoryLRUPCP, BuiltinIEContentProvider) + + @pytest.fixture + def pcp(self, ie, logger) -> MemoryLRUPCP: + return MemoryLRUPCP(ie, logger, {}, initialize_cache=lambda max_size: (OrderedDict(), threading.Lock(), max_size)) + + def test_is_registered(self): + assert _pot_cache_providers.value.get('MemoryLRU') == MemoryLRUPCP + + def test_initialization(self, pcp): + assert pcp.PROVIDER_NAME == 'memory' + assert pcp.PROVIDER_VERSION == __version__ + assert pcp.BUG_REPORT_MESSAGE == bug_reports_message(before='') + assert pcp.is_available() + + def test_store_and_get(self, pcp): + pcp.store('key1', 'value1', int(time.time()) + 60) + assert pcp.get('key1') == 'value1' + assert len(pcp.cache) == 1 + + def test_store_ignore_expired(self, pcp): + pcp.store('key1', 'value1', int(time.time()) - 1) + assert len(pcp.cache) == 0 + assert pcp.get('key1') is None + assert len(pcp.cache) == 0 + + def test_store_override_existing_key(self, ie, logger): + MAX_SIZE = 2 + pcp = MemoryLRUPCP(ie, logger, {}, initialize_cache=lambda max_size: (OrderedDict(), threading.Lock(), MAX_SIZE)) + pcp.store('key1', 'value1', int(time.time()) + 60) + pcp.store('key2', 'value2', int(time.time()) + 60) + assert len(pcp.cache) == 2 + pcp.store('key1', 'value2', int(time.time()) + 60) + # Ensure that the override key gets added to the end of the cache instead of in the same position + pcp.store('key3', 'value3', int(time.time()) + 60) + assert pcp.get('key1') == 'value2' + + def test_store_ignore_expired_existing_key(self, pcp): + pcp.store('key1', 'value2', int(time.time()) + 60) + pcp.store('key1', 'value1', int(time.time()) - 1) + assert len(pcp.cache) == 1 + assert pcp.get('key1') == 'value2' + assert len(pcp.cache) == 1 + + def test_get_key_expired(self, pcp): + pcp.store('key1', 'value1', int(time.time()) + 60) + assert pcp.get('key1') == 'value1' + assert len(pcp.cache) == 1 + pcp.cache['key1'] = ('value1', int(time.time()) - 1) + assert pcp.get('key1') is None + assert len(pcp.cache) == 0 + + def test_lru_eviction(self, ie, logger): + MAX_SIZE = 2 + provider = MemoryLRUPCP(ie, logger, {}, initialize_cache=lambda max_size: (OrderedDict(), threading.Lock(), MAX_SIZE)) + provider.store('key1', 'value1', int(time.time()) + 5) + provider.store('key2', 'value2', int(time.time()) + 5) + assert len(provider.cache) == 2 + + assert provider.get('key1') == 'value1' + + provider.store('key3', 'value3', int(time.time()) + 5) + assert len(provider.cache) == 2 + + assert provider.get('key2') is None + + provider.store('key4', 'value4', int(time.time()) + 5) + assert len(provider.cache) == 2 + + assert provider.get('key1') is None + assert provider.get('key3') == 'value3' + assert provider.get('key4') == 'value4' + + def test_delete(self, pcp): + pcp.store('key1', 'value1', int(time.time()) + 5) + assert len(pcp.cache) == 1 + assert pcp.get('key1') == 'value1' + pcp.delete('key1') + assert len(pcp.cache) == 0 + assert pcp.get('key1') is None + + def test_use_global_cache_default(self, ie, logger): + pcp = MemoryLRUPCP(ie, logger, {}) + assert pcp.max_size == _pot_memory_cache.value['max_size'] == 25 + assert pcp.cache is _pot_memory_cache.value['cache'] + assert pcp.lock is _pot_memory_cache.value['lock'] + + pcp2 = MemoryLRUPCP(ie, logger, {}) + assert pcp.max_size == pcp2.max_size == _pot_memory_cache.value['max_size'] == 25 + assert pcp.cache is pcp2.cache is _pot_memory_cache.value['cache'] + assert pcp.lock is pcp2.lock is _pot_memory_cache.value['lock'] + + def test_fail_max_size_change_global(self, ie, logger): + pcp = MemoryLRUPCP(ie, logger, {}) + assert pcp.max_size == _pot_memory_cache.value['max_size'] == 25 + with pytest.raises(ValueError, match='Cannot change max_size of initialized global memory cache'): + initialize_global_cache(50) + + assert pcp.max_size == _pot_memory_cache.value['max_size'] == 25 + + def test_memory_lru_preference(self, pcp, ie, pot_request): + assert memorylru_preference(pcp, pot_request) == 10000 diff --git a/test/test_pot/test_pot_builtin_utils.py b/test/test_pot/test_pot_builtin_utils.py new file mode 100644 index 0000000000..1682e42a16 --- /dev/null +++ b/test/test_pot/test_pot_builtin_utils.py @@ -0,0 +1,46 @@ +import pytest +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenContext, + +) + +from yt_dlp.extractor.youtube.pot.utils import get_webpo_content_binding, ContentBindingType + + +class TestGetWebPoContentBinding: + + @pytest.mark.parametrize('client_name, context, is_authenticated, expected', [ + *[(client, context, is_authenticated, expected) for client in [ + 'WEB', 'MWEB', 'TVHTML5', 'WEB_EMBEDDED_PLAYER', 'WEB_CREATOR', 'TVHTML5_SIMPLY_EMBEDDED_PLAYER'] + for context, is_authenticated, expected in [ + (PoTokenContext.GVS, False, ('example-visitor-data', ContentBindingType.VISITOR_DATA)), + (PoTokenContext.PLAYER, False, ('example-video-id', ContentBindingType.VIDEO_ID)), + (PoTokenContext.GVS, True, ('example-data-sync-id', ContentBindingType.DATASYNC_ID)), + ]], + ('WEB_REMIX', PoTokenContext.GVS, False, ('example-visitor-data', ContentBindingType.VISITOR_DATA)), + ('WEB_REMIX', PoTokenContext.PLAYER, False, ('example-visitor-data', ContentBindingType.VISITOR_DATA)), + ('ANDROID', PoTokenContext.GVS, False, (None, None)), + ('IOS', PoTokenContext.GVS, False, (None, None)), + ]) + def test_get_webpo_content_binding(self, pot_request, client_name, context, is_authenticated, expected): + pot_request.innertube_context['client']['clientName'] = client_name + pot_request.context = context + pot_request.is_authenticated = is_authenticated + assert get_webpo_content_binding(pot_request) == expected + + def test_extract_visitor_id(self, pot_request): + pot_request.visitor_data = 'CgsxMjNhYmNYWVpfLSiA4s%2DqBg%3D%3D' + assert get_webpo_content_binding(pot_request, bind_to_visitor_id=True) == ('123abcXYZ_-', ContentBindingType.VISITOR_ID) + + def test_invalid_visitor_id(self, pot_request): + # visitor id not alphanumeric (i.e. protobuf extraction failed) + pot_request.visitor_data = 'CggxMjM0NTY3OCiA4s-qBg%3D%3D' + assert get_webpo_content_binding(pot_request, bind_to_visitor_id=True) == (pot_request.visitor_data, ContentBindingType.VISITOR_DATA) + + def test_no_visitor_id(self, pot_request): + pot_request.visitor_data = 'KIDiz6oG' + assert get_webpo_content_binding(pot_request, bind_to_visitor_id=True) == (pot_request.visitor_data, ContentBindingType.VISITOR_DATA) + + def test_invalid_base64(self, pot_request): + pot_request.visitor_data = 'invalid-base64' + assert get_webpo_content_binding(pot_request, bind_to_visitor_id=True) == (pot_request.visitor_data, ContentBindingType.VISITOR_DATA) diff --git a/test/test_pot/test_pot_builtin_webpospec.py b/test/test_pot/test_pot_builtin_webpospec.py new file mode 100644 index 0000000000..c5fb6f3820 --- /dev/null +++ b/test/test_pot/test_pot_builtin_webpospec.py @@ -0,0 +1,92 @@ +import pytest + +from yt_dlp.extractor.youtube.pot._provider import IEContentProvider, BuiltinIEContentProvider +from yt_dlp.extractor.youtube.pot.cache import CacheProviderWritePolicy +from yt_dlp.utils import bug_reports_message +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenRequest, + PoTokenContext, + +) +from yt_dlp.version import __version__ + +from yt_dlp.extractor.youtube.pot._builtin.webpo_cachespec import WebPoPCSP +from yt_dlp.extractor.youtube.pot._registry import _pot_pcs_providers + + +@pytest.fixture() +def pot_request(pot_request) -> PoTokenRequest: + pot_request.visitor_data = 'CgsxMjNhYmNYWVpfLSiA4s%2DqBg%3D%3D' # visitor_id=123abcXYZ_- + return pot_request + + +class TestWebPoPCSP: + def test_base_type(self): + assert issubclass(WebPoPCSP, IEContentProvider) + assert issubclass(WebPoPCSP, BuiltinIEContentProvider) + + def test_init(self, ie, logger): + pcs = WebPoPCSP(ie=ie, logger=logger, settings={}) + assert pcs.PROVIDER_NAME == 'webpo' + assert pcs.PROVIDER_VERSION == __version__ + assert pcs.BUG_REPORT_MESSAGE == bug_reports_message(before='') + assert pcs.is_available() + + def test_is_registered(self): + assert _pot_pcs_providers.value.get('WebPo') == WebPoPCSP + + @pytest.mark.parametrize('client_name, context, is_authenticated', [ + ('ANDROID', PoTokenContext.GVS, False), + ('IOS', PoTokenContext.GVS, False), + ('IOS', PoTokenContext.PLAYER, False), + ]) + def test_not_supports(self, ie, logger, pot_request, client_name, context, is_authenticated): + pcs = WebPoPCSP(ie=ie, logger=logger, settings={}) + pot_request.innertube_context['client']['clientName'] = client_name + pot_request.context = context + pot_request.is_authenticated = is_authenticated + assert pcs.generate_cache_spec(pot_request) is None + + @pytest.mark.parametrize('client_name, context, is_authenticated, remote_host, source_address, request_proxy, expected', [ + *[(client, context, is_authenticated, remote_host, source_address, request_proxy, expected) for client in [ + 'WEB', 'MWEB', 'TVHTML5', 'WEB_EMBEDDED_PLAYER', 'WEB_CREATOR', 'TVHTML5_SIMPLY_EMBEDDED_PLAYER'] + for context, is_authenticated, remote_host, source_address, request_proxy, expected in [ + (PoTokenContext.GVS, False, 'example-remote-host', 'example-source-address', 'example-request-proxy', {'t': 'webpo', 'ip': 'example-remote-host', 'sa': 'example-source-address', 'px': 'example-request-proxy', 'cb': '123abcXYZ_-', 'cbt': 'visitor_id'}), + (PoTokenContext.PLAYER, False, 'example-remote-host', 'example-source-address', 'example-request-proxy', {'t': 'webpo', 'ip': 'example-remote-host', 'sa': 'example-source-address', 'px': 'example-request-proxy', 'cb': '123abcXYZ_-', 'cbt': 'video_id'}), + (PoTokenContext.GVS, True, 'example-remote-host', 'example-source-address', 'example-request-proxy', {'t': 'webpo', 'ip': 'example-remote-host', 'sa': 'example-source-address', 'px': 'example-request-proxy', 'cb': 'example-data-sync-id', 'cbt': 'datasync_id'}), + ]], + ('WEB_REMIX', PoTokenContext.PLAYER, False, 'example-remote-host', 'example-source-address', 'example-request-proxy', {'t': 'webpo', 'ip': 'example-remote-host', 'sa': 'example-source-address', 'px': 'example-request-proxy', 'cb': '123abcXYZ_-', 'cbt': 'visitor_id'}), + ('WEB', PoTokenContext.GVS, False, None, None, None, {'t': 'webpo', 'cb': '123abcXYZ_-', 'cbt': 'visitor_id', 'ip': None, 'sa': None, 'px': None}), + ('TVHTML5', PoTokenContext.PLAYER, False, None, None, 'http://example.com', {'t': 'webpo', 'cb': '123abcXYZ_-', 'cbt': 'video_id', 'ip': None, 'sa': None, 'px': 'http://example.com'}), + + ]) + def test_generate_key_bindings(self, ie, logger, pot_request, client_name, context, is_authenticated, remote_host, source_address, request_proxy, expected): + pcs = WebPoPCSP(ie=ie, logger=logger, settings={}) + pot_request.innertube_context['client']['clientName'] = client_name + pot_request.context = context + pot_request.is_authenticated = is_authenticated + pot_request.innertube_context['client']['remoteHost'] = remote_host + pot_request.request_source_address = source_address + pot_request.request_proxy = request_proxy + pot_request.video_id = '123abcXYZ_-' # same as visitor id to test type + + assert pcs.generate_cache_spec(pot_request).key_bindings == expected + + def test_no_bind_visitor_id(self, ie, logger, pot_request): + # Should not bind to visitor id if setting is set to False + pcs = WebPoPCSP(ie=ie, logger=logger, settings={'bind_to_visitor_id': ['false']}) + pot_request.innertube_context['client']['clientName'] = 'WEB' + pot_request.context = PoTokenContext.GVS + pot_request.is_authenticated = False + assert pcs.generate_cache_spec(pot_request).key_bindings == {'t': 'webpo', 'ip': None, 'sa': None, 'px': None, 'cb': 'CgsxMjNhYmNYWVpfLSiA4s%2DqBg%3D%3D', 'cbt': 'visitor_data'} + + def test_default_ttl(self, ie, logger, pot_request): + pcs = WebPoPCSP(ie=ie, logger=logger, settings={}) + assert pcs.generate_cache_spec(pot_request).default_ttl == 6 * 60 * 60 # should default to 6 hours + + def test_write_policy(self, ie, logger, pot_request): + pcs = WebPoPCSP(ie=ie, logger=logger, settings={}) + pot_request.context = PoTokenContext.GVS + assert pcs.generate_cache_spec(pot_request).write_policy == CacheProviderWritePolicy.WRITE_ALL + pot_request.context = PoTokenContext.PLAYER + assert pcs.generate_cache_spec(pot_request).write_policy == CacheProviderWritePolicy.WRITE_FIRST diff --git a/test/test_pot/test_pot_director.py b/test/test_pot/test_pot_director.py new file mode 100644 index 0000000000..bbfdd0e98e --- /dev/null +++ b/test/test_pot/test_pot_director.py @@ -0,0 +1,1529 @@ +from __future__ import annotations +import abc +import base64 +import dataclasses +import hashlib +import json +import time +import pytest + +from yt_dlp.extractor.youtube.pot._provider import BuiltinIEContentProvider, IEContentProvider + +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenRequest, + PoTokenContext, + PoTokenProviderError, + PoTokenProviderRejectedRequest, +) +from yt_dlp.extractor.youtube.pot._director import ( + PoTokenCache, + validate_cache_spec, + clean_pot, + validate_response, + PoTokenRequestDirector, + provider_display_list, +) + +from yt_dlp.extractor.youtube.pot.cache import ( + PoTokenCacheSpec, + PoTokenCacheSpecProvider, + PoTokenCacheProvider, + CacheProviderWritePolicy, + PoTokenCacheProviderError, +) + + +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenResponse, + PoTokenProvider, +) + + +class BaseMockPoTokenProvider(PoTokenProvider, abc.ABC): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.available_called_times = 0 + self.request_called_times = 0 + self.close_called = False + + def is_available(self) -> bool: + self.available_called_times += 1 + return True + + def request_pot(self, *args, **kwargs): + self.request_called_times += 1 + return super().request_pot(*args, **kwargs) + + def close(self): + self.close_called = True + super().close() + + +class ExamplePTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'example' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://example.com/issues' + + _SUPPORTED_CLIENTS = ('WEB',) + _SUPPORTED_CONTEXTS = (PoTokenContext.GVS, ) + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + if request.data_sync_id == 'example': + return PoTokenResponse(request.video_id) + return PoTokenResponse(EXAMPLE_PO_TOKEN) + + +def success_ptp(response: PoTokenResponse | None = None, key: str | None = None): + class SuccessPTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'success' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://success.example.com/issues' + + _SUPPORTED_CLIENTS = ('WEB',) + _SUPPORTED_CONTEXTS = (PoTokenContext.GVS,) + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + return response or PoTokenResponse(EXAMPLE_PO_TOKEN) + + if key: + SuccessPTP.PROVIDER_KEY = key + return SuccessPTP + + +@pytest.fixture +def pot_provider(ie, logger): + return success_ptp()(ie=ie, logger=logger, settings={}) + + +class UnavailablePTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'unavailable' + BUG_REPORT_LOCATION = 'https://unavailable.example.com/issues' + _SUPPORTED_CLIENTS = None + _SUPPORTED_CONTEXTS = None + + def is_available(self) -> bool: + super().is_available() + return False + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderError('something went wrong') + + +class UnsupportedPTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'unsupported' + BUG_REPORT_LOCATION = 'https://unsupported.example.com/issues' + _SUPPORTED_CLIENTS = None + _SUPPORTED_CONTEXTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderRejectedRequest('unsupported request') + + +class ErrorPTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'error' + BUG_REPORT_LOCATION = 'https://error.example.com/issues' + _SUPPORTED_CLIENTS = None + _SUPPORTED_CONTEXTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + expected = request.video_id == 'expected' + raise PoTokenProviderError('an error occurred', expected=expected) + + +class UnexpectedErrorPTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'unexpected_error' + BUG_REPORT_LOCATION = 'https://unexpected.example.com/issues' + _SUPPORTED_CLIENTS = None + _SUPPORTED_CONTEXTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise ValueError('an unexpected error occurred') + + +class InvalidPTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'invalid' + BUG_REPORT_LOCATION = 'https://invalid.example.com/issues' + _SUPPORTED_CLIENTS = None + _SUPPORTED_CONTEXTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + if request.video_id == 'invalid_type': + return 'invalid-response' + else: + return PoTokenResponse('example-token?', expires_at='123') + + +class BaseMockCacheSpecProvider(PoTokenCacheSpecProvider, abc.ABC): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.generate_called_times = 0 + self.is_available_called_times = 0 + self.close_called = False + + def is_available(self) -> bool: + self.is_available_called_times += 1 + return super().is_available() + + def generate_cache_spec(self, request: PoTokenRequest): + self.generate_called_times += 1 + + def close(self): + self.close_called = True + super().close() + + +class ExampleCacheSpecProviderPCSP(BaseMockCacheSpecProvider): + + PROVIDER_NAME = 'example' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://example.com/issues' + + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return PoTokenCacheSpec( + key_bindings={'v': request.video_id, 'e': None}, + default_ttl=60, + ) + + +class UnavailableCacheSpecProviderPCSP(BaseMockCacheSpecProvider): + + PROVIDER_NAME = 'unavailable' + PROVIDER_VERSION = '0.0.1' + + def is_available(self) -> bool: + super().is_available() + return False + + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return None + + +class UnsupportedCacheSpecProviderPCSP(BaseMockCacheSpecProvider): + + PROVIDER_NAME = 'unsupported' + PROVIDER_VERSION = '0.0.1' + + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return None + + +class InvalidSpecCacheSpecProviderPCSP(BaseMockCacheSpecProvider): + + PROVIDER_NAME = 'invalid' + PROVIDER_VERSION = '0.0.1' + + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return 'invalid-spec' + + +class ErrorSpecCacheSpecProviderPCSP(BaseMockCacheSpecProvider): + + PROVIDER_NAME = 'invalid' + PROVIDER_VERSION = '0.0.1' + + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + raise ValueError('something went wrong') + + +class BaseMockCacheProvider(PoTokenCacheProvider, abc.ABC): + BUG_REPORT_MESSAGE = 'example bug report message' + + def __init__(self, *args, available=True, **kwargs): + super().__init__(*args, **kwargs) + self.store_calls = 0 + self.delete_calls = 0 + self.get_calls = 0 + self.available_called_times = 0 + self.available = available + + def is_available(self) -> bool: + self.available_called_times += 1 + return self.available + + def store(self, *args, **kwargs): + self.store_calls += 1 + + def delete(self, *args, **kwargs): + self.delete_calls += 1 + + def get(self, *args, **kwargs): + self.get_calls += 1 + + def close(self): + self.close_called = True + super().close() + + +class ErrorPCP(BaseMockCacheProvider): + PROVIDER_NAME = 'error' + + def store(self, *args, **kwargs): + super().store(*args, **kwargs) + raise PoTokenCacheProviderError('something went wrong') + + def get(self, *args, **kwargs): + super().get(*args, **kwargs) + raise PoTokenCacheProviderError('something went wrong') + + +class UnexpectedErrorPCP(BaseMockCacheProvider): + PROVIDER_NAME = 'unexpected_error' + + def store(self, *args, **kwargs): + super().store(*args, **kwargs) + raise ValueError('something went wrong') + + def get(self, *args, **kwargs): + super().get(*args, **kwargs) + raise ValueError('something went wrong') + + +class MockMemoryPCP(BaseMockCacheProvider): + PROVIDER_NAME = 'memory' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cache = {} + + def store(self, key, value, expires_at): + super().store(key, value, expires_at) + self.cache[key] = (value, expires_at) + + def delete(self, key): + super().delete(key) + self.cache.pop(key, None) + + def get(self, key): + super().get(key) + return self.cache.get(key, [None])[0] + + +def create_memory_pcp(ie, logger, provider_key='memory', provider_name='memory', available=True): + cache = MockMemoryPCP(ie, logger, {}, available=available) + cache.PROVIDER_KEY = provider_key + cache.PROVIDER_NAME = provider_name + return cache + + +@pytest.fixture +def memorypcp(ie, logger) -> MockMemoryPCP: + return create_memory_pcp(ie, logger) + + +@pytest.fixture +def pot_cache(ie, logger): + class MockPoTokenCache(PoTokenCache): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.get_calls = 0 + self.store_calls = 0 + self.close_called = False + + def get(self, *args, **kwargs): + self.get_calls += 1 + return super().get(*args, **kwargs) + + def store(self, *args, **kwargs): + self.store_calls += 1 + return super().store(*args, **kwargs) + + def close(self): + self.close_called = True + super().close() + + return MockPoTokenCache( + cache_providers=[MockMemoryPCP(ie, logger, {})], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie, logger, settings={})], + logger=logger, + ) + + +EXAMPLE_PO_TOKEN = base64.urlsafe_b64encode(b'example-token').decode() + + +class TestPoTokenCache: + + def test_cache_success(self, memorypcp, pot_request, ie, logger): + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + + cached_response = cache.get(pot_request) + assert cached_response is not None + assert cached_response.po_token == EXAMPLE_PO_TOKEN + assert cached_response.expires_at is not None + + assert cache.get(dataclasses.replace(pot_request, video_id='another-video-id')) is None + + def test_unsupported_cache_spec_no_fallback(self, memorypcp, pot_request, ie, logger): + unsupported_provider = UnsupportedCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[unsupported_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + assert cache.get(pot_request) is None + assert unsupported_provider.generate_called_times == 1 + cache.store(pot_request, response) + assert len(memorypcp.cache) == 0 + assert unsupported_provider.generate_called_times == 2 + assert cache.get(pot_request) is None + assert unsupported_provider.generate_called_times == 3 + assert len(logger.messages.get('error', [])) == 0 + + def test_unsupported_cache_spec_fallback(self, memorypcp, pot_request, ie, logger): + unsupported_provider = UnsupportedCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + example_provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[unsupported_provider, example_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + assert unsupported_provider.generate_called_times == 1 + assert example_provider.generate_called_times == 1 + + cache.store(pot_request, response) + assert unsupported_provider.generate_called_times == 2 + assert example_provider.generate_called_times == 2 + + cached_response = cache.get(pot_request) + assert unsupported_provider.generate_called_times == 3 + assert example_provider.generate_called_times == 3 + assert cached_response is not None + assert cached_response.po_token == EXAMPLE_PO_TOKEN + assert cached_response.expires_at is not None + + assert len(logger.messages.get('error', [])) == 0 + + def test_invalid_cache_spec_no_fallback(self, memorypcp, pot_request, ie, logger): + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[InvalidSpecCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + + assert cache.get(pot_request) is None + + assert 'PoTokenCacheSpecProvider "InvalidSpecCacheSpecProvider" generate_cache_spec() returned invalid spec invalid-spec; please report this issue to the provider developer at (developer has not provided a bug report location) .' in logger.messages['error'] + + def test_invalid_cache_spec_fallback(self, memorypcp, pot_request, ie, logger): + + invalid_provider = InvalidSpecCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + example_provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[invalid_provider, example_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + assert invalid_provider.generate_called_times == example_provider.generate_called_times == 1 + + cache.store(pot_request, response) + assert invalid_provider.generate_called_times == example_provider.generate_called_times == 2 + + cached_response = cache.get(pot_request) + assert invalid_provider.generate_called_times == example_provider.generate_called_times == 3 + assert cached_response is not None + assert cached_response.po_token == EXAMPLE_PO_TOKEN + assert cached_response.expires_at is not None + + assert 'PoTokenCacheSpecProvider "InvalidSpecCacheSpecProvider" generate_cache_spec() returned invalid spec invalid-spec; please report this issue to the provider developer at (developer has not provided a bug report location) .' in logger.messages['error'] + + def test_unavailable_cache_spec_no_fallback(self, memorypcp, pot_request, ie, logger): + unavailable_provider = UnavailableCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[unavailable_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + assert cache.get(pot_request) is None + assert unavailable_provider.generate_called_times == 0 + + def test_unavailable_cache_spec_fallback(self, memorypcp, pot_request, ie, logger): + unavailable_provider = UnavailableCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + example_provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[unavailable_provider, example_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + assert unavailable_provider.generate_called_times == 0 + assert unavailable_provider.is_available_called_times == 1 + assert example_provider.generate_called_times == 1 + + cache.store(pot_request, response) + assert unavailable_provider.generate_called_times == 0 + assert unavailable_provider.is_available_called_times == 2 + assert example_provider.generate_called_times == 2 + + cached_response = cache.get(pot_request) + assert unavailable_provider.generate_called_times == 0 + assert unavailable_provider.is_available_called_times == 3 + assert example_provider.generate_called_times == 3 + assert example_provider.is_available_called_times == 3 + assert cached_response is not None + assert cached_response.po_token == EXAMPLE_PO_TOKEN + assert cached_response.expires_at is not None + + def test_unexpected_error_cache_spec(self, memorypcp, pot_request, ie, logger): + error_provider = ErrorSpecCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[error_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + assert cache.get(pot_request) is None + assert error_provider.generate_called_times == 3 + assert error_provider.is_available_called_times == 3 + + assert 'Error occurred with "invalid" PO Token cache spec provider: ValueError(\'something went wrong\'); please report this issue to the provider developer at (developer has not provided a bug report location) .' in logger.messages['error'] + + def test_unexpected_error_cache_spec_fallback(self, memorypcp, pot_request, ie, logger): + error_provider = ErrorSpecCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + example_provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[error_provider, example_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + assert error_provider.generate_called_times == 1 + assert error_provider.is_available_called_times == 1 + assert example_provider.generate_called_times == 1 + + cache.store(pot_request, response) + assert error_provider.generate_called_times == 2 + assert error_provider.is_available_called_times == 2 + assert example_provider.generate_called_times == 2 + + cached_response = cache.get(pot_request) + assert error_provider.generate_called_times == 3 + assert error_provider.is_available_called_times == 3 + assert example_provider.generate_called_times == 3 + assert example_provider.is_available_called_times == 3 + assert cached_response is not None + assert cached_response.po_token == EXAMPLE_PO_TOKEN + assert cached_response.expires_at is not None + + assert 'Error occurred with "invalid" PO Token cache spec provider: ValueError(\'something went wrong\'); please report this issue to the provider developer at (developer has not provided a bug report location) .' in logger.messages['error'] + + def test_key_bindings_spec_provider(self, memorypcp, pot_request, ie, logger): + + class ExampleProviderPCSP(PoTokenCacheSpecProvider): + PROVIDER_NAME = 'example' + + def generate_cache_spec(self, request: PoTokenRequest): + return PoTokenCacheSpec( + key_bindings={'v': request.video_id}, + default_ttl=60, + ) + + class ExampleProviderTwoPCSP(ExampleProviderPCSP): + pass + + example_provider = ExampleProviderPCSP(ie=ie, logger=logger, settings={}) + example_provider_two = ExampleProviderTwoPCSP(ie=ie, logger=logger, settings={}) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[example_provider], + logger=logger, + ) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + assert len(memorypcp.cache) == 1 + assert hashlib.sha256( + f"{{'_dlp_cache': 'v1', '_p': 'ExampleProvider', 'v': '{pot_request.video_id}'}}".encode()).hexdigest() in memorypcp.cache + + # The second spec provider returns the exact same key bindings as the first one, + # however the PoTokenCache should use the provider key to differentiate between them + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[example_provider_two], + logger=logger, + ) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + assert len(memorypcp.cache) == 2 + assert hashlib.sha256( + f"{{'_dlp_cache': 'v1', '_p': 'ExampleProviderTwo', 'v': '{pot_request.video_id}'}}".encode()).hexdigest() in memorypcp.cache + + def test_cache_provider_preferences(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN), write_policy=CacheProviderWritePolicy.WRITE_FIRST) + assert len(pcp_one.cache) == 1 + assert len(pcp_two.cache) == 0 + + assert cache.get(pot_request) + assert pcp_one.get_calls == 1 + assert pcp_two.get_calls == 0 + + standard_preference_called = False + pcp_one_preference_claled = False + + def standard_preference(provider, request, *_, **__): + nonlocal standard_preference_called + standard_preference_called = True + assert isinstance(provider, PoTokenCacheProvider) + assert isinstance(request, PoTokenRequest) + return 1 + + def pcp_one_preference(provider, request, *_, **__): + nonlocal pcp_one_preference_claled + pcp_one_preference_claled = True + assert isinstance(provider, PoTokenCacheProvider) + assert isinstance(request, PoTokenRequest) + if provider.PROVIDER_KEY == pcp_one.PROVIDER_KEY: + return -100 + return 0 + + # test that it can hanldle multiple preferences + cache.cache_provider_preferences.append(standard_preference) + cache.cache_provider_preferences.append(pcp_one_preference) + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN), write_policy=CacheProviderWritePolicy.WRITE_FIRST) + assert cache.get(pot_request) + assert len(pcp_one.cache) == len(pcp_two.cache) == 1 + assert pcp_two.get_calls == pcp_one.get_calls == 1 + assert pcp_one.store_calls == pcp_two.store_calls == 1 + assert standard_preference_called + assert pcp_one_preference_claled + + def test_secondary_cache_provider_hit(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + cache = PoTokenCache( + cache_providers=[pcp_two], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + # Given the lower priority provider has the cache hit, store the response in the higher priority provider + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN)) + assert cache.get(pot_request) + + cache.cache_providers[pcp_one.PROVIDER_KEY] = pcp_one + + def pcp_one_pref(provider, *_, **__): + if provider.PROVIDER_KEY == pcp_one.PROVIDER_KEY: + return 1 + return -1 + + cache.cache_provider_preferences.append(pcp_one_pref) + + assert cache.get(pot_request) + assert pcp_one.get_calls == 1 + assert pcp_two.get_calls == 2 + # Should write back to pcp_one (now the highest priority cache provider) + assert pcp_one.store_calls == pcp_two.store_calls == 1 + assert 'Writing PO Token response to highest priority cache provider' in logger.messages['trace'] + + def test_cache_provider_no_hits(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + assert cache.get(pot_request) is None + assert pcp_one.get_calls == pcp_two.get_calls == 1 + + def test_get_invalid_po_token_response(self, pot_request, ie, logger): + # Test various scenarios where the po token response stored in the cache provider is invalid + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + valid_response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, valid_response) + assert len(pcp_one.cache) == len(pcp_two.cache) == 1 + # Overwrite the valid response with an invalid one in the cache + pcp_one.store(next(iter(pcp_one.cache.keys())), json.dumps(dataclasses.asdict(PoTokenResponse(None))), int(time.time() + 1000)) + assert cache.get(pot_request).po_token == valid_response.po_token + assert pcp_one.get_calls == pcp_two.get_calls == 1 + assert pcp_one.delete_calls == 1 # Invalid response should be deleted from cache + assert pcp_one.store_calls == 3 # Since response was fetched from second cache provider, it should be stored in the first one + assert len(pcp_one.cache) == 1 + assert 'Invalid PO Token response retrieved from cache provider "memory": {"po_token": null, "expires_at": null}; example bug report message' in logger.messages['error'] + + # Overwrite the valid response with an invalid json in the cache + pcp_one.store(next(iter(pcp_one.cache.keys())), 'invalid-json', int(time.time() + 1000)) + assert cache.get(pot_request).po_token == valid_response.po_token + assert pcp_one.get_calls == pcp_two.get_calls == 2 + assert pcp_one.delete_calls == 2 + assert pcp_one.store_calls == 5 # 3 + 1 store we made in the test + 1 store from lower priority cache provider + assert len(pcp_one.cache) == 1 + + assert 'Invalid PO Token response retrieved from cache provider "memory": invalid-json; example bug report message' in logger.messages['error'] + + # Valid json, but missing required fields + pcp_one.store(next(iter(pcp_one.cache.keys())), '{"unknown_param": 0}', int(time.time() + 1000)) + assert cache.get(pot_request).po_token == valid_response.po_token + assert pcp_one.get_calls == pcp_two.get_calls == 3 + assert pcp_one.delete_calls == 3 + assert pcp_one.store_calls == 7 # 5 + 1 store from test + 1 store from lower priority cache provider + assert len(pcp_one.cache) == 1 + + assert 'Invalid PO Token response retrieved from cache provider "memory": {"unknown_param": 0}; example bug report message' in logger.messages['error'] + + def test_store_invalid_po_token_response(self, pot_request, ie, logger): + # Should not store an invalid po token response + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + + cache = PoTokenCache( + cache_providers=[pcp_one], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + cache.store(pot_request, PoTokenResponse(po_token=EXAMPLE_PO_TOKEN, expires_at=80)) + assert cache.get(pot_request) is None + assert pcp_one.store_calls == 0 + assert 'Invalid PO Token response provided to PoTokenCache.store()' in logger.messages['error'][0] + + def test_store_write_policy(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN), write_policy=CacheProviderWritePolicy.WRITE_FIRST) + assert pcp_one.store_calls == 1 + assert pcp_two.store_calls == 0 + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN), write_policy=CacheProviderWritePolicy.WRITE_ALL) + assert pcp_one.store_calls == 2 + assert pcp_two.store_calls == 1 + + def test_store_write_first_policy_cache_spec(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + class WriteFirstPCSP(BaseMockCacheSpecProvider): + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return PoTokenCacheSpec( + key_bindings={'v': request.video_id, 'e': None}, + default_ttl=60, + write_policy=CacheProviderWritePolicy.WRITE_FIRST, + ) + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[WriteFirstPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN)) + assert pcp_one.store_calls == 1 + assert pcp_two.store_calls == 0 + + def test_store_write_all_policy_cache_spec(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + class WriteAllPCSP(BaseMockCacheSpecProvider): + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return PoTokenCacheSpec( + key_bindings={'v': request.video_id, 'e': None}, + default_ttl=60, + write_policy=CacheProviderWritePolicy.WRITE_ALL, + ) + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[WriteAllPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN)) + assert pcp_one.store_calls == 1 + assert pcp_two.store_calls == 1 + + def test_expires_at_pot_response(self, pot_request, memorypcp, ie, logger): + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=10000000000) + cache.store(pot_request, response) + assert next(iter(memorypcp.cache.values()))[1] == 10000000000 + + def test_expires_at_default_spec(self, pot_request, memorypcp, ie, logger): + + class TtlPCSP(BaseMockCacheSpecProvider): + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return PoTokenCacheSpec( + key_bindings={'v': request.video_id, 'e': None}, + default_ttl=10000000000, + ) + + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[TtlPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + assert next(iter(memorypcp.cache.values()))[1] >= 10000000000 + + def test_cache_provider_error_no_fallback(self, pot_request, ie, logger): + error_pcp = ErrorPCP(ie, logger, {}) + cache = PoTokenCache( + cache_providers=[error_pcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + assert cache.get(pot_request) is None + assert error_pcp.get_calls == 1 + assert error_pcp.store_calls == 1 + + assert logger.messages['warning'].count("Error from \"error\" PO Token cache provider: PoTokenCacheProviderError('something went wrong'); example bug report message") == 2 + + def test_cache_provider_error_fallback(self, pot_request, ie, logger): + error_pcp = ErrorPCP(ie, logger, {}) + memory_pcp = create_memory_pcp(ie, logger, provider_key='memory') + + cache = PoTokenCache( + cache_providers=[error_pcp, memory_pcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + + # 1. Store fails for error_pcp, stored in memory_pcp + # 2. Get fails for error_pcp, fetched from memory_pcp + # 3. Since fetched from lower priority, it should be stored in the highest priority cache provider + # 4. Store fails in error_pcp. Since write policy is WRITE_FIRST, it should not try to store in memory_pcp regardless of if the store in error_pcp fails + + assert cache.get(pot_request) + assert error_pcp.get_calls == 1 + assert error_pcp.store_calls == 2 # since highest priority, when fetched from lower priority, it should be stored in the highest priority cache provider + assert memory_pcp.get_calls == 1 + assert memory_pcp.store_calls == 1 + + assert logger.messages['warning'].count("Error from \"error\" PO Token cache provider: PoTokenCacheProviderError('something went wrong'); example bug report message") == 3 + + def test_cache_provider_unexpected_error_no_fallback(self, pot_request, ie, logger): + error_pcp = UnexpectedErrorPCP(ie, logger, {}) + cache = PoTokenCache( + cache_providers=[error_pcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + assert cache.get(pot_request) is None + assert error_pcp.get_calls == 1 + assert error_pcp.store_calls == 1 + + assert logger.messages['error'].count("Error occurred with \"unexpected_error\" PO Token cache provider: ValueError('something went wrong'); example bug report message") == 2 + + def test_cache_provider_unexpected_error_fallback(self, pot_request, ie, logger): + error_pcp = UnexpectedErrorPCP(ie, logger, {}) + memory_pcp = create_memory_pcp(ie, logger, provider_key='memory') + + cache = PoTokenCache( + cache_providers=[error_pcp, memory_pcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + + # 1. Store fails for error_pcp, stored in memory_pcp + # 2. Get fails for error_pcp, fetched from memory_pcp + # 3. Since fetched from lower priority, it should be stored in the highest priority cache provider + # 4. Store fails in error_pcp. Since write policy is WRITE_FIRST, it should not try to store in memory_pcp regardless of if the store in error_pcp fails + + assert cache.get(pot_request) + assert error_pcp.get_calls == 1 + assert error_pcp.store_calls == 2 # since highest priority, when fetched from lower priority, it should be stored in the highest priority cache provider + assert memory_pcp.get_calls == 1 + assert memory_pcp.store_calls == 1 + + assert logger.messages['error'].count("Error occurred with \"unexpected_error\" PO Token cache provider: ValueError('something went wrong'); example bug report message") == 3 + + def test_cache_provider_unavailable_no_fallback(self, pot_request, ie, logger): + provider = create_memory_pcp(ie, logger, available=False) + + cache = PoTokenCache( + cache_providers=[provider], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + assert cache.get(pot_request) is None + assert provider.get_calls == 0 + assert provider.store_calls == 0 + assert provider.available_called_times + + def test_cache_provider_unavailable_fallback(self, pot_request, ie, logger): + provider_unavailable = create_memory_pcp(ie, logger, provider_key='unavailable', provider_name='unavailable', available=False) + provider_available = create_memory_pcp(ie, logger, provider_key='available', provider_name='available') + + cache = PoTokenCache( + cache_providers=[provider_unavailable, provider_available], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + assert cache.get(pot_request) is not None + assert provider_unavailable.get_calls == 0 + assert provider_unavailable.store_calls == 0 + assert provider_available.get_calls == 1 + assert provider_available.store_calls == 1 + assert provider_unavailable.available_called_times + assert provider_available.available_called_times + + # should not even try to use the provider for the request + assert 'Attempting to fetch a PO Token response from "unavailable" provider' not in logger.messages['trace'] + assert 'Attempting to fetch a PO Token response from "available" provider' not in logger.messages['trace'] + + def test_available_not_called(self, ie, pot_request, logger): + # Test that the available method is not called when provider higher in the list is available + provider_unavailable = create_memory_pcp( + ie, logger, provider_key='unavailable', provider_name='unavailable', available=False) + provider_available = create_memory_pcp(ie, logger, provider_key='available', provider_name='available') + + logger.log_level = logger.LogLevel.INFO + + cache = PoTokenCache( + cache_providers=[provider_available, provider_unavailable], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response, write_policy=CacheProviderWritePolicy.WRITE_FIRST) + assert cache.get(pot_request) is not None + assert provider_unavailable.get_calls == 0 + assert provider_unavailable.store_calls == 0 + assert provider_available.get_calls == 1 + assert provider_available.store_calls == 1 + assert provider_unavailable.available_called_times == 0 + assert provider_available.available_called_times + assert 'PO Token Cache Providers: available-0.0.0 (external), unavailable-0.0.0 (external, unavailable)' not in logger.messages.get('trace', []) + + def test_available_called_trace(self, ie, pot_request, logger): + # But if logging level is trace should call available (as part of debug logging) + provider_unavailable = create_memory_pcp( + ie, logger, provider_key='unavailable', provider_name='unavailable', available=False) + provider_available = create_memory_pcp(ie, logger, provider_key='available', provider_name='available') + + logger.log_level = logger.LogLevel.TRACE + + cache = PoTokenCache( + cache_providers=[provider_available, provider_unavailable], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response, write_policy=CacheProviderWritePolicy.WRITE_FIRST) + assert cache.get(pot_request) is not None + assert provider_unavailable.get_calls == 0 + assert provider_unavailable.store_calls == 0 + assert provider_available.get_calls == 1 + assert provider_available.store_calls == 1 + assert provider_unavailable.available_called_times + assert provider_available.available_called_times + assert 'PO Token Cache Providers: available-0.0.0 (external), unavailable-0.0.0 (external, unavailable)' in logger.messages.get('trace', []) + + def test_close(self, ie, pot_request, logger): + # Should call close on the cache providers and cache specs + memory_pcp = create_memory_pcp(ie, logger, provider_key='memory') + memory2_pcp = create_memory_pcp(ie, logger, provider_key='memory2') + + spec1 = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + spec2 = UnavailableCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + + cache = PoTokenCache( + cache_providers=[memory2_pcp, memory_pcp], + cache_spec_providers=[spec1, spec2], + logger=logger, + ) + + cache.close() + assert memory_pcp.close_called + assert memory2_pcp.close_called + assert spec1.close_called + assert spec2.close_called + + +class TestPoTokenRequestDirector: + + def test_request_pot_success(self, ie, pot_request, pot_cache, pot_provider, logger): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + director.register_provider(pot_provider) + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + + def test_request_and_cache(self, ie, pot_request, pot_cache, pot_provider, logger): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + director.register_provider(pot_provider) + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_provider.request_called_times == 1 + assert pot_cache.get_calls == 1 + assert pot_cache.store_calls == 1 + + # Second request, should be cached + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_cache.get_calls == 2 + assert pot_cache.store_calls == 1 + assert pot_provider.request_called_times == 1 + + def test_bypass_cache(self, ie, pot_request, pot_cache, logger, pot_provider): + pot_request.bypass_cache = True + + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + director.register_provider(pot_provider) + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_provider.request_called_times == 1 + assert pot_cache.get_calls == 0 + assert pot_cache.store_calls == 1 + + # Second request, should not get from cache + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_provider.request_called_times == 2 + assert pot_cache.get_calls == 0 + assert pot_cache.store_calls == 2 + + # POT is still cached, should get from cache + pot_request.bypass_cache = False + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_provider.request_called_times == 2 + assert pot_cache.get_calls == 1 + assert pot_cache.store_calls == 2 + + def test_clean_pot_generate(self, ie, pot_request, pot_cache, logger): + # Token should be cleaned before returning + base_token = base64.urlsafe_b64encode(b'token').decode() + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = success_ptp(PoTokenResponse(base_token + '?extra=params'))(ie, logger, settings={}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response == base_token + assert provider.request_called_times == 1 + + # Confirm the cleaned version was stored in the cache + cached_token = pot_cache.get(pot_request) + assert cached_token.po_token == base_token + + def test_clean_pot_cache(self, ie, pot_request, pot_cache, logger, pot_provider): + # Token retrieved from cache should be cleaned before returning + base_token = base64.urlsafe_b64encode(b'token').decode() + pot_cache.store(pot_request, PoTokenResponse(base_token + '?extra=params')) + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + director.register_provider(pot_provider) + + response = director.get_po_token(pot_request) + assert response == base_token + assert pot_cache.get_calls == 1 + assert pot_provider.request_called_times == 0 + + def test_cache_expires_at_none(self, ie, pot_request, pot_cache, logger, pot_provider): + # Should cache if expires_at=None in the response + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = success_ptp(PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=None))(ie, logger, settings={}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_cache.store_calls == 1 + assert pot_cache.get(pot_request).po_token == EXAMPLE_PO_TOKEN + + def test_cache_expires_at_positive(self, ie, pot_request, pot_cache, logger, pot_provider): + # Should cache if expires_at is a positive number in the response + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = success_ptp(PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=99999999999))(ie, logger, settings={}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_cache.store_calls == 1 + assert pot_cache.get(pot_request).po_token == EXAMPLE_PO_TOKEN + + @pytest.mark.parametrize('expires_at', [0, -1]) + def test_not_cache_expires_at(self, ie, pot_request, pot_cache, logger, pot_provider, expires_at): + # Should not cache if expires_at <= 0 in the response + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = success_ptp(PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=expires_at))(ie, logger, settings={}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_cache.store_calls == 0 + assert pot_cache.get(pot_request) is None + + def test_no_providers(self, ie, pot_request, pot_cache, logger): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + response = director.get_po_token(pot_request) + assert response is None + + def test_try_cache_no_providers(self, ie, pot_request, pot_cache, logger): + # Should still try the cache even if no providers are configured + pot_cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN)) + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + + def test_close(self, ie, pot_request, pot_cache, pot_provider, logger): + # Should call close on the pot cache and any providers + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + + provider2 = UnavailablePTP(ie, logger, {}) + director.register_provider(pot_provider) + director.register_provider(provider2) + + director.close() + assert pot_provider.close_called + assert provider2.close_called + assert pot_cache.close_called + + def test_pot_provider_preferences(self, pot_request, pot_cache, ie, logger): + pot_request.bypass_cache = True + provider_two_pot = base64.urlsafe_b64encode(b'token2').decode() + + example_provider = success_ptp(response=PoTokenResponse(EXAMPLE_PO_TOKEN), key='exampleone')(ie, logger, settings={}) + example_provider_two = success_ptp(response=PoTokenResponse(provider_two_pot), key='exampletwo')(ie, logger, settings={}) + + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + director.register_provider(example_provider) + director.register_provider(example_provider_two) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert example_provider.request_called_times == 1 + assert example_provider_two.request_called_times == 0 + + standard_preference_called = False + example_preference_called = False + + # Test that the provider preferences are respected + def standard_preference(provider, request, *_, **__): + nonlocal standard_preference_called + standard_preference_called = True + assert isinstance(provider, PoTokenProvider) + assert isinstance(request, PoTokenRequest) + return 1 + + def example_preference(provider, request, *_, **__): + nonlocal example_preference_called + example_preference_called = True + assert isinstance(provider, PoTokenProvider) + assert isinstance(request, PoTokenRequest) + if provider.PROVIDER_KEY == example_provider.PROVIDER_KEY: + return -100 + return 0 + + # test that it can handle multiple preferences + director.register_preference(example_preference) + director.register_preference(standard_preference) + + response = director.get_po_token(pot_request) + assert response == provider_two_pot + assert example_provider.request_called_times == 1 + assert example_provider_two.request_called_times == 1 + assert standard_preference_called + assert example_preference_called + + def test_unsupported_request_no_fallback(self, ie, logger, pot_cache, pot_request): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnsupportedPTP(ie, logger, {}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + + def test_unsupported_request_fallback(self, ie, logger, pot_cache, pot_request, pot_provider): + # Should fallback to the next provider if the first one does not support the request + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnsupportedPTP(ie, logger, {}) + director.register_provider(provider) + director.register_provider(pot_provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 1 + assert pot_provider.request_called_times == 1 + assert 'PO Token Provider "unsupported" rejected this request, trying next available provider. Reason: unsupported request' in logger.messages['trace'] + + def test_unavailable_request_no_fallback(self, ie, logger, pot_cache, pot_request): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnavailablePTP(ie, logger, {}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 0 + assert provider.available_called_times + + def test_unavailable_request_fallback(self, ie, logger, pot_cache, pot_request, pot_provider): + # Should fallback to the next provider if the first one is unavailable + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnavailablePTP(ie, logger, {}) + director.register_provider(provider) + director.register_provider(pot_provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 0 + assert provider.available_called_times + assert pot_provider.request_called_times == 1 + assert pot_provider.available_called_times + # should not even try use the provider for the request + assert 'Attempting to fetch a PO Token from "unavailable" provider' not in logger.messages['trace'] + assert 'Attempting to fetch a PO Token from "success" provider' in logger.messages['trace'] + + def test_available_not_called(self, ie, logger, pot_cache, pot_request, pot_provider): + # Test that the available method is not called when provider higher in the list is available + logger.log_level = logger.LogLevel.INFO + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnavailablePTP(ie, logger, {}) + director.register_provider(pot_provider) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 0 + assert provider.available_called_times == 0 + assert pot_provider.request_called_times == 1 + assert pot_provider.available_called_times == 2 + assert 'PO Token Providers: success-0.0.1 (external), unavailable-0.0.0 (external, unavailable)' not in logger.messages.get('trace', []) + + def test_available_called_trace(self, ie, logger, pot_cache, pot_request, pot_provider): + # But if logging level is trace should call available (as part of debug logging) + logger.log_level = logger.LogLevel.TRACE + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnavailablePTP(ie, logger, {}) + director.register_provider(pot_provider) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 0 + assert provider.available_called_times == 1 + assert pot_provider.request_called_times == 1 + assert pot_provider.available_called_times == 3 + assert 'PO Token Providers: success-0.0.1 (external), unavailable-0.0.0 (external, unavailable)' in logger.messages['trace'] + + def test_provider_error_no_fallback_unexpected(self, ie, logger, pot_cache, pot_request): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = ErrorPTP(ie, logger, {}) + director.register_provider(provider) + pot_request.video_id = 'unexpected' + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + assert "Error fetching PO Token from \"error\" provider: PoTokenProviderError('an error occurred'); please report this issue to the provider developer at https://error.example.com/issues ." in logger.messages['warning'] + + def test_provider_error_no_fallback_expected(self, ie, logger, pot_cache, pot_request): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = ErrorPTP(ie, logger, {}) + director.register_provider(provider) + pot_request.video_id = 'expected' + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + assert "Error fetching PO Token from \"error\" provider: PoTokenProviderError('an error occurred')" in logger.messages['warning'] + + def test_provider_error_fallback(self, ie, logger, pot_cache, pot_request, pot_provider): + # Should fallback to the next provider if the first one raises an error + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = ErrorPTP(ie, logger, {}) + director.register_provider(provider) + director.register_provider(pot_provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 1 + assert pot_provider.request_called_times == 1 + assert "Error fetching PO Token from \"error\" provider: PoTokenProviderError('an error occurred'); please report this issue to the provider developer at https://error.example.com/issues ." in logger.messages['warning'] + + def test_provider_unexpected_error_no_fallback(self, ie, logger, pot_cache, pot_request): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnexpectedErrorPTP(ie, logger, {}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + assert "Unexpected error when fetching PO Token from \"unexpected_error\" provider: ValueError('an unexpected error occurred'); please report this issue to the provider developer at https://unexpected.example.com/issues ." in logger.messages['error'] + + def test_provider_unexpected_error_fallback(self, ie, logger, pot_cache, pot_request, pot_provider): + # Should fallback to the next provider if the first one raises an unexpected error + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnexpectedErrorPTP(ie, logger, {}) + director.register_provider(provider) + director.register_provider(pot_provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 1 + assert pot_provider.request_called_times == 1 + assert "Unexpected error when fetching PO Token from \"unexpected_error\" provider: ValueError('an unexpected error occurred'); please report this issue to the provider developer at https://unexpected.example.com/issues ." in logger.messages['error'] + + def test_invalid_po_token_response_type(self, ie, logger, pot_cache, pot_request, pot_provider): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = InvalidPTP(ie, logger, {}) + director.register_provider(provider) + + pot_request.video_id = 'invalid_type' + + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + assert 'Invalid PO Token response received from "invalid" provider: invalid-response; please report this issue to the provider developer at https://invalid.example.com/issues .' in logger.messages['error'] + + # Should fallback to next available provider + director.register_provider(pot_provider) + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 2 + assert pot_provider.request_called_times == 1 + + def test_invalid_po_token_response(self, ie, logger, pot_cache, pot_request, pot_provider): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = InvalidPTP(ie, logger, {}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + assert "Invalid PO Token response received from \"invalid\" provider: PoTokenResponse(po_token='example-token?', expires_at='123'); please report this issue to the provider developer at https://invalid.example.com/issues ." in logger.messages['error'] + + # Should fallback to next available provider + director.register_provider(pot_provider) + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 2 + assert pot_provider.request_called_times == 1 + + def test_copy_request_provider(self, ie, logger, pot_cache, pot_request): + + class BadProviderPTP(BaseMockPoTokenProvider): + _SUPPORTED_CONTEXTS = None + _SUPPORTED_CLIENTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + # Providers should not modify the request object, but we should guard against it + request.video_id = 'bad' + raise PoTokenProviderRejectedRequest('bad request') + + class GoodProviderPTP(BaseMockPoTokenProvider): + _SUPPORTED_CONTEXTS = None + _SUPPORTED_CLIENTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + return PoTokenResponse(base64.urlsafe_b64encode(request.video_id.encode()).decode()) + + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + + bad_provider = BadProviderPTP(ie, logger, {}) + good_provider = GoodProviderPTP(ie, logger, {}) + + director.register_provider(bad_provider) + director.register_provider(good_provider) + + pot_request.video_id = 'good' + response = director.get_po_token(pot_request) + assert response == base64.urlsafe_b64encode(b'good').decode() + assert bad_provider.request_called_times == 1 + assert good_provider.request_called_times == 1 + assert pot_request.video_id == 'good' + + +@pytest.mark.parametrize('spec, expected', [ + (None, False), + (PoTokenCacheSpec(key_bindings={'v': 'video-id'}, default_ttl=60, write_policy=None), False), # type: ignore + (PoTokenCacheSpec(key_bindings={'v': 'video-id'}, default_ttl='invalid'), False), # type: ignore + (PoTokenCacheSpec(key_bindings='invalid', default_ttl=60), False), # type: ignore + (PoTokenCacheSpec(key_bindings={2: 'video-id'}, default_ttl=60), False), # type: ignore + (PoTokenCacheSpec(key_bindings={'v': 2}, default_ttl=60), False), # type: ignore + (PoTokenCacheSpec(key_bindings={'v': None}, default_ttl=60), False), # type: ignore + + (PoTokenCacheSpec(key_bindings={'v': 'video_id', 'e': None}, default_ttl=60), True), + (PoTokenCacheSpec(key_bindings={'v': 'video_id'}, default_ttl=60, write_policy=CacheProviderWritePolicy.WRITE_FIRST), True), +]) +def test_validate_cache_spec(spec, expected): + assert validate_cache_spec(spec) == expected + + +@pytest.mark.parametrize('po_token', [ + 'invalid-token?', + '123', +]) +def test_clean_pot_fail(po_token): + with pytest.raises(ValueError, match='Invalid PO Token'): + clean_pot(po_token) + + +@pytest.mark.parametrize('po_token,expected', [ + ('TwAA/+8=', 'TwAA_-8='), + ('TwAA%5F%2D9VA6Q92v%5FvEQ4==?extra-param=2', 'TwAA_-9VA6Q92v_vEQ4='), +]) +def test_clean_pot(po_token, expected): + assert clean_pot(po_token) == expected + + +@pytest.mark.parametrize( + 'response, expected', + [ + (None, False), + (PoTokenResponse(None), False), + (PoTokenResponse(1), False), + (PoTokenResponse('invalid-token?'), False), + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at='abc'), False), # type: ignore + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=100), False), + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=time.time() + 10000.0), False), # type: ignore + (PoTokenResponse(EXAMPLE_PO_TOKEN), True), + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=-1), True), + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=0), True), + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=int(time.time()) + 10000), True), + ], +) +def test_validate_pot_response(response, expected): + assert validate_response(response) == expected + + +def test_built_in_provider(ie, logger): + class BuiltinProviderDefaultT(BuiltinIEContentProvider, suffix='T'): + def is_available(self): + return True + + class BuiltinProviderCustomNameT(BuiltinIEContentProvider, suffix='T'): + PROVIDER_NAME = 'CustomName' + + def is_available(self): + return True + + class ExternalProviderDefaultT(IEContentProvider, suffix='T'): + def is_available(self): + return True + + class ExternalProviderCustomT(IEContentProvider, suffix='T'): + PROVIDER_NAME = 'custom' + PROVIDER_VERSION = '5.4b2' + + def is_available(self): + return True + + class ExternalProviderUnavailableT(IEContentProvider, suffix='T'): + def is_available(self) -> bool: + return False + + class BuiltinProviderUnavailableT(IEContentProvider, suffix='T'): + def is_available(self) -> bool: + return False + + built_in_default = BuiltinProviderDefaultT(ie=ie, logger=logger, settings={}) + built_in_custom_name = BuiltinProviderCustomNameT(ie=ie, logger=logger, settings={}) + built_in_unavailable = BuiltinProviderUnavailableT(ie=ie, logger=logger, settings={}) + external_default = ExternalProviderDefaultT(ie=ie, logger=logger, settings={}) + external_custom = ExternalProviderCustomT(ie=ie, logger=logger, settings={}) + external_unavailable = ExternalProviderUnavailableT(ie=ie, logger=logger, settings={}) + + assert provider_display_list([]) == 'none' + assert provider_display_list([built_in_default]) == 'BuiltinProviderDefault' + assert provider_display_list([external_unavailable]) == 'ExternalProviderUnavailable-0.0.0 (external, unavailable)' + assert provider_display_list([ + built_in_default, + built_in_custom_name, + external_default, + external_custom, + external_unavailable, + built_in_unavailable], + ) == 'BuiltinProviderDefault, CustomName, ExternalProviderDefault-0.0.0 (external), custom-5.4b2 (external), ExternalProviderUnavailable-0.0.0 (external, unavailable), BuiltinProviderUnavailable-0.0.0 (external, unavailable)' diff --git a/test/test_pot/test_pot_framework.py b/test/test_pot/test_pot_framework.py new file mode 100644 index 0000000000..bc94653f4a --- /dev/null +++ b/test/test_pot/test_pot_framework.py @@ -0,0 +1,629 @@ +import pytest + +from yt_dlp.extractor.youtube.pot._provider import IEContentProvider +from yt_dlp.cookies import YoutubeDLCookieJar +from yt_dlp.utils.networking import HTTPHeaderDict +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenRequest, + PoTokenContext, + ExternalRequestFeature, + +) + +from yt_dlp.extractor.youtube.pot.cache import ( + PoTokenCacheProvider, + PoTokenCacheSpec, + PoTokenCacheSpecProvider, + CacheProviderWritePolicy, +) + +import yt_dlp.extractor.youtube.pot.cache as cache + +from yt_dlp.networking import Request +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenResponse, + PoTokenProvider, + PoTokenProviderRejectedRequest, + provider_bug_report_message, + register_provider, + register_preference, +) + +from yt_dlp.extractor.youtube.pot._registry import _pot_providers, _ptp_preferences, _pot_pcs_providers, _pot_cache_providers, _pot_cache_provider_preferences + + +class ExamplePTP(PoTokenProvider): + PROVIDER_NAME = 'example' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://example.com/issues' + + _SUPPORTED_CLIENTS = ('WEB',) + _SUPPORTED_CONTEXTS = (PoTokenContext.GVS, ) + + _SUPPORTED_EXTERNAL_REQUEST_FEATURES = ( + ExternalRequestFeature.PROXY_SCHEME_HTTP, + ExternalRequestFeature.PROXY_SCHEME_SOCKS5H, + ) + + def is_available(self) -> bool: + return True + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + return PoTokenResponse('example-token', expires_at=123) + + +class ExampleCacheProviderPCP(PoTokenCacheProvider): + + PROVIDER_NAME = 'example' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://example.com/issues' + + def is_available(self) -> bool: + return True + + def get(self, key: str): + return 'example-cache' + + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + +class ExampleCacheSpecProviderPCSP(PoTokenCacheSpecProvider): + + PROVIDER_NAME = 'example' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://example.com/issues' + + def generate_cache_spec(self, request: PoTokenRequest): + return PoTokenCacheSpec( + key_bindings={'field': 'example-key'}, + default_ttl=60, + write_policy=CacheProviderWritePolicy.WRITE_FIRST, + ) + + +class TestPoTokenProvider: + + def test_base_type(self): + assert issubclass(PoTokenProvider, IEContentProvider) + + def test_create_provider_missing_fetch_method(self, ie, logger): + class MissingMethodsPTP(PoTokenProvider): + def is_available(self) -> bool: + return True + + with pytest.raises(TypeError): + MissingMethodsPTP(ie=ie, logger=logger, settings={}) + + def test_create_provider_missing_available_method(self, ie, logger): + class MissingMethodsPTP(PoTokenProvider): + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderRejectedRequest('Not implemented') + + with pytest.raises(TypeError): + MissingMethodsPTP(ie=ie, logger=logger, settings={}) + + def test_barebones_provider(self, ie, logger): + class BarebonesProviderPTP(PoTokenProvider): + def is_available(self) -> bool: + return True + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderRejectedRequest('Not implemented') + + provider = BarebonesProviderPTP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'BarebonesProvider' + assert provider.PROVIDER_KEY == 'BarebonesProvider' + assert provider.PROVIDER_VERSION == '0.0.0' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at (developer has not provided a bug report location) .' + + def test_example_provider_success(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'example' + assert provider.PROVIDER_KEY == 'Example' + assert provider.PROVIDER_VERSION == '0.0.1' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at https://example.com/issues .' + assert provider.is_available() + + response = provider.request_pot(pot_request) + + assert response.po_token == 'example-token' + assert response.expires_at == 123 + + def test_provider_unsupported_context(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + pot_request.context = PoTokenContext.PLAYER + + with pytest.raises(PoTokenProviderRejectedRequest): + provider.request_pot(pot_request) + + def test_provider_unsupported_client(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + pot_request.innertube_context['client']['clientName'] = 'ANDROID' + + with pytest.raises(PoTokenProviderRejectedRequest): + provider.request_pot(pot_request) + + def test_provider_unsupported_proxy_scheme(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + pot_request.request_proxy = 'socks4://example.com' + + with pytest.raises( + PoTokenProviderRejectedRequest, + match='External requests by "example" provider do not support proxy scheme "socks4". Supported proxy ' + 'schemes: http, socks5h', + ): + provider.request_pot(pot_request) + + pot_request.request_proxy = 'http://example.com' + + assert provider.request_pot(pot_request) + + def test_provider_ignore_external_request_features(self, ie, logger, pot_request): + class InternalPTP(ExamplePTP): + _SUPPORTED_EXTERNAL_REQUEST_FEATURES = None + + provider = InternalPTP(ie=ie, logger=logger, settings={}) + + pot_request.request_proxy = 'socks5://example.com' + assert provider.request_pot(pot_request) + pot_request.request_source_address = '0.0.0.0' + assert provider.request_pot(pot_request) + + def test_provider_unsupported_external_request_source_address(self, ie, logger, pot_request): + class InternalPTP(ExamplePTP): + _SUPPORTED_EXTERNAL_REQUEST_FEATURES = tuple() + + provider = InternalPTP(ie=ie, logger=logger, settings={}) + + pot_request.request_source_address = None + assert provider.request_pot(pot_request) + + pot_request.request_source_address = '0.0.0.0' + with pytest.raises( + PoTokenProviderRejectedRequest, + match='External requests by "example" provider do not support setting source address', + ): + provider.request_pot(pot_request) + + def test_provider_supported_external_request_source_address(self, ie, logger, pot_request): + class InternalPTP(ExamplePTP): + _SUPPORTED_EXTERNAL_REQUEST_FEATURES = ( + ExternalRequestFeature.SOURCE_ADDRESS, + ) + + provider = InternalPTP(ie=ie, logger=logger, settings={}) + + pot_request.request_source_address = None + assert provider.request_pot(pot_request) + + pot_request.request_source_address = '0.0.0.0' + assert provider.request_pot(pot_request) + + def test_provider_unsupported_external_request_tls_verification(self, ie, logger, pot_request): + class InternalPTP(ExamplePTP): + _SUPPORTED_EXTERNAL_REQUEST_FEATURES = tuple() + + provider = InternalPTP(ie=ie, logger=logger, settings={}) + + pot_request.request_verify_tls = True + assert provider.request_pot(pot_request) + + pot_request.request_verify_tls = False + with pytest.raises( + PoTokenProviderRejectedRequest, + match='External requests by "example" provider do not support ignoring TLS certificate failures', + ): + provider.request_pot(pot_request) + + def test_provider_supported_external_request_tls_verification(self, ie, logger, pot_request): + class InternalPTP(ExamplePTP): + _SUPPORTED_EXTERNAL_REQUEST_FEATURES = ( + ExternalRequestFeature.DISABLE_TLS_VERIFICATION, + ) + + provider = InternalPTP(ie=ie, logger=logger, settings={}) + + pot_request.request_verify_tls = True + assert provider.request_pot(pot_request) + + pot_request.request_verify_tls = False + assert provider.request_pot(pot_request) + + def test_provider_request_webpage(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + + cookiejar = YoutubeDLCookieJar() + pot_request.request_headers = HTTPHeaderDict({'User-Agent': 'example-user-agent'}) + pot_request.request_proxy = 'socks5://example-proxy.com' + pot_request.request_cookiejar = cookiejar + + def mock_urlopen(request): + return request + + ie._downloader.urlopen = mock_urlopen + + sent_request = provider._request_webpage(Request( + 'https://example.com', + ), pot_request=pot_request) + + assert sent_request.url == 'https://example.com' + assert sent_request.headers['User-Agent'] == 'example-user-agent' + assert sent_request.proxies == {'all': 'socks5://example-proxy.com'} + assert sent_request.extensions['cookiejar'] is cookiejar + assert 'Requesting webpage' in logger.messages['info'] + + def test_provider_request_webpage_override(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + + cookiejar_request = YoutubeDLCookieJar() + pot_request.request_headers = HTTPHeaderDict({'User-Agent': 'example-user-agent'}) + pot_request.request_proxy = 'socks5://example-proxy.com' + pot_request.request_cookiejar = cookiejar_request + + def mock_urlopen(request): + return request + + ie._downloader.urlopen = mock_urlopen + + sent_request = provider._request_webpage(Request( + 'https://example.com', + headers={'User-Agent': 'override-user-agent-override'}, + proxies={'http': 'http://example-proxy-override.com'}, + extensions={'cookiejar': YoutubeDLCookieJar()}, + ), pot_request=pot_request, note='Custom requesting webpage') + + assert sent_request.url == 'https://example.com' + assert sent_request.headers['User-Agent'] == 'override-user-agent-override' + assert sent_request.proxies == {'http': 'http://example-proxy-override.com'} + assert sent_request.extensions['cookiejar'] is not cookiejar_request + assert 'Custom requesting webpage' in logger.messages['info'] + + def test_provider_request_webpage_no_log(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + + def mock_urlopen(request): + return request + + ie._downloader.urlopen = mock_urlopen + + sent_request = provider._request_webpage(Request( + 'https://example.com', + ), note=False) + + assert sent_request.url == 'https://example.com' + assert 'info' not in logger.messages + + def test_provider_request_webpage_no_pot_request(self, ie, logger): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + + def mock_urlopen(request): + return request + + ie._downloader.urlopen = mock_urlopen + + sent_request = provider._request_webpage(Request( + 'https://example.com', + ), pot_request=None) + + assert sent_request.url == 'https://example.com' + + def test_get_config_arg(self, ie, logger): + provider = ExamplePTP(ie=ie, logger=logger, settings={'abc': ['123D'], 'xyz': ['456a', '789B']}) + + assert provider._configuration_arg('abc') == ['123d'] + assert provider._configuration_arg('abc', default=['default']) == ['123d'] + assert provider._configuration_arg('ABC', default=['default']) == ['default'] + assert provider._configuration_arg('abc', casesense=True) == ['123D'] + assert provider._configuration_arg('xyz', casesense=False) == ['456a', '789b'] + + def test_require_class_end_with_suffix(self, ie, logger): + class InvalidSuffix(PoTokenProvider): + PROVIDER_NAME = 'invalid-suffix' + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderRejectedRequest('Not implemented') + + def is_available(self) -> bool: + return True + + provider = InvalidSuffix(ie=ie, logger=logger, settings={}) + + with pytest.raises(AssertionError): + provider.PROVIDER_KEY # noqa: B018 + + +class TestPoTokenCacheProvider: + + def test_base_type(self): + assert issubclass(PoTokenCacheProvider, IEContentProvider) + + def test_create_provider_missing_get_method(self, ie, logger): + class MissingMethodsPCP(PoTokenCacheProvider): + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + def is_available(self) -> bool: + return True + + with pytest.raises(TypeError): + MissingMethodsPCP(ie=ie, logger=logger, settings={}) + + def test_create_provider_missing_store_method(self, ie, logger): + class MissingMethodsPCP(PoTokenCacheProvider): + def get(self, key: str): + pass + + def delete(self, key: str): + pass + + def is_available(self) -> bool: + return True + + with pytest.raises(TypeError): + MissingMethodsPCP(ie=ie, logger=logger, settings={}) + + def test_create_provider_missing_delete_method(self, ie, logger): + class MissingMethodsPCP(PoTokenCacheProvider): + def get(self, key: str): + pass + + def store(self, key: str, value: str, expires_at: int): + pass + + def is_available(self) -> bool: + return True + + with pytest.raises(TypeError): + MissingMethodsPCP(ie=ie, logger=logger, settings={}) + + def test_create_provider_missing_is_available_method(self, ie, logger): + class MissingMethodsPCP(PoTokenCacheProvider): + def get(self, key: str): + pass + + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + with pytest.raises(TypeError): + MissingMethodsPCP(ie=ie, logger=logger, settings={}) + + def test_barebones_provider(self, ie, logger): + class BarebonesProviderPCP(PoTokenCacheProvider): + + def is_available(self) -> bool: + return True + + def get(self, key: str): + return 'example-cache' + + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + provider = BarebonesProviderPCP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'BarebonesProvider' + assert provider.PROVIDER_KEY == 'BarebonesProvider' + assert provider.PROVIDER_VERSION == '0.0.0' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at (developer has not provided a bug report location) .' + + def test_create_provider_example(self, ie, logger): + provider = ExampleCacheProviderPCP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'example' + assert provider.PROVIDER_KEY == 'ExampleCacheProvider' + assert provider.PROVIDER_VERSION == '0.0.1' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at https://example.com/issues .' + assert provider.is_available() + + def test_get_config_arg(self, ie, logger): + provider = ExampleCacheProviderPCP(ie=ie, logger=logger, settings={'abc': ['123D'], 'xyz': ['456a', '789B']}) + assert provider._configuration_arg('abc') == ['123d'] + assert provider._configuration_arg('abc', default=['default']) == ['123d'] + assert provider._configuration_arg('ABC', default=['default']) == ['default'] + assert provider._configuration_arg('abc', casesense=True) == ['123D'] + assert provider._configuration_arg('xyz', casesense=False) == ['456a', '789b'] + + def test_require_class_end_with_suffix(self, ie, logger): + class InvalidSuffix(PoTokenCacheProvider): + def get(self, key: str): + return 'example-cache' + + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + def is_available(self) -> bool: + return True + + provider = InvalidSuffix(ie=ie, logger=logger, settings={}) + + with pytest.raises(AssertionError): + provider.PROVIDER_KEY # noqa: B018 + + +class TestPoTokenCacheSpecProvider: + + def test_base_type(self): + assert issubclass(PoTokenCacheSpecProvider, IEContentProvider) + + def test_create_provider_missing_supports_method(self, ie, logger): + class MissingMethodsPCS(PoTokenCacheSpecProvider): + pass + + with pytest.raises(TypeError): + MissingMethodsPCS(ie=ie, logger=logger, settings={}) + + def test_create_provider_barebones(self, ie, pot_request, logger): + class BarebonesProviderPCSP(PoTokenCacheSpecProvider): + def generate_cache_spec(self, request: PoTokenRequest): + return PoTokenCacheSpec( + default_ttl=100, + key_bindings={}, + ) + + provider = BarebonesProviderPCSP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'BarebonesProvider' + assert provider.PROVIDER_KEY == 'BarebonesProvider' + assert provider.PROVIDER_VERSION == '0.0.0' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at (developer has not provided a bug report location) .' + assert provider.is_available() + assert provider.generate_cache_spec(request=pot_request).default_ttl == 100 + assert provider.generate_cache_spec(request=pot_request).key_bindings == {} + assert provider.generate_cache_spec(request=pot_request).write_policy == CacheProviderWritePolicy.WRITE_ALL + + def test_create_provider_example(self, ie, pot_request, logger): + provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'example' + assert provider.PROVIDER_KEY == 'ExampleCacheSpecProvider' + assert provider.PROVIDER_VERSION == '0.0.1' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at https://example.com/issues .' + assert provider.is_available() + assert provider.generate_cache_spec(pot_request) + assert provider.generate_cache_spec(pot_request).key_bindings == {'field': 'example-key'} + assert provider.generate_cache_spec(pot_request).default_ttl == 60 + assert provider.generate_cache_spec(pot_request).write_policy == CacheProviderWritePolicy.WRITE_FIRST + + def test_get_config_arg(self, ie, logger): + provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={'abc': ['123D'], 'xyz': ['456a', '789B']}) + + assert provider._configuration_arg('abc') == ['123d'] + assert provider._configuration_arg('abc', default=['default']) == ['123d'] + assert provider._configuration_arg('ABC', default=['default']) == ['default'] + assert provider._configuration_arg('abc', casesense=True) == ['123D'] + assert provider._configuration_arg('xyz', casesense=False) == ['456a', '789b'] + + def test_require_class_end_with_suffix(self, ie, logger): + class InvalidSuffix(PoTokenCacheSpecProvider): + def generate_cache_spec(self, request: PoTokenRequest): + return None + + provider = InvalidSuffix(ie=ie, logger=logger, settings={}) + + with pytest.raises(AssertionError): + provider.PROVIDER_KEY # noqa: B018 + + +class TestPoTokenRequest: + def test_copy_request(self, pot_request): + copied_request = pot_request.copy() + + assert copied_request is not pot_request + assert copied_request.context == pot_request.context + assert copied_request.innertube_context == pot_request.innertube_context + assert copied_request.innertube_context is not pot_request.innertube_context + copied_request.innertube_context['client']['clientName'] = 'ANDROID' + assert pot_request.innertube_context['client']['clientName'] != 'ANDROID' + assert copied_request.innertube_host == pot_request.innertube_host + assert copied_request.session_index == pot_request.session_index + assert copied_request.player_url == pot_request.player_url + assert copied_request.is_authenticated == pot_request.is_authenticated + assert copied_request.visitor_data == pot_request.visitor_data + assert copied_request.data_sync_id == pot_request.data_sync_id + assert copied_request.video_id == pot_request.video_id + assert copied_request.request_cookiejar is pot_request.request_cookiejar + assert copied_request.request_proxy == pot_request.request_proxy + assert copied_request.request_headers == pot_request.request_headers + assert copied_request.request_headers is not pot_request.request_headers + assert copied_request.request_timeout == pot_request.request_timeout + assert copied_request.request_source_address == pot_request.request_source_address + assert copied_request.request_verify_tls == pot_request.request_verify_tls + assert copied_request.bypass_cache == pot_request.bypass_cache + + +def test_provider_bug_report_message(ie, logger): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at https://example.com/issues .' + + message = provider_bug_report_message(provider) + assert message == '; please report this issue to the provider developer at https://example.com/issues .' + + message_before = provider_bug_report_message(provider, before='custom message!') + assert message_before == 'custom message! Please report this issue to the provider developer at https://example.com/issues .' + + +def test_register_provider(ie): + + @register_provider + class UnavailableProviderPTP(PoTokenProvider): + def is_available(self) -> bool: + return False + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderRejectedRequest('Not implemented') + + assert _pot_providers.value.get('UnavailableProvider') == UnavailableProviderPTP + _pot_providers.value.pop('UnavailableProvider') + + +def test_register_pot_preference(ie): + before = len(_ptp_preferences.value) + + @register_preference(ExamplePTP) + def unavailable_preference(provider: PoTokenProvider, request: PoTokenRequest): + return 1 + + assert len(_ptp_preferences.value) == before + 1 + + +def test_register_cache_provider(ie): + + @cache.register_provider + class UnavailableCacheProviderPCP(PoTokenCacheProvider): + def is_available(self) -> bool: + return False + + def get(self, key: str): + return 'example-cache' + + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + assert _pot_cache_providers.value.get('UnavailableCacheProvider') == UnavailableCacheProviderPCP + _pot_cache_providers.value.pop('UnavailableCacheProvider') + + +def test_register_cache_provider_spec(ie): + + @cache.register_spec + class UnavailableCacheProviderPCSP(PoTokenCacheSpecProvider): + def is_available(self) -> bool: + return False + + def generate_cache_spec(self, request: PoTokenRequest): + return None + + assert _pot_pcs_providers.value.get('UnavailableCacheProvider') == UnavailableCacheProviderPCSP + _pot_pcs_providers.value.pop('UnavailableCacheProvider') + + +def test_register_cache_provider_preference(ie): + before = len(_pot_cache_provider_preferences.value) + + @cache.register_preference(ExampleCacheProviderPCP) + def unavailable_preference(provider: PoTokenCacheProvider, request: PoTokenRequest): + return 1 + + assert len(_pot_cache_provider_preferences.value) == before + 1 + + +def test_logger_log_level(logger): + assert logger.LogLevel('INFO') == logger.LogLevel.INFO + assert logger.LogLevel('debuG') == logger.LogLevel.DEBUG + assert logger.LogLevel(10) == logger.LogLevel.DEBUG + assert logger.LogLevel('UNKNOWN') == logger.LogLevel.INFO diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py index 63e6e11b26..ea6264a0d6 100644 --- a/yt_dlp/YoutubeDL.py +++ b/yt_dlp/YoutubeDL.py @@ -640,6 +640,7 @@ class YoutubeDL: self._printed_messages = set() self._first_webpage_request = True self._post_hooks = [] + self._close_hooks = [] self._progress_hooks = [] self._postprocessor_hooks = [] self._download_retcode = 0 @@ -908,6 +909,11 @@ class YoutubeDL: """Add the post hook""" self._post_hooks.append(ph) + def add_close_hook(self, ch): + """Add a close hook, called when YoutubeDL.close() is called""" + assert callable(ch), 'Close hook must be callable' + self._close_hooks.append(ch) + def add_progress_hook(self, ph): """Add the download progress hook""" self._progress_hooks.append(ph) @@ -1016,6 +1022,9 @@ class YoutubeDL: self._request_director.close() del self._request_director + for close_hook in self._close_hooks: + close_hook() + def trouble(self, message=None, tb=None, is_error=True): """Determine action to take when a download problem appears. diff --git a/yt_dlp/extractor/youtube/_video.py b/yt_dlp/extractor/youtube/_video.py index 548e3aa93a..28fff19695 100644 --- a/yt_dlp/extractor/youtube/_video.py +++ b/yt_dlp/extractor/youtube/_video.py @@ -23,6 +23,8 @@ from ._base import ( _split_innertube_client, short_client_name, ) +from .pot._director import initialize_pot_director +from .pot.provider import PoTokenContext, PoTokenRequest from ..openload import PhantomJSwrapper from ...jsinterp import JSInterpreter from ...networking.exceptions import HTTPError @@ -66,6 +68,7 @@ from ...utils import ( urljoin, variadic, ) +from ...utils.networking import clean_headers, clean_proxies, select_proxy STREAMING_DATA_CLIENT_NAME = '__yt_dlp_client' STREAMING_DATA_INITIAL_PO_TOKEN = '__yt_dlp_po_token' @@ -1809,6 +1812,11 @@ class YoutubeIE(YoutubeBaseInfoExtractor): super().__init__(*args, **kwargs) self._code_cache = {} self._player_cache = {} + self._pot_director = None + + def _real_initialize(self): + super()._real_initialize() + self._pot_director = initialize_pot_director(self) def _prepare_live_from_start_formats(self, formats, video_id, live_start_time, url, webpage_url, smuggled_data, is_live): lock = threading.Lock() @@ -2855,7 +2863,7 @@ class YoutubeIE(YoutubeBaseInfoExtractor): continue def fetch_po_token(self, client='web', context=_PoTokenContext.GVS, ytcfg=None, visitor_data=None, - data_sync_id=None, session_index=None, player_url=None, video_id=None, **kwargs): + data_sync_id=None, session_index=None, player_url=None, video_id=None, webpage=None, **kwargs): """ Fetch a PO Token for a given client and context. This function will validate required parameters for a given context and client. @@ -2869,10 +2877,14 @@ class YoutubeIE(YoutubeBaseInfoExtractor): @param session_index: session index. @param player_url: player URL. @param video_id: video ID. + @param webpage: video webpage. @param kwargs: Additional arguments to pass down. May be more added in the future. @return: The fetched PO Token. None if it could not be fetched. """ + # TODO(future): This validation should be moved into pot framework. + # Some sort of middleware or validation provider perhaps? + # GVS WebPO Token is bound to visitor_data / Visitor ID when logged out. # Must have visitor_data for it to function. if player_url and context == _PoTokenContext.GVS and not visitor_data and not self.is_authenticated: @@ -2894,6 +2906,7 @@ class YoutubeIE(YoutubeBaseInfoExtractor): f'Got a GVS PO Token for {client} client, but missing Data Sync ID for account. Formats may not work.' f'You may need to pass a Data Sync ID with --extractor-args "youtube:data_sync_id=XXX"') + self.write_debug(f'{video_id}: Retrieved a {context.value} PO Token for {client} client from config') return config_po_token # Require GVS WebPO Token if logged in for external fetching @@ -2903,7 +2916,7 @@ class YoutubeIE(YoutubeBaseInfoExtractor): f'You may need to pass a Data Sync ID with --extractor-args "youtube:data_sync_id=XXX"') return - return self._fetch_po_token( + po_token = self._fetch_po_token( client=client, context=context.value, ytcfg=ytcfg, @@ -2912,11 +2925,66 @@ class YoutubeIE(YoutubeBaseInfoExtractor): session_index=session_index, player_url=player_url, video_id=video_id, + video_webpage=webpage, **kwargs, ) + if po_token: + self.write_debug(f'{video_id}: Retrieved a {context.value} PO Token for {client} client') + return po_token + def _fetch_po_token(self, client, **kwargs): - """(Unstable) External PO Token fetch stub""" + context = kwargs.get('context') + + # Avoid fetching PO Tokens when not required + fetch_pot_policy = self._configuration_arg('fetch_pot', [''], ie_key=YoutubeIE)[0] + if fetch_pot_policy not in ('never', 'auto', 'always'): + fetch_pot_policy = 'auto' + if ( + fetch_pot_policy == 'never' + or ( + fetch_pot_policy == 'auto' + and _PoTokenContext(context) not in self._get_default_ytcfg(client)['PO_TOKEN_REQUIRED_CONTEXTS'] + ) + ): + return None + + headers = self.get_param('http_headers').copy() + proxies = self._downloader.proxies.copy() + clean_headers(headers) + clean_proxies(proxies, headers) + + innertube_host = self._select_api_hostname(None, default_client=client) + + pot_request = PoTokenRequest( + context=PoTokenContext(context), + innertube_context=traverse_obj(kwargs, ('ytcfg', 'INNERTUBE_CONTEXT')), + innertube_host=innertube_host, + internal_client_name=client, + session_index=kwargs.get('session_index'), + player_url=kwargs.get('player_url'), + video_webpage=kwargs.get('video_webpage'), + is_authenticated=self.is_authenticated, + visitor_data=kwargs.get('visitor_data'), + data_sync_id=kwargs.get('data_sync_id'), + video_id=kwargs.get('video_id'), + request_cookiejar=self._downloader.cookiejar, + + # All requests that would need to be proxied should be in the + # context of www.youtube.com or the innertube host + request_proxy=( + select_proxy('https://www.youtube.com', proxies) + or select_proxy(f'https://{innertube_host}', proxies) + ), + request_headers=headers, + request_timeout=self.get_param('socket_timeout'), + request_verify_tls=not self.get_param('nocheckcertificate'), + request_source_address=self.get_param('source_address'), + + bypass_cache=False, + ) + + return self._pot_director.get_po_token(pot_request) @staticmethod def _is_agegated(player_response): @@ -3074,8 +3142,9 @@ class YoutubeIE(YoutubeBaseInfoExtractor): 'video_id': video_id, 'data_sync_id': data_sync_id if self.is_authenticated else None, 'player_url': player_url if require_js_player else None, + 'webpage': webpage, 'session_index': self._extract_session_index(master_ytcfg, player_ytcfg), - 'ytcfg': player_ytcfg, + 'ytcfg': player_ytcfg or self._get_default_ytcfg(client), } player_po_token = self.fetch_po_token( diff --git a/yt_dlp/extractor/youtube/pot/README.md b/yt_dlp/extractor/youtube/pot/README.md new file mode 100644 index 0000000000..f39e290710 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/README.md @@ -0,0 +1,309 @@ +# YoutubeIE PO Token Provider Framework + +As part of the YouTube extractor, we have a framework for providing PO Tokens programmatically. This can be used by plugins. + +Refer to the [PO Token Guide](https://github.com/yt-dlp/yt-dlp/wiki/PO-Token-Guide) for more information on PO Tokens. + +> [!TIP] +> If publishing a PO Token Provider plugin to GitHub, add the [yt-dlp-pot-provider](https://github.com/topics/yt-dlp-pot-provider) topic to your repository to help users find it. + + +## Public APIs + +- `yt_dlp.extractor.youtube.pot.cache` +- `yt_dlp.extractor.youtube.pot.provider` +- `yt_dlp.extractor.youtube.pot.utils` + +Everything else is internal-only and no guarantees are made about the API stability. + +> [!WARNING] +> We will try our best to maintain stability with the public APIs. +> However, due to the nature of extractors and YouTube, we may need to remove or change APIs in the future. +> If you are using these APIs outside yt-dlp plugins, please account for this by importing them safely. + +## PO Token Provider + +`yt_dlp.extractor.youtube.pot.provider` + +```python +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenRequest, + PoTokenContext, + PoTokenProvider, + PoTokenResponse, + PoTokenProviderError, + PoTokenProviderRejectedRequest, + register_provider, + register_preference, + ExternalRequestFeature, +) +from yt_dlp.networking.common import Request +from yt_dlp.extractor.youtube.pot.utils import get_webpo_content_binding +from yt_dlp.utils import traverse_obj +from yt_dlp.networking.exceptions import RequestError +import json + + +@register_provider +class MyPoTokenProviderPTP(PoTokenProvider): # Provider class name must end with "PTP" + PROVIDER_VERSION = '0.2.1' + # Define a unique display name for the provider + PROVIDER_NAME = 'my-provider' + BUG_REPORT_LOCATION = 'https://issues.example.com/report' + + # -- Validation shortcuts. Set these to None to disable. -- + + # Innertube Client Name. + # For example, "WEB", "ANDROID", "TVHTML5". + # For a list of WebPO client names, + # see yt_dlp.extractor.youtube.pot.utils.WEBPO_CLIENTS. + # Also see yt_dlp.extractor.youtube._base.INNERTUBE_CLIENTS + # for a list of client names currently supported by the YouTube extractor. + _SUPPORTED_CLIENTS = ('WEB', 'TVHTML5') + + _SUPPORTED_CONTEXTS = ( + PoTokenContext.GVS, + ) + + # If your provider makes external requests to websites (i.e. to youtube.com) + # using another library or service (i.e., not _request_webpage), + # set the request features that are supported here. + # If only using _request_webpage to make external requests, set this to None. + _SUPPORTED_EXTERNAL_REQUEST_FEATURES = ( + ExternalRequestFeature.PROXY_SCHEME_HTTP, + ExternalRequestFeature.SOURCE_ADDRESS, + ExternalRequestFeature.DISABLE_TLS_VERIFICATION + ) + + def is_available(self) -> bool: + """ + Check if the provider is available (e.g. all required dependencies are available) + This is used to determine if the provider should be used and to provide debug information. + + IMPORTANT: This method SHOULD NOT make any network requests or perform any expensive operations. + + Since this is called multiple times, we recommend caching the result. + """ + return True + + def close(self): + # Optional close hook, called when YoutubeDL is closed. + pass + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + # ℹ️ If you need to validate the request before making the request to the external source. + # Raise yt_dlp.extractor.youtube.pot.provider.PoTokenProviderRejectedRequest if the request is not supported. + if request.is_authenticated: + raise PoTokenProviderRejectedRequest( + 'This provider does not support authenticated requests' + ) + + # ℹ️ Settings are pulled from extractor args passed to yt-dlp with the key `youtubepot-`. + # For this example, the extractor arg would be: + # `--extractor-args "youtubepot-mypotokenprovider:url=https://custom.example.com/get_pot"` + external_provider_url = self._configuration_arg( + 'url', default=['https://provider.example.com/get_pot'])[0] + + # See below for logging guidelines + self.logger.trace(f'Using external provider URL: {external_provider_url}') + + # You should use the internal HTTP client to make requests where possible, + # as it will handle cookies and other networking settings passed to yt-dlp. + try: + # See docstring in _request_webpage method for request tips + response = self._request_webpage( + Request(external_provider_url, data=json.dumps({ + 'content_binding': get_webpo_content_binding(request), + 'proxy': request.request_proxy, + 'headers': request.request_headers, + 'source_address': request.request_source_address, + 'verify_tls': request.request_verify_tls, + # Important: If your provider has its own caching, please respect `bypass_cache`. + # This may be used in the future to request a fresh PO Token if required. + 'do_not_cache': request.bypass_cache, + }).encode(), proxies={'all': None}), + pot_request=request, + note=( + f'Requesting {request.context.value} PO Token ' + f'for {request.internal_client_name} client from external provider'), + ) + + except RequestError as e: + # ℹ️ If there is an error, raise PoTokenProviderError. + # You can specify whether it is expected or not. If it is unexpected, + # the log will include a link to the bug report location (BUG_REPORT_LOCATION). + raise PoTokenProviderError( + 'Networking error while fetching to get PO Token from external provider', + expected=True + ) from e + + # Note: PO Token is expected to be base64url encoded + po_token = traverse_obj(response, 'po_token') + if not po_token: + raise PoTokenProviderError( + 'Bad PO Token Response from external provider', + expected=False + ) + + return PoTokenResponse( + po_token=po_token, + # Optional, add a custom expiration timestamp for the token. Use for caching. + # By default, yt-dlp will use the default ttl from a registered cache spec (see below) + # Set to 0 or -1 to not cache this response. + expires_at=None, + ) + + +# If there are multiple PO Token Providers that can handle the same PoTokenRequest, +# you can define a preference function to increase/decrease the priority of providers. + +@register_preference(MyPoTokenProviderPTP) +def my_provider_preference(provider: PoTokenProvider, request: PoTokenRequest) -> int: + return 50 +``` + +## Logging Guidelines + +- Use the `self.logger` object to log messages. +- When making HTTP requests or any other expensive operation, use `self.logger.info` to log a message to standard non-verbose output. + - This lets users know what is happening when a time-expensive operation is taking place. + - It is recommended to include the PO Token context and internal client name in the message if possible. + - For example, `self.logger.info(f'Requesting {request.context.value} PO Token for {request.internal_client_name} client from external provider')`. +- Use `self.logger.debug` to log a message to the verbose output (`--verbose`). + - For debugging information visible to users posting verbose logs. + - Try to not log too much, prefer using trace logging for detailed debug messages. +- Use `self.logger.trace` to log a message to the PO Token debug output (`--extractor-args "youtube:pot_trace=true"`). + - Log as much as you like here as needed for debugging your provider. +- Avoid logging PO Tokens or any sensitive information to debug or info output. + +## Debugging + +- Use `-v --extractor-args "youtube:pot_trace=true"` to enable PO Token debug output. + +## Caching + +> [!WARNING] +> The following describes more advance features that most users/developers will not need to use. + +> [!IMPORTANT] +> yt-dlp currently has a built-in LRU Memory Cache Provider and a cache spec provider for WebPO Tokens. +> You should only need to implement cache providers if you want an external cache, or a cache spec if you are handling non-WebPO Tokens. + +### Cache Providers + +`yt_dlp.extractor.youtube.pot.cache` + +```python +from yt_dlp.extractor.youtube.pot.cache import ( + PoTokenCacheProvider, + register_preference, + register_provider +) + +from yt_dlp.extractor.youtube.pot.provider import PoTokenRequest + + +@register_provider +class MyCacheProviderPCP(PoTokenCacheProvider): # Provider class name must end with "PCP" + PROVIDER_VERSION = '0.1.0' + # Define a unique display name for the provider + PROVIDER_NAME = 'my-cache-provider' + BUG_REPORT_LOCATION = 'https://issues.example.com/report' + + def is_available(self) -> bool: + """ + Check if the provider is available (e.g. all required dependencies are available) + This is used to determine if the provider should be used and to provide debug information. + + IMPORTANT: This method SHOULD NOT make any network requests or perform any expensive operations. + + Since this is called multiple times, we recommend caching the result. + """ + return True + + def get(self, key: str): + # ℹ️ Similar to PO Token Providers, Cache Providers and Cache Spec Providers + # are passed down extractor args matching key youtubepot-. + some_setting = self._configuration_arg('some_setting', default=['default_value'])[0] + return self.my_cache.get(key) + + def store(self, key: str, value: str, expires_at: int): + # ⚠ expires_at MUST be respected. + # Cache entries should not be returned if they have expired. + self.my_cache.store(key, value, expires_at) + + def delete(self, key: str): + self.my_cache.delete(key) + + def close(self): + # Optional close hook, called when the YoutubeDL instance is closed. + pass + +# If there are multiple PO Token Cache Providers available, you can +# define a preference function to increase/decrease the priority of providers. + +# IMPORTANT: Providers should be in preference of cache lookup time. +# For example, a memory cache should have a higher preference than a disk cache. + +# VERY IMPORTANT: yt-dlp has a built-in memory cache with a priority of 10000. +# Your cache provider should be lower than this. + + +@register_preference(MyCacheProviderPCP) +def my_cache_preference(provider: PoTokenCacheProvider, request: PoTokenRequest) -> int: + return 50 +``` + +### Cache Specs + +`yt_dlp.extractor.youtube.pot.cache` + +These are used to provide information on how to cache a particular PO Token Request. +You might have a different cache spec for different kinds of PO Tokens. + +```python +from yt_dlp.extractor.youtube.pot.cache import ( + PoTokenCacheSpec, + PoTokenCacheSpecProvider, + CacheProviderWritePolicy, + register_spec, +) +from yt_dlp.utils import traverse_obj +from yt_dlp.extractor.youtube.pot.provider import PoTokenRequest + + +@register_spec +class MyCacheSpecProviderPCSP(PoTokenCacheSpecProvider): # Provider class name must end with "PCSP" + PROVIDER_VERSION = '0.1.0' + # Define a unique display name for the provider + PROVIDER_NAME = 'mycachespec' + BUG_REPORT_LOCATION = 'https://issues.example.com/report' + + def generate_cache_spec(self, request: PoTokenRequest): + + client_name = traverse_obj(request.innertube_context, ('client', 'clientName')) + if client_name != 'ANDROID': + # ℹ️ If the request is not supported by the cache spec, return None + return None + + # Generate a cache spec for the request + return PoTokenCacheSpec( + # Key bindings to uniquely identify the request. These are used to generate a cache key. + key_bindings={ + 'client_name': client_name, + 'content_binding': 'unique_content_binding', + 'ip': traverse_obj(request.innertube_context, ('client', 'remoteHost')), + 'source_address': request.request_source_address, + 'proxy': request.request_proxy, + }, + # Default Cache TTL in seconds + default_ttl=21600, + + # Optional: Specify a write policy. + # WRITE_FIRST will write to the highest priority provider only, + # whereas WRITE_ALL will write to all providers. + # WRITE_FIRST may be useful if the PO Token is short-lived + # and there is no use writing to all providers. + write_policy=CacheProviderWritePolicy.WRITE_ALL, + ) +``` \ No newline at end of file diff --git a/yt_dlp/extractor/youtube/pot/__init__.py b/yt_dlp/extractor/youtube/pot/__init__.py new file mode 100644 index 0000000000..febcee0104 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/__init__.py @@ -0,0 +1,3 @@ +# Trigger import of built-in providers +from ._builtin.memory_cache import MemoryLRUPCP as _MemoryLRUPCP # noqa: F401 +from ._builtin.webpo_cachespec import WebPoPCSP as _WebPoPCSP # noqa: F401 diff --git a/yt_dlp/extractor/youtube/pot/_builtin/__init__.py b/yt_dlp/extractor/youtube/pot/_builtin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/yt_dlp/extractor/youtube/pot/_builtin/memory_cache.py b/yt_dlp/extractor/youtube/pot/_builtin/memory_cache.py new file mode 100644 index 0000000000..9c913e8c98 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/_builtin/memory_cache.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import datetime as dt +import typing +from threading import Lock + +from yt_dlp.extractor.youtube.pot._provider import BuiltinIEContentProvider +from yt_dlp.extractor.youtube.pot._registry import _pot_memory_cache +from yt_dlp.extractor.youtube.pot.cache import ( + PoTokenCacheProvider, + register_preference, + register_provider, +) + + +def initialize_global_cache(max_size: int): + if _pot_memory_cache.value.get('cache') is None: + _pot_memory_cache.value['cache'] = {} + _pot_memory_cache.value['lock'] = Lock() + _pot_memory_cache.value['max_size'] = max_size + + if _pot_memory_cache.value['max_size'] != max_size: + raise ValueError('Cannot change max_size of initialized global memory cache') + + return ( + _pot_memory_cache.value['cache'], + _pot_memory_cache.value['lock'], + _pot_memory_cache.value['max_size'], + ) + + +@register_provider +class MemoryLRUPCP(PoTokenCacheProvider, BuiltinIEContentProvider): + PROVIDER_NAME = 'memory' + DEFAULT_CACHE_SIZE = 25 + + def __init__( + self, + *args, + initialize_cache: typing.Callable[[int], tuple[dict[str, tuple[str, int]], Lock, int]] = initialize_global_cache, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.cache, self.lock, self.max_size = initialize_cache(self.DEFAULT_CACHE_SIZE) + + def is_available(self) -> bool: + return True + + def get(self, key: str) -> str | None: + with self.lock: + if key not in self.cache: + return None + value, expires_at = self.cache.pop(key) + if expires_at < int(dt.datetime.now(dt.timezone.utc).timestamp()): + return None + self.cache[key] = (value, expires_at) + return value + + def store(self, key: str, value: str, expires_at: int): + with self.lock: + if expires_at < int(dt.datetime.now(dt.timezone.utc).timestamp()): + return + if key in self.cache: + self.cache.pop(key) + self.cache[key] = (value, expires_at) + if len(self.cache) > self.max_size: + oldest_key = next(iter(self.cache)) + self.cache.pop(oldest_key) + + def delete(self, key: str): + with self.lock: + self.cache.pop(key, None) + + +@register_preference(MemoryLRUPCP) +def memorylru_preference(*_, **__): + # Memory LRU Cache SHOULD be the highest priority + return 10000 diff --git a/yt_dlp/extractor/youtube/pot/_builtin/webpo_cachespec.py b/yt_dlp/extractor/youtube/pot/_builtin/webpo_cachespec.py new file mode 100644 index 0000000000..426b815c7e --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/_builtin/webpo_cachespec.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from yt_dlp.extractor.youtube.pot._provider import BuiltinIEContentProvider +from yt_dlp.extractor.youtube.pot.cache import ( + CacheProviderWritePolicy, + PoTokenCacheSpec, + PoTokenCacheSpecProvider, + register_spec, +) +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenRequest, +) +from yt_dlp.extractor.youtube.pot.utils import ContentBindingType, get_webpo_content_binding +from yt_dlp.utils import traverse_obj + + +@register_spec +class WebPoPCSP(PoTokenCacheSpecProvider, BuiltinIEContentProvider): + PROVIDER_NAME = 'webpo' + + def generate_cache_spec(self, request: PoTokenRequest) -> PoTokenCacheSpec | None: + bind_to_visitor_id = self._configuration_arg( + 'bind_to_visitor_id', default=['true'])[0] == 'true' + + content_binding, content_binding_type = get_webpo_content_binding( + request, bind_to_visitor_id=bind_to_visitor_id) + + if not content_binding or not content_binding_type: + return None + + write_policy = CacheProviderWritePolicy.WRITE_ALL + if content_binding_type == ContentBindingType.VIDEO_ID: + write_policy = CacheProviderWritePolicy.WRITE_FIRST + + return PoTokenCacheSpec( + key_bindings={ + 't': 'webpo', + 'cb': content_binding, + 'cbt': content_binding_type.value, + 'ip': traverse_obj(request.innertube_context, ('client', 'remoteHost')), + 'sa': request.request_source_address, + 'px': request.request_proxy, + }, + # Integrity token response usually states it has a ttl of 12 hours (43200 seconds). + # We will default to 6 hours to be safe. + default_ttl=21600, + write_policy=write_policy, + ) diff --git a/yt_dlp/extractor/youtube/pot/_director.py b/yt_dlp/extractor/youtube/pot/_director.py new file mode 100644 index 0000000000..aaf1d5290a --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/_director.py @@ -0,0 +1,468 @@ +from __future__ import annotations + +import base64 +import binascii +import dataclasses +import datetime as dt +import hashlib +import json +import typing +import urllib.parse +from collections.abc import Iterable + +from yt_dlp.extractor.youtube.pot._provider import ( + BuiltinIEContentProvider, + IEContentProvider, + IEContentProviderLogger, +) +from yt_dlp.extractor.youtube.pot._registry import ( + _pot_cache_provider_preferences, + _pot_cache_providers, + _pot_pcs_providers, + _pot_providers, + _ptp_preferences, +) +from yt_dlp.extractor.youtube.pot.cache import ( + CacheProviderWritePolicy, + PoTokenCacheProvider, + PoTokenCacheProviderError, + PoTokenCacheSpec, + PoTokenCacheSpecProvider, +) +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenProvider, + PoTokenProviderError, + PoTokenProviderRejectedRequest, + PoTokenRequest, + PoTokenResponse, + provider_bug_report_message, +) +from yt_dlp.utils import bug_reports_message, format_field, join_nonempty + +if typing.TYPE_CHECKING: + from yt_dlp.extractor.youtube.pot.cache import CacheProviderPreference + from yt_dlp.extractor.youtube.pot.provider import Preference + + +class YoutubeIEContentProviderLogger(IEContentProviderLogger): + def __init__(self, ie, prefix, log_level: IEContentProviderLogger.LogLevel | None = None): + self.__ie = ie + self.prefix = prefix + self.log_level = log_level if log_level is not None else self.LogLevel.INFO + + def _format_msg(self, message: str): + prefixstr = format_field(self.prefix, None, '[%s] ') + return f'{prefixstr}{message}' + + def trace(self, message: str): + if self.log_level <= self.LogLevel.TRACE: + self.__ie.write_debug(self._format_msg('TRACE: ' + message)) + + def debug(self, message: str): + if self.log_level <= self.LogLevel.DEBUG: + self.__ie.write_debug(self._format_msg(message)) + + def info(self, message: str): + if self.log_level <= self.LogLevel.INFO: + self.__ie.to_screen(self._format_msg(message)) + + def warning(self, message: str, *, once=False): + if self.log_level <= self.LogLevel.WARNING: + self.__ie.report_warning(self._format_msg(message), only_once=once) + + def error(self, message: str): + if self.log_level <= self.LogLevel.ERROR: + self.__ie._downloader.report_error(self._format_msg(message), is_error=False) + + +class PoTokenCache: + + def __init__( + self, + logger: IEContentProviderLogger, + cache_providers: list[PoTokenCacheProvider], + cache_spec_providers: list[PoTokenCacheSpecProvider], + cache_provider_preferences: list[CacheProviderPreference] | None = None, + ): + self.cache_providers: dict[str, PoTokenCacheProvider] = { + provider.PROVIDER_KEY: provider for provider in (cache_providers or [])} + self.cache_provider_preferences: list[CacheProviderPreference] = cache_provider_preferences or [] + self.cache_spec_providers: dict[str, PoTokenCacheSpecProvider] = { + provider.PROVIDER_KEY: provider for provider in (cache_spec_providers or [])} + self.logger = logger + + def _get_cache_providers(self, request: PoTokenRequest) -> Iterable[PoTokenCacheProvider]: + """Sorts available cache providers by preference, given a request""" + preferences = { + provider: sum(pref(provider, request) for pref in self.cache_provider_preferences) + for provider in self.cache_providers.values() + } + if self.logger.log_level <= self.logger.LogLevel.TRACE: + # calling is_available() for every PO Token provider upfront may have some overhead + self.logger.trace(f'PO Token Cache Providers: {provider_display_list(self.cache_providers.values())}') + self.logger.trace('Cache Provider preferences for this request: {}'.format(', '.join( + f'{provider.PROVIDER_KEY}={pref}' for provider, pref in preferences.items()))) + + return ( + provider for provider in sorted( + self.cache_providers.values(), key=preferences.get, reverse=True) if provider.is_available()) + + def _get_cache_spec(self, request: PoTokenRequest) -> PoTokenCacheSpec | None: + for provider in self.cache_spec_providers.values(): + if not provider.is_available(): + continue + try: + spec = provider.generate_cache_spec(request) + if not spec: + continue + if not validate_cache_spec(spec): + self.logger.error( + f'PoTokenCacheSpecProvider "{provider.PROVIDER_KEY}" generate_cache_spec() ' + f'returned invalid spec {spec}{provider_bug_report_message(provider)}') + continue + spec = dataclasses.replace(spec, _provider=provider) + self.logger.trace( + f'Retrieved cache spec {spec} from cache spec provider "{provider.PROVIDER_NAME}"') + return spec + except Exception as e: + self.logger.error( + f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache spec provider: ' + f'{e!r}{provider_bug_report_message(provider)}') + continue + return None + + def _generate_key_bindings(self, spec: PoTokenCacheSpec) -> dict[str, str]: + bindings_cleaned = { + **{k: v for k, v in spec.key_bindings.items() if v is not None}, + # Allow us to invalidate caches if such need arises + '_dlp_cache': 'v1', + } + if spec._provider: + bindings_cleaned['_p'] = spec._provider.PROVIDER_KEY + self.logger.trace(f'Generated cache key bindings: {bindings_cleaned}') + return bindings_cleaned + + def _generate_key(self, bindings: dict) -> str: + binding_string = ''.join(repr(dict(sorted(bindings.items())))) + return hashlib.sha256(binding_string.encode()).hexdigest() + + def get(self, request: PoTokenRequest) -> PoTokenResponse | None: + spec = self._get_cache_spec(request) + if not spec: + self.logger.trace('No cache spec available for this request, unable to fetch from cache') + return None + + cache_key = self._generate_key(self._generate_key_bindings(spec)) + self.logger.trace(f'Attempting to access PO Token cache using key: {cache_key}') + + for idx, provider in enumerate(self._get_cache_providers(request)): + try: + self.logger.trace( + f'Attempting to fetch PO Token response from "{provider.PROVIDER_NAME}" cache provider') + cache_response = provider.get(cache_key) + if not cache_response: + continue + try: + po_token_response = PoTokenResponse(**json.loads(cache_response)) + except (TypeError, ValueError, json.JSONDecodeError): + po_token_response = None + if not validate_response(po_token_response): + self.logger.error( + f'Invalid PO Token response retrieved from cache provider "{provider.PROVIDER_NAME}": ' + f'{cache_response}{provider_bug_report_message(provider)}') + provider.delete(cache_key) + continue + self.logger.trace( + f'PO Token response retrieved from cache using "{provider.PROVIDER_NAME}" provider: ' + f'{po_token_response}') + if idx > 0: + # Write back to the highest priority cache provider, + # so we stop trying to fetch from lower priority providers + self.logger.trace('Writing PO Token response to highest priority cache provider') + self.store(request, po_token_response, write_policy=CacheProviderWritePolicy.WRITE_FIRST) + + return po_token_response + except PoTokenCacheProviderError as e: + self.logger.warning( + f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: ' + f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}') + continue + except Exception as e: + self.logger.error( + f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: ' + f'{e!r}{provider_bug_report_message(provider)}', + ) + continue + return None + + def store( + self, + request: PoTokenRequest, + response: PoTokenResponse, + write_policy: CacheProviderWritePolicy | None = None, + ): + spec = self._get_cache_spec(request) + if not spec: + self.logger.trace('No cache spec available for this request. Not caching.') + return + + if not validate_response(response): + self.logger.error( + f'Invalid PO Token response provided to PoTokenCache.store(): ' + f'{response}{bug_reports_message()}') + return + + cache_key = self._generate_key(self._generate_key_bindings(spec)) + self.logger.trace(f'Attempting to access PO Token cache using key: {cache_key}') + + default_expires_at = int(dt.datetime.now(dt.timezone.utc).timestamp()) + spec.default_ttl + cache_response = dataclasses.replace(response, expires_at=response.expires_at or default_expires_at) + + write_policy = write_policy or spec.write_policy + self.logger.trace(f'Using write policy: {write_policy}') + + for idx, provider in enumerate(self._get_cache_providers(request)): + try: + self.logger.trace( + f'Caching PO Token response in "{provider.PROVIDER_NAME}" cache provider ' + f'(key={cache_key}, expires_at={cache_response.expires_at})') + provider.store( + key=cache_key, + value=json.dumps(dataclasses.asdict(cache_response)), + expires_at=cache_response.expires_at) + except PoTokenCacheProviderError as e: + self.logger.warning( + f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: ' + f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}') + except Exception as e: + self.logger.error( + f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: ' + f'{e!r}{provider_bug_report_message(provider)}') + + # WRITE_FIRST should not write to lower priority providers in the case the highest priority provider fails + if idx == 0 and write_policy == CacheProviderWritePolicy.WRITE_FIRST: + return + + def close(self): + for provider in self.cache_providers.values(): + provider.close() + for spec_provider in self.cache_spec_providers.values(): + spec_provider.close() + + +class PoTokenRequestDirector: + + def __init__(self, logger: IEContentProviderLogger, cache: PoTokenCache): + self.providers: dict[str, PoTokenProvider] = {} + self.preferences: list[Preference] = [] + self.cache = cache + self.logger = logger + + def register_provider(self, provider: PoTokenProvider): + self.providers[provider.PROVIDER_KEY] = provider + + def register_preference(self, preference: Preference): + self.preferences.append(preference) + + def _get_providers(self, request: PoTokenRequest) -> Iterable[PoTokenProvider]: + """Sorts available providers by preference, given a request""" + preferences = { + provider: sum(pref(provider, request) for pref in self.preferences) + for provider in self.providers.values() + } + if self.logger.log_level <= self.logger.LogLevel.TRACE: + # calling is_available() for every PO Token provider upfront may have some overhead + self.logger.trace(f'PO Token Providers: {provider_display_list(self.providers.values())}') + self.logger.trace('Provider preferences for this request: {}'.format(', '.join( + f'{provider.PROVIDER_NAME}={pref}' for provider, pref in preferences.items()))) + + return ( + provider for provider in sorted( + self.providers.values(), key=preferences.get, reverse=True) + if provider.is_available() + ) + + def _get_po_token(self, request) -> PoTokenResponse | None: + for provider in self._get_providers(request): + try: + self.logger.trace( + f'Attempting to fetch a PO Token from "{provider.PROVIDER_NAME}" provider') + response = provider.request_pot(request.copy()) + except PoTokenProviderRejectedRequest as e: + self.logger.trace( + f'PO Token Provider "{provider.PROVIDER_NAME}" rejected this request, ' + f'trying next available provider. Reason: {e}') + continue + except PoTokenProviderError as e: + self.logger.warning( + f'Error fetching PO Token from "{provider.PROVIDER_NAME}" provider: ' + f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}') + continue + except Exception as e: + self.logger.error( + f'Unexpected error when fetching PO Token from "{provider.PROVIDER_NAME}" provider: ' + f'{e!r}{provider_bug_report_message(provider)}') + continue + + self.logger.trace(f'PO Token response from "{provider.PROVIDER_NAME}" provider: {response}') + + if not validate_response(response): + self.logger.error( + f'Invalid PO Token response received from "{provider.PROVIDER_NAME}" provider: ' + f'{response}{provider_bug_report_message(provider)}') + continue + + return response + + self.logger.trace('No PO Token providers were able to provide a valid PO Token') + return None + + def get_po_token(self, request: PoTokenRequest) -> str | None: + if not request.bypass_cache: + if pot_response := self.cache.get(request): + return clean_pot(pot_response.po_token) + + if not self.providers: + self.logger.trace('No PO Token providers registered') + return None + + pot_response = self._get_po_token(request) + if not pot_response: + return None + + pot_response.po_token = clean_pot(pot_response.po_token) + + if pot_response.expires_at is None or pot_response.expires_at > 0: + self.cache.store(request, pot_response) + else: + self.logger.trace( + f'PO Token response will not be cached (expires_at={pot_response.expires_at})') + + return pot_response.po_token + + def close(self): + for provider in self.providers.values(): + provider.close() + self.cache.close() + + +EXTRACTOR_ARG_PREFIX = 'youtubepot' + + +def initialize_pot_director(ie): + assert ie._downloader is not None, 'Downloader not set' + + enable_trace = ie._configuration_arg( + 'pot_trace', ['false'], ie_key='youtube', casesense=False)[0] == 'true' + + if enable_trace: + log_level = IEContentProviderLogger.LogLevel.TRACE + elif ie.get_param('verbose', False): + log_level = IEContentProviderLogger.LogLevel.DEBUG + else: + log_level = IEContentProviderLogger.LogLevel.INFO + + def get_provider_logger_and_settings(provider, logger_key): + logger_prefix = f'{logger_key}:{provider.PROVIDER_NAME}' + extractor_key = f'{EXTRACTOR_ARG_PREFIX}-{provider.PROVIDER_KEY.lower()}' + return ( + YoutubeIEContentProviderLogger(ie, logger_prefix, log_level=log_level), + ie.get_param('extractor_args', {}).get(extractor_key, {})) + + cache_providers = [] + for cache_provider in _pot_cache_providers.value.values(): + logger, settings = get_provider_logger_and_settings(cache_provider, 'pot:cache') + cache_providers.append(cache_provider(ie, logger, settings)) + cache_spec_providers = [] + for cache_spec_provider in _pot_pcs_providers.value.values(): + logger, settings = get_provider_logger_and_settings(cache_spec_provider, 'pot:cache:spec') + cache_spec_providers.append(cache_spec_provider(ie, logger, settings)) + + cache = PoTokenCache( + logger=YoutubeIEContentProviderLogger(ie, 'pot:cache', log_level=log_level), + cache_providers=cache_providers, + cache_spec_providers=cache_spec_providers, + cache_provider_preferences=list(_pot_cache_provider_preferences.value), + ) + + director = PoTokenRequestDirector( + logger=YoutubeIEContentProviderLogger(ie, 'pot', log_level=log_level), + cache=cache, + ) + + ie._downloader.add_close_hook(director.close) + + for provider in _pot_providers.value.values(): + logger, settings = get_provider_logger_and_settings(provider, 'pot') + director.register_provider(provider(ie, logger, settings)) + + for preference in _ptp_preferences.value: + director.register_preference(preference) + + if director.logger.log_level <= director.logger.LogLevel.DEBUG: + # calling is_available() for every PO Token provider upfront may have some overhead + director.logger.debug(f'PO Token Providers: {provider_display_list(director.providers.values())}') + director.logger.debug(f'PO Token Cache Providers: {provider_display_list(cache.cache_providers.values())}') + director.logger.debug(f'PO Token Cache Spec Providers: {provider_display_list(cache.cache_spec_providers.values())}') + director.logger.trace(f'Registered {len(director.preferences)} provider preferences') + director.logger.trace(f'Registered {len(cache.cache_provider_preferences)} cache provider preferences') + + return director + + +def provider_display_list(providers: Iterable[IEContentProvider]): + def provider_display_name(provider): + display_str = join_nonempty( + provider.PROVIDER_NAME, + provider.PROVIDER_VERSION if not isinstance(provider, BuiltinIEContentProvider) else None) + statuses = [] + if not isinstance(provider, BuiltinIEContentProvider): + statuses.append('external') + if not provider.is_available(): + statuses.append('unavailable') + if statuses: + display_str += f' ({", ".join(statuses)})' + return display_str + + return ', '.join(provider_display_name(provider) for provider in providers) or 'none' + + +def clean_pot(po_token: str): + # Clean and validate the PO Token. This will strip invalid characters off + # (e.g. additional url params the user may accidentally include) + try: + return base64.urlsafe_b64encode( + base64.urlsafe_b64decode(urllib.parse.unquote(po_token))).decode() + except (binascii.Error, ValueError): + raise ValueError('Invalid PO Token') + + +def validate_response(response: PoTokenResponse | None): + if ( + not isinstance(response, PoTokenResponse) + or not isinstance(response.po_token, str) + or not response.po_token + ): # noqa: SIM103 + return False + + try: + clean_pot(response.po_token) + except ValueError: + return False + + if not isinstance(response.expires_at, int): + return response.expires_at is None + + return response.expires_at <= 0 or response.expires_at > int(dt.datetime.now(dt.timezone.utc).timestamp()) + + +def validate_cache_spec(spec: PoTokenCacheSpec): + return ( + isinstance(spec, PoTokenCacheSpec) + and isinstance(spec.write_policy, CacheProviderWritePolicy) + and isinstance(spec.default_ttl, int) + and isinstance(spec.key_bindings, dict) + and all(isinstance(k, str) for k in spec.key_bindings) + and all(v is None or isinstance(v, str) for v in spec.key_bindings.values()) + and bool([v for v in spec.key_bindings.values() if v is not None]) + ) diff --git a/yt_dlp/extractor/youtube/pot/_provider.py b/yt_dlp/extractor/youtube/pot/_provider.py new file mode 100644 index 0000000000..af7034d227 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/_provider.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import abc +import enum +import functools + +from yt_dlp.extractor.common import InfoExtractor +from yt_dlp.utils import NO_DEFAULT, bug_reports_message, classproperty, traverse_obj +from yt_dlp.version import __version__ + +# xxx: these could be generalized outside YoutubeIE eventually + + +class IEContentProviderLogger(abc.ABC): + + class LogLevel(enum.IntEnum): + TRACE = 0 + DEBUG = 10 + INFO = 20 + WARNING = 30 + ERROR = 40 + + @classmethod + def _missing_(cls, value): + if isinstance(value, str): + value = value.upper() + if value in dir(cls): + return cls[value] + + return cls.INFO + + log_level = LogLevel.INFO + + @abc.abstractmethod + def trace(self, message: str): + pass + + @abc.abstractmethod + def debug(self, message: str): + pass + + @abc.abstractmethod + def info(self, message: str): + pass + + @abc.abstractmethod + def warning(self, message: str, *, once=False): + pass + + @abc.abstractmethod + def error(self, message: str): + pass + + +class IEContentProviderError(Exception): + def __init__(self, msg=None, expected=False): + super().__init__(msg) + self.expected = expected + + +class IEContentProvider(abc.ABC): + PROVIDER_VERSION: str = '0.0.0' + BUG_REPORT_LOCATION: str = '(developer has not provided a bug report location)' + + def __init__( + self, + ie: InfoExtractor, + logger: IEContentProviderLogger, + settings: dict[str, list[str]], *_, **__, + ): + self.ie = ie + self.settings = settings or {} + self.logger = logger + super().__init__() + + @classmethod + def __init_subclass__(cls, *, suffix=None, **kwargs): + if suffix: + cls._PROVIDER_KEY_SUFFIX = suffix + return super().__init_subclass__(**kwargs) + + @classproperty + def PROVIDER_NAME(cls) -> str: + return cls.__name__[:-len(cls._PROVIDER_KEY_SUFFIX)] + + @classproperty + def BUG_REPORT_MESSAGE(cls): + return f'please report this issue to the provider developer at {cls.BUG_REPORT_LOCATION} .' + + @classproperty + def PROVIDER_KEY(cls) -> str: + assert hasattr(cls, '_PROVIDER_KEY_SUFFIX'), 'Content Provider implementation must define a suffix for the provider key' + assert cls.__name__.endswith(cls._PROVIDER_KEY_SUFFIX), f'PoTokenProvider class names must end with "{cls._PROVIDER_KEY_SUFFIX}"' + return cls.__name__[:-len(cls._PROVIDER_KEY_SUFFIX)] + + @abc.abstractmethod + def is_available(self) -> bool: + """ + Check if the provider is available (e.g. all required dependencies are available) + This is used to determine if the provider should be used and to provide debug information. + + IMPORTANT: This method should not make any network requests or perform any expensive operations. + It is called multiple times. + """ + raise NotImplementedError + + def close(self): # noqa: B027 + pass + + def _configuration_arg(self, key, default=NO_DEFAULT, *, casesense=False): + """ + @returns A list of values for the setting given by "key" + or "default" if no such key is present + @param default The default value to return when the key is not present (default: []) + @param casesense When false, the values are converted to lower case + """ + val = traverse_obj(self.settings, key) + if val is None: + return [] if default is NO_DEFAULT else default + return list(val) if casesense else [x.lower() for x in val] + + +class BuiltinIEContentProvider(IEContentProvider, abc.ABC): + PROVIDER_VERSION = __version__ + BUG_REPORT_MESSAGE = bug_reports_message(before='') + + +def register_provider_generic( + provider, + base_class, + registry, +): + """Generic function to register a provider class""" + assert issubclass(provider, base_class), f'{provider} must be a subclass of {base_class.__name__}' + assert provider.PROVIDER_KEY not in registry, f'{base_class.__name__} {provider.PROVIDER_KEY} already registered' + registry[provider.PROVIDER_KEY] = provider + return provider + + +def register_preference_generic( + base_class, + registry, + *providers, +): + """Generic function to register a preference for a provider""" + assert all(issubclass(provider, base_class) for provider in providers) + + def outer(preference): + @functools.wraps(preference) + def inner(provider, *args, **kwargs): + if not providers or isinstance(provider, providers): + return preference(provider, *args, **kwargs) + return 0 + registry.add(inner) + return preference + return outer diff --git a/yt_dlp/extractor/youtube/pot/_registry.py b/yt_dlp/extractor/youtube/pot/_registry.py new file mode 100644 index 0000000000..c72a622c12 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/_registry.py @@ -0,0 +1,8 @@ +from yt_dlp.globals import Indirect + +_pot_providers = Indirect({}) +_ptp_preferences = Indirect(set()) +_pot_pcs_providers = Indirect({}) +_pot_cache_providers = Indirect({}) +_pot_cache_provider_preferences = Indirect(set()) +_pot_memory_cache = Indirect({}) diff --git a/yt_dlp/extractor/youtube/pot/cache.py b/yt_dlp/extractor/youtube/pot/cache.py new file mode 100644 index 0000000000..6d69316adc --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/cache.py @@ -0,0 +1,97 @@ +"""PUBLIC API""" + +from __future__ import annotations + +import abc +import dataclasses +import enum +import typing + +from yt_dlp.extractor.youtube.pot._provider import ( + IEContentProvider, + IEContentProviderError, + register_preference_generic, + register_provider_generic, +) +from yt_dlp.extractor.youtube.pot._registry import ( + _pot_cache_provider_preferences, + _pot_cache_providers, + _pot_pcs_providers, +) +from yt_dlp.extractor.youtube.pot.provider import PoTokenRequest + + +class PoTokenCacheProviderError(IEContentProviderError): + """An error occurred while fetching a PO Token""" + + +class PoTokenCacheProvider(IEContentProvider, abc.ABC, suffix='PCP'): + @abc.abstractmethod + def get(self, key: str) -> str | None: + pass + + @abc.abstractmethod + def store(self, key: str, value: str, expires_at: int): + pass + + @abc.abstractmethod + def delete(self, key: str): + pass + + +class CacheProviderWritePolicy(enum.Enum): + WRITE_ALL = enum.auto() # Write to all cache providers + WRITE_FIRST = enum.auto() # Write to only the first cache provider + + +@dataclasses.dataclass +class PoTokenCacheSpec: + key_bindings: dict[str, str | None] + default_ttl: int + write_policy: CacheProviderWritePolicy = CacheProviderWritePolicy.WRITE_ALL + + # Internal + _provider: PoTokenCacheSpecProvider | None = None + + +class PoTokenCacheSpecProvider(IEContentProvider, abc.ABC, suffix='PCSP'): + + def is_available(self) -> bool: + return True + + @abc.abstractmethod + def generate_cache_spec(self, request: PoTokenRequest) -> PoTokenCacheSpec | None: + """Generate a cache spec for the given request""" + pass + + +def register_provider(provider: type[PoTokenCacheProvider]): + """Register a PoTokenCacheProvider class""" + return register_provider_generic( + provider=provider, + base_class=PoTokenCacheProvider, + registry=_pot_cache_providers.value, + ) + + +def register_spec(provider: type[PoTokenCacheSpecProvider]): + """Register a PoTokenCacheSpecProvider class""" + return register_provider_generic( + provider=provider, + base_class=PoTokenCacheSpecProvider, + registry=_pot_pcs_providers.value, + ) + + +def register_preference( + *providers: type[PoTokenCacheProvider]) -> typing.Callable[[CacheProviderPreference], CacheProviderPreference]: + """Register a preference for a PoTokenCacheProvider""" + return register_preference_generic( + PoTokenCacheProvider, + _pot_cache_provider_preferences.value, + *providers, + ) + + +if typing.TYPE_CHECKING: + CacheProviderPreference = typing.Callable[[PoTokenCacheProvider, PoTokenRequest], int] diff --git a/yt_dlp/extractor/youtube/pot/provider.py b/yt_dlp/extractor/youtube/pot/provider.py new file mode 100644 index 0000000000..53af92d30b --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/provider.py @@ -0,0 +1,280 @@ +"""PUBLIC API""" + +from __future__ import annotations + +import abc +import copy +import dataclasses +import enum +import functools +import typing +import urllib.parse + +from yt_dlp.cookies import YoutubeDLCookieJar +from yt_dlp.extractor.youtube.pot._provider import ( + IEContentProvider, + IEContentProviderError, + register_preference_generic, + register_provider_generic, +) +from yt_dlp.extractor.youtube.pot._registry import _pot_providers, _ptp_preferences +from yt_dlp.networking import Request, Response +from yt_dlp.utils import traverse_obj +from yt_dlp.utils.networking import HTTPHeaderDict + +__all__ = [ + 'ExternalRequestFeature', + 'PoTokenContext', + 'PoTokenProvider', + 'PoTokenProviderError', + 'PoTokenProviderRejectedRequest', + 'PoTokenRequest', + 'PoTokenResponse', + 'provider_bug_report_message', + 'register_preference', + 'register_provider', +] + + +class PoTokenContext(enum.Enum): + GVS = 'gvs' + PLAYER = 'player' + + +@dataclasses.dataclass +class PoTokenRequest: + # YouTube parameters + context: PoTokenContext + innertube_context: InnertubeContext + innertube_host: str | None = None + session_index: str | None = None + player_url: str | None = None + is_authenticated: bool = False + video_webpage: str | None = None + internal_client_name: str | None = None + + # Content binding parameters + visitor_data: str | None = None + data_sync_id: str | None = None + video_id: str | None = None + + # Networking parameters + request_cookiejar: YoutubeDLCookieJar = dataclasses.field(default_factory=YoutubeDLCookieJar) + request_proxy: str | None = None + request_headers: HTTPHeaderDict = dataclasses.field(default_factory=HTTPHeaderDict) + request_timeout: float | None = None + request_source_address: str | None = None + request_verify_tls: bool = True + + # Generate a new token, do not used a cached token + # The token should still be cached for future requests + bypass_cache: bool = False + + def copy(self): + return dataclasses.replace( + self, + request_headers=HTTPHeaderDict(self.request_headers), + innertube_context=copy.deepcopy(self.innertube_context), + ) + + +@dataclasses.dataclass +class PoTokenResponse: + po_token: str + expires_at: int | None = None + + +class PoTokenProviderRejectedRequest(IEContentProviderError): + """Reject the PoTokenRequest (cannot handle the request)""" + + +class PoTokenProviderError(IEContentProviderError): + """An error occurred while fetching a PO Token""" + + +class ExternalRequestFeature(enum.Enum): + PROXY_SCHEME_HTTP = enum.auto() + PROXY_SCHEME_HTTPS = enum.auto() + PROXY_SCHEME_SOCKS4 = enum.auto() + PROXY_SCHEME_SOCKS4A = enum.auto() + PROXY_SCHEME_SOCKS5 = enum.auto() + PROXY_SCHEME_SOCKS5H = enum.auto() + SOURCE_ADDRESS = enum.auto() + DISABLE_TLS_VERIFICATION = enum.auto() + + +class PoTokenProvider(IEContentProvider, abc.ABC, suffix='PTP'): + + # Set to None to disable the check + _SUPPORTED_CONTEXTS: tuple[PoTokenContext] | None = () + + # Innertube Client Name. + # For example, "WEB", "ANDROID", "TVHTML5". + # For a list of WebPO client names, see yt_dlp.extractor.youtube.pot.utils.WEBPO_CLIENTS. + # Also see yt_dlp.extractor.youtube._base.INNERTUBE_CLIENTS + # for a list of client names currently supported by the YouTube extractor. + _SUPPORTED_CLIENTS: tuple[str] | None = () + + # If making external requests to websites (i.e. to youtube.com) + # using another library or service (i.e., not _request_webpage), + # add the request features that are supported. + # If only using _request_webpage to make external requests, set this to None. + _SUPPORTED_EXTERNAL_REQUEST_FEATURES: tuple[ExternalRequestFeature] | None = () + + def __validate_request(self, request: PoTokenRequest): + if not self.is_available(): + raise PoTokenProviderRejectedRequest(f'{self.PROVIDER_NAME} is not available') + + # Validate request using built-in settings + if ( + self._SUPPORTED_CONTEXTS is not None + and request.context not in self._SUPPORTED_CONTEXTS + ): + raise PoTokenProviderRejectedRequest( + f'PO Token Context "{request.context}" is not supported by {self.PROVIDER_NAME}') + + if self._SUPPORTED_CLIENTS is not None: + client_name = traverse_obj( + request.innertube_context, ('client', 'clientName')) + if client_name not in self._SUPPORTED_CLIENTS: + raise PoTokenProviderRejectedRequest( + f'Client "{client_name}" is not supported by {self.PROVIDER_NAME}. ' + f'Supported clients: {", ".join(self._SUPPORTED_CLIENTS) or "none"}') + + self.__validate_external_request_features(request) + + @functools.cached_property + def _supported_proxy_schemes(self): + return { + scheme: feature + for scheme, feature in { + 'http': ExternalRequestFeature.PROXY_SCHEME_HTTP, + 'https': ExternalRequestFeature.PROXY_SCHEME_HTTPS, + 'socks4': ExternalRequestFeature.PROXY_SCHEME_SOCKS4, + 'socks4a': ExternalRequestFeature.PROXY_SCHEME_SOCKS4A, + 'socks5': ExternalRequestFeature.PROXY_SCHEME_SOCKS5, + 'socks5h': ExternalRequestFeature.PROXY_SCHEME_SOCKS5H, + }.items() + if feature in (self._SUPPORTED_EXTERNAL_REQUEST_FEATURES or []) + } + + def __validate_external_request_features(self, request: PoTokenRequest): + if self._SUPPORTED_EXTERNAL_REQUEST_FEATURES is None: + return + + if request.request_proxy: + scheme = urllib.parse.urlparse(request.request_proxy).scheme + if scheme.lower() not in self._supported_proxy_schemes: + raise PoTokenProviderRejectedRequest( + f'External requests by "{self.PROVIDER_NAME}" provider do not ' + f'support proxy scheme "{scheme}". Supported proxy schemes: ' + f'{", ".join(self._supported_proxy_schemes) or "none"}') + + if ( + request.request_source_address + and ExternalRequestFeature.SOURCE_ADDRESS not in self._SUPPORTED_EXTERNAL_REQUEST_FEATURES + ): + raise PoTokenProviderRejectedRequest( + f'External requests by "{self.PROVIDER_NAME}" provider ' + f'do not support setting source address') + + if ( + not request.request_verify_tls + and ExternalRequestFeature.DISABLE_TLS_VERIFICATION not in self._SUPPORTED_EXTERNAL_REQUEST_FEATURES + ): + raise PoTokenProviderRejectedRequest( + f'External requests by "{self.PROVIDER_NAME}" provider ' + f'do not support ignoring TLS certificate failures') + + def request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + self.__validate_request(request) + return self._real_request_pot(request) + + @abc.abstractmethod + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + """To be implemented by subclasses""" + pass + + # Helper functions + + def _request_webpage(self, request: Request, pot_request: PoTokenRequest | None = None, note=None, **kwargs) -> Response: + """Make a request using the internal HTTP Client. + Use this instead of calling requests, urllib3 or other HTTP client libraries directly! + + YouTube cookies will be automatically applied if this request is made to YouTube. + + @param request: The request to make + @param pot_request: The PoTokenRequest to use. Request parameters will be merged from it. + @param note: Custom log message to display when making the request. Set to `False` to disable logging. + + Tips: + - Disable proxy (e.g. if calling local service): Request(..., proxies={'all': None}) + - Set request timeout: Request(..., extensions={'timeout': 5.0}) + """ + req = request.copy() + + # Merge some ctx request settings into the request + # Most of these will already be used by the configured ydl instance, + # however, the YouTube extractor may override some. + if pot_request is not None: + req.headers = HTTPHeaderDict(pot_request.request_headers, req.headers) + req.proxies = req.proxies or ({'all': pot_request.request_proxy} if pot_request.request_proxy else {}) + + if pot_request.request_cookiejar is not None: + req.extensions['cookiejar'] = req.extensions.get('cookiejar', pot_request.request_cookiejar) + + if note is not False: + self.logger.info(str(note) if note else 'Requesting webpage') + return self.ie._downloader.urlopen(req) + + +def register_provider(provider: type[PoTokenProvider]): + """Register a PoTokenProvider class""" + return register_provider_generic( + provider=provider, + base_class=PoTokenProvider, + registry=_pot_providers.value, + ) + + +def provider_bug_report_message(provider: IEContentProvider, before=';'): + msg = provider.BUG_REPORT_MESSAGE + + before = before.rstrip() + if not before or before.endswith(('.', '!', '?')): + msg = msg[0].title() + msg[1:] + + return f'{before} {msg}' if before else msg + + +def register_preference(*providers: type[PoTokenProvider]) -> typing.Callable[[Preference], Preference]: + """Register a preference for a PoTokenProvider""" + return register_preference_generic( + PoTokenProvider, + _ptp_preferences.value, + *providers, + ) + + +if typing.TYPE_CHECKING: + Preference = typing.Callable[[PoTokenProvider, PoTokenRequest], int] + __all__.append('Preference') + + # Barebones innertube context. There may be more fields. + class ClientInfo(typing.TypedDict, total=False): + hl: str | None + gl: str | None + remoteHost: str | None + deviceMake: str | None + deviceModel: str | None + visitorData: str | None + userAgent: str | None + clientName: str + clientVersion: str + osName: str | None + osVersion: str | None + + class InnertubeContext(typing.TypedDict, total=False): + client: ClientInfo + request: dict + user: dict diff --git a/yt_dlp/extractor/youtube/pot/utils.py b/yt_dlp/extractor/youtube/pot/utils.py new file mode 100644 index 0000000000..1c0db243bf --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/utils.py @@ -0,0 +1,73 @@ +"""PUBLIC API""" + +from __future__ import annotations + +import base64 +import contextlib +import enum +import re +import urllib.parse + +from yt_dlp.extractor.youtube.pot.provider import PoTokenContext, PoTokenRequest +from yt_dlp.utils import traverse_obj + +__all__ = ['WEBPO_CLIENTS', 'ContentBindingType', 'get_webpo_content_binding'] + +WEBPO_CLIENTS = ( + 'WEB', + 'MWEB', + 'TVHTML5', + 'WEB_EMBEDDED_PLAYER', + 'WEB_CREATOR', + 'WEB_REMIX', + 'TVHTML5_SIMPLY_EMBEDDED_PLAYER', +) + + +class ContentBindingType(enum.Enum): + VISITOR_DATA = 'visitor_data' + DATASYNC_ID = 'datasync_id' + VIDEO_ID = 'video_id' + VISITOR_ID = 'visitor_id' + + +def get_webpo_content_binding( + request: PoTokenRequest, + webpo_clients=WEBPO_CLIENTS, + bind_to_visitor_id=False, +) -> tuple[str | None, ContentBindingType | None]: + + client_name = traverse_obj(request.innertube_context, ('client', 'clientName')) + if not client_name or client_name not in webpo_clients: + return None, None + + if request.context == PoTokenContext.GVS or client_name in ('WEB_REMIX', ): + if request.is_authenticated: + return request.data_sync_id, ContentBindingType.DATASYNC_ID + else: + if bind_to_visitor_id: + visitor_id = _extract_visitor_id(request.visitor_data) + if visitor_id: + return visitor_id, ContentBindingType.VISITOR_ID + return request.visitor_data, ContentBindingType.VISITOR_DATA + + elif request.context == PoTokenContext.PLAYER or client_name != 'WEB_REMIX': + return request.video_id, ContentBindingType.VIDEO_ID + + return None, None + + +def _extract_visitor_id(visitor_data): + if not visitor_data: + return None + + # Attempt to extract the visitor ID from the visitor_data protobuf + # xxx: ideally should use a protobuf parser + with contextlib.suppress(Exception): + visitor_id = base64.urlsafe_b64decode( + urllib.parse.unquote_plus(visitor_data))[2:13].decode() + # check that visitor id is all letters and numbers + if re.fullmatch(r'[A-Za-z0-9_-]{11}', visitor_id): + return visitor_id + + return None diff --git a/yt_dlp/networking/_curlcffi.py b/yt_dlp/networking/_curlcffi.py index c800f2c095..747879da87 100644 --- a/yt_dlp/networking/_curlcffi.py +++ b/yt_dlp/networking/_curlcffi.py @@ -6,7 +6,8 @@ import math import re import urllib.parse -from ._helper import InstanceStoreMixin, select_proxy +from ._helper import InstanceStoreMixin +from ..utils.networking import select_proxy from .common import ( Features, Request, diff --git a/yt_dlp/networking/_helper.py b/yt_dlp/networking/_helper.py index b86d3606d8..ef9c8bafab 100644 --- a/yt_dlp/networking/_helper.py +++ b/yt_dlp/networking/_helper.py @@ -13,7 +13,6 @@ import urllib.request from .exceptions import RequestError from ..dependencies import certifi from ..socks import ProxyType, sockssocket -from ..utils import format_field, traverse_obj if typing.TYPE_CHECKING: from collections.abc import Iterable @@ -82,19 +81,6 @@ def make_socks_proxy_opts(socks_proxy): } -def select_proxy(url, proxies): - """Unified proxy selector for all backends""" - url_components = urllib.parse.urlparse(url) - if 'no' in proxies: - hostport = url_components.hostname + format_field(url_components.port, None, ':%s') - if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}): - return - elif urllib.request.proxy_bypass(hostport): # check system settings - return - - return traverse_obj(proxies, url_components.scheme or 'http', 'all') - - def get_redirect_method(method, status): """Unified redirect method handling""" diff --git a/yt_dlp/networking/_requests.py b/yt_dlp/networking/_requests.py index 5b6b264a68..d02e976b57 100644 --- a/yt_dlp/networking/_requests.py +++ b/yt_dlp/networking/_requests.py @@ -10,7 +10,7 @@ import warnings from ..dependencies import brotli, requests, urllib3 from ..utils import bug_reports_message, int_or_none, variadic -from ..utils.networking import normalize_url +from ..utils.networking import normalize_url, select_proxy if requests is None: raise ImportError('requests module is not installed') @@ -41,7 +41,6 @@ from ._helper import ( create_socks_proxy_socket, get_redirect_method, make_socks_proxy_opts, - select_proxy, ) from .common import ( Features, diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py index a188b35f57..cb7a430bb3 100644 --- a/yt_dlp/networking/_urllib.py +++ b/yt_dlp/networking/_urllib.py @@ -26,7 +26,6 @@ from ._helper import ( create_socks_proxy_socket, get_redirect_method, make_socks_proxy_opts, - select_proxy, ) from .common import Features, RequestHandler, Response, register_rh from .exceptions import ( @@ -41,7 +40,7 @@ from .exceptions import ( from ..dependencies import brotli from ..socks import ProxyError as SocksProxyError from ..utils import update_url_query -from ..utils.networking import normalize_url +from ..utils.networking import normalize_url, select_proxy SUPPORTED_ENCODINGS = ['gzip', 'deflate'] CONTENT_DECODE_ERRORS = [zlib.error, OSError] diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py index d29f8e45a9..fd8730ac7e 100644 --- a/yt_dlp/networking/_websockets.py +++ b/yt_dlp/networking/_websockets.py @@ -11,8 +11,8 @@ from ._helper import ( create_connection, create_socks_proxy_socket, make_socks_proxy_opts, - select_proxy, ) +from ..utils.networking import select_proxy from .common import Features, Response, register_rh from .exceptions import ( CertificateVerifyError, diff --git a/yt_dlp/utils/networking.py b/yt_dlp/utils/networking.py index 542abace87..9fcab6456f 100644 --- a/yt_dlp/utils/networking.py +++ b/yt_dlp/utils/networking.py @@ -10,7 +10,8 @@ import urllib.request if typing.TYPE_CHECKING: T = typing.TypeVar('T') -from ._utils import NO_DEFAULT, remove_start +from ._utils import NO_DEFAULT, remove_start, format_field +from .traversal import traverse_obj def random_user_agent(): @@ -278,3 +279,16 @@ def normalize_url(url): query=escape_rfc3986(url_parsed.query), fragment=escape_rfc3986(url_parsed.fragment), ).geturl() + + +def select_proxy(url, proxies): + """Unified proxy selector for all backends""" + url_components = urllib.parse.urlparse(url) + if 'no' in proxies: + hostport = url_components.hostname + format_field(url_components.port, None, ':%s') + if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}): + return + elif urllib.request.proxy_bypass(hostport): # check system settings + return + + return traverse_obj(proxies, url_components.scheme or 'http', 'all')