1package ssrf23import (4 "context"5 "errors"6 "net"7 "net/http"8 "net/http/httptest"9 "testing"10 "time"11)1213func TestNewSecureClientBlocksPrivateIPs(t *testing.T) {14 client := NewSecureClient()15 transport := client.Transport.(*http.Transport)1617 tests := []struct {18 name string19 addr string20 wantErr bool21 }{22 {"block loopback", "127.0.0.1:80", true},23 {"block private 10.x", "10.0.0.1:80", true},24 {"block link-local", "169.254.169.254:80", true},25 {"block CGNAT", "100.64.0.1:80", true},26 {"allow public IP", "8.8.8.8:80", false},27 }2829 for _, tt := range tests {30 t.Run(tt.name, func(t *testing.T) {31 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)32 defer cancel()3334 conn, err := transport.DialContext(ctx, "tcp", tt.addr)35 if conn != nil {36 conn.Close()37 }3839 if tt.wantErr {40 if err == nil {41 t.Errorf("expected error for %s, got none", tt.addr)42 }43 } else {44 if err != nil && errors.Is(err, ErrPrivateIP) {45 t.Errorf("should not block %s with SSRF error, got: %v", tt.addr, err)46 }47 }48 })49 }50}5152func TestNewSecureClientBlocksPrivateHostnames(t *testing.T) {53 client := NewSecureClient()54 transport := client.Transport.(*http.Transport)5556 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)57 defer cancel()5859 // "localhost" resolves to 127.0.0.1 (loopback) -- must be blocked.60 // This exercises the hostname resolution path in DialContext:61 // net.LookupIP("localhost") -> 127.0.0.1 -> isPrivateOrInternal -> blocked.62 conn, err := transport.DialContext(ctx, "tcp", "localhost:80")63 if conn != nil {64 conn.Close()65 }66 if !errors.Is(err, ErrPrivateIP) {67 t.Errorf("expected ErrPrivateIP for hostname resolving to loopback, got: %v", err)68 }69}7071func TestNewSecureClientNilIPNotErrPrivateIP(t *testing.T) {72 client := NewSecureClient()73 transport := client.Transport.(*http.Transport)7475 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)76 defer cancel()7778 conn, err := transport.DialContext(ctx, "tcp", "not-an-ip:80")79 if conn != nil {80 conn.Close()81 }82 if err == nil {83 t.Fatal("expected error for non-IP address, got none")84 }85 if errors.Is(err, ErrPrivateIP) {86 t.Errorf("nil-IP path should not wrap ErrPrivateIP, got: %v", err)87 }88}8990func TestNewSecureClientBlocksRedirects(t *testing.T) {91 redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {92 http.Redirect(w, r, "http://8.8.8.8:8080/safe", http.StatusFound)93 }))94 defer redirectServer.Close()9596 client := NewSecureClient()97 req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, redirectServer.URL, nil)98 if err != nil {99 t.Fatalf("Failed to create request: %v", err)100 }101102 resp, err := client.Do(req)103 if err != nil {104 // httptest uses 127.0.0.1, blocked by SSRF protection105 if !errors.Is(err, ErrPrivateIP) {106 t.Fatalf("Request failed with non-SSRF error: %v", err)107 }108 return109 }110 defer resp.Body.Close()111112 if resp.StatusCode != http.StatusFound {113 t.Errorf("Expected redirect response (302), got %d", resp.StatusCode)114 }115}116117func TestIsPrivateOrInternal(t *testing.T) {118 tests := []struct {119 ip string120 want bool121 }{122 // Public123 {"8.8.8.8", false},124 {"2001:4860:4860::8888", false},125126 // Loopback127 {"127.0.0.1", true},128 {"::1", true},129130 // Private ranges131 {"10.0.0.1", true},132 {"192.168.1.1", true},133 {"172.16.0.1", true},134135 // Link-local (cloud metadata)136 {"169.254.169.254", true},137138 // CGNAT boundaries139 {"100.64.0.1", true},140 {"100.127.255.255", true},141142 // IPv6-mapped IPv4 (bypass vector the old webhook code missed)143 {"::ffff:127.0.0.1", true},144 {"::ffff:169.254.169.254", true},145 {"::ffff:8.8.8.8", false},146147 // Reserved148 {"0.0.0.0", true},149 {"240.0.0.1", true},150 }151152 for _, tt := range tests {153 t.Run(tt.ip, func(t *testing.T) {154 ip := net.ParseIP(tt.ip)155 if ip == nil {156 t.Fatalf("failed to parse IP: %s", tt.ip)157 }158 if got := isPrivateOrInternal(ip); got != tt.want {159 t.Errorf("isPrivateOrInternal(%s) = %v, want %v", tt.ip, got, tt.want)160 }161 })162 }163}164165func TestValidateURL(t *testing.T) {166 tests := []struct {167 name string168 url string169 wantErr bool170 errType error171 }{172 // Valid173 {"valid https", "https://1.1.1.1/webhook", false, nil},174175 // Scheme validation176 {"ftp scheme", "ftp://example.com/webhook", true, ErrInvalidScheme},177 {"no scheme", "example.com/webhook", true, ErrInvalidScheme},178179 // Localhost180 {"localhost", "http://localhost/webhook", true, ErrPrivateIP},181 {"subdomain.localhost", "http://test.localhost/webhook", true, ErrPrivateIP},182183 // IP-based blocking (one per category -- range coverage is in TestIsPrivateOrInternal)184 {"loopback IP", "http://127.0.0.1/webhook", true, ErrPrivateIP},185 {"metadata IP", "http://169.254.169.254/latest/meta-data/", true, ErrPrivateIP},186187 // Invalid URLs188 {"empty", "", true, ErrInvalidURL},189 {"missing hostname", "http:///webhook", true, ErrInvalidURL},190 }191192 for _, tt := range tests {193 t.Run(tt.name, func(t *testing.T) {194 err := ValidateURL(tt.url)195 if (err != nil) != tt.wantErr {196 t.Errorf("ValidateURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr)197 return198 }199 if tt.wantErr && tt.errType != nil {200 if !errors.Is(err, tt.errType) {201 t.Errorf("ValidateURL(%q) error = %v, want error type %v", tt.url, err, tt.errType)202 }203 }204 })205 }206}207208func TestIsLocalhost(t *testing.T) {209 tests := []struct {210 hostname string211 want bool212 }{213 {"localhost", true},214 {"LOCALHOST", true},215 {"test.localhost", true},216 {"example.com", false},217 {"localhost.com", false},218 }219220 for _, tt := range tests {221 t.Run(tt.hostname, func(t *testing.T) {222 if got := isLocalhost(tt.hostname); got != tt.want {223 t.Errorf("isLocalhost(%s) = %v, want %v", tt.hostname, got, tt.want)224 }225 })226 }227}