soft-serve

git clone git://git.lin.moe/fork/soft-serve.git

  1package ssrf
  2
  3import (
  4	"context"
  5	"errors"
  6	"net"
  7	"net/http"
  8	"net/http/httptest"
  9	"testing"
 10	"time"
 11)
 12
 13func TestNewSecureClientBlocksPrivateIPs(t *testing.T) {
 14	client := NewSecureClient()
 15	transport := client.Transport.(*http.Transport)
 16
 17	tests := []struct {
 18		name    string
 19		addr    string
 20		wantErr bool
 21	}{
 22		{"block loopback", "127.0.0.1:80", true},
 23		{"block private 10.x", "10.0.0.1:80", true},
 24		{"block link-local", "169.254.169.254:80", true},
 25		{"block CGNAT", "100.64.0.1:80", true},
 26		{"allow public IP", "8.8.8.8:80", false},
 27	}
 28
 29	for _, tt := range tests {
 30		t.Run(tt.name, func(t *testing.T) {
 31			ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
 32			defer cancel()
 33
 34			conn, err := transport.DialContext(ctx, "tcp", tt.addr)
 35			if conn != nil {
 36				conn.Close()
 37			}
 38
 39			if tt.wantErr {
 40				if err == nil {
 41					t.Errorf("expected error for %s, got none", tt.addr)
 42				}
 43			} else {
 44				if err != nil && errors.Is(err, ErrPrivateIP) {
 45					t.Errorf("should not block %s with SSRF error, got: %v", tt.addr, err)
 46				}
 47			}
 48		})
 49	}
 50}
 51
 52func TestNewSecureClientBlocksPrivateHostnames(t *testing.T) {
 53	client := NewSecureClient()
 54	transport := client.Transport.(*http.Transport)
 55
 56	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
 57	defer cancel()
 58
 59	// "localhost" resolves to 127.0.0.1 (loopback) -- must be blocked.
 60	// This exercises the hostname resolution path in DialContext:
 61	// net.LookupIP("localhost") -> 127.0.0.1 -> isPrivateOrInternal -> blocked.
 62	conn, err := transport.DialContext(ctx, "tcp", "localhost:80")
 63	if conn != nil {
 64		conn.Close()
 65	}
 66	if !errors.Is(err, ErrPrivateIP) {
 67		t.Errorf("expected ErrPrivateIP for hostname resolving to loopback, got: %v", err)
 68	}
 69}
 70
 71func TestNewSecureClientNilIPNotErrPrivateIP(t *testing.T) {
 72	client := NewSecureClient()
 73	transport := client.Transport.(*http.Transport)
 74
 75	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
 76	defer cancel()
 77
 78	conn, err := transport.DialContext(ctx, "tcp", "not-an-ip:80")
 79	if conn != nil {
 80		conn.Close()
 81	}
 82	if err == nil {
 83		t.Fatal("expected error for non-IP address, got none")
 84	}
 85	if errors.Is(err, ErrPrivateIP) {
 86		t.Errorf("nil-IP path should not wrap ErrPrivateIP, got: %v", err)
 87	}
 88}
 89
 90func TestNewSecureClientBlocksRedirects(t *testing.T) {
 91	redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 92		http.Redirect(w, r, "http://8.8.8.8:8080/safe", http.StatusFound)
 93	}))
 94	defer redirectServer.Close()
 95
 96	client := NewSecureClient()
 97	req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, redirectServer.URL, nil)
 98	if err != nil {
 99		t.Fatalf("Failed to create request: %v", err)
100	}
101
102	resp, err := client.Do(req)
103	if err != nil {
104		// httptest uses 127.0.0.1, blocked by SSRF protection
105		if !errors.Is(err, ErrPrivateIP) {
106			t.Fatalf("Request failed with non-SSRF error: %v", err)
107		}
108		return
109	}
110	defer resp.Body.Close()
111
112	if resp.StatusCode != http.StatusFound {
113		t.Errorf("Expected redirect response (302), got %d", resp.StatusCode)
114	}
115}
116
117func TestIsPrivateOrInternal(t *testing.T) {
118	tests := []struct {
119		ip   string
120		want bool
121	}{
122		// Public
123		{"8.8.8.8", false},
124		{"2001:4860:4860::8888", false},
125
126		// Loopback
127		{"127.0.0.1", true},
128		{"::1", true},
129
130		// Private ranges
131		{"10.0.0.1", true},
132		{"192.168.1.1", true},
133		{"172.16.0.1", true},
134
135		// Link-local (cloud metadata)
136		{"169.254.169.254", true},
137
138		// CGNAT boundaries
139		{"100.64.0.1", true},
140		{"100.127.255.255", true},
141
142		// IPv6-mapped IPv4 (bypass vector the old webhook code missed)
143		{"::ffff:127.0.0.1", true},
144		{"::ffff:169.254.169.254", true},
145		{"::ffff:8.8.8.8", false},
146
147		// Reserved
148		{"0.0.0.0", true},
149		{"240.0.0.1", true},
150	}
151
152	for _, tt := range tests {
153		t.Run(tt.ip, func(t *testing.T) {
154			ip := net.ParseIP(tt.ip)
155			if ip == nil {
156				t.Fatalf("failed to parse IP: %s", tt.ip)
157			}
158			if got := isPrivateOrInternal(ip); got != tt.want {
159				t.Errorf("isPrivateOrInternal(%s) = %v, want %v", tt.ip, got, tt.want)
160			}
161		})
162	}
163}
164
165func TestValidateURL(t *testing.T) {
166	tests := []struct {
167		name    string
168		url     string
169		wantErr bool
170		errType error
171	}{
172		// Valid
173		{"valid https", "https://1.1.1.1/webhook", false, nil},
174
175		// Scheme validation
176		{"ftp scheme", "ftp://example.com/webhook", true, ErrInvalidScheme},
177		{"no scheme", "example.com/webhook", true, ErrInvalidScheme},
178
179		// Localhost
180		{"localhost", "http://localhost/webhook", true, ErrPrivateIP},
181		{"subdomain.localhost", "http://test.localhost/webhook", true, ErrPrivateIP},
182
183		// IP-based blocking (one per category -- range coverage is in TestIsPrivateOrInternal)
184		{"loopback IP", "http://127.0.0.1/webhook", true, ErrPrivateIP},
185		{"metadata IP", "http://169.254.169.254/latest/meta-data/", true, ErrPrivateIP},
186
187		// Invalid URLs
188		{"empty", "", true, ErrInvalidURL},
189		{"missing hostname", "http:///webhook", true, ErrInvalidURL},
190	}
191
192	for _, tt := range tests {
193		t.Run(tt.name, func(t *testing.T) {
194			err := ValidateURL(tt.url)
195			if (err != nil) != tt.wantErr {
196				t.Errorf("ValidateURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr)
197				return
198			}
199			if tt.wantErr && tt.errType != nil {
200				if !errors.Is(err, tt.errType) {
201					t.Errorf("ValidateURL(%q) error = %v, want error type %v", tt.url, err, tt.errType)
202				}
203			}
204		})
205	}
206}
207
208func TestIsLocalhost(t *testing.T) {
209	tests := []struct {
210		hostname string
211		want     bool
212	}{
213		{"localhost", true},
214		{"LOCALHOST", true},
215		{"test.localhost", true},
216		{"example.com", false},
217		{"localhost.com", false},
218	}
219
220	for _, tt := range tests {
221		t.Run(tt.hostname, func(t *testing.T) {
222			if got := isLocalhost(tt.hostname); got != tt.want {
223				t.Errorf("isLocalhost(%s) = %v, want %v", tt.hostname, got, tt.want)
224			}
225		})
226	}
227}