1
0
mirror of https://github.com/mikf/gallery-dl.git synced 2024-11-22 18:53:21 +01:00
gallery-dl/gallery_dl/util.py

568 lines
16 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
# Copyright 2017-2018 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
# published by the Free Software Foundation.
2017-03-28 13:12:44 +02:00
"""Utility functions and classes"""
import re
2017-03-28 13:12:44 +02:00
import os
import sys
import hmac
import time
import base64
import random
import shutil
import string
2017-09-27 21:18:34 +02:00
import _string
import hashlib
2018-01-29 22:13:06 +01:00
import sqlite3
import datetime
2017-12-03 01:38:24 +01:00
import itertools
import urllib.parse
from . import text, exception
def parse_range(rangespec):
"""Parse an integer range string and return the resulting ranges
Examples
parse_range("-2,4,6-8,10-") -> [(1,2), (4,4), (6,8), (10,INTMAX)]
parse_range(" - 3 , 4- 4, 2-6") -> [(1,3), (4,4), (2,6)]
"""
ranges = []
for group in rangespec.split(","):
first, sep, last = group.partition("-")
try:
if not sep:
beg = end = int(first)
else:
beg = int(first) if first.strip() else 1
end = int(last) if last.strip() else sys.maxsize
ranges.append((beg, end) if beg <= end else (end, beg))
except ValueError:
pass
return ranges
def optimize_range(ranges):
"""Simplify/Combine a parsed list of ranges
Examples
optimize_range([(2,4), (4,6), (5,8)]) -> [(2,8)]
optimize_range([(1,1), (2,2), (3,6), (8,9))]) -> [(1,6), (8,9)]
"""
if len(ranges) <= 1:
return ranges
ranges.sort()
riter = iter(ranges)
result = []
beg, end = next(riter)
for lower, upper in riter:
if lower > end+1:
result.append((beg, end))
beg, end = lower, upper
elif upper > end:
end = upper
result.append((beg, end))
return result
def bdecode(data, alphabet="0123456789"):
"""Decode a base-N encoded string ( N = len(alphabet) )"""
num = 0
base = len(alphabet)
for c in data:
num *= base
num += alphabet.index(c)
return num
2017-12-02 01:24:49 +01:00
def parse_bytes(value, suffixes="bkmgtp"):
"""Convert a bytes-amount ("500k", "2.5M", ...) to int"""
last = value[-1].lower()
if last in suffixes:
mul = 1024 ** suffixes.index(last)
value = value[:-1]
else:
mul = 1
try:
return round(float(value) * mul)
except ValueError:
return 0
2017-12-03 01:38:24 +01:00
def advance(iterable, num):
""""Advance the iterable by 'num' steps"""
iterator = iter(iterable)
next(itertools.islice(iterator, num, num), None)
return iterator
def combine_dict(a, b):
"""Recursively combine the contents of b into a"""
for key, value in b.items():
if key in a and isinstance(value, dict) and isinstance(a[key], dict):
combine_dict(a[key], value)
else:
a[key] = value
2017-08-13 14:31:22 +02:00
return a
def safe_int(value, default=0):
"""Safely convert value to integer"""
if value is None or value == "":
return default
try:
return int(value)
except (ValueError, TypeError):
return default
2017-10-26 00:04:28 +02:00
def expand_path(path):
"""Expand environment variables and tildes (~)"""
if not path:
return path
if not isinstance(path, str):
path = os.path.join(*path)
2017-10-26 00:04:28 +02:00
return os.path.expandvars(os.path.expanduser(path))
def code_to_language(code, default=None):
2017-03-28 13:12:44 +02:00
"""Map an ISO 639-1 language code to its actual name"""
return CODES.get((code or "").lower(), default)
2017-03-28 13:12:44 +02:00
def language_to_code(lang, default=None):
2017-03-28 13:12:44 +02:00
"""Map a language name to its ISO 639-1 code"""
if lang is None:
return default
2017-03-28 13:12:44 +02:00
lang = lang.capitalize()
for code, language in CODES.items():
2017-03-28 13:12:44 +02:00
if language == lang:
return code
return default
CODES = {
2017-03-28 13:12:44 +02:00
"ar": "Arabic",
"cs": "Czech",
"da": "Danish",
"de": "German",
"el": "Greek",
"en": "English",
"es": "Spanish",
"fi": "Finnish",
"fr": "French",
"he": "Hebrew",
"hu": "Hungarian",
"id": "Indonesian",
"it": "Italian",
"jp": "Japanese",
"ko": "Korean",
"ms": "Malay",
"nl": "Dutch",
"no": "Norwegian",
"pl": "Polish",
"pt": "Portuguese",
"ro": "Romanian",
"ru": "Russian",
"sv": "Swedish",
"th": "Thai",
"tr": "Turkish",
"vi": "Vietnamese",
"zh": "Chinese",
}
SPECIAL_EXTRACTORS = {"oauth", "recursive", "test"}
2017-03-28 13:12:44 +02:00
def build_predicate(predicates):
if not predicates:
return lambda url, kwds: True
elif len(predicates) == 1:
return predicates[0]
else:
return ChainPredicate(predicates)
class RangePredicate():
"""Predicate; True if the current index is in the given range"""
def __init__(self, ranges):
self.ranges = ranges
self.index = 0
if self.ranges:
self.lower, self.upper = self.ranges[0][0], self.ranges[-1][1]
else:
self.lower, self.upper = 0, 0
def __call__(self, url, kwds):
self.index += 1
if self.index > self.upper:
raise exception.StopExtraction()
for lower, upper in self.ranges:
if lower <= self.index <= upper:
return True
return False
2017-03-28 13:12:44 +02:00
class UniquePredicate():
"""Predicate; True if given URL has not been encountered before"""
def __init__(self):
self.urls = set()
def __call__(self, url, kwds):
if url.startswith("text:"):
return True
if url not in self.urls:
self.urls.add(url)
return True
return False
class FilterPredicate():
"""Predicate; True if evaluating the given expression returns True"""
globalsdict = {
"safe_int": safe_int,
"urlsplit": urllib.parse.urlsplit,
"datetime": datetime.datetime,
"re": re,
}
def __init__(self, codeobj):
self.codeobj = codeobj
def __call__(self, url, kwds):
try:
return eval(self.codeobj, self.globalsdict, kwds)
except Exception as exc:
raise exception.FilterError(exc)
class ChainPredicate():
"""Predicate; True if all of its predicates return True"""
def __init__(self, predicates):
self.predicates = predicates
def __call__(self, url, kwds):
for pred in self.predicates:
if not pred(url, kwds):
return False
return True
class ExtendedUrl():
"""URL with attached config key-value pairs"""
def __init__(self, url, gconf, lconf):
self.value, self.gconfig, self.lconfig = url, gconf, lconf
def __str__(self):
return self.value
2017-09-27 21:18:34 +02:00
class Formatter():
"""Custom, trimmed-down version of string.Formatter
This string formatter implementation is a mostly performance-optimized
variant of the original string.Formatter class. Unnecessary features have
been removed (positional arguments, unused argument check) and new
formatting options have been added.
Extra Conversions:
- "l": calls str.lower on the target value
- "u": calls str.upper
- "c": calls str.capitalize
- "C": calls string.capwords
- Example: {f!l} -> "example"; {f!u} -> "EXAMPLE"
Extra Format Specifiers:
- "?<before>/<after>/":
Adds <before> and <after> to the actual value if it evaluates to True.
2017-09-30 18:52:23 +02:00
Otherwise the whole replacement field becomes an empty string.
2017-09-27 21:18:34 +02:00
Example: {f:?-+/+-/} -> "-+Example+-" (if "f" contains "Example")
-> "" (if "f" is None, 0, "")
"""
conversions = {
"l": str.lower,
"u": str.upper,
"c": str.capitalize,
"C": string.capwords,
"s": str,
"r": repr,
"a": ascii,
}
def __init__(self, default=None):
self.kwdefault = default
def vformat(self, format_string, kwargs):
2017-09-27 21:18:34 +02:00
"""Apply 'kwargs' to the initial format_string and return its result"""
result = []
append = result.append
for literal_text, field_name, format_spec, conversion in \
_string.formatter_parser(format_string):
2017-09-27 21:18:34 +02:00
if literal_text:
append(literal_text)
if field_name:
obj = self.get_field(field_name, kwargs)
if conversion:
obj = self.conversions[conversion](obj)
if format_spec:
format_spec = format_spec.format_map(kwargs)
obj = self.format_field(obj, format_spec)
else:
obj = str(obj)
append(obj)
return "".join(result)
@staticmethod
def format_field(value, format_spec):
"""Format 'value' according to 'format_spec'"""
if format_spec[0] == "?":
if not value:
return ""
before, after, format_spec = format_spec.split("/", 2)
return before[1:] + format(value, format_spec) + after
return format(value, format_spec)
def get_field(self, field_name, kwargs):
"""Return value with key 'field_name' from 'kwargs'"""
2017-09-27 21:18:34 +02:00
first, rest = _string.formatter_field_name_split(field_name)
if first not in kwargs:
return self.kwdefault
2017-09-27 21:18:34 +02:00
obj = kwargs[first]
for is_attr, i in rest:
if is_attr:
obj = getattr(obj, i)
else:
obj = obj[i]
return obj
2017-03-28 13:12:44 +02:00
class PathFormat():
def __init__(self, extractor):
self.filename_fmt = extractor.config(
"filename", extractor.filename_fmt)
self.directory_fmt = extractor.config(
"directory", extractor.directory_fmt)
self.formatter = Formatter(extractor.config("keywords-default"))
2017-03-28 13:12:44 +02:00
self.has_extension = False
self.keywords = {}
self.filename = ""
2017-03-28 13:12:44 +02:00
self.directory = self.realdirectory = ""
self.path = self.realpath = self.partpath = ""
2017-03-28 13:12:44 +02:00
self.basedirectory = expand_path(
extractor.config("base-directory", (".", "gallery-dl")))
if os.altsep:
self.basedirectory = self.basedirectory.replace(os.altsep, os.sep)
skip = extractor.config("skip", True)
if skip:
if skip == "abort":
self._skipexc = exception.StopExtraction
elif skip == "exit":
self._skipexc = exit
else:
self._skipexc = None
else:
self.exists = lambda x=None: False
def open(self, mode="wb"):
"""Open file and return a corresponding file object"""
return open(self.partpath or self.realpath, mode)
2017-03-28 13:12:44 +02:00
def exists(self, archive=None):
"""Return True if the file exists on disk or in 'archive'"""
if (self.has_extension and os.path.exists(self.realpath) or
archive and archive.check(self.keywords)):
if self._skipexc:
raise self._skipexc()
if not self.has_extension:
self.set_extension("")
if self.path[-1] == ".":
self.path = self.path[:-1]
return True
2017-03-28 13:12:44 +02:00
return False
def set_directory(self, keywords):
"""Build directory path and create it if necessary"""
try:
segments = [
2017-09-27 21:18:34 +02:00
text.clean_path(
self.formatter.vformat(segment, keywords).strip())
for segment in self.directory_fmt
]
except Exception as exc:
raise exception.FormatError(exc, "directory")
2017-03-28 13:12:44 +02:00
self.directory = os.path.join(
self.basedirectory,
2017-03-28 13:12:44 +02:00
*segments
)
self.realdirectory = self.adjust_path(self.directory)
os.makedirs(self.realdirectory, exist_ok=True)
def set_keywords(self, keywords):
"""Set filename keywords"""
self.keywords = keywords
self.has_extension = bool(keywords.get("extension"))
if self.has_extension:
self.build_path()
def set_extension(self, extension, real=True):
2017-03-28 13:12:44 +02:00
"""Set the 'extension' keyword"""
self.has_extension = real
2017-03-28 13:12:44 +02:00
self.keywords["extension"] = extension
self.build_path()
def adjust_extension(self, extension):
"""Change filename extension of existing file"""
oldpath = self.realpath
self.set_extension(extension)
if not self.partpath:
os.replace(oldpath, self.realpath)
def build_path(self):
2017-03-28 13:12:44 +02:00
"""Use filename-keywords and directory to build a full path"""
try:
self.filename = text.clean_path(
self.formatter.vformat(self.filename_fmt, self.keywords))
except Exception as exc:
raise exception.FormatError(exc, "filename")
filename = os.sep + self.filename
self.path = self.directory + filename
self.realpath = self.realdirectory + filename
2017-03-28 13:12:44 +02:00
def part_enable(self, part_directory=None):
"""Enable .part file usage"""
if self.has_extension:
self.partpath = self.realpath + ".part"
else:
self.set_extension("part", False)
self.partpath = self.realpath
if part_directory:
self.partpath = os.path.join(
part_directory,
os.path.basename(self.partpath),
)
def part_size(self):
"""Return size of .part file"""
if self.partpath:
try:
return os.stat(self.partpath).st_size
except OSError:
pass
return 0
def part_move(self):
"""Rename .part file to its actual filename"""
try:
os.replace(self.partpath, self.realpath)
return
except OSError:
pass
shutil.copyfile(self.partpath, self.realpath)
os.unlink(self.partpath)
2017-03-28 13:12:44 +02:00
@staticmethod
def adjust_path(path):
"""Enable longer-than-260-character paths on windows"""
return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path
class OAuthSession():
"""Minimal wrapper for requests.session objects to support OAuth 1.0"""
def __init__(self, session, consumer_key, consumer_secret,
token=None, token_secret=None):
self.session = session
self.consumer_secret = consumer_secret
self.token_secret = token_secret or ""
2017-07-17 10:33:36 +02:00
self.params = {}
self.params["oauth_consumer_key"] = consumer_key
self.params["oauth_token"] = token
self.params["oauth_signature_method"] = "HMAC-SHA1"
self.params["oauth_version"] = "1.0"
def get(self, url, params):
params.update(self.params)
params["oauth_nonce"] = self.nonce(16)
params["oauth_timestamp"] = int(time.time())
return self.session.get(url + self.sign(url, params))
def sign(self, url, params):
"""Generate 'oauth_signature' value and return query string"""
query = self.urlencode(params)
message = self.concat("GET", url, query).encode()
key = self.concat(self.consumer_secret, self.token_secret).encode()
signature = hmac.new(key, message, hashlib.sha1).digest()
return "?{}&oauth_signature={}".format(
query, self.quote(base64.b64encode(signature).decode()))
@staticmethod
def concat(*args):
return "&".join(OAuthSession.quote(item) for item in args)
@staticmethod
def nonce(N, alphabet=string.ascii_letters):
return "".join(random.choice(alphabet) for _ in range(N))
@staticmethod
2018-02-11 16:04:40 +01:00
def quote(value, quote=urllib.parse.quote):
return quote(value, "~")
@staticmethod
def urlencode(params):
2018-02-11 16:04:40 +01:00
return "&".join(
OAuthSession.quote(str(key)) + "=" + OAuthSession.quote(str(value))
for key, value in sorted(params.items()) if value
)
2018-01-29 22:13:06 +01:00
class DownloadArchive():
def __init__(self, path, extractor):
2018-01-29 22:13:06 +01:00
con = sqlite3.connect(path)
con.isolation_level = None
self.cursor = con.cursor()
self.cursor.execute("CREATE TABLE IF NOT EXISTS archive "
"(entry PRIMARY KEY) WITHOUT ROWID")
2018-02-24 21:21:59 +01:00
self.keygen = (extractor.category + extractor.config(
"archive-format", extractor.archive_fmt)
).format_map
2018-01-29 22:13:06 +01:00
def check(self, kwdict):
"""Return True if item described by 'kwdict' exists in archive"""
key = self.keygen(kwdict)
2018-01-29 22:13:06 +01:00
self.cursor.execute(
"SELECT 1 FROM archive WHERE entry=? LIMIT 1", (key,))
2018-01-29 22:13:06 +01:00
return self.cursor.fetchone()
def add(self, kwdict):
"""Add item described by 'kwdict' to archive"""
key = self.keygen(kwdict)
2018-01-29 22:13:06 +01:00
self.cursor.execute(
"INSERT OR IGNORE INTO archive VALUES (?)", (key,))