kpaste

git clone git://git.lin.moe/go/kpaste.git

  1package main
  2
  3import (
  4	"bytes"
  5	"context"
  6	"database/sql"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"io/fs"
 11	"log/slog"
 12	"math/rand"
 13	"mime"
 14	"net/http"
 15	"os"
 16	"path"
 17	"strings"
 18	"time"
 19
 20	_ "github.com/mattn/go-sqlite3"
 21)
 22
 23func init() {
 24	mime.AddExtensionType(".org", "text/org")
 25	mime.AddExtensionType(".markdown", "text/markdown")
 26	mime.AddExtensionType(".md", "text/markdown")
 27}
 28
 29var (
 30	ErrNotExist = errors.New("not existed")
 31)
 32
 33type MediaFS interface {
 34	Reader(path string) (io.ReadCloser, error)
 35	Writer(path string) (io.WriteCloser, error)
 36	Delete(path string) error
 37}
 38
 39type NoteDB interface {
 40	SaveNote(ctx context.Context, key string, note *Note) (err error)
 41	LoadNote(ctx context.Context, key string) (*Note, error)
 42	IncRead(ctx context.Context, key string) (err error)
 43	DeleteNote(ctx context.Context, key string) error
 44	CleanCache(ctx context.Context) (map[string]Note, error)
 45}
 46
 47type sqliteDB struct {
 48	*sql.DB
 49}
 50
 51func NewSqliteDB(dsn string) (NoteDB, error) {
 52	var err error
 53	db, err := sql.Open("sqlite3", dsn)
 54	if err != nil {
 55		return nil, err
 56	}
 57
 58	// TODO: init tables
 59	_, err = db.Exec(`
 60CREATE TABLE IF NOT EXISTS note (
 61key TEXT PRIMARY KEY,
 62filename TEXT,
 63mime_type TEXT,
 64create_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
 65expire_at TIMESTAMP NOT NULL,
 66payload BLOB NOT NULL,
 67readed INTEGER NOT NULL DEFAULT 0,
 68maxread INTEGER,
 69auth_user TEXT,
 70auth_password TEXT
 71);`)
 72	if err != nil {
 73		return nil, err
 74	}
 75
 76	return &sqliteDB{
 77		DB: db,
 78	}, nil
 79}
 80
 81func (db *sqliteDB) SaveNote(ctx context.Context, key string, note *Note) (err error) {
 82	row := db.QueryRowContext(ctx, `
 83                insert into note
 84                (key, mime_type, filename, expire_at, payload, maxread, auth_user, auth_password)
 85                values ($1, $2, $3, $4, $5, $6, $7, $8)
 86                ON CONFLICT(key) DO UPDATE SET mime_type=$2, filename=$3, expire_at=$4, payload=$5, maxread=$6
 87                returning create_at`,
 88		key, note.MimeType, note.Filename, note.ExpireAt, note.Payload, note.MaxRead, note.AuthUser, note.AuthHashedPassword)
 89	return row.Scan(&note.CreateAt)
 90}
 91
 92func (db *sqliteDB) LoadNote(ctx context.Context, key string) (*Note, error) {
 93	row := db.QueryRowContext(ctx, `
 94select mime_type, filename, create_at, expire_at, payload, readed, maxread, auth_user, auth_password
 95from note where key=?`, key)
 96	note := new(Note)
 97	if err := row.Scan(&note.MimeType, &note.Filename, &note.CreateAt, &note.ExpireAt, &note.Payload, &note.Readed, &note.MaxRead, &note.AuthUser, &note.AuthHashedPassword); err != nil {
 98		if errors.Is(err, sql.ErrNoRows) {
 99			return nil, ErrNotExist
100		}
101		return nil, err
102	}
103
104	return note, nil
105}
106
107func (db *sqliteDB) IncRead(ctx context.Context, key string) error {
108	var (
109		readed, maxread int64
110	)
111
112	row := db.QueryRowContext(ctx, `update note set readed = readed + 1 where key=? returning readed,maxread`, key)
113
114	err := row.Scan(&readed, &maxread)
115	if errors.Is(err, sql.ErrNoRows) {
116		return ErrNotExist
117	}
118
119	return err
120}
121
122func (db *sqliteDB) DeleteNote(ctx context.Context, key string) error {
123	_, err := db.ExecContext(ctx, `delete from note where key=?`, key)
124	return err
125}
126
127func (db *sqliteDB) CleanCache(ctx context.Context) (notes map[string]Note, err error) {
128	rows, err := db.QueryContext(ctx, `
129delete from note
130where (maxread >= 0 and readed >= maxread)
131       or expire_at < CURRENT_TIMESTAMP
132RETURNING key, mime_type, filename, create_at, expire_at, payload, readed, maxread, auth_user, auth_password`)
133	if err != nil {
134		return
135	}
136
137	notes = make(map[string]Note)
138
139	for rows.Next() {
140		var (
141			key  string
142			note Note
143		)
144		if err := rows.Scan(
145			&key,
146			&note.MimeType, &note.Filename,
147			&note.CreateAt, &note.ExpireAt,
148			&note.Payload,
149			&note.Readed, &note.MaxRead,
150			&note.AuthUser, &note.AuthHashedPassword); err != nil {
151			return nil, err
152		}
153		notes[key] = note
154	}
155
156	return notes, nil
157}
158
159type dbfs struct {
160	NoteDB
161	MediaFS
162}
163
164func DBWithExternFS(db NoteDB, fs MediaFS) NoteDB {
165	return &dbfs{
166		NoteDB:  db,
167		MediaFS: fs,
168	}
169}
170
171func (db *dbfs) SaveNote(ctx context.Context, key string, note *Note) (err error) {
172	if !note.IsText() {
173		// TODO: rename local path
174		localpath := key
175
176		writer, err := db.MediaFS.Writer(localpath)
177		if err != nil {
178			return err
179		}
180		defer writer.Close()
181
182		if _, err := writer.Write(note.Payload); err != nil {
183			return err
184		}
185
186		note.Payload = []byte(localpath)
187	}
188
189	return db.NoteDB.SaveNote(ctx, key, note)
190}
191
192func (db *dbfs) LoadNote(ctx context.Context, key string) (*Note, error) {
193	note, err := db.NoteDB.LoadNote(ctx, key)
194	if err != nil {
195		return nil, err
196	}
197
198	if !note.IsText() {
199		reader, err := db.MediaFS.Reader(string(note.Payload))
200		if err != nil {
201			return nil, err
202		}
203		defer reader.Close()
204
205		buf := new(bytes.Buffer)
206		if _, err := io.Copy(buf, reader); err != nil {
207			return nil, err
208		}
209		note.Payload = buf.Bytes()
210		return note, nil
211
212	} else {
213		return note, nil
214	}
215}
216
217func (db *dbfs) DeleteNote(ctx context.Context, key string) error {
218	note, err := db.NoteDB.LoadNote(ctx, key)
219	if err != nil {
220		return err
221	}
222
223	if !note.IsText() {
224		if err := db.MediaFS.Delete(string(note.Payload)); err != nil {
225			return err
226		}
227	}
228
229	if err := db.NoteDB.DeleteNote(ctx, key); err != nil {
230		return err
231	}
232
233	return nil
234}
235
236func (db *dbfs) CleanCache(ctx context.Context) (notes map[string]Note, err error) {
237	notes, err = db.NoteDB.CleanCache(ctx)
238	if err != nil {
239		return
240	}
241	for _, note := range notes {
242		if !note.IsText() {
243			if err := db.MediaFS.Delete(string(note.Payload)); err != nil {
244				continue
245			}
246		}
247	}
248	return notes, nil
249}
250
251type Note struct {
252	MimeType string
253	Filename string // original uploaded filename
254	CreateAt time.Time
255	ExpireAt time.Time
256	Payload  []byte
257
258	Readed             int
259	MaxRead            int
260	AuthUser           string
261	AuthHashedPassword string
262}
263
264func (n *Note) IsVisable() bool {
265	if n.MaxRead >= 0 {
266		if n.Readed > n.MaxRead {
267			return false
268		}
269	}
270
271	if n.ExpireAt.Before(time.Now()) {
272		return false
273	}
274
275	return true
276}
277
278func (n *Note) IsText() bool {
279	return strings.HasPrefix(n.MimeType, "text/")
280}
281
282type localFS struct {
283	root string
284}
285
286func NewLocalFS(root string) MediaFS {
287	return &localFS{root}
288}
289
290func (rwfs *localFS) Reader(fpath string) (io.ReadCloser, error) {
291	if !fs.ValidPath(fpath) {
292		return nil, fmt.Errorf("invalid path")
293	}
294	fpath = path.Join(rwfs.root, fpath)
295	return os.Open(fpath)
296}
297
298func (rwfs *localFS) Writer(fpath string) (io.WriteCloser, error) {
299	if !fs.ValidPath(fpath) {
300		return nil, fmt.Errorf("invalid path")
301	}
302	fpath = path.Join(rwfs.root, fpath)
303
304	dirpath := path.Dir(fpath)
305	_, err := os.Stat(dirpath)
306	if errors.Is(err, os.ErrNotExist) {
307		if err := os.Mkdir(dirpath, 0700); err != nil {
308			return nil, err
309		}
310	}
311	return os.OpenFile(fpath, os.O_RDWR|os.O_CREATE, 0600)
312}
313
314func (rwfs *localFS) Delete(fpath string) error {
315	if !fs.ValidPath(fpath) {
316		return fmt.Errorf("invalid path")
317	}
318	fpath = path.Join(rwfs.root, fpath)
319	return os.Remove(fpath)
320}
321
322func NextKey(ctx context.Context, db NoteDB) string {
323	const (
324		minkeylen   = 4
325		maxkeylen   = 16
326		letterBytes = "23456789abcdefghijkmnpqrstuvwxyz" // remove 1, l, o, 0
327	)
328	buf := make([]byte, maxkeylen)
329	for i := 0; i < maxkeylen; i += 1 {
330		buf[i] = letterBytes[rand.Intn(len(letterBytes))]
331		if i >= minkeylen-1 {
332			if _, err := db.LoadNote(ctx, string(buf[:i+1])); errors.Is(err, ErrNotExist) {
333				return string(buf[:i+1])
334			} else {
335				slog.Warn(err.Error())
336			}
337		}
338	}
339	return string(buf)
340}
341
342func DatabaseFromCtx(ctx context.Context) NoteDB {
343	st, ok := ctx.Value("database").(NoteDB)
344	if !ok {
345		panic("not found database in context")
346	}
347	return st
348}
349
350func DBMiddleware(db NoteDB) Middleware {
351	return func(next http.HandlerFunc) http.HandlerFunc {
352		return func(w http.ResponseWriter, r *http.Request) {
353			r = r.WithContext(context.WithValue(r.Context(), "database", db))
354			next(w, r)
355		}
356	}
357}