[feature] persist worker queues to db (#3042)

* persist queued worker tasks to database on shutdown, fill worker queues from database on startup

* ensure the tasks are sorted by creation time before pushing them

* add migration to insert WorkerTask{} into database, add test for worker task persistence

* add test for recovering worker queues from database

* quick tweak

* whoops we ended up with double cleaner job scheduling

* insert each task separately, because bun is throwing some reflection error??

* add specific checking of cancelled worker contexts

* add http request signing to deliveries recovered from database

* add test for outgoing public key ID being correctly set on delivery

* replace select with Queue.PopCtx()

* get rid of loop now we don't use it

* remove field now we don't use it

* ensure that signing func is set

* header values weren't being copied over 🤦

* use ptr for httpclient.Request in delivery

* move worker queue filling to later in server init process

* fix rebase issues

* make logging less shouty

* use slices.Delete() instead of copying / reslicing

* have database return tasks in ascending order instead of sorting them

* add a 1 minute timeout to persisting worker queues
This commit is contained in:
kim 2024-07-30 11:58:31 +00:00 committed by GitHub
commit 87cff71af9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 1191 additions and 93 deletions

View file

@ -21,6 +21,7 @@ import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/url"
@ -169,6 +170,38 @@ func (t *transport) prepare(
}, nil
}
func (t *transport) SignDelivery(dlv *delivery.Delivery) error {
if dlv.Request.GetBody == nil {
return gtserror.New("delivery request body not rewindable")
}
// Get a new copy of the request body.
body, err := dlv.Request.GetBody()
if err != nil {
return gtserror.Newf("error getting request body: %w", err)
}
// Read body data into memory.
data, err := io.ReadAll(body)
if err != nil {
return gtserror.Newf("error reading request body: %w", err)
}
// Get signing function for POST data.
// (note that delivery is ALWAYS POST).
sign := t.signPOST(data)
// Extract delivery context.
ctx := dlv.Request.Context()
// Update delivery request context with signing details.
ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID)
ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign)
dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
return nil
}
// getObjectID extracts an object ID from 'serialized' ActivityPub object map.
func getObjectID(obj map[string]interface{}) string {
switch t := obj["object"].(type) {

View file

@ -33,10 +33,6 @@ import (
// be indexed (and so, dropped from queue)
// by any of these possible ID IRIs.
type Delivery struct {
// PubKeyID is the signing public key
// ID of the actor performing request.
PubKeyID string
// ActorID contains the ActivityPub
// actor ID IRI (if any) of the activity
// being sent out by this request.
@ -55,7 +51,7 @@ type Delivery struct {
// Request is the prepared (+ wrapped)
// httpclient.Client{} request that
// constitutes this ActivtyPub delivery.
Request httpclient.Request
Request *httpclient.Request
// internal fields.
next time.Time
@ -66,7 +62,6 @@ type Delivery struct {
// a json serialize / deserialize
// able shape that minimizes data.
type delivery struct {
PubKeyID string `json:"pub_key_id,omitempty"`
ActorID string `json:"actor_id,omitempty"`
ObjectID string `json:"object_id,omitempty"`
TargetID string `json:"target_id,omitempty"`
@ -101,7 +96,6 @@ func (dlv *Delivery) Serialize() ([]byte, error) {
// Marshal as internal JSON type.
return json.Marshal(delivery{
PubKeyID: dlv.PubKeyID,
ActorID: dlv.ActorID,
ObjectID: dlv.ObjectID,
TargetID: dlv.TargetID,
@ -125,7 +119,6 @@ func (dlv *Delivery) Deserialize(data []byte) error {
}
// Copy over simplest fields.
dlv.PubKeyID = idlv.PubKeyID
dlv.ActorID = idlv.ActorID
dlv.ObjectID = idlv.ObjectID
dlv.TargetID = idlv.TargetID
@ -143,6 +136,13 @@ func (dlv *Delivery) Deserialize(data []byte) error {
return err
}
// Copy over any stored header values.
for key, values := range idlv.Header {
for _, value := range values {
r.Header.Add(key, value)
}
}
// Wrap request in httpclient type.
dlv.Request = httpclient.WrapRequest(r)

View file

@ -35,32 +35,30 @@ var deliveryCases = []struct {
}{
{
msg: delivery.Delivery{
PubKeyID: "https://google.com/users/bigboy#pubkey",
ActorID: "https://google.com/users/bigboy",
ObjectID: "https://google.com/users/bigboy/follow/1",
TargetID: "https://askjeeves.com/users/smallboy",
Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!")),
Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Hello": {"world1", "world2"}}),
},
data: toJSON(map[string]any{
"pub_key_id": "https://google.com/users/bigboy#pubkey",
"actor_id": "https://google.com/users/bigboy",
"object_id": "https://google.com/users/bigboy/follow/1",
"target_id": "https://askjeeves.com/users/smallboy",
"method": "POST",
"url": "https://askjeeves.com/users/smallboy/inbox",
"body": []byte("data!"),
// "header": map[string][]string{},
"actor_id": "https://google.com/users/bigboy",
"object_id": "https://google.com/users/bigboy/follow/1",
"target_id": "https://askjeeves.com/users/smallboy",
"method": "POST",
"url": "https://askjeeves.com/users/smallboy/inbox",
"body": []byte("data!"),
"header": map[string][]string{"Hello": {"world1", "world2"}},
}),
},
{
msg: delivery.Delivery{
Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin")),
Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), nil),
},
data: toJSON(map[string]any{
"method": "GET",
"url": "https://google.com",
"body": []byte("uwu im just a wittle seawch engwin"),
// "header": map[string][]string{},
// "header": map[string][]string{},
}),
},
}
@ -89,18 +87,18 @@ func TestDeserializeDelivery(t *testing.T) {
}
// Check that delivery fields are as expected.
assert.Equal(t, test.msg.PubKeyID, msg.PubKeyID)
assert.Equal(t, test.msg.ActorID, msg.ActorID)
assert.Equal(t, test.msg.ObjectID, msg.ObjectID)
assert.Equal(t, test.msg.TargetID, msg.TargetID)
assert.Equal(t, test.msg.Request.Method, msg.Request.Method)
assert.Equal(t, test.msg.Request.URL, msg.Request.URL)
assert.Equal(t, readBody(test.msg.Request.Body), readBody(msg.Request.Body))
assert.Equal(t, test.msg.Request.Header, msg.Request.Header)
}
}
// toRequest creates httpclient.Request from HTTP method, URL and body data.
func toRequest(method string, url string, body []byte) httpclient.Request {
func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request {
var rbody io.Reader
if body != nil {
rbody = bytes.NewReader(body)
@ -109,6 +107,11 @@ func toRequest(method string, url string, body []byte) httpclient.Request {
if err != nil {
panic(err)
}
for key, values := range hdr {
for _, value := range values {
req.Header.Add(key, value)
}
}
return httpclient.WrapRequest(req)
}

View file

@ -19,6 +19,7 @@ package delivery
import (
"context"
"errors"
"slices"
"time"
@ -160,6 +161,13 @@ func (w *Worker) process(ctx context.Context) bool {
loop:
for {
// Before trying to get
// next delivery, check
// context still valid.
if ctx.Err() != nil {
return true
}
// Get next delivery.
dlv, ok := w.next(ctx)
if !ok {
@ -195,16 +203,30 @@ loop:
// Attempt delivery of AP request.
rsp, retry, err := w.Client.DoOnce(
&dlv.Request,
dlv.Request,
)
if err == nil {
switch {
case err == nil:
// Ensure body closed.
_ = rsp.Body.Close()
continue loop
}
if !retry {
case errors.Is(err, context.Canceled) &&
ctx.Err() != nil:
// In the case of our own context
// being cancelled, push delivery
// back onto queue for persisting.
//
// Note we specifically check against
// context.Canceled here as it will
// be faster than the mutex lock of
// ctx.Err(), so gives an initial
// faster check in the if-clause.
w.Queue.Push(dlv)
continue loop
case !retry:
// Drop deliveries when no
// retry requested, or they
// reached max (either).
@ -222,42 +244,36 @@ loop:
// next gets the next available delivery, blocking until available if necessary.
func (w *Worker) next(ctx context.Context) (*Delivery, bool) {
loop:
for {
// Try pop next queued.
dlv, ok := w.Queue.Pop()
// Try a fast-pop of queued
// delivery before anything.
dlv, ok := w.Queue.Pop()
if !ok {
// Check the backlog.
if len(w.backlog) > 0 {
if !ok {
// Check the backlog.
if len(w.backlog) > 0 {
// Sort by 'next' time.
sortDeliveries(w.backlog)
// Sort by 'next' time.
sortDeliveries(w.backlog)
// Pop next delivery.
dlv := w.popBacklog()
// Pop next delivery.
dlv := w.popBacklog()
return dlv, true
}
select {
// Backlog is empty, we MUST
// block until next enqueued.
case <-w.Queue.Wait():
continue loop
// Worker was stopped.
case <-ctx.Done():
return nil, false
}
return dlv, true
}
// Replace request context for worker state canceling.
ctx := gtscontext.WithValues(ctx, dlv.Request.Context())
dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
return dlv, true
// Block on next delivery push
// OR worker context canceled.
dlv, ok = w.Queue.PopCtx(ctx)
if !ok {
return nil, false
}
}
// Replace request context for worker state canceling.
ctx = gtscontext.WithValues(ctx, dlv.Request.Context())
dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
return dlv, true
}
// popBacklog pops next available from the backlog.

View file

@ -30,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
"github.com/superseriousbusiness/gotosocial/internal/transport/delivery"
"github.com/superseriousbusiness/httpsig"
)
@ -50,6 +51,10 @@ type Transport interface {
// transport client, retrying on certain preset errors.
POST(*http.Request, []byte) (*http.Response, error)
// SignDelivery adds HTTP request signing client "middleware"
// to the request context within given delivery.Delivery{}.
SignDelivery(*delivery.Delivery) error
// Deliver sends an ActivityStreams object.
Deliver(ctx context.Context, obj map[string]interface{}, to *url.URL) error