Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/diffusers/utils/loading_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import ipaddress
import os
import socket
import tempfile
from typing import Any, Callable
from urllib.parse import unquote, urlparse
Expand All @@ -11,6 +13,48 @@
from .import_utils import BACKENDS_MAPPING, is_imageio_available


_BLOCKED_IPV4 = [
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("100.64.0.0/10"),
ipaddress.ip_network("0.0.0.0/8"),
]
_BLOCKED_IPV6 = [
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("fe80::/10"),
]


def _validate_url(url: str) -> None:
"""Raise ValueError if url resolves to a private/reserved IP address (SSRF prevention)."""
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
raise ValueError(f"URL scheme {parsed.scheme!r} is not allowed.")
if not parsed.hostname:
raise ValueError(f"URL has no hostname: {url!r}")
try:
infos = socket.getaddrinfo(parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80))
except socket.gaierror as exc:
raise ValueError(f"Cannot resolve hostname {parsed.hostname!r}: {exc}") from exc
for _family, _, _, _, sockaddr in infos:
try:
addr = ipaddress.ip_address(sockaddr[0])
if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped:
addr = addr.ipv4_mapped
nets = _BLOCKED_IPV4 if isinstance(addr, ipaddress.IPv4Address) else _BLOCKED_IPV6
if any(addr in net for net in nets):
raise ValueError(
f"URL {url!r} resolves to private/reserved IP address {addr}. "
"Requests to internal network addresses are not allowed."
)
except ValueError:
raise


def load_image(
image: str | PIL.Image.Image, convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] | None = None
) -> PIL.Image.Image:
Expand All @@ -30,6 +74,7 @@ def load_image(
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
_validate_url(image)
image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
Expand Down Expand Up @@ -82,6 +127,7 @@ def load_video(
)

if is_url:
_validate_url(video)
response = requests.get(video, stream=True)
if response.status_code != 200:
raise ValueError(f"Failed to download video. Status code: {response.status_code}")
Expand Down
Loading