From 8c3783706d9d4fc626f1732bda3ad0a77368b694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mike=20F=C3=A4hrmann?= Date: Fri, 20 Nov 2015 19:54:07 +0100 Subject: [PATCH] allow multiple extractors per module --- gallery_dl/extractor/__init__.py | 33 +++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/gallery_dl/extractor/__init__.py b/gallery_dl/extractor/__init__.py index 066cd18a..4c278641 100644 --- a/gallery_dl/extractor/__init__.py +++ b/gallery_dl/extractor/__init__.py @@ -47,10 +47,10 @@ modules = [ def find(url): """Find extractor suitable for handling the given url""" - for pattern, module, klass in _list_patterns(): + for pattern, info, klass in _list_patterns(): match = re.match(pattern, url) if match: - return klass(match), module.info + return klass(match), info return None, None # -------------------------------------------------------------------- @@ -60,15 +60,30 @@ _cache = [] _module_iter = iter(modules) def _list_patterns(): - """Yield all available (pattern, module, klass) tuples""" + """Yield all available (pattern, info, class) tuples""" for entry in _cache: yield entry for module_name in _module_iter: module = importlib.import_module("."+module_name, __package__) - klass = getattr(module, module.info["extractor"]) - userpatterns = config.get(("extractor", module_name, "pattern"), default=[]) - for pattern in userpatterns + module.info["pattern"]: - etuple = (pattern, module, klass) - _cache.append(etuple) - yield etuple + try: + klass = getattr(module, module.info["extractor"]) + userpatterns = config.get(("extractor", module_name, "pattern"), default=[]) + for pattern in userpatterns + module.info["pattern"]: + etuple = (pattern, module.info, klass) + _cache.append(etuple) + yield etuple + except AttributeError: + for klass in _get_classes(module): + for pattern in klass.pattern: + etuple = (pattern, klass.info, klass) + _cache.append(etuple) + yield etuple + +def _get_classes(module): + """Return a list of all extractor classes in a module""" + return [ + klass for klass in module.__dict__.values() if ( + hasattr(klass, "info") and klass.__module__ == module.__name__ + ) + ]