From 6473886c8e4c713a6f5686e289985a0ebf3e9a82 Mon Sep 17 00:00:00 2001 From: Kelson Vibber Date: Thu, 3 Apr 2025 00:43:21 -0700 Subject: [PATCH 1/4] [bugfix] Fix Atkinson Hyperlegible font embedding on Ecks Pee theme. (#3964) Most browsers just take the second src line and they're fine, but Tor has trouble displaying the woff version on Linux. With two separate lines it doesn't fall back correctly. --- web/assets/themes/ecks-pee.css | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/web/assets/themes/ecks-pee.css b/web/assets/themes/ecks-pee.css index f474f800c..a85e5da0b 100644 --- a/web/assets/themes/ecks-pee.css +++ b/web/assets/themes/ecks-pee.css @@ -93,29 +93,29 @@ font-family: "Atkinson Hyperlegible"; font-weight: normal; font-style: normal; - src: url(/assets/fonts/Atkinson-Hyperlegible-Regular-102a.woff2) format('woff2'); - src: url(/assets/fonts/Atkinson-Hyperlegible-Regular-102.woff) format('woff'); + src: url(/assets/fonts/Atkinson-Hyperlegible-Regular-102a.woff2) format('woff2'), + url(/assets/fonts/Atkinson-Hyperlegible-Regular-102.woff) format('woff'); } @font-face { font-family: "Atkinson Hyperlegible"; font-weight: bold; font-style: normal; - src: url(/assets/fonts/Atkinson-Hyperlegible-Bold-102a.woff2) format('woff2'); - src: url(/assets/fonts/Atkinson-Hyperlegible-Bold-102.woff) format('woff'); + src: url(/assets/fonts/Atkinson-Hyperlegible-Bold-102a.woff2) format('woff2'), + url(/assets/fonts/Atkinson-Hyperlegible-Bold-102.woff) format('woff'); } @font-face { font-family: "Atkinson Hyperlegible"; font-weight: normal; font-style: italic; - src: url(/assets/fonts/Atkinson-Hyperlegible-Italic-102a.woff2) format('woff2'); - src: url(/assets/fonts/Atkinson-Hyperlegible-Italic-102.woff) format('woff'); + src: url(/assets/fonts/Atkinson-Hyperlegible-Italic-102a.woff2) format('woff2'), + url(/assets/fonts/Atkinson-Hyperlegible-Italic-102.woff) format('woff'); } @font-face { font-family: "Atkinson Hyperlegible"; font-weight: bold; font-style: italic; - src: url(/assets/fonts/Atkinson-Hyperlegible-BoldItalic-102a.woff2) format('woff2'); - src: url(/assets/fonts/Atkinson-Hyperlegible-BoldItalic-102.woff) format('woff'); + src: url(/assets/fonts/Atkinson-Hyperlegible-BoldItalic-102a.woff2) format('woff2'), + url(/assets/fonts/Atkinson-Hyperlegible-BoldItalic-102.woff) format('woff'); } /* Main page background */ From db4b85715966ee590c6cdff5cc52e592b66e3d17 Mon Sep 17 00:00:00 2001 From: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com> Date: Fri, 4 Apr 2025 15:34:38 +0000 Subject: [PATCH 2/4] [chore] bump ncruces/go-sqlite3 to v0.25.0 (#3966) --- go.mod | 2 +- go.sum | 4 +- .../github.com/ncruces/go-sqlite3/README.md | 22 +- vendor/github.com/ncruces/go-sqlite3/blob.go | 9 +- .../github.com/ncruces/go-sqlite3/config.go | 16 +- vendor/github.com/ncruces/go-sqlite3/conn.go | 79 +++----- vendor/github.com/ncruces/go-sqlite3/const.go | 7 +- .../github.com/ncruces/go-sqlite3/context.go | 10 +- .../ncruces/go-sqlite3/driver/driver.go | 188 +++++++++++------- .../ncruces/go-sqlite3/driver/time.go | 11 +- vendor/github.com/ncruces/go-sqlite3/error.go | 18 +- vendor/github.com/ncruces/go-sqlite3/func.go | 159 +++++++++++---- .../ncruces/go-sqlite3/internal/util/error.go | 4 +- .../ncruces/go-sqlite3/internal/util/mem.go | 5 +- .../github.com/ncruces/go-sqlite3/sqlite.go | 38 ++-- vendor/github.com/ncruces/go-sqlite3/stmt.go | 140 ++++++++++--- vendor/github.com/ncruces/go-sqlite3/txn.go | 24 +-- .../ncruces/go-sqlite3/util/osutil/open.go | 16 -- .../go-sqlite3/util/osutil/open_windows.go | 115 ----------- .../ncruces/go-sqlite3/util/osutil/osfs.go | 33 --- .../ncruces/go-sqlite3/util/osutil/osutil.go | 2 - .../go-sqlite3/util/sql3util/sql3util.go | 2 +- vendor/github.com/ncruces/go-sqlite3/value.go | 8 +- .../ncruces/go-sqlite3/vfs/README.md | 35 +++- .../github.com/ncruces/go-sqlite3/vfs/cksm.go | 24 ++- .../github.com/ncruces/go-sqlite3/vfs/file.go | 28 ++- .../ncruces/go-sqlite3/vfs/os_bsd.go | 27 ++- .../ncruces/go-sqlite3/vfs/os_darwin.go | 22 +- .../ncruces/go-sqlite3/vfs/os_linux.go | 46 ++++- .../ncruces/go-sqlite3/vfs/os_std.go | 2 +- .../ncruces/go-sqlite3/vfs/os_unix.go | 13 +- .../ncruces/go-sqlite3/vfs/os_windows.go | 35 +--- .../ncruces/go-sqlite3/vfs/shm_bsd.go | 15 +- .../ncruces/go-sqlite3/vfs/shm_windows.go | 23 +-- vendor/github.com/ncruces/go-sqlite3/vtab.go | 29 ++- vendor/modules.txt | 3 +- 36 files changed, 636 insertions(+), 578 deletions(-) delete mode 100644 vendor/github.com/ncruces/go-sqlite3/util/osutil/open.go delete mode 100644 vendor/github.com/ncruces/go-sqlite3/util/osutil/open_windows.go delete mode 100644 vendor/github.com/ncruces/go-sqlite3/util/osutil/osfs.go delete mode 100644 vendor/github.com/ncruces/go-sqlite3/util/osutil/osutil.go diff --git a/go.mod b/go.mod index c23748996..479002530 100644 --- a/go.mod +++ b/go.mod @@ -54,7 +54,7 @@ require ( github.com/miekg/dns v1.1.64 github.com/minio/minio-go/v7 v7.0.85 github.com/mitchellh/mapstructure v1.5.0 - github.com/ncruces/go-sqlite3 v0.24.0 + github.com/ncruces/go-sqlite3 v0.25.0 github.com/oklog/ulid v1.3.1 github.com/prometheus/client_golang v1.21.1 github.com/rivo/uniseg v0.4.7 diff --git a/go.sum b/go.sum index 88fb45e71..f886b0e89 100644 --- a/go.sum +++ b/go.sum @@ -322,8 +322,8 @@ github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs= github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/ncruces/go-sqlite3 v0.24.0 h1:Z4jfmzu2NCd4SmyFwLT2OmF3EnTZbqwATvdiuNHNhLA= -github.com/ncruces/go-sqlite3 v0.24.0/go.mod h1:/Vs8ACZHjJ1SA6E9RZUn3EyB1OP3nDQ4z/ar+0fplTQ= +github.com/ncruces/go-sqlite3 v0.25.0 h1:trugKUs98Zwy9KwRr/EUxZHL92LYt7UqcKqAfpGpK+I= +github.com/ncruces/go-sqlite3 v0.25.0/go.mod h1:n6Z7036yFilJx04yV0mi5JWaF66rUmXn1It9Ux8dx68= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= diff --git a/vendor/github.com/ncruces/go-sqlite3/README.md b/vendor/github.com/ncruces/go-sqlite3/README.md index cac6ee6b8..94fd9950b 100644 --- a/vendor/github.com/ncruces/go-sqlite3/README.md +++ b/vendor/github.com/ncruces/go-sqlite3/README.md @@ -65,17 +65,20 @@ db.QueryRow(`SELECT sqlite_version()`).Scan(&version) This module replaces the SQLite [OS Interface](https://sqlite.org/vfs.html) (aka VFS) with a [pure Go](vfs/) implementation, which has advantages and disadvantages. - Read more about the Go VFS design [here](vfs/README.md). +Because each database connection executes within a Wasm sandboxed environment, +memory usage will be higher than alternatives. + ### Testing This project aims for [high test coverage](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report). It also benefits greatly from [SQLite's](https://sqlite.org/testing.html) and -[wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach) thorough testing. +[wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach) +thorough testing. Every commit is [tested](https://github.com/ncruces/go-sqlite3/wiki/Support-matrix) on -Linux (amd64/arm64/386/riscv64/ppc64le/s390x), macOS (amd64/arm64), +Linux (amd64/arm64/386/riscv64/ppc64le/s390x), macOS (arm64/amd64), Windows (amd64), FreeBSD (amd64/arm64), OpenBSD (amd64), NetBSD (amd64/arm64), DragonFly BSD (amd64), illumos (amd64), and Solaris (amd64). @@ -84,12 +87,21 @@ The Go VFS is tested by running SQLite's ### Performance -Perfomance of the [`database/sql`](https://pkg.go.dev/database/sql) driver is +Performance of the [`database/sql`](https://pkg.go.dev/database/sql) driver is [competitive](https://github.com/cvilsmeier/go-sqlite-bench) with alternatives. -The Wasm and VFS layers are also tested by running SQLite's +The Wasm and VFS layers are also benchmarked by running SQLite's [speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c). +### Concurrency + +This module behaves similarly to SQLite in [multi-thread](https://sqlite.org/threadsafe.html) mode: +it is goroutine-safe, provided that no single database connection, or object derived from it, +is used concurrently by multiple goroutines. + +The [`database/sql`](https://pkg.go.dev/database/sql) API is safe to use concurrently, +according to its documentation. + ### FAQ, issues, new features For questions, please see [Discussions](https://github.com/ncruces/go-sqlite3/discussions/categories/q-a). diff --git a/vendor/github.com/ncruces/go-sqlite3/blob.go b/vendor/github.com/ncruces/go-sqlite3/blob.go index 2fac72045..ea7caf9d8 100644 --- a/vendor/github.com/ncruces/go-sqlite3/blob.go +++ b/vendor/github.com/ncruces/go-sqlite3/blob.go @@ -31,6 +31,10 @@ var _ io.ReadWriteSeeker = &Blob{} // // https://sqlite.org/c3ref/blob_open.html func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) { + if c.interrupt.Err() != nil { + return nil, INTERRUPT + } + defer c.arena.mark()() blobPtr := c.arena.new(ptrlen) dbPtr := c.arena.string(db) @@ -42,7 +46,6 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, flags = 1 } - c.checkInterrupt(c.handle) rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle), stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr), stk_t(row), stk_t(flags), stk_t(blobPtr))) @@ -253,7 +256,9 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) { // // https://sqlite.org/c3ref/blob_reopen.html func (b *Blob) Reopen(row int64) error { - b.c.checkInterrupt(b.c.handle) + if b.c.interrupt.Err() != nil { + return INTERRUPT + } err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row)))) b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle)))) b.offset = 0 diff --git a/vendor/github.com/ncruces/go-sqlite3/config.go b/vendor/github.com/ncruces/go-sqlite3/config.go index 17166b9c5..3921fe98a 100644 --- a/vendor/github.com/ncruces/go-sqlite3/config.go +++ b/vendor/github.com/ncruces/go-sqlite3/config.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strconv" + "sync/atomic" "github.com/tetratelabs/wazero/api" @@ -48,6 +49,15 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) { return util.ReadBool(c.mod, argsPtr), c.error(rc) } +var defaultLogger atomic.Pointer[func(code ExtendedErrorCode, msg string)] + +// ConfigLog sets up the default error logging callback for new connections. +// +// https://sqlite.org/errlog.html +func ConfigLog(cb func(code ExtendedErrorCode, msg string)) { + defaultLogger.Store(&cb) +} + // ConfigLog sets up the error logging callback for the connection. // // https://sqlite.org/errlog.html @@ -265,6 +275,10 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr // // https://sqlite.org/c3ref/wal_checkpoint_v2.html func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt int, err error) { + if c.interrupt.Err() != nil { + return 0, 0, INTERRUPT + } + defer c.arena.mark()() nLogPtr := c.arena.new(ptrlen) nCkptPtr := c.arena.new(ptrlen) @@ -378,6 +392,6 @@ func (c *Conn) EnableChecksums(schema string) error { } // Checkpoint the WAL. - _, _, err = c.WALCheckpoint(schema, CHECKPOINT_RESTART) + _, _, err = c.WALCheckpoint(schema, CHECKPOINT_FULL) return err } diff --git a/vendor/github.com/ncruces/go-sqlite3/conn.go b/vendor/github.com/ncruces/go-sqlite3/conn.go index 9f9251e9f..7e88d8c85 100644 --- a/vendor/github.com/ncruces/go-sqlite3/conn.go +++ b/vendor/github.com/ncruces/go-sqlite3/conn.go @@ -25,7 +25,6 @@ type Conn struct { *sqlite interrupt context.Context - pending *Stmt stmts []*Stmt busy func(context.Context, int) bool log func(xErrorCode, string) @@ -41,6 +40,7 @@ type Conn struct { busylst time.Time arena arena handle ptr_t + gosched uint8 } // Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI]. @@ -49,7 +49,7 @@ func Open(filename string) (*Conn, error) { } // OpenContext is like [Open] but includes a context, -// which is used to interrupt the process of opening the connectiton. +// which is used to interrupt the process of opening the connection. func OpenContext(ctx context.Context, filename string) (*Conn, error) { return newConn(ctx, filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI) } @@ -92,6 +92,9 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _ }() c.ctx = context.WithValue(c.ctx, connKey{}, c) + if logger := defaultLogger.Load(); logger != nil { + c.ConfigLog(*logger) + } c.arena = c.newArena() c.handle, err = c.openDB(filename, flags) if err == nil { @@ -117,7 +120,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) { return 0, err } - c.call("sqlite3_progress_handler_go", stk_t(handle), 100) + c.call("sqlite3_progress_handler_go", stk_t(handle), 1000) if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") { var pragmas strings.Builder if _, after, ok := strings.Cut(filename, "?"); ok { @@ -129,7 +132,6 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) { } } if pragmas.Len() != 0 { - c.checkInterrupt(handle) pragmaPtr := c.arena.string(pragmas.String()) rc := res_t(c.call("sqlite3_exec", stk_t(handle), stk_t(pragmaPtr), 0, 0, 0)) if err := c.sqlite.error(rc, handle, pragmas.String()); err != nil { @@ -163,9 +165,6 @@ func (c *Conn) Close() error { return nil } - c.pending.Close() - c.pending = nil - rc := res_t(c.call("sqlite3_close", stk_t(c.handle))) if err := c.error(rc); err != nil { return err @@ -180,11 +179,16 @@ func (c *Conn) Close() error { // // https://sqlite.org/c3ref/exec.html func (c *Conn) Exec(sql string) error { - defer c.arena.mark()() - sqlPtr := c.arena.string(sql) + if c.interrupt.Err() != nil { + return INTERRUPT + } + return c.exec(sql) +} - c.checkInterrupt(c.handle) - rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(sqlPtr), 0, 0, 0)) +func (c *Conn) exec(sql string) error { + defer c.arena.mark()() + textPtr := c.arena.string(sql) + rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(textPtr), 0, 0, 0)) return c.error(rc, sql) } @@ -203,20 +207,22 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str if len(sql) > _MAX_SQL_LENGTH { return nil, "", TOOBIG } + if c.interrupt.Err() != nil { + return nil, "", INTERRUPT + } defer c.arena.mark()() stmtPtr := c.arena.new(ptrlen) tailPtr := c.arena.new(ptrlen) - sqlPtr := c.arena.string(sql) + textPtr := c.arena.string(sql) - c.checkInterrupt(c.handle) rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle), - stk_t(sqlPtr), stk_t(len(sql)+1), stk_t(flags), + stk_t(textPtr), stk_t(len(sql)+1), stk_t(flags), stk_t(stmtPtr), stk_t(tailPtr))) - stmt = &Stmt{c: c} + stmt = &Stmt{c: c, sql: sql} stmt.handle = util.Read32[ptr_t](c.mod, stmtPtr) - if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-sqlPtr:]; sql != "" { + if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-textPtr:]; sql != "" { tail = sql } @@ -337,43 +343,17 @@ func (c *Conn) GetInterrupt() context.Context { // // https://sqlite.org/c3ref/interrupt.html func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { + if ctx == nil { + panic("nil Context") + } old = c.interrupt c.interrupt = ctx - - if ctx == old || ctx.Done() == old.Done() { - return old - } - - // A busy SQL statement prevents SQLite from ignoring an interrupt - // that comes before any other statements are started. - if c.pending == nil { - defer c.arena.mark()() - stmtPtr := c.arena.new(ptrlen) - loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`) - c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(loopPtr), math.MaxUint64, - stk_t(PREPARE_PERSISTENT), stk_t(stmtPtr), 0) - c.pending = &Stmt{c: c} - c.pending.handle = util.Read32[ptr_t](c.mod, stmtPtr) - } - - if old.Done() != nil && ctx.Err() == nil { - c.pending.Reset() - } - if ctx.Done() != nil { - c.pending.Step() - } return old } -func (c *Conn) checkInterrupt(handle ptr_t) { - if c.interrupt.Err() != nil { - c.call("sqlite3_interrupt", stk_t(handle)) - } -} - func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok { - if c.interrupt.Done() != nil { + if c.gosched++; c.gosched%16 == 0 { runtime.Gosched() } if c.interrupt.Err() != nil { @@ -429,11 +409,8 @@ func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry int32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil { - interrupt := c.interrupt - if interrupt == nil { - interrupt = context.Background() - } - if interrupt.Err() == nil && c.busy(interrupt, int(count)) { + if interrupt := c.interrupt; interrupt.Err() == nil && + c.busy(interrupt, int(count)) { retry = 1 } } diff --git a/vendor/github.com/ncruces/go-sqlite3/const.go b/vendor/github.com/ncruces/go-sqlite3/const.go index 82d80515e..522f68bfb 100644 --- a/vendor/github.com/ncruces/go-sqlite3/const.go +++ b/vendor/github.com/ncruces/go-sqlite3/const.go @@ -11,10 +11,9 @@ const ( _ROW = 100 /* sqlite3_step() has another row ready */ _DONE = 101 /* sqlite3_step() has finished executing */ - _MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings. - _MAX_LENGTH = 1e9 - _MAX_SQL_LENGTH = 1e9 - _MAX_FUNCTION_ARG = 100 + _MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings. + _MAX_LENGTH = 1e9 + _MAX_SQL_LENGTH = 1e9 ptrlen = util.PtrLen intlen = util.IntLen diff --git a/vendor/github.com/ncruces/go-sqlite3/context.go b/vendor/github.com/ncruces/go-sqlite3/context.go index 637ddc282..abee4ec1e 100644 --- a/vendor/github.com/ncruces/go-sqlite3/context.go +++ b/vendor/github.com/ncruces/go-sqlite3/context.go @@ -89,20 +89,26 @@ func (ctx Context) ResultText(value string) { } // ResultRawText sets the text result of the function to a []byte. -// Returning a nil slice is the same as calling [Context.ResultNull]. // // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultRawText(value []byte) { + if len(value) == 0 { + ctx.ResultText("") + return + } ptr := ctx.c.newBytes(value) ctx.c.call("sqlite3_result_text_go", stk_t(ctx.handle), stk_t(ptr), stk_t(len(value))) } // ResultBlob sets the result of the function to a []byte. -// Returning a nil slice is the same as calling [Context.ResultNull]. // // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultBlob(value []byte) { + if len(value) == 0 { + ctx.ResultZeroBlob(0) + return + } ptr := ctx.c.newBytes(value) ctx.c.call("sqlite3_result_blob_go", stk_t(ctx.handle), stk_t(ptr), stk_t(len(value))) diff --git a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go index 21799aeb2..9250cf39d 100644 --- a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go +++ b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go @@ -20,22 +20,45 @@ // - a [serializable] transaction is always "immediate"; // - a [read-only] transaction is always "deferred". // +// # Datatypes In SQLite +// +// SQLite is dynamically typed. +// Columns can mostly hold any value regardless of their declared type. +// SQLite supports most [driver.Value] types out of the box, +// but bool and [time.Time] require special care. +// +// Booleans can be stored on any column type and scanned back to a *bool. +// However, if scanned to a *any, booleans may either become an +// int64, string or bool, depending on the declared type of the column. +// If you use BOOLEAN for your column type, +// 1 and 0 will always scan as true and false. +// // # Working with time // +// Time values can similarly be stored on any column type. // The time encoding/decoding format can be specified using "_timefmt": // // sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite") // -// Possible values are: "auto" (the default), "sqlite", "rfc3339"; +// Special values are: "auto" (the default), "sqlite", "rfc3339"; // - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite; // - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite; // - "rfc3339" encodes and decodes RFC 3339 only. // -// If you encode as RFC 3339 (the default), -// consider using the TIME [collating sequence] to produce a time-ordered sequence. +// You can also set "_timefmt" to an arbitrary [sqlite3.TimeFormat] or [time.Layout]. // -// To scan values in other formats, [sqlite3.TimeFormat.Scanner] may be helpful. -// To bind values in other formats, [sqlite3.TimeFormat.Encode] them before binding. +// If you encode as RFC 3339 (the default), +// consider using the TIME [collating sequence] to produce time-ordered sequences. +// +// If you encode as RFC 3339 (the default), +// time values will scan back to a *time.Time unless your column type is TEXT. +// Otherwise, if scanned to a *any, time values may either become an +// int64, float64 or string, depending on the time format and declared type of the column. +// If you use DATE, TIME, DATETIME, or TIMESTAMP for your column type, +// "_timefmt" will be used to decode values. +// +// To scan values in custom formats, [sqlite3.TimeFormat.Scanner] may be helpful. +// To bind values in custom formats, [sqlite3.TimeFormat.Encode] them before binding. // // When using a custom time struct, you'll have to implement // [database/sql/driver.Valuer] and [database/sql.Scanner]. @@ -48,7 +71,7 @@ // The Scan method needs to take into account that the value it receives can be of differing types. // It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules. // Or it can be a: string, int64, float64, []byte, or nil, -// depending on the column type and what whoever wrote the value. +// depending on the column type and whoever wrote the value. // [sqlite3.TimeFormat.Decode] may help. // // # Setting PRAGMAs @@ -358,13 +381,10 @@ func (c *conn) Commit() error { } func (c *conn) Rollback() error { - err := c.Conn.Exec(`ROLLBACK` + c.txReset) - if errors.Is(err, sqlite3.INTERRUPT) { - old := c.Conn.SetInterrupt(context.Background()) - defer c.Conn.SetInterrupt(old) - err = c.Conn.Exec(`ROLLBACK` + c.txReset) - } - return err + // ROLLBACK even if interrupted. + old := c.Conn.SetInterrupt(context.Background()) + defer c.Conn.SetInterrupt(old) + return c.Conn.Exec(`ROLLBACK` + c.txReset) } func (c *conn) Prepare(query string) (driver.Stmt, error) { @@ -598,6 +618,28 @@ const ( _TIME ) +func scanFromDecl(decl string) scantype { + // These types are only used before we have rows, + // and otherwise as type hints. + // The first few ensure STRICT tables are strictly typed. + // The other two are type hints for booleans and time. + switch decl { + case "INT", "INTEGER": + return _INT + case "REAL": + return _REAL + case "TEXT": + return _TEXT + case "BLOB": + return _BLOB + case "BOOLEAN": + return _BOOL + case "DATE", "TIME", "DATETIME", "TIMESTAMP": + return _TIME + } + return _ANY +} + var ( // Ensure these interfaces are implemented: _ driver.RowsColumnTypeDatabaseTypeName = &rows{} @@ -622,6 +664,18 @@ func (r *rows) Columns() []string { return r.names } +func (r *rows) scanType(index int) scantype { + if r.scans == nil { + count := r.Stmt.ColumnCount() + scans := make([]scantype, count) + for i := range scans { + scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i))) + } + r.scans = scans + } + return r.scans[index] +} + func (r *rows) loadColumnMetadata() { if r.nulls == nil { count := r.Stmt.ColumnCount() @@ -635,24 +689,7 @@ func (r *rows) loadColumnMetadata() { r.Stmt.ColumnTableName(i), col) types[i] = strings.ToUpper(types[i]) - // These types are only used before we have rows, - // and otherwise as type hints. - // The first few ensure STRICT tables are strictly typed. - // The other two are type hints for booleans and time. - switch types[i] { - case "INT", "INTEGER": - scans[i] = _INT - case "REAL": - scans[i] = _REAL - case "TEXT": - scans[i] = _TEXT - case "BLOB": - scans[i] = _BLOB - case "BOOLEAN": - scans[i] = _BOOL - case "DATE", "TIME", "DATETIME", "TIMESTAMP": - scans[i] = _TIME - } + scans[i] = scanFromDecl(types[i]) } } r.nulls = nulls @@ -661,27 +698,15 @@ func (r *rows) loadColumnMetadata() { } } -func (r *rows) declType(index int) string { - if r.types == nil { - count := r.Stmt.ColumnCount() - types := make([]string, count) - for i := range types { - types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i)) - } - r.types = types - } - return r.types[index] -} - func (r *rows) ColumnTypeDatabaseTypeName(index int) string { r.loadColumnMetadata() - decltype := r.types[index] - if len := len(decltype); len > 0 && decltype[len-1] == ')' { - if i := strings.LastIndexByte(decltype, '('); i >= 0 { - decltype = decltype[:i] + decl := r.types[index] + if len := len(decl); len > 0 && decl[len-1] == ')' { + if i := strings.LastIndexByte(decl, '('); i >= 0 { + decl = decl[:i] } } - return strings.TrimSpace(decltype) + return strings.TrimSpace(decl) } func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { @@ -748,36 +773,49 @@ func (r *rows) Next(dest []driver.Value) error { } data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest)) - err := r.Stmt.Columns(data...) + if err := r.Stmt.ColumnsRaw(data...); err != nil { + return err + } for i := range dest { - if t, ok := r.decodeTime(i, dest[i]); ok { - dest[i] = t - } - } - return err -} - -func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) { - switch v := v.(type) { - case int64, float64: - // could be a time value - case string: - if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano { + scan := r.scanType(i) + switch v := dest[i].(type) { + case int64: + if scan == _BOOL { + switch v { + case 1: + dest[i] = true + case 0: + dest[i] = false + } + continue + } + case []byte: + if len(v) == cap(v) { // a BLOB + continue + } + if scan != _TEXT { + switch r.tmWrite { + case "", time.RFC3339, time.RFC3339Nano: + t, ok := maybeTime(v) + if ok { + dest[i] = t + continue + } + } + } + dest[i] = string(v) + case float64: break + default: + continue } - t, ok := maybeTime(v) - if ok { - return t, true + if scan == _TIME { + t, err := r.tmRead.Decode(dest[i]) + if err == nil { + dest[i] = t + continue + } } - default: - return } - switch r.declType(i) { - case "DATE", "TIME", "DATETIME", "TIMESTAMP": - // could be a time value - default: - return - } - t, err := r.tmRead.Decode(v) - return t, err == nil + return nil } diff --git a/vendor/github.com/ncruces/go-sqlite3/driver/time.go b/vendor/github.com/ncruces/go-sqlite3/driver/time.go index b3ebdd263..4d48bd8dc 100644 --- a/vendor/github.com/ncruces/go-sqlite3/driver/time.go +++ b/vendor/github.com/ncruces/go-sqlite3/driver/time.go @@ -1,12 +1,15 @@ package driver -import "time" +import ( + "bytes" + "time" +) // Convert a string in [time.RFC3339Nano] format into a [time.Time] // if it roundtrips back to the same string. // This way times can be persisted to, and recovered from, the database, // but if a string is needed, [database/sql] will recover the same string. -func maybeTime(text string) (_ time.Time, _ bool) { +func maybeTime(text []byte) (_ time.Time, _ bool) { // Weed out (some) values that can't possibly be // [time.RFC3339Nano] timestamps. if len(text) < len("2006-01-02T15:04:05Z") { @@ -21,8 +24,8 @@ func maybeTime(text string) (_ time.Time, _ bool) { // Slow path. var buf [len(time.RFC3339Nano)]byte - date, err := time.Parse(time.RFC3339Nano, text) - if err == nil && text == string(date.AppendFormat(buf[:0], time.RFC3339Nano)) { + date, err := time.Parse(time.RFC3339Nano, string(text)) + if err == nil && bytes.Equal(text, date.AppendFormat(buf[:0], time.RFC3339Nano)) { return date, true } return diff --git a/vendor/github.com/ncruces/go-sqlite3/error.go b/vendor/github.com/ncruces/go-sqlite3/error.go index 6d4bd63f8..59982eafd 100644 --- a/vendor/github.com/ncruces/go-sqlite3/error.go +++ b/vendor/github.com/ncruces/go-sqlite3/error.go @@ -2,7 +2,6 @@ package sqlite3 import ( "errors" - "strconv" "strings" "github.com/ncruces/go-sqlite3/internal/util" @@ -12,7 +11,6 @@ import ( // // https://sqlite.org/c3ref/errcode.html type Error struct { - str string msg string sql string code res_t @@ -29,19 +27,13 @@ func (e *Error) Code() ErrorCode { // // https://sqlite.org/rescode.html func (e *Error) ExtendedCode() ExtendedErrorCode { - return ExtendedErrorCode(e.code) + return xErrorCode(e.code) } // Error implements the error interface. func (e *Error) Error() string { var b strings.Builder - b.WriteString("sqlite3: ") - - if e.str != "" { - b.WriteString(e.str) - } else { - b.WriteString(strconv.Itoa(int(e.code))) - } + b.WriteString(util.ErrorCodeString(uint32(e.code))) if e.msg != "" { b.WriteString(": ") @@ -103,12 +95,12 @@ func (e ErrorCode) Error() string { // Temporary returns true for [BUSY] errors. func (e ErrorCode) Temporary() bool { - return e == BUSY + return e == BUSY || e == INTERRUPT } // ExtendedCode returns the extended error code for this error. func (e ErrorCode) ExtendedCode() ExtendedErrorCode { - return ExtendedErrorCode(e) + return xErrorCode(e) } // Error implements the error interface. @@ -133,7 +125,7 @@ func (e ExtendedErrorCode) As(err any) bool { // Temporary returns true for [BUSY] errors. func (e ExtendedErrorCode) Temporary() bool { - return ErrorCode(e) == BUSY + return ErrorCode(e) == BUSY || ErrorCode(e) == INTERRUPT } // Timeout returns true for [BUSY_TIMEOUT] errors. diff --git a/vendor/github.com/ncruces/go-sqlite3/func.go b/vendor/github.com/ncruces/go-sqlite3/func.go index f907fa940..16b43056d 100644 --- a/vendor/github.com/ncruces/go-sqlite3/func.go +++ b/vendor/github.com/ncruces/go-sqlite3/func.go @@ -3,7 +3,9 @@ package sqlite3 import ( "context" "io" + "iter" "sync" + "sync/atomic" "github.com/tetratelabs/wazero/api" @@ -45,7 +47,7 @@ func (c Conn) AnyCollationNeeded() error { // CreateCollation defines a new collating sequence. // // https://sqlite.org/c3ref/create_collation.html -func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { +func (c *Conn) CreateCollation(name string, fn CollatingFunction) error { var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) @@ -57,6 +59,10 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { return c.error(rc) } +// Collating function is the type of a collation callback. +// Implementations must not retain a or b. +type CollatingFunction func(a, b []byte) int + // CreateFunction defines a new scalar SQL function. // // https://sqlite.org/c3ref/create_function.html @@ -77,34 +83,67 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn Scala // Implementations must not retain arg. type ScalarFunction func(ctx Context, arg ...Value) -// CreateWindowFunction defines a new aggregate or aggregate window SQL function. -// If fn returns a [WindowFunction], then an aggregate window function is created. -// If fn returns an [io.Closer], it will be called to free resources. +// CreateAggregateFunction defines a new aggregate SQL function. // // https://sqlite.org/c3ref/create_function.html -func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { +func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error { var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) - call := "sqlite3_create_aggregate_function_go" if fn != nil { - agg := fn() - if c, ok := agg.(io.Closer); ok { - if err := c.Close(); err != nil { - return err + funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction { + var a aggregateFunc + coro := func(yieldCoro func(struct{}) bool) { + seq := func(yieldSeq func([]Value) bool) { + for yieldSeq(a.arg) { + if !yieldCoro(struct{}{}) { + break + } + } + } + fn(&a.ctx, seq) } - } - if _, ok := agg.(WindowFunction); ok { - call = "sqlite3_create_window_function_go" - } - funcPtr = util.AddHandle(c.ctx, fn) + a.next, a.stop = iter.Pull(coro) + return &a + })) } - rc := res_t(c.call(call, + rc := res_t(c.call("sqlite3_create_aggregate_function_go", stk_t(c.handle), stk_t(namePtr), stk_t(nArg), stk_t(flag), stk_t(funcPtr))) return c.error(rc) } +// AggregateSeqFunction is the type of an aggregate SQL function. +// Implementations must not retain the slices yielded by seq. +type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value]) + +// CreateWindowFunction defines a new aggregate or aggregate window SQL function. +// If fn returns a [WindowFunction], an aggregate window function is created. +// If fn returns an [io.Closer], it will be called to free resources. +// +// https://sqlite.org/c3ref/create_function.html +func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error { + var funcPtr ptr_t + defer c.arena.mark()() + namePtr := c.arena.string(name) + if fn != nil { + funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction { + agg := fn() + if win, ok := agg.(WindowFunction); ok { + return win + } + return windowFunc{agg, name} + })) + } + rc := res_t(c.call("sqlite3_create_window_function_go", + stk_t(c.handle), stk_t(namePtr), stk_t(nArg), + stk_t(flag), stk_t(funcPtr))) + return c.error(rc) +} + +// AggregateConstructor is a an [AggregateFunction] constructor. +type AggregateConstructor func() AggregateFunction + // AggregateFunction is the interface an aggregate function should implement. // // https://sqlite.org/appfunc.html @@ -153,26 +192,24 @@ func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTe } func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 { - fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) + fn := util.GetHandle(ctx, pApp).(CollatingFunction) return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2)))) } func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg int32, pArg ptr_t) { - args := getFuncArgs() - defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) + args := callbackArgs(db, nArg, pArg) + defer returnArgs(args) fn := util.GetHandle(db.ctx, pApp).(ScalarFunction) - callbackArgs(db, args[:nArg], pArg) - fn(Context{db, pCtx}, args[:nArg]...) + fn(Context{db, pCtx}, *args...) } func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg int32, pArg ptr_t) { - args := getFuncArgs() - defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) - callbackArgs(db, args[:nArg], pArg) + args := callbackArgs(db, nArg, pArg) + defer returnArgs(args) fn, _ := callbackAggregate(db, pAgg, pApp) - fn.Step(Context{db, pCtx}, args[:nArg]...) + fn.Step(Context{db, pCtx}, *args...) } func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, final int32) { @@ -196,12 +233,11 @@ func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, } func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) { - args := getFuncArgs() - defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) - callbackArgs(db, args[:nArg], pArg) + args := callbackArgs(db, nArg, pArg) + defer returnArgs(args) fn := util.GetHandle(db.ctx, pAgg).(WindowFunction) - fn.Inverse(Context{db, pCtx}, args[:nArg]...) + fn.Inverse(Context{db, pCtx}, *args...) } func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) { @@ -211,7 +247,7 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) { } // We need to create the aggregate. - fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)() + fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)() if pAgg != 0 { handle := util.AddHandle(db.ctx, fn) util.Write32(db.mod, pAgg, handle) @@ -220,25 +256,64 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) { return fn, 0 } -func callbackArgs(db *Conn, arg []Value, pArg ptr_t) { - for i := range arg { - arg[i] = Value{ +var ( + valueArgsPool sync.Pool + valueArgsLen atomic.Int32 +) + +func callbackArgs(db *Conn, nArg int32, pArg ptr_t) *[]Value { + arg, ok := valueArgsPool.Get().(*[]Value) + if !ok || cap(*arg) < int(nArg) { + max := valueArgsLen.Or(nArg) | nArg + lst := make([]Value, max) + arg = &lst + } + lst := (*arg)[:nArg] + for i := range lst { + lst[i] = Value{ c: db, handle: util.Read32[ptr_t](db.mod, pArg+ptr_t(i)*ptrlen), } } + *arg = lst + return arg } -var funcArgsPool sync.Pool - -func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) { - funcArgsPool.Put(p) +func returnArgs(p *[]Value) { + valueArgsPool.Put(p) } -func getFuncArgs() *[_MAX_FUNCTION_ARG]Value { - if p := funcArgsPool.Get(); p == nil { - return new([_MAX_FUNCTION_ARG]Value) - } else { - return p.(*[_MAX_FUNCTION_ARG]Value) +type aggregateFunc struct { + next func() (struct{}, bool) + stop func() + ctx Context + arg []Value +} + +func (a *aggregateFunc) Step(ctx Context, arg ...Value) { + a.ctx = ctx + a.arg = append(a.arg[:0], arg...) + if _, more := a.next(); !more { + a.stop() } } + +func (a *aggregateFunc) Value(ctx Context) { + a.ctx = ctx + a.stop() +} + +func (a *aggregateFunc) Close() error { + a.stop() + return nil +} + +type windowFunc struct { + AggregateFunction + name string +} + +func (w windowFunc) Inverse(ctx Context, arg ...Value) { + // Implementing inverse allows certain queries that don't really need it to succeed. + ctx.ResultError(util.ErrorString(w.name + ": may not be used as a window function")) +} diff --git a/vendor/github.com/ncruces/go-sqlite3/internal/util/error.go b/vendor/github.com/ncruces/go-sqlite3/internal/util/error.go index 2aecac96e..76769ed2e 100644 --- a/vendor/github.com/ncruces/go-sqlite3/internal/util/error.go +++ b/vendor/github.com/ncruces/go-sqlite3/internal/util/error.go @@ -75,7 +75,7 @@ func ErrorCodeString(rc uint32) string { return "sqlite3: unable to open database file" case PROTOCOL: return "sqlite3: locking protocol" - case FORMAT: + case EMPTY: break case SCHEMA: return "sqlite3: database schema has changed" @@ -91,7 +91,7 @@ func ErrorCodeString(rc uint32) string { break case AUTH: return "sqlite3: authorization denied" - case EMPTY: + case FORMAT: break case RANGE: return "sqlite3: column index out of range" diff --git a/vendor/github.com/ncruces/go-sqlite3/internal/util/mem.go b/vendor/github.com/ncruces/go-sqlite3/internal/util/mem.go index d2fea08b4..90c0e9e54 100644 --- a/vendor/github.com/ncruces/go-sqlite3/internal/util/mem.go +++ b/vendor/github.com/ncruces/go-sqlite3/internal/util/mem.go @@ -135,11 +135,10 @@ func ReadString(mod api.Module, ptr Ptr_t, maxlen int64) string { panic(RangeErr) } } - if i := bytes.IndexByte(buf, 0); i < 0 { - panic(NoNulErr) - } else { + if i := bytes.IndexByte(buf, 0); i >= 0 { return string(buf[:i]) } + panic(NoNulErr) } func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) { diff --git a/vendor/github.com/ncruces/go-sqlite3/sqlite.go b/vendor/github.com/ncruces/go-sqlite3/sqlite.go index 9e2d1d381..c05a86fde 100644 --- a/vendor/github.com/ncruces/go-sqlite3/sqlite.go +++ b/vendor/github.com/ncruces/go-sqlite3/sqlite.go @@ -120,33 +120,33 @@ func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error { return nil } - err := Error{code: rc} - - if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM { + if ErrorCode(rc) == NOMEM || xErrorCode(rc) == IOERR_NOMEM { panic(util.OOMErr) } - if ptr := ptr_t(sqlt.call("sqlite3_errstr", stk_t(rc))); ptr != 0 { - err.str = util.ReadString(sqlt.mod, ptr, _MAX_NAME) - } - if handle != 0 { + var msg, query string if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 { - err.msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH) + msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH) + switch { + case msg == "not an error": + msg = "" + case msg == util.ErrorCodeString(uint32(rc))[len("sqlite3: "):]: + msg = "" + } } if len(sql) != 0 { if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 { - err.sql = sql[0][i:] + query = sql[0][i:] } } - } - switch err.msg { - case err.str, "not an error": - err.msg = "" + if msg != "" || query != "" { + return &Error{code: rc, msg: msg, sql: query} + } } - return &err + return xErrorCode(rc) } func (sqlt *sqlite) getfn(name string) api.Function { @@ -212,14 +212,10 @@ func (sqlt *sqlite) realloc(ptr ptr_t, size int64) ptr_t { } func (sqlt *sqlite) newBytes(b []byte) ptr_t { - if (*[0]byte)(b) == nil { + if len(b) == 0 { return 0 } - size := len(b) - if size == 0 { - size = 1 - } - ptr := sqlt.new(int64(size)) + ptr := sqlt.new(int64(len(b))) util.WriteBytes(sqlt.mod, ptr, b) return ptr } @@ -288,7 +284,7 @@ func (a *arena) new(size int64) ptr_t { } func (a *arena) bytes(b []byte) ptr_t { - if (*[0]byte)(b) == nil { + if len(b) == 0 { return 0 } ptr := a.new(int64(len(b))) diff --git a/vendor/github.com/ncruces/go-sqlite3/stmt.go b/vendor/github.com/ncruces/go-sqlite3/stmt.go index 4e17d1039..1ea726ea1 100644 --- a/vendor/github.com/ncruces/go-sqlite3/stmt.go +++ b/vendor/github.com/ncruces/go-sqlite3/stmt.go @@ -106,7 +106,14 @@ func (s *Stmt) Busy() bool { // // https://sqlite.org/c3ref/step.html func (s *Stmt) Step() bool { - s.c.checkInterrupt(s.c.handle) + if s.c.interrupt.Err() != nil { + s.err = INTERRUPT + return false + } + return s.step() +} + +func (s *Stmt) step() bool { rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle))) switch rc { case _ROW: @@ -131,7 +138,11 @@ func (s *Stmt) Err() error { // Exec is a convenience function that repeatedly calls [Stmt.Step] until it returns false, // then calls [Stmt.Reset] to reset the statement and get any error that occurred. func (s *Stmt) Exec() error { - for s.Step() { + if s.c.interrupt.Err() != nil { + return INTERRUPT + } + // TODO: implement this in C. + for s.step() { } return s.Reset() } @@ -254,13 +265,15 @@ func (s *Stmt) BindText(param int, value string) error { // BindRawText binds a []byte to the prepared statement as text. // The leftmost SQL parameter has an index of 1. -// Binding a nil slice is the same as calling [Stmt.BindNull]. // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindRawText(param int, value []byte) error { if len(value) > _MAX_LENGTH { return TOOBIG } + if len(value) == 0 { + return s.BindText(param, "") + } ptr := s.c.newBytes(value) rc := res_t(s.c.call("sqlite3_bind_text_go", stk_t(s.handle), stk_t(param), @@ -270,13 +283,15 @@ func (s *Stmt) BindRawText(param int, value []byte) error { // BindBlob binds a []byte to the prepared statement. // The leftmost SQL parameter has an index of 1. -// Binding a nil slice is the same as calling [Stmt.BindNull]. // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindBlob(param int, value []byte) error { if len(value) > _MAX_LENGTH { return TOOBIG } + if len(value) == 0 { + return s.BindZeroBlob(param, 0) + } ptr := s.c.newBytes(value) rc := res_t(s.c.call("sqlite3_bind_blob_go", stk_t(s.handle), stk_t(param), @@ -560,7 +575,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte { func (s *Stmt) ColumnRawText(col int) []byte { ptr := ptr_t(s.c.call("sqlite3_column_text", stk_t(s.handle), stk_t(col))) - return s.columnRawBytes(col, ptr) + return s.columnRawBytes(col, ptr, 1) } // ColumnRawBlob returns the value of the result column as a []byte. @@ -572,10 +587,10 @@ func (s *Stmt) ColumnRawText(col int) []byte { func (s *Stmt) ColumnRawBlob(col int) []byte { ptr := ptr_t(s.c.call("sqlite3_column_blob", stk_t(s.handle), stk_t(col))) - return s.columnRawBytes(col, ptr) + return s.columnRawBytes(col, ptr, 0) } -func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte { +func (s *Stmt) columnRawBytes(col int, ptr ptr_t, nul int32) []byte { if ptr == 0 { rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle))) if rc != _ROW && rc != _DONE { @@ -586,7 +601,7 @@ func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte { n := int32(s.c.call("sqlite3_column_bytes", stk_t(s.handle), stk_t(col))) - return util.View(s.c.mod, ptr, int64(n)) + return util.View(s.c.mod, ptr, int64(n+nul))[:n] } // ColumnJSON parses the JSON-encoded value of the result column @@ -633,22 +648,12 @@ func (s *Stmt) ColumnValue(col int) Value { // [INTEGER] columns will be retrieved as int64 values, // [FLOAT] as float64, [NULL] as nil, // [TEXT] as string, and [BLOB] as []byte. -// Any []byte are owned by SQLite and may be invalidated by -// subsequent calls to [Stmt] methods. func (s *Stmt) Columns(dest ...any) error { - defer s.c.arena.mark()() - count := int64(len(dest)) - typePtr := s.c.arena.new(count) - dataPtr := s.c.arena.new(count * 8) - - rc := res_t(s.c.call("sqlite3_columns_go", - stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr))) - if err := s.c.error(rc); err != nil { + types, ptr, err := s.columns(int64(len(dest))) + if err != nil { return err } - types := util.View(s.c.mod, typePtr, count) - // Avoid bounds checks on types below. if len(types) != len(dest) { panic(util.AssertErr()) @@ -657,26 +662,95 @@ func (s *Stmt) Columns(dest ...any) error { for i := range dest { switch types[i] { case byte(INTEGER): - dest[i] = util.Read64[int64](s.c.mod, dataPtr) + dest[i] = util.Read64[int64](s.c.mod, ptr) case byte(FLOAT): - dest[i] = util.ReadFloat64(s.c.mod, dataPtr) + dest[i] = util.ReadFloat64(s.c.mod, ptr) case byte(NULL): dest[i] = nil - default: - ptr := util.Read32[ptr_t](s.c.mod, dataPtr+0) - if ptr == 0 { - dest[i] = []byte{} - continue - } - len := util.Read32[int32](s.c.mod, dataPtr+4) - buf := util.View(s.c.mod, ptr, int64(len)) - if types[i] == byte(TEXT) { + case byte(TEXT): + len := util.Read32[int32](s.c.mod, ptr+4) + if len != 0 { + ptr := util.Read32[ptr_t](s.c.mod, ptr) + buf := util.View(s.c.mod, ptr, int64(len)) dest[i] = string(buf) } else { - dest[i] = buf + dest[i] = "" + } + case byte(BLOB): + len := util.Read32[int32](s.c.mod, ptr+4) + if len != 0 { + ptr := util.Read32[ptr_t](s.c.mod, ptr) + buf := util.View(s.c.mod, ptr, int64(len)) + tmp, _ := dest[i].([]byte) + dest[i] = append(tmp[:0], buf...) + } else { + dest[i], _ = dest[i].([]byte) } } - dataPtr += 8 + ptr += 8 } return nil } + +// ColumnsRaw populates result columns into the provided slice. +// The slice must have [Stmt.ColumnCount] length. +// +// [INTEGER] columns will be retrieved as int64 values, +// [FLOAT] as float64, [NULL] as nil, +// [TEXT] and [BLOB] as []byte. +// Any []byte are owned by SQLite and may be invalidated by +// subsequent calls to [Stmt] methods. +func (s *Stmt) ColumnsRaw(dest ...any) error { + types, ptr, err := s.columns(int64(len(dest))) + if err != nil { + return err + } + + // Avoid bounds checks on types below. + if len(types) != len(dest) { + panic(util.AssertErr()) + } + + for i := range dest { + switch types[i] { + case byte(INTEGER): + dest[i] = util.Read64[int64](s.c.mod, ptr) + case byte(FLOAT): + dest[i] = util.ReadFloat64(s.c.mod, ptr) + case byte(NULL): + dest[i] = nil + default: + len := util.Read32[int32](s.c.mod, ptr+4) + if len == 0 && types[i] == byte(BLOB) { + dest[i] = []byte{} + } else { + cap := len + if types[i] == byte(TEXT) { + cap++ + } + ptr := util.Read32[ptr_t](s.c.mod, ptr) + buf := util.View(s.c.mod, ptr, int64(cap))[:len] + dest[i] = buf + } + } + ptr += 8 + } + return nil +} + +func (s *Stmt) columns(count int64) ([]byte, ptr_t, error) { + defer s.c.arena.mark()() + typePtr := s.c.arena.new(count) + dataPtr := s.c.arena.new(count * 8) + + rc := res_t(s.c.call("sqlite3_columns_go", + stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr))) + if rc == res_t(MISUSE) { + return nil, 0, MISUSE + } + if err := s.c.error(rc); err != nil { + return nil, 0, err + } + + return util.View(s.c.mod, typePtr, count), dataPtr, nil +} diff --git a/vendor/github.com/ncruces/go-sqlite3/txn.go b/vendor/github.com/ncruces/go-sqlite3/txn.go index b24789f87..931b89958 100644 --- a/vendor/github.com/ncruces/go-sqlite3/txn.go +++ b/vendor/github.com/ncruces/go-sqlite3/txn.go @@ -2,7 +2,6 @@ package sqlite3 import ( "context" - "errors" "math/rand" "runtime" "strconv" @@ -21,11 +20,13 @@ type Txn struct { } // Begin starts a deferred transaction. +// It panics if a transaction is in-progress. +// For nested transactions, use [Conn.Savepoint]. // // https://sqlite.org/lang_transaction.html func (c *Conn) Begin() Txn { // BEGIN even if interrupted. - err := c.txnExecInterrupted(`BEGIN DEFERRED`) + err := c.exec(`BEGIN DEFERRED`) if err != nil { panic(err) } @@ -120,7 +121,8 @@ func (tx Txn) Commit() error { // // https://sqlite.org/lang_transaction.html func (tx Txn) Rollback() error { - return tx.c.txnExecInterrupted(`ROLLBACK`) + // ROLLBACK even if interrupted. + return tx.c.exec(`ROLLBACK`) } // Savepoint is a marker within a transaction @@ -143,7 +145,7 @@ func (c *Conn) Savepoint() Savepoint { // Names can be reused, but this makes catching bugs more likely. name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31()))) - err := c.txnExecInterrupted(`SAVEPOINT ` + name) + err := c.exec(`SAVEPOINT ` + name) if err != nil { panic(err) } @@ -199,7 +201,7 @@ func (s Savepoint) Release(errp *error) { return } // ROLLBACK and RELEASE even if interrupted. - err := s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name) + err := s.c.exec(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name) if err != nil { panic(err) } @@ -212,17 +214,7 @@ func (s Savepoint) Release(errp *error) { // https://sqlite.org/lang_transaction.html func (s Savepoint) Rollback() error { // ROLLBACK even if interrupted. - return s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name) -} - -func (c *Conn) txnExecInterrupted(sql string) error { - err := c.Exec(sql) - if errors.Is(err, INTERRUPT) { - old := c.SetInterrupt(context.Background()) - defer c.SetInterrupt(old) - err = c.Exec(sql) - } - return err + return s.c.exec(`ROLLBACK TO ` + s.name) } // TxnState determines the transaction state of a database. diff --git a/vendor/github.com/ncruces/go-sqlite3/util/osutil/open.go b/vendor/github.com/ncruces/go-sqlite3/util/osutil/open.go deleted file mode 100644 index 0242ad032..000000000 --- a/vendor/github.com/ncruces/go-sqlite3/util/osutil/open.go +++ /dev/null @@ -1,16 +0,0 @@ -//go:build !windows - -package osutil - -import ( - "io/fs" - "os" -) - -// OpenFile behaves the same as [os.OpenFile], -// except on Windows it sets [syscall.FILE_SHARE_DELETE]. -// -// See: https://go.dev/issue/32088#issuecomment-502850674 -func OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) { - return os.OpenFile(name, flag, perm) -} diff --git a/vendor/github.com/ncruces/go-sqlite3/util/osutil/open_windows.go b/vendor/github.com/ncruces/go-sqlite3/util/osutil/open_windows.go deleted file mode 100644 index febaf846e..000000000 --- a/vendor/github.com/ncruces/go-sqlite3/util/osutil/open_windows.go +++ /dev/null @@ -1,115 +0,0 @@ -package osutil - -import ( - "io/fs" - "os" - . "syscall" - "unsafe" -) - -// OpenFile behaves the same as [os.OpenFile], -// except on Windows it sets [syscall.FILE_SHARE_DELETE]. -// -// See: https://go.dev/issue/32088#issuecomment-502850674 -func OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) { - if name == "" { - return nil, &os.PathError{Op: "open", Path: name, Err: ENOENT} - } - r, e := syscallOpen(name, flag|O_CLOEXEC, uint32(perm.Perm())) - if e != nil { - return nil, &os.PathError{Op: "open", Path: name, Err: e} - } - return os.NewFile(uintptr(r), name), nil -} - -// syscallOpen is a copy of [syscall.Open] -// that uses [syscall.FILE_SHARE_DELETE]. -// -// https://go.dev/src/syscall/syscall_windows.go -func syscallOpen(path string, mode int, perm uint32) (fd Handle, err error) { - if len(path) == 0 { - return InvalidHandle, ERROR_FILE_NOT_FOUND - } - pathp, err := UTF16PtrFromString(path) - if err != nil { - return InvalidHandle, err - } - var access uint32 - switch mode & (O_RDONLY | O_WRONLY | O_RDWR) { - case O_RDONLY: - access = GENERIC_READ - case O_WRONLY: - access = GENERIC_WRITE - case O_RDWR: - access = GENERIC_READ | GENERIC_WRITE - } - if mode&O_CREAT != 0 { - access |= GENERIC_WRITE - } - if mode&O_APPEND != 0 { - access &^= GENERIC_WRITE - access |= FILE_APPEND_DATA - } - sharemode := uint32(FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE) - var sa *SecurityAttributes - if mode&O_CLOEXEC == 0 { - sa = makeInheritSa() - } - var createmode uint32 - switch { - case mode&(O_CREAT|O_EXCL) == (O_CREAT | O_EXCL): - createmode = CREATE_NEW - case mode&(O_CREAT|O_TRUNC) == (O_CREAT | O_TRUNC): - createmode = CREATE_ALWAYS - case mode&O_CREAT == O_CREAT: - createmode = OPEN_ALWAYS - case mode&O_TRUNC == O_TRUNC: - createmode = TRUNCATE_EXISTING - default: - createmode = OPEN_EXISTING - } - var attrs uint32 = FILE_ATTRIBUTE_NORMAL - if perm&S_IWRITE == 0 { - attrs = FILE_ATTRIBUTE_READONLY - if createmode == CREATE_ALWAYS { - const _ERROR_BAD_NETPATH = Errno(53) - // We have been asked to create a read-only file. - // If the file already exists, the semantics of - // the Unix open system call is to preserve the - // existing permissions. If we pass CREATE_ALWAYS - // and FILE_ATTRIBUTE_READONLY to CreateFile, - // and the file already exists, CreateFile will - // change the file permissions. - // Avoid that to preserve the Unix semantics. - h, e := CreateFile(pathp, access, sharemode, sa, TRUNCATE_EXISTING, FILE_ATTRIBUTE_NORMAL, 0) - switch e { - case ERROR_FILE_NOT_FOUND, _ERROR_BAD_NETPATH, ERROR_PATH_NOT_FOUND: - // File does not exist. These are the same - // errors as Errno.Is checks for ErrNotExist. - // Carry on to create the file. - default: - // Success or some different error. - return h, e - } - } - } - if createmode == OPEN_EXISTING && access == GENERIC_READ { - // Necessary for opening directory handles. - attrs |= FILE_FLAG_BACKUP_SEMANTICS - } - if mode&O_SYNC != 0 { - const _FILE_FLAG_WRITE_THROUGH = 0x80000000 - attrs |= _FILE_FLAG_WRITE_THROUGH - } - if mode&O_NONBLOCK != 0 { - attrs |= FILE_FLAG_OVERLAPPED - } - return CreateFile(pathp, access, sharemode, sa, createmode, attrs, 0) -} - -func makeInheritSa() *SecurityAttributes { - var sa SecurityAttributes - sa.Length = uint32(unsafe.Sizeof(sa)) - sa.InheritHandle = 1 - return &sa -} diff --git a/vendor/github.com/ncruces/go-sqlite3/util/osutil/osfs.go b/vendor/github.com/ncruces/go-sqlite3/util/osutil/osfs.go deleted file mode 100644 index 2e1195934..000000000 --- a/vendor/github.com/ncruces/go-sqlite3/util/osutil/osfs.go +++ /dev/null @@ -1,33 +0,0 @@ -package osutil - -import ( - "io/fs" - "os" -) - -// FS implements [fs.FS], [fs.StatFS], and [fs.ReadFileFS] -// using package [os]. -// -// This filesystem does not respect [fs.ValidPath] rules, -// and fails [testing/fstest.TestFS]! -// -// Still, it can be a useful tool to unify implementations -// that can access either the [os] filesystem or an [fs.FS]. -// It's OK to use this to open files, but you should avoid -// opening directories, resolving paths, or walking the file system. -type FS struct{} - -// Open implements [fs.FS]. -func (FS) Open(name string) (fs.File, error) { - return OpenFile(name, os.O_RDONLY, 0) -} - -// ReadFileFS implements [fs.StatFS]. -func (FS) Stat(name string) (fs.FileInfo, error) { - return os.Stat(name) -} - -// ReadFile implements [fs.ReadFileFS]. -func (FS) ReadFile(name string) ([]byte, error) { - return os.ReadFile(name) -} diff --git a/vendor/github.com/ncruces/go-sqlite3/util/osutil/osutil.go b/vendor/github.com/ncruces/go-sqlite3/util/osutil/osutil.go deleted file mode 100644 index 83444e906..000000000 --- a/vendor/github.com/ncruces/go-sqlite3/util/osutil/osutil.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package osutil implements operating system utilities. -package osutil diff --git a/vendor/github.com/ncruces/go-sqlite3/util/sql3util/sql3util.go b/vendor/github.com/ncruces/go-sqlite3/util/sql3util/sql3util.go index 6be61927d..f2e33c0b2 100644 --- a/vendor/github.com/ncruces/go-sqlite3/util/sql3util/sql3util.go +++ b/vendor/github.com/ncruces/go-sqlite3/util/sql3util/sql3util.go @@ -5,5 +5,5 @@ package sql3util // // https://sqlite.org/fileformat.html#pages func ValidPageSize(s int) bool { - return 512 <= s && s <= 65536 && s&(s-1) == 0 + return s&(s-1) == 0 && 512 <= s && s <= 65536 } diff --git a/vendor/github.com/ncruces/go-sqlite3/value.go b/vendor/github.com/ncruces/go-sqlite3/value.go index a2399fba0..6753027b5 100644 --- a/vendor/github.com/ncruces/go-sqlite3/value.go +++ b/vendor/github.com/ncruces/go-sqlite3/value.go @@ -139,7 +139,7 @@ func (v Value) Blob(buf []byte) []byte { // https://sqlite.org/c3ref/value_blob.html func (v Value) RawText() []byte { ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected())) - return v.rawBytes(ptr) + return v.rawBytes(ptr, 1) } // RawBlob returns the value as a []byte. @@ -149,16 +149,16 @@ func (v Value) RawText() []byte { // https://sqlite.org/c3ref/value_blob.html func (v Value) RawBlob() []byte { ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected())) - return v.rawBytes(ptr) + return v.rawBytes(ptr, 0) } -func (v Value) rawBytes(ptr ptr_t) []byte { +func (v Value) rawBytes(ptr ptr_t, nul int32) []byte { if ptr == 0 { return nil } n := int32(v.c.call("sqlite3_value_bytes", v.protected())) - return util.View(v.c.mod, ptr, int64(n)) + return util.View(v.c.mod, ptr, int64(n+nul))[:n] } // Pointer gets the pointer associated with this value, diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/README.md b/vendor/github.com/ncruces/go-sqlite3/vfs/README.md index 4e987ce3f..17c24ec65 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/README.md +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/README.md @@ -6,22 +6,30 @@ It replaces the default SQLite VFS with a **pure Go** implementation, and exposes [interfaces](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs#VFS) that should allow you to implement your own [custom VFSes](#custom-vfses). -Since it is a from scratch reimplementation, -there are naturally some ways it deviates from the original. +See the [support matrix](https://github.com/ncruces/go-sqlite3/wiki/Support-matrix) +for the list of supported OS and CPU architectures. -The main differences are [file locking](#file-locking) and [WAL mode](#write-ahead-logging) support. +Since this is a from scratch reimplementation, +there are naturally some ways it deviates from the original. +It's also not as battle tested as the original. + +The main differences to be aware of are +[file locking](#file-locking) and +[WAL mode](#write-ahead-logging) support. ### File Locking -POSIX advisory locks, which SQLite uses on Unix, are -[broken by design](https://github.com/sqlite/sqlite/blob/b74eb0/src/os_unix.c#L1073-L1161). +POSIX advisory locks, +which SQLite uses on [Unix](https://github.com/sqlite/sqlite/blob/5d60f4/src/os_unix.c#L13-L14), +are [broken by design](https://github.com/sqlite/sqlite/blob/5d60f4/src/os_unix.c#L1074-L1162). Instead, on Linux and macOS, this package uses [OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html) to synchronize access to database files. This package can also use [BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2), -albeit with reduced concurrency (`BEGIN IMMEDIATE` behaves like `BEGIN EXCLUSIVE`). +albeit with reduced concurrency (`BEGIN IMMEDIATE` behaves like `BEGIN EXCLUSIVE`, +[docs](https://sqlite.org/lang_transaction.html#immediate)). BSD locks are the default on BSD and illumos, but you can opt into them with the `sqlite3_flock` build tag. @@ -44,11 +52,11 @@ to check if your build supports file locking. ### Write-Ahead Logging -On Unix, this package may use `mmap` to implement +On Unix, this package uses `mmap` to implement [shared-memory for the WAL-index](https://sqlite.org/wal.html#implementation_of_shared_memory_for_the_wal_index), like SQLite. -On Windows, this package may use `MapViewOfFile`, like SQLite. +On Windows, this package uses `MapViewOfFile`, like SQLite. You can also opt into a cross-platform, in-process, memory sharing implementation with the `sqlite3_dotlk` build tag. @@ -63,6 +71,11 @@ you must disable connection pooling by calling You can use [`vfs.SupportsSharedMemory`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs#SupportsSharedMemory) to check if your build supports shared memory. +### Blocking Locks + +On Windows and macOS, this package implements +[Wal-mode blocking locks](https://sqlite.org/src/doc/tip/doc/wal-lock.md). + ### Batch-Atomic Write On Linux, this package may support @@ -94,8 +107,10 @@ The VFS can be customized with a few build tags: > [`unix-flock` VFS](https://sqlite.org/compile.html#enable_locking_style); > `sqlite3_dotlk` builds are compatible with the > [`unix-dotfile` VFS](https://sqlite.org/compile.html#enable_locking_style). -> If incompatible file locking is used, accessing databases concurrently with -> _other_ SQLite libraries will eventually corrupt data. + +> [!CAUTION] +> Concurrently accessing databases using incompatible VFSes +> will eventually corrupt data. ### Custom VFSes diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/cksm.go b/vendor/github.com/ncruces/go-sqlite3/vfs/cksm.go index 041defec3..0ff7b6f18 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/cksm.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/cksm.go @@ -49,9 +49,7 @@ func (c cksmFile) ReadAt(p []byte, off int64) (n int, err error) { n, err = c.File.ReadAt(p, off) p = p[:n] - // SQLite is reading the header of a database file. - if c.isDB && off == 0 && len(p) >= 100 && - bytes.HasPrefix(p, []byte("SQLite format 3\000")) { + if isHeader(c.isDB, p, off) { c.init((*[100]byte)(p)) } @@ -67,9 +65,7 @@ func (c cksmFile) ReadAt(p []byte, off int64) (n int, err error) { } func (c cksmFile) WriteAt(p []byte, off int64) (n int, err error) { - // SQLite is writing the first page of a database file. - if c.isDB && off == 0 && len(p) >= 100 && - bytes.HasPrefix(p, []byte("SQLite format 3\000")) { + if isHeader(c.isDB, p, off) { c.init((*[100]byte)(p)) } @@ -116,9 +112,11 @@ func (c cksmFile) fileControl(ctx context.Context, mod api.Module, op _FcntlOpco c.inCkpt = true case _FCNTL_CKPT_DONE: c.inCkpt = false - } - if rc := vfsFileControlImpl(ctx, mod, c, op, pArg); rc != _NOTFOUND { - return rc + case _FCNTL_PRAGMA: + rc := vfsFileControlImpl(ctx, mod, c, op, pArg) + if rc != _NOTFOUND { + return rc + } } return vfsFileControlImpl(ctx, mod, c.File, op, pArg) } @@ -135,6 +133,14 @@ func (f *cksmFlags) init(header *[100]byte) { } } +func isHeader(isDB bool, p []byte, off int64) bool { + check := sql3util.ValidPageSize(len(p)) + if isDB { + check = off == 0 && len(p) >= 100 + } + return check && bytes.HasPrefix(p, []byte("SQLite format 3\000")) +} + func cksmCompute(a []byte) (cksm [8]byte) { var s1, s2 uint32 for len(a) >= 8 { diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/file.go b/vendor/github.com/ncruces/go-sqlite3/vfs/file.go index 0a3c9d622..65409823c 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/file.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/file.go @@ -6,9 +6,8 @@ import ( "io/fs" "os" "path/filepath" + "runtime" "syscall" - - "github.com/ncruces/go-sqlite3/util/osutil" ) type vfsOS struct{} @@ -40,7 +39,7 @@ func (vfsOS) Delete(path string, syncDir bool) error { if err != nil { return err } - if canSyncDirs && syncDir { + if isUnix && syncDir { f, err := os.Open(filepath.Dir(path)) if err != nil { return _OK @@ -96,7 +95,7 @@ func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error if name == nil { f, err = os.CreateTemp(os.Getenv("SQLITE_TMPDIR"), "*.db") } else { - f, err = osutil.OpenFile(name.String(), oflags, 0666) + f, err = os.OpenFile(name.String(), oflags, 0666) } if err != nil { if name == nil { @@ -118,15 +117,17 @@ func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error return nil, flags, _IOERR_FSTAT } } - if flags&OPEN_DELETEONCLOSE != 0 { + if isUnix && flags&OPEN_DELETEONCLOSE != 0 { os.Remove(f.Name()) } file := vfsFile{ File: f, psow: true, + atomic: osBatchAtomic(f), readOnly: flags&OPEN_READONLY != 0, - syncDir: canSyncDirs && isCreate && isJournl, + syncDir: isUnix && isCreate && isJournl, + delete: !isUnix && flags&OPEN_DELETEONCLOSE != 0, shm: NewSharedMemory(name.String()+"-shm", flags), } return &file, flags, nil @@ -139,6 +140,8 @@ type vfsFile struct { readOnly bool keepWAL bool syncDir bool + atomic bool + delete bool psow bool } @@ -152,6 +155,9 @@ var ( ) func (f *vfsFile) Close() error { + if f.delete { + defer os.Remove(f.Name()) + } if f.shm != nil { f.shm.Close() } @@ -175,7 +181,7 @@ func (f *vfsFile) Sync(flags SyncFlag) error { if err != nil { return err } - if canSyncDirs && f.syncDir { + if isUnix && f.syncDir { f.syncDir = false d, err := os.Open(filepath.Dir(f.File.Name())) if err != nil { @@ -200,12 +206,15 @@ func (f *vfsFile) SectorSize() int { func (f *vfsFile) DeviceCharacteristics() DeviceCharacteristic { ret := IOCAP_SUBPAGE_READ - if osBatchAtomic(f.File) { + if f.atomic { ret |= IOCAP_BATCH_ATOMIC } if f.psow { ret |= IOCAP_POWERSAFE_OVERWRITE } + if runtime.GOOS == "windows" { + ret |= IOCAP_UNDELETABLE_WHEN_OPEN + } return ret } @@ -214,6 +223,9 @@ func (f *vfsFile) SizeHint(size int64) error { } func (f *vfsFile) HasMoved() (bool, error) { + if runtime.GOOS == "windows" { + return false, nil + } fi, err := f.Stat() if err != nil { return false, err diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/os_bsd.go b/vendor/github.com/ncruces/go-sqlite3/vfs/os_bsd.go index 4f6fadef4..4542f8e7c 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/os_bsd.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/os_bsd.go @@ -50,11 +50,15 @@ func osDowngradeLock(file *os.File, _ LockLevel) _ErrorCode { } func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode { - err := unix.Flock(int(file.Fd()), unix.LOCK_UN) - if err != nil { - return _IOERR_UNLOCK + for { + err := unix.Flock(int(file.Fd()), unix.LOCK_UN) + if err == nil { + return _OK + } + if err != unix.EINTR { + return _IOERR_UNLOCK + } } - return _OK } func osCheckReservedLock(file *os.File) (bool, _ErrorCode) { @@ -89,13 +93,18 @@ func osLock(file *os.File, typ int16, start, len int64, def _ErrorCode) _ErrorCo } func osUnlock(file *os.File, start, len int64) _ErrorCode { - err := unix.FcntlFlock(file.Fd(), unix.F_SETLK, &unix.Flock_t{ + lock := unix.Flock_t{ Type: unix.F_UNLCK, Start: start, Len: len, - }) - if err != nil { - return _IOERR_UNLOCK } - return _OK + for { + err := unix.FcntlFlock(file.Fd(), unix.F_SETLK, &lock) + if err == nil { + return _OK + } + if err != unix.EINTR { + return _IOERR_UNLOCK + } + } } diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/os_darwin.go b/vendor/github.com/ncruces/go-sqlite3/vfs/os_darwin.go index 07de7c3d8..ee08e9a7b 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/os_darwin.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/os_darwin.go @@ -27,7 +27,12 @@ func osSync(file *os.File, fullsync, _ /*dataonly*/ bool) error { if fullsync { return file.Sync() } - return unix.Fsync(int(file.Fd())) + for { + err := unix.Fsync(int(file.Fd())) + if err != unix.EINTR { + return err + } + } } func osAllocate(file *os.File, size int64) error { @@ -85,13 +90,18 @@ func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, d } func osUnlock(file *os.File, start, len int64) _ErrorCode { - err := unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &unix.Flock_t{ + lock := unix.Flock_t{ Type: unix.F_UNLCK, Start: start, Len: len, - }) - if err != nil { - return _IOERR_UNLOCK } - return _OK + for { + err := unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &lock) + if err == nil { + return _OK + } + if err != unix.EINTR { + return _IOERR_UNLOCK + } + } } diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/os_linux.go b/vendor/github.com/ncruces/go-sqlite3/vfs/os_linux.go index 6199c7b00..d112c5a99 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/os_linux.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/os_linux.go @@ -3,6 +3,7 @@ package vfs import ( + "io" "os" "time" @@ -11,14 +12,36 @@ import ( func osSync(file *os.File, _ /*fullsync*/, _ /*dataonly*/ bool) error { // SQLite trusts Linux's fdatasync for all fsync's. - return unix.Fdatasync(int(file.Fd())) + for { + err := unix.Fdatasync(int(file.Fd())) + if err != unix.EINTR { + return err + } + } } func osAllocate(file *os.File, size int64) error { if size == 0 { return nil } - return unix.Fallocate(int(file.Fd()), 0, 0, size) + for { + err := unix.Fallocate(int(file.Fd()), 0, 0, size) + if err == unix.EOPNOTSUPP { + break + } + if err != unix.EINTR { + return err + } + } + off, err := file.Seek(0, io.SeekEnd) + if err != nil { + return err + } + if size <= off { + return nil + } + return file.Truncate(size) + } func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode { @@ -37,22 +60,27 @@ func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, d } var err error switch { - case timeout < 0: - err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLKW, &lock) default: err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock) + case timeout < 0: + err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLKW, &lock) } return osLockErrorCode(err, def) } func osUnlock(file *os.File, start, len int64) _ErrorCode { - err := unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &unix.Flock_t{ + lock := unix.Flock_t{ Type: unix.F_UNLCK, Start: start, Len: len, - }) - if err != nil { - return _IOERR_UNLOCK } - return _OK + for { + err := unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock) + if err == nil { + return _OK + } + if err != unix.EINTR { + return _IOERR_UNLOCK + } + } } diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/os_std.go b/vendor/github.com/ncruces/go-sqlite3/vfs/os_std.go index 0d0ca24c9..a48c71e9f 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/os_std.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/os_std.go @@ -8,8 +8,8 @@ import ( ) const ( + isUnix = false _O_NOFOLLOW = 0 - canSyncDirs = false ) func osAccess(path string, flags AccessFlag) error { diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/os_unix.go b/vendor/github.com/ncruces/go-sqlite3/vfs/os_unix.go index 9f42b5f6c..ec312ccd3 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/os_unix.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/os_unix.go @@ -10,8 +10,8 @@ import ( ) const ( + isUnix = true _O_NOFOLLOW = unix.O_NOFOLLOW - canSyncDirs = true ) func osAccess(path string, flags AccessFlag) error { @@ -65,10 +65,15 @@ func osTestLock(file *os.File, start, len int64) (int16, _ErrorCode) { Start: start, Len: len, } - if unix.FcntlFlock(file.Fd(), unix.F_GETLK, &lock) != nil { - return 0, _IOERR_CHECKRESERVEDLOCK + for { + err := unix.FcntlFlock(file.Fd(), unix.F_GETLK, &lock) + if err == nil { + return lock.Type, _OK + } + if err != unix.EINTR { + return 0, _IOERR_CHECKRESERVEDLOCK + } } - return lock.Type, _OK } func osLockErrorCode(err error, def _ErrorCode) _ErrorCode { diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/os_windows.go b/vendor/github.com/ncruces/go-sqlite3/vfs/os_windows.go index ecce3cfa2..0a6693de5 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/os_windows.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/os_windows.go @@ -135,12 +135,10 @@ func osWriteLock(file *os.File, start, len uint32, timeout time.Duration) _Error func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode { var err error switch { - case timeout == 0: + default: err = osLockEx(file, flags|windows.LOCKFILE_FAIL_IMMEDIATELY, start, len) case timeout < 0: err = osLockEx(file, flags, start, len) - default: - err = osLockExTimeout(file, flags, start, len, timeout) } return osLockErrorCode(err, def) } @@ -162,37 +160,6 @@ func osLockEx(file *os.File, flags, start, len uint32) error { 0, len, 0, &windows.Overlapped{Offset: start}) } -func osLockExTimeout(file *os.File, flags, start, len uint32, timeout time.Duration) error { - event, err := windows.CreateEvent(nil, 1, 0, nil) - if err != nil { - return err - } - defer windows.CloseHandle(event) - - fd := windows.Handle(file.Fd()) - overlapped := &windows.Overlapped{ - Offset: start, - HEvent: event, - } - - err = windows.LockFileEx(fd, flags, 0, len, 0, overlapped) - if err != windows.ERROR_IO_PENDING { - return err - } - - ms := (timeout + time.Millisecond - 1) / time.Millisecond - rc, err := windows.WaitForSingleObject(event, uint32(ms)) - if rc == windows.WAIT_OBJECT_0 { - return nil - } - defer windows.CancelIoEx(fd, overlapped) - - if err != nil { - return err - } - return windows.Errno(rc) -} - func osLockErrorCode(err error, def _ErrorCode) _ErrorCode { if err == nil { return _OK diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_bsd.go b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_bsd.go index 11e7bb2fd..be1495d99 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_bsd.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_bsd.go @@ -68,16 +68,11 @@ func (s *vfsShm) Close() error { panic(util.AssertErr()) } -func (s *vfsShm) shmOpen() _ErrorCode { +func (s *vfsShm) shmOpen() (rc _ErrorCode) { if s.vfsShmParent != nil { return _OK } - var f *os.File - // Close file on error. - // Keep this here to avoid confusing checklocks. - defer func() { f.Close() }() - vfsShmListMtx.Lock() defer vfsShmListMtx.Unlock() @@ -98,11 +93,16 @@ func (s *vfsShm) shmOpen() _ErrorCode { } // Always open file read-write, as it will be shared. - f, err = os.OpenFile(s.path, + f, err := os.OpenFile(s.path, os.O_RDWR|os.O_CREATE|_O_NOFOLLOW, 0666) if err != nil { return _CANTOPEN } + defer func() { + if rc != _OK { + f.Close() + } + }() // Dead man's switch. if lock, rc := osTestLock(f, _SHM_DMS, 1); rc != _OK { @@ -131,7 +131,6 @@ func (s *vfsShm) shmOpen() _ErrorCode { File: f, info: fi, } - f = nil // Don't close the file. for i, g := range vfsShmList { if g == nil { vfsShmList[i] = s.vfsShmParent diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_windows.go b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_windows.go index ed2e93f8e..7cc5b2a23 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_windows.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_windows.go @@ -7,14 +7,11 @@ import ( "io" "os" "sync" - "syscall" - "time" "github.com/tetratelabs/wazero/api" "golang.org/x/sys/windows" "github.com/ncruces/go-sqlite3/internal/util" - "github.com/ncruces/go-sqlite3/util/osutil" ) type vfsShm struct { @@ -33,8 +30,6 @@ type vfsShm struct { sync.Mutex } -var _ blockingSharedMemory = &vfsShm{} - func (s *vfsShm) Close() error { // Unmap regions. for _, r := range s.regions { @@ -48,8 +43,7 @@ func (s *vfsShm) Close() error { func (s *vfsShm) shmOpen() _ErrorCode { if s.File == nil { - f, err := osutil.OpenFile(s.path, - os.O_RDWR|os.O_CREATE|syscall.O_NONBLOCK, 0666) + f, err := os.OpenFile(s.path, os.O_RDWR|os.O_CREATE, 0666) if err != nil { return _CANTOPEN } @@ -67,7 +61,7 @@ func (s *vfsShm) shmOpen() _ErrorCode { return _IOERR_SHMOPEN } } - rc := osReadLock(s.File, _SHM_DMS, 1, time.Millisecond) + rc := osReadLock(s.File, _SHM_DMS, 1, 0) s.fileLock = rc == _OK return rc } @@ -135,11 +129,6 @@ func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, ext } func (s *vfsShm) shmLock(offset, n int32, flags _ShmFlag) (rc _ErrorCode) { - var timeout time.Duration - if s.blocking { - timeout = time.Millisecond - } - switch { case flags&_SHM_LOCK != 0: defer s.shmAcquire(&rc) @@ -151,9 +140,9 @@ func (s *vfsShm) shmLock(offset, n int32, flags _ShmFlag) (rc _ErrorCode) { case flags&_SHM_UNLOCK != 0: return osUnlock(s.File, _SHM_BASE+uint32(offset), uint32(n)) case flags&_SHM_SHARED != 0: - return osReadLock(s.File, _SHM_BASE+uint32(offset), uint32(n), timeout) + return osReadLock(s.File, _SHM_BASE+uint32(offset), uint32(n), 0) case flags&_SHM_EXCLUSIVE != 0: - return osWriteLock(s.File, _SHM_BASE+uint32(offset), uint32(n), timeout) + return osWriteLock(s.File, _SHM_BASE+uint32(offset), uint32(n), 0) default: panic(util.AssertErr()) } @@ -184,7 +173,3 @@ func (s *vfsShm) shmUnmap(delete bool) { os.Remove(s.path) } } - -func (s *vfsShm) shmEnableBlocking(block bool) { - s.blocking = block -} diff --git a/vendor/github.com/ncruces/go-sqlite3/vtab.go b/vendor/github.com/ncruces/go-sqlite3/vtab.go index 884aaaa0c..16ff2806b 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vtab.go +++ b/vendor/github.com/ncruces/go-sqlite3/vtab.go @@ -79,9 +79,12 @@ func implements[T any](typ reflect.Type) bool { // // https://sqlite.org/c3ref/declare_vtab.html func (c *Conn) DeclareVTab(sql string) error { + if c.interrupt.Err() != nil { + return INTERRUPT + } defer c.arena.mark()() - sqlPtr := c.arena.string(sql) - rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(sqlPtr))) + textPtr := c.arena.string(sql) + rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(textPtr))) return c.error(rc) } @@ -162,6 +165,7 @@ type VTabDestroyer interface { } // A VTabUpdater allows a virtual table to be updated. +// Implementations must not retain arg. type VTabUpdater interface { VTab // https://sqlite.org/vtab.html#xupdate @@ -241,6 +245,7 @@ type VTabSavepointer interface { // to loop through the virtual table. // A VTabCursor may optionally implement // [io.Closer] to free resources. +// Implementations of Filter must not retain arg. // // https://sqlite.org/c3ref/vtab_cursor.html type VTabCursor interface { @@ -489,12 +494,12 @@ func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo } func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, pArg, pRowID ptr_t) res_t { - vtab := vtabGetHandle(ctx, mod, pVTab).(VTabUpdater) - db := ctx.Value(connKey{}).(*Conn) - args := make([]Value, nArg) - callbackArgs(db, args, pArg) - rowID, err := vtab.Update(args...) + args := callbackArgs(db, nArg, pArg) + defer returnArgs(args) + + vtab := vtabGetHandle(ctx, mod, pVTab).(VTabUpdater) + rowID, err := vtab.Update(*args...) if err == nil { util.Write64(mod, pRowID, rowID) } @@ -593,15 +598,17 @@ func cursorCloseCallback(ctx context.Context, mod api.Module, pCur ptr_t) res_t } func cursorFilterCallback(ctx context.Context, mod api.Module, pCur ptr_t, idxNum int32, idxStr ptr_t, nArg int32, pArg ptr_t) res_t { - cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) db := ctx.Value(connKey{}).(*Conn) - args := make([]Value, nArg) - callbackArgs(db, args, pArg) + args := callbackArgs(db, nArg, pArg) + defer returnArgs(args) + var idxName string if idxStr != 0 { idxName = util.ReadString(mod, idxStr, _MAX_LENGTH) } - err := cursor.Filter(int(idxNum), idxName, args...) + + cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) + err := cursor.Filter(int(idxNum), idxName, *args...) return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) } diff --git a/vendor/modules.txt b/vendor/modules.txt index a6317e4d9..5e421a2da 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -691,7 +691,7 @@ github.com/modern-go/reflect2 # github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 ## explicit github.com/munnerz/goautoneg -# github.com/ncruces/go-sqlite3 v0.24.0 +# github.com/ncruces/go-sqlite3 v0.25.0 ## explicit; go 1.23.0 github.com/ncruces/go-sqlite3 github.com/ncruces/go-sqlite3/driver @@ -699,7 +699,6 @@ github.com/ncruces/go-sqlite3/embed github.com/ncruces/go-sqlite3/internal/alloc github.com/ncruces/go-sqlite3/internal/dotlk github.com/ncruces/go-sqlite3/internal/util -github.com/ncruces/go-sqlite3/util/osutil github.com/ncruces/go-sqlite3/util/sql3util github.com/ncruces/go-sqlite3/vfs github.com/ncruces/go-sqlite3/vfs/memdb From b1844323314dd1f0832f1fcdb765a7f67ca01dbc Mon Sep 17 00:00:00 2001 From: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Fri, 4 Apr 2025 18:29:22 +0200 Subject: [PATCH 3/4] [feature] Allow editing domain blocks/allows, fix comment import (#3967) * start implementing editing of existing domain permissions * [feature] Allow editing domain blocks/allows, fix comment import * [bugfix] Use "comment" via /api/v1/instance * fix the stuff --- docs/admin/domain_permission_subscriptions.md | 21 ++ docs/api/swagger.yaml | 116 +++++++- internal/api/client/admin/admin.go | 2 + .../api/client/admin/domainallowupdate.go | 91 +++++++ .../api/client/admin/domainblockupdate.go | 91 +++++++ internal/api/client/admin/domainpermission.go | 89 +++++- .../admin/domainpermissiondraftcreate.go | 7 +- .../domainpermissionsubscriptiontest_test.go | 28 +- .../client/instance/instancepeersget_test.go | 10 +- internal/api/model/domain.go | 17 +- internal/db/bundb/domain.go | 4 +- internal/db/bundb/domain_test.go | 14 +- .../domainpermissionsubscription_test.go | 2 +- internal/db/domain.go | 8 +- internal/gtsmodel/domainallow.go | 2 +- internal/gtsmodel/domainblock.go | 2 +- internal/processing/admin/domainallow.go | 50 +++- internal/processing/admin/domainblock.go | 50 +++- internal/processing/admin/domainpermission.go | 179 +++++++++--- internal/processing/instance.go | 6 +- internal/subscriptions/domainperms.go | 39 +-- internal/subscriptions/subscriptions_test.go | 4 +- internal/typeutils/internaltofrontend.go | 6 +- testrig/transportcontroller.go | 2 +- .../query/admin/domain-permissions/import.ts | 38 +-- .../query/admin/domain-permissions/update.ts | 43 +++ .../settings/lib/types/domain-permission.ts | 8 +- web/source/settings/style.css | 41 ++- .../moderation/domain-permissions/detail.tsx | 255 +++++++++++------- .../domain-permissions/import-export.tsx | 2 +- .../moderation/domain-permissions/process.tsx | 103 ++++--- .../settings/views/moderation/router.tsx | 4 +- 32 files changed, 1021 insertions(+), 313 deletions(-) create mode 100644 internal/api/client/admin/domainallowupdate.go create mode 100644 internal/api/client/admin/domainblockupdate.go diff --git a/docs/admin/domain_permission_subscriptions.md b/docs/admin/domain_permission_subscriptions.md index 77ec831e1..78518e187 100644 --- a/docs/admin/domain_permission_subscriptions.md +++ b/docs/admin/domain_permission_subscriptions.md @@ -113,6 +113,27 @@ nothanks.com,suspend,false,false,,false JSON lists use content type `application/json`. +```json +[ + { + "domain": "bumfaces.net", + "suspended_at": "2020-05-13T13:29:12.000Z", + "comment": "big jerks" + }, + { + "domain": "peepee.poopoo", + "suspended_at": "2020-05-13T13:29:12.000Z", + "comment": "harassment" + }, + { + "domain": "nothanks.com", + "suspended_at": "2020-05-13T13:29:12.000Z" + } +] +``` + +As an alternative to `"comment"`, `"public_comment"` will also work: + ```json [ { diff --git a/docs/api/swagger.yaml b/docs/api/swagger.yaml index e1c1c14e9..778b1c843 100644 --- a/docs/api/swagger.yaml +++ b/docs/api/swagger.yaml @@ -1099,13 +1099,22 @@ definitions: domain: description: Domain represents a remote domain properties: + comment: + description: |- + If the domain is blocked, what's the publicly-stated reason for the block. + Alternative to `public_comment` to be used when serializing/deserializing via /api/v1/instance. + example: they smell + type: string + x-go-name: Comment domain: description: The hostname of the domain. example: example.org type: string x-go-name: Domain public_comment: - description: If the domain is blocked, what's the publicly-stated reason for the block. + description: |- + If the domain is blocked, what's the publicly-stated reason for the block. + Alternative to `comment` to be used when serializing/deserializing NOT via /api/v1/instance. example: they smell type: string x-go-name: PublicComment @@ -1124,6 +1133,13 @@ definitions: x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model domainPermission: properties: + comment: + description: |- + If the domain is blocked, what's the publicly-stated reason for the block. + Alternative to `public_comment` to be used when serializing/deserializing via /api/v1/instance. + example: they smell + type: string + x-go-name: Comment created_at: description: Time at which the permission entry was created (ISO 8601 Datetime). example: "2021-07-30T09:20:25+00:00" @@ -1162,7 +1178,9 @@ definitions: type: string x-go-name: PrivateComment public_comment: - description: If the domain is blocked, what's the publicly-stated reason for the block. + description: |- + If the domain is blocked, what's the publicly-stated reason for the block. + Alternative to `comment` to be used when serializing/deserializing NOT via /api/v1/instance. example: they smell type: string x-go-name: PublicComment @@ -5823,6 +5841,53 @@ paths: summary: View domain allow with the given ID. tags: - admin + put: + consumes: + - multipart/form-data + operationId: domainAllowUpdate + parameters: + - description: The id of the domain allow. + in: path + name: id + required: true + type: string + - description: Obfuscate the name of the domain when serving it publicly. Eg., `example.org` becomes something like `ex***e.org`. + in: formData + name: obfuscate + type: boolean + - description: Public comment about this domain allow. This will be displayed alongside the domain allow if you choose to share allows. + in: formData + name: public_comment + type: string + - description: Private comment about this domain allow. Will only be shown to other admins, so this is a useful way of internally keeping track of why a certain domain ended up allowed. + in: formData + name: private_comment + type: string + produces: + - application/json + responses: + "200": + description: The updated domain allow. + schema: + $ref: '#/definitions/domainPermission' + "400": + description: bad request + "401": + description: unauthorized + "403": + description: forbidden + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - admin:write:domain_allows + summary: Update a single domain allow. + tags: + - admin /api/v1/admin/domain_blocks: get: operationId: domainBlocksGet @@ -5990,6 +6055,53 @@ paths: summary: View domain block with the given ID. tags: - admin + put: + consumes: + - multipart/form-data + operationId: domainBlockUpdate + parameters: + - description: The id of the domain block. + in: path + name: id + required: true + type: string + - description: Obfuscate the name of the domain when serving it publicly. Eg., `example.org` becomes something like `ex***e.org`. + in: formData + name: obfuscate + type: boolean + - description: Public comment about this domain block. This will be displayed alongside the domain block if you choose to share blocks. + in: formData + name: public_comment + type: string + - description: Private comment about this domain block. Will only be shown to other admins, so this is a useful way of internally keeping track of why a certain domain ended up blocked. + in: formData + name: private_comment + type: string + produces: + - application/json + responses: + "200": + description: The updated domain block. + schema: + $ref: '#/definitions/domainPermission' + "400": + description: bad request + "401": + description: unauthorized + "403": + description: forbidden + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - admin:write:domain_blocks + summary: Update a single domain block. + tags: + - admin /api/v1/admin/domain_keys_expire: post: consumes: diff --git a/internal/api/client/admin/admin.go b/internal/api/client/admin/admin.go index a5a16f35f..01a5796ae 100644 --- a/internal/api/client/admin/admin.go +++ b/internal/api/client/admin/admin.go @@ -102,12 +102,14 @@ func (m *Module) Route(attachHandler func(method string, path string, f ...gin.H attachHandler(http.MethodPost, DomainBlocksPath, m.DomainBlocksPOSTHandler) attachHandler(http.MethodGet, DomainBlocksPath, m.DomainBlocksGETHandler) attachHandler(http.MethodGet, DomainBlocksPathWithID, m.DomainBlockGETHandler) + attachHandler(http.MethodPut, DomainBlocksPathWithID, m.DomainBlockUpdatePUTHandler) attachHandler(http.MethodDelete, DomainBlocksPathWithID, m.DomainBlockDELETEHandler) // domain allow stuff attachHandler(http.MethodPost, DomainAllowsPath, m.DomainAllowsPOSTHandler) attachHandler(http.MethodGet, DomainAllowsPath, m.DomainAllowsGETHandler) attachHandler(http.MethodGet, DomainAllowsPathWithID, m.DomainAllowGETHandler) + attachHandler(http.MethodPut, DomainAllowsPathWithID, m.DomainAllowUpdatePUTHandler) attachHandler(http.MethodDelete, DomainAllowsPathWithID, m.DomainAllowDELETEHandler) // domain permission draft stuff diff --git a/internal/api/client/admin/domainallowupdate.go b/internal/api/client/admin/domainallowupdate.go new file mode 100644 index 000000000..02edfdfef --- /dev/null +++ b/internal/api/client/admin/domainallowupdate.go @@ -0,0 +1,91 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package admin + +import ( + "github.com/gin-gonic/gin" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// DomainAllowUpdatePUTHandler swagger:operation PUT /api/v1/admin/domain_allows/{id} domainAllowUpdate +// +// Update a single domain allow. +// +// --- +// tags: +// - admin +// +// consumes: +// - multipart/form-data +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: The id of the domain allow. +// in: path +// required: true +// - +// name: obfuscate +// in: formData +// description: >- +// Obfuscate the name of the domain when serving it publicly. +// Eg., `example.org` becomes something like `ex***e.org`. +// type: boolean +// - +// name: public_comment +// in: formData +// description: >- +// Public comment about this domain allow. +// This will be displayed alongside the domain allow if you choose to share allows. +// type: string +// - +// name: private_comment +// in: formData +// description: >- +// Private comment about this domain allow. Will only be shown to other admins, so this +// is a useful way of internally keeping track of why a certain domain ended up allowed. +// type: string +// +// security: +// - OAuth2 Bearer: +// - admin:write:domain_allows +// +// responses: +// '200': +// description: The updated domain allow. +// schema: +// "$ref": "#/definitions/domainPermission" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) DomainAllowUpdatePUTHandler(c *gin.Context) { + m.updateDomainPermission(c, gtsmodel.DomainPermissionAllow) +} diff --git a/internal/api/client/admin/domainblockupdate.go b/internal/api/client/admin/domainblockupdate.go new file mode 100644 index 000000000..0fbe72aa8 --- /dev/null +++ b/internal/api/client/admin/domainblockupdate.go @@ -0,0 +1,91 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package admin + +import ( + "github.com/gin-gonic/gin" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// DomainBlockUpdatePUTHandler swagger:operation PUT /api/v1/admin/domain_blocks/{id} domainBlockUpdate +// +// Update a single domain block. +// +// --- +// tags: +// - admin +// +// consumes: +// - multipart/form-data +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: The id of the domain block. +// in: path +// required: true +// - +// name: obfuscate +// in: formData +// description: >- +// Obfuscate the name of the domain when serving it publicly. +// Eg., `example.org` becomes something like `ex***e.org`. +// type: boolean +// - +// name: public_comment +// in: formData +// description: >- +// Public comment about this domain block. +// This will be displayed alongside the domain block if you choose to share blocks. +// type: string +// - +// name: private_comment +// in: formData +// description: >- +// Private comment about this domain block. Will only be shown to other admins, so this +// is a useful way of internally keeping track of why a certain domain ended up blocked. +// type: string +// +// security: +// - OAuth2 Bearer: +// - admin:write:domain_blocks +// +// responses: +// '200': +// description: The updated domain block. +// schema: +// "$ref": "#/definitions/domainPermission" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) DomainBlockUpdatePUTHandler(c *gin.Context) { + m.updateDomainPermission(c, gtsmodel.DomainPermissionBlock) +} diff --git a/internal/api/client/admin/domainpermission.go b/internal/api/client/admin/domainpermission.go index c64c90eb2..91b95334b 100644 --- a/internal/api/client/admin/domainpermission.go +++ b/internal/api/client/admin/domainpermission.go @@ -29,6 +29,7 @@ import ( apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" ) type singleDomainPermCreate func( @@ -112,7 +113,7 @@ func (m *Module) createDomainPermissions( if importing && form.Domains.Size == 0 { err = errors.New("import was specified but list of domains is empty") } else if !importing && form.Domain == "" { - err = errors.New("empty domain provided") + err = errors.New("no domain provided") } if err != nil { @@ -122,14 +123,14 @@ func (m *Module) createDomainPermissions( if !importing { // Single domain permission creation. - domainBlock, _, errWithCode := single( + perm, _, errWithCode := single( c.Request.Context(), permType, authed.Account, form.Domain, - form.Obfuscate, - form.PublicComment, - form.PrivateComment, + util.PtrOrZero(form.Obfuscate), + util.PtrOrZero(form.PublicComment), + util.PtrOrZero(form.PrivateComment), "", // No sub ID for single perm creation. ) @@ -138,7 +139,7 @@ func (m *Module) createDomainPermissions( return } - apiutil.JSON(c, http.StatusOK, domainBlock) + apiutil.JSON(c, http.StatusOK, perm) return } @@ -177,6 +178,82 @@ func (m *Module) createDomainPermissions( apiutil.JSON(c, http.StatusOK, domainPerms) } +func (m *Module) updateDomainPermission( + c *gin.Context, + permType gtsmodel.DomainPermissionType, +) { + // Scope differs based on permType. + var requireScope apiutil.Scope + if permType == gtsmodel.DomainPermissionBlock { + requireScope = apiutil.ScopeAdminWriteDomainBlocks + } else { + requireScope = apiutil.ScopeAdminWriteDomainAllows + } + + authed, errWithCode := apiutil.TokenAuth(c, + true, true, true, true, + requireScope, + ) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if !*authed.User.Admin { + err := fmt.Errorf("user %s not an admin", authed.User.ID) + apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if authed.Account.IsMoving() { + apiutil.ForbiddenAfterMove(c) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + permID, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey)) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + // Parse + validate form. + form := new(apimodel.DomainPermissionRequest) + if err := c.ShouldBind(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if form.Obfuscate == nil && + form.PrivateComment == nil && + form.PublicComment == nil { + const errText = "empty form submitted" + errWithCode := gtserror.NewErrorBadRequest(errors.New(errText), errText) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + perm, errWithCode := m.processor.Admin().DomainPermissionUpdate( + c.Request.Context(), + permType, + permID, + form.Obfuscate, + form.PublicComment, + form.PrivateComment, + nil, // Can't update perm sub ID this way yet. + ) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiutil.JSON(c, http.StatusOK, perm) +} + // deleteDomainPermission deletes a single domain permission (block or allow). func (m *Module) deleteDomainPermission( c *gin.Context, diff --git a/internal/api/client/admin/domainpermissiondraftcreate.go b/internal/api/client/admin/domainpermissiondraftcreate.go index b8d3085e9..e7fcd2c40 100644 --- a/internal/api/client/admin/domainpermissiondraftcreate.go +++ b/internal/api/client/admin/domainpermissiondraftcreate.go @@ -26,6 +26,7 @@ import ( apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/util" ) // DomainPermissionDraftsPOSTHandler swagger:operation POST /api/v1/admin/domain_permission_drafts domainPermissionDraftCreate @@ -148,9 +149,9 @@ func (m *Module) DomainPermissionDraftsPOSTHandler(c *gin.Context) { authed.Account, form.Domain, permType, - form.Obfuscate, - form.PublicComment, - form.PrivateComment, + util.PtrOrZero(form.Obfuscate), + util.PtrOrZero(form.PublicComment), + util.PtrOrZero(form.PrivateComment), ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) diff --git a/internal/api/client/admin/domainpermissionsubscriptiontest_test.go b/internal/api/client/admin/domainpermissionsubscriptiontest_test.go index c03b950a9..4ac366520 100644 --- a/internal/api/client/admin/domainpermissionsubscriptiontest_test.go +++ b/internal/api/client/admin/domainpermissionsubscriptiontest_test.go @@ -97,14 +97,21 @@ func (suite *DomainPermissionSubscriptionTestTestSuite) TestDomainPermissionSubs suite.Equal(`[ { "domain": "bumfaces.net", - "public_comment": "big jerks" + "public_comment": "big jerks", + "obfuscate": false, + "private_comment": "" }, { "domain": "peepee.poopoo", - "public_comment": "harassment" + "public_comment": "harassment", + "obfuscate": false, + "private_comment": "" }, { - "domain": "nothanks.com" + "domain": "nothanks.com", + "public_comment": "", + "obfuscate": false, + "private_comment": "" } ]`, dst.String()) @@ -177,13 +184,22 @@ func (suite *DomainPermissionSubscriptionTestTestSuite) TestDomainPermissionSubs // Ensure expected. suite.Equal(`[ { - "domain": "bumfaces.net" + "domain": "bumfaces.net", + "public_comment": "", + "obfuscate": false, + "private_comment": "" }, { - "domain": "peepee.poopoo" + "domain": "peepee.poopoo", + "public_comment": "", + "obfuscate": false, + "private_comment": "" }, { - "domain": "nothanks.com" + "domain": "nothanks.com", + "public_comment": "", + "obfuscate": false, + "private_comment": "" } ]`, dst.String()) diff --git a/internal/api/client/instance/instancepeersget_test.go b/internal/api/client/instance/instancepeersget_test.go index a2c81cc4e..2421205f7 100644 --- a/internal/api/client/instance/instancepeersget_test.go +++ b/internal/api/client/instance/instancepeersget_test.go @@ -136,7 +136,7 @@ func (suite *InstancePeersGetTestSuite) TestInstancePeersGetOnlySuspended() { { "domain": "replyguys.com", "suspended_at": "2020-05-13T13:29:12.000Z", - "public_comment": "reply-guying to tech posts" + "comment": "reply-guying to tech posts" } ]`, dst.String()) } @@ -186,7 +186,7 @@ func (suite *InstancePeersGetTestSuite) TestInstancePeersGetOnlySuspendedAuthori { "domain": "replyguys.com", "suspended_at": "2020-05-13T13:29:12.000Z", - "public_comment": "reply-guying to tech posts" + "comment": "reply-guying to tech posts" } ]`, dst.String()) } @@ -219,7 +219,7 @@ func (suite *InstancePeersGetTestSuite) TestInstancePeersGetAll() { { "domain": "replyguys.com", "suspended_at": "2020-05-13T13:29:12.000Z", - "public_comment": "reply-guying to tech posts" + "comment": "reply-guying to tech posts" } ]`, dst.String()) } @@ -263,12 +263,12 @@ func (suite *InstancePeersGetTestSuite) TestInstancePeersGetAllWithObfuscated() { "domain": "o*g.*u**.t**.*or*t.*r**ev**", "suspended_at": "2021-06-09T10:34:55.000Z", - "public_comment": "just absolutely the worst, wowza" + "comment": "just absolutely the worst, wowza" }, { "domain": "replyguys.com", "suspended_at": "2020-05-13T13:29:12.000Z", - "public_comment": "reply-guying to tech posts" + "comment": "reply-guying to tech posts" } ]`, dst.String()) } diff --git a/internal/api/model/domain.go b/internal/api/model/domain.go index 94a190f63..8d94321d0 100644 --- a/internal/api/model/domain.go +++ b/internal/api/model/domain.go @@ -33,8 +33,13 @@ type Domain struct { // example: 2021-07-30T09:20:25+00:00 SilencedAt string `json:"silenced_at,omitempty"` // If the domain is blocked, what's the publicly-stated reason for the block. + // Alternative to `public_comment` to be used when serializing/deserializing via /api/v1/instance. // example: they smell - PublicComment string `form:"public_comment" json:"public_comment,omitempty"` + Comment *string `form:"comment" json:"comment,omitempty"` + // If the domain is blocked, what's the publicly-stated reason for the block. + // Alternative to `comment` to be used when serializing/deserializing NOT via /api/v1/instance. + // example: they smell + PublicComment *string `form:"public_comment" json:"public_comment,omitempty"` } // DomainPermission represents a permission applied to one domain (explicit block/allow). @@ -48,10 +53,10 @@ type DomainPermission struct { ID string `json:"id,omitempty"` // Obfuscate the domain name when serving this domain permission entry publicly. // example: false - Obfuscate bool `json:"obfuscate,omitempty"` + Obfuscate *bool `json:"obfuscate,omitempty"` // Private comment for this permission entry, visible to this instance's admins only. // example: they are poopoo - PrivateComment string `json:"private_comment,omitempty"` + PrivateComment *string `json:"private_comment,omitempty"` // If applicable, the ID of the subscription that caused this domain permission entry to be created. // example: 01FBW25TF5J67JW3HFHZCSD23K SubscriptionID string `json:"subscription_id,omitempty"` @@ -80,14 +85,14 @@ type DomainPermissionRequest struct { // Obfuscate the domain name when displaying this permission entry publicly. // Ie., instead of 'example.org' show something like 'e**mpl*.or*'. // example: false - Obfuscate bool `form:"obfuscate" json:"obfuscate"` + Obfuscate *bool `form:"obfuscate" json:"obfuscate"` // Private comment for other admins on why this permission entry was created. // example: don't like 'em!!!! - PrivateComment string `form:"private_comment" json:"private_comment"` + PrivateComment *string `form:"private_comment" json:"private_comment"` // Public comment on why this permission entry was created. // Will be visible to requesters at /api/v1/instance/peers if this endpoint is exposed. // example: foss dorks 😫 - PublicComment string `form:"public_comment" json:"public_comment"` + PublicComment *string `form:"public_comment" json:"public_comment"` // Permission type to create (only applies to domain permission drafts, not explicit blocks and allows). PermissionType string `form:"permission_type" json:"permission_type"` } diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 925387bd9..23b9abc74 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -36,7 +36,7 @@ type domainDB struct { state *state.State } -func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) (err error) { +func (d *domainDB) PutDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) (err error) { // Normalize the domain as punycode, note the extra // validation step for domain name write operations. allow.Domain, err = util.PunifySafely(allow.Domain) @@ -162,7 +162,7 @@ func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error { return nil } -func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error { +func (d *domainDB) PutDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error { var err error // Normalize the domain as punycode, note the extra diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go index 8164259e8..a56f469c4 100644 --- a/internal/db/bundb/domain_test.go +++ b/internal/db/bundb/domain_test.go @@ -46,7 +46,7 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() { suite.NoError(err) suite.False(blocked) - err = suite.db.CreateDomainBlock(ctx, domainBlock) + err = suite.db.PutDomainBlock(ctx, domainBlock) suite.NoError(err) // domain block now exists @@ -75,7 +75,7 @@ func (suite *DomainTestSuite) TestIsDomainBlockedWithAllow() { suite.False(blocked) // Block this domain. - if err := suite.db.CreateDomainBlock(ctx, domainBlock); err != nil { + if err := suite.db.PutDomainBlock(ctx, domainBlock); err != nil { suite.FailNow(err.Error()) } @@ -96,7 +96,7 @@ func (suite *DomainTestSuite) TestIsDomainBlockedWithAllow() { CreatedByAccount: suite.testAccounts["admin_account"], } - if err := suite.db.CreateDomainAllow(ctx, domainAllow); err != nil { + if err := suite.db.PutDomainAllow(ctx, domainAllow); err != nil { suite.FailNow(err.Error()) } @@ -124,7 +124,7 @@ func (suite *DomainTestSuite) TestIsDomainBlockedWildcard() { suite.NoError(err) suite.False(blocked) - err = suite.db.CreateDomainBlock(ctx, domainBlock) + err = suite.db.PutDomainBlock(ctx, domainBlock) suite.NoError(err) // Start with the base block domain @@ -164,7 +164,7 @@ func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII() { suite.NoError(err) suite.False(blocked) - err = suite.db.CreateDomainBlock(ctx, domainBlock) + err = suite.db.PutDomainBlock(ctx, domainBlock) suite.NoError(err) // domain block now exists @@ -200,7 +200,7 @@ func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII2() { suite.NoError(err) suite.False(blocked) - err = suite.db.CreateDomainBlock(ctx, domainBlock) + err = suite.db.PutDomainBlock(ctx, domainBlock) suite.NoError(err) // domain block now exists @@ -232,7 +232,7 @@ func (suite *DomainTestSuite) TestIsOtherDomainBlockedWildcardAndExplicit() { } for _, block := range blocks { - if err := suite.db.CreateDomainBlock(ctx, block); err != nil { + if err := suite.db.PutDomainBlock(ctx, block); err != nil { suite.FailNow(err.Error()) } } diff --git a/internal/db/bundb/domainpermissionsubscription_test.go b/internal/db/bundb/domainpermissionsubscription_test.go index 732befbff..7a5cf8685 100644 --- a/internal/db/bundb/domainpermissionsubscription_test.go +++ b/internal/db/bundb/domainpermissionsubscription_test.go @@ -80,7 +80,7 @@ func (suite *DomainPermissionSubscriptionTestSuite) TestCount() { // Whack the perms in the db. for _, perm := range perms { - if err := suite.state.DB.CreateDomainBlock(ctx, perm); err != nil { + if err := suite.state.DB.PutDomainBlock(ctx, perm); err != nil { suite.FailNow(err.Error()) } } diff --git a/internal/db/domain.go b/internal/db/domain.go index 643538e7e..95a2f0755 100644 --- a/internal/db/domain.go +++ b/internal/db/domain.go @@ -31,8 +31,8 @@ type Domain interface { Block/allow storage + retrieval functions. */ - // CreateDomainAllow puts the given instance-level domain allow into the database. - CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error + // PutDomainAllow puts the given instance-level domain allow into the database. + PutDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error // GetDomainAllow returns one instance-level domain allow with the given domain, if it exists. GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error) @@ -49,8 +49,8 @@ type Domain interface { // DeleteDomainAllow deletes an instance-level domain allow with the given domain, if it exists. DeleteDomainAllow(ctx context.Context, domain string) error - // CreateDomainBlock puts the given instance-level domain block into the database. - CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error + // PutDomainBlock puts the given instance-level domain block into the database. + PutDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error // GetDomainBlock returns one instance-level domain block with the given domain, if it exists. GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error) diff --git a/internal/gtsmodel/domainallow.go b/internal/gtsmodel/domainallow.go index 3a7ca8774..f6aedbbba 100644 --- a/internal/gtsmodel/domainallow.go +++ b/internal/gtsmodel/domainallow.go @@ -26,7 +26,7 @@ type DomainAllow struct { UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated Domain string `bun:",nullzero,notnull"` // domain to allow. Eg. 'whatever.com' CreatedByAccountID string `bun:"type:CHAR(26),nullzero,notnull"` // Account ID of the creator of this allow - CreatedByAccount *Account `bun:"rel:belongs-to"` // Account corresponding to createdByAccountID + CreatedByAccount *Account `bun:"-"` // Account corresponding to createdByAccountID PrivateComment string `bun:""` // Private comment on this allow, viewable to admins PublicComment string `bun:""` // Public comment on this allow, viewable (optionally) by everyone Obfuscate *bool `bun:",nullzero,notnull,default:false"` // whether the domain name should appear obfuscated when displaying it publicly diff --git a/internal/gtsmodel/domainblock.go b/internal/gtsmodel/domainblock.go index 4a0e1c5b7..fb0921c25 100644 --- a/internal/gtsmodel/domainblock.go +++ b/internal/gtsmodel/domainblock.go @@ -26,7 +26,7 @@ type DomainBlock struct { UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated Domain string `bun:",nullzero,notnull"` // domain to block. Eg. 'whatever.com' CreatedByAccountID string `bun:"type:CHAR(26),nullzero,notnull"` // Account ID of the creator of this block - CreatedByAccount *Account `bun:"rel:belongs-to"` // Account corresponding to createdByAccountID + CreatedByAccount *Account `bun:"-"` // Account corresponding to createdByAccountID PrivateComment string `bun:""` // Private comment on this block, viewable to admins PublicComment string `bun:""` // Public comment on this block, viewable (optionally) by everyone Obfuscate *bool `bun:",nullzero,notnull,default:false"` // whether the domain name should appear obfuscated when displaying it publicly diff --git a/internal/processing/admin/domainallow.go b/internal/processing/admin/domainallow.go index 02101ccff..134351ad5 100644 --- a/internal/processing/admin/domainallow.go +++ b/internal/processing/admin/domainallow.go @@ -60,7 +60,7 @@ func (p *Processor) createDomainAllow( } // Insert the new allow into the database. - if err := p.state.DB.CreateDomainAllow(ctx, domainAllow); err != nil { + if err := p.state.DB.PutDomainAllow(ctx, domainAllow); err != nil { err = gtserror.Newf("db error putting domain allow %s: %w", domain, err) return nil, "", gtserror.NewErrorInternalError(err) } @@ -92,6 +92,54 @@ func (p *Processor) createDomainAllow( return apiDomainAllow, action.ID, nil } +func (p *Processor) updateDomainAllow( + ctx context.Context, + domainAllowID string, + obfuscate *bool, + publicComment *string, + privateComment *string, + subscriptionID *string, +) (*apimodel.DomainPermission, gtserror.WithCode) { + domainAllow, err := p.state.DB.GetDomainAllowByID(ctx, domainAllowID) + if err != nil { + if !errors.Is(err, db.ErrNoEntries) { + // Real error. + err = gtserror.Newf("db error getting domain allow: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + // There are just no entries for this ID. + err = fmt.Errorf("no domain allow entry exists with ID %s", domainAllowID) + return nil, gtserror.NewErrorNotFound(err, err.Error()) + } + + var columns []string + if obfuscate != nil { + domainAllow.Obfuscate = obfuscate + columns = append(columns, "obfuscate") + } + if publicComment != nil { + domainAllow.PublicComment = *publicComment + columns = append(columns, "public_comment") + } + if privateComment != nil { + domainAllow.PrivateComment = *privateComment + columns = append(columns, "private_comment") + } + if subscriptionID != nil { + domainAllow.SubscriptionID = *subscriptionID + columns = append(columns, "subscription_id") + } + + // Update the domain allow. + if err := p.state.DB.UpdateDomainAllow(ctx, domainAllow, columns...); err != nil { + err = gtserror.Newf("db error updating domain allow: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + return p.apiDomainPerm(ctx, domainAllow, false) +} + func (p *Processor) deleteDomainAllow( ctx context.Context, adminAcct *gtsmodel.Account, diff --git a/internal/processing/admin/domainblock.go b/internal/processing/admin/domainblock.go index 249df744c..3dd5a256f 100644 --- a/internal/processing/admin/domainblock.go +++ b/internal/processing/admin/domainblock.go @@ -60,7 +60,7 @@ func (p *Processor) createDomainBlock( } // Insert the new block into the database. - if err := p.state.DB.CreateDomainBlock(ctx, domainBlock); err != nil { + if err := p.state.DB.PutDomainBlock(ctx, domainBlock); err != nil { err = gtserror.Newf("db error putting domain block %s: %w", domain, err) return nil, "", gtserror.NewErrorInternalError(err) } @@ -93,6 +93,54 @@ func (p *Processor) createDomainBlock( return apiDomainBlock, action.ID, nil } +func (p *Processor) updateDomainBlock( + ctx context.Context, + domainBlockID string, + obfuscate *bool, + publicComment *string, + privateComment *string, + subscriptionID *string, +) (*apimodel.DomainPermission, gtserror.WithCode) { + domainBlock, err := p.state.DB.GetDomainBlockByID(ctx, domainBlockID) + if err != nil { + if !errors.Is(err, db.ErrNoEntries) { + // Real error. + err = gtserror.Newf("db error getting domain block: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + // There are just no entries for this ID. + err = fmt.Errorf("no domain block entry exists with ID %s", domainBlockID) + return nil, gtserror.NewErrorNotFound(err, err.Error()) + } + + var columns []string + if obfuscate != nil { + domainBlock.Obfuscate = obfuscate + columns = append(columns, "obfuscate") + } + if publicComment != nil { + domainBlock.PublicComment = *publicComment + columns = append(columns, "public_comment") + } + if privateComment != nil { + domainBlock.PrivateComment = *privateComment + columns = append(columns, "private_comment") + } + if subscriptionID != nil { + domainBlock.SubscriptionID = *subscriptionID + columns = append(columns, "subscription_id") + } + + // Update the domain block. + if err := p.state.DB.UpdateDomainBlock(ctx, domainBlock, columns...); err != nil { + err = gtserror.Newf("db error updating domain block: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + return p.apiDomainPerm(ctx, domainBlock, false) +} + func (p *Processor) deleteDomainBlock( ctx context.Context, adminAcct *gtsmodel.Account, diff --git a/internal/processing/admin/domainpermission.go b/internal/processing/admin/domainpermission.go index 55800f458..04ee2ab26 100644 --- a/internal/processing/admin/domainpermission.go +++ b/internal/processing/admin/domainpermission.go @@ -18,6 +18,7 @@ package admin import ( + "cmp" "context" "encoding/json" "errors" @@ -29,6 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" ) // DomainPermissionCreate creates an instance-level permission @@ -84,6 +86,50 @@ func (p *Processor) DomainPermissionCreate( } } +// DomainPermissionUpdate updates a domain permission +// of the given permissionType, with the given ID. +func (p *Processor) DomainPermissionUpdate( + ctx context.Context, + permissionType gtsmodel.DomainPermissionType, + permID string, + obfuscate *bool, + publicComment *string, + privateComment *string, + subscriptionID *string, +) (*apimodel.DomainPermission, gtserror.WithCode) { + switch permissionType { + + // Explicitly block a domain. + case gtsmodel.DomainPermissionBlock: + return p.updateDomainBlock( + ctx, + permID, + obfuscate, + publicComment, + privateComment, + subscriptionID, + ) + + // Explicitly allow a domain. + case gtsmodel.DomainPermissionAllow: + return p.updateDomainAllow( + ctx, + permID, + obfuscate, + publicComment, + privateComment, + subscriptionID, + ) + + // 🎵 Why don't we all strap bombs to our chests, + // and ride our bikes to the next G7 picnic? + // Seems easier with every clock-tick. 🎵 + default: + err := gtserror.Newf("unrecognized permission type %d", permissionType) + return nil, gtserror.NewErrorInternalError(err) + } +} + // DomainPermissionDelete removes one domain block with the given ID, // and processes side effects of removing the block asynchronously. // @@ -153,14 +199,14 @@ func (p *Processor) DomainPermissionsImport( } defer file.Close() - // Parse file as slice of domain blocks. - domainPerms := make([]*apimodel.DomainPermission, 0) - if err := json.NewDecoder(file).Decode(&domainPerms); err != nil { + // Parse file as slice of domain permissions. + apiDomainPerms := make([]*apimodel.DomainPermission, 0) + if err := json.NewDecoder(file).Decode(&apiDomainPerms); err != nil { err = gtserror.Newf("error parsing attachment as domain permissions: %w", err) return nil, gtserror.NewErrorBadRequest(err, err.Error()) } - count := len(domainPerms) + count := len(apiDomainPerms) if count == 0 { err = gtserror.New("error importing domain permissions: 0 entries provided") return nil, gtserror.NewErrorBadRequest(err, err.Error()) @@ -170,52 +216,97 @@ func (p *Processor) DomainPermissionsImport( // between successes and errors so that the caller can // try failed imports again if desired. multiStatusEntries := make([]apimodel.MultiStatusEntry, 0, count) - - for _, domainPerm := range domainPerms { - var ( - domain = domainPerm.Domain.Domain - obfuscate = domainPerm.Obfuscate - publicComment = domainPerm.PublicComment - privateComment = domainPerm.PrivateComment - subscriptionID = "" // No sub ID for imports. - errWithCode gtserror.WithCode + for _, apiDomainPerm := range apiDomainPerms { + multiStatusEntries = append( + multiStatusEntries, + p.importOrUpdateDomainPerm( + ctx, + permissionType, + account, + apiDomainPerm, + ), ) - - domainPerm, _, errWithCode = p.DomainPermissionCreate( - ctx, - permissionType, - account, - domain, - obfuscate, - publicComment, - privateComment, - subscriptionID, - ) - - var entry *apimodel.MultiStatusEntry - - if errWithCode != nil { - entry = &apimodel.MultiStatusEntry{ - // Use the failed domain entry as the resource value. - Resource: domain, - Message: errWithCode.Safe(), - Status: errWithCode.Code(), - } - } else { - entry = &apimodel.MultiStatusEntry{ - // Use successfully created API model domain block as the resource value. - Resource: domainPerm, - Message: http.StatusText(http.StatusOK), - Status: http.StatusOK, - } - } - - multiStatusEntries = append(multiStatusEntries, *entry) } return apimodel.NewMultiStatus(multiStatusEntries), nil } +func (p *Processor) importOrUpdateDomainPerm( + ctx context.Context, + permType gtsmodel.DomainPermissionType, + account *gtsmodel.Account, + apiDomainPerm *apimodel.DomainPermission, +) apimodel.MultiStatusEntry { + var ( + domain = apiDomainPerm.Domain.Domain + obfuscate = apiDomainPerm.Obfuscate + publicComment = cmp.Or(apiDomainPerm.PublicComment, apiDomainPerm.Comment) + privateComment = apiDomainPerm.PrivateComment + subscriptionID = "" // No sub ID for imports. + ) + + // Check if this domain + // perm already exists. + var ( + domainPerm gtsmodel.DomainPermission + err error + ) + if permType == gtsmodel.DomainPermissionBlock { + domainPerm, err = p.state.DB.GetDomainBlock(ctx, domain) + } else { + domainPerm, err = p.state.DB.GetDomainAllow(ctx, domain) + } + + if err != nil && !errors.Is(err, db.ErrNoEntries) { + // Real db error. + return apimodel.MultiStatusEntry{ + Resource: domain, + Message: "db error checking for existence of domain permission", + Status: http.StatusInternalServerError, + } + } + + var errWithCode gtserror.WithCode + if domainPerm != nil { + // Permission already exists, update it. + apiDomainPerm, errWithCode = p.DomainPermissionUpdate( + ctx, + permType, + domainPerm.GetID(), + obfuscate, + publicComment, + privateComment, + nil, + ) + } else { + // Permission didn't exist yet, create it. + apiDomainPerm, _, errWithCode = p.DomainPermissionCreate( + ctx, + permType, + account, + domain, + util.PtrOrZero(obfuscate), + util.PtrOrZero(publicComment), + util.PtrOrZero(privateComment), + subscriptionID, + ) + } + + if errWithCode != nil { + return apimodel.MultiStatusEntry{ + Resource: domain, + Message: errWithCode.Safe(), + Status: errWithCode.Code(), + } + } + + return apimodel.MultiStatusEntry{ + Resource: apiDomainPerm, + Message: http.StatusText(http.StatusOK), + Status: http.StatusOK, + } +} + // DomainPermissionsGet returns all existing domain // permissions of the requested type. If export is // true, the format will be suitable for writing out diff --git a/internal/processing/instance.go b/internal/processing/instance.go index 4cbbb742a..e723c751e 100644 --- a/internal/processing/instance.go +++ b/internal/processing/instance.go @@ -106,9 +106,9 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool, } domains = append(domains, &apimodel.Domain{ - Domain: d, - SuspendedAt: util.FormatISO8601(domainBlock.CreatedAt), - PublicComment: domainBlock.PublicComment, + Domain: d, + SuspendedAt: util.FormatISO8601(domainBlock.CreatedAt), + Comment: &domainBlock.PublicComment, }) } } diff --git a/internal/subscriptions/domainperms.go b/internal/subscriptions/domainperms.go index c9f569f94..8da9064f6 100644 --- a/internal/subscriptions/domainperms.go +++ b/internal/subscriptions/domainperms.go @@ -438,7 +438,7 @@ func (s *Subscriptions) processDomainPermission( Obfuscate: wantedPerm.GetObfuscate(), SubscriptionID: permSub.ID, } - insertF = func() error { return s.state.DB.CreateDomainBlock(ctx, domainBlock) } + insertF = func() error { return s.state.DB.PutDomainBlock(ctx, domainBlock) } action = >smodel.AdminAction{ ID: id.NewULID(), @@ -461,7 +461,7 @@ func (s *Subscriptions) processDomainPermission( Obfuscate: wantedPerm.GetObfuscate(), SubscriptionID: permSub.ID, } - insertF = func() error { return s.state.DB.CreateDomainAllow(ctx, domainAllow) } + insertF = func() error { return s.state.DB.PutDomainAllow(ctx, domainAllow) } action = >smodel.AdminAction{ ID: id.NewULID(), @@ -564,13 +564,13 @@ func permsFromCSV( for i, columnHeader := range columnHeaders { // Remove leading # if present. - normal := strings.TrimLeft(columnHeader, "#") + columnHeader = strings.TrimLeft(columnHeader, "#") // Find index of each column header we // care about, ensuring no duplicates. - switch normal { + switch { - case "domain": + case columnHeader == "domain": if domainI != nil { body.Close() err := gtserror.NewfAt(3, "duplicate domain column header in csv: %+v", columnHeaders) @@ -578,7 +578,7 @@ func permsFromCSV( } domainI = &i - case "severity": + case columnHeader == "severity": if severityI != nil { body.Close() err := gtserror.NewfAt(3, "duplicate severity column header in csv: %+v", columnHeaders) @@ -586,15 +586,15 @@ func permsFromCSV( } severityI = &i - case "public_comment": + case columnHeader == "public_comment" || columnHeader == "comment": if publicCommentI != nil { body.Close() - err := gtserror.NewfAt(3, "duplicate public_comment column header in csv: %+v", columnHeaders) + err := gtserror.NewfAt(3, "duplicate public_comment or comment column header in csv: %+v", columnHeaders) return nil, err } publicCommentI = &i - case "obfuscate": + case columnHeader == "obfuscate": if obfuscateI != nil { body.Close() err := gtserror.NewfAt(3, "duplicate obfuscate column header in csv: %+v", columnHeaders) @@ -674,15 +674,15 @@ func permsFromCSV( perm.SetPublicComment(record[*publicCommentI]) } + var obfuscate bool if obfuscateI != nil { - obfuscate, err := strconv.ParseBool(record[*obfuscateI]) + obfuscate, err = strconv.ParseBool(record[*obfuscateI]) if err != nil { l.Warnf("couldn't parse obfuscate field of record: %+v", record) continue } - - perm.SetObfuscate(&obfuscate) } + perm.SetObfuscate(&obfuscate) // We're done. perms = append(perms, perm) @@ -742,8 +742,9 @@ func permsFromJSON( } // Set remaining fields. - perm.SetPublicComment(apiPerm.PublicComment) - perm.SetObfuscate(&apiPerm.Obfuscate) + publicComment := cmp.Or(apiPerm.PublicComment, apiPerm.Comment) + perm.SetPublicComment(util.PtrOrZero(publicComment)) + perm.SetObfuscate(util.Ptr(util.PtrOrZero(apiPerm.Obfuscate))) // We're done. perms = append(perms, perm) @@ -792,9 +793,15 @@ func permsFromPlain( var perm gtsmodel.DomainPermission switch permType { case gtsmodel.DomainPermissionBlock: - perm = >smodel.DomainBlock{Domain: domain} + perm = >smodel.DomainBlock{ + Domain: domain, + Obfuscate: util.Ptr(false), + } case gtsmodel.DomainPermissionAllow: - perm = >smodel.DomainAllow{Domain: domain} + perm = >smodel.DomainAllow{ + Domain: domain, + Obfuscate: util.Ptr(false), + } } // We're done. diff --git a/internal/subscriptions/subscriptions_test.go b/internal/subscriptions/subscriptions_test.go index 133db4b7c..4441d8c15 100644 --- a/internal/subscriptions/subscriptions_test.go +++ b/internal/subscriptions/subscriptions_test.go @@ -775,7 +775,7 @@ func (suite *SubscriptionsTestSuite) TestAdoption() { existingBlock2, existingBlock3, } { - if err := testStructs.State.DB.CreateDomainBlock( + if err := testStructs.State.DB.PutDomainBlock( ctx, block, ); err != nil { suite.FailNow(err.Error()) @@ -876,7 +876,7 @@ func (suite *SubscriptionsTestSuite) TestDomainAllowsAndBlocks() { } // Store existing allow. - if err := testStructs.State.DB.CreateDomainAllow(ctx, existingAllow); err != nil { + if err := testStructs.State.DB.PutDomainAllow(ctx, existingAllow); err != nil { suite.FailNow(err.Error()) } diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index b0f5d12fa..62a1ebc1e 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -2182,7 +2182,7 @@ func (c *Converter) DomainPermToAPIDomainPerm( domainPerm := &apimodel.DomainPermission{ Domain: apimodel.Domain{ Domain: domain, - PublicComment: d.GetPublicComment(), + PublicComment: util.Ptr(d.GetPublicComment()), }, } @@ -2193,8 +2193,8 @@ func (c *Converter) DomainPermToAPIDomainPerm( } domainPerm.ID = d.GetID() - domainPerm.Obfuscate = util.PtrOrZero(d.GetObfuscate()) - domainPerm.PrivateComment = d.GetPrivateComment() + domainPerm.Obfuscate = d.GetObfuscate() + domainPerm.PrivateComment = util.Ptr(d.GetPrivateComment()) domainPerm.SubscriptionID = d.GetSubscriptionID() domainPerm.CreatedBy = d.GetCreatedByAccountID() if createdAt := d.GetCreatedAt(); !createdAt.IsZero() { diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go index bbcb3901d..d66c71179 100644 --- a/testrig/transportcontroller.go +++ b/testrig/transportcontroller.go @@ -627,7 +627,7 @@ nothanks.com` { "domain": "bumfaces.net", "suspended_at": "2020-05-13T13:29:12.000Z", - "public_comment": "big jerks" + "comment": "big jerks" }, { "domain": "peepee.poopoo", diff --git a/web/source/settings/lib/query/admin/domain-permissions/import.ts b/web/source/settings/lib/query/admin/domain-permissions/import.ts index cbcf44964..a83448a1f 100644 --- a/web/source/settings/lib/query/admin/domain-permissions/import.ts +++ b/web/source/settings/lib/query/admin/domain-permissions/import.ts @@ -40,39 +40,19 @@ function importEntriesProcessor(formData: ImportDomainPermsParams): (_entry: Dom // Override each obfuscate entry if necessary. if (formData.obfuscate !== undefined) { - const obfuscateEntry = (entry: DomainPerm) => { + processingFuncs.push((entry: DomainPerm) => { entry.obfuscate = formData.obfuscate; - }; - processingFuncs.push(obfuscateEntry); + }); } - // Check whether we need to append or replace - // private_comment and public_comment. + // Check whether we need to replace + // private_comment and/or public_comment. ["private_comment","public_comment"].forEach((commentType) => { - let text = formData.commentType?.trim(); - if (!text) { - return; - } - - switch(formData[`${commentType}_behavior`]) { - case "append": - const appendComment = (entry: DomainPerm) => { - if (entry.commentType == undefined) { - entry.commentType = text; - } else { - entry.commentType = [entry.commentType, text].join("\n"); - } - }; - - processingFuncs.push(appendComment); - break; - case "replace": - const replaceComment = (entry: DomainPerm) => { - entry.commentType = text; - }; - - processingFuncs.push(replaceComment); - break; + if (formData[`replace_${commentType}`]) { + const text = formData[commentType]?.trim(); + processingFuncs.push((entry: DomainPerm) => { + entry[commentType] = text; + }); } }); diff --git a/web/source/settings/lib/query/admin/domain-permissions/update.ts b/web/source/settings/lib/query/admin/domain-permissions/update.ts index a6b4b2039..396c30d6e 100644 --- a/web/source/settings/lib/query/admin/domain-permissions/update.ts +++ b/web/source/settings/lib/query/admin/domain-permissions/update.ts @@ -22,6 +22,7 @@ import { gtsApi } from "../../gts-api"; import { replaceCacheOnMutation, removeFromCacheOnMutation, + updateCacheOnMutation, } from "../../query-modifiers"; import { listToKeyedObject } from "../../transforms"; import type { @@ -55,6 +56,36 @@ const extended = gtsApi.injectEndpoints({ ...replaceCacheOnMutation("domainAllows") }), + updateDomainBlock: build.mutation({ + query: ({ id, ...formData}) => ({ + method: "PUT", + url: `/api/v1/admin/domain_blocks/${id}`, + asForm: true, + body: formData, + discardEmpty: false + }), + ...updateCacheOnMutation("domainBlocks", { + key: (_draft, newData) => { + return newData.domain; + } + }) + }), + + updateDomainAllow: build.mutation({ + query: ({ id, ...formData}) => ({ + method: "PUT", + url: `/api/v1/admin/domain_allows/${id}`, + asForm: true, + body: formData, + discardEmpty: false + }), + ...updateCacheOnMutation("domainAllows", { + key: (_draft, newData) => { + return newData.domain; + } + }) + }), + removeDomainBlock: build.mutation({ query: (id) => ({ method: "DELETE", @@ -91,6 +122,16 @@ const useAddDomainBlockMutation = extended.useAddDomainBlockMutation; */ const useAddDomainAllowMutation = extended.useAddDomainAllowMutation; +/** + * Update a single domain permission (block) by PUTing to `/api/v1/admin/domain_blocks/{id}`. + */ +const useUpdateDomainBlockMutation = extended.useUpdateDomainBlockMutation; + +/** + * Update a single domain permission (allow) by PUTing to `/api/v1/admin/domain_allows/{id}`. + */ +const useUpdateDomainAllowMutation = extended.useUpdateDomainAllowMutation; + /** * Remove a single domain permission (block) by DELETEing to `/api/v1/admin/domain_blocks/{id}`. */ @@ -104,6 +145,8 @@ const useRemoveDomainAllowMutation = extended.useRemoveDomainAllowMutation; export { useAddDomainBlockMutation, useAddDomainAllowMutation, + useUpdateDomainBlockMutation, + useUpdateDomainAllowMutation, useRemoveDomainBlockMutation, useRemoveDomainAllowMutation }; diff --git a/web/source/settings/lib/types/domain-permission.ts b/web/source/settings/lib/types/domain-permission.ts index c4560d79b..27c4b56c9 100644 --- a/web/source/settings/lib/types/domain-permission.ts +++ b/web/source/settings/lib/types/domain-permission.ts @@ -46,8 +46,8 @@ export interface DomainPerm { valid?: boolean; checked?: boolean; commentType?: string; - private_comment_behavior?: "append" | "replace"; - public_comment_behavior?: "append" | "replace"; + replace_private_comment?: boolean; + replace_public_comment?: boolean; } /** @@ -65,8 +65,8 @@ const domainPermStripOnImport: Set = new Set([ "valid", "checked", "commentType", - "private_comment_behavior", - "public_comment_behavior", + "replace_private_comment", + "replace_public_comment", ]); /** diff --git a/web/source/settings/style.css b/web/source/settings/style.css index fc146cdd7..c05072043 100644 --- a/web/source/settings/style.css +++ b/web/source/settings/style.css @@ -618,6 +618,15 @@ span.form-info { } } +section > div.domain-block, +section > div.domain-allow { + height: 100%; + + > a { + margin-top: auto; + } +} + .domain-permissions-list { p { margin-top: 0; @@ -976,32 +985,26 @@ button.tab-button { .domain-perm-import-list { .checkbox-list-wrapper { - overflow-x: auto; display: grid; gap: 1rem; } .checkbox-list { + overflow-x: auto; .header { + align-items: center; input[type="checkbox"] { - align-self: start; height: 1.5rem; } } .entry { - gap: 0; - width: 100%; - grid-template-columns: auto minmax(25ch, 2fr) minmax(40ch, 1fr); - grid-template-rows: auto 1fr; - - input[type="checkbox"] { - margin-right: 1rem; - } + grid-template-columns: auto max(50%, 14rem) 1fr; + column-gap: 1rem; + align-items: center; .domain-input { - margin-right: 0.5rem; display: grid; grid-template-columns: 1fr $fa-fw; gap: 0.5rem; @@ -1020,13 +1023,21 @@ button.tab-button { } p { - align-self: center; margin: 0; - grid-column: 4; - grid-row: 1 / span 2; } } } + + .set-comment-checkbox { + display: flex; + flex-direction: column; + gap: 0.25rem; + + padding: 0.5rem 1rem 1rem 1rem; + width: 100%; + border: 0.1rem solid var(--gray1); + border-radius: 0.1rem; + } } .import-export { @@ -1406,6 +1417,7 @@ button.tab-button { } } +.domain-permission-details, .domain-permission-draft-details, .domain-permission-exclude-details, .domain-permission-subscription-details { @@ -1414,6 +1426,7 @@ button.tab-button { } } +.domain-permission-details, .domain-permission-drafts-view, .domain-permission-draft-details, .domain-permission-subscriptions-view, diff --git a/web/source/settings/views/moderation/domain-permissions/detail.tsx b/web/source/settings/views/moderation/domain-permissions/detail.tsx index 0105d9615..e8ef487e3 100644 --- a/web/source/settings/views/moderation/domain-permissions/detail.tsx +++ b/web/source/settings/views/moderation/domain-permissions/detail.tsx @@ -32,8 +32,18 @@ import Loading from "../../../components/loading"; import BackButton from "../../../components/back-button"; import MutationButton from "../../../components/form/mutation-button"; -import { useDomainAllowsQuery, useDomainBlocksQuery } from "../../../lib/query/admin/domain-permissions/get"; -import { useAddDomainAllowMutation, useAddDomainBlockMutation, useRemoveDomainAllowMutation, useRemoveDomainBlockMutation } from "../../../lib/query/admin/domain-permissions/update"; +import { + useDomainAllowsQuery, + useDomainBlocksQuery, +} from "../../../lib/query/admin/domain-permissions/get"; +import { + useAddDomainAllowMutation, + useAddDomainBlockMutation, + useRemoveDomainAllowMutation, + useRemoveDomainBlockMutation, + useUpdateDomainAllowMutation, + useUpdateDomainBlockMutation, +} from "../../../lib/query/admin/domain-permissions/update"; import { DomainPerm } from "../../../lib/types/domain-permission"; import { NoArg } from "../../../lib/types/query"; import { Error } from "../../../components/error"; @@ -41,8 +51,10 @@ import { useBaseUrl } from "../../../lib/navigation/util"; import { PermType } from "../../../lib/types/perm"; import { useCapitalize } from "../../../lib/util"; import { formDomainValidator } from "../../../lib/util/formvalidators"; +import UsernameLozenge from "../../../components/username-lozenge"; +import { FormSubmitEvent } from "../../../lib/form/types"; -export default function DomainPermDetail() { +export default function DomainPermView() { const baseUrl = useBaseUrl(); const search = useSearch(); @@ -101,33 +113,16 @@ export default function DomainPermDetail() { ? blocks[domain] : allows[domain]; - // Render different into content depending on - // if we have a perm already for this domain. - let infoContent: React.JSX.Element; - if (existingPerm === undefined) { - infoContent = ( - - No stored {permType} yet, you can add one below: - - ); - } else { - infoContent = ( -
- - Editing existing domain {permTypeRaw} isn't implemented yet, check here for progress -
- ); - } + const title = Domain {permType} for {domain}; return ( -
-

- - {" "} - Domain {permType} for {domain} -

- {infoContent} - +

{title}

+ { existingPerm + ? + : No stored {permType} yet, you can add one below: + } + { + if (perm.created_at) { + return new Date(perm.created_at).toDateString(); + } + return "unknown"; + }, [perm.created_at]); + + return ( +
+
+
Created
+
+
+
+
Created By
+
+ +
+
+
+
Domain
+
{perm.domain}
+
+
+
Permission type
+
+ + {permType} +
+
+
+
Subscription ID
+
{perm.subscription_id ?? "[none]"}
+
+
+ ); +} + +interface CreateOrUpdateDomainPermProps { defaultDomain: string; perm?: DomainPerm; permType: PermType; } -function DomainPermForm({ defaultDomain, perm, permType }: DomainPermFormProps) { +function CreateOrUpdateDomainPerm({ + defaultDomain, + perm, + permType +}: CreateOrUpdateDomainPermProps) { const isExistingPerm = perm !== undefined; - const disabledForm = isExistingPerm - ? { - disabled: true, - title: "Domain permissions currently cannot be edited." - } - : { - disabled: false, - title: "", - }; const form = { domain: useTextInput("domain", { @@ -161,8 +208,8 @@ function DomainPermForm({ defaultDomain, perm, permType }: DomainPermFormProps) validator: formDomainValidator, }), obfuscate: useBoolInput("obfuscate", { source: perm }), - commentPrivate: useTextInput("private_comment", { source: perm }), - commentPublic: useTextInput("public_comment", { source: perm }) + privateComment: useTextInput("private_comment", { source: perm }), + publicComment: useTextInput("public_comment", { source: perm }) }; // Check which perm type we're meant to be handling @@ -171,112 +218,132 @@ function DomainPermForm({ defaultDomain, perm, permType }: DomainPermFormProps) // react is like "weh" (mood), but we can decide // which ones to use conditionally. const [ addBlock, addBlockResult ] = useAddDomainBlockMutation(); + const [ updateBlock, updateBlockResult ] = useUpdateDomainBlockMutation({ fixedCacheKey: perm?.id }); const [ removeBlock, removeBlockResult] = useRemoveDomainBlockMutation({ fixedCacheKey: perm?.id }); const [ addAllow, addAllowResult ] = useAddDomainAllowMutation(); + const [ updateAllow, updateAllowResult ] = useUpdateDomainAllowMutation({ fixedCacheKey: perm?.id }); const [ removeAllow, removeAllowResult ] = useRemoveDomainAllowMutation({ fixedCacheKey: perm?.id }); const [ - addTrigger, - addResult, + createOrUpdateTrigger, + createOrUpdateResult, removeTrigger, removeResult, ] = useMemo(() => { - return permType == "block" - ? [ - addBlock, - addBlockResult, - removeBlock, - removeBlockResult, - ] - : [ - addAllow, - addAllowResult, - removeAllow, - removeAllowResult, - ]; - }, [permType, - addBlock, addBlockResult, removeBlock, removeBlockResult, - addAllow, addAllowResult, removeAllow, removeAllowResult, + switch (true) { + case (permType === "block" && !isExistingPerm): + return [ addBlock, addBlockResult, removeBlock, removeBlockResult ]; + case (permType === "block"): + return [ updateBlock, updateBlockResult, removeBlock, removeBlockResult ]; + case !isExistingPerm: + return [ addAllow, addAllowResult, removeAllow, removeAllowResult ]; + default: + return [ updateAllow, updateAllowResult, removeAllow, removeAllowResult ]; + } + }, [permType, isExistingPerm, + addBlock, addBlockResult, updateBlock, updateBlockResult, removeBlock, removeBlockResult, + addAllow, addAllowResult, updateAllow, updateAllowResult, removeAllow, removeAllowResult, ]); - // Use appropriate submission params for this permType. - const [submitForm, submitFormResult] = useFormSubmit(form, [addTrigger, addResult], { changedOnly: false }); + // Use appropriate submission params for this + // permType, and whether we're creating or updating. + const [submit, submitResult] = useFormSubmit( + form, + [ createOrUpdateTrigger, createOrUpdateResult ], + { + changedOnly: isExistingPerm, + // If we're updating an existing perm, + // insert the perm ID into the mutation + // data before submitting. Otherwise just + // return the mutationData unmodified. + customizeMutationArgs: (mutationData) => { + if (isExistingPerm) { + return { + id: perm?.id, + ...mutationData, + }; + } else { + return mutationData; + } + }, + }, + ); // Uppercase first letter of given permType. const permTypeUpper = useCapitalize(permType); const [location, setLocation] = useLocation(); - - function verifyUrlThenSubmit(e) { + function onSubmit(e: FormSubmitEvent) { // Adding a new domain permissions happens on a url like // "/settings/admin/domain-permissions/:permType/domain.com", // but if domain input changes, that doesn't match anymore // and causes issues later on so, before submitting the form, // silently change url, and THEN submit. - let correctUrl = `/${permType}s/${form.domain.value}`; - if (location != correctUrl) { - setLocation(correctUrl); + if (!isExistingPerm) { + let correctUrl = `/${permType}s/${form.domain.value}`; + if (location != correctUrl) { + setLocation(correctUrl); + } } - return submitForm(e); + return submit(e); } return ( -
- + + { !isExistingPerm && + + }