mirror of
https://github.com/mikf/gallery-dl.git
synced 2024-11-21 18:22:30 +01:00
[civitai] support using internal tRPC API endpoints (#3706)
This commit is contained in:
parent
daa220370c
commit
3eb3564b5d
@ -1514,6 +1514,19 @@ Description
|
||||
``gallery``.
|
||||
|
||||
|
||||
extractor.civitai.api
|
||||
---------------------
|
||||
Type
|
||||
``string``
|
||||
Default
|
||||
``"rest"``
|
||||
Description
|
||||
Selects which API endpoints to use.
|
||||
|
||||
* ``"rest"``: `Public REST API <https://developer.civitai.com/docs/api/public-rest>`__
|
||||
* ``"trpc"``: Internal TRPC API
|
||||
|
||||
|
||||
extractor.civitai.api-key
|
||||
-------------------------
|
||||
Type
|
||||
|
@ -9,8 +9,9 @@
|
||||
"""Extractors for https://www.civitai.com/"""
|
||||
|
||||
from .common import Extractor, Message
|
||||
from .. import text
|
||||
from .. import text, util
|
||||
import itertools
|
||||
import time
|
||||
|
||||
BASE_PATTERN = r"(?:https?://)?civitai\.com"
|
||||
USER_PATTERN = BASE_PATTERN + r"/user/([^/?#]+)"
|
||||
@ -20,13 +21,18 @@ class CivitaiExtractor(Extractor):
|
||||
"""Base class for civitai extractors"""
|
||||
category = "civitai"
|
||||
root = "https://civitai.com"
|
||||
directory_fmt = ("{category}", "{username}", "images")
|
||||
directory_fmt = ("{category}", "{username|user[username]}", "images")
|
||||
filename_fmt = "{id}.{extension}"
|
||||
archive_fmt = "{hash}"
|
||||
request_interval = (0.5, 1.5)
|
||||
|
||||
def _init(self):
|
||||
self.api = CivitaiAPI(self)
|
||||
if self.config("api") == "trpc":
|
||||
self.log.debug("Using TRPC API")
|
||||
self.api = CivitaiTrpcAPI(self)
|
||||
else:
|
||||
self.log.debug("Using REST API")
|
||||
self.api = CivitaiRestAPI(self)
|
||||
|
||||
quality = self.config("quality")
|
||||
if quality:
|
||||
@ -94,11 +100,15 @@ class CivitaiModelExtractor(CivitaiExtractor):
|
||||
|
||||
def items(self):
|
||||
model_id, version_id = self.groups
|
||||
|
||||
model = self.api.model(model_id)
|
||||
creator = model["creator"]
|
||||
|
||||
if "user" in model:
|
||||
user = model["user"]
|
||||
del model["user"]
|
||||
else:
|
||||
user = model["creator"]
|
||||
del model["creator"]
|
||||
versions = model["modelVersions"]
|
||||
del model["creator"]
|
||||
del model["modelVersions"]
|
||||
|
||||
if version_id:
|
||||
@ -117,18 +127,18 @@ class CivitaiModelExtractor(CivitaiExtractor):
|
||||
data = {
|
||||
"model" : model,
|
||||
"version": version,
|
||||
"user" : creator,
|
||||
"user" : user,
|
||||
}
|
||||
|
||||
yield Message.Directory, data
|
||||
for file in self._extract_files(model, version):
|
||||
for file in self._extract_files(model, version, user):
|
||||
file.update(data)
|
||||
yield Message.Url, file["url"], file
|
||||
|
||||
def _extract_files(self, model, version):
|
||||
def _extract_files(self, model, version, user):
|
||||
filetypes = self.config("files")
|
||||
if filetypes is None:
|
||||
return self._extract_files_image(model, version)
|
||||
return self._extract_files_image(model, version, user)
|
||||
|
||||
generators = {
|
||||
"model" : self._extract_files_model,
|
||||
@ -140,11 +150,11 @@ class CivitaiModelExtractor(CivitaiExtractor):
|
||||
filetypes = filetypes.split(",")
|
||||
|
||||
return itertools.chain.from_iterable(
|
||||
generators[ft.rstrip("s")](model, version)
|
||||
generators[ft.rstrip("s")](model, version, user)
|
||||
for ft in filetypes
|
||||
)
|
||||
|
||||
def _extract_files_model(self, model, version):
|
||||
def _extract_files_model(self, model, version, user):
|
||||
return [
|
||||
{
|
||||
"num" : num,
|
||||
@ -159,17 +169,30 @@ class CivitaiModelExtractor(CivitaiExtractor):
|
||||
for num, file in enumerate(version["files"], 1)
|
||||
]
|
||||
|
||||
def _extract_files_image(self, model, version):
|
||||
def _extract_files_image(self, model, version, user):
|
||||
if "images" in version:
|
||||
images = version["images"]
|
||||
else:
|
||||
params = {
|
||||
"modelVersionId": version["id"],
|
||||
"prioritizedUserIds": [user["id"]],
|
||||
"period": "AllTime",
|
||||
"sort": "Most Reactions",
|
||||
"limit": 20,
|
||||
"pending": True,
|
||||
}
|
||||
images = self.api.images(params, defaults=False)
|
||||
|
||||
return [
|
||||
text.nameext_from_url(file["url"], {
|
||||
"num" : num,
|
||||
"file": file,
|
||||
"url" : self._url(file),
|
||||
})
|
||||
for num, file in enumerate(version["images"], 1)
|
||||
for num, file in enumerate(images, 1)
|
||||
]
|
||||
|
||||
def _extract_files_gallery(self, model, version):
|
||||
def _extract_files_gallery(self, model, version, user):
|
||||
params = {
|
||||
"modelId" : model["id"],
|
||||
"modelVersionId": version["id"],
|
||||
@ -202,7 +225,7 @@ class CivitaiImageExtractor(CivitaiExtractor):
|
||||
example = "https://civitai.com/images/12345"
|
||||
|
||||
def images(self):
|
||||
return self.api.images({"imageId": self.groups[0]})
|
||||
return self.api.image(self.groups[0])
|
||||
|
||||
|
||||
class CivitaiTagModelsExtractor(CivitaiExtractor):
|
||||
@ -273,7 +296,7 @@ class CivitaiUserImagesExtractor(CivitaiExtractor):
|
||||
return self.api.images(params)
|
||||
|
||||
|
||||
class CivitaiAPI():
|
||||
class CivitaiRestAPI():
|
||||
"""Interface for the Civitai Public REST API
|
||||
|
||||
https://developer.civitai.com/docs/api/public-rest
|
||||
@ -289,6 +312,11 @@ class CivitaiAPI():
|
||||
extractor.log.debug("Using api_key authentication")
|
||||
self.headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
def image(self, image_id):
|
||||
endpoint = "/v1/images"
|
||||
params = {"imageId": image_id}
|
||||
return self._pagination(endpoint, params)
|
||||
|
||||
def images(self, params):
|
||||
endpoint = "/v1/images"
|
||||
return self._pagination(endpoint, params)
|
||||
@ -324,3 +352,107 @@ class CivitaiAPI():
|
||||
except KeyError:
|
||||
return
|
||||
params = None
|
||||
|
||||
|
||||
class CivitaiTrpcAPI():
|
||||
"""Interface for the Civitai TRPC API"""
|
||||
|
||||
def __init__(self, extractor):
|
||||
self.extractor = extractor
|
||||
self.root = extractor.root + "/api/trpc/"
|
||||
self.headers = {
|
||||
"content-type" : "application/json",
|
||||
"x-client-version": "5.0.94",
|
||||
"x-client-date" : "",
|
||||
"x-client" : "web",
|
||||
"x-fingerprint" : "undefined",
|
||||
}
|
||||
|
||||
api_key = extractor.config("api-key")
|
||||
if api_key:
|
||||
extractor.log.debug("Using api_key authentication")
|
||||
self.headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
def image(self, image_id):
|
||||
endpoint = "image.get"
|
||||
params = {"id": int(image_id)}
|
||||
return (self._call(endpoint, params),)
|
||||
|
||||
def images(self, params, defaults=True):
|
||||
endpoint = "image.getInfinite"
|
||||
|
||||
if defaults:
|
||||
params_ = {
|
||||
"useIndex" : True,
|
||||
"period" : "AllTime",
|
||||
"sort" : "Newest",
|
||||
"types" : ["image"],
|
||||
"withMeta" : False, # Metadata Only
|
||||
"fromPlatform" : False, # Made On-Site
|
||||
"browsingLevel": 31,
|
||||
"include" : ["cosmetics"],
|
||||
}
|
||||
params_.update(params)
|
||||
else:
|
||||
params_ = params
|
||||
|
||||
return self._pagination(endpoint, params_)
|
||||
|
||||
def model(self, model_id):
|
||||
endpoint = "model.getById"
|
||||
params = {"id": int(model_id)}
|
||||
return self._call(endpoint, params)
|
||||
|
||||
def model_version(self, model_version_id):
|
||||
endpoint = "modelVersion.getById"
|
||||
params = {"id": int(model_version_id)}
|
||||
return self._call(endpoint, params)
|
||||
|
||||
def models(self, params, defaults=True):
|
||||
endpoint = "model.getAll"
|
||||
|
||||
if defaults:
|
||||
params_ = {
|
||||
"period" : "AllTime",
|
||||
"periodMode" : "published",
|
||||
"sort" : "Newest",
|
||||
"pending" : False,
|
||||
"hidden" : False,
|
||||
"followed" : False,
|
||||
"earlyAccess" : False,
|
||||
"fromPlatform" : False,
|
||||
"supportsGeneration": False,
|
||||
"browsingLevel": 31,
|
||||
}
|
||||
params_.update(params)
|
||||
else:
|
||||
params_ = params
|
||||
|
||||
return self._pagination(endpoint, params_)
|
||||
|
||||
def user(self, username):
|
||||
endpoint = "user.getCreator"
|
||||
params = {"username": username}
|
||||
return (self._call(endpoint, params),)
|
||||
|
||||
def _call(self, endpoint, params):
|
||||
url = self.root + endpoint
|
||||
headers = self.headers
|
||||
params = {"input": util.json_dumps({"json": params})}
|
||||
|
||||
headers["x-client-date"] = str(int(time.time() * 1000))
|
||||
response = self.extractor.request(url, headers=headers, params=params)
|
||||
|
||||
return response.json()["result"]["data"]["json"]
|
||||
|
||||
def _pagination(self, endpoint, params):
|
||||
while True:
|
||||
data = self._call(endpoint, params)
|
||||
yield from data["items"]
|
||||
|
||||
try:
|
||||
if not data["nextCursor"]:
|
||||
return
|
||||
params["cursor"] = data["nextCursor"]
|
||||
except KeyError:
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user