1package web23import (4 "context"5 "errors"6 "fmt"7 "net/http"8 "strings"910 "github.com/charmbracelet/log/v2"11 "github.com/charmbracelet/soft-serve/pkg/backend"12 "github.com/charmbracelet/soft-serve/pkg/config"13 "github.com/charmbracelet/soft-serve/pkg/proto"14 "github.com/golang-jwt/jwt/v5"15)1617// authenticate authenticates the user from the request.18func authenticate(r *http.Request) (proto.User, error) {19 // Prefer the Authorization header20 user, err := parseAuthHdr(r)21 if err != nil || user == nil {22 if errors.Is(err, ErrInvalidToken) || errors.Is(err, ErrInvalidPassword) {23 return nil, err24 }25 return nil, proto.ErrUserNotFound26 }2728 return user, nil29}3031// ErrInvalidPassword is returned when the password is invalid.32var ErrInvalidPassword = errors.New("invalid password")3334func parseUsernamePassword(ctx context.Context, username, password string) (proto.User, error) {35 logger := log.FromContext(ctx)36 be := backend.FromContext(ctx)3738 if username != "" && password != "" {39 user, err := be.User(ctx, username)40 if err == nil && user != nil && backend.VerifyPassword(password, user.Password()) {41 return user, nil42 }4344 // Try to authenticate using access token as the password45 user, err = be.UserByAccessToken(ctx, password)46 if err == nil {47 return user, nil48 }4950 logger.Error("invalid password or token", "username", username, "err", err)51 return nil, ErrInvalidPassword52 } else if username != "" {53 // Try to authenticate using access token as the username54 logger.Debug("trying to authenticate using access token as username", "username", username)55 user, err := be.UserByAccessToken(ctx, username)56 if err == nil {57 return user, nil58 }5960 logger.Error("failed to get user", "err", err)61 return nil, ErrInvalidToken62 }6364 return nil, proto.ErrUserNotFound65}6667// ErrInvalidHeader is returned when the authorization header is invalid.68var ErrInvalidHeader = errors.New("invalid authorization header")6970func parseAuthHdr(r *http.Request) (proto.User, error) {71 // Check for auth header72 header := r.Header.Get("Authorization")73 if header == "" {74 return nil, ErrInvalidHeader75 }7677 ctx := r.Context()78 logger := log.FromContext(ctx).WithPrefix("http.auth")79 be := backend.FromContext(ctx)8081 logger.Debug("authorization auth header", "header", header)8283 parts := strings.SplitN(header, " ", 2)84 if len(parts) != 2 {85 return nil, errors.New("invalid authorization header")86 }8788 switch strings.ToLower(parts[0]) {89 case "token":90 user, err := be.UserByAccessToken(ctx, parts[1])91 if err != nil {92 logger.Error("failed to get user", "err", err)93 return nil, err94 }9596 return user, nil97 case "bearer":98 claims, err := parseJWT(ctx, parts[1])99 if err != nil {100 return nil, err101 }102103 // Find the user104 parts := strings.SplitN(claims.Subject, "#", 2)105 if len(parts) != 2 {106 logger.Error("invalid jwt subject", "subject", claims.Subject)107 return nil, errors.New("invalid jwt subject")108 }109110 user, err := be.User(ctx, parts[0])111 if err != nil {112 logger.Error("failed to get user", "err", err)113 return nil, err114 }115116 expectedSubject := fmt.Sprintf("%s#%d", user.Username(), user.ID())117 if expectedSubject != claims.Subject {118 logger.Error("invalid jwt subject", "subject", claims.Subject, "expected", expectedSubject)119 return nil, errors.New("invalid jwt subject")120 }121122 return user, nil123 default:124 username, password, ok := r.BasicAuth()125 if !ok {126 return nil, ErrInvalidHeader127 }128129 return parseUsernamePassword(ctx, username, password)130 }131}132133// ErrInvalidToken is returned when a token is invalid.134var ErrInvalidToken = errors.New("invalid token")135136func parseJWT(ctx context.Context, bearer string) (*jwt.RegisteredClaims, error) {137 cfg := config.FromContext(ctx)138 logger := log.FromContext(ctx).WithPrefix("http.auth")139 kp, err := config.KeyPair(cfg)140 if err != nil {141 return nil, err142 }143144 repo := proto.RepositoryFromContext(ctx)145 if repo == nil {146 return nil, errors.New("missing repository")147 }148149 token, err := jwt.ParseWithClaims(bearer, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) {150 if _, ok := t.Method.(*jwt.SigningMethodEd25519); !ok {151 return nil, errors.New("invalid signing method")152 }153154 return kp.CryptoPublicKey(), nil155 },156 jwt.WithIssuer(cfg.HTTP.PublicURL),157 jwt.WithIssuedAt(),158 jwt.WithAudience(repo.Name()),159 )160 if err != nil {161 logger.Error("failed to parse jwt", "err", err)162 return nil, ErrInvalidToken163 }164165 claims, ok := token.Claims.(*jwt.RegisteredClaims)166 if !token.Valid || !ok {167 return nil, ErrInvalidToken168 }169170 return claims, nil171}