1
0
mirror of https://github.com/mikf/gallery-dl.git synced 2024-11-22 10:42:34 +01:00
gallery-dl/gallery_dl/extractor/civitai.py
2024-10-27 15:46:00 +01:00

660 lines
21 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2024 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
# published by the Free Software Foundation.
"""Extractors for https://www.civitai.com/"""
from .common import Extractor, Message
from .. import text, util, exception
import itertools
import time
BASE_PATTERN = r"(?:https?://)?civitai\.com"
USER_PATTERN = BASE_PATTERN + r"/user/([^/?#]+)"
class CivitaiExtractor(Extractor):
"""Base class for civitai extractors"""
category = "civitai"
root = "https://civitai.com"
directory_fmt = ("{category}", "{username|user[username]}", "images")
filename_fmt = "{file[id]|id|filename}.{extension}"
archive_fmt = "{file[uuid]|uuid}"
request_interval = (0.5, 1.5)
def _init(self):
if self.config("api") == "rest":
self.log.debug("Using REST API")
self.api = CivitaiRestAPI(self)
else:
self.log.debug("Using tRPC API")
self.api = CivitaiTrpcAPI(self)
quality = self.config("quality")
if quality:
if not isinstance(quality, str):
quality = ",".join(quality)
self._image_quality = quality
self._image_ext = ("png" if quality == "original=true" else "jpg")
else:
self._image_quality = "original=true"
self._image_ext = "png"
metadata = self.config("metadata")
if metadata:
if isinstance(metadata, str):
metadata = metadata.split(",")
elif not isinstance(metadata, (list, tuple)):
metadata = ("generation",)
self._meta_generation = ("generation" in metadata)
else:
self._meta_generation = False
def items(self):
models = self.models()
if models:
data = {"_extractor": CivitaiModelExtractor}
for model in models:
url = "{}/models/{}".format(self.root, model["id"])
yield Message.Queue, url, data
return
posts = self.posts()
if posts:
for post in posts:
if "images" in post:
images = post["images"]
else:
images = self.api.images_post(post["id"])
post = self.api.post(post["id"])
post["date"] = text.parse_datetime(
post["publishedAt"], "%Y-%m-%dT%H:%M:%S.%fZ")
data = {
"post": post,
"user": post["user"],
}
del post["user"]
yield Message.Directory, data
for file in self._image_results(images):
file.update(data)
yield Message.Url, file["url"], file
return
images = self.images()
if images:
for image in images:
url = self._url(image)
if self._meta_generation:
image["generation"] = self.api.image_generationdata(
image["id"])
image["date"] = text.parse_datetime(
image["createdAt"], "%Y-%m-%dT%H:%M:%S.%fZ")
text.nameext_from_url(url, image)
image["extension"] = self._image_ext
yield Message.Directory, image
yield Message.Url, url, image
return
def models(self):
return ()
def posts(self):
return ()
def images(self):
return ()
def _url(self, image):
url = image["url"]
if "/" in url:
parts = url.rsplit("/", 3)
image["uuid"] = parts[1]
parts[2] = self._image_quality
return "/".join(parts)
image["uuid"] = url
name = image.get("name")
if not name:
mime = image.get("mimeType") or self._image_ext
name = "{}.{}".format(image.get("id"), mime.rpartition("/")[2])
return (
"https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{}/{}/{}".format(
url, self._image_quality, name)
)
def _image_results(self, images):
for num, file in enumerate(images, 1):
data = text.nameext_from_url(file["url"], {
"num" : num,
"file": file,
"url" : self._url(file),
})
if not data["extension"]:
data["extension"] = self._image_ext
if "id" not in file and data["filename"].isdecimal():
file["id"] = text.parse_int(data["filename"])
if self._meta_generation:
file["generation"] = self.api.image_generationdata(file["id"])
yield data
class CivitaiModelExtractor(CivitaiExtractor):
subcategory = "model"
directory_fmt = ("{category}", "{user[username]}",
"{model[id]}{model[name]:? //}",
"{version[id]}{version[name]:? //}")
pattern = BASE_PATTERN + r"/models/(\d+)(?:/?\?modelVersionId=(\d+))?"
example = "https://civitai.com/models/12345/TITLE"
def items(self):
model_id, version_id = self.groups
model = self.api.model(model_id)
if "user" in model:
user = model["user"]
del model["user"]
else:
user = model["creator"]
del model["creator"]
versions = model["modelVersions"]
del model["modelVersions"]
if version_id:
version_id = int(version_id)
for version in versions:
if version["id"] == version_id:
break
else:
version = self.api.model_version(version_id)
versions = (version,)
for version in versions:
version["date"] = text.parse_datetime(
version["createdAt"], "%Y-%m-%dT%H:%M:%S.%fZ")
data = {
"model" : model,
"version": version,
"user" : user,
}
yield Message.Directory, data
for file in self._extract_files(model, version, user):
file.update(data)
yield Message.Url, file["url"], file
def _extract_files(self, model, version, user):
filetypes = self.config("files")
if filetypes is None:
return self._extract_files_image(model, version, user)
generators = {
"model" : self._extract_files_model,
"image" : self._extract_files_image,
"gallery" : self._extract_files_gallery,
"gallerie": self._extract_files_gallery,
}
if isinstance(filetypes, str):
filetypes = filetypes.split(",")
return itertools.chain.from_iterable(
generators[ft.rstrip("s")](model, version, user)
for ft in filetypes
)
def _extract_files_model(self, model, version, user):
files = []
for num, file in enumerate(version["files"], 1):
file["uuid"] = "model-{}-{}-{}".format(
model["id"], version["id"], file["id"])
files.append({
"num" : num,
"file" : file,
"filename" : file["name"],
"extension": "bin",
"url" : file.get("downloadUrl") or
"{}/api/download/models/{}".format(
self.root, version["id"]),
"_http_headers" : {
"Authorization": self.api.headers.get("Authorization")},
"_http_validate": self._validate_file_model,
})
return files
def _extract_files_image(self, model, version, user):
if "images" in version:
images = version["images"]
else:
params = {
"modelVersionId": version["id"],
"prioritizedUserIds": [user["id"]],
"period": "AllTime",
"sort": "Most Reactions",
"limit": 20,
"pending": True,
}
images = self.api.images(params, defaults=False)
return self._image_results(images)
def _extract_files_gallery(self, model, version, user):
images = self.api.images_gallery(model, version, user)
return self._image_results(images)
def _validate_file_model(self, response):
if response.headers.get("Content-Type", "").startswith("text/html"):
alert = text.extr(
response.text, 'mantine-Alert-message">', "</div></div></div>")
if alert:
msg = "\"{}\" - 'api-key' required".format(
text.remove_html(alert))
else:
msg = "'api-key' required to download this file"
self.log.warning(msg)
return False
return True
class CivitaiImageExtractor(CivitaiExtractor):
subcategory = "image"
pattern = BASE_PATTERN + r"/images/(\d+)"
example = "https://civitai.com/images/12345"
def images(self):
return self.api.image(self.groups[0])
class CivitaiPostExtractor(CivitaiExtractor):
subcategory = "post"
directory_fmt = ("{category}", "{username|user[username]}", "posts",
"{post[id]}{post[title]:? //}")
pattern = BASE_PATTERN + r"/posts/(\d+)"
example = "https://civitai.com/posts/12345"
def posts(self):
return ({"id": int(self.groups[0])},)
class CivitaiTagExtractor(CivitaiExtractor):
subcategory = "tag"
pattern = BASE_PATTERN + r"/tag/([^/?&#]+)"
example = "https://civitai.com/tag/TAG"
def models(self):
tag = text.unquote(self.groups[0])
return self.api.models_tag(tag)
class CivitaiSearchExtractor(CivitaiExtractor):
subcategory = "search"
pattern = BASE_PATTERN + r"/search/models\?([^#]+)"
example = "https://civitai.com/search/models?query=QUERY"
def models(self):
params = text.parse_query(self.groups[0])
return self.api.models(params)
class CivitaiModelsExtractor(CivitaiExtractor):
subcategory = "models"
pattern = BASE_PATTERN + r"/models(?:/?\?([^#]+))?(?:$|#)"
example = "https://civitai.com/models"
def models(self):
params = text.parse_query(self.groups[0])
return self.api.models(params)
class CivitaiImagesExtractor(CivitaiExtractor):
subcategory = "images"
pattern = BASE_PATTERN + r"/images(?:/?\?([^#]+))?(?:$|#)"
example = "https://civitai.com/images"
def images(self):
params = text.parse_query(self.groups[0])
return self.api.images(params)
class CivitaiUserExtractor(CivitaiExtractor):
subcategory = "user"
pattern = USER_PATTERN + r"/?(?:$|\?|#)"
example = "https://civitai.com/user/USER"
def initialize(self):
pass
def items(self):
base = "{}/user/{}/".format(self.root, self.groups[0])
return self._dispatch_extractors((
(CivitaiUserModelsExtractor, base + "models"),
(CivitaiUserPostsExtractor , base + "posts"),
(CivitaiUserImagesExtractor, base + "images"),
), ("user-models", "user-posts"))
class CivitaiUserModelsExtractor(CivitaiExtractor):
subcategory = "user-models"
pattern = USER_PATTERN + r"/models/?(?:\?([^#]+))?"
example = "https://civitai.com/user/USER/models"
def models(self):
params = text.parse_query(self.groups[1])
params["username"] = text.unquote(self.groups[0])
return self.api.models(params)
class CivitaiUserPostsExtractor(CivitaiExtractor):
subcategory = "user-posts"
directory_fmt = ("{category}", "{username|user[username]}", "posts",
"{post[id]}{post[title]:? //}")
pattern = USER_PATTERN + r"/posts/?(?:\?([^#]+))?"
example = "https://civitai.com/user/USER/posts"
def posts(self):
params = text.parse_query(self.groups[1])
params["username"] = text.unquote(self.groups[0])
return self.api.posts(params)
class CivitaiUserImagesExtractor(CivitaiExtractor):
subcategory = "user-images"
pattern = USER_PATTERN + r"/images/?(?:\?([^#]+))?"
example = "https://civitai.com/user/USER/images"
def __init__(self, match):
self.params = text.parse_query_list(match.group(2))
if self.params.get("section") == "reactions":
self.subcategory = "reactions"
self.images = self.images_reactions
CivitaiExtractor.__init__(self, match)
def images(self):
params = self.params
params["username"] = text.unquote(self.groups[0])
return self.api.images(params)
def images_reactions(self):
if "Authorization" not in self.api.headers and \
not self.cookies.get(
"__Secure-civitai-token", domain=".civitai.com"):
raise exception.AuthorizationError("api-key or cookies required")
params = self.params
params["authed"] = True
params["useIndex"] = False
if "reactions" in params:
if isinstance(params["reactions"], str):
params["reactions"] = (params["reactions"],)
else:
params["reactions"] = (
"Like", "Dislike", "Heart", "Laugh", "Cry")
return self.api.images(params)
class CivitaiRestAPI():
"""Interface for the Civitai Public REST API
https://developer.civitai.com/docs/api/public-rest
"""
def __init__(self, extractor):
self.extractor = extractor
self.root = extractor.root + "/api"
self.headers = {"Content-Type": "application/json"}
api_key = extractor.config("api-key")
if api_key:
extractor.log.debug("Using api_key authentication")
self.headers["Authorization"] = "Bearer " + api_key
nsfw = extractor.config("nsfw")
if nsfw is None or nsfw is True:
nsfw = "X"
elif not nsfw:
nsfw = "Safe"
self.nsfw = nsfw
def image(self, image_id):
return self.images({
"imageId": image_id,
})
def images(self, params):
endpoint = "/v1/images"
if "nsfw" not in params:
params["nsfw"] = self.nsfw
return self._pagination(endpoint, params)
def images_gallery(self, model, version, user):
return self.images({
"modelId" : model["id"],
"modelVersionId": version["id"],
})
def model(self, model_id):
endpoint = "/v1/models/{}".format(model_id)
return self._call(endpoint)
def model_version(self, model_version_id):
endpoint = "/v1/model-versions/{}".format(model_version_id)
return self._call(endpoint)
def models(self, params):
return self._pagination("/v1/models", params)
def models_tag(self, tag):
return self.models({"tag": tag})
def _call(self, endpoint, params=None):
if endpoint[0] == "/":
url = self.root + endpoint
else:
url = endpoint
response = self.extractor.request(
url, params=params, headers=self.headers)
return response.json()
def _pagination(self, endpoint, params):
while True:
data = self._call(endpoint, params)
yield from data["items"]
try:
endpoint = data["metadata"]["nextPage"]
except KeyError:
return
params = None
class CivitaiTrpcAPI():
"""Interface for the Civitai tRPC API"""
def __init__(self, extractor):
self.extractor = extractor
self.root = extractor.root + "/api/trpc/"
self.headers = {
"content-type" : "application/json",
"x-client-version": "5.0.211",
"x-client-date" : "",
"x-client" : "web",
"x-fingerprint" : "undefined",
}
api_key = extractor.config("api-key")
if api_key:
extractor.log.debug("Using api_key authentication")
self.headers["Authorization"] = "Bearer " + api_key
nsfw = extractor.config("nsfw")
if nsfw is None or nsfw is True:
nsfw = 31
elif not nsfw:
nsfw = 1
self.nsfw = nsfw
def image(self, image_id):
endpoint = "image.get"
params = {"id": int(image_id)}
return (self._call(endpoint, params),)
def image_generationdata(self, image_id):
endpoint = "image.getGenerationData"
params = {"id": int(image_id)}
return self._call(endpoint, params)
def images(self, params, defaults=True):
endpoint = "image.getInfinite"
if defaults:
params = self._merge_params(params, {
"useIndex" : True,
"period" : "AllTime",
"sort" : "Newest",
"types" : ["image"],
"withMeta" : False, # Metadata Only
"fromPlatform" : False, # Made On-Site
"browsingLevel": self.nsfw,
"include" : ["cosmetics"],
})
params = self._type_params(params)
return self._pagination(endpoint, params)
def images_gallery(self, model, version, user):
endpoint = "image.getImagesAsPostsInfinite"
params = {
"period" : "AllTime",
"sort" : "Newest",
"modelVersionId": version["id"],
"modelId" : model["id"],
"hidden" : False,
"limit" : 50,
"browsingLevel" : self.nsfw,
}
for post in self._pagination(endpoint, params):
yield from post["images"]
def images_post(self, post_id):
params = {
"postId" : int(post_id),
"pending": True,
}
return self.images(params)
def model(self, model_id):
endpoint = "model.getById"
params = {"id": int(model_id)}
return self._call(endpoint, params)
def model_version(self, model_version_id):
endpoint = "modelVersion.getById"
params = {"id": int(model_version_id)}
return self._call(endpoint, params)
def models(self, params, defaults=True):
endpoint = "model.getAll"
if defaults:
params = self._merge_params(params, {
"period" : "AllTime",
"periodMode" : "published",
"sort" : "Newest",
"pending" : False,
"hidden" : False,
"followed" : False,
"earlyAccess" : False,
"fromPlatform" : False,
"supportsGeneration": False,
"browsingLevel": self.nsfw,
})
return self._pagination(endpoint, params)
def models_tag(self, tag):
return self.models({"tagname": tag})
def post(self, post_id):
endpoint = "post.get"
params = {"id": int(post_id)}
return self._call(endpoint, params)
def posts(self, params, defaults=True):
endpoint = "post.getInfinite"
meta = {"cursor": ("Date",)}
if defaults:
params = self._merge_params(params, {
"browsingLevel": self.nsfw,
"period" : "AllTime",
"periodMode" : "published",
"sort" : "Newest",
"followed" : False,
"draftOnly" : False,
"pending" : True,
"include" : ["cosmetics"],
})
return self._pagination(endpoint, params, meta)
def user(self, username):
endpoint = "user.getCreator"
params = {"username": username}
return (self._call(endpoint, params),)
def _call(self, endpoint, params, meta=None):
url = self.root + endpoint
headers = self.headers
if meta:
input = {"json": params, "meta": {"values": meta}}
else:
input = {"json": params}
params = {"input": util.json_dumps(input)}
headers["x-client-date"] = str(int(time.time() * 1000))
response = self.extractor.request(url, params=params, headers=headers)
return response.json()["result"]["data"]["json"]
def _pagination(self, endpoint, params, meta=None):
if "cursor" not in params:
params["cursor"] = None
meta_ = {"cursor": ("undefined",)}
while True:
data = self._call(endpoint, params, meta_)
yield from data["items"]
try:
if not data["nextCursor"]:
return
except KeyError:
return
params["cursor"] = data["nextCursor"]
meta_ = meta
def _merge_params(self, params_user, params_default):
params_default.update(params_user)
return params_default
def _type_params(self, params):
for key, type in (
("tags" , int),
("modelId" , int),
("modelVersionId", int),
):
if key in params:
params[key] = type(params[key])
return params