kpaste

git clone git://git.lin.moe/go/kpaste.git

  1package main
  2
  3import (
  4	"context"
  5	"embed"
  6	"fmt"
  7	"io/fs"
  8	"log/slog"
  9	"net/http"
 10	"strings"
 11	"time"
 12)
 13
 14//go:embed static/css/*.css
 15var staticfs embed.FS
 16
 17type Mux struct {
 18	*http.ServeMux
 19	middlewares []Middleware
 20}
 21
 22type Middleware func(http.HandlerFunc) http.HandlerFunc
 23
 24func (m *Mux) Use(mws ...Middleware) *Mux {
 25	return &Mux{
 26		ServeMux:    m.ServeMux,
 27		middlewares: append(m.middlewares, mws...)[:],
 28	}
 29}
 30
 31func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 32	var handler http.HandlerFunc = m.ServeMux.ServeHTTP
 33	for _, mw := range m.middlewares {
 34		handler = mw(handler)
 35	}
 36	handler.ServeHTTP(w, r)
 37}
 38
 39func NewMux(db NoteDB) (*Mux, error) {
 40	l := slog.Default()
 41
 42	mux := &Mux{
 43		ServeMux: http.NewServeMux(),
 44	}
 45
 46	mux = mux.Use(
 47		LogMiddleware(l),
 48		URLMiddleware(),
 49		DBMiddleware(db),
 50	)
 51
 52	staticfs, _ := fs.Sub(staticfs, "static")
 53	mux.Handle("/-/", http.StripPrefix("/-/", http.FileServer(http.FS(staticfs))))
 54
 55	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
 56		switch r.Method {
 57		case http.MethodPost:
 58			NewOrUpdateNoteAPI(!isPlainTextAgent(r.UserAgent()))(w, r)
 59		case http.MethodGet:
 60			if r.URL.EscapedPath() == "/" {
 61				Index(!isPlainTextAgent(r.UserAgent()))(w, r)
 62			} else {
 63				var isrich bool
 64				slog.Info(r.URL.EscapedPath())
 65				if strings.HasPrefix(r.URL.EscapedPath(), "/!/") {
 66					isrich = false
 67				} else {
 68					isrich = !isPlainTextAgent(r.UserAgent())
 69				}
 70				GetNoteAPI(isrich)(w, r)
 71			}
 72		case http.MethodDelete:
 73			DeleteNoteAPI(!isPlainTextAgent(r.UserAgent()))(w, r)
 74		default:
 75			w.WriteHeader(http.StatusNotFound)
 76			fmt.Fprintf(w, "Not found\n")
 77		}
 78	})
 79
 80	return mux, nil
 81}
 82
 83const LOGGER_MWKEY = "logger"
 84
 85func LogMiddleware(l *slog.Logger) Middleware {
 86	return func(next http.HandlerFunc) http.HandlerFunc {
 87		return func(w http.ResponseWriter, r *http.Request) {
 88			l := l.With(
 89				"path", r.URL.EscapedPath(),
 90				"remote", r.RemoteAddr,
 91			)
 92
 93			start := time.Now()
 94
 95			l.Debug("new request",
 96				"method", r.Method,
 97			)
 98
 99			r = r.WithContext(context.WithValue(r.Context(), LOGGER_MWKEY, l))
100			next(w, r)
101
102			l.Debug("request ended",
103				"duration", time.Now().Sub(start),
104			)
105		}
106	}
107}
108
109func LogCtx(ctx context.Context) *slog.Logger {
110	raw, ok := ctx.Value(LOGGER_MWKEY).(*slog.Logger)
111	if !ok {
112		return slog.Default()
113	}
114	return raw
115}
116
117func URLMiddleware() Middleware {
118	return func(next http.HandlerFunc) http.HandlerFunc {
119		return func(w http.ResponseWriter, r *http.Request) {
120			// scheme
121			if r.TLS != nil {
122				r.URL.Scheme = "https"
123			} else if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" {
124				r.URL.Scheme = scheme
125			} else if scheme := r.Header.Get("X-Forwarded-Protocol"); scheme != "" {
126				r.URL.Scheme = scheme
127			} else if ssl := r.Header.Get("X-Forwarded-Ssl"); ssl == "on" {
128				r.URL.Scheme = "https"
129			} else if scheme := r.Header.Get("X-Url-Scheme"); scheme != "" {
130				r.URL.Scheme = scheme
131			} else {
132				r.URL.Scheme = "http"
133			}
134
135			next(w, r)
136		}
137	}
138}
139
140func isPlainTextAgent(userAgent string) bool {
141	var plainTextAgents = []string{
142		"curl",
143		"httpie",
144		"lwp-request",
145		"wget",
146		"python-httpx",
147		"python-requests",
148		"openbsd ftp",
149		"powershell",
150		"fetch",
151		"aiohttp",
152		"http_get",
153		"xh",
154	}
155
156	userAgentLower := strings.ToLower(userAgent)
157	for _, signature := range plainTextAgents {
158		if strings.Contains(userAgentLower, signature) {
159			return true
160		}
161	}
162	return false
163}