1package testscript23import (4 "bytes"5 "context"6 "encoding/json"7 "flag"8 "fmt"9 "io"10 "math/rand"11 "net"12 "net/http"13 "net/url"14 "os"15 "os/exec"16 "path/filepath"17 "runtime"18 "strconv"19 "strings"20 "testing"21 "time"2223 "github.com/charmbracelet/keygen"24 "github.com/charmbracelet/soft-serve/pkg/config"25 "github.com/charmbracelet/soft-serve/pkg/db"26 "github.com/charmbracelet/soft-serve/pkg/test"27 "github.com/rogpeppe/go-internal/testscript"28 "github.com/spf13/cobra"29 "golang.org/x/crypto/ssh"30)3132var (33 update = flag.Bool("update", false, "update script files")34 binPath string35)3637func PrepareBuildCommand(binPath string) *exec.Cmd {38 _, disableRaceSet := os.LookupEnv("SOFT_SERVE_DISABLE_RACE_CHECKS")39 if disableRaceSet {40 // don't add the -race flag41 return exec.Command("go", "build", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))42 }43 return exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))44}4546func TestMain(m *testing.M) {47 tmp, err := os.MkdirTemp("", "soft-serve*")48 if err != nil {49 fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)50 os.Exit(1)51 }52 defer os.RemoveAll(tmp)5354 binPath = filepath.Join(tmp, "soft")55 if runtime.GOOS == "windows" {56 binPath += ".exe"57 }5859 // Build the soft binary with -cover flag.60 cmd := PrepareBuildCommand(binPath)61 if err := cmd.Run(); err != nil {62 fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)63 os.Exit(1)64 }6566 // Run tests67 os.Exit(m.Run())68}6970func TestScript(t *testing.T) {71 flag.Parse()7273 mkkey := func(name string) (string, *keygen.SSHKeyPair) {74 path := filepath.Join(t.TempDir(), name)75 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())76 if err != nil {77 t.Fatal(err)78 }79 return path, pair80 }8182 admin1Key, admin1 := mkkey("admin1")83 _, admin2 := mkkey("admin2")84 user1Key, user1 := mkkey("user1")8586 testscript.Run(t, testscript.Params{87 Dir: "./testdata/",88 UpdateScripts: *update,89 RequireExplicitExec: true,90 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){91 "soft": cmdSoft("admin", admin1.Signer()),92 "usoft": cmdSoft("user1", user1.Signer()),93 "git": cmdGit(admin1Key),94 "ugit": cmdGit(user1Key),95 "curl": cmdCurl,96 "mkfile": cmdMkfile,97 "envfile": cmdEnvfile,98 "readfile": cmdReadfile,99 "dos2unix": cmdDos2Unix,100 "new-webhook": cmdNewWebhook,101 "ensureserverrunning": cmdEnsureServerRunning,102 "ensureservernotrunning": cmdEnsureServerNotRunning,103 "stopserver": cmdStopserver,104 "ui": cmdUI(admin1.Signer()),105 "uui": cmdUI(user1.Signer()),106 },107 Setup: func(e *testscript.Env) error {108 // Add binPath to PATH109 e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))110111 data := t.TempDir()112 sshPort := test.RandomPort()113 sshListen := fmt.Sprintf("localhost:%d", sshPort)114 gitPort := test.RandomPort()115 gitListen := fmt.Sprintf("localhost:%d", gitPort)116 httpPort := test.RandomPort()117 httpListen := fmt.Sprintf("localhost:%d", httpPort)118 statsPort := test.RandomPort()119 statsListen := fmt.Sprintf("localhost:%d", statsPort)120 serverName := "Test Soft Serve"121122 e.Setenv("DATA_PATH", data)123 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))124 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))125 e.Setenv("STATS_PORT", fmt.Sprintf("%d", statsPort))126 e.Setenv("GIT_PORT", fmt.Sprintf("%d", gitPort))127 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())128 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())129 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())130 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))131 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))132133 // This is used to set up test specific configuration and http endpoints134 e.Setenv("SOFT_SERVE_TESTRUN", "1")135136 // This will disable the default lipgloss renderer colors137 e.Setenv("SOFT_SERVE_NO_COLOR", "1")138139 // Soft Serve debug environment variables140 for _, env := range []string{141 "SOFT_SERVE_DEBUG",142 "SOFT_SERVE_VERBOSE",143 } {144 if v, ok := os.LookupEnv(env); ok {145 e.Setenv(env, v)146 }147 }148149 // TODO: test different configs150 cfg := config.DefaultConfig()151 cfg.DataPath = data152 cfg.Name = serverName153 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}154 cfg.SSH.ListenAddr = sshListen155 cfg.SSH.PublicURL = "ssh://" + sshListen156 cfg.Git.ListenAddr = gitListen157 cfg.HTTP.ListenAddr = httpListen158 cfg.HTTP.PublicURL = "http://" + httpListen159 cfg.Stats.ListenAddr = statsListen160 cfg.LFS.Enabled = true161162 // Parse os SOFT_SERVE environment variables163 if err := cfg.ParseEnv(); err != nil {164 return err165 }166167 // Override the database data source if we're using postgres168 // so we can create a temporary database for the tests.169 if cfg.DB.Driver == "postgres" {170 err, cleanup := setupPostgres(e.T(), cfg)171 if err != nil {172 return err173 }174 if cleanup != nil {175 e.Defer(cleanup)176 }177 }178179 for _, env := range cfg.Environ() {180 parts := strings.SplitN(env, "=", 2)181 if len(parts) != 2 {182 e.T().Fatal("invalid environment variable", env)183 }184 e.Setenv(parts[0], parts[1])185 }186187 return nil188 },189 })190}191192func cmdSoft(user string, key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {193 return func(ts *testscript.TestScript, neg bool, args []string) {194 cli, err := ssh.Dial(195 "tcp",196 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),197 &ssh.ClientConfig{198 User: user,199 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},200 HostKeyCallback: ssh.InsecureIgnoreHostKey(),201 },202 )203 ts.Check(err)204 defer cli.Close()205206 sess, err := cli.NewSession()207 ts.Check(err)208 defer sess.Close()209210 sess.Stdout = ts.Stdout()211 sess.Stderr = ts.Stderr()212213 check(ts, sess.Run(strings.Join(args, " ")), neg)214 }215}216217func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {218 return func(ts *testscript.TestScript, neg bool, args []string) {219 if len(args) < 1 {220 ts.Fatalf("usage: ui <quoted string input>")221 return222 }223224 cli, err := ssh.Dial(225 "tcp",226 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),227 &ssh.ClientConfig{228 User: "git",229 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},230 HostKeyCallback: ssh.InsecureIgnoreHostKey(),231 },232 )233 check(ts, err, neg)234 defer cli.Close()235236 sess, err := cli.NewSession()237 check(ts, err, neg)238 defer sess.Close()239240 // XXX: this is a hack to make the UI tests work241 // cmp command always complains about an extra newline242 // in the output243 defer ts.Stdout().Write([]byte("\n"))244245 sess.Stdout = ts.Stdout()246 sess.Stderr = ts.Stderr()247248 stdin, err := sess.StdinPipe()249 check(ts, err, neg)250251 err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})252 check(ts, err, neg)253 check(ts, sess.Start(""), neg)254255 in, err := strconv.Unquote(args[0])256 check(ts, err, neg)257 reader := strings.NewReader(in)258 go func() {259 defer stdin.Close()260 for {261 r, _, err := reader.ReadRune()262 if err == io.EOF {263 break264 }265 check(ts, err, neg)266 stdin.Write([]byte(string(r))) // nolint: errcheck267268 // Wait for the UI to process the input269 time.Sleep(100 * time.Millisecond)270 }271 }()272273 check(ts, sess.Wait(), neg)274 }275}276277func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {278 if neg {279 ts.Fatalf("unsupported: ! dos2unix")280 }281 if len(args) < 1 {282 ts.Fatalf("usage: dos2unix paths...")283 }284 for _, arg := range args {285 filename := ts.MkAbs(arg)286 data, err := os.ReadFile(filename)287 if err != nil {288 ts.Fatalf("%s: %v", filename, err)289 }290291 // Replace all '\r\n' with '\n'.292 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})293294 if err := os.WriteFile(filename, data, 0o644); err != nil {295 ts.Fatalf("%s: %v", filename, err)296 }297 }298}299300var sshConfig = `301Host *302 UserKnownHostsFile %q303 StrictHostKeyChecking no304 IdentityAgent none305 IdentitiesOnly yes306 ServerAliveInterval 60307`308309func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {310 return func(ts *testscript.TestScript, neg bool, args []string) {311 ts.Check(os.WriteFile(312 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),313 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),314 0o600,315 ))316 sshArgs := []string{317 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),318 "-i", filepath.ToSlash(key),319 }320 ts.Setenv(321 "GIT_SSH_COMMAND",322 strings.Join(append([]string{"ssh"}, sshArgs...), " "),323 )324 // Disable git prompting for credentials.325 ts.Setenv("GIT_TERMINAL_PROMPT", "0")326 args = append([]string{327 "-c", "user.email=john@example.com",328 "-c", "user.name=John Doe",329 }, args...)330 check(ts, ts.Exec("git", args...), neg)331 }332}333334func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {335 if len(args) < 2 {336 ts.Fatalf("usage: mkfile path content")337 }338 check(ts, os.WriteFile(339 ts.MkAbs(args[0]),340 []byte(strings.Join(args[1:], " ")),341 0o644,342 ), neg)343}344345func check(ts *testscript.TestScript, err error, neg bool) {346 if neg && err == nil {347 ts.Fatalf("expected error, got nil")348 }349 if !neg {350 ts.Check(err)351 }352}353354func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {355 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))356}357358func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {359 if len(args) < 1 {360 ts.Fatalf("usage: envfile key=file...")361 }362363 for _, arg := range args {364 parts := strings.SplitN(arg, "=", 2)365 if len(parts) != 2 {366 ts.Fatalf("usage: envfile key=file...")367 }368 key := parts[0]369 file := parts[1]370 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))371 }372}373374func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {375 type webhookSite struct {376 UUID string `json:"uuid"`377 }378379 if len(args) != 1 {380 ts.Fatalf("usage: new-webhook <env-name>")381 }382383 const whSite = "https://webhook.site"384 req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)385 check(ts, err, neg)386387 resp, err := http.DefaultClient.Do(req)388 check(ts, err, neg)389390 defer resp.Body.Close()391 var site webhookSite392 check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)393394 ts.Setenv(args[0], whSite+"/"+site.UUID)395}396397func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {398 var verbose bool399 var headers []string400 var data string401 method := http.MethodGet402403 cmd := &cobra.Command{404 Use: "curl",405 Args: cobra.MinimumNArgs(1),406 RunE: func(cmd *cobra.Command, args []string) error {407 url, err := url.Parse(args[0])408 if err != nil {409 return err410 }411412 req, err := http.NewRequest(method, url.String(), nil)413 if err != nil {414 return err415 }416417 if data != "" {418 req.Body = io.NopCloser(strings.NewReader(data))419 }420421 if verbose {422 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())423 }424425 for _, header := range headers {426 parts := strings.SplitN(header, ":", 2)427 if len(parts) != 2 {428 return fmt.Errorf("invalid header: %s", header)429 }430 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))431 }432433 if userInfo := url.User; userInfo != nil {434 password, _ := userInfo.Password()435 req.SetBasicAuth(userInfo.Username(), password)436 }437438 if verbose {439 for key, values := range req.Header {440 for _, value := range values {441 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)442 }443 }444 }445446 resp, err := http.DefaultClient.Do(req)447 if err != nil {448 return err449 }450451 if verbose {452 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)453 for key, values := range resp.Header {454 for _, value := range values {455 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)456 }457 }458 }459460 defer resp.Body.Close()461 buf, err := io.ReadAll(resp.Body)462 if err != nil {463 return err464 }465466 cmd.Print(string(buf))467468 return nil469 },470 }471472 cmd.SetArgs(args)473 cmd.SetOut(ts.Stdout())474 cmd.SetErr(ts.Stderr())475476 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")477 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")478 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")479 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")480481 check(ts, cmd.Execute(), neg)482}483484func cmdEnsureServerRunning(ts *testscript.TestScript, neg bool, args []string) {485 if len(args) < 1 {486 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +487 "These are set as env vars as they are randomized. " +488 "Example usage: \"cmdensureserverrunning SSH_PORT\"\n" +489 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")490 }491492 port := ts.Getenv(args[0])493494 // verify that the server is up495 addr := net.JoinHostPort("localhost", port)496 for {497 conn, _ := net.DialTimeout(498 "tcp",499 addr,500 time.Second,501 )502 if conn != nil {503 ts.Logf("Server is running on port: %s", port)504 conn.Close()505 break506 }507 }508}509510func cmdEnsureServerNotRunning(ts *testscript.TestScript, neg bool, args []string) {511 if len(args) < 1 {512 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +513 "These are set as env vars as they are randomized. " +514 "Example usage: \"cmdensureservernotrunning SSH_PORT\"\n" +515 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")516 }517518 port := ts.Getenv(args[0])519520 // verify that the server is not up521 addr := net.JoinHostPort("localhost", port)522 for {523 conn, _ := net.DialTimeout(524 "tcp",525 addr,526 time.Second,527 )528 if conn != nil {529 ts.Fatalf("server is running on port %s while it should not be running", port)530 conn.Close()531 }532 break533 }534}535536func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {537 // stop the server538 resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))539 check(ts, err, neg)540 resp.Body.Close()541 time.Sleep(time.Second * 2) // Allow some time for the server to stop542}543544func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {545 // Indicates postgres546 // Create a disposable database547 rnd := rand.New(rand.NewSource(time.Now().UnixNano()))548 dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())549 dbDsn := cfg.DB.DataSource550 if dbDsn == "" {551 cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"552 }553554 dbUrl, err := url.Parse(cfg.DB.DataSource)555 if err != nil {556 return err, nil557 }558559 scheme := dbUrl.Scheme560 if scheme == "" {561 scheme = "postgres"562 }563564 host := dbUrl.Hostname()565 if host == "" {566 host = "localhost"567 }568569 connInfo := fmt.Sprintf("host=%s sslmode=disable", host)570 username := dbUrl.User.Username()571 if username != "" {572 connInfo += fmt.Sprintf(" user=%s", username)573 password, ok := dbUrl.User.Password()574 if ok {575 username = fmt.Sprintf("%s:%s", username, password)576 connInfo += fmt.Sprintf(" password=%s", password)577 }578 username = fmt.Sprintf("%s@", username)579 } else {580 connInfo += " user=postgres"581 username = "postgres@"582 }583584 port := dbUrl.Port()585 if port != "" {586 connInfo += fmt.Sprintf(" port=%s", port)587 port = fmt.Sprintf(":%s", port)588 }589590 cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",591 scheme,592 username,593 host,594 port,595 dbName,596 )597598 // Create the database599 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)600 if err != nil {601 return err, nil602 }603604 if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {605 return err, nil606 }607608 return nil, func() {609 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)610 if err != nil {611 t.Fatal("failed to open database", dbName, err)612 }613614 if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {615 t.Fatal("failed to drop database", dbName, err)616 }617 }618}