1
0
mirror of https://github.com/mikf/gallery-dl.git synced 2024-11-25 20:22:36 +01:00

[reddit] don't send OAuth headers for file downloads (fixes #729)

This commit is contained in:
Mike Fährmann 2020-05-08 21:42:52 +02:00
parent ba42ec321c
commit 0bf0146bfe
No known key found for this signature in database
GPG Key ID: 5680CA389D365A88
3 changed files with 16 additions and 12 deletions

View File

@ -1,5 +1,7 @@
# Changelog # Changelog
## Unreleased
## 1.13.6 - 2020-05-02 ## 1.13.6 - 2020-05-02
### Additions ### Additions
- [patreon] respect filters and sort order in query parameters ([#711](https://github.com/mikf/gallery-dl/issues/711)) - [patreon] respect filters and sort order in query parameters ([#711](https://github.com/mikf/gallery-dl/issues/711))

View File

@ -229,13 +229,12 @@ class RedditAPI():
user_agent = extractor.config("user-agent", self.USER_AGENT) user_agent = extractor.config("user-agent", self.USER_AGENT)
if (client_id == self.CLIENT_ID) ^ (user_agent == self.USER_AGENT): if (client_id == self.CLIENT_ID) ^ (user_agent == self.USER_AGENT):
self.client_id = None raise exception.StopExtraction(
self.log.warning(
"Conflicting values for 'client-id' and 'user-agent': " "Conflicting values for 'client-id' and 'user-agent': "
"overwrite either both or none of them.") "overwrite either both or none of them.")
else:
self.client_id = client_id self.client_id = client_id
extractor.session.headers["User-Agent"] = user_agent self.headers = {"User-Agent": user_agent}
def submission(self, submission_id): def submission(self, submission_id):
"""Fetch the (submission, comments)=-tuple for a submission id""" """Fetch the (submission, comments)=-tuple for a submission id"""
@ -277,13 +276,15 @@ class RedditAPI():
def authenticate(self): def authenticate(self):
"""Authenticate the application by requesting an access token""" """Authenticate the application by requesting an access token"""
access_token = self._authenticate_impl(self.refresh_token) self.headers["Authorization"] = \
self.extractor.session.headers["Authorization"] = access_token self._authenticate_impl(self.refresh_token)
@cache(maxage=3600, keyarg=1) @cache(maxage=3600, keyarg=1)
def _authenticate_impl(self, refresh_token=None): def _authenticate_impl(self, refresh_token=None):
"""Actual authenticate implementation""" """Actual authenticate implementation"""
url = "https://www.reddit.com/api/v1/access_token" url = "https://www.reddit.com/api/v1/access_token"
self.headers["Authorization"] = None
if refresh_token: if refresh_token:
self.log.info("Refreshing private access token") self.log.info("Refreshing private access token")
data = {"grant_type": "refresh_token", data = {"grant_type": "refresh_token",
@ -294,9 +295,9 @@ class RedditAPI():
"grants/installed_client"), "grants/installed_client"),
"device_id": "DO_NOT_TRACK_THIS_DEVICE"} "device_id": "DO_NOT_TRACK_THIS_DEVICE"}
auth = (self.client_id, "")
response = self.extractor.request( response = self.extractor.request(
url, method="POST", data=data, auth=auth, fatal=False) url, method="POST", headers=self.headers,
data=data, auth=(self.client_id, ""), fatal=False)
data = response.json() data = response.json()
if response.status_code != 200: if response.status_code != 200:
@ -307,9 +308,10 @@ class RedditAPI():
def _call(self, endpoint, params): def _call(self, endpoint, params):
url = "https://oauth.reddit.com" + endpoint url = "https://oauth.reddit.com" + endpoint
params["raw_json"] = 1 params["raw_json"] = "1"
self.authenticate() self.authenticate()
response = self.extractor.request(url, params=params, fatal=None) response = self.extractor.request(
url, params=params, headers=self.headers, fatal=None)
remaining = response.headers.get("x-ratelimit-remaining") remaining = response.headers.get("x-ratelimit-remaining")
if remaining and float(remaining) < 2: if remaining and float(remaining) < 2:

View File

@ -6,4 +6,4 @@
# it under the terms of the GNU General Public License version 2 as # it under the terms of the GNU General Public License version 2 as
# published by the Free Software Foundation. # published by the Free Software Foundation.
__version__ = "1.13.6" __version__ = "1.14.0-dev"