diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index c4fee0cfdd83..89ee0a702c3a 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -1,4 +1,6 @@ +import ipaddress import os +import socket import tempfile from typing import Any, Callable from urllib.parse import unquote, urlparse @@ -11,6 +13,48 @@ from .import_utils import BACKENDS_MAPPING, is_imageio_available +_BLOCKED_IPV4 = [ + ipaddress.ip_network("127.0.0.0/8"), + ipaddress.ip_network("10.0.0.0/8"), + ipaddress.ip_network("172.16.0.0/12"), + ipaddress.ip_network("192.168.0.0/16"), + ipaddress.ip_network("169.254.0.0/16"), + ipaddress.ip_network("100.64.0.0/10"), + ipaddress.ip_network("0.0.0.0/8"), +] +_BLOCKED_IPV6 = [ + ipaddress.ip_network("::1/128"), + ipaddress.ip_network("fc00::/7"), + ipaddress.ip_network("fe80::/10"), +] + + +def _validate_url(url: str) -> None: + """Raise ValueError if url resolves to a private/reserved IP address (SSRF prevention).""" + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + raise ValueError(f"URL scheme {parsed.scheme!r} is not allowed.") + if not parsed.hostname: + raise ValueError(f"URL has no hostname: {url!r}") + try: + infos = socket.getaddrinfo(parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80)) + except socket.gaierror as exc: + raise ValueError(f"Cannot resolve hostname {parsed.hostname!r}: {exc}") from exc + for _family, _, _, _, sockaddr in infos: + try: + addr = ipaddress.ip_address(sockaddr[0]) + if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped: + addr = addr.ipv4_mapped + nets = _BLOCKED_IPV4 if isinstance(addr, ipaddress.IPv4Address) else _BLOCKED_IPV6 + if any(addr in net for net in nets): + raise ValueError( + f"URL {url!r} resolves to private/reserved IP address {addr}. " + "Requests to internal network addresses are not allowed." + ) + except ValueError: + raise + + def load_image( image: str | PIL.Image.Image, convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] | None = None ) -> PIL.Image.Image: @@ -30,6 +74,7 @@ def load_image( """ if isinstance(image, str): if image.startswith("http://") or image.startswith("https://"): + _validate_url(image) image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw) elif os.path.isfile(image): image = PIL.Image.open(image) @@ -82,6 +127,7 @@ def load_video( ) if is_url: + _validate_url(video) response = requests.get(video, stream=True) if response.status_code != 200: raise ValueError(f"Failed to download video. Status code: {response.status_code}")