1package testscript23import (4 "bytes"5 "context"6 "encoding/json"7 "flag"8 "fmt"9 "io"10 "math/rand"11 "net"12 "net/http"13 "net/url"14 "os"15 "os/exec"16 "path/filepath"17 "runtime"18 "strconv"19 "strings"20 "testing"21 "time"2223 "github.com/charmbracelet/keygen"24 "github.com/charmbracelet/soft-serve/pkg/config"25 "github.com/charmbracelet/soft-serve/pkg/db"26 "github.com/charmbracelet/soft-serve/pkg/test"27 "github.com/rogpeppe/go-internal/testscript"28 "github.com/spf13/cobra"29 "golang.org/x/crypto/ssh"30)3132var (33 update = flag.Bool("update", false, "update script files")34 binPath string35)3637func PrepareBuildCommand(binPath string) *exec.Cmd {38 _, disableRaceSet := os.LookupEnv("SOFT_SERVE_DISABLE_RACE_CHECKS")39 if disableRaceSet {40 // don't add the -race flag41 return exec.Command("go", "build", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft")) //nolint:noctx42 }43 return exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft")) //nolint:noctx44}4546func TestMain(m *testing.M) {47 tmp, err := os.MkdirTemp("", "soft-serve*")48 if err != nil {49 fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)50 os.Exit(1)51 }52 defer os.RemoveAll(tmp)5354 binPath = filepath.Join(tmp, "soft")55 if runtime.GOOS == "windows" {56 binPath += ".exe"57 }5859 // Build the soft binary with -cover flag.60 cmd := PrepareBuildCommand(binPath)61 if err := cmd.Run(); err != nil {62 fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)63 os.Exit(1)64 }6566 // Run tests67 os.Exit(m.Run())68}6970func TestScript(t *testing.T) {71 flag.Parse()7273 mkkey := func(name string) (string, *keygen.SSHKeyPair) {74 path := filepath.Join(t.TempDir(), name)75 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())76 if err != nil {77 t.Fatal(err)78 }79 return path, pair80 }8182 admin1Key, admin1 := mkkey("admin1")83 _, admin2 := mkkey("admin2")84 user1Key, user1 := mkkey("user1")85 attackerKey, attacker := mkkey("attacker")86 attackerSigner := &maliciousSigner{87 publicKey: admin1.PublicKey(),88 }8990 testscript.Run(t, testscript.Params{91 Dir: "./testdata/",92 UpdateScripts: *update,93 RequireExplicitExec: true,94 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){95 "soft": cmdSoft("admin", admin1.Signer()),96 "usoft": cmdSoft("user1", user1.Signer()),97 "attacksoft": cmdSoft("attacker", attackerSigner, attacker.Signer()),98 "git": cmdGit(admin1Key),99 "ugit": cmdGit(user1Key),100 "agit": cmdGit(attackerKey),101 "curl": cmdCurl,102 "mkfile": cmdMkfile,103 "envfile": cmdEnvfile,104 "readfile": cmdReadfile,105 "dos2unix": cmdDos2Unix,106 "new-webhook": cmdNewWebhook,107 "ensureserverrunning": cmdEnsureServerRunning,108 "ensureservernotrunning": cmdEnsureServerNotRunning,109 "stopserver": cmdStopserver,110 "ui": cmdUI(admin1.Signer()),111 "uui": cmdUI(user1.Signer()),112 },113 Setup: func(e *testscript.Env) error {114 // Add binPath to PATH115 e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))116117 data := t.TempDir()118 sshPort := test.RandomPort()119 sshListen := fmt.Sprintf("localhost:%d", sshPort)120 gitPort := test.RandomPort()121 gitListen := fmt.Sprintf("localhost:%d", gitPort)122 httpPort := test.RandomPort()123 httpListen := fmt.Sprintf("localhost:%d", httpPort)124 statsPort := test.RandomPort()125 statsListen := fmt.Sprintf("localhost:%d", statsPort)126 serverName := "Test Soft Serve"127128 e.Setenv("DATA_PATH", data)129 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))130 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))131 e.Setenv("STATS_PORT", fmt.Sprintf("%d", statsPort))132 e.Setenv("GIT_PORT", fmt.Sprintf("%d", gitPort))133 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())134 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())135 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())136 e.Setenv("ATTACKER_AUTHORIZED_KEY", attacker.AuthorizedKey())137 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))138 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))139140 // This is used to set up test specific configuration and http endpoints141 e.Setenv("SOFT_SERVE_TESTRUN", "1")142143 // This will disable the default lipgloss renderer colors144 e.Setenv("SOFT_SERVE_NO_COLOR", "1")145146 // Soft Serve debug environment variables147 for _, env := range []string{148 "SOFT_SERVE_DEBUG",149 "SOFT_SERVE_VERBOSE",150 } {151 if v, ok := os.LookupEnv(env); ok {152 e.Setenv(env, v)153 }154 }155156 // TODO: test different configs157 cfg := config.DefaultConfig()158 cfg.DataPath = data159 cfg.Name = serverName160 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}161 cfg.SSH.ListenAddr = sshListen162 cfg.SSH.PublicURL = "ssh://" + sshListen163 cfg.Git.ListenAddr = gitListen164 cfg.HTTP.ListenAddr = httpListen165 cfg.HTTP.PublicURL = "http://" + httpListen166 cfg.Stats.ListenAddr = statsListen167 cfg.LFS.Enabled = true168169 // Parse os SOFT_SERVE environment variables170 if err := cfg.ParseEnv(); err != nil {171 return err172 }173174 // Override the database data source if we're using postgres175 // so we can create a temporary database for the tests.176 if cfg.DB.Driver == "postgres" {177 cleanup, err := setupPostgres(e.T(), cfg)178 if err != nil {179 return err180 }181 if cleanup != nil {182 e.Defer(cleanup)183 }184 }185186 for _, env := range cfg.Environ() {187 parts := strings.SplitN(env, "=", 2)188 if len(parts) != 2 {189 e.T().Fatal("invalid environment variable", env)190 }191 e.Setenv(parts[0], parts[1])192 }193194 return nil195 },196 })197}198199func cmdSoft(user string, keys ...ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {200 return func(ts *testscript.TestScript, neg bool, args []string) {201 cli, err := ssh.Dial(202 "tcp",203 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),204 &ssh.ClientConfig{205 User: user,206 Auth: []ssh.AuthMethod{ssh.PublicKeys(keys...)},207 HostKeyCallback: ssh.InsecureIgnoreHostKey(),208 },209 )210 ts.Check(err)211 defer cli.Close()212213 sess, err := cli.NewSession()214 ts.Check(err)215 defer sess.Close()216217 sess.Stdout = ts.Stdout()218 sess.Stderr = ts.Stderr()219220 check(ts, sess.Run(strings.Join(args, " ")), neg)221 }222}223224func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {225 return func(ts *testscript.TestScript, neg bool, args []string) {226 if len(args) < 1 {227 ts.Fatalf("usage: ui <quoted string input>")228 return229 }230231 cli, err := ssh.Dial(232 "tcp",233 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),234 &ssh.ClientConfig{235 User: "git",236 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},237 HostKeyCallback: ssh.InsecureIgnoreHostKey(),238 },239 )240 check(ts, err, neg)241 defer cli.Close()242243 sess, err := cli.NewSession()244 check(ts, err, neg)245 defer sess.Close()246247 // XXX: this is a hack to make the UI tests work248 // cmp command always complains about an extra newline249 // in the output250 defer ts.Stdout().Write([]byte("\n"))251252 sess.Stdout = ts.Stdout()253 sess.Stderr = ts.Stderr()254255 stdin, err := sess.StdinPipe()256 check(ts, err, neg)257258 err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})259 check(ts, err, neg)260 check(ts, sess.Start(""), neg)261262 in, err := strconv.Unquote(args[0])263 check(ts, err, neg)264 reader := strings.NewReader(in)265 go func() {266 defer stdin.Close()267 for {268 r, _, err := reader.ReadRune()269 if err == io.EOF {270 break271 }272 check(ts, err, neg)273 _, _ = io.WriteString(stdin, string(r))274275 // Wait for the UI to process the input276 time.Sleep(100 * time.Millisecond)277 }278 }()279280 check(ts, sess.Wait(), neg)281 }282}283284func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {285 if neg {286 ts.Fatalf("unsupported: ! dos2unix")287 }288 if len(args) < 1 {289 ts.Fatalf("usage: dos2unix paths...")290 }291 for _, arg := range args {292 filename := ts.MkAbs(arg)293 data, err := os.ReadFile(filename)294 if err != nil {295 ts.Fatalf("%s: %v", filename, err)296 }297298 // Replace all '\r\n' with '\n'.299 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})300301 if err := os.WriteFile(filename, data, 0o644); err != nil {302 ts.Fatalf("%s: %v", filename, err)303 }304 }305}306307var sshConfig = `308Host *309 UserKnownHostsFile %q310 StrictHostKeyChecking no311 IdentityAgent none312 IdentitiesOnly yes313 ServerAliveInterval 60314`315316func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {317 return func(ts *testscript.TestScript, neg bool, args []string) {318 ts.Check(os.WriteFile(319 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),320 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),321 0o600,322 ))323 sshArgs := []string{324 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),325 "-i", filepath.ToSlash(key),326 }327 ts.Setenv(328 "GIT_SSH_COMMAND",329 strings.Join(append([]string{"ssh"}, sshArgs...), " "),330 )331 // Disable git prompting for credentials.332 ts.Setenv("GIT_TERMINAL_PROMPT", "0")333 args = append([]string{334 "-c", "user.email=john@example.com",335 "-c", "user.name=John Doe",336 }, args...)337 check(ts, ts.Exec("git", args...), neg)338 }339}340341func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {342 if len(args) < 2 {343 ts.Fatalf("usage: mkfile path content")344 }345 check(ts, os.WriteFile(346 ts.MkAbs(args[0]),347 []byte(strings.Join(args[1:], " ")),348 0o644,349 ), neg)350}351352func check(ts *testscript.TestScript, err error, neg bool) {353 if neg && err == nil {354 ts.Fatalf("expected error, got nil")355 }356 if !neg {357 ts.Check(err)358 }359}360361func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {362 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))363}364365func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {366 if len(args) < 1 {367 ts.Fatalf("usage: envfile key=file...")368 }369370 for _, arg := range args {371 parts := strings.SplitN(arg, "=", 2)372 if len(parts) != 2 {373 ts.Fatalf("usage: envfile key=file...")374 }375 key := parts[0]376 file := parts[1]377 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))378 }379}380381func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {382 type webhookSite struct {383 UUID string `json:"uuid"`384 }385386 if len(args) != 1 {387 ts.Fatalf("usage: new-webhook <env-name>")388 }389390 const whSite = "https://webhook.site"391 req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil) //nolint:noctx392 check(ts, err, neg)393394 resp, err := http.DefaultClient.Do(req)395 check(ts, err, neg)396397 defer resp.Body.Close()398 var site webhookSite399 check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)400401 ts.Setenv(args[0], whSite+"/"+site.UUID)402}403404func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {405 var verbose bool406 var headers []string407 var data string408 method := http.MethodGet409410 cmd := &cobra.Command{411 Use: "curl",412 Args: cobra.MinimumNArgs(1),413 RunE: func(cmd *cobra.Command, args []string) error {414 url, err := url.Parse(args[0])415 if err != nil {416 return err417 }418419 req, err := http.NewRequest(method, url.String(), nil) //nolint:noctx420 if err != nil {421 return err422 }423424 if data != "" {425 req.Body = io.NopCloser(strings.NewReader(data))426 }427428 if verbose {429 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())430 }431432 for _, header := range headers {433 parts := strings.SplitN(header, ":", 2)434 if len(parts) != 2 {435 return fmt.Errorf("invalid header: %s", header)436 }437 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))438 }439440 if userInfo := url.User; userInfo != nil {441 password, _ := userInfo.Password()442 req.SetBasicAuth(userInfo.Username(), password)443 }444445 if verbose {446 for key, values := range req.Header {447 for _, value := range values {448 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)449 }450 }451 }452453 resp, err := http.DefaultClient.Do(req)454 if err != nil {455 return err456 }457458 if verbose {459 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)460 for key, values := range resp.Header {461 for _, value := range values {462 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)463 }464 }465 }466467 defer resp.Body.Close()468 buf, err := io.ReadAll(resp.Body)469 if err != nil {470 return err471 }472473 cmd.Print(string(buf))474475 return nil476 },477 }478479 cmd.SetArgs(args)480 cmd.SetOut(ts.Stdout())481 cmd.SetErr(ts.Stderr())482483 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")484 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")485 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")486 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")487488 check(ts, cmd.Execute(), neg)489}490491func cmdEnsureServerRunning(ts *testscript.TestScript, neg bool, args []string) {492 if len(args) < 1 {493 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +494 "These are set as env vars as they are randomized. " +495 "Example usage: \"cmdensureserverrunning SSH_PORT\"\n" +496 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")497 }498499 port := ts.Getenv(args[0])500501 // verify that the server is up502 addr := net.JoinHostPort("localhost", port)503 for {504 conn, _ := net.DialTimeout( //nolint:noctx505 "tcp",506 addr,507 time.Second,508 )509 if conn != nil {510 ts.Logf("Server is running on port: %s", port)511 conn.Close()512 break513 }514 }515}516517func cmdEnsureServerNotRunning(ts *testscript.TestScript, neg bool, args []string) {518 if len(args) < 1 {519 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +520 "These are set as env vars as they are randomized. " +521 "Example usage: \"cmdensureservernotrunning SSH_PORT\"\n" +522 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")523 }524525 port := ts.Getenv(args[0])526527 // verify that the server is not up528 addr := net.JoinHostPort("localhost", port)529 conn, _ := net.DialTimeout( //nolint:noctx530 "tcp",531 addr,532 time.Second,533 )534 if conn != nil {535 ts.Fatalf("server is running on port %s while it should not be running", port)536 conn.Close()537 }538}539540func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {541 // stop the server542 resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL"))) //nolint:noctx543 check(ts, err, neg)544 resp.Body.Close()545 time.Sleep(time.Second * 2) // Allow some time for the server to stop546}547548func setupPostgres(t testscript.T, cfg *config.Config) (func(), error) {549 // Indicates postgres550 // Create a disposable database551 rnd := rand.New(rand.NewSource(time.Now().UnixNano()))552 dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())553 dbDsn := cfg.DB.DataSource554 if dbDsn == "" {555 cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"556 }557558 dbUrl, err := url.Parse(cfg.DB.DataSource)559 if err != nil {560 return nil, err561 }562563 scheme := dbUrl.Scheme564 if scheme == "" {565 scheme = "postgres"566 }567568 host := dbUrl.Hostname()569 if host == "" {570 host = "localhost"571 }572573 connInfo := fmt.Sprintf("host=%s sslmode=disable", host)574 username := dbUrl.User.Username()575 if username != "" {576 connInfo += fmt.Sprintf(" user=%s", username)577 password, ok := dbUrl.User.Password()578 if ok {579 username = fmt.Sprintf("%s:%s", username, password)580 connInfo += fmt.Sprintf(" password=%s", password)581 }582 username = fmt.Sprintf("%s@", username)583 } else {584 connInfo += " user=postgres"585 username = "postgres@"586 }587588 port := dbUrl.Port()589 if port != "" {590 connInfo += fmt.Sprintf(" port=%s", port)591 port = fmt.Sprintf(":%s", port)592 }593594 cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",595 scheme,596 username,597 host,598 port,599 dbName,600 )601602 // Create the database603 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)604 if err != nil {605 return nil, err606 }607608 if _, err := dbx.ExecContext(context.TODO(), "CREATE DATABASE "+dbName); err != nil {609 return nil, err610 }611612 return func() {613 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)614 if err != nil {615 t.Fatal("failed to open database", dbName, err)616 }617618 if _, err := dbx.ExecContext(context.TODO(), "DROP DATABASE "+dbName); err != nil {619 t.Fatal("failed to drop database", dbName, err)620 }621 }, nil622}623624type maliciousSigner struct {625 publicKey ssh.PublicKey626}627628var _ ssh.Signer = (*maliciousSigner)(nil)629630// PublicKey implements ssh.Signer.631func (m *maliciousSigner) PublicKey() ssh.PublicKey {632 return m.publicKey633}634635// Sign implements ssh.Signer.636func (m *maliciousSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {637 // The attacker doesn't know how to sign the data without a private key.638 return &ssh.Signature{}, nil639}