1
0
mirror of https://github.com/mikf/gallery-dl.git synced 2024-11-25 20:22:36 +01:00

[downloader] overhaul http and text modules

Get rid of the modular structure and simplify/specialize those modules.
This commit is contained in:
Mike Fährmann 2019-06-19 22:19:29 +02:00
parent 03e6876fbe
commit 179d112083
No known key found for this signature in database
GPG Key ID: 5680CA389D365A88
4 changed files with 155 additions and 198 deletions

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2018 Mike Fährmann # Copyright 2014-2019 Mike Fährmann
# #
# This program is free software; you can redistribute it and/or modify # This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as # it under the terms of the GNU General Public License version 2 as
@ -9,23 +9,18 @@
"""Common classes and constants used by downloader modules.""" """Common classes and constants used by downloader modules."""
import os import os
import time
import logging import logging
from .. import config, util, exception from .. import config, util
from requests.exceptions import RequestException
from ssl import SSLError
class DownloaderBase(): class DownloaderBase():
"""Base class for downloaders""" """Base class for downloaders"""
scheme = "" scheme = ""
retries = 1
def __init__(self, extractor, output): def __init__(self, extractor, output):
self.session = extractor.session self.session = extractor.session
self.out = output self.out = output
self.log = logging.getLogger("downloader." + self.scheme) self.log = logging.getLogger("downloader." + self.scheme)
self.downloading = False
self.part = self.config("part", True) self.part = self.config("part", True)
self.partdir = self.config("part-directory") self.partdir = self.config("part-directory")
@ -34,137 +29,8 @@ class DownloaderBase():
os.makedirs(self.partdir, exist_ok=True) os.makedirs(self.partdir, exist_ok=True)
def config(self, key, default=None): def config(self, key, default=None):
"""Interpolate config value for 'key'""" """Interpolate downloader config value for 'key'"""
return config.interpolate(("downloader", self.scheme, key), default) return config.interpolate(("downloader", self.scheme, key), default)
def download(self, url, pathfmt): def download(self, url, pathfmt):
"""Download the resource at 'url' and write it to a file-like object""" """Write data from 'url' into the file specified by 'pathfmt'"""
try:
return self.download_impl(url, pathfmt)
except Exception:
print()
raise
finally:
# remove file from incomplete downloads
if self.downloading and not self.part:
try:
os.remove(pathfmt.temppath)
except (OSError, AttributeError):
pass
def download_impl(self, url, pathfmt):
"""Actual implementaion of the download process"""
adj_ext = None
tries = 0
msg = ""
if self.part:
pathfmt.part_enable(self.partdir)
while True:
self.reset()
if tries:
self.log.warning("%s (%d/%d)", msg, tries, self.retries)
if tries >= self.retries:
return False
time.sleep(tries)
tries += 1
# check for .part file
filesize = pathfmt.part_size()
# connect to (remote) source
try:
offset, size = self.connect(url, filesize)
except exception.DownloadRetry as exc:
msg = exc
continue
except exception.DownloadComplete:
break
except Exception as exc:
self.log.warning(exc)
return False
# check response
if not offset:
mode = "w+b"
if filesize:
self.log.info("Unable to resume partial download")
else:
mode = "r+b"
self.log.info("Resuming download at byte %d", offset)
# set missing filename extension
if not pathfmt.has_extension:
pathfmt.set_extension(self.get_extension())
if pathfmt.exists():
pathfmt.temppath = ""
return True
self.out.start(pathfmt.path)
self.downloading = True
with pathfmt.open(mode) as file:
if offset:
file.seek(offset)
# download content
try:
self.receive(file)
except (RequestException, SSLError) as exc:
msg = exc
print()
continue
# check filesize
if size and file.tell() < size:
msg = "filesize mismatch ({} < {})".format(
file.tell(), size)
continue
# check filename extension
adj_ext = self._check_extension(file, pathfmt)
break
self.downloading = False
if adj_ext:
pathfmt.set_extension(adj_ext)
return True
def connect(self, url, offset):
"""Connect to 'url' while respecting 'offset' if possible
Returns a 2-tuple containing the actual offset and expected filesize.
If the returned offset-value is greater than zero, all received data
will be appended to the existing .part file.
Return '0' as second tuple-field to indicate an unknown filesize.
"""
def receive(self, file):
"""Write data to 'file'"""
def reset(self):
"""Reset internal state / cleanup"""
def get_extension(self):
"""Return a filename extension appropriate for the current request"""
@staticmethod
def _check_extension(file, pathfmt):
"""Check filename extension against fileheader"""
extension = pathfmt.keywords["extension"]
if extension in FILETYPE_CHECK:
file.seek(0)
header = file.read(8)
if len(header) >= 8 and not FILETYPE_CHECK[extension](header):
for ext, check in FILETYPE_CHECK.items():
if ext != extension and check(header):
return ext
return None
FILETYPE_CHECK = {
"jpg": lambda h: h[0:2] == b"\xff\xd8",
"png": lambda h: h[0:8] == b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a",
"gif": lambda h: h[0:4] == b"GIF8" and h[5] == 97,
}

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2018 Mike Fährmann # Copyright 2014-2019 Mike Fährmann
# #
# This program is free software; you can redistribute it and/or modify # This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as # it under the terms of the GNU General Public License version 2 as
@ -8,11 +8,13 @@
"""Downloader module for http:// and https:// URLs""" """Downloader module for http:// and https:// URLs"""
import os
import time import time
import mimetypes import mimetypes
from requests.exceptions import ConnectionError, Timeout from requests.exceptions import RequestException, ConnectionError, Timeout
from ssl import SSLError
from .common import DownloaderBase from .common import DownloaderBase
from .. import text, exception from .. import text
class HttpDownloader(DownloaderBase): class HttpDownloader(DownloaderBase):
@ -20,11 +22,11 @@ class HttpDownloader(DownloaderBase):
def __init__(self, extractor, output): def __init__(self, extractor, output):
DownloaderBase.__init__(self, extractor, output) DownloaderBase.__init__(self, extractor, output)
self.response = None
self.retries = self.config("retries", extractor._retries) self.retries = self.config("retries", extractor._retries)
self.timeout = self.config("timeout", extractor._timeout) self.timeout = self.config("timeout", extractor._timeout)
self.verify = self.config("verify", extractor._verify) self.verify = self.config("verify", extractor._verify)
self.rate = self.config("rate") self.rate = self.config("rate")
self.downloading = False
self.chunk_size = 16384 self.chunk_size = 16384
if self.rate: if self.rate:
@ -34,41 +36,129 @@ class HttpDownloader(DownloaderBase):
elif self.rate < self.chunk_size: elif self.rate < self.chunk_size:
self.chunk_size = self.rate self.chunk_size = self.rate
def connect(self, url, offset): def download(self, url, pathfmt):
headers = {}
if offset:
headers["Range"] = "bytes={}-".format(offset)
try: try:
self.response = self.session.request( return self._download_impl(url, pathfmt)
"GET", url, stream=True, headers=headers, allow_redirects=True, except Exception:
timeout=self.timeout, verify=self.verify) print()
except (ConnectionError, Timeout) as exc: raise
raise exception.DownloadRetry(exc) finally:
# remove file from incomplete downloads
if self.downloading and not self.part:
try:
os.unlink(pathfmt.temppath)
except (OSError, AttributeError):
pass
code = self.response.status_code def _download_impl(self, url, pathfmt):
if code == 200: # OK response = None
offset = 0 adj_ext = None
size = self.response.headers.get("Content-Length") tries = 0
elif code == 206: # Partial Content msg = ""
size = self.response.headers["Content-Range"].rpartition("/")[2]
elif code == 416: # Requested Range Not Satisfiable
raise exception.DownloadComplete()
elif code == 429 or 500 <= code < 600: # Server Error
raise exception.DownloadRetry(
"{} Server Error: {} for url: {}".format(
code, self.response.reason, url))
else:
self.response.raise_for_status()
return offset, text.parse_int(size) if self.part:
pathfmt.part_enable(self.partdir)
def receive(self, file): while True:
if tries:
if response:
response.close()
self.log.warning("%s (%d/%d)", msg, tries, self.retries)
if tries >= self.retries:
return False
time.sleep(tries)
tries += 1
# check for .part file
filesize = pathfmt.part_size()
if filesize:
headers = {"Range": "bytes={}-".format(filesize)}
else:
headers = None
# connect to (remote) source
try:
response = self.session.request(
"GET", url, stream=True, headers=headers,
timeout=self.timeout, verify=self.verify)
except (ConnectionError, Timeout) as exc:
msg = str(exc)
continue
except Exception as exc:
self.log.warning("%s", exc)
return False
# check response
code = response.status_code
if code == 200: # OK
offset = 0
size = response.headers.get("Content-Length")
elif code == 206: # Partial Content
offset = filesize
size = response.headers["Content-Range"].rpartition("/")[2]
elif code == 416: # Requested Range Not Satisfiable
break
else:
msg = "{}: {} for url: {}".format(code, response.reason, url)
if code == 429 or 500 <= code < 600: # Server Error
continue
self.log.warning("%s", msg)
return False
size = text.parse_int(size)
# set missing filename extension
if not pathfmt.has_extension:
pathfmt.set_extension(self.get_extension(response))
if pathfmt.exists():
pathfmt.temppath = ""
return True
# set open mode
if not offset:
mode = "w+b"
if filesize:
self.log.info("Unable to resume partial download")
else:
mode = "r+b"
self.log.info("Resuming download at byte %d", offset)
# start downloading
self.out.start(pathfmt.path)
self.downloading = True
with pathfmt.open(mode) as file:
if offset:
file.seek(offset)
# download content
try:
self.receive(response, file)
except (RequestException, SSLError) as exc:
msg = str(exc)
print()
continue
# check filesize
if size and file.tell() < size:
msg = "filesize mismatch ({} < {})".format(
file.tell(), size)
continue
# check filename extension
adj_ext = self.check_extension(file, pathfmt)
break
self.downloading = False
if adj_ext:
pathfmt.set_extension(adj_ext)
return True
def receive(self, response, file):
if self.rate: if self.rate:
total = 0 # total amount of bytes received total = 0 # total amount of bytes received
start = time.time() # start time start = time.time() # start time
for data in self.response.iter_content(self.chunk_size): for data in response.iter_content(self.chunk_size):
file.write(data) file.write(data)
if self.rate: if self.rate:
@ -79,13 +169,8 @@ class HttpDownloader(DownloaderBase):
# sleep if less time passed than expected # sleep if less time passed than expected
time.sleep(expected - delta) time.sleep(expected - delta)
def reset(self): def get_extension(self, response):
if self.response: mtype = response.headers.get("Content-Type", "image/jpeg")
self.response.close()
self.response = None
def get_extension(self):
mtype = self.response.headers.get("Content-Type", "image/jpeg")
mtype = mtype.partition(";")[0] mtype = mtype.partition(";")[0]
if mtype in MIMETYPE_MAP: if mtype in MIMETYPE_MAP:
@ -100,6 +185,26 @@ class HttpDownloader(DownloaderBase):
"No filename extension found for MIME type '%s'", mtype) "No filename extension found for MIME type '%s'", mtype)
return "txt" return "txt"
@staticmethod
def check_extension(file, pathfmt):
"""Check filename extension against fileheader"""
extension = pathfmt.keywords["extension"]
if extension in FILETYPE_CHECK:
file.seek(0)
header = file.read(8)
if len(header) >= 8 and not FILETYPE_CHECK[extension](header):
for ext, check in FILETYPE_CHECK.items():
if ext != extension and check(header):
return ext
return None
FILETYPE_CHECK = {
"jpg": lambda h: h[0:2] == b"\xff\xd8",
"png": lambda h: h[0:8] == b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a",
"gif": lambda h: h[0:4] == b"GIF8" and h[5] == 97,
}
MIMETYPE_MAP = { MIMETYPE_MAP = {
"image/jpeg": "jpg", "image/jpeg": "jpg",

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2018 Mike Fährmann # Copyright 2014-2019 Mike Fährmann
# #
# This program is free software; you can redistribute it and/or modify # This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as # it under the terms of the GNU General Public License version 2 as
@ -14,24 +14,13 @@ from .common import DownloaderBase
class TextDownloader(DownloaderBase): class TextDownloader(DownloaderBase):
scheme = "text" scheme = "text"
def __init__(self, extractor, output): def download(self, url, pathfmt):
DownloaderBase.__init__(self, extractor, output) if self.part:
self.content = b"" pathfmt.part_enable(self.partdir)
self.out.start(pathfmt.path)
def connect(self, url, offset): with pathfmt.open("wb") as file:
data = url.encode() file.write(url.encode()[5:])
self.content = data[offset + 5:] return True
return offset, len(data) - 5
def receive(self, file):
file.write(self.content)
def reset(self):
self.content = b""
@staticmethod
def get_extension():
return "txt"
__downloader__ = TextDownloader __downloader__ = TextDownloader

View File

@ -134,9 +134,6 @@ class TestTextDownloader(TestDownloaderBase):
def test_text_offset(self): def test_text_offset(self):
self._run_test("text:foobar", "foo", "foobar", "txt", "txt") self._run_test("text:foobar", "foo", "foobar", "txt", "txt")
def test_text_extension(self):
self._run_test("text:foobar", None, "foobar", None, "txt")
def test_text_empty(self): def test_text_empty(self):
self._run_test("text:", None, "", "txt", "txt") self._run_test("text:", None, "", "txt", "txt")