1package database23import (4 "context"5 "strings"67 "github.com/charmbracelet/soft-serve/pkg/db"8 "github.com/charmbracelet/soft-serve/pkg/db/models"9 "github.com/charmbracelet/soft-serve/pkg/sshutils"10 "github.com/charmbracelet/soft-serve/pkg/store"11 "github.com/charmbracelet/soft-serve/pkg/utils"12 "golang.org/x/crypto/ssh"13)1415type userStore struct{}1617var _ store.UserStore = (*userStore)(nil)1819// AddPublicKeyByUsername implements store.UserStore.20func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {21 username = strings.ToLower(username)22 if err := utils.ValidateUsername(username); err != nil {23 return err24 }2526 var userID int6427 if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil {28 return err29 }3031 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)32 VALUES (?, ?, CURRENT_TIMESTAMP);`)33 ak := sshutils.MarshalAuthorizedKey(pk)34 _, err := tx.ExecContext(ctx, query, userID, ak)3536 return err37}3839// CreateUser implements store.UserStore.40func (*userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error {41 username = strings.ToLower(username)42 if err := utils.ValidateUsername(username); err != nil {43 return err44 }4546 query := tx.Rebind(`INSERT INTO users (username, admin, updated_at)47 VALUES (?, ?, CURRENT_TIMESTAMP) RETURNING id;`)4849 var userID int6450 if err := tx.GetContext(ctx, &userID, query, username, isAdmin); err != nil {51 return err52 }5354 for _, pk := range pks {55 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)56 VALUES (?, ?, CURRENT_TIMESTAMP);`)57 ak := sshutils.MarshalAuthorizedKey(pk)58 _, err := tx.ExecContext(ctx, query, userID, ak)59 if err != nil {60 return err61 }62 }6364 return nil65}6667// DeleteUserByUsername implements store.UserStore.68func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error {69 username = strings.ToLower(username)70 if err := utils.ValidateUsername(username); err != nil {71 return err72 }7374 query := tx.Rebind(`DELETE FROM users WHERE username = ?;`)75 _, err := tx.ExecContext(ctx, query, username)76 return err77}7879// GetUserByID implements store.UserStore.80func (*userStore) GetUserByID(ctx context.Context, tx db.Handler, id int64) (models.User, error) {81 var m models.User82 query := tx.Rebind(`SELECT * FROM users WHERE id = ?;`)83 err := tx.GetContext(ctx, &m, query, id)84 return m, err85}8687// FindUserByPublicKey implements store.UserStore.88func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh.PublicKey) (models.User, error) {89 var m models.User90 query := tx.Rebind(`SELECT users.*91 FROM users92 INNER JOIN public_keys ON users.id = public_keys.user_id93 WHERE public_keys.public_key = ?;`)94 err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))95 return m, err96}9798// FindUserByUsername implements store.UserStore.99func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) {100 username = strings.ToLower(username)101 if err := utils.ValidateUsername(username); err != nil {102 return models.User{}, err103 }104105 var m models.User106 query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`)107 err := tx.GetContext(ctx, &m, query, username)108 return m, err109}110111// FindUserByAccessToken implements store.UserStore.112func (*userStore) FindUserByAccessToken(ctx context.Context, tx db.Handler, token string) (models.User, error) {113 var m models.User114 query := tx.Rebind(`SELECT users.*115 FROM users116 INNER JOIN access_tokens ON users.id = access_tokens.user_id117 WHERE access_tokens.token = ?;`)118 err := tx.GetContext(ctx, &m, query, token)119 return m, err120}121122// GetAllUsers implements store.UserStore.123func (*userStore) GetAllUsers(ctx context.Context, tx db.Handler) ([]models.User, error) {124 var ms []models.User125 query := tx.Rebind(`SELECT * FROM users;`)126 err := tx.SelectContext(ctx, &ms, query)127 return ms, err128}129130// ListPublicKeysByUserID implements store.UserStore..131func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id int64) ([]ssh.PublicKey, error) {132 var aks []string133 query := tx.Rebind(`SELECT public_key FROM public_keys134 WHERE user_id = ?135 ORDER BY public_keys.id ASC;`)136 err := tx.SelectContext(ctx, &aks, query, id)137 if err != nil {138 return nil, err139 }140141 pks := make([]ssh.PublicKey, len(aks))142 for i, ak := range aks {143 pk, _, err := sshutils.ParseAuthorizedKey(ak)144 if err != nil {145 return nil, err146 }147 pks[i] = pk148 }149150 return pks, nil151}152153// ListPublicKeysByUsername implements store.UserStore.154func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) {155 username = strings.ToLower(username)156 if err := utils.ValidateUsername(username); err != nil {157 return nil, err158 }159160 var aks []string161 query := tx.Rebind(`SELECT public_key FROM public_keys162 INNER JOIN users ON users.id = public_keys.user_id163 WHERE users.username = ?164 ORDER BY public_keys.id ASC;`)165 err := tx.SelectContext(ctx, &aks, query, username)166 if err != nil {167 return nil, err168 }169170 pks := make([]ssh.PublicKey, len(aks))171 for i, ak := range aks {172 pk, _, err := sshutils.ParseAuthorizedKey(ak)173 if err != nil {174 return nil, err175 }176 pks[i] = pk177 }178179 return pks, nil180}181182// RemovePublicKeyByUsername implements store.UserStore.183func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {184 username = strings.ToLower(username)185 if err := utils.ValidateUsername(username); err != nil {186 return err187 }188189 query := tx.Rebind(`DELETE FROM public_keys190 WHERE user_id = (SELECT id FROM users WHERE username = ?)191 AND public_key = ?;`)192 _, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))193 return err194}195196// SetAdminByUsername implements store.UserStore.197func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error {198 username = strings.ToLower(username)199 if err := utils.ValidateUsername(username); err != nil {200 return err201 }202203 query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`)204 _, err := tx.ExecContext(ctx, query, isAdmin, username)205 return err206}207208// SetUsernameByUsername implements store.UserStore.209func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error {210 username = strings.ToLower(username)211 if err := utils.ValidateUsername(username); err != nil {212 return err213 }214215 newUsername = strings.ToLower(newUsername)216 if err := utils.ValidateUsername(newUsername); err != nil {217 return err218 }219220 query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`)221 _, err := tx.ExecContext(ctx, query, newUsername, username)222 return err223}224225// SetUserPassword implements store.UserStore.226func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int64, password string) error {227 query := tx.Rebind(`UPDATE users SET password = ? WHERE id = ?;`)228 _, err := tx.ExecContext(ctx, query, password, userID)229 return err230}231232// SetUserPasswordByUsername implements store.UserStore.233func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error {234 username = strings.ToLower(username)235 if err := utils.ValidateUsername(username); err != nil {236 return err237 }238239 query := tx.Rebind(`UPDATE users SET password = ? WHERE username = ?;`)240 _, err := tx.ExecContext(ctx, query, password, username)241 return err242}