soft-serve

git clone git://git.lin.moe/fork/soft-serve.git

  1package ssrf
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"net"
  8	"net/http"
  9	"net/url"
 10	"slices"
 11	"strings"
 12	"time"
 13)
 14
 15var (
 16	// ErrPrivateIP is returned when a connection to a private or internal IP is blocked.
 17	ErrPrivateIP = errors.New("connection to private or internal IP address is not allowed")
 18	// ErrInvalidScheme is returned when a URL scheme is not http or https.
 19	ErrInvalidScheme = errors.New("URL must use http or https scheme")
 20	// ErrInvalidURL is returned when a URL is invalid.
 21	ErrInvalidURL = errors.New("invalid URL")
 22)
 23
 24// NewSecureClient returns an HTTP client with SSRF protection.
 25// It validates resolved IPs at dial time to block connections to private
 26// and internal networks. Hostnames are resolved and the validated IP is
 27// used directly in the dial call to prevent DNS rebinding (TOCTOU between
 28// validation and connection). Redirects are disabled to match the webhook
 29// client convention and prevent redirect-based SSRF.
 30func NewSecureClient() *http.Client {
 31	return &http.Client{
 32		Timeout: 30 * time.Second,
 33		Transport: &http.Transport{
 34			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
 35				host, port, err := net.SplitHostPort(addr)
 36				if err != nil {
 37					return nil, err //nolint:wrapcheck
 38				}
 39
 40				ip := net.ParseIP(host)
 41				if ip == nil {
 42					ips, err := net.LookupIP(host) //nolint
 43					if err != nil {
 44						return nil, fmt.Errorf("DNS resolution failed for host %s: %v", host, err)
 45					}
 46					if len(ips) == 0 {
 47						return nil, fmt.Errorf("no IP addresses found for host: %s", host)
 48					}
 49					ip = ips[0] // Use the first resolved IP address
 50				}
 51				if isPrivateOrInternal(ip) {
 52					return nil, fmt.Errorf("%w", ErrPrivateIP)
 53				}
 54
 55				dialer := &net.Dialer{
 56					Timeout:   10 * time.Second,
 57					KeepAlive: 30 * time.Second,
 58				}
 59				// Dial using the validated IP to prevent DNS rebinding.
 60				// Without this, the dialer resolves the hostname again
 61				// independently, and the second resolution could return
 62				// a different (private) IP.
 63				return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
 64			},
 65			MaxIdleConns:          100,
 66			IdleConnTimeout:       90 * time.Second,
 67			TLSHandshakeTimeout:   10 * time.Second,
 68			ExpectContinueTimeout: 1 * time.Second,
 69		},
 70		CheckRedirect: func(*http.Request, []*http.Request) error {
 71			return http.ErrUseLastResponse
 72		},
 73	}
 74}
 75
 76// isPrivateOrInternal checks if an IP address is private, internal, or reserved.
 77func isPrivateOrInternal(ip net.IP) bool {
 78	// Normalize IPv6-mapped IPv4 (e.g. ::ffff:127.0.0.1) to IPv4 form
 79	// so all checks apply consistently.
 80	if ip4 := ip.To4(); ip4 != nil {
 81		ip = ip4
 82	}
 83
 84	if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
 85		ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() {
 86		return true
 87	}
 88
 89	if ip4 := ip.To4(); ip4 != nil {
 90		// 0.0.0.0/8
 91		if ip4[0] == 0 {
 92			return true
 93		}
 94		// 100.64.0.0/10 (Shared Address Space / CGNAT)
 95		if ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127 {
 96			return true
 97		}
 98		// 192.0.0.0/24 (IETF Protocol Assignments)
 99		if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 {
100			return true
101		}
102		// 192.0.2.0/24 (TEST-NET-1)
103		if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2 {
104			return true
105		}
106		// 198.18.0.0/15 (benchmarking)
107		if ip4[0] == 198 && (ip4[1] == 18 || ip4[1] == 19) {
108			return true
109		}
110		// 198.51.100.0/24 (TEST-NET-2)
111		if ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100 {
112			return true
113		}
114		// 203.0.113.0/24 (TEST-NET-3)
115		if ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113 {
116			return true
117		}
118		// 240.0.0.0/4 (Reserved, includes 255.255.255.255 broadcast)
119		if ip4[0] >= 240 {
120			return true
121		}
122	}
123
124	return false
125}
126
127// ValidateURL validates that a URL is safe to make requests to.
128// It checks that the scheme is http/https, the hostname is not localhost,
129// and all resolved IPs are public.
130func ValidateURL(rawURL string) error {
131	if rawURL == "" {
132		return ErrInvalidURL
133	}
134
135	u, err := url.Parse(rawURL)
136	if err != nil {
137		return fmt.Errorf("%w: %v", ErrInvalidURL, err)
138	}
139
140	if u.Scheme != "http" && u.Scheme != "https" {
141		return ErrInvalidScheme
142	}
143
144	hostname := u.Hostname()
145	if hostname == "" {
146		return fmt.Errorf("%w: missing hostname", ErrInvalidURL)
147	}
148
149	if isLocalhost(hostname) {
150		return ErrPrivateIP
151	}
152
153	if ip := net.ParseIP(hostname); ip != nil {
154		if isPrivateOrInternal(ip) {
155			return ErrPrivateIP
156		}
157		return nil
158	}
159
160	ips, err := net.DefaultResolver.LookupIPAddr(context.Background(), hostname)
161	if err != nil {
162		return fmt.Errorf("%w: cannot resolve hostname: %v", ErrInvalidURL, err)
163	}
164
165	if slices.ContainsFunc(ips, func(addr net.IPAddr) bool {
166		return isPrivateOrInternal(addr.IP)
167	}) {
168		return ErrPrivateIP
169	}
170
171	return nil
172}
173
174// ValidateIPBeforeDial validates an IP address before establishing a connection.
175// This prevents DNS rebinding attacks by checking the resolved IP at dial time.
176func ValidateIPBeforeDial(ip net.IP) error {
177	if isPrivateOrInternal(ip) {
178		return ErrPrivateIP
179	}
180	return nil
181}
182
183// isLocalhost checks if the hostname is localhost or similar.
184func isLocalhost(hostname string) bool {
185	hostname = strings.ToLower(hostname)
186	return hostname == "localhost" ||
187		hostname == "localhost.localdomain" ||
188		strings.HasSuffix(hostname, ".localhost")
189}