diff --git a/go.mod b/go.mod
index 01b5338fa..b8752992c 100644
--- a/go.mod
+++ b/go.mod
@@ -21,7 +21,6 @@ require (
github.com/go-errors/errors v1.4.0 // indirect
github.com/go-fed/activity v1.0.1-0.20210803212804-d866ba75dd0f
github.com/go-fed/httpsig v1.1.0
- github.com/go-pg/pg/extra/pgdebug v0.2.0
github.com/go-pg/pg/v10 v10.10.3
github.com/go-playground/validator/v10 v10.7.0 // indirect
github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 // indirect
@@ -47,13 +46,16 @@ require (
github.com/superseriousbusiness/oauth2/v4 v4.3.0-SSB
github.com/tdewolff/minify/v2 v2.9.21
github.com/tidwall/buntdb v1.2.4 // indirect
+ github.com/uptrace/bun v0.4.3
+ github.com/uptrace/bun/dialect/pgdialect v0.4.3
+ github.com/uptrace/bun/driver/pgdriver v0.4.3
github.com/urfave/cli/v2 v2.3.0
- github.com/vmihailenco/msgpack/v5 v5.3.4 // indirect
github.com/wagslane/go-password-validator v0.3.0
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97
golang.org/x/oauth2 v0.0.0-20210628180205-a41e5a781914
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect
golang.org/x/text v0.3.6
+ google.golang.org/appengine v1.6.7 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
diff --git a/go.sum b/go.sum
index d074f2c62..9acb6dc7a 100644
--- a/go.sum
+++ b/go.sum
@@ -131,9 +131,6 @@ github.com/go-fed/httpsig v1.1.0/go.mod h1:RCMrTZvN1bJYtofsG4rd5NaO5obxQ5xBkdiS7
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
-github.com/go-pg/pg/extra/pgdebug v0.2.0 h1:t62UhMiV6KYAxSWojwIJiyX06TdepkzCeIzdeb00184=
-github.com/go-pg/pg/extra/pgdebug v0.2.0/go.mod h1:KmW//PLshMAQunfInLv9mFIbYXuGplOY9bc6qo3CaY0=
-github.com/go-pg/pg/v10 v10.6.2/go.mod h1:BfgPoQnD2wXNd986RYEHzikqv9iE875PrFaZ9vXvtNM=
github.com/go-pg/pg/v10 v10.10.3 h1:WobSfk5I+v7XwD1h9x2B7n4slDzjdBIonJ5PID95Aag=
github.com/go-pg/pg/v10 v10.10.3/go.mod h1:EmoJGYErc+stNN/1Jf+o4csXuprjxcRztBnn6cHe38E=
github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU=
@@ -200,7 +197,6 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
@@ -382,6 +378,12 @@ github.com/ugorji/go v1.2.6/go.mod h1:anCg0y61KIhDlPZmnH+so+RQbysYVyDko0IMgJv0Nn
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/ugorji/go/codec v1.2.6 h1:7kbGefxLoDBuYXOms4yD7223OpNMMPNPZxXk5TvFcyQ=
github.com/ugorji/go/codec v1.2.6/go.mod h1:V6TCNZ4PHqoHGFZuSG1W8nrCzzdgA2DozYxWFFpvxTw=
+github.com/uptrace/bun v0.4.3 h1:x6bjDqwjxwM/9Q1eauhkznuvTrz/rLiCK2p4tT63sAE=
+github.com/uptrace/bun v0.4.3/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw=
+github.com/uptrace/bun/dialect/pgdialect v0.4.3 h1:lM2IUKpU99110chKkupw3oTfXiOKpB0hTJIe6frqQDo=
+github.com/uptrace/bun/dialect/pgdialect v0.4.3/go.mod h1:BaNvWejl32oKUhwpFkw/eNcWldzIlVY4nfw/sNul0s8=
+github.com/uptrace/bun/driver/pgdriver v0.4.3 h1:WLtUL3xtnZuryRcXII8PV8dm6UfEMWQniHFmV5T4tEw=
+github.com/uptrace/bun/driver/pgdriver v0.4.3/go.mod h1:CQsGmzHK9Sq70avzRy7aFYAomoT3XihjGPtRDSToV+0=
github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
@@ -391,12 +393,9 @@ github.com/valyala/fasthttp v1.14.0/go.mod h1:ol1PCaL0dX20wC0htZ7sYCsvCYmrouYra0
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94=
github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ=
-github.com/vmihailenco/msgpack/v4 v4.3.11/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+NXzzngzBKDPIqw4=
-github.com/vmihailenco/msgpack/v5 v5.0.0-beta.1/go.mod h1:xlngVLeyQ/Qi05oQxhQ+oTuqa03RjMwMfk/7/TCs+QI=
github.com/vmihailenco/msgpack/v5 v5.3.1/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc=
github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
-github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI=
github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc=
github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
@@ -426,7 +425,6 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
-go.opentelemetry.io/otel v0.13.0/go.mod h1:dlSNewoRYikTkotEnxdmuBHgzT+k/idJSfDv/FxEnOY=
go.uber.org/goleak v0.10.0 h1:G3eWbSNIskeRqtsN/1uI5B+eP73y3JUuBsv9AZjehb4=
go.uber.org/goleak v0.10.0/go.mod h1:VCZuO8V8mFPlL0F5J5GK1rtHV3DrFcQ1R8ryq7FK0aI=
golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@@ -436,7 +434,6 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
@@ -502,7 +499,6 @@ golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
-golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
@@ -562,7 +558,6 @@ golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -570,6 +565,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
diff --git a/internal/db/account.go b/internal/db/account.go
index 0e1575f9b..61d97bf8c 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -19,6 +19,7 @@
package db
import (
+ "context"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -27,40 +28,40 @@ import (
// Account contains functions related to account getting/setting/creation.
type Account interface {
// GetAccountByID returns one account with the given ID, or an error if something goes wrong.
- GetAccountByID(id string) (*gtsmodel.Account, Error)
+ GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, Error)
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
- GetAccountByURI(uri string) (*gtsmodel.Account, Error)
+ GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
- GetAccountByURL(uri string) (*gtsmodel.Account, Error)
+ GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetLocalAccountByUsername returns an account on this instance by its username.
- GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error)
+ GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error)
// GetAccountFaves fetches faves/likes created by the target accountID.
- GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error)
+ GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, Error)
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
- CountAccountStatuses(accountID string) (int, Error)
+ CountAccountStatuses(ctx context.Context, accountID string) (int, Error)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
- GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
+ GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
- GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
+ GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
//
// The returned time will be zero if account has never posted anything.
- GetAccountLastPosted(accountID string) (time.Time, Error)
+ GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, Error)
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
- SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
+ SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
// GetInstanceAccount returns the instance account for the given domain.
// If domain is empty, this instance account will be returned.
- GetInstanceAccount(domain string) (*gtsmodel.Account, Error)
+ GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, Error)
}
diff --git a/internal/db/admin.go b/internal/db/admin.go
index aa2b22f47..f9b4f821e 100644
--- a/internal/db/admin.go
+++ b/internal/db/admin.go
@@ -19,6 +19,7 @@
package db
import (
+ "context"
"net"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -28,26 +29,26 @@ import (
type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
- IsUsernameAvailable(username string) Error
+ IsUsernameAvailable(ctx context.Context, username string) Error
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
- IsEmailAvailable(email string) Error
+ IsEmailAvailable(ctx context.Context, email string) Error
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
- NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
+ NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
- CreateInstanceAccount() Error
+ CreateInstanceAccount(ctx context.Context) Error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
- CreateInstanceInstance() Error
+ CreateInstanceInstance(ctx context.Context) Error
}
diff --git a/internal/db/basic.go b/internal/db/basic.go
index 729920bba..d48ffb018 100644
--- a/internal/db/basic.go
+++ b/internal/db/basic.go
@@ -24,15 +24,15 @@ import "context"
type Basic interface {
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
- CreateTable(i interface{}) Error
+ CreateTable(ctx context.Context, i interface{}) Error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
- DropTable(i interface{}) Error
+ DropTable(ctx context.Context, i interface{}) Error
// RegisterTable registers a table for use in many2many relations.
// For implementations that don't use tables, or many2many relations, this can just return nil.
- RegisterTable(i interface{}) Error
+ RegisterTable(ctx context.Context, i interface{}) Error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
@@ -45,43 +45,43 @@ type Basic interface {
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetByID(id string, i interface{}) Error
+ GetByID(ctx context.Context, id string, i interface{}) Error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetWhere(where []Where, i interface{}) Error
+ GetWhere(ctx context.Context, where []Where, i interface{}) Error
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetAll(i interface{}) Error
+ GetAll(ctx context.Context, i interface{}) Error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Put(i interface{}) Error
+ Put(ctx context.Context, i interface{}) Error
// Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
// It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Upsert(i interface{}, conflictColumn string) Error
+ Upsert(ctx context.Context, i interface{}, conflictColumn string) Error
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- UpdateByID(id string, i interface{}) Error
+ UpdateByID(ctx context.Context, id string, i interface{}) Error
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
- UpdateOneByID(id string, key string, value interface{}, i interface{}) Error
+ UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) Error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
- UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error
+ UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
- DeleteByID(id string, i interface{}) Error
+ DeleteByID(ctx context.Context, id string, i interface{}) Error
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
- DeleteWhere(where []Where, i interface{}) Error
+ DeleteWhere(ctx context.Context, where []Where, i interface{}) Error
}
diff --git a/internal/db/db.go b/internal/db/db.go
index d6ac883e4..71ac887bb 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -19,6 +19,8 @@
package db
import (
+ "context"
+
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -52,7 +54,7 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the accounts in the DB, it's just for checking
// if they exist in the db and conveniently returning them if they do.
- MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error)
+ MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error)
// TagStringsToTags takes a slice of deduplicated, lowercase tags in the form "somehashtag", which have been
// used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then
@@ -61,7 +63,7 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the tags in the DB, it's just for checking
// if they exist in the db already, and conveniently returning them, or creating new tag structs.
- TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error)
+ TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error)
// EmojiStringsToEmojis takes a slice of deduplicated, lowercase emojis in the form ":emojiname:", which have been
// used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then
@@ -69,5 +71,5 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the emoji in the DB, it's just for checking
// if they exist in the db and conveniently returning them if they do.
- EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error)
+ EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error)
}
diff --git a/internal/db/domain.go b/internal/db/domain.go
index a6583c80c..df50a6770 100644
--- a/internal/db/domain.go
+++ b/internal/db/domain.go
@@ -18,19 +18,22 @@
package db
-import "net/url"
+import (
+ "context"
+ "net/url"
+)
// Domain contains DB functions related to domains and domain blocks.
type Domain interface {
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
- IsDomainBlocked(domain string) (bool, Error)
+ IsDomainBlocked(ctx context.Context, domain string) (bool, Error)
// AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
- AreDomainsBlocked(domains []string) (bool, Error)
+ AreDomainsBlocked(ctx context.Context, domains []string) (bool, Error)
// IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
- IsURIBlocked(uri *url.URL) (bool, Error)
+ IsURIBlocked(ctx context.Context, uri *url.URL) (bool, Error)
// AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
- AreURIsBlocked(uris []*url.URL) (bool, Error)
+ AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, Error)
}
diff --git a/internal/db/instance.go b/internal/db/instance.go
index 1f7c83e4f..dcd978a81 100644
--- a/internal/db/instance.go
+++ b/internal/db/instance.go
@@ -18,19 +18,23 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Instance contains functions for instance-level actions (counting instance users etc.).
type Instance interface {
// CountInstanceUsers returns the number of known accounts registered with the given domain.
- CountInstanceUsers(domain string) (int, Error)
+ CountInstanceUsers(ctx context.Context, domain string) (int, Error)
// CountInstanceStatuses returns the number of known statuses posted from the given domain.
- CountInstanceStatuses(domain string) (int, Error)
+ CountInstanceStatuses(ctx context.Context, domain string) (int, Error)
// CountInstanceDomains returns the number of known instances known that the given domain federates with.
- CountInstanceDomains(domain string) (int, Error)
+ CountInstanceDomains(ctx context.Context, domain string) (int, Error)
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
- GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
+ GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
}
diff --git a/internal/db/media.go b/internal/db/media.go
index db4db3411..b779dd276 100644
--- a/internal/db/media.go
+++ b/internal/db/media.go
@@ -18,10 +18,14 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Media contains functions related to creating/getting/removing media attachments.
type Media interface {
// GetAttachmentByID gets a single attachment by its ID
- GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error)
+ GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error)
}
diff --git a/internal/db/mention.go b/internal/db/mention.go
index cb1c56dc1..b9b45546a 100644
--- a/internal/db/mention.go
+++ b/internal/db/mention.go
@@ -18,13 +18,17 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Mention contains functions for getting/creating mentions in the database.
type Mention interface {
// GetMention gets a single mention by ID
- GetMention(id string) (*gtsmodel.Mention, Error)
+ GetMention(ctx context.Context, id string) (*gtsmodel.Mention, Error)
// GetMentions gets multiple mentions.
- GetMentions(ids []string) ([]*gtsmodel.Mention, Error)
+ GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error)
}
diff --git a/internal/db/notification.go b/internal/db/notification.go
index 326f0f149..09c17f031 100644
--- a/internal/db/notification.go
+++ b/internal/db/notification.go
@@ -18,14 +18,18 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Notification contains functions for creating and getting notifications.
type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
- GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
+ GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
// GetNotification returns one notification according to its id.
- GetNotification(id string) (*gtsmodel.Notification, Error)
+ GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error)
}
diff --git a/internal/db/pg/account.go b/internal/db/pg/account.go
index 3889c6601..4cf69df18 100644
--- a/internal/db/pg/account.go
+++ b/internal/db/pg/account.go
@@ -24,61 +24,62 @@ import (
"fmt"
"time"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type accountDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
-func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
- return a.conn.Model(account).
+func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
+ return a.conn.
+ NewSelect().
+ Model(account).
Relation("AvatarMediaAttachment").
Relation("HeaderMediaAttachment")
}
-func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
account := >smodel.Account{}
q := a.newAccountQ(account).
Where("account.id = ?", id)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
account := >smodel.Account{}
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
account := >smodel.Account{}
q := a.newAccountQ(account).
Where("account.url = ?", uri)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
account := >smodel.Account{}
q := a.newAccountQ(account)
@@ -90,29 +91,31 @@ func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Err
} else {
q = q.
Where("account.username = ?", domain).
- Where("? IS NULL", pg.Ident("domain"))
+ Where("? IS NULL", bun.Ident("domain"))
}
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) {
+func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) {
status := >smodel.Status{}
- q := a.conn.Model(status).
+ q := a.conn.
+ NewSelect().
+ Model(status).
Order("id DESC").
Limit(1).
Where("account_id = ?", accountID).
Column("created_at")
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return status.CreatedAt, err
}
-func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
+func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
@@ -127,51 +130,58 @@ func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAtta
}
// TODO: there are probably more side effects here that need to be handled
- if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
+ if _, err := a.conn.NewInsert().Model(mediaAttachment).On("CONFLICT (id) DO UPDATE").Exec(ctx); err != nil {
return err
}
- if _, err := a.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
+ if _, err := a.conn.NewInsert().Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Exec(ctx); err != nil {
return err
}
return nil
}
-func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, db.Error) {
account := >smodel.Account{}
q := a.newAccountQ(account).
Where("username = ?", username).
- Where("? IS NULL", pg.Ident("domain"))
+ Where("? IS NULL", bun.Ident("domain"))
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) {
+func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
- if err := a.conn.Model(&faves).
+ if err := a.conn.
+ NewSelect().
+ Model(&faves).
Where("account_id = ?", accountID).
- Select(); err != nil {
- if err == pg.ErrNoRows {
- return faves, nil
- }
+ Scan(ctx); err != nil {
return nil, err
}
return faves, nil
}
-func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) {
- return a.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count()
+func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
+ return a.conn.
+ NewSelect().
+ Model(>smodel.Status{}).
+ Where("account_id = ?", accountID).
+ Count(ctx)
}
-func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
+func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
- q := a.conn.Model(&statuses).Order("id DESC")
+ q := a.conn.
+ NewSelect().
+ Model(&statuses).
+ Order("id DESC")
+
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
@@ -181,7 +191,7 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli
}
if excludeReplies {
- q = q.Where("? IS NULL", pg.Ident("in_reply_to_id"))
+ q = q.Where("? IS NULL", bun.Ident("in_reply_to_id"))
}
if pinnedOnly {
@@ -189,8 +199,10 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli
}
if mediaOnly {
- q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) {
- return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil
+ q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
+ return q.
+ WhereOr("? IS NOT NULL", bun.Ident("attachments")).
+ WhereOr("attachments != '{}'")
})
}
@@ -198,10 +210,7 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli
q = q.Where("id < ?", maxID)
}
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
+ if err := q.Scan(ctx); err != nil {
return nil, err
}
@@ -213,10 +222,12 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli
return statuses, nil
}
-func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
+func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
blocks := []*gtsmodel.Block{}
- fq := a.conn.Model(&blocks).
+ fq := a.conn.
+ NewSelect().
+ Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
@@ -233,11 +244,8 @@ func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID str
fq = fq.Limit(limit)
}
- err := fq.Select()
+ err := fq.Scan(ctx)
if err != nil {
- if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries
- }
return nil, "", "", err
}
diff --git a/internal/db/pg/account_test.go b/internal/db/pg/account_test.go
index 7ea5ff39a..df4d244bf 100644
--- a/internal/db/pg/account_test.go
+++ b/internal/db/pg/account_test.go
@@ -19,6 +19,7 @@
package pg_test
import (
+ "context"
"testing"
"github.com/stretchr/testify/suite"
@@ -54,7 +55,7 @@ func (suite *AccountTestSuite) TearDownTest() {
}
func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
- account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID)
+ account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
diff --git a/internal/db/pg/admin.go b/internal/db/pg/admin.go
index 854f56ef0..6319ba273 100644
--- a/internal/db/pg/admin.go
+++ b/internal/db/pg/admin.go
@@ -35,21 +35,27 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
)
type adminDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
-func (a *adminDB) IsUsernameAvailable(username string) db.Error {
+func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) db.Error {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
- if err := a.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
+ if err := a.conn.
+ NewSelect().
+ Model(>smodel.Account{}).
+ Where("username = ?", username).
+ Where("domain = ?", nil).
+ Scan(ctx); err == nil {
return fmt.Errorf("username %s already in use", username)
} else if err != pg.ErrNoRows {
return fmt.Errorf("db error: %s", err)
diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go
index 6e76b4450..f43b80afe 100644
--- a/internal/db/pg/basic.go
+++ b/internal/db/pg/basic.go
@@ -29,11 +29,12 @@ import (
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/uptrace/bun"
)
type basicDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
diff --git a/internal/db/pg/domain.go b/internal/db/pg/domain.go
index 4e9b2ab48..1666d8b11 100644
--- a/internal/db/pg/domain.go
+++ b/internal/db/pg/domain.go
@@ -22,17 +22,17 @@ import (
"context"
"net/url"
- "github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
)
type domainDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
diff --git a/internal/db/pg/instance.go b/internal/db/pg/instance.go
index 968832ca5..2f0c326ca 100644
--- a/internal/db/pg/instance.go
+++ b/internal/db/pg/instance.go
@@ -26,11 +26,12 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type instanceDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
diff --git a/internal/db/pg/media.go b/internal/db/pg/media.go
index 618030af3..4b421cca4 100644
--- a/internal/db/pg/media.go
+++ b/internal/db/pg/media.go
@@ -21,17 +21,17 @@ package pg
import (
"context"
- "github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type mediaDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
diff --git a/internal/db/pg/mention.go b/internal/db/pg/mention.go
index b31f07b67..fac7ca5ad 100644
--- a/internal/db/pg/mention.go
+++ b/internal/db/pg/mention.go
@@ -21,18 +21,18 @@ package pg
import (
"context"
- "github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type mentionDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
diff --git a/internal/db/pg/notification.go b/internal/db/pg/notification.go
index 281a76d85..27fe00b1e 100644
--- a/internal/db/pg/notification.go
+++ b/internal/db/pg/notification.go
@@ -21,18 +21,18 @@ package pg
import (
"context"
- "github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type notificationDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go
index 0437baf02..7626386ee 100644
--- a/internal/db/pg/pg.go
+++ b/internal/db/pg/pg.go
@@ -22,6 +22,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
+ "database/sql"
"encoding/pem"
"errors"
"fmt"
@@ -29,7 +30,6 @@ import (
"strings"
"time"
- "github.com/go-pg/pg/extra/pgdebug"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
@@ -37,6 +37,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/uptrace/bun"
+ "github.com/uptrace/bun/dialect/pgdialect"
+ "github.com/uptrace/bun/driver/pgdriver"
)
var registerTables []interface{} = []interface{}{
@@ -77,105 +80,75 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
}
log.Debugf("using pg options: %+v", opts)
- // create a connection
- pgCtx, cancel := context.WithCancel(ctx)
- conn := pg.Connect(opts).WithContext(pgCtx)
+ sqldb := sql.OpenDB(pgdriver.NewConnector(opts...))
- // this will break the logfmt format we normally log in,
- // since we can't choose where pg outputs to and it defaults to
- // stdout. So use this option with care!
- if log.GetLevel() >= logrus.TraceLevel {
- conn.AddQueryHook(pgdebug.DebugHook{
- // Print all queries.
- Verbose: true,
- })
- }
+ conn := bun.NewDB(sqldb, pgdialect.New())
// actually *begin* the connection so that we can tell if the db is there and listening
- if err := conn.Ping(ctx); err != nil {
- cancel()
+ if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("db connection error: %s", err)
}
-
- // print out discovered postgres version
- var version string
- if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil {
- cancel()
- return nil, fmt.Errorf("db connection error: %s", err)
- }
- log.Infof("connected to postgres version: %s", version)
+ log.Info("connected to postgres")
ps := &postgresService{
Account: &accountDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Admin: &adminDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Basic: &basicDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Domain: &domainDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Instance: &instanceDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Media: &mediaDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Mention: &mentionDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Notification: ¬ificationDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Relationship: &relationshipDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Status: &statusDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Timeline: &timelineDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
config: c,
conn: conn,
log: log,
- cancel: cancel,
}
// we can confidently return this useable postgres service now
@@ -188,7 +161,7 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options
// with sensible defaults, or an error if it's not satisfied by the provided config.
-func derivePGOptions(c *config.Config) (*pg.Options, error) {
+func derivePGOptions(c *config.Config) ([]pgdriver.DriverOption, error) {
if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
}
@@ -268,13 +241,13 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
// We can rely on the pg library we're using to set
// sensible defaults for everything we don't set here.
- options := &pg.Options{
- Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port),
- User: c.DBConfig.User,
- Password: c.DBConfig.Password,
- Database: c.DBConfig.Database,
- ApplicationName: c.ApplicationName,
- TLSConfig: tlsConfig,
+ options := []pgdriver.DriverOption{
+ pgdriver.WithAddr(fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port)),
+ pgdriver.WithUser(c.DBConfig.User),
+ pgdriver.WithPassword(c.DBConfig.Password),
+ pgdriver.WithDatabase(c.DBConfig.Database),
+ pgdriver.WithApplicationName(c.ApplicationName),
+ pgdriver.WithTLSConfig(tlsConfig),
}
return options, nil
diff --git a/internal/db/pg/relationship.go b/internal/db/pg/relationship.go
index 76bd50c76..35f5a7eab 100644
--- a/internal/db/pg/relationship.go
+++ b/internal/db/pg/relationship.go
@@ -28,11 +28,12 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type relationshipDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go
index 99790428e..312845765 100644
--- a/internal/db/pg/status.go
+++ b/internal/db/pg/status.go
@@ -31,11 +31,12 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type statusDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
diff --git a/internal/db/pg/timeline.go b/internal/db/pg/timeline.go
index fa8b07aab..a00e39942 100644
--- a/internal/db/pg/timeline.go
+++ b/internal/db/pg/timeline.go
@@ -27,11 +27,12 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type timelineDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
diff --git a/internal/db/pg/util.go b/internal/db/pg/util.go
index 17c09b720..e6d31780a 100644
--- a/internal/db/pg/util.go
+++ b/internal/db/pg/util.go
@@ -3,7 +3,6 @@ package pg
import (
"strings"
- "github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
@@ -12,9 +11,9 @@ func processErrorResponse(err error) db.Error {
switch err {
case nil:
return nil
- case pg.ErrNoRows:
+ case bun.ErrNoRows:
return db.ErrNoEntries
- case pg.ErrMultiRows:
+ case bun.ErrMultiRows:
return db.ErrMultipleEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index 85f64d72b..804526425 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -18,54 +18,58 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
// IsBlocked checks whether account 1 has a block in place against block2.
// If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
- IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error)
+ IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
//
// Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
// not if you're just checking for the existence of a block.
- GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error)
+ GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
- GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
+ GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
- IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
- IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
- IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
+ IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
- AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
+ AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
// GetAccountFollowRequests returns all follow requests targeting the given account.
- GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error)
+ GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, Error)
// GetAccountFollows returns a slice of follows owned by the given accountID.
- GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error)
+ GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, Error)
// CountAccountFollows returns the amount of accounts that the given accountID is following.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
- CountAccountFollows(accountID string, localOnly bool) (int, Error)
+ CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, Error)
// GetAccountFollowedBy fetches follows that target given accountID.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
- GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
+ GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
// CountAccountFollowedBy returns the amounts that the given ID is followed by.
- CountAccountFollowedBy(accountID string, localOnly bool) (int, Error)
+ CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, Error)
}
diff --git a/internal/db/status.go b/internal/db/status.go
index 9d206c198..7430433c4 100644
--- a/internal/db/status.go
+++ b/internal/db/status.go
@@ -18,58 +18,62 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface {
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
- GetStatusByID(id string) (*gtsmodel.Status, Error)
+ GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error)
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
- GetStatusByURI(uri string) (*gtsmodel.Status, Error)
+ GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
- GetStatusByURL(uri string) (*gtsmodel.Status, Error)
+ GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// PutStatus stores one status in the database.
- PutStatus(status *gtsmodel.Status) Error
+ PutStatus(ctx context.Context, status *gtsmodel.Status) Error
// CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong
- CountStatusReplies(status *gtsmodel.Status) (int, Error)
+ CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, Error)
// CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
- CountStatusReblogs(status *gtsmodel.Status) (int, Error)
+ CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, Error)
// CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong
- CountStatusFaves(status *gtsmodel.Status) (int, Error)
+ CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, Error)
// GetStatusParents gets the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
- GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
+ GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
// GetStatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
- GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
+ GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
// IsStatusFavedBy checks if a given status has been faved by a given account ID
- IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
- IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusMutedBy checks if a given status has been muted by a given account ID
- IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
- IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
+ GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
// GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
+ GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
}
diff --git a/internal/db/timeline.go b/internal/db/timeline.go
index 74aa5c781..83fb3a959 100644
--- a/internal/db/timeline.go
+++ b/internal/db/timeline.go
@@ -18,20 +18,24 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Timeline contains functionality for retrieving home/public/faved etc timelines for an account.
type Timeline interface {
// GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
- GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
+ GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
- GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
+ GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
@@ -40,5 +44,5 @@ type Timeline interface {
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
- GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
+ GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
}
diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go
index ba6766061..50ec04f51 100644
--- a/internal/federation/dereferencing/account.go
+++ b/internal/federation/dereferencing/account.go
@@ -39,12 +39,12 @@ import (
//
// EnrichRemoteAccount is mostly useful for calling after an account has been initially created by
// the federatingDB's Create function, or during the federated authorization flow.
-func (d *deref) EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error) {
- if err := d.PopulateAccountFields(account, username, false); err != nil {
+func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) {
+ if err := d.PopulateAccountFields(ctx, account, username, false); err != nil {
return nil, err
}
- if err := d.db.UpdateByID(account.ID, account); err != nil {
+ if err := d.db.UpdateByID(ctx, account.ID, account); err != nil {
return nil, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err)
}
@@ -60,22 +60,22 @@ func (d *deref) EnrichRemoteAccount(username string, account *gtsmodel.Account)
// the remote instance again.
//
// SIDE EFFECTS: remote account will be stored in the database, or updated if it already exists (and refresh is true).
-func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) {
+func (d *deref) GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) {
new := true
// check if we already have the account in our db
- maybeAccount, err := d.db.GetAccountByURI(remoteAccountID.String())
+ maybeAccount, err := d.db.GetAccountByURI(ctx, remoteAccountID.String())
if err == nil {
// we've seen this account before so it's not new
new = false
if !refresh {
// we're not being asked to refresh, but just in case we don't have the avatar/header cached yet....
- maybeAccount, err = d.EnrichRemoteAccount(username, maybeAccount)
+ maybeAccount, err = d.EnrichRemoteAccount(ctx, username, maybeAccount)
return maybeAccount, new, err
}
}
- accountable, err := d.dereferenceAccountable(username, remoteAccountID)
+ accountable, err := d.dereferenceAccountable(ctx, username, remoteAccountID)
if err != nil {
return nil, new, fmt.Errorf("FullyDereferenceAccount: error dereferencing accountable: %s", err)
}
@@ -93,22 +93,22 @@ func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refr
}
gtsAccount.ID = ulid
- if err := d.PopulateAccountFields(gtsAccount, username, refresh); err != nil {
+ if err := d.PopulateAccountFields(ctx, gtsAccount, username, refresh); err != nil {
return nil, new, fmt.Errorf("FullyDereferenceAccount: error populating further account fields: %s", err)
}
- if err := d.db.Put(gtsAccount); err != nil {
+ if err := d.db.Put(ctx, gtsAccount); err != nil {
return nil, new, fmt.Errorf("FullyDereferenceAccount: error putting new account: %s", err)
}
} else {
// take the id we already have and do an update
gtsAccount.ID = maybeAccount.ID
- if err := d.PopulateAccountFields(gtsAccount, username, refresh); err != nil {
+ if err := d.PopulateAccountFields(ctx, gtsAccount, username, refresh); err != nil {
return nil, new, fmt.Errorf("FullyDereferenceAccount: error populating further account fields: %s", err)
}
- if err := d.db.UpdateByID(gtsAccount.ID, gtsAccount); err != nil {
+ if err := d.db.UpdateByID(ctx, gtsAccount.ID, gtsAccount); err != nil {
return nil, new, fmt.Errorf("FullyDereferenceAccount: error updating existing account: %s", err)
}
}
@@ -120,15 +120,15 @@ func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refr
// it finds as something that an account model can be constructed out of.
//
// Will work for Person, Application, or Service models.
-func (d *deref) dereferenceAccountable(username string, remoteAccountID *url.URL) (ap.Accountable, error) {
+func (d *deref) dereferenceAccountable(ctx context.Context, username string, remoteAccountID *url.URL) (ap.Accountable, error) {
d.startHandshake(username, remoteAccountID)
defer d.stopHandshake(username, remoteAccountID)
- if blocked, err := d.blockedDomain(remoteAccountID.Host); blocked || err != nil {
+ if blocked, err := d.db.IsDomainBlocked(ctx, remoteAccountID.Host); blocked || err != nil {
return nil, fmt.Errorf("DereferenceAccountable: domain %s is blocked", remoteAccountID.Host)
}
- transport, err := d.transportController.NewTransportForUsername(username)
+ transport, err := d.transportController.NewTransportForUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("DereferenceAccountable: transport err: %s", err)
}
@@ -174,7 +174,7 @@ func (d *deref) dereferenceAccountable(username string, remoteAccountID *url.URL
// PopulateAccountFields populates any fields on the given account that weren't populated by the initial
// dereferencing. This includes things like header and avatar etc.
-func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsername string, refresh bool) error {
+func (d *deref) PopulateAccountFields(ctx context.Context, account *gtsmodel.Account, requestingUsername string, refresh bool) error {
l := d.log.WithFields(logrus.Fields{
"func": "PopulateAccountFields",
"requestingUsername": requestingUsername,
@@ -184,17 +184,17 @@ func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsern
if err != nil {
return fmt.Errorf("PopulateAccountFields: couldn't parse account URI %s: %s", account.URI, err)
}
- if blocked, err := d.blockedDomain(accountURI.Host); blocked || err != nil {
+ if blocked, err := d.db.IsDomainBlocked(ctx, accountURI.Host); blocked || err != nil {
return fmt.Errorf("PopulateAccountFields: domain %s is blocked", accountURI.Host)
}
- t, err := d.transportController.NewTransportForUsername(requestingUsername)
+ t, err := d.transportController.NewTransportForUsername(ctx, requestingUsername)
if err != nil {
return fmt.Errorf("PopulateAccountFields: error getting transport for user: %s", err)
}
// fetch the header and avatar
- if err := d.fetchHeaderAndAviForAccount(account, t, refresh); err != nil {
+ if err := d.fetchHeaderAndAviForAccount(ctx, account, t, refresh); err != nil {
// if this doesn't work, just skip it -- we can do it later
l.Debugf("error fetching header/avi for account: %s", err)
}
@@ -208,12 +208,12 @@ func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsern
// targetAccount's AvatarMediaAttachmentID and HeaderMediaAttachmentID will be updated as necessary.
//
// SIDE EFFECTS: remote header and avatar will be stored in local storage.
-func (d *deref) fetchHeaderAndAviForAccount(targetAccount *gtsmodel.Account, t transport.Transport, refresh bool) error {
+func (d *deref) fetchHeaderAndAviForAccount(ctx context.Context, targetAccount *gtsmodel.Account, t transport.Transport, refresh bool) error {
accountURI, err := url.Parse(targetAccount.URI)
if err != nil {
return fmt.Errorf("fetchHeaderAndAviForAccount: couldn't parse account URI %s: %s", targetAccount.URI, err)
}
- if blocked, err := d.blockedDomain(accountURI.Host); blocked || err != nil {
+ if blocked, err := d.db.IsDomainBlocked(ctx, accountURI.Host); blocked || err != nil {
return fmt.Errorf("fetchHeaderAndAviForAccount: domain %s is blocked", accountURI.Host)
}
diff --git a/internal/federation/dereferencing/announce.go b/internal/federation/dereferencing/announce.go
index 6773db425..33af74ebe 100644
--- a/internal/federation/dereferencing/announce.go
+++ b/internal/federation/dereferencing/announce.go
@@ -19,6 +19,7 @@
package dereferencing
import (
+ "context"
"errors"
"fmt"
"net/url"
@@ -26,7 +27,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error {
+func (d *deref) DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error {
if announce.BoostOf == nil || announce.BoostOf.URI == "" {
// we can't do anything unfortunately
return errors.New("DereferenceAnnounce: no URI to dereference")
@@ -36,16 +37,16 @@ func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsernam
if err != nil {
return fmt.Errorf("DereferenceAnnounce: couldn't parse boosted status URI %s: %s", announce.BoostOf.URI, err)
}
- if blocked, err := d.blockedDomain(boostedStatusURI.Host); blocked || err != nil {
+ if blocked, err := d.db.IsDomainBlocked(ctx, boostedStatusURI.Host); blocked || err != nil {
return fmt.Errorf("DereferenceAnnounce: domain %s is blocked", boostedStatusURI.Host)
}
// dereference statuses in the thread of the boosted status
- if err := d.DereferenceThread(requestingUsername, boostedStatusURI); err != nil {
+ if err := d.DereferenceThread(ctx, requestingUsername, boostedStatusURI); err != nil {
return fmt.Errorf("DereferenceAnnounce: error dereferencing thread of boosted status: %s", err)
}
- boostedStatus, _, _, err := d.GetRemoteStatus(requestingUsername, boostedStatusURI, false)
+ boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false)
if err != nil {
return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.BoostOf.URI, err)
}
diff --git a/internal/federation/dereferencing/blocked.go b/internal/federation/dereferencing/blocked.go
deleted file mode 100644
index c8a4c6ade..000000000
--- a/internal/federation/dereferencing/blocked.go
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package dereferencing
-
-import (
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-func (d *deref) blockedDomain(host string) (bool, error) {
- b := >smodel.DomainBlock{}
- err := d.db.GetWhere([]db.Where{{Key: "domain", Value: host, CaseInsensitive: true}}, b)
- if err == nil {
- // block exists
- return true, nil
- }
-
- if err == db.ErrNoEntries {
- // there are no entries so there's no block
- return false, nil
- }
-
- // there's an actual error
- return false, err
-}
diff --git a/internal/federation/dereferencing/collectionpage.go b/internal/federation/dereferencing/collectionpage.go
index 5feadc1ad..6f0beeaf6 100644
--- a/internal/federation/dereferencing/collectionpage.go
+++ b/internal/federation/dereferencing/collectionpage.go
@@ -32,12 +32,12 @@ import (
)
// DereferenceCollectionPage returns the activitystreams CollectionPage at the specified IRI, or an error if something goes wrong.
-func (d *deref) DereferenceCollectionPage(username string, pageIRI *url.URL) (ap.CollectionPageable, error) {
- if blocked, err := d.blockedDomain(pageIRI.Host); blocked || err != nil {
+func (d *deref) DereferenceCollectionPage(ctx context.Context, username string, pageIRI *url.URL) (ap.CollectionPageable, error) {
+ if blocked, err := d.db.IsDomainBlocked(ctx, pageIRI.Host); blocked || err != nil {
return nil, fmt.Errorf("DereferenceCollectionPage: domain %s is blocked", pageIRI.Host)
}
- transport, err := d.transportController.NewTransportForUsername(username)
+ transport, err := d.transportController.NewTransportForUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("DereferenceCollectionPage: error creating transport: %s", err)
}
diff --git a/internal/federation/dereferencing/dereferencer.go b/internal/federation/dereferencing/dereferencer.go
index 03b90569a..71625ed88 100644
--- a/internal/federation/dereferencing/dereferencer.go
+++ b/internal/federation/dereferencing/dereferencer.go
@@ -19,6 +19,7 @@
package dereferencing
import (
+ "context"
"net/url"
"sync"
@@ -34,18 +35,18 @@ import (
// Dereferencer wraps logic and functionality for doing dereferencing of remote accounts, statuses, etc, from federated instances.
type Dereferencer interface {
- GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)
- EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error)
+ GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)
+ EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error)
- GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error)
- EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error)
+ GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error)
+ EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error)
- GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error)
+ GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error)
- DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error
- DereferenceThread(username string, statusIRI *url.URL) error
+ DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error
+ DereferenceThread(ctx context.Context, username string, statusIRI *url.URL) error
- Handshaking(username string, remoteAccountID *url.URL) bool
+ Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool
}
type deref struct {
diff --git a/internal/federation/dereferencing/handshake.go b/internal/federation/dereferencing/handshake.go
index cda8eafd0..17003be84 100644
--- a/internal/federation/dereferencing/handshake.go
+++ b/internal/federation/dereferencing/handshake.go
@@ -18,9 +18,12 @@
package dereferencing
-import "net/url"
+import (
+ "context"
+ "net/url"
+)
-func (d *deref) Handshaking(username string, remoteAccountID *url.URL) bool {
+func (d *deref) Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool {
d.handshakeSync.Lock()
defer d.handshakeSync.Unlock()
diff --git a/internal/federation/dereferencing/instance.go b/internal/federation/dereferencing/instance.go
index 80f626662..ec3c3f13d 100644
--- a/internal/federation/dereferencing/instance.go
+++ b/internal/federation/dereferencing/instance.go
@@ -26,12 +26,12 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (d *deref) GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) {
- if blocked, err := d.blockedDomain(remoteInstanceURI.Host); blocked || err != nil {
+func (d *deref) GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) {
+ if blocked, err := d.db.IsDomainBlocked(ctx, remoteInstanceURI.Host); blocked || err != nil {
return nil, fmt.Errorf("GetRemoteInstance: domain %s is blocked", remoteInstanceURI.Host)
}
- transport, err := d.transportController.NewTransportForUsername(username)
+ transport, err := d.transportController.NewTransportForUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("transport err: %s", err)
}
diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go
index 68693c021..ac42be14e 100644
--- a/internal/federation/dereferencing/status.go
+++ b/internal/federation/dereferencing/status.go
@@ -39,12 +39,12 @@ import (
//
// EnrichRemoteStatus is mostly useful for calling after a status has been initially created by
// the federatingDB's Create function, but additional dereferencing is needed on it.
-func (d *deref) EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error) {
- if err := d.populateStatusFields(status, username); err != nil {
+func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) {
+ if err := d.populateStatusFields(ctx, status, username); err != nil {
return nil, err
}
- if err := d.db.UpdateByID(status.ID, status); err != nil {
+ if err := d.db.UpdateByID(ctx, status.ID, status); err != nil {
return nil, fmt.Errorf("EnrichRemoteStatus: error updating status: %s", err)
}
@@ -62,11 +62,11 @@ func (d *deref) EnrichRemoteStatus(username string, status *gtsmodel.Status) (*g
// If a dereference was performed, then the function also returns the ap.Statusable representation for further processing.
//
// SIDE EFFECTS: remote status will be stored in the database, and the remote status owner will also be stored.
-func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
+func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
new := true
// check if we already have the status in our db
- maybeStatus, err := d.db.GetStatusByURI(remoteStatusID.String())
+ maybeStatus, err := d.db.GetStatusByURI(ctx, remoteStatusID.String())
if err == nil {
// we've seen this status before so it's not new
new = false
@@ -77,7 +77,7 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
}
}
- statusable, err := d.dereferenceStatusable(username, remoteStatusID)
+ statusable, err := d.dereferenceStatusable(ctx, username, remoteStatusID)
if err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error dereferencing statusable: %s", err)
}
@@ -88,7 +88,7 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
}
// do this so we know we have the remote account of the status in the db
- _, _, err = d.GetRemoteAccount(username, accountURI, false)
+ _, _, err = d.GetRemoteAccount(ctx, username, accountURI, false)
if err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: couldn't derive status author: %s", err)
}
@@ -105,21 +105,21 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
}
gtsStatus.ID = ulid
- if err := d.populateStatusFields(gtsStatus, username); err != nil {
+ if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)
}
- if err := d.db.PutStatus(gtsStatus); err != nil {
+ if err := d.db.PutStatus(ctx, gtsStatus); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error putting new status: %s", err)
}
} else {
gtsStatus.ID = maybeStatus.ID
- if err := d.populateStatusFields(gtsStatus, username); err != nil {
+ if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)
}
- if err := d.db.UpdateByID(gtsStatus.ID, gtsStatus); err != nil {
+ if err := d.db.UpdateByID(ctx, gtsStatus.ID, gtsStatus); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error updating status: %s", err)
}
}
@@ -127,12 +127,12 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
return gtsStatus, statusable, new, nil
}
-func (d *deref) dereferenceStatusable(username string, remoteStatusID *url.URL) (ap.Statusable, error) {
- if blocked, err := d.blockedDomain(remoteStatusID.Host); blocked || err != nil {
+func (d *deref) dereferenceStatusable(ctx context.Context, username string, remoteStatusID *url.URL) (ap.Statusable, error) {
+ if blocked, err := d.db.IsDomainBlocked(ctx, remoteStatusID.Host); blocked || err != nil {
return nil, fmt.Errorf("DereferenceStatusable: domain %s is blocked", remoteStatusID.Host)
}
- transport, err := d.transportController.NewTransportForUsername(username)
+ transport, err := d.transportController.NewTransportForUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("DereferenceStatusable: transport err: %s", err)
}
@@ -236,7 +236,7 @@ func (d *deref) dereferenceStatusable(username string, remoteStatusID *url.URL)
// This function will deference all of the above, insert them in the database as necessary,
// and attach them to the status. The status itself will not be added to the database yet,
// that's up the caller to do.
-func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername string) error {
+func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string) error {
l := d.log.WithFields(logrus.Fields{
"func": "dereferenceStatusFields",
"status": fmt.Sprintf("%+v", status),
@@ -248,12 +248,12 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
if err != nil {
return fmt.Errorf("DereferenceStatusFields: couldn't parse status URI %s: %s", status.URI, err)
}
- if blocked, err := d.blockedDomain(statusURI.Host); blocked || err != nil {
+ if blocked, err := d.db.IsDomainBlocked(ctx, statusURI.Host); blocked || err != nil {
return fmt.Errorf("DereferenceStatusFields: domain %s is blocked", statusURI.Host)
}
// we can continue -- create a new transport here because we'll probably need it
- t, err := d.transportController.NewTransportForUsername(requestingUsername)
+ t, err := d.transportController.NewTransportForUsername(ctx, requestingUsername)
if err != nil {
return fmt.Errorf("error creating transport: %s", err)
}
@@ -281,7 +281,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
// it might have been processed elsewhere so check first if it's already in the database or not
maybeAttachment := >smodel.MediaAttachment{}
- err := d.db.GetWhere([]db.Where{{Key: "remote_url", Value: a.RemoteURL}}, maybeAttachment)
+ err := d.db.GetWhere(ctx, []db.Where{{Key: "remote_url", Value: a.RemoteURL}}, maybeAttachment)
if err == nil {
// we already have it in the db, dereferenced, no need to do it again
l.Tracef("attachment already exists with id %s", maybeAttachment.ID)
@@ -302,7 +302,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
l.Debugf("dereferenced attachment: %+v", deferencedAttachment)
deferencedAttachment.StatusID = status.ID
deferencedAttachment.Description = a.Description
- if err := d.db.Put(deferencedAttachment); err != nil {
+ if err := d.db.Put(ctx, deferencedAttachment); err != nil {
return fmt.Errorf("error inserting dereferenced attachment with remote url %s: %s", a.RemoteURL, err)
}
attachmentIDs = append(attachmentIDs, deferencedAttachment.ID)
@@ -338,9 +338,9 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
}
var targetAccount *gtsmodel.Account
- if a, err := d.db.GetAccountByURL(targetAccountURI.String()); err == nil {
+ if a, err := d.db.GetAccountByURL(ctx, targetAccountURI.String()); err == nil {
targetAccount = a
- } else if a, _, err := d.GetRemoteAccount(requestingUsername, targetAccountURI, false); err == nil {
+ } else if a, _, err := d.GetRemoteAccount(ctx, requestingUsername, targetAccountURI, false); err == nil {
targetAccount = a
} else {
// we can't find the target account so bail
@@ -369,7 +369,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
TargetAccountURL: targetAccount.URL,
}
- if err := d.db.Put(m); err != nil {
+ if err := d.db.Put(ctx, m); err != nil {
return fmt.Errorf("error creating mention: %s", err)
}
mentionIDs = append(mentionIDs, m.ID)
@@ -382,13 +382,13 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
if err != nil {
return err
}
- if replyToStatus, err := d.db.GetStatusByURI(status.InReplyToURI); err == nil {
+ if replyToStatus, err := d.db.GetStatusByURI(ctx, status.InReplyToURI); err == nil {
// we have the status
status.InReplyToID = replyToStatus.ID
status.InReplyTo = replyToStatus
status.InReplyToAccountID = replyToStatus.AccountID
status.InReplyToAccount = replyToStatus.Account
- } else if replyToStatus, _, _, err := d.GetRemoteStatus(requestingUsername, statusURI, false); err == nil {
+ } else if replyToStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, statusURI, false); err == nil {
// we got the status
status.InReplyToID = replyToStatus.ID
status.InReplyTo = replyToStatus
diff --git a/internal/federation/dereferencing/thread.go b/internal/federation/dereferencing/thread.go
index 2a407f923..328a1c4ee 100644
--- a/internal/federation/dereferencing/thread.go
+++ b/internal/federation/dereferencing/thread.go
@@ -19,6 +19,7 @@
package dereferencing
import (
+ "context"
"fmt"
"net/url"
@@ -34,7 +35,7 @@ import (
// This process involves working up and down the chain of replies, and parsing through the collections of IDs
// presented by remote instances as part of their replies collections, and will likely involve making several calls to
// multiple different hosts.
-func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error {
+func (d *deref) DereferenceThread(ctx context.Context, username string, statusIRI *url.URL) error {
l := d.log.WithFields(logrus.Fields{
"func": "DereferenceThread",
"username": username,
@@ -49,18 +50,18 @@ func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error {
}
// first make sure we have this status in our db
- _, statusable, _, err := d.GetRemoteStatus(username, statusIRI, true)
+ _, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true)
if err != nil {
return fmt.Errorf("DereferenceThread: error getting status with id %s: %s", statusIRI.String(), err)
}
// first iterate up through ancestors, dereferencing if necessary as we go
- if err := d.iterateAncestors(username, *statusIRI); err != nil {
+ if err := d.iterateAncestors(ctx, username, *statusIRI); err != nil {
return fmt.Errorf("error iterating ancestors of status %s: %s", statusIRI.String(), err)
}
// now iterate down through descendants, again dereferencing as we go
- if err := d.iterateDescendants(username, *statusIRI, statusable); err != nil {
+ if err := d.iterateDescendants(ctx, username, *statusIRI, statusable); err != nil {
return fmt.Errorf("error iterating descendants of status %s: %s", statusIRI.String(), err)
}
@@ -68,7 +69,7 @@ func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error {
}
// iterateAncestors has the goal of reaching the oldest ancestor of a given status, and stashing all statuses along the way.
-func (d *deref) iterateAncestors(username string, statusIRI url.URL) error {
+func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI url.URL) error {
l := d.log.WithFields(logrus.Fields{
"func": "iterateAncestors",
"username": username,
@@ -87,7 +88,7 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error {
}
status := >smodel.Status{}
- if err := d.db.GetByID(id, status); err != nil {
+ if err := d.db.GetByID(ctx, id, status); err != nil {
return err
}
@@ -99,12 +100,12 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error {
if err != nil {
return err
}
- return d.iterateAncestors(username, *nextIRI)
+ return d.iterateAncestors(ctx, username, *nextIRI)
}
// If we reach here, we're looking at a remote status -- make sure we have it in our db by calling GetRemoteStatus
// We call it with refresh to true because we want the statusable representation to parse inReplyTo from.
- status, statusable, _, err := d.GetRemoteStatus(username, &statusIRI, true)
+ status, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true)
if err != nil {
l.Debugf("error getting remote status: %s", err)
return nil
@@ -117,22 +118,22 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error {
}
// get the ancestor status into our database if we don't have it yet
- if _, _, _, err := d.GetRemoteStatus(username, inReplyTo, false); err != nil {
+ if _, _, _, err := d.GetRemoteStatus(ctx, username, inReplyTo, false); err != nil {
l.Debugf("error getting remote status: %s", err)
return nil
}
// now enrich the current status, since we should have the ancestor in the db
- if _, err := d.EnrichRemoteStatus(username, status); err != nil {
+ if _, err := d.EnrichRemoteStatus(ctx, username, status); err != nil {
l.Debugf("error enriching remote status: %s", err)
return nil
}
// now move up to the next ancestor
- return d.iterateAncestors(username, *inReplyTo)
+ return d.iterateAncestors(ctx, username, *inReplyTo)
}
-func (d *deref) iterateDescendants(username string, statusIRI url.URL, statusable ap.Statusable) error {
+func (d *deref) iterateDescendants(ctx context.Context, username string, statusIRI url.URL, statusable ap.Statusable) error {
l := d.log.WithFields(logrus.Fields{
"func": "iterateDescendants",
"username": username,
@@ -182,7 +183,7 @@ func (d *deref) iterateDescendants(username string, statusIRI url.URL, statusabl
pageLoop:
for {
l.Debugf("dereferencing page %s", currentPageIRI)
- nextPage, err := d.DereferenceCollectionPage(username, currentPageIRI)
+ nextPage, err := d.DereferenceCollectionPage(ctx, username, currentPageIRI)
if err != nil {
return nil
}
@@ -226,10 +227,10 @@ pageLoop:
foundReplies = foundReplies + 1
// get the remote statusable and put it in the db
- _, statusable, new, err := d.GetRemoteStatus(username, itemURI, false)
+ _, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false)
if new && err == nil && statusable != nil {
// now iterate descendants of *that* status
- if err := d.iterateDescendants(username, *itemURI, statusable); err != nil {
+ if err := d.iterateDescendants(ctx, username, *itemURI, statusable); err != nil {
continue
}
}
diff --git a/internal/processing/account.go b/internal/processing/account.go
index f722c88eb..94ba596ac 100644
--- a/internal/processing/account.go
+++ b/internal/processing/account.go
@@ -19,51 +19,53 @@
package processing
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) AccountCreate(authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) {
- return p.accountProcessor.Create(authed.Token, authed.Application, form)
+func (p *processor) AccountCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) {
+ return p.accountProcessor.Create(ctx, authed.Token, authed.Application, form)
}
-func (p *processor) AccountGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) {
- return p.accountProcessor.Get(authed.Account, targetAccountID)
+func (p *processor) AccountGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) {
+ return p.accountProcessor.Get(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountUpdate(authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) {
- return p.accountProcessor.Update(authed.Account, form)
+func (p *processor) AccountUpdate(ctx context.Context, authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) {
+ return p.accountProcessor.Update(ctx, authed.Account, form)
}
-func (p *processor) AccountStatusesGet(authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
- return p.accountProcessor.StatusesGet(authed.Account, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
+func (p *processor) AccountStatusesGet(ctx context.Context, authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
+ return p.accountProcessor.StatusesGet(ctx, authed.Account, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
}
-func (p *processor) AccountFollowersGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- return p.accountProcessor.FollowersGet(authed.Account, targetAccountID)
+func (p *processor) AccountFollowersGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
+ return p.accountProcessor.FollowersGet(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountFollowingGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- return p.accountProcessor.FollowingGet(authed.Account, targetAccountID)
+func (p *processor) AccountFollowingGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
+ return p.accountProcessor.FollowingGet(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountRelationshipGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
- return p.accountProcessor.RelationshipGet(authed.Account, targetAccountID)
+func (p *processor) AccountRelationshipGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ return p.accountProcessor.RelationshipGet(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountFollowCreate(authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
- return p.accountProcessor.FollowCreate(authed.Account, form)
+func (p *processor) AccountFollowCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
+ return p.accountProcessor.FollowCreate(ctx, authed.Account, form)
}
-func (p *processor) AccountFollowRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
- return p.accountProcessor.FollowRemove(authed.Account, targetAccountID)
+func (p *processor) AccountFollowRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ return p.accountProcessor.FollowRemove(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountBlockCreate(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
- return p.accountProcessor.BlockCreate(authed.Account, targetAccountID)
+func (p *processor) AccountBlockCreate(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ return p.accountProcessor.BlockCreate(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountBlockRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
- return p.accountProcessor.BlockRemove(authed.Account, targetAccountID)
+func (p *processor) AccountBlockRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ return p.accountProcessor.BlockRemove(ctx, authed.Account, targetAccountID)
}
diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go
index 7b8910149..81701fd7c 100644
--- a/internal/processing/account/account.go
+++ b/internal/processing/account/account.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"mime/multipart"
"github.com/sirupsen/logrus"
@@ -38,40 +39,40 @@ import (
// Processor wraps a bunch of functions for processing account actions.
type Processor interface {
// Create processes the given form for creating a new account, returning an oauth token for that account if successful.
- Create(applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error)
+ Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error)
// Delete deletes an account, and all of that account's statuses, media, follows, notifications, etc etc etc.
// The origin passed here should be either the ID of the account doing the delete (can be itself), or the ID of a domain block.
- Delete(account *gtsmodel.Account, origin string) error
+ Delete(ctx context.Context, account *gtsmodel.Account, origin string) error
// Get processes the given request for account information.
- Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error)
+ Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error)
// Update processes the update of an account with the given form
- Update(account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error)
+ Update(ctx context.Context, account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error)
// StatusesGet fetches a number of statuses (in time descending order) from the given account, filtered by visibility for
// the account given in authed.
- StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode)
+ StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode)
// FollowersGet fetches a list of the target account's followers.
- FollowersGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
+ FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
// FollowingGet fetches a list of the accounts that target account is following.
- FollowingGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
+ FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
// RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account.
- RelationshipGet(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// FollowCreate handles a follow request to an account, either remote or local.
- FollowCreate(requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode)
+ FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode)
// FollowRemove handles the removal of a follow/follow request to an account, either remote or local.
- FollowRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local.
- BlockCreate(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// BlockRemove handles the removal of a block from requestingAccount to targetAccountID, either remote or local.
- BlockRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// UpdateHeader does the dirty work of checking the header part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new header image.
- UpdateAvatar(avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error)
+ UpdateAvatar(ctx context.Context, avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error)
// UpdateAvatar does the dirty work of checking the avatar part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new avatar image.
- UpdateHeader(header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error)
+ UpdateHeader(ctx context.Context, header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error)
}
type processor struct {
diff --git a/internal/processing/account/create.go b/internal/processing/account/create.go
index 83e76973d..1eae90d03 100644
--- a/internal/processing/account/create.go
+++ b/internal/processing/account/create.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,14 +28,14 @@ import (
"github.com/superseriousbusiness/oauth2/v4"
)
-func (p *processor) Create(applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) {
+func (p *processor) Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) {
l := p.log.WithField("func", "accountCreate")
- if err := p.db.IsEmailAvailable(form.Email); err != nil {
+ if err := p.db.IsEmailAvailable(ctx, form.Email); err != nil {
return nil, err
}
- if err := p.db.IsUsernameAvailable(form.Username); err != nil {
+ if err := p.db.IsUsernameAvailable(ctx, form.Username); err != nil {
return nil, err
}
@@ -45,7 +46,7 @@ func (p *processor) Create(applicationToken oauth2.TokenInfo, application *gtsmo
}
l.Trace("creating new username and account")
- user, err := p.db.NewSignup(form.Username, text.RemoveHTML(reason), p.config.AccountsConfig.RequireApproval, form.Email, form.Password, form.IP, form.Locale, application.ID, false, false)
+ user, err := p.db.NewSignup(ctx, form.Username, text.RemoveHTML(reason), p.config.AccountsConfig.RequireApproval, form.Email, form.Password, form.IP, form.Locale, application.ID, false, false)
if err != nil {
return nil, fmt.Errorf("error creating new signup in the database: %s", err)
}
diff --git a/internal/processing/account/createblock.go b/internal/processing/account/createblock.go
index f10a2efa3..06f82b37d 100644
--- a/internal/processing/account/createblock.go
+++ b/internal/processing/account/createblock.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -29,18 +30,18 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+func (p *processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db
- targetAccount, err := p.db.GetAccountByID(targetAccountID)
+ targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
}
// if requestingAccount already blocks target account, we don't need to do anything
- if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, false); err != nil {
+ if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err))
} else if blocked {
- return p.RelationshipGet(requestingAccount, targetAccountID)
+ return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
}
// make the block
@@ -57,18 +58,18 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
block.URI = util.GenerateURIForBlock(requestingAccount.Username, p.config.Protocol, p.config.Host, newBlockID)
// whack it in the database
- if err := p.db.Put(block); err != nil {
+ if err := p.db.Put(ctx, block); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error creating block in db: %s", err))
}
// clear any follows or follow requests from the blocked account to the target account -- this is a simple delete
- if err := p.db.DeleteWhere([]db.Where{
+ if err := p.db.DeleteWhere(ctx, []db.Where{
{Key: "account_id", Value: targetAccountID},
{Key: "target_account_id", Value: requestingAccount.ID},
}, >smodel.Follow{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow in db: %s", err))
}
- if err := p.db.DeleteWhere([]db.Where{
+ if err := p.db.DeleteWhere(ctx, []db.Where{
{Key: "account_id", Value: targetAccountID},
{Key: "target_account_id", Value: requestingAccount.ID},
}, >smodel.FollowRequest{}); err != nil {
@@ -82,12 +83,12 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
var frChanged bool
var frURI string
fr := >smodel.FollowRequest{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, fr); err == nil {
frURI = fr.URI
- if err := p.db.DeleteByID(fr.ID, fr); err != nil {
+ if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow request from db: %s", err))
}
frChanged = true
@@ -97,12 +98,12 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
var fChanged bool
var fURI string
f := >smodel.Follow{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, f); err == nil {
fURI = f.URI
- if err := p.db.DeleteByID(f.ID, f); err != nil {
+ if err := p.db.DeleteByID(ctx, f.ID, f); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow from db: %s", err))
}
fChanged = true
@@ -147,5 +148,5 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
TargetAccount: targetAccount,
}
- return p.RelationshipGet(requestingAccount, targetAccountID)
+ return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
}
diff --git a/internal/processing/account/createfollow.go b/internal/processing/account/createfollow.go
index 8c856a50e..a7767afea 100644
--- a/internal/processing/account/createfollow.go
+++ b/internal/processing/account/createfollow.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -29,16 +30,16 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
+func (p *processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
// if there's a block between the accounts we shouldn't create the request ofc
- if blocked, err := p.db.IsBlocked(requestingAccount.ID, form.ID, true); err != nil {
+ if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
// make sure the target account actually exists in our db
- targetAcct, err := p.db.GetAccountByID(form.ID)
+ targetAcct, err := p.db.GetAccountByID(ctx, form.ID)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err))
@@ -47,19 +48,19 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim
}
// check if a follow exists already
- if follows, err := p.db.IsFollowing(requestingAccount, targetAcct); err != nil {
+ if follows, err := p.db.IsFollowing(ctx, requestingAccount, targetAcct); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err))
} else if follows {
// already follows so just return the relationship
- return p.RelationshipGet(requestingAccount, form.ID)
+ return p.RelationshipGet(ctx, requestingAccount, form.ID)
}
// check if a follow request exists already
- if followRequested, err := p.db.IsFollowRequested(requestingAccount, targetAcct); err != nil {
+ if followRequested, err := p.db.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err))
} else if followRequested {
// already follow requested so just return the relationship
- return p.RelationshipGet(requestingAccount, form.ID)
+ return p.RelationshipGet(ctx, requestingAccount, form.ID)
}
// make the follow request
@@ -84,17 +85,17 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim
}
// whack it in the database
- if err := p.db.Put(fr); err != nil {
+ if err := p.db.Put(ctx, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error creating follow request in db: %s", err))
}
// if it's a local account that's not locked we can just straight up accept the follow request
if !targetAcct.Locked && targetAcct.Domain == "" {
- if _, err := p.db.AcceptFollowRequest(requestingAccount.ID, form.ID); err != nil {
+ if _, err := p.db.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error accepting folow request for local unlocked account: %s", err))
}
// return the new relationship
- return p.RelationshipGet(requestingAccount, form.ID)
+ return p.RelationshipGet(ctx, requestingAccount, form.ID)
}
// otherwise we leave the follow request as it is and we handle the rest of the process asynchronously
@@ -107,5 +108,5 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim
}
// return whatever relationship results from this
- return p.RelationshipGet(requestingAccount, form.ID)
+ return p.RelationshipGet(ctx, requestingAccount, form.ID)
}
diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go
index e8840abae..a0758c846 100644
--- a/internal/processing/account/delete.go
+++ b/internal/processing/account/delete.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"time"
"github.com/sirupsen/logrus"
@@ -48,7 +49,7 @@ import (
// 16. Delete account's user
// 17. Delete account's timeline
// 18. Delete account itself
-func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
+func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origin string) error {
l := p.log.WithFields(logrus.Fields{
"func": "Delete",
"username": account.Username,
@@ -61,22 +62,22 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
if account.Domain == "" {
// see if we can get a user for this account
u := >smodel.User{}
- if err := p.db.GetWhere([]db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil {
// we got one! select all tokens with the user's ID
tokens := []*oauth.Token{}
- if err := p.db.GetWhere([]db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil {
// we have some tokens to delete
for _, t := range tokens {
// delete client(s) associated with this token
- if err := p.db.DeleteByID(t.ClientID, &oauth.Client{}); err != nil {
+ if err := p.db.DeleteByID(ctx, t.ClientID, &oauth.Client{}); err != nil {
l.Errorf("error deleting oauth client: %s", err)
}
// delete application(s) associated with this token
- if err := p.db.DeleteWhere([]db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil {
l.Errorf("error deleting application: %s", err)
}
// delete the token itself
- if err := p.db.DeleteByID(t.ID, t); err != nil {
+ if err := p.db.DeleteByID(ctx, t.ID, t); err != nil {
l.Errorf("error deleting oauth token: %s", err)
}
}
@@ -87,12 +88,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
// 2. Delete account's blocks
l.Debug("deleting account blocks")
// first delete any blocks that this account created
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil {
l.Errorf("error deleting blocks created by account: %s", err)
}
// now delete any blocks that target this account
- if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil {
l.Errorf("error deleting blocks targeting account: %s", err)
}
@@ -103,12 +104,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
// TODO: federate these if necessary
l.Debug("deleting account follow requests")
// first delete any follow requests that this account created
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
l.Errorf("error deleting follow requests created by account: %s", err)
}
// now delete any follow requests that target this account
- if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
l.Errorf("error deleting follow requests targeting account: %s", err)
}
@@ -116,12 +117,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
// TODO: federate these if necessary
l.Debug("deleting account follows")
// first delete any follows that this account created
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
l.Errorf("error deleting follows created by account: %s", err)
}
// now delete any follows that target this account
- if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
l.Errorf("error deleting follows targeting account: %s", err)
}
@@ -133,7 +134,7 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
var maxID string
selectStatusesLoop:
for {
- statuses, err := p.db.GetAccountStatuses(account.ID, 20, false, maxID, false, false)
+ statuses, err := p.db.GetAccountStatuses(ctx, account.ID, 20, false, maxID, false, false)
if err != nil {
if err == db.ErrNoEntries {
// no statuses left for this instance so we're done
@@ -157,7 +158,7 @@ selectStatusesLoop:
TargetAccount: account,
}
- if err := p.db.DeleteByID(s.ID, s); err != nil {
+ if err := p.db.DeleteByID(ctx, s.ID, s); err != nil {
if err != db.ErrNoEntries {
// actual error has occurred
l.Errorf("Delete: db error status %s for account %s: %s", s.ID, account.Username, err)
@@ -167,7 +168,7 @@ selectStatusesLoop:
// if there are any boosts of this status, delete them as well
boosts := []*gtsmodel.Status{}
- if err := p.db.GetWhere([]db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil {
if err != db.ErrNoEntries {
// an actual error has occurred
l.Errorf("Delete: db error selecting boosts of status %s for account %s: %s", s.ID, account.Username, err)
@@ -177,7 +178,7 @@ selectStatusesLoop:
for _, b := range boosts {
oa := >smodel.Account{}
- if err := p.db.GetByID(b.AccountID, oa); err == nil {
+ if err := p.db.GetByID(ctx, b.AccountID, oa); err == nil {
l.Debug("putting boost undo in the client api channel")
p.fromClientAPI <- gtsmodel.FromClientAPI{
@@ -189,7 +190,7 @@ selectStatusesLoop:
}
}
- if err := p.db.DeleteByID(b.ID, b); err != nil {
+ if err := p.db.DeleteByID(ctx, b.ID, b); err != nil {
if err != db.ErrNoEntries {
// actual error has occurred
l.Errorf("Delete: db error deleting boost with id %s: %s", b.ID, err)
@@ -208,26 +209,26 @@ selectStatusesLoop:
// 10. Delete account's notifications
l.Debug("deleting account notifications")
- if err := p.db.DeleteWhere([]db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
l.Errorf("error deleting notifications created by account: %s", err)
}
// 11. Delete account's bookmarks
l.Debug("deleting account bookmarks")
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
l.Errorf("error deleting bookmarks created by account: %s", err)
}
// 12. Delete account's faves
// TODO: federate these if necessary
l.Debug("deleting account faves")
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil {
l.Errorf("error deleting faves created by account: %s", err)
}
// 13. Delete account's mutes
l.Debug("deleting account mutes")
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil {
l.Errorf("error deleting status mutes created by account: %s", err)
}
@@ -239,7 +240,7 @@ selectStatusesLoop:
// 16. Delete account's user
l.Debug("deleting account user")
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, >smodel.User{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, >smodel.User{}); err != nil {
return err
}
@@ -266,7 +267,7 @@ selectStatusesLoop:
account.SuspendedAt = time.Now()
account.SuspensionOrigin = origin
- if err := p.db.UpdateByID(account.ID, account); err != nil {
+ if err := p.db.UpdateByID(ctx, account.ID, account); err != nil {
return err
}
diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go
index 3dfc54b51..01a0fb51a 100644
--- a/internal/processing/account/get.go
+++ b/internal/processing/account/get.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"errors"
"fmt"
@@ -27,9 +28,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) {
+func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) {
targetAccount := >smodel.Account{}
- if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
+ if err := p.db.GetByID(ctx, targetAccountID, targetAccount); err != nil {
if err == db.ErrNoEntries {
return nil, errors.New("account not found")
}
@@ -39,7 +40,7 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID str
var blocked bool
var err error
if requestingAccount != nil {
- blocked, err = p.db.IsBlocked(requestingAccount.ID, targetAccountID, true)
+ blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)
if err != nil {
return nil, fmt.Errorf("error checking account block: %s", err)
}
diff --git a/internal/processing/account/getfollowers.go b/internal/processing/account/getfollowers.go
index 4f66b40ee..f90b0f767 100644
--- a/internal/processing/account/getfollowers.go
+++ b/internal/processing/account/getfollowers.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,15 +28,15 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
+func (p *processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
+ if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
accounts := []apimodel.Account{}
- follows, err := p.db.GetAccountFollowedBy(targetAccountID, false)
+ follows, err := p.db.GetAccountFollowedBy(ctx, targetAccountID, false)
if err != nil {
if err == db.ErrNoEntries {
return accounts, nil
@@ -44,7 +45,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
}
for _, f := range follows {
- blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -53,7 +54,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
}
if f.Account == nil {
- a, err := p.db.GetAccountByID(f.AccountID)
+ a, err := p.db.GetAccountByID(ctx, f.AccountID)
if err != nil {
if err == db.ErrNoEntries {
continue
diff --git a/internal/processing/account/getfollowing.go b/internal/processing/account/getfollowing.go
index c7fb426f9..4082e89c1 100644
--- a/internal/processing/account/getfollowing.go
+++ b/internal/processing/account/getfollowing.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,15 +28,15 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
+func (p *processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
+ if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
accounts := []apimodel.Account{}
- follows, err := p.db.GetAccountFollows(targetAccountID)
+ follows, err := p.db.GetAccountFollows(ctx, targetAccountID)
if err != nil {
if err == db.ErrNoEntries {
return accounts, nil
@@ -44,7 +45,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
}
for _, f := range follows {
- blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -53,7 +54,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
}
if f.TargetAccount == nil {
- a, err := p.db.GetAccountByID(f.TargetAccountID)
+ a, err := p.db.GetAccountByID(ctx, f.TargetAccountID)
if err != nil {
if err == db.ErrNoEntries {
continue
diff --git a/internal/processing/account/getrelationship.go b/internal/processing/account/getrelationship.go
index a0a93a4c2..615f30d5a 100644
--- a/internal/processing/account/getrelationship.go
+++ b/internal/processing/account/getrelationship.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"errors"
"fmt"
@@ -27,12 +28,12 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) RelationshipGet(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+func (p *processor) RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
if requestingAccount == nil {
return nil, gtserror.NewErrorForbidden(errors.New("not authed"))
}
- gtsR, err := p.db.GetRelationship(requestingAccount.ID, targetAccountID)
+ gtsR, err := p.db.GetRelationship(ctx, requestingAccount.ID, targetAccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err))
}
diff --git a/internal/processing/account/getstatuses.go b/internal/processing/account/getstatuses.go
index dc21e7006..dae9fee33 100644
--- a/internal/processing/account/getstatuses.go
+++ b/internal/processing/account/getstatuses.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,8 +28,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
- if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
+func (p *processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
+ if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
@@ -36,7 +37,7 @@ func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccou
apiStatuses := []apimodel.Status{}
- statuses, err := p.db.GetAccountStatuses(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
+ statuses, err := p.db.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
if err != nil {
if err == db.ErrNoEntries {
return apiStatuses, nil
diff --git a/internal/processing/account/removeblock.go b/internal/processing/account/removeblock.go
index 7c1f2bc17..7e3d78076 100644
--- a/internal/processing/account/removeblock.go
+++ b/internal/processing/account/removeblock.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,9 +28,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+func (p *processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db
- targetAccount, err := p.db.GetAccountByID(targetAccountID)
+ targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
}
@@ -37,13 +38,13 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou
// check if a block exists, and remove it if it does (storing the URI for later)
var blockChanged bool
block := >smodel.Block{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, block); err == nil {
block.Account = requestingAccount
block.TargetAccount = targetAccount
- if err := p.db.DeleteByID(block.ID, >smodel.Block{}); err != nil {
+ if err := p.db.DeleteByID(ctx, block.ID, >smodel.Block{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err))
}
blockChanged = true
@@ -61,5 +62,5 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou
}
// return whatever relationship results from all this
- return p.RelationshipGet(requestingAccount, targetAccountID)
+ return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
}
diff --git a/internal/processing/account/removefollow.go b/internal/processing/account/removefollow.go
index 6646d694e..c271de79f 100644
--- a/internal/processing/account/removefollow.go
+++ b/internal/processing/account/removefollow.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,9 +28,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+func (p *processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// if there's a block between the accounts we shouldn't do anything
- blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -39,7 +40,7 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco
// make sure the target account actually exists in our db
targetAcct := >smodel.Account{}
- if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
+ if err := p.db.GetByID(ctx, targetAccountID, targetAcct); err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err))
}
@@ -49,12 +50,12 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco
var frChanged bool
var frURI string
fr := >smodel.FollowRequest{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, fr); err == nil {
frURI = fr.URI
- if err := p.db.DeleteByID(fr.ID, fr); err != nil {
+ if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow request from db: %s", err))
}
frChanged = true
@@ -64,12 +65,12 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco
var fChanged bool
var fURI string
f := >smodel.Follow{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, f); err == nil {
fURI = f.URI
- if err := p.db.DeleteByID(f.ID, f); err != nil {
+ if err := p.db.DeleteByID(ctx, f.ID, f); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow from db: %s", err))
}
fChanged = true
@@ -106,5 +107,5 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco
}
// return whatever relationship results from all this
- return p.RelationshipGet(requestingAccount, targetAccountID)
+ return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
}
diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go
index df842bacd..46aa10ce1 100644
--- a/internal/processing/account/update.go
+++ b/internal/processing/account/update.go
@@ -20,6 +20,7 @@ package account
import (
"bytes"
+ "context"
"errors"
"fmt"
"io"
@@ -32,17 +33,17 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) {
+func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) {
l := p.log.WithField("func", "AccountUpdate")
if form.Discoverable != nil {
- if err := p.db.UpdateOneByID(account.ID, "discoverable", *form.Discoverable, >smodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "discoverable", *form.Discoverable, >smodel.Account{}); err != nil {
return nil, fmt.Errorf("error updating discoverable: %s", err)
}
}
if form.Bot != nil {
- if err := p.db.UpdateOneByID(account.ID, "bot", *form.Bot, >smodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "bot", *form.Bot, >smodel.Account{}); err != nil {
return nil, fmt.Errorf("error updating bot: %s", err)
}
}
@@ -52,7 +53,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
return nil, err
}
displayName := text.RemoveHTML(*form.DisplayName) // no html allowed in display name
- if err := p.db.UpdateOneByID(account.ID, "display_name", displayName, >smodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "display_name", displayName, >smodel.Account{}); err != nil {
return nil, err
}
}
@@ -62,13 +63,13 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
return nil, err
}
note := text.SanitizeHTML(*form.Note) // html OK in note but sanitize it
- if err := p.db.UpdateOneByID(account.ID, "note", note, >smodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "note", note, >smodel.Account{}); err != nil {
return nil, err
}
}
if form.Avatar != nil && form.Avatar.Size != 0 {
- avatarInfo, err := p.UpdateAvatar(form.Avatar, account.ID)
+ avatarInfo, err := p.UpdateAvatar(ctx, form.Avatar, account.ID)
if err != nil {
return nil, err
}
@@ -76,7 +77,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
}
if form.Header != nil && form.Header.Size != 0 {
- headerInfo, err := p.UpdateHeader(form.Header, account.ID)
+ headerInfo, err := p.UpdateHeader(ctx, form.Header, account.ID)
if err != nil {
return nil, err
}
@@ -84,7 +85,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
}
if form.Locked != nil {
- if err := p.db.UpdateOneByID(account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil {
return nil, err
}
}
@@ -94,13 +95,13 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
if err := util.ValidateLanguage(*form.Source.Language); err != nil {
return nil, err
}
- if err := p.db.UpdateOneByID(account.ID, "language", *form.Source.Language, >smodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "language", *form.Source.Language, >smodel.Account{}); err != nil {
return nil, err
}
}
if form.Source.Sensitive != nil {
- if err := p.db.UpdateOneByID(account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil {
return nil, err
}
}
@@ -109,7 +110,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
if err := util.ValidatePrivacy(*form.Source.Privacy); err != nil {
return nil, err
}
- if err := p.db.UpdateOneByID(account.ID, "privacy", *form.Source.Privacy, >smodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "privacy", *form.Source.Privacy, >smodel.Account{}); err != nil {
return nil, err
}
}
@@ -117,7 +118,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
// fetch the account with all updated values set
updatedAccount := >smodel.Account{}
- if err := p.db.GetByID(account.ID, updatedAccount); err != nil {
+ if err := p.db.GetByID(ctx, account.ID, updatedAccount); err != nil {
return nil, fmt.Errorf("could not fetch updated account %s: %s", account.ID, err)
}
@@ -138,7 +139,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
// UpdateAvatar does the dirty work of checking the avatar part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new avatar image.
-func (p *processor) UpdateAvatar(avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) {
+func (p *processor) UpdateAvatar(ctx context.Context, avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) {
var err error
if int(avatar.Size) > p.config.MediaConfig.MaxImageSize {
err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, p.config.MediaConfig.MaxImageSize)
@@ -171,7 +172,7 @@ func (p *processor) UpdateAvatar(avatar *multipart.FileHeader, accountID string)
// UpdateHeader does the dirty work of checking the header part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new header image.
-func (p *processor) UpdateHeader(header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) {
+func (p *processor) UpdateHeader(ctx context.Context, header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) {
var err error
if int(header.Size) > p.config.MediaConfig.MaxImageSize {
err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, p.config.MediaConfig.MaxImageSize)
diff --git a/internal/processing/admin.go b/internal/processing/admin.go
index 9a38f5ec1..48faee986 100644
--- a/internal/processing/admin.go
+++ b/internal/processing/admin.go
@@ -19,31 +19,33 @@
package processing
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) AdminEmojiCreate(authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) {
- return p.adminProcessor.EmojiCreate(authed.Account, authed.User, form)
+func (p *processor) AdminEmojiCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) {
+ return p.adminProcessor.EmojiCreate(ctx, authed.Account, authed.User, form)
}
-func (p *processor) AdminDomainBlockCreate(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) {
- return p.adminProcessor.DomainBlockCreate(authed.Account, form.Domain, form.Obfuscate, form.PublicComment, form.PrivateComment, "")
+func (p *processor) AdminDomainBlockCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) {
+ return p.adminProcessor.DomainBlockCreate(ctx, authed.Account, form.Domain, form.Obfuscate, form.PublicComment, form.PrivateComment, "")
}
-func (p *processor) AdminDomainBlocksImport(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) {
- return p.adminProcessor.DomainBlocksImport(authed.Account, form.Domains)
+func (p *processor) AdminDomainBlocksImport(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) {
+ return p.adminProcessor.DomainBlocksImport(ctx, authed.Account, form.Domains)
}
-func (p *processor) AdminDomainBlocksGet(authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
- return p.adminProcessor.DomainBlocksGet(authed.Account, export)
+func (p *processor) AdminDomainBlocksGet(ctx context.Context, authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
+ return p.adminProcessor.DomainBlocksGet(ctx, authed.Account, export)
}
-func (p *processor) AdminDomainBlockGet(authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
- return p.adminProcessor.DomainBlockGet(authed.Account, id, export)
+func (p *processor) AdminDomainBlockGet(ctx context.Context, authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
+ return p.adminProcessor.DomainBlockGet(ctx, authed.Account, id, export)
}
-func (p *processor) AdminDomainBlockDelete(authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
- return p.adminProcessor.DomainBlockDelete(authed.Account, id)
+func (p *processor) AdminDomainBlockDelete(ctx context.Context, authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
+ return p.adminProcessor.DomainBlockDelete(ctx, authed.Account, id)
}
diff --git a/internal/processing/admin/admin.go b/internal/processing/admin/admin.go
index fd63d8a10..de288811b 100644
--- a/internal/processing/admin/admin.go
+++ b/internal/processing/admin/admin.go
@@ -19,6 +19,7 @@
package admin
import (
+ "context"
"mime/multipart"
"github.com/sirupsen/logrus"
@@ -33,12 +34,12 @@ import (
// Processor wraps a bunch of functions for processing admin actions.
type Processor interface {
- DomainBlockCreate(account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode)
- DomainBlocksImport(account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode)
- DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode)
- DomainBlockGet(account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode)
- DomainBlockDelete(account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode)
- EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error)
+ DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode)
+ DomainBlocksImport(ctx context.Context, account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode)
+ DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode)
+ DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode)
+ DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode)
+ EmojiCreate(ctx context.Context, account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error)
}
type processor struct {
diff --git a/internal/processing/admin/createdomainblock.go b/internal/processing/admin/createdomainblock.go
index 624f632dc..d845fe4b9 100644
--- a/internal/processing/admin/createdomainblock.go
+++ b/internal/processing/admin/createdomainblock.go
@@ -19,6 +19,7 @@
package admin
import (
+ "context"
"fmt"
"time"
@@ -31,10 +32,10 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
)
-func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) {
+func (p *processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) {
// first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work
domainBlock := >smodel.DomainBlock{}
- err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock)
+ err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock)
if err != nil {
if err != db.ErrNoEntries {
// something went wrong in the DB
@@ -59,7 +60,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
}
// put the new block in the database
- if err := p.db.Put(domainBlock); err != nil {
+ if err := p.db.Put(ctx, domainBlock); err != nil {
if err != db.ErrNoEntries {
// there's a real error creating the block
return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error putting new domain block %s: %s", domain, err))
@@ -67,7 +68,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
}
// process the side effects of the domain block asynchronously since it might take a while
- go p.initiateDomainBlockSideEffects(account, domainBlock) // TODO: add this to a queuing system so it can retry/resume
+ go p.initiateDomainBlockSideEffects(ctx, account, domainBlock) // TODO: add this to a queuing system so it can retry/resume
}
mastoDomainBlock, err := p.tc.DomainBlockToMasto(domainBlock, false)
@@ -83,7 +84,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
// 1. Strip most info away from the instance entry for the domain.
// 2. Delete the instance account for that instance if it exists.
// 3. Select all accounts from this instance and pass them through the delete functionality of the processor.
-func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, block *gtsmodel.DomainBlock) {
+func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account *gtsmodel.Account, block *gtsmodel.DomainBlock) {
l := p.log.WithFields(logrus.Fields{
"func": "domainBlockProcessSideEffects",
"domain": block.Domain,
@@ -93,7 +94,7 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl
// if we have an instance entry for this domain, update it with the new block ID and clear all fields
instance := >smodel.Instance{}
- if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: block.Domain, CaseInsensitive: true}}, instance); err == nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain, CaseInsensitive: true}}, instance); err == nil {
instance.Title = ""
instance.UpdatedAt = time.Now()
instance.SuspendedAt = time.Now()
@@ -105,14 +106,14 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl
instance.ContactAccountUsername = ""
instance.ContactAccountID = ""
instance.Version = ""
- if err := p.db.UpdateByID(instance.ID, instance); err != nil {
+ if err := p.db.UpdateByID(ctx, instance.ID, instance); err != nil {
l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err)
}
l.Debug("domainBlockProcessSideEffects: instance entry updated")
}
// if we have an instance account for this instance, delete it
- if err := p.db.DeleteWhere([]db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil {
l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err)
}
@@ -123,7 +124,7 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl
selectAccountsLoop:
for {
- accounts, err := p.db.GetInstanceAccounts(block.Domain, maxID, limit)
+ accounts, err := p.db.GetInstanceAccounts(ctx, block.Domain, maxID, limit)
if err != nil {
if err == db.ErrNoEntries {
// no accounts left for this instance so we're done
diff --git a/internal/processing/admin/deletedomainblock.go b/internal/processing/admin/deletedomainblock.go
index edb0a58f9..c57554772 100644
--- a/internal/processing/admin/deletedomainblock.go
+++ b/internal/processing/admin/deletedomainblock.go
@@ -19,6 +19,7 @@
package admin
import (
+ "context"
"fmt"
"time"
@@ -28,10 +29,10 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
+func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock := >smodel.DomainBlock{}
- if err := p.db.GetByID(id, domainBlock); err != nil {
+ if err := p.db.GetByID(ctx, id, domainBlock); err != nil {
if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
@@ -47,33 +48,33 @@ func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*ap
}
// delete the domain block
- if err := p.db.DeleteByID(id, domainBlock); err != nil {
+ if err := p.db.DeleteByID(ctx, id, domainBlock); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
// remove the domain block reference from the instance, if we have an entry for it
i := >smodel.Instance{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true},
{Key: "domain_block_id", Value: id},
}, i); err == nil {
i.SuspendedAt = time.Time{}
i.DomainBlockID = ""
- if err := p.db.UpdateByID(i.ID, i); err != nil {
+ if err := p.db.UpdateByID(ctx, i.ID, i); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err))
}
}
// unsuspend all accounts whose suspension origin was this domain block
// 1. remove the 'suspended_at' entry from their accounts
- if err := p.db.UpdateWhere([]db.Where{
+ if err := p.db.UpdateWhere(ctx, []db.Where{
{Key: "suspension_origin", Value: domainBlock.ID},
}, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err))
}
// 2. remove the 'suspension_origin' entry from their accounts
- if err := p.db.UpdateWhere([]db.Where{
+ if err := p.db.UpdateWhere(ctx, []db.Where{
{Key: "suspension_origin", Value: domainBlock.ID},
}, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err))
diff --git a/internal/processing/admin/emoji.go b/internal/processing/admin/emoji.go
index f19e173b5..23ed1b6c2 100644
--- a/internal/processing/admin/emoji.go
+++ b/internal/processing/admin/emoji.go
@@ -20,6 +20,7 @@ package admin
import (
"bytes"
+ "context"
"errors"
"fmt"
"io"
@@ -29,7 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
)
-func (p *processor) EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) {
+func (p *processor) EmojiCreate(ctx context.Context, account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) {
if user.Admin {
return nil, fmt.Errorf("user %s not an admin", user.ID)
}
@@ -65,7 +66,7 @@ func (p *processor) EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User,
return nil, fmt.Errorf("error converting emoji to mastotype: %s", err)
}
- if err := p.db.Put(emoji); err != nil {
+ if err := p.db.Put(ctx, emoji); err != nil {
return nil, fmt.Errorf("database error while processing emoji: %s", err)
}
diff --git a/internal/processing/admin/getdomainblock.go b/internal/processing/admin/getdomainblock.go
index f74010627..676aff1ae 100644
--- a/internal/processing/admin/getdomainblock.go
+++ b/internal/processing/admin/getdomainblock.go
@@ -19,6 +19,7 @@
package admin
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,10 +28,10 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) DomainBlockGet(account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
+func (p *processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock := >smodel.DomainBlock{}
- if err := p.db.GetByID(id, domainBlock); err != nil {
+ if err := p.db.GetByID(ctx, id, domainBlock); err != nil {
if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
diff --git a/internal/processing/admin/getdomainblocks.go b/internal/processing/admin/getdomainblocks.go
index f827d03fc..9f0940aef 100644
--- a/internal/processing/admin/getdomainblocks.go
+++ b/internal/processing/admin/getdomainblocks.go
@@ -19,16 +19,18 @@
package admin
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
+func (p *processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
domainBlocks := []*gtsmodel.DomainBlock{}
- if err := p.db.GetAll(&domainBlocks); err != nil {
+ if err := p.db.GetAll(ctx, &domainBlocks); err != nil {
if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
diff --git a/internal/processing/admin/importdomainblocks.go b/internal/processing/admin/importdomainblocks.go
index ab171b712..66326bd62 100644
--- a/internal/processing/admin/importdomainblocks.go
+++ b/internal/processing/admin/importdomainblocks.go
@@ -20,6 +20,7 @@ package admin
import (
"bytes"
+ "context"
"encoding/json"
"errors"
"fmt"
@@ -32,7 +33,7 @@ import (
)
// DomainBlocksImport handles the import of a bunch of domain blocks at once, by calling the DomainBlockCreate function for each domain in the provided file.
-func (p *processor) DomainBlocksImport(account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) {
+func (p *processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) {
f, err := domains.Open()
if err != nil {
@@ -54,7 +55,7 @@ func (p *processor) DomainBlocksImport(account *gtsmodel.Account, domains *multi
blocks := []*apimodel.DomainBlock{}
for _, d := range d {
- block, err := p.DomainBlockCreate(account, d.Domain, false, d.PublicComment, "", "")
+ block, err := p.DomainBlockCreate(ctx, account, d.Domain, false, d.PublicComment, "", "")
if err != nil {
return nil, err
diff --git a/internal/processing/app.go b/internal/processing/app.go
index 7da5344ac..444c4dda2 100644
--- a/internal/processing/app.go
+++ b/internal/processing/app.go
@@ -19,6 +19,8 @@
package processing
import (
+ "context"
+
"github.com/google/uuid"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -26,7 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) {
+func (p *processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) {
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
var scopes string
if form.Scopes == "" {
@@ -61,7 +63,7 @@ func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCrea
}
// chuck it in the db
- if err := p.db.Put(app); err != nil {
+ if err := p.db.Put(ctx, app); err != nil {
return nil, err
}
@@ -74,7 +76,7 @@ func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCrea
}
// chuck it in the db
- if err := p.db.Put(oc); err != nil {
+ if err := p.db.Put(ctx, oc); err != nil {
return nil, err
}
diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go
index 809cbde8e..7451485f9 100644
--- a/internal/processing/blocks.go
+++ b/internal/processing/blocks.go
@@ -19,6 +19,7 @@
package processing
import (
+ "context"
"fmt"
"net/url"
@@ -28,8 +29,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
- accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(authed.Account.ID, maxID, sinceID, limit)
+func (p *processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
+ accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit)
if err != nil {
if err == db.ErrNoEntries {
// there are just no entries
diff --git a/internal/processing/federation.go b/internal/processing/federation.go
index cea14b4de..2f8b15d16 100644
--- a/internal/processing/federation.go
+++ b/internal/processing/federation.go
@@ -36,7 +36,7 @@ import (
func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -62,7 +62,7 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r
return nil, gtserror.NewErrorNotAuthorized(err)
}
- blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -90,7 +90,7 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r
func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -106,7 +106,7 @@ func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername stri
return nil, gtserror.NewErrorNotAuthorized(err)
}
- blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -135,7 +135,7 @@ func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername stri
func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -151,7 +151,7 @@ func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername stri
return nil, gtserror.NewErrorNotAuthorized(err)
}
- blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -180,7 +180,7 @@ func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername stri
func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string, requestedStatusID string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -198,7 +198,7 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string,
// authorize the request:
// 1. check if a block exists between the requester and the requestee
- blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -209,7 +209,7 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string,
// get the status out of the database here
s := >smodel.Status{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "id", Value: requestedStatusID},
{Key: "account_id", Value: requestedAccount.ID},
}, s); err != nil {
@@ -240,7 +240,7 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string,
func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername string, requestedStatusID string, page bool, onlyOtherAccounts bool, minID string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -258,7 +258,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
// authorize the request:
// 1. check if a block exists between the requester and the requestee
- blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -269,7 +269,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
// get the status out of the database here
s := >smodel.Status{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "id", Value: requestedStatusID},
{Key: "account_id", Value: requestedAccount.ID},
}, s); err != nil {
@@ -320,7 +320,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
} else {
// scenario 3
// get immediate children
- replies, err := p.db.GetStatusChildren(s, true, minID)
+ replies, err := p.db.GetStatusChildren(ctx, s, true, minID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -373,7 +373,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
func (p *processor) GetWebfingerAccount(ctx context.Context, requestedUsername string, requestURL *url.URL) (*apimodel.WellKnownResponse, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -400,7 +400,7 @@ func (p *processor) GetWebfingerAccount(ctx context.Context, requestedUsername s
}, nil
}
-func (p *processor) GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) {
+func (p *processor) GetNodeInfoRel(ctx context.Context, request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) {
return &apimodel.WellKnownResponse{
Links: []apimodel.Link{
{
@@ -411,7 +411,7 @@ func (p *processor) GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownRe
}, nil
}
-func (p *processor) GetNodeInfo(request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) {
+func (p *processor) GetNodeInfo(ctx context.Context, request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) {
return &apimodel.Nodeinfo{
Version: "2.0",
Software: apimodel.NodeInfoSoftware{
diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go
index 867725023..783a69a96 100644
--- a/internal/processing/followrequest.go
+++ b/internal/processing/followrequest.go
@@ -19,6 +19,8 @@
package processing
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
@@ -26,8 +28,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
- frs, err := p.db.GetAccountFollowRequests(auth.Account.ID)
+func (p *processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
+ frs, err := p.db.GetAccountFollowRequests(ctx, auth.Account.ID)
if err != nil {
if err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(err)
@@ -37,7 +39,7 @@ func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gts
accts := []apimodel.Account{}
for _, fr := range frs {
acct := >smodel.Account{}
- if err := p.db.GetByID(fr.AccountID, acct); err != nil {
+ if err := p.db.GetByID(ctx, fr.AccountID, acct); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
mastoAcct, err := p.tc.AccountToMastoPublic(acct)
@@ -49,19 +51,19 @@ func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gts
return accts, nil
}
-func (p *processor) FollowRequestAccept(auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
- follow, err := p.db.AcceptFollowRequest(accountID, auth.Account.ID)
+func (p *processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ follow, err := p.db.AcceptFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
originAccount := >smodel.Account{}
- if err := p.db.GetByID(follow.AccountID, originAccount); err != nil {
+ if err := p.db.GetByID(ctx, follow.AccountID, originAccount); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
targetAccount := >smodel.Account{}
- if err := p.db.GetByID(follow.TargetAccountID, targetAccount); err != nil {
+ if err := p.db.GetByID(ctx, follow.TargetAccountID, targetAccount); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -73,7 +75,7 @@ func (p *processor) FollowRequestAccept(auth *oauth.Auth, accountID string) (*ap
TargetAccount: targetAccount,
}
- gtsR, err := p.db.GetRelationship(auth.Account.ID, accountID)
+ gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -86,6 +88,6 @@ func (p *processor) FollowRequestAccept(auth *oauth.Auth, accountID string) (*ap
return r, nil
}
-func (p *processor) FollowRequestDeny(auth *oauth.Auth) gtserror.WithCode {
+func (p *processor) FollowRequestDeny(ctx context.Context, auth *oauth.Auth) gtserror.WithCode {
return nil
}
diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go
index beed283c1..7c8743005 100644
--- a/internal/processing/fromclientapi.go
+++ b/internal/processing/fromclientapi.go
@@ -29,7 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error {
+func (p *processor) processFromClientAPI(ctx context.Context, clientMsg gtsmodel.FromClientAPI) error {
switch clientMsg.APActivityType {
case gtsmodel.ActivityStreamsCreate:
// CREATE
@@ -41,16 +41,16 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("note was not parseable as *gtsmodel.Status")
}
- if err := p.timelineStatus(status); err != nil {
+ if err := p.timelineStatus(ctx, status); err != nil {
return err
}
- if err := p.notifyStatus(status); err != nil {
+ if err := p.notifyStatus(ctx, status); err != nil {
return err
}
if status.VisibilityAdvanced != nil && status.VisibilityAdvanced.Federated {
- return p.federateStatus(status)
+ return p.federateStatus(ctx, status)
}
case gtsmodel.ActivityStreamsFollow:
// CREATE FOLLOW REQUEST
@@ -59,11 +59,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("followrequest was not parseable as *gtsmodel.FollowRequest")
}
- if err := p.notifyFollowRequest(followRequest, clientMsg.TargetAccount); err != nil {
+ if err := p.notifyFollowRequest(ctx, followRequest, clientMsg.TargetAccount); err != nil {
return err
}
- return p.federateFollow(followRequest, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateFollow(ctx, followRequest, clientMsg.OriginAccount, clientMsg.TargetAccount)
case gtsmodel.ActivityStreamsLike:
// CREATE LIKE/FAVE
fave, ok := clientMsg.GTSModel.(*gtsmodel.StatusFave)
@@ -71,11 +71,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("fave was not parseable as *gtsmodel.StatusFave")
}
- if err := p.notifyFave(fave, clientMsg.TargetAccount); err != nil {
+ if err := p.notifyFave(ctx, fave, clientMsg.TargetAccount); err != nil {
return err
}
- return p.federateFave(fave, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateFave(ctx, fave, clientMsg.OriginAccount, clientMsg.TargetAccount)
case gtsmodel.ActivityStreamsAnnounce:
// CREATE BOOST/ANNOUNCE
boostWrapperStatus, ok := clientMsg.GTSModel.(*gtsmodel.Status)
@@ -83,15 +83,15 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("boost was not parseable as *gtsmodel.Status")
}
- if err := p.timelineStatus(boostWrapperStatus); err != nil {
+ if err := p.timelineStatus(ctx, boostWrapperStatus); err != nil {
return err
}
- if err := p.notifyAnnounce(boostWrapperStatus); err != nil {
+ if err := p.notifyAnnounce(ctx, boostWrapperStatus); err != nil {
return err
}
- return p.federateAnnounce(boostWrapperStatus, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateAnnounce(ctx, boostWrapperStatus, clientMsg.OriginAccount, clientMsg.TargetAccount)
case gtsmodel.ActivityStreamsBlock:
// CREATE BLOCK
block, ok := clientMsg.GTSModel.(*gtsmodel.Block)
@@ -110,7 +110,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
// TODO: same with notifications
// TODO: same with bookmarks
- return p.federateBlock(block)
+ return p.federateBlock(ctx, block)
}
case gtsmodel.ActivityStreamsUpdate:
// UPDATE
@@ -122,7 +122,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("account was not parseable as *gtsmodel.Account")
}
- return p.federateAccountUpdate(account, clientMsg.OriginAccount)
+ return p.federateAccountUpdate(ctx, account, clientMsg.OriginAccount)
}
case gtsmodel.ActivityStreamsAccept:
// ACCEPT
@@ -134,11 +134,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("accept was not parseable as *gtsmodel.Follow")
}
- if err := p.notifyFollow(follow, clientMsg.TargetAccount); err != nil {
+ if err := p.notifyFollow(ctx, follow, clientMsg.TargetAccount); err != nil {
return err
}
- return p.federateAcceptFollowRequest(follow, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateAcceptFollowRequest(ctx, follow, clientMsg.OriginAccount, clientMsg.TargetAccount)
}
case gtsmodel.ActivityStreamsUndo:
// UNDO
@@ -149,21 +149,21 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
if !ok {
return errors.New("undo was not parseable as *gtsmodel.Follow")
}
- return p.federateUnfollow(follow, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateUnfollow(ctx, follow, clientMsg.OriginAccount, clientMsg.TargetAccount)
case gtsmodel.ActivityStreamsBlock:
// UNDO BLOCK
block, ok := clientMsg.GTSModel.(*gtsmodel.Block)
if !ok {
return errors.New("undo was not parseable as *gtsmodel.Block")
}
- return p.federateUnblock(block)
+ return p.federateUnblock(ctx, block)
case gtsmodel.ActivityStreamsLike:
// UNDO LIKE/FAVE
fave, ok := clientMsg.GTSModel.(*gtsmodel.StatusFave)
if !ok {
return errors.New("undo was not parseable as *gtsmodel.StatusFave")
}
- return p.federateUnfave(fave, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateUnfave(ctx, fave, clientMsg.OriginAccount, clientMsg.TargetAccount)
case gtsmodel.ActivityStreamsAnnounce:
// UNDO ANNOUNCE/BOOST
boost, ok := clientMsg.GTSModel.(*gtsmodel.Status)
@@ -175,7 +175,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return err
}
- return p.federateUnannounce(boost, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateUnannounce(ctx, boost, clientMsg.OriginAccount, clientMsg.TargetAccount)
}
case gtsmodel.ActivityStreamsDelete:
// DELETE
@@ -200,13 +200,13 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
// delete all mentions for this status
for _, m := range statusToDelete.MentionIDs {
- if err := p.db.DeleteByID(m, >smodel.Mention{}); err != nil {
+ if err := p.db.DeleteByID(ctx, m, >smodel.Mention{}); err != nil {
return err
}
}
// delete all notifications for this status
- if err := p.db.DeleteWhere([]db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
return err
}
@@ -215,7 +215,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return err
}
- return p.federateStatusDelete(statusToDelete)
+ return p.federateStatusDelete(ctx, statusToDelete)
case gtsmodel.ActivityStreamsProfile, gtsmodel.ActivityStreamsPerson:
// DELETE ACCOUNT/PROFILE
@@ -228,7 +228,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
// origin is whichever account caused this message
origin = clientMsg.OriginAccount.ID
}
- return p.accountProcessor.Delete(clientMsg.TargetAccount, origin)
+ return p.accountProcessor.Delete(ctx, clientMsg.TargetAccount, origin)
}
}
return nil
@@ -236,10 +236,10 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
// TODO: move all the below functions into federation.Federator
-func (p *processor) federateStatus(status *gtsmodel.Status) error {
+func (p *processor) federateStatus(ctx context.Context, status *gtsmodel.Status) error {
if status.Account == nil {
a := >smodel.Account{}
- if err := p.db.GetByID(status.AccountID, a); err != nil {
+ if err := p.db.GetByID(ctx, status.AccountID, a); err != nil {
return fmt.Errorf("federateStatus: error fetching status author account: %s", err)
}
status.Account = a
@@ -260,14 +260,14 @@ func (p *processor) federateStatus(status *gtsmodel.Status) error {
return fmt.Errorf("federateStatus: error parsing outboxURI %s: %s", status.Account.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asStatus)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asStatus)
return err
}
-func (p *processor) federateStatusDelete(status *gtsmodel.Status) error {
+func (p *processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error {
if status.Account == nil {
a := >smodel.Account{}
- if err := p.db.GetByID(status.AccountID, a); err != nil {
+ if err := p.db.GetByID(ctx, status.AccountID, a); err != nil {
return fmt.Errorf("federateStatus: error fetching status author account: %s", err)
}
status.Account = a
@@ -310,11 +310,11 @@ func (p *processor) federateStatusDelete(status *gtsmodel.Status) error {
delete.SetActivityStreamsTo(asStatus.GetActivityStreamsTo())
delete.SetActivityStreamsCc(asStatus.GetActivityStreamsCc())
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, delete)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, delete)
return err
}
-func (p *processor) federateFollow(followRequest *gtsmodel.FollowRequest, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateFollow(ctx context.Context, followRequest *gtsmodel.FollowRequest, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
// if both accounts are local there's nothing to do here
if originAccount.Domain == "" && targetAccount.Domain == "" {
return nil
@@ -332,11 +332,11 @@ func (p *processor) federateFollow(followRequest *gtsmodel.FollowRequest, origin
return fmt.Errorf("federateFollow: error parsing outboxURI %s: %s", originAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asFollow)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asFollow)
return err
}
-func (p *processor) federateUnfollow(follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateUnfollow(ctx context.Context, follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
// if both accounts are local there's nothing to do here
if originAccount.Domain == "" && targetAccount.Domain == "" {
return nil
@@ -373,11 +373,11 @@ func (p *processor) federateUnfollow(follow *gtsmodel.Follow, originAccount *gts
}
// send off the Undo
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo)
return err
}
-func (p *processor) federateUnfave(fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateUnfave(ctx context.Context, fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
// if both accounts are local there's nothing to do here
if originAccount.Domain == "" && targetAccount.Domain == "" {
return nil
@@ -412,11 +412,11 @@ func (p *processor) federateUnfave(fave *gtsmodel.StatusFave, originAccount *gts
if err != nil {
return fmt.Errorf("federateFave: error parsing outboxURI %s: %s", originAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo)
return err
}
-func (p *processor) federateUnannounce(boost *gtsmodel.Status, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateUnannounce(ctx context.Context, boost *gtsmodel.Status, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
if originAccount.Domain != "" {
// nothing to do here
return nil
@@ -447,11 +447,11 @@ func (p *processor) federateUnannounce(boost *gtsmodel.Status, originAccount *gt
return fmt.Errorf("federateUnannounce: error parsing outboxURI %s: %s", originAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo)
return err
}
-func (p *processor) federateAcceptFollowRequest(follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateAcceptFollowRequest(ctx context.Context, follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
// if both accounts are local there's nothing to do here
if originAccount.Domain == "" && targetAccount.Domain == "" {
return nil
@@ -497,11 +497,11 @@ func (p *processor) federateAcceptFollowRequest(follow *gtsmodel.Follow, originA
}
// send off the accept using the accepter's outbox
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, accept)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, accept)
return err
}
-func (p *processor) federateFave(fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateFave(ctx context.Context, fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
// if both accounts are local there's nothing to do here
if originAccount.Domain == "" && targetAccount.Domain == "" {
return nil
@@ -517,11 +517,11 @@ func (p *processor) federateFave(fave *gtsmodel.StatusFave, originAccount *gtsmo
if err != nil {
return fmt.Errorf("federateFave: error parsing outboxURI %s: %s", originAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asFave)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asFave)
return err
}
-func (p *processor) federateAnnounce(boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) error {
+func (p *processor) federateAnnounce(ctx context.Context, boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) error {
announce, err := p.tc.BoostToAS(boostWrapperStatus, boostingAccount, boostedAccount)
if err != nil {
return fmt.Errorf("federateAnnounce: error converting status to announce: %s", err)
@@ -532,11 +532,11 @@ func (p *processor) federateAnnounce(boostWrapperStatus *gtsmodel.Status, boosti
return fmt.Errorf("federateAnnounce: error parsing outboxURI %s: %s", boostingAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, announce)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, announce)
return err
}
-func (p *processor) federateAccountUpdate(updatedAccount *gtsmodel.Account, originAccount *gtsmodel.Account) error {
+func (p *processor) federateAccountUpdate(ctx context.Context, updatedAccount *gtsmodel.Account, originAccount *gtsmodel.Account) error {
person, err := p.tc.AccountToAS(updatedAccount)
if err != nil {
return fmt.Errorf("federateAccountUpdate: error converting account to person: %s", err)
@@ -552,14 +552,14 @@ func (p *processor) federateAccountUpdate(updatedAccount *gtsmodel.Account, orig
return fmt.Errorf("federateAnnounce: error parsing outboxURI %s: %s", originAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, update)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, update)
return err
}
-func (p *processor) federateBlock(block *gtsmodel.Block) error {
+func (p *processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil {
a := >smodel.Account{}
- if err := p.db.GetByID(block.AccountID, a); err != nil {
+ if err := p.db.GetByID(ctx, block.AccountID, a); err != nil {
return fmt.Errorf("federateBlock: error getting block account from database: %s", err)
}
block.Account = a
@@ -567,7 +567,7 @@ func (p *processor) federateBlock(block *gtsmodel.Block) error {
if block.TargetAccount == nil {
a := >smodel.Account{}
- if err := p.db.GetByID(block.TargetAccountID, a); err != nil {
+ if err := p.db.GetByID(ctx, block.TargetAccountID, a); err != nil {
return fmt.Errorf("federateBlock: error getting block target account from database: %s", err)
}
block.TargetAccount = a
@@ -588,14 +588,14 @@ func (p *processor) federateBlock(block *gtsmodel.Block) error {
return fmt.Errorf("federateBlock: error parsing outboxURI %s: %s", block.Account.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asBlock)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asBlock)
return err
}
-func (p *processor) federateUnblock(block *gtsmodel.Block) error {
+func (p *processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil {
a := >smodel.Account{}
- if err := p.db.GetByID(block.AccountID, a); err != nil {
+ if err := p.db.GetByID(ctx, block.AccountID, a); err != nil {
return fmt.Errorf("federateUnblock: error getting block account from database: %s", err)
}
block.Account = a
@@ -603,7 +603,7 @@ func (p *processor) federateUnblock(block *gtsmodel.Block) error {
if block.TargetAccount == nil {
a := >smodel.Account{}
- if err := p.db.GetByID(block.TargetAccountID, a); err != nil {
+ if err := p.db.GetByID(ctx, block.TargetAccountID, a); err != nil {
return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err)
}
block.TargetAccount = a
@@ -642,6 +642,6 @@ func (p *processor) federateUnblock(block *gtsmodel.Block) error {
if err != nil {
return fmt.Errorf("federateUnblock: error parsing outboxURI %s: %s", block.Account.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo)
return err
}
diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go
index 2c2635175..47dd907ee 100644
--- a/internal/processing/fromcommon.go
+++ b/internal/processing/fromcommon.go
@@ -19,6 +19,7 @@
package processing
import (
+ "context"
"fmt"
"strings"
"sync"
@@ -28,7 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
)
-func (p *processor) notifyStatus(status *gtsmodel.Status) error {
+func (p *processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) error {
// if there are no mentions in this status then just bail
if len(status.MentionIDs) == 0 {
return nil
@@ -36,7 +37,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
if status.Mentions == nil {
// there are mentions but they're not fully populated on the status yet so do this
- menchies, err := p.db.GetMentions(status.MentionIDs)
+ menchies, err := p.db.GetMentions(ctx, status.MentionIDs)
if err != nil {
return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err)
}
@@ -47,7 +48,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
for _, m := range status.Mentions {
// make sure this is a local account, otherwise we don't need to create a notification for it
if m.TargetAccount == nil {
- a, err := p.db.GetAccountByID(m.TargetAccountID)
+ a, err := p.db.GetAccountByID(ctx, m.TargetAccountID)
if err != nil {
// we don't have the account or there's been an error
return fmt.Errorf("notifyStatus: error getting account with id %s from the db: %s", m.TargetAccountID, err)
@@ -60,7 +61,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
}
// make sure a notif doesn't already exist for this mention
- err := p.db.GetWhere([]db.Where{
+ err := p.db.GetWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationMention},
{Key: "target_account_id", Value: m.TargetAccountID},
{Key: "origin_account_id", Value: status.AccountID},
@@ -92,7 +93,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
Status: status,
}
- if err := p.db.Put(notif); err != nil {
+ if err := p.db.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyStatus: error putting notification in database: %s", err)
}
@@ -110,7 +111,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
return nil
}
-func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, receivingAccount *gtsmodel.Account) error {
+func (p *processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest, receivingAccount *gtsmodel.Account) error {
// return if this isn't a local account
if receivingAccount.Domain != "" {
return nil
@@ -128,7 +129,7 @@ func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, r
OriginAccountID: followRequest.AccountID,
}
- if err := p.db.Put(notif); err != nil {
+ if err := p.db.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err)
}
@@ -145,14 +146,14 @@ func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, r
return nil
}
-func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmodel.Account) error {
+func (p *processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, targetAccount *gtsmodel.Account) error {
// return if this isn't a local account
if targetAccount.Domain != "" {
return nil
}
// first remove the follow request notification
- if err := p.db.DeleteWhere([]db.Where{
+ if err := p.db.DeleteWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationFollowRequest},
{Key: "target_account_id", Value: follow.TargetAccountID},
{Key: "origin_account_id", Value: follow.AccountID},
@@ -174,7 +175,7 @@ func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmode
OriginAccountID: follow.AccountID,
OriginAccount: follow.Account,
}
- if err := p.db.Put(notif); err != nil {
+ if err := p.db.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFollow: error putting notification in database: %s", err)
}
@@ -191,7 +192,7 @@ func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmode
return nil
}
-func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmodel.Account) error {
+func (p *processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave, targetAccount *gtsmodel.Account) error {
// return if this isn't a local account
if targetAccount.Domain != "" {
return nil
@@ -213,7 +214,7 @@ func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmode
Status: fave.Status,
}
- if err := p.db.Put(notif); err != nil {
+ if err := p.db.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFave: error putting notification in database: %s", err)
}
@@ -230,14 +231,14 @@ func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmode
return nil
}
-func (p *processor) notifyAnnounce(status *gtsmodel.Status) error {
+func (p *processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status) error {
if status.BoostOfID == "" {
// not a boost, nothing to do
return nil
}
if status.BoostOf == nil {
- boostedStatus, err := p.db.GetStatusByID(status.BoostOfID)
+ boostedStatus, err := p.db.GetStatusByID(ctx, status.BoostOfID)
if err != nil {
return fmt.Errorf("notifyAnnounce: error getting status with id %s: %s", status.BoostOfID, err)
}
@@ -245,7 +246,7 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error {
}
if status.BoostOfAccount == nil {
- boostedAcct, err := p.db.GetAccountByID(status.BoostOfAccountID)
+ boostedAcct, err := p.db.GetAccountByID(ctx, status.BoostOfAccountID)
if err != nil {
return fmt.Errorf("notifyAnnounce: error getting account with id %s: %s", status.BoostOfAccountID, err)
}
@@ -264,7 +265,7 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error {
}
// make sure a notif doesn't already exist for this announce
- err := p.db.GetWhere([]db.Where{
+ err := p.db.GetWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationReblog},
{Key: "target_account_id", Value: status.BoostOfAccountID},
{Key: "origin_account_id", Value: status.AccountID},
@@ -292,7 +293,7 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error {
Status: status,
}
- if err := p.db.Put(notif); err != nil {
+ if err := p.db.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err)
}
@@ -309,10 +310,10 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error {
return nil
}
-func (p *processor) timelineStatus(status *gtsmodel.Status) error {
+func (p *processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error {
// make sure the author account is pinned onto the status
if status.Account == nil {
- a, err := p.db.GetAccountByID(status.AccountID)
+ a, err := p.db.GetAccountByID(ctx, status.AccountID)
if err != nil {
return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err)
}
@@ -320,7 +321,7 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error {
}
// get local followers of the account that posted the status
- follows, err := p.db.GetAccountFollowedBy(status.AccountID, true)
+ follows, err := p.db.GetAccountFollowedBy(ctx, status.AccountID, true)
if err != nil {
return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err)
}
@@ -338,7 +339,7 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error {
errors := make(chan error, len(follows))
for _, f := range follows {
- go p.timelineStatusForAccount(status, f.AccountID, errors, &wg)
+ go p.timelineStatusForAccount(ctx, status, f.AccountID, errors, &wg)
}
// read any errors that come in from the async functions
@@ -365,11 +366,11 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error {
return nil
}
-func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID string, errors chan error, wg *sync.WaitGroup) {
+func (p *processor) timelineStatusForAccount(ctx context.Context, status *gtsmodel.Status, accountID string, errors chan error, wg *sync.WaitGroup) {
defer wg.Done()
// get the timeline owner account
- timelineAccount, err := p.db.GetAccountByID(accountID)
+ timelineAccount, err := p.db.GetAccountByID(ctx, accountID)
if err != nil {
errors <- fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %s", accountID, err)
return
diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go
index c95c27778..6d8de289f 100644
--- a/internal/processing/fromfederator.go
+++ b/internal/processing/fromfederator.go
@@ -19,6 +19,7 @@
package processing
import (
+ "context"
"errors"
"fmt"
"net/url"
@@ -29,7 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
)
-func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) error {
+func (p *processor) processFromFederator(ctx context.Context, federatorMsg gtsmodel.FromFederator) error {
l := p.log.WithFields(logrus.Fields{
"func": "processFromFederator",
"federatorMsg": fmt.Sprintf("%+v", federatorMsg),
@@ -53,11 +54,11 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
return err
}
- if err := p.timelineStatus(status); err != nil {
+ if err := p.timelineStatus(ctx, status); err != nil {
return err
}
- if err := p.notifyStatus(status); err != nil {
+ if err := p.notifyStatus(ctx, status); err != nil {
return err
}
case gtsmodel.ActivityStreamsProfile:
@@ -70,7 +71,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
return errors.New("like was not parseable as *gtsmodel.StatusFave")
}
- if err := p.notifyFave(incomingFave, federatorMsg.ReceivingAccount); err != nil {
+ if err := p.notifyFave(ctx, incomingFave, federatorMsg.ReceivingAccount); err != nil {
return err
}
case gtsmodel.ActivityStreamsFollow:
@@ -80,7 +81,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
return errors.New("incomingFollowRequest was not parseable as *gtsmodel.FollowRequest")
}
- if err := p.notifyFollowRequest(incomingFollowRequest, federatorMsg.ReceivingAccount); err != nil {
+ if err := p.notifyFollowRequest(ctx, incomingFollowRequest, federatorMsg.ReceivingAccount); err != nil {
return err
}
case gtsmodel.ActivityStreamsAnnounce:
@@ -100,17 +101,17 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
}
incomingAnnounce.ID = incomingAnnounceID
- if err := p.db.PutStatus(incomingAnnounce); err != nil {
+ if err := p.db.PutStatus(ctx, incomingAnnounce); err != nil {
if err != db.ErrNoEntries {
return fmt.Errorf("error adding dereferenced announce to the db: %s", err)
}
}
- if err := p.timelineStatus(incomingAnnounce); err != nil {
+ if err := p.timelineStatus(ctx, incomingAnnounce); err != nil {
return err
}
- if err := p.notifyAnnounce(incomingAnnounce); err != nil {
+ if err := p.notifyAnnounce(ctx, incomingAnnounce); err != nil {
return err
}
case gtsmodel.ActivityStreamsBlock:
@@ -172,13 +173,13 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
// delete all mentions for this status
for _, m := range statusToDelete.MentionIDs {
- if err := p.db.DeleteByID(m, >smodel.Mention{}); err != nil {
+ if err := p.db.DeleteByID(ctx, m, >smodel.Mention{}); err != nil {
return err
}
}
// delete all notifications for this status
- if err := p.db.DeleteWhere([]db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
return err
}
@@ -198,7 +199,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
return errors.New("follow was not parseable as *gtsmodel.Follow")
}
- if err := p.notifyFollow(follow, federatorMsg.ReceivingAccount); err != nil {
+ if err := p.notifyFollow(ctx, follow, federatorMsg.ReceivingAccount); err != nil {
return err
}
}
diff --git a/internal/processing/instance.go b/internal/processing/instance.go
index b151744ef..b42be869d 100644
--- a/internal/processing/instance.go
+++ b/internal/processing/instance.go
@@ -19,6 +19,7 @@
package processing
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -29,9 +30,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) InstanceGet(domain string) (*apimodel.Instance, gtserror.WithCode) {
+func (p *processor) InstanceGet(ctx context.Context, domain string) (*apimodel.Instance, gtserror.WithCode) {
i := >smodel.Instance{}
- if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain}}, i); err != nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: domain}}, i); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", p.config.Host, err))
}
@@ -43,15 +44,15 @@ func (p *processor) InstanceGet(domain string) (*apimodel.Instance, gtserror.Wit
return ai, nil
}
-func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) {
+func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) {
// fetch the instance entry from the db for processing
i := >smodel.Instance{}
- if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: p.config.Host}}, i); err != nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: p.config.Host}}, i); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", p.config.Host, err))
}
// fetch the instance account from the db for processing
- ia, err := p.db.GetInstanceAccount("")
+ ia, err := p.db.GetInstanceAccount(ctx, "")
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance account %s: %s", p.config.Host, err))
}
@@ -67,13 +68,13 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest)
// validate & update site contact account if it's set on the form
if form.ContactUsername != nil {
// make sure the account with the given username exists in the db
- contactAccount, err := p.db.GetLocalAccountByUsername(*form.ContactUsername)
+ contactAccount, err := p.db.GetLocalAccountByUsername(ctx, *form.ContactUsername)
if err != nil {
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername))
}
// make sure it has a user associated with it
contactUser := >smodel.User{}
- if err := p.db.GetWhere([]db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil {
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername))
}
// suspended accounts cannot be contact accounts
@@ -132,7 +133,7 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest)
// process avatar if provided
if form.Avatar != nil && form.Avatar.Size != 0 {
- _, err := p.accountProcessor.UpdateAvatar(form.Avatar, ia.ID)
+ _, err := p.accountProcessor.UpdateAvatar(ctx, form.Avatar, ia.ID)
if err != nil {
return nil, gtserror.NewErrorBadRequest(err, "error processing avatar")
}
@@ -140,13 +141,13 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest)
// process header if provided
if form.Header != nil && form.Header.Size != 0 {
- _, err := p.accountProcessor.UpdateHeader(form.Header, ia.ID)
+ _, err := p.accountProcessor.UpdateHeader(ctx, form.Header, ia.ID)
if err != nil {
return nil, gtserror.NewErrorBadRequest(err, "error processing header")
}
}
- if err := p.db.UpdateByID(i.ID, i); err != nil {
+ if err := p.db.UpdateByID(ctx, i.ID, i); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", p.config.Host, err))
}
diff --git a/internal/processing/media.go b/internal/processing/media.go
index 6ca0eda5b..0b2443893 100644
--- a/internal/processing/media.go
+++ b/internal/processing/media.go
@@ -19,23 +19,25 @@
package processing
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) MediaCreate(authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) {
- return p.mediaProcessor.Create(authed.Account, form)
+func (p *processor) MediaCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) {
+ return p.mediaProcessor.Create(ctx, authed.Account, form)
}
-func (p *processor) MediaGet(authed *oauth.Auth, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
- return p.mediaProcessor.GetMedia(authed.Account, mediaAttachmentID)
+func (p *processor) MediaGet(ctx context.Context, authed *oauth.Auth, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
+ return p.mediaProcessor.GetMedia(ctx, authed.Account, mediaAttachmentID)
}
-func (p *processor) MediaUpdate(authed *oauth.Auth, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
- return p.mediaProcessor.Update(authed.Account, mediaAttachmentID, form)
+func (p *processor) MediaUpdate(ctx context.Context, authed *oauth.Auth, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
+ return p.mediaProcessor.Update(ctx, authed.Account, mediaAttachmentID, form)
}
-func (p *processor) FileGet(authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) {
- return p.mediaProcessor.GetFile(authed.Account, form)
+func (p *processor) FileGet(ctx context.Context, authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) {
+ return p.mediaProcessor.GetFile(ctx, authed.Account, form)
}
diff --git a/internal/processing/processor.go b/internal/processing/processor.go
index 48ed2a35f..c2ddbbed9 100644
--- a/internal/processing/processor.go
+++ b/internal/processing/processor.go
@@ -64,108 +64,108 @@ type Processor interface {
*/
// AccountCreate processes the given form for creating a new account, returning an oauth token for that account if successful.
- AccountCreate(authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error)
+ AccountCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error)
// AccountGet processes the given request for account information.
- AccountGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error)
+ AccountGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error)
// AccountUpdate processes the update of an account with the given form
- AccountUpdate(authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error)
+ AccountUpdate(ctx context.Context, authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error)
// AccountStatusesGet fetches a number of statuses (in time descending order) from the given account, filtered by visibility for
// the account given in authed.
- AccountStatusesGet(authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode)
+ AccountStatusesGet(ctx context.Context, authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode)
// AccountFollowersGet fetches a list of the target account's followers.
- AccountFollowersGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
+ AccountFollowersGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
// AccountFollowingGet fetches a list of the accounts that target account is following.
- AccountFollowingGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
+ AccountFollowingGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
// AccountRelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account.
- AccountRelationshipGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ AccountRelationshipGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// AccountFollowCreate handles a follow request to an account, either remote or local.
- AccountFollowCreate(authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode)
+ AccountFollowCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode)
// AccountFollowRemove handles the removal of a follow/follow request to an account, either remote or local.
- AccountFollowRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ AccountFollowRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// AccountBlockCreate handles the creation of a block from authed account to target account, either remote or local.
- AccountBlockCreate(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ AccountBlockCreate(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// AccountBlockRemove handles the removal of a block from authed account to target account, either remote or local.
- AccountBlockRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ AccountBlockRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// AdminEmojiCreate handles the creation of a new instance emoji by an admin, using the given form.
- AdminEmojiCreate(authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error)
+ AdminEmojiCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error)
// AdminDomainBlockCreate handles the creation of a new domain block by an admin, using the given form.
- AdminDomainBlockCreate(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode)
+ AdminDomainBlockCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode)
// AdminDomainBlocksImport handles the import of multiple domain blocks by an admin, using the given form.
- AdminDomainBlocksImport(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode)
+ AdminDomainBlocksImport(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode)
// AdminDomainBlocksGet returns a list of currently blocked domains.
- AdminDomainBlocksGet(authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode)
+ AdminDomainBlocksGet(ctx context.Context, authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode)
// AdminDomainBlockGet returns one domain block, specified by ID.
- AdminDomainBlockGet(authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode)
+ AdminDomainBlockGet(ctx context.Context, authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode)
// AdminDomainBlockDelete deletes one domain block, specified by ID, returning the deleted domain block.
- AdminDomainBlockDelete(authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode)
+ AdminDomainBlockDelete(ctx context.Context, authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode)
// AppCreate processes the creation of a new API application
- AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error)
+ AppCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error)
// BlocksGet returns a list of accounts blocked by the requesting account.
- BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode)
+ BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode)
// FileGet handles the fetching of a media attachment file via the fileserver.
- FileGet(authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error)
+ FileGet(ctx context.Context, authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error)
// FollowRequestsGet handles the getting of the authed account's incoming follow requests
- FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode)
+ FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode)
// FollowRequestAccept handles the acceptance of a follow request from the given account ID
- FollowRequestAccept(auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode)
+ FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode)
// InstanceGet retrieves instance information for serving at api/v1/instance
- InstanceGet(domain string) (*apimodel.Instance, gtserror.WithCode)
+ InstanceGet(ctx context.Context, domain string) (*apimodel.Instance, gtserror.WithCode)
// InstancePatch updates this instance according to the given form.
//
// It should already be ascertained that the requesting account is authenticated and an admin.
- InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode)
+ InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode)
// MediaCreate handles the creation of a media attachment, using the given form.
- MediaCreate(authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error)
+ MediaCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error)
// MediaGet handles the GET of a media attachment with the given ID
- MediaGet(authed *oauth.Auth, attachmentID string) (*apimodel.Attachment, gtserror.WithCode)
+ MediaGet(ctx context.Context, authed *oauth.Auth, attachmentID string) (*apimodel.Attachment, gtserror.WithCode)
// MediaUpdate handles the PUT of a media attachment with the given ID and form
- MediaUpdate(authed *oauth.Auth, attachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode)
+ MediaUpdate(ctx context.Context, authed *oauth.Auth, attachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode)
// NotificationsGet
- NotificationsGet(authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode)
+ NotificationsGet(ctx context.Context, authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode)
// SearchGet performs a search with the given params, resolving/dereferencing remotely as desired
- SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode)
+ SearchGet(ctx context.Context, authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode)
// StatusCreate processes the given form to create a new status, returning the api model representation of that status if it's OK.
- StatusCreate(authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error)
+ StatusCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error)
// StatusDelete processes the delete of a given status, returning the deleted status if the delete goes through.
- StatusDelete(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
+ StatusDelete(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
// StatusFave processes the faving of a given status, returning the updated status if the fave goes through.
- StatusFave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
+ StatusFave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
// StatusBoost processes the boost/reblog of a given status, returning the newly-created boost if all is well.
- StatusBoost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
+ StatusBoost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
// StatusUnboost processes the unboost/unreblog of a given status, returning the status if all is well.
- StatusUnboost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
+ StatusUnboost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
// StatusBoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings.
- StatusBoostedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode)
+ StatusBoostedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode)
// StatusFavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings.
- StatusFavedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error)
+ StatusFavedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error)
// StatusGet gets the given status, taking account of privacy settings and blocks etc.
- StatusGet(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
+ StatusGet(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
// StatusUnfave processes the unfaving of a given status, returning the updated status if the fave goes through.
- StatusUnfave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
+ StatusUnfave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
// StatusGetContext returns the context (previous and following posts) from the given status ID
- StatusGetContext(authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode)
+ StatusGetContext(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode)
// HomeTimelineGet returns statuses from the home timeline, with the given filters/parameters.
- HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
+ HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
// PublicTimelineGet returns statuses from the public/local timeline, with the given filters/parameters.
- PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
+ PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
// FavedTimelineGet returns faved statuses, with the given filters/parameters.
- FavedTimelineGet(authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
+ FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
// AuthorizeStreamingRequest returns a gotosocial account in exchange for an access token, or an error if the given token is not valid.
- AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error)
+ AuthorizeStreamingRequest(ctx context.Context, accessToken string) (*gtsmodel.Account, error)
// OpenStreamForAccount opens a new stream for the given account, with the given stream type.
- OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode)
+ OpenStreamForAccount(ctx context.Context, account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode)
/*
FEDERATION API-FACING PROCESSING FUNCTIONS
@@ -199,10 +199,10 @@ type Processor interface {
GetWebfingerAccount(ctx context.Context, requestedUsername string, requestURL *url.URL) (*apimodel.WellKnownResponse, gtserror.WithCode)
// GetNodeInfoRel returns a well known response giving the path to node info.
- GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode)
+ GetNodeInfoRel(ctx context.Context, request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode)
// GetNodeInfo returns a node info struct in response to a node info request.
- GetNodeInfo(request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode)
+ GetNodeInfo(ctx context.Context, request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode)
// InboxPost handles POST requests to a user's inbox for new activitypub messages.
//
@@ -280,7 +280,7 @@ func NewProcessor(config *config.Config, tc typeutils.TypeConverter, federator f
}
// Start starts the Processor, reading from its channels and passing messages back and forth.
-func (p *processor) Start() error {
+func (p *processor) Start(ctx context.Context) error {
go func() {
DistLoop:
for {
@@ -288,14 +288,14 @@ func (p *processor) Start() error {
case clientMsg := <-p.fromClientAPI:
p.log.Tracef("received message FROM client API: %+v", clientMsg)
go func() {
- if err := p.processFromClientAPI(clientMsg); err != nil {
+ if err := p.processFromClientAPI(ctx, clientMsg); err != nil {
p.log.Error(err)
}
}()
case federatorMsg := <-p.fromFederator:
p.log.Tracef("received message FROM federator: %+v", federatorMsg)
go func() {
- if err := p.processFromFederator(federatorMsg); err != nil {
+ if err := p.processFromFederator(ctx, federatorMsg); err != nil {
p.log.Error(err)
}
}()
diff --git a/internal/transport/controller.go b/internal/transport/controller.go
index 4eb6b5658..c2f5026e0 100644
--- a/internal/transport/controller.go
+++ b/internal/transport/controller.go
@@ -19,6 +19,7 @@
package transport
import (
+ "context"
"crypto"
"fmt"
"sync"
@@ -33,7 +34,7 @@ import (
// Controller generates transports for use in making federation requests to other servers.
type Controller interface {
NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error)
- NewTransportForUsername(username string) (Transport, error)
+ NewTransportForUsername(ctx context.Context, username string) (Transport, error)
}
type controller struct {
@@ -90,7 +91,7 @@ func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (T
}, nil
}
-func (c *controller) NewTransportForUsername(username string) (Transport, error) {
+func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) {
// We need an account to use to create a transport for dereferecing something.
// If a username has been given, we can fetch the account with that username and use it.
// Otherwise, we can take the instance account and use those credentials to make the request.
@@ -101,7 +102,7 @@ func (c *controller) NewTransportForUsername(username string) (Transport, error)
u = username
}
- ourAccount, err := c.db.GetLocalAccountByUsername(u)
+ ourAccount, err := c.db.GetLocalAccountByUsername(ctx, u)
if err != nil {
return nil, fmt.Errorf("error getting account %s from db: %s", username, err)
}
diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go
index 844cb6bea..fd0fb576f 100644
--- a/internal/transport/deliver.go
+++ b/internal/transport/deliver.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
package transport
import (
@@ -5,12 +23,12 @@ import (
"net/url"
)
-func (t *transport) BatchDeliver(c context.Context, b []byte, recipients []*url.URL) error {
- return t.sigTransport.BatchDeliver(c, b, recipients)
+func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error {
+ return t.sigTransport.BatchDeliver(ctx, b, recipients)
}
-func (t *transport) Deliver(c context.Context, b []byte, to *url.URL) error {
+func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
l := t.log.WithField("func", "Deliver")
l.Debugf("performing POST to %s", to.String())
- return t.sigTransport.Deliver(c, b, to)
+ return t.sigTransport.Deliver(ctx, b, to)
}
diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go
index d7a28fe17..85fa370ee 100644
--- a/internal/transport/dereference.go
+++ b/internal/transport/dereference.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
package transport
import (
@@ -5,8 +23,8 @@ import (
"net/url"
)
-func (t *transport) Dereference(c context.Context, iri *url.URL) ([]byte, error) {
+func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) {
l := t.log.WithField("func", "Dereference")
l.Debugf("performing GET to %s", iri.String())
- return t.sigTransport.Dereference(c, iri)
+ return t.sigTransport.Dereference(ctx, iri)
}
diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go
index a8b2ddfc7..fca611598 100644
--- a/internal/transport/derefinstance.go
+++ b/internal/transport/derefinstance.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
package transport
import (
@@ -16,7 +34,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmodel.Instance, error) {
+func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error) {
l := t.log.WithField("func", "DereferenceInstance")
var i *gtsmodel.Instance
@@ -27,7 +45,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo
//
// This will only work with Mastodon-api compatible instances: Mastodon, some Pleroma instances, GoToSocial.
l.Debugf("trying to dereference instance %s by /api/v1/instance", iri.Host)
- i, err = dereferenceByAPIV1Instance(c, t, iri)
+ i, err = dereferenceByAPIV1Instance(ctx, t, iri)
if err == nil {
l.Debugf("successfully dereferenced instance using /api/v1/instance")
return i, nil
@@ -37,7 +55,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo
// If that doesn't work, try to dereference using /.well-known/nodeinfo.
// This will involve two API calls and return less info overall, but should be more widely compatible.
l.Debugf("trying to dereference instance %s by /.well-known/nodeinfo", iri.Host)
- i, err = dereferenceByNodeInfo(c, t, iri)
+ i, err = dereferenceByNodeInfo(ctx, t, iri)
if err == nil {
l.Debugf("successfully dereferenced instance using /.well-known/nodeinfo")
return i, nil
@@ -58,7 +76,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo
}, nil
}
-func dereferenceByAPIV1Instance(c context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) {
+func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) {
l := t.log.WithField("func", "dereferenceByAPIV1Instance")
cleanIRI := &url.URL{
@@ -68,11 +86,10 @@ func dereferenceByAPIV1Instance(c context.Context, t *transport, iri *url.URL) (
}
l.Debugf("performing GET to %s", cleanIRI.String())
- req, err := http.NewRequest("GET", cleanIRI.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
if err != nil {
return nil, err
}
- req = req.WithContext(c)
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
@@ -216,7 +233,7 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm
return i, nil
}
-func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url.URL, error) {
+func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) {
l := t.log.WithField("func", "callNodeInfoWellKnown")
cleanIRI := &url.URL{
@@ -226,11 +243,11 @@ func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url.
}
l.Debugf("performing GET to %s", cleanIRI.String())
- req, err := http.NewRequest("GET", cleanIRI.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
if err != nil {
return nil, err
}
- req = req.WithContext(c)
+
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
@@ -281,15 +298,15 @@ func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url.
return nodeinfoHref, nil
}
-func callNodeInfo(c context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) {
+func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) {
l := t.log.WithField("func", "callNodeInfo")
l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequest("GET", iri.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
if err != nil {
return nil, err
}
- req = req.WithContext(c)
+
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go
index 5fa901100..e265bfdd4 100644
--- a/internal/transport/derefmedia.go
+++ b/internal/transport/derefmedia.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
package transport
import (
@@ -8,14 +26,13 @@ import (
"net/url"
)
-func (t *transport) DereferenceMedia(c context.Context, iri *url.URL, expectedContentType string) ([]byte, error) {
+func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL, expectedContentType string) ([]byte, error) {
l := t.log.WithField("func", "DereferenceMedia")
l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequest("GET", iri.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
if err != nil {
return nil, err
}
- req = req.WithContext(c)
if expectedContentType == "" {
req.Header.Add("Accept", "*/*")
} else {
diff --git a/internal/transport/finger.go b/internal/transport/finger.go
index 12cd2fb64..bf64521c4 100644
--- a/internal/transport/finger.go
+++ b/internal/transport/finger.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
package transport
import (
@@ -8,7 +26,7 @@ import (
"net/url"
)
-func (t *transport) Finger(c context.Context, targetUsername string, targetDomain string) ([]byte, error) {
+func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
l := t.log.WithField("func", "Finger")
urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain)
l.Debugf("performing GET to %s", urlString)
@@ -20,11 +38,11 @@ func (t *transport) Finger(c context.Context, targetUsername string, targetDomai
l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequest("GET", iri.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
if err != nil {
return nil, err
}
- req = req.WithContext(c)
+
req.Header.Add("Accept", "application/json")
req.Header.Add("Accept", "application/jrd+json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
diff --git a/internal/transport/transport.go b/internal/transport/transport.go
index 04c72de5c..8d8262834 100644
--- a/internal/transport/transport.go
+++ b/internal/transport/transport.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
package transport
import (
@@ -17,11 +35,11 @@ import (
type Transport interface {
pub.Transport
// DereferenceMedia fetches the bytes of the given media attachment IRI, with the expectedContentType.
- DereferenceMedia(c context.Context, iri *url.URL, expectedContentType string) ([]byte, error)
+ DereferenceMedia(ctx context.Context, iri *url.URL, expectedContentType string) ([]byte, error)
// DereferenceInstance dereferences remote instance information, first by checking /api/v1/instance, and then by checking /.well-known/nodeinfo.
- DereferenceInstance(c context.Context, iri *url.URL) (*gtsmodel.Instance, error)
+ DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error)
// Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body.
- Finger(c context.Context, targetUsername string, targetDomains string) ([]byte, error)
+ Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error)
}
// transport implements the Transport interface
diff --git a/testrig/db.go b/testrig/db.go
index c670103f1..47ff1a78c 100644
--- a/testrig/db.go
+++ b/testrig/db.go
@@ -84,111 +84,114 @@ func NewTestDB() db.DB {
// signatures with, otherwise this function will randomly generate new keys for accounts and signature
// verification will fail.
func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) {
+ ctx := context.Background()
+
for _, m := range testModels {
- if err := db.CreateTable(m); err != nil {
+ if err := db.CreateTable(ctx, m); err != nil {
panic(err)
}
}
for _, v := range NewTestTokens() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
for _, v := range NewTestClients() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
for _, v := range NewTestApplications() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
for _, v := range NewTestUsers() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
if accounts == nil {
for _, v := range NewTestAccounts() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
} else {
for _, v := range accounts {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
}
for _, v := range NewTestAttachments() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
for _, v := range NewTestStatuses() {
- if err := db.PutStatus(v); err != nil {
+ if err := db.PutStatus(ctx, v); err != nil {
panic(err)
}
}
for _, v := range NewTestEmojis() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
for _, v := range NewTestTags() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
for _, v := range NewTestMentions() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
for _, v := range NewTestFaves() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
for _, v := range NewTestFollows() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
for _, v := range NewTestNotifications() {
- if err := db.Put(v); err != nil {
+ if err := db.Put(ctx, v); err != nil {
panic(err)
}
}
- if err := db.CreateInstanceAccount(); err != nil {
+ if err := db.CreateInstanceAccount(ctx); err != nil {
panic(err)
}
- if err := db.CreateInstanceInstance(); err != nil {
+ if err := db.CreateInstanceInstance(ctx); err != nil {
panic(err)
}
}
// StandardDBTeardown drops all the standard testing tables/models from the database to ensure it's clean for the next test.
func StandardDBTeardown(db db.DB) {
+ ctx := context.Background()
for _, m := range testModels {
- if err := db.DropTable(m); err != nil {
+ if err := db.DropTable(ctx, m); err != nil {
panic(err)
}
}
diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/go.mod b/vendor/github.com/go-pg/pg/extra/pgdebug/go.mod
deleted file mode 100644
index d44ba0123..000000000
--- a/vendor/github.com/go-pg/pg/extra/pgdebug/go.mod
+++ /dev/null
@@ -1,7 +0,0 @@
-module github.com/go-pg/pg/extra/pgdebug
-
-go 1.15
-
-replace github.com/go-pg/pg/v10 => ../..
-
-require github.com/go-pg/pg/v10 v10.6.2
diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/go.sum b/vendor/github.com/go-pg/pg/extra/pgdebug/go.sum
deleted file mode 100644
index 8483a864a..000000000
--- a/vendor/github.com/go-pg/pg/extra/pgdebug/go.sum
+++ /dev/null
@@ -1,161 +0,0 @@
-cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
-github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
-github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
-github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
-github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
-github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
-github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
-github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
-github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
-github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU=
-github.com/go-pg/zerochecker v0.2.0/go.mod h1:NJZ4wKL0NmTtz0GKCoJ8kym6Xn/EQzXRl2OnAe7MmDo=
-github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
-github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
-github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
-github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
-github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
-github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
-github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
-github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
-github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
-github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
-github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM=
-github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
-github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
-github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
-github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
-github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.3 h1:x95R7cp+rSeeqAMI2knLtQ0DKlaBhv2NrtrOvafPHRo=
-github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
-github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
-github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
-github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
-github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
-github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
-github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
-github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
-github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78=
-github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
-github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
-github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
-github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M=
-github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY=
-github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
-github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
-github.com/onsi/gomega v1.10.3 h1:gph6h/qe9GSUw1NhH1gp+qb+h8rXD8Cy60Z32Qw3ELA=
-github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc=
-github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
-github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
-github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
-github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
-github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
-github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94=
-github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ=
-github.com/vmihailenco/msgpack/v4 v4.3.11/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+NXzzngzBKDPIqw4=
-github.com/vmihailenco/msgpack/v5 v5.0.0 h1:nCaMMPEyfgwkGc/Y0GreJPhuvzqCqW+Ufq5lY7zLO2c=
-github.com/vmihailenco/msgpack/v5 v5.0.0/go.mod h1:HVxBVPUK/+fZMonk4bi1islLa8V3cfnBug0+4dykPzo=
-github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI=
-github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc=
-github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI=
-go.opentelemetry.io/otel v0.14.0 h1:YFBEfjCk9MTjaytCNSUkp9Q8lF7QJezA06T71FbQxLQ=
-go.opentelemetry.io/otel v0.14.0/go.mod h1:vH5xEuwy7Rts0GNtsCW3HYQoZDY+OmBJ6t1bFGGlxgw=
-golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 h1:phUcVbl53swtrUN8kQEXFhUxPlIlWyBfKmidCu7P95o=
-golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
-golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
-golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
-golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
-golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
-golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
-golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
-golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME=
-golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
-golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
-golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
-golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
-golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
-golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
-golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
-golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
-golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
-golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
-google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
-google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
-google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
-google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
-google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
-google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
-google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
-google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
-google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
-google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
-google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
-google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
-google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
-google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
-google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c=
-google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
-gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
-gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
-gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
-gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
-gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
-gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
-gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
-honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
-mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w=
-mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ=
diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/pgdebug.go b/vendor/github.com/go-pg/pg/extra/pgdebug/pgdebug.go
deleted file mode 100644
index bbf6ada19..000000000
--- a/vendor/github.com/go-pg/pg/extra/pgdebug/pgdebug.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package pgdebug
-
-import (
- "context"
- "fmt"
-
- "github.com/go-pg/pg/v10"
-)
-
-// DebugHook is a query hook that logs an error with a query if there are any.
-// It can be installed with:
-//
-// db.AddQueryHook(pgext.DebugHook{})
-type DebugHook struct {
- // Verbose causes hook to print all queries (even those without an error).
- Verbose bool
- EmptyLine bool
-}
-
-var _ pg.QueryHook = (*DebugHook)(nil)
-
-func (h DebugHook) BeforeQuery(ctx context.Context, evt *pg.QueryEvent) (context.Context, error) {
- q, err := evt.FormattedQuery()
- if err != nil {
- return nil, err
- }
-
- if evt.Err != nil {
- fmt.Printf("%s executing a query:\n%s\n", evt.Err, q)
- } else if h.Verbose {
- if h.EmptyLine {
- fmt.Println()
- }
- fmt.Println(string(q))
- }
-
- return ctx, nil
-}
-
-func (DebugHook) AfterQuery(context.Context, *pg.QueryEvent) error {
- return nil
-}
diff --git a/vendor/github.com/uptrace/bun/.gitignore b/vendor/github.com/uptrace/bun/.gitignore
new file mode 100644
index 000000000..6f7763c71
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/.gitignore
@@ -0,0 +1,3 @@
+*.s3db
+*.prof
+*.test
diff --git a/vendor/github.com/uptrace/bun/.prettierrc.yaml b/vendor/github.com/uptrace/bun/.prettierrc.yaml
new file mode 100644
index 000000000..decea5634
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/.prettierrc.yaml
@@ -0,0 +1,6 @@
+trailingComma: all
+tabWidth: 2
+semi: false
+singleQuote: true
+proseWrap: always
+printWidth: 100
diff --git a/vendor/github.com/uptrace/bun/CHANGELOG.md b/vendor/github.com/uptrace/bun/CHANGELOG.md
new file mode 100644
index 000000000..01bf6ba31
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/CHANGELOG.md
@@ -0,0 +1,99 @@
+# Changelog
+
+## v0.4.1 - Aug 18 2021
+
+- Fixed migrate package to properly rollback migrations.
+- Added `allowzero` tag option that undoes `nullzero` option.
+
+## v0.4.0 - Aug 11 2021
+
+- Changed `WhereGroup` function to accept `*SelectQuery`.
+- Fixed query hooks for count queries.
+
+## v0.3.4 - Jul 19 2021
+
+- Renamed `migrate.CreateGo` to `CreateGoMigration`.
+- Added `migrate.WithPackageName` to customize the Go package name in generated migrations.
+- Renamed `migrate.CreateSQL` to `CreateSQLMigrations` and changed `CreateSQLMigrations` to create
+ both up and down migration files.
+
+## v0.3.1 - Jul 12 2021
+
+- Renamed `alias` field struct tag to `alt` so it is not confused with column alias.
+- Reworked migrate package API. See
+ [migrate](https://github.com/uptrace/bun/tree/master/example/migrate) example for details.
+
+## v0.3.0 - Jul 09 2021
+
+- Changed migrate package to return structured data instead of logging the progress. See
+ [migrate](https://github.com/uptrace/bun/tree/master/example/migrate) example for details.
+
+## v0.2.14 - Jul 01 2021
+
+- Added [sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) by
+ [Ivan Trubach](https://github.com/tie).
+- Added support for MySQL 5.7 in addition to MySQL 8.
+
+## v0.2.12 - Jun 29 2021
+
+- Fixed scanners for net.IP and net.IPNet.
+
+## v0.2.10 - Jun 29 2021
+
+- Fixed pgdriver to format passed query args.
+
+## v0.2.9 - Jun 27 2021
+
+- Added support for prepared statements in pgdriver.
+
+## v0.2.7 - Jun 26 2021
+
+- Added `UpdateQuery.Bulk` helper to generate bulk-update queries.
+
+ Before:
+
+ ```go
+ models := []Model{
+ {42, "hello"},
+ {43, "world"},
+ }
+ return db.NewUpdate().
+ With("_data", db.NewValues(&models)).
+ Model(&models).
+ Table("_data").
+ Set("model.str = _data.str").
+ Where("model.id = _data.id")
+ ```
+
+ Now:
+
+ ```go
+ db.NewUpdate().
+ Model(&models).
+ Bulk()
+ ```
+
+## v0.2.5 - Jun 25 2021
+
+- Changed time.Time to always append zero time as `NULL`.
+- Added `db.RunInTx` helper.
+
+## v0.2.4 - Jun 21 2021
+
+- Added SSL support to pgdriver.
+
+## v0.2.3 - Jun 20 2021
+
+- Replaced `ForceDelete(ctx)` with `ForceDelete().Exec(ctx)` for soft deletes.
+
+## v0.2.1 - Jun 17 2021
+
+- Renamed `DBI` to `IConn`. `IConn` is a common interface for `*sql.DB`, `*sql.Conn`, and `*sql.Tx`.
+- Added `IDB`. `IDB` is a common interface for `*bun.DB`, `bun.Conn`, and `bun.Tx`.
+
+## v0.2.0 - Jun 16 2021
+
+- Changed [model hooks](https://bun.uptrace.dev/guide/hooks.html#model-hooks). See
+ [model-hooks](example/model-hooks) example.
+- Renamed `has-one` to `belongs-to`. Renamed `belongs-to` to `has-one`. Previously Bun used
+ incorrect names for these relations.
diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/LICENSE b/vendor/github.com/uptrace/bun/LICENSE
similarity index 94%
rename from vendor/github.com/go-pg/pg/extra/pgdebug/LICENSE
rename to vendor/github.com/uptrace/bun/LICENSE
index 7751509b8..7ec81810c 100644
--- a/vendor/github.com/go-pg/pg/extra/pgdebug/LICENSE
+++ b/vendor/github.com/uptrace/bun/LICENSE
@@ -1,4 +1,4 @@
-Copyright (c) 2013 github.com/go-pg/pg Authors. All rights reserved.
+Copyright (c) 2021 Vladimir Mihailenco. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
diff --git a/vendor/github.com/uptrace/bun/Makefile b/vendor/github.com/uptrace/bun/Makefile
new file mode 100644
index 000000000..54744c617
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/Makefile
@@ -0,0 +1,21 @@
+ALL_GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | sort)
+
+test:
+ set -e; for dir in $(ALL_GO_MOD_DIRS); do \
+ echo "go test in $${dir}"; \
+ (cd "$${dir}" && \
+ go test ./... && \
+ go vet); \
+ done
+
+go_mod_tidy:
+ set -e; for dir in $(ALL_GO_MOD_DIRS); do \
+ echo "go mod tidy in $${dir}"; \
+ (cd "$${dir}" && \
+ go get -d ./... && \
+ go mod tidy); \
+ done
+
+fmt:
+ gofmt -w -s ./
+ goimports -w -local github.com/uptrace/bun ./
diff --git a/vendor/github.com/uptrace/bun/README.md b/vendor/github.com/uptrace/bun/README.md
new file mode 100644
index 000000000..e7cc77a60
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/README.md
@@ -0,0 +1,267 @@
+
+
+
+
+
+
+# Simple and performant SQL database client
+
+[](https://github.com/uptrace/bun/actions)
+[](https://pkg.go.dev/github.com/uptrace/bun)
+[](https://bun.uptrace.dev/)
+[](https://discord.gg/rWtp5Aj)
+
+Main features are:
+
+- Works with [PostgreSQL](https://bun.uptrace.dev/guide/drivers.html#postgresql),
+ [MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql),
+ [SQLite](https://bun.uptrace.dev/guide/drivers.html#sqlite).
+- [Selecting](/example/basic/) into a map, struct, slice of maps/structs/vars.
+- [Bulk inserts](https://bun.uptrace.dev/guide/queries.html#insert).
+- [Bulk updates](https://bun.uptrace.dev/guide/queries.html#update) using common table expressions.
+- [Bulk deletes](https://bun.uptrace.dev/guide/queries.html#delete).
+- [Fixtures](https://bun.uptrace.dev/guide/fixtures.html).
+- [Migrations](https://bun.uptrace.dev/guide/migrations.html).
+- [Soft deletes](https://bun.uptrace.dev/guide/soft-deletes.html).
+
+Resources:
+
+- [Examples](https://github.com/uptrace/bun/tree/master/example)
+- [Documentation](https://bun.uptrace.dev/)
+- [Reference](https://pkg.go.dev/github.com/uptrace/bun)
+- [Starter kit](https://github.com/go-bun/bun-starter-kit)
+- [RealWorld app](https://github.com/go-bun/bun-realworld-app)
+
+
+ github.com/frederikhors/orm-benchmark results
+
+```
+ 4000 times - Insert
+ raw_stmt: 0.38s 94280 ns/op 718 B/op 14 allocs/op
+ raw: 0.39s 96719 ns/op 718 B/op 13 allocs/op
+ beego_orm: 0.48s 118994 ns/op 2411 B/op 56 allocs/op
+ bun: 0.57s 142285 ns/op 918 B/op 12 allocs/op
+ pg: 0.58s 145496 ns/op 1235 B/op 12 allocs/op
+ gorm: 0.70s 175294 ns/op 6665 B/op 88 allocs/op
+ xorm: 0.76s 189533 ns/op 3032 B/op 94 allocs/op
+
+ 4000 times - MultiInsert 100 row
+ raw: 4.59s 1147385 ns/op 135155 B/op 916 allocs/op
+ raw_stmt: 4.59s 1148137 ns/op 131076 B/op 916 allocs/op
+ beego_orm: 5.50s 1375637 ns/op 179962 B/op 2747 allocs/op
+ bun: 6.18s 1544648 ns/op 4265 B/op 214 allocs/op
+ pg: 7.01s 1753495 ns/op 5039 B/op 114 allocs/op
+ gorm: 9.52s 2379219 ns/op 293956 B/op 3729 allocs/op
+ xorm: 11.66s 2915478 ns/op 286140 B/op 7422 allocs/op
+
+ 4000 times - Update
+ raw_stmt: 0.26s 65781 ns/op 773 B/op 14 allocs/op
+ raw: 0.31s 77209 ns/op 757 B/op 13 allocs/op
+ beego_orm: 0.43s 107064 ns/op 1802 B/op 47 allocs/op
+ bun: 0.56s 139839 ns/op 589 B/op 4 allocs/op
+ pg: 0.60s 149608 ns/op 896 B/op 11 allocs/op
+ gorm: 0.74s 185970 ns/op 6604 B/op 81 allocs/op
+ xorm: 0.81s 203240 ns/op 2994 B/op 119 allocs/op
+
+ 4000 times - Read
+ raw: 0.33s 81671 ns/op 2081 B/op 49 allocs/op
+ raw_stmt: 0.34s 85847 ns/op 2112 B/op 50 allocs/op
+ beego_orm: 0.38s 94777 ns/op 2106 B/op 75 allocs/op
+ pg: 0.42s 106148 ns/op 1526 B/op 22 allocs/op
+ bun: 0.43s 106904 ns/op 1319 B/op 18 allocs/op
+ gorm: 0.65s 162221 ns/op 5240 B/op 108 allocs/op
+ xorm: 1.13s 281738 ns/op 8326 B/op 237 allocs/op
+
+ 4000 times - MultiRead limit 100
+ raw: 1.52s 380351 ns/op 38356 B/op 1037 allocs/op
+ raw_stmt: 1.54s 385541 ns/op 38388 B/op 1038 allocs/op
+ pg: 1.86s 465468 ns/op 24045 B/op 631 allocs/op
+ bun: 2.58s 645354 ns/op 30009 B/op 1122 allocs/op
+ beego_orm: 2.93s 732028 ns/op 55280 B/op 3077 allocs/op
+ gorm: 4.97s 1241831 ns/op 71628 B/op 3877 allocs/op
+ xorm: doesn't work
+```
+
+
+
+## Installation
+
+```go
+go get github.com/uptrace/bun
+```
+
+You also need to install a database/sql driver and the corresponding Bun
+[dialect](https://bun.uptrace.dev/guide/drivers.html).
+
+## Quickstart
+
+First you need to create a `sql.DB`. Here we are using the
+[sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) driver which choses
+between [modernc.org/sqlite](https://modernc.org/sqlite/) and
+[mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) depending on your platform.
+
+```go
+import "github.com/uptrace/bun/driver/sqliteshim"
+
+sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
+if err != nil {
+ panic(err)
+}
+```
+
+And then create a `bun.DB` on top of it using the corresponding SQLite dialect:
+
+```go
+import (
+ "github.com/uptrace/bun"
+ "github.com/uptrace/bun/dialect/sqlitedialect"
+)
+
+db := bun.NewDB(sqldb, sqlitedialect.New())
+```
+
+Now you are ready to issue some queries:
+
+```go
+type User struct {
+ ID int64
+ Name string
+}
+
+user := new(User)
+err := db.NewSelect().
+ Model(user).
+ Where("name != ?", "").
+ OrderExpr("id ASC").
+ Limit(1).
+ Scan(ctx)
+```
+
+The code above is equivalent to:
+
+```go
+query := "SELECT id, name FROM users AS user WHERE name != '' ORDER BY id ASC LIMIT 1"
+
+rows, err := sqldb.QueryContext(ctx, query)
+if err != nil {
+ panic(err)
+}
+
+if !rows.Next() {
+ panic(sql.ErrNoRows)
+}
+
+user := new(User)
+if err := db.ScanRow(ctx, rows, user); err != nil {
+ panic(err)
+}
+
+if err := rows.Err(); err != nil {
+ panic(err)
+}
+```
+
+## Basic example
+
+To provide initial data for our [example](/example/basic/), we will use Bun
+[fixtures](https://bun.uptrace.dev/guide/fixtures.html):
+
+```go
+import "github.com/uptrace/bun/dbfixture"
+
+// Register models for the fixture.
+db.RegisterModel((*User)(nil), (*Story)(nil))
+
+// WithRecreateTables tells Bun to drop existing tables and create new ones.
+fixture := dbfixture.New(db, dbfixture.WithRecreateTables())
+
+// Load fixture.yaml which contains data for User and Story models.
+if err := fixture.Load(ctx, os.DirFS("."), "fixture.yaml"); err != nil {
+ panic(err)
+}
+```
+
+The `fixture.yaml` looks like this:
+
+```yaml
+- model: User
+ rows:
+ - _id: admin
+ name: admin
+ emails: ['admin1@admin', 'admin2@admin']
+ - _id: root
+ name: root
+ emails: ['root1@root', 'root2@root']
+
+- model: Story
+ rows:
+ - title: Cool story
+ author_id: '{{ $.User.admin.ID }}'
+```
+
+To select all users:
+
+```go
+users := make([]User, 0)
+if err := db.NewSelect().Model(&users).OrderExpr("id ASC").Scan(ctx); err != nil {
+ panic(err)
+}
+```
+
+To select a single user by id:
+
+```go
+user1 := new(User)
+if err := db.NewSelect().Model(user1).Where("id = ?", 1).Scan(ctx); err != nil {
+ panic(err)
+}
+```
+
+To select a story and the associated author in a single query:
+
+```go
+story := new(Story)
+if err := db.NewSelect().
+ Model(story).
+ Relation("Author").
+ Limit(1).
+ Scan(ctx); err != nil {
+ panic(err)
+}
+```
+
+To select a user into a map:
+
+```go
+m := make(map[string]interface{})
+if err := db.NewSelect().
+ Model((*User)(nil)).
+ Limit(1).
+ Scan(ctx, &m); err != nil {
+ panic(err)
+}
+```
+
+To select all users scanning each column into a separate slice:
+
+```go
+var ids []int64
+var names []string
+if err := db.NewSelect().
+ ColumnExpr("id, name").
+ Model((*User)(nil)).
+ OrderExpr("id ASC").
+ Scan(ctx, &ids, &names); err != nil {
+ panic(err)
+}
+```
+
+For more details, please consult [docs](https://bun.uptrace.dev/) and check [examples](example).
+
+## Contributors
+
+Thanks to all the people who already contributed!
+
+
+
+
diff --git a/vendor/github.com/uptrace/bun/RELEASING.md b/vendor/github.com/uptrace/bun/RELEASING.md
new file mode 100644
index 000000000..9e50c1063
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/RELEASING.md
@@ -0,0 +1,21 @@
+# Releasing
+
+1. Run `release.sh` script which updates versions in go.mod files and pushes a new branch to GitHub:
+
+```shell
+./scripts/release.sh -t v1.0.0
+```
+
+2. Open a pull request and wait for the build to finish.
+
+3. Merge the pull request and run `tag.sh` to create tags for packages:
+
+```shell
+./scripts/tag.sh -t v1.0.0
+```
+
+4. Push the tags:
+
+```shell
+git push origin --tags
+```
diff --git a/vendor/github.com/uptrace/bun/bun.go b/vendor/github.com/uptrace/bun/bun.go
new file mode 100644
index 000000000..92ebe691a
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/bun.go
@@ -0,0 +1,122 @@
+package bun
+
+import (
+ "context"
+ "fmt"
+ "reflect"
+
+ "github.com/uptrace/bun/schema"
+)
+
+type (
+ Safe = schema.Safe
+ Ident = schema.Ident
+)
+
+type NullTime = schema.NullTime
+
+type BaseModel = schema.BaseModel
+
+type (
+ BeforeScanHook = schema.BeforeScanHook
+ AfterScanHook = schema.AfterScanHook
+)
+
+type BeforeSelectHook interface {
+ BeforeSelect(ctx context.Context, query *SelectQuery) error
+}
+
+type AfterSelectHook interface {
+ AfterSelect(ctx context.Context, query *SelectQuery) error
+}
+
+type BeforeInsertHook interface {
+ BeforeInsert(ctx context.Context, query *InsertQuery) error
+}
+
+type AfterInsertHook interface {
+ AfterInsert(ctx context.Context, query *InsertQuery) error
+}
+
+type BeforeUpdateHook interface {
+ BeforeUpdate(ctx context.Context, query *UpdateQuery) error
+}
+
+type AfterUpdateHook interface {
+ AfterUpdate(ctx context.Context, query *UpdateQuery) error
+}
+
+type BeforeDeleteHook interface {
+ BeforeDelete(ctx context.Context, query *DeleteQuery) error
+}
+
+type AfterDeleteHook interface {
+ AfterDelete(ctx context.Context, query *DeleteQuery) error
+}
+
+type BeforeCreateTableHook interface {
+ BeforeCreateTable(ctx context.Context, query *CreateTableQuery) error
+}
+
+type AfterCreateTableHook interface {
+ AfterCreateTable(ctx context.Context, query *CreateTableQuery) error
+}
+
+type BeforeDropTableHook interface {
+ BeforeDropTable(ctx context.Context, query *DropTableQuery) error
+}
+
+type AfterDropTableHook interface {
+ AfterDropTable(ctx context.Context, query *DropTableQuery) error
+}
+
+//------------------------------------------------------------------------------
+
+type InValues struct {
+ slice reflect.Value
+ err error
+}
+
+var _ schema.QueryAppender = InValues{}
+
+func In(slice interface{}) InValues {
+ v := reflect.ValueOf(slice)
+ if v.Kind() != reflect.Slice {
+ return InValues{
+ err: fmt.Errorf("bun: In(non-slice %T)", slice),
+ }
+ }
+ return InValues{
+ slice: v,
+ }
+}
+
+func (in InValues) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if in.err != nil {
+ return nil, in.err
+ }
+ return appendIn(fmter, b, in.slice), nil
+}
+
+func appendIn(fmter schema.Formatter, b []byte, slice reflect.Value) []byte {
+ sliceLen := slice.Len()
+ for i := 0; i < sliceLen; i++ {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ elem := slice.Index(i)
+ if elem.Kind() == reflect.Interface {
+ elem = elem.Elem()
+ }
+
+ if elem.Kind() == reflect.Slice {
+ b = append(b, '(')
+ b = appendIn(fmter, b, elem)
+ b = append(b, ')')
+ } else {
+ b = fmter.AppendValue(b, elem)
+ }
+ }
+ return b
+}
diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go
new file mode 100644
index 000000000..d08adefb5
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/db.go
@@ -0,0 +1,502 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "reflect"
+ "strings"
+ "sync/atomic"
+
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+const (
+ discardUnknownColumns internal.Flag = 1 << iota
+)
+
+type DBStats struct {
+ Queries uint64
+ Errors uint64
+}
+
+type DBOption func(db *DB)
+
+func WithDiscardUnknownColumns() DBOption {
+ return func(db *DB) {
+ db.flags = db.flags.Set(discardUnknownColumns)
+ }
+}
+
+type DB struct {
+ *sql.DB
+ dialect schema.Dialect
+ features feature.Feature
+
+ queryHooks []QueryHook
+
+ fmter schema.Formatter
+ flags internal.Flag
+
+ stats DBStats
+}
+
+func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
+ dialect.Init(sqldb)
+
+ db := &DB{
+ DB: sqldb,
+ dialect: dialect,
+ features: dialect.Features(),
+ fmter: schema.NewFormatter(dialect),
+ }
+
+ for _, opt := range opts {
+ opt(db)
+ }
+
+ return db
+}
+
+func (db *DB) String() string {
+ var b strings.Builder
+ b.WriteString("DB")
+ return b.String()
+}
+
+func (db *DB) DBStats() DBStats {
+ return DBStats{
+ Queries: atomic.LoadUint64(&db.stats.Queries),
+ Errors: atomic.LoadUint64(&db.stats.Errors),
+ }
+}
+
+func (db *DB) NewValues(model interface{}) *ValuesQuery {
+ return NewValuesQuery(db, model)
+}
+
+func (db *DB) NewSelect() *SelectQuery {
+ return NewSelectQuery(db)
+}
+
+func (db *DB) NewInsert() *InsertQuery {
+ return NewInsertQuery(db)
+}
+
+func (db *DB) NewUpdate() *UpdateQuery {
+ return NewUpdateQuery(db)
+}
+
+func (db *DB) NewDelete() *DeleteQuery {
+ return NewDeleteQuery(db)
+}
+
+func (db *DB) NewCreateTable() *CreateTableQuery {
+ return NewCreateTableQuery(db)
+}
+
+func (db *DB) NewDropTable() *DropTableQuery {
+ return NewDropTableQuery(db)
+}
+
+func (db *DB) NewCreateIndex() *CreateIndexQuery {
+ return NewCreateIndexQuery(db)
+}
+
+func (db *DB) NewDropIndex() *DropIndexQuery {
+ return NewDropIndexQuery(db)
+}
+
+func (db *DB) NewTruncateTable() *TruncateTableQuery {
+ return NewTruncateTableQuery(db)
+}
+
+func (db *DB) NewAddColumn() *AddColumnQuery {
+ return NewAddColumnQuery(db)
+}
+
+func (db *DB) NewDropColumn() *DropColumnQuery {
+ return NewDropColumnQuery(db)
+}
+
+func (db *DB) ResetModel(ctx context.Context, models ...interface{}) error {
+ for _, model := range models {
+ if _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx); err != nil {
+ return err
+ }
+ if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (db *DB) Dialect() schema.Dialect {
+ return db.dialect
+}
+
+func (db *DB) ScanRows(ctx context.Context, rows *sql.Rows, dest ...interface{}) error {
+ model, err := newModel(db, dest)
+ if err != nil {
+ return err
+ }
+
+ _, err = model.ScanRows(ctx, rows)
+ return err
+}
+
+func (db *DB) ScanRow(ctx context.Context, rows *sql.Rows, dest ...interface{}) error {
+ model, err := newModel(db, dest)
+ if err != nil {
+ return err
+ }
+
+ rs, ok := model.(rowScanner)
+ if !ok {
+ return fmt.Errorf("bun: %T does not support ScanRow", model)
+ }
+
+ return rs.ScanRow(ctx, rows)
+}
+
+func (db *DB) AddQueryHook(hook QueryHook) {
+ db.queryHooks = append(db.queryHooks, hook)
+}
+
+func (db *DB) Table(typ reflect.Type) *schema.Table {
+ return db.dialect.Tables().Get(typ)
+}
+
+func (db *DB) RegisterModel(models ...interface{}) {
+ db.dialect.Tables().Register(models...)
+}
+
+func (db *DB) clone() *DB {
+ clone := *db
+
+ l := len(clone.queryHooks)
+ clone.queryHooks = clone.queryHooks[:l:l]
+
+ return &clone
+}
+
+func (db *DB) WithNamedArg(name string, value interface{}) *DB {
+ clone := db.clone()
+ clone.fmter = clone.fmter.WithNamedArg(name, value)
+ return clone
+}
+
+func (db *DB) Formatter() schema.Formatter {
+ return db.fmter
+}
+
+//------------------------------------------------------------------------------
+
+func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
+ return db.ExecContext(context.Background(), query, args...)
+}
+
+func (db *DB) ExecContext(
+ ctx context.Context, query string, args ...interface{},
+) (sql.Result, error) {
+ ctx, event := db.beforeQuery(ctx, nil, query, args)
+ res, err := db.DB.ExecContext(ctx, db.format(query, args))
+ db.afterQuery(ctx, event, res, err)
+ return res, err
+}
+
+func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
+ return db.QueryContext(context.Background(), query, args...)
+}
+
+func (db *DB) QueryContext(
+ ctx context.Context, query string, args ...interface{},
+) (*sql.Rows, error) {
+ ctx, event := db.beforeQuery(ctx, nil, query, args)
+ rows, err := db.DB.QueryContext(ctx, db.format(query, args))
+ db.afterQuery(ctx, event, nil, err)
+ return rows, err
+}
+
+func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row {
+ return db.QueryRowContext(context.Background(), query, args...)
+}
+
+func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
+ ctx, event := db.beforeQuery(ctx, nil, query, args)
+ row := db.DB.QueryRowContext(ctx, db.format(query, args))
+ db.afterQuery(ctx, event, nil, row.Err())
+ return row
+}
+
+func (db *DB) format(query string, args []interface{}) string {
+ return db.fmter.FormatQuery(query, args...)
+}
+
+//------------------------------------------------------------------------------
+
+type Conn struct {
+ db *DB
+ *sql.Conn
+}
+
+func (db *DB) Conn(ctx context.Context) (Conn, error) {
+ conn, err := db.DB.Conn(ctx)
+ if err != nil {
+ return Conn{}, err
+ }
+ return Conn{
+ db: db,
+ Conn: conn,
+ }, nil
+}
+
+func (c Conn) ExecContext(
+ ctx context.Context, query string, args ...interface{},
+) (sql.Result, error) {
+ ctx, event := c.db.beforeQuery(ctx, nil, query, args)
+ res, err := c.Conn.ExecContext(ctx, c.db.format(query, args))
+ c.db.afterQuery(ctx, event, res, err)
+ return res, err
+}
+
+func (c Conn) QueryContext(
+ ctx context.Context, query string, args ...interface{},
+) (*sql.Rows, error) {
+ ctx, event := c.db.beforeQuery(ctx, nil, query, args)
+ rows, err := c.Conn.QueryContext(ctx, c.db.format(query, args))
+ c.db.afterQuery(ctx, event, nil, err)
+ return rows, err
+}
+
+func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
+ ctx, event := c.db.beforeQuery(ctx, nil, query, args)
+ row := c.Conn.QueryRowContext(ctx, c.db.format(query, args))
+ c.db.afterQuery(ctx, event, nil, row.Err())
+ return row
+}
+
+func (c Conn) NewValues(model interface{}) *ValuesQuery {
+ return NewValuesQuery(c.db, model).Conn(c)
+}
+
+func (c Conn) NewSelect() *SelectQuery {
+ return NewSelectQuery(c.db).Conn(c)
+}
+
+func (c Conn) NewInsert() *InsertQuery {
+ return NewInsertQuery(c.db).Conn(c)
+}
+
+func (c Conn) NewUpdate() *UpdateQuery {
+ return NewUpdateQuery(c.db).Conn(c)
+}
+
+func (c Conn) NewDelete() *DeleteQuery {
+ return NewDeleteQuery(c.db).Conn(c)
+}
+
+func (c Conn) NewCreateTable() *CreateTableQuery {
+ return NewCreateTableQuery(c.db).Conn(c)
+}
+
+func (c Conn) NewDropTable() *DropTableQuery {
+ return NewDropTableQuery(c.db).Conn(c)
+}
+
+func (c Conn) NewCreateIndex() *CreateIndexQuery {
+ return NewCreateIndexQuery(c.db).Conn(c)
+}
+
+func (c Conn) NewDropIndex() *DropIndexQuery {
+ return NewDropIndexQuery(c.db).Conn(c)
+}
+
+func (c Conn) NewTruncateTable() *TruncateTableQuery {
+ return NewTruncateTableQuery(c.db).Conn(c)
+}
+
+func (c Conn) NewAddColumn() *AddColumnQuery {
+ return NewAddColumnQuery(c.db).Conn(c)
+}
+
+func (c Conn) NewDropColumn() *DropColumnQuery {
+ return NewDropColumnQuery(c.db).Conn(c)
+}
+
+//------------------------------------------------------------------------------
+
+type Stmt struct {
+ *sql.Stmt
+}
+
+func (db *DB) Prepare(query string) (Stmt, error) {
+ return db.PrepareContext(context.Background(), query)
+}
+
+func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) {
+ stmt, err := db.DB.PrepareContext(ctx, query)
+ if err != nil {
+ return Stmt{}, err
+ }
+ return Stmt{Stmt: stmt}, nil
+}
+
+//------------------------------------------------------------------------------
+
+type Tx struct {
+ db *DB
+ *sql.Tx
+}
+
+// RunInTx runs the function in a transaction. If the function returns an error,
+// the transaction is rolled back. Otherwise, the transaction is committed.
+func (db *DB) RunInTx(
+ ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error,
+) error {
+ tx, err := db.BeginTx(ctx, opts)
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback() //nolint:errcheck
+
+ if err := fn(ctx, tx); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (db *DB) Begin() (Tx, error) {
+ return db.BeginTx(context.Background(), nil)
+}
+
+func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
+ tx, err := db.DB.BeginTx(ctx, opts)
+ if err != nil {
+ return Tx{}, err
+ }
+ return Tx{
+ db: db,
+ Tx: tx,
+ }, nil
+}
+
+func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
+ return tx.ExecContext(context.TODO(), query, args...)
+}
+
+func (tx Tx) ExecContext(
+ ctx context.Context, query string, args ...interface{},
+) (sql.Result, error) {
+ ctx, event := tx.db.beforeQuery(ctx, nil, query, args)
+ res, err := tx.Tx.ExecContext(ctx, tx.db.format(query, args))
+ tx.db.afterQuery(ctx, event, res, err)
+ return res, err
+}
+
+func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
+ return tx.QueryContext(context.TODO(), query, args...)
+}
+
+func (tx Tx) QueryContext(
+ ctx context.Context, query string, args ...interface{},
+) (*sql.Rows, error) {
+ ctx, event := tx.db.beforeQuery(ctx, nil, query, args)
+ rows, err := tx.Tx.QueryContext(ctx, tx.db.format(query, args))
+ tx.db.afterQuery(ctx, event, nil, err)
+ return rows, err
+}
+
+func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row {
+ return tx.QueryRowContext(context.TODO(), query, args...)
+}
+
+func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
+ ctx, event := tx.db.beforeQuery(ctx, nil, query, args)
+ row := tx.Tx.QueryRowContext(ctx, tx.db.format(query, args))
+ tx.db.afterQuery(ctx, event, nil, row.Err())
+ return row
+}
+
+//------------------------------------------------------------------------------
+
+func (tx Tx) NewValues(model interface{}) *ValuesQuery {
+ return NewValuesQuery(tx.db, model).Conn(tx)
+}
+
+func (tx Tx) NewSelect() *SelectQuery {
+ return NewSelectQuery(tx.db).Conn(tx)
+}
+
+func (tx Tx) NewInsert() *InsertQuery {
+ return NewInsertQuery(tx.db).Conn(tx)
+}
+
+func (tx Tx) NewUpdate() *UpdateQuery {
+ return NewUpdateQuery(tx.db).Conn(tx)
+}
+
+func (tx Tx) NewDelete() *DeleteQuery {
+ return NewDeleteQuery(tx.db).Conn(tx)
+}
+
+func (tx Tx) NewCreateTable() *CreateTableQuery {
+ return NewCreateTableQuery(tx.db).Conn(tx)
+}
+
+func (tx Tx) NewDropTable() *DropTableQuery {
+ return NewDropTableQuery(tx.db).Conn(tx)
+}
+
+func (tx Tx) NewCreateIndex() *CreateIndexQuery {
+ return NewCreateIndexQuery(tx.db).Conn(tx)
+}
+
+func (tx Tx) NewDropIndex() *DropIndexQuery {
+ return NewDropIndexQuery(tx.db).Conn(tx)
+}
+
+func (tx Tx) NewTruncateTable() *TruncateTableQuery {
+ return NewTruncateTableQuery(tx.db).Conn(tx)
+}
+
+func (tx Tx) NewAddColumn() *AddColumnQuery {
+ return NewAddColumnQuery(tx.db).Conn(tx)
+}
+
+func (tx Tx) NewDropColumn() *DropColumnQuery {
+ return NewDropColumnQuery(tx.db).Conn(tx)
+}
+
+//------------------------------------------------------------------------------0
+
+func (db *DB) makeQueryBytes() []byte {
+ // TODO: make this configurable?
+ return make([]byte, 0, 4096)
+}
+
+//------------------------------------------------------------------------------
+
+type result struct {
+ r sql.Result
+ n int
+}
+
+func (r result) RowsAffected() (int64, error) {
+ if r.r != nil {
+ return r.r.RowsAffected()
+ }
+ return int64(r.n), nil
+}
+
+func (r result) LastInsertId() (int64, error) {
+ if r.r != nil {
+ return r.r.LastInsertId()
+ }
+ return 0, errors.New("LastInsertId is not available")
+}
diff --git a/vendor/github.com/uptrace/bun/dialect/append.go b/vendor/github.com/uptrace/bun/dialect/append.go
new file mode 100644
index 000000000..7040c5155
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/append.go
@@ -0,0 +1,178 @@
+package dialect
+
+import (
+ "encoding/hex"
+ "math"
+ "strconv"
+ "time"
+ "unicode/utf8"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/internal/parser"
+)
+
+func AppendError(b []byte, err error) []byte {
+ b = append(b, "?!("...)
+ b = append(b, err.Error()...)
+ b = append(b, ')')
+ return b
+}
+
+func AppendNull(b []byte) []byte {
+ return append(b, "NULL"...)
+}
+
+func AppendBool(b []byte, v bool) []byte {
+ if v {
+ return append(b, "TRUE"...)
+ }
+ return append(b, "FALSE"...)
+}
+
+func AppendFloat32(b []byte, v float32) []byte {
+ return appendFloat(b, float64(v), 32)
+}
+
+func AppendFloat64(b []byte, v float64) []byte {
+ return appendFloat(b, v, 64)
+}
+
+func appendFloat(b []byte, v float64, bitSize int) []byte {
+ switch {
+ case math.IsNaN(v):
+ return append(b, "'NaN'"...)
+ case math.IsInf(v, 1):
+ return append(b, "'Infinity'"...)
+ case math.IsInf(v, -1):
+ return append(b, "'-Infinity'"...)
+ default:
+ return strconv.AppendFloat(b, v, 'f', -1, bitSize)
+ }
+}
+
+func AppendString(b []byte, s string) []byte {
+ b = append(b, '\'')
+ for _, r := range s {
+ if r == '\000' {
+ continue
+ }
+
+ if r == '\'' {
+ b = append(b, '\'', '\'')
+ continue
+ }
+
+ if r < utf8.RuneSelf {
+ b = append(b, byte(r))
+ continue
+ }
+
+ l := len(b)
+ if cap(b)-l < utf8.UTFMax {
+ b = append(b, make([]byte, utf8.UTFMax)...)
+ }
+ n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
+ b = b[:l+n]
+ }
+ b = append(b, '\'')
+ return b
+}
+
+func AppendBytes(b []byte, bytes []byte) []byte {
+ if bytes == nil {
+ return AppendNull(b)
+ }
+
+ b = append(b, `'\x`...)
+
+ s := len(b)
+ b = append(b, make([]byte, hex.EncodedLen(len(bytes)))...)
+ hex.Encode(b[s:], bytes)
+
+ b = append(b, '\'')
+
+ return b
+}
+
+func AppendTime(b []byte, tm time.Time) []byte {
+ if tm.IsZero() {
+ return AppendNull(b)
+ }
+ b = append(b, '\'')
+ b = tm.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00")
+ b = append(b, '\'')
+ return b
+}
+
+func AppendJSON(b, jsonb []byte) []byte {
+ b = append(b, '\'')
+
+ p := parser.New(jsonb)
+ for p.Valid() {
+ c := p.Read()
+ switch c {
+ case '"':
+ b = append(b, '"')
+ case '\'':
+ b = append(b, "''"...)
+ case '\000':
+ continue
+ case '\\':
+ if p.SkipBytes([]byte("u0000")) {
+ b = append(b, `\\u0000`...)
+ } else {
+ b = append(b, '\\')
+ if p.Valid() {
+ b = append(b, p.Read())
+ }
+ }
+ default:
+ b = append(b, c)
+ }
+ }
+
+ b = append(b, '\'')
+
+ return b
+}
+
+//------------------------------------------------------------------------------
+
+func AppendIdent(b []byte, field string, quote byte) []byte {
+ return appendIdent(b, internal.Bytes(field), quote)
+}
+
+func appendIdent(b, src []byte, quote byte) []byte {
+ var quoted bool
+loop:
+ for _, c := range src {
+ switch c {
+ case '*':
+ if !quoted {
+ b = append(b, '*')
+ continue loop
+ }
+ case '.':
+ if quoted {
+ b = append(b, quote)
+ quoted = false
+ }
+ b = append(b, '.')
+ continue loop
+ }
+
+ if !quoted {
+ b = append(b, quote)
+ quoted = true
+ }
+ if c == quote {
+ b = append(b, quote, quote)
+ } else {
+ b = append(b, c)
+ }
+ }
+ if quoted {
+ b = append(b, quote)
+ }
+ return b
+}
diff --git a/vendor/github.com/uptrace/bun/dialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/dialect.go
new file mode 100644
index 000000000..9ff8b2461
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/dialect.go
@@ -0,0 +1,26 @@
+package dialect
+
+type Name int
+
+func (n Name) String() string {
+ switch n {
+ case PG:
+ return "pg"
+ case SQLite:
+ return "sqlite"
+ case MySQL5:
+ return "mysql5"
+ case MySQL8:
+ return "mysql8"
+ default:
+ return "invalid"
+ }
+}
+
+const (
+ Invalid Name = iota
+ PG
+ SQLite
+ MySQL5
+ MySQL8
+)
diff --git a/vendor/github.com/uptrace/bun/dialect/feature/feature.go b/vendor/github.com/uptrace/bun/dialect/feature/feature.go
new file mode 100644
index 000000000..ff8f1d625
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/feature/feature.go
@@ -0,0 +1,22 @@
+package feature
+
+import "github.com/uptrace/bun/internal"
+
+type Feature = internal.Flag
+
+const DefaultFeatures = Returning | TableCascade
+
+const (
+ Returning Feature = 1 << iota
+ DefaultPlaceholder
+ DoubleColonCast
+ ValuesRow
+ UpdateMultiTable
+ InsertTableAlias
+ DeleteTableAlias
+ AutoIncrement
+ TableCascade
+ TableIdentity
+ TableTruncate
+ OnDuplicateKey
+)
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE b/vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE
new file mode 100644
index 000000000..7ec81810c
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE
@@ -0,0 +1,24 @@
+Copyright (c) 2021 Vladimir Mihailenco. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go
new file mode 100644
index 000000000..475621197
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go
@@ -0,0 +1,303 @@
+package pgdialect
+
+import (
+ "database/sql/driver"
+ "fmt"
+ "reflect"
+ "strconv"
+ "time"
+ "unicode/utf8"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/schema"
+)
+
+var (
+ driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
+
+ stringType = reflect.TypeOf((*string)(nil)).Elem()
+ sliceStringType = reflect.TypeOf([]string(nil))
+
+ intType = reflect.TypeOf((*int)(nil)).Elem()
+ sliceIntType = reflect.TypeOf([]int(nil))
+
+ int64Type = reflect.TypeOf((*int64)(nil)).Elem()
+ sliceInt64Type = reflect.TypeOf([]int64(nil))
+
+ float64Type = reflect.TypeOf((*float64)(nil)).Elem()
+ sliceFloat64Type = reflect.TypeOf([]float64(nil))
+)
+
+func customAppender(typ reflect.Type) schema.AppenderFunc {
+ switch typ.Kind() {
+ case reflect.Uint32:
+ return appendUint32ValueAsInt
+ case reflect.Uint, reflect.Uint64:
+ return appendUint64ValueAsInt
+ }
+ return nil
+}
+
+func appendUint32ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
+ return strconv.AppendInt(b, int64(int32(v.Uint())), 10)
+}
+
+func appendUint64ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
+ return strconv.AppendInt(b, int64(v.Uint()), 10)
+}
+
+//------------------------------------------------------------------------------
+
+func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
+ switch v := v.(type) {
+ case int64:
+ return strconv.AppendInt(b, v, 10)
+ case float64:
+ return dialect.AppendFloat64(b, v)
+ case bool:
+ return dialect.AppendBool(b, v)
+ case []byte:
+ return dialect.AppendBytes(b, v)
+ case string:
+ return arrayAppendString(b, v)
+ case time.Time:
+ return dialect.AppendTime(b, v)
+ default:
+ err := fmt.Errorf("pgdialect: can't append %T", v)
+ return dialect.AppendError(b, err)
+ }
+}
+
+func arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
+ if typ.Kind() == reflect.String {
+ return arrayAppendStringValue
+ }
+ if typ.Implements(driverValuerType) {
+ return arrayAppendDriverValue
+ }
+ return schema.Appender(typ, customAppender)
+}
+
+func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
+ return arrayAppendString(b, v.String())
+}
+
+func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
+ iface, err := v.Interface().(driver.Valuer).Value()
+ if err != nil {
+ return dialect.AppendError(b, err)
+ }
+ return arrayAppend(fmter, b, iface)
+}
+
+//------------------------------------------------------------------------------
+
+func arrayAppender(typ reflect.Type) schema.AppenderFunc {
+ kind := typ.Kind()
+ if kind == reflect.Ptr {
+ typ = typ.Elem()
+ kind = typ.Kind()
+ }
+
+ switch kind {
+ case reflect.Slice, reflect.Array:
+ // ok:
+ default:
+ return nil
+ }
+
+ elemType := typ.Elem()
+
+ if kind == reflect.Slice {
+ switch elemType {
+ case stringType:
+ return appendStringSliceValue
+ case intType:
+ return appendIntSliceValue
+ case int64Type:
+ return appendInt64SliceValue
+ case float64Type:
+ return appendFloat64SliceValue
+ }
+ }
+
+ appendElem := arrayElemAppender(elemType)
+ if appendElem == nil {
+ panic(fmt.Errorf("pgdialect: %s is not supported", typ))
+ }
+
+ return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
+ kind := v.Kind()
+ switch kind {
+ case reflect.Ptr, reflect.Slice:
+ if v.IsNil() {
+ return dialect.AppendNull(b)
+ }
+ }
+
+ if kind == reflect.Ptr {
+ v = v.Elem()
+ }
+
+ b = append(b, '\'')
+
+ b = append(b, '{')
+ for i := 0; i < v.Len(); i++ {
+ elem := v.Index(i)
+ b = appendElem(fmter, b, elem)
+ b = append(b, ',')
+ }
+ if v.Len() > 0 {
+ b[len(b)-1] = '}' // Replace trailing comma.
+ } else {
+ b = append(b, '}')
+ }
+
+ b = append(b, '\'')
+
+ return b
+ }
+}
+
+func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
+ ss := v.Convert(sliceStringType).Interface().([]string)
+ return appendStringSlice(b, ss)
+}
+
+func appendStringSlice(b []byte, ss []string) []byte {
+ if ss == nil {
+ return dialect.AppendNull(b)
+ }
+
+ b = append(b, '\'')
+
+ b = append(b, '{')
+ for _, s := range ss {
+ b = arrayAppendString(b, s)
+ b = append(b, ',')
+ }
+ if len(ss) > 0 {
+ b[len(b)-1] = '}' // Replace trailing comma.
+ } else {
+ b = append(b, '}')
+ }
+
+ b = append(b, '\'')
+
+ return b
+}
+
+func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
+ ints := v.Convert(sliceIntType).Interface().([]int)
+ return appendIntSlice(b, ints)
+}
+
+func appendIntSlice(b []byte, ints []int) []byte {
+ if ints == nil {
+ return dialect.AppendNull(b)
+ }
+
+ b = append(b, '\'')
+
+ b = append(b, '{')
+ for _, n := range ints {
+ b = strconv.AppendInt(b, int64(n), 10)
+ b = append(b, ',')
+ }
+ if len(ints) > 0 {
+ b[len(b)-1] = '}' // Replace trailing comma.
+ } else {
+ b = append(b, '}')
+ }
+
+ b = append(b, '\'')
+
+ return b
+}
+
+func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
+ ints := v.Convert(sliceInt64Type).Interface().([]int64)
+ return appendInt64Slice(b, ints)
+}
+
+func appendInt64Slice(b []byte, ints []int64) []byte {
+ if ints == nil {
+ return dialect.AppendNull(b)
+ }
+
+ b = append(b, '\'')
+
+ b = append(b, '{')
+ for _, n := range ints {
+ b = strconv.AppendInt(b, n, 10)
+ b = append(b, ',')
+ }
+ if len(ints) > 0 {
+ b[len(b)-1] = '}' // Replace trailing comma.
+ } else {
+ b = append(b, '}')
+ }
+
+ b = append(b, '\'')
+
+ return b
+}
+
+func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
+ floats := v.Convert(sliceFloat64Type).Interface().([]float64)
+ return appendFloat64Slice(b, floats)
+}
+
+func appendFloat64Slice(b []byte, floats []float64) []byte {
+ if floats == nil {
+ return dialect.AppendNull(b)
+ }
+
+ b = append(b, '\'')
+
+ b = append(b, '{')
+ for _, n := range floats {
+ b = dialect.AppendFloat64(b, n)
+ b = append(b, ',')
+ }
+ if len(floats) > 0 {
+ b[len(b)-1] = '}' // Replace trailing comma.
+ } else {
+ b = append(b, '}')
+ }
+
+ b = append(b, '\'')
+
+ return b
+}
+
+//------------------------------------------------------------------------------
+
+func arrayAppendString(b []byte, s string) []byte {
+ b = append(b, '"')
+ for _, r := range s {
+ switch r {
+ case 0:
+ // ignore
+ case '\'':
+ b = append(b, "'''"...)
+ case '"':
+ b = append(b, '\\', '"')
+ case '\\':
+ b = append(b, '\\', '\\')
+ default:
+ if r < utf8.RuneSelf {
+ b = append(b, byte(r))
+ break
+ }
+ l := len(b)
+ if cap(b)-l < utf8.UTFMax {
+ b = append(b, make([]byte, utf8.UTFMax)...)
+ }
+ n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
+ b = b[:l+n]
+ }
+ }
+ b = append(b, '"')
+ return b
+}
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go
new file mode 100644
index 000000000..57f5a4384
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go
@@ -0,0 +1,65 @@
+package pgdialect
+
+import (
+ "database/sql"
+ "fmt"
+ "reflect"
+
+ "github.com/uptrace/bun/schema"
+)
+
+type ArrayValue struct {
+ v reflect.Value
+
+ append schema.AppenderFunc
+ scan schema.ScannerFunc
+}
+
+// Array accepts a slice and returns a wrapper for working with PostgreSQL
+// array data type.
+//
+// For struct fields you can use array tag:
+//
+// Emails []string `bun:",array"`
+func Array(vi interface{}) *ArrayValue {
+ v := reflect.ValueOf(vi)
+ if !v.IsValid() {
+ panic(fmt.Errorf("bun: Array(nil)"))
+ }
+
+ return &ArrayValue{
+ v: v,
+
+ append: arrayAppender(v.Type()),
+ scan: arrayScanner(v.Type()),
+ }
+}
+
+var (
+ _ schema.QueryAppender = (*ArrayValue)(nil)
+ _ sql.Scanner = (*ArrayValue)(nil)
+)
+
+func (a *ArrayValue) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) {
+ if a.append == nil {
+ panic(fmt.Errorf("bun: Array(unsupported %s)", a.v.Type()))
+ }
+ return a.append(fmter, b, a.v), nil
+}
+
+func (a *ArrayValue) Scan(src interface{}) error {
+ if a.scan == nil {
+ return fmt.Errorf("bun: Array(unsupported %s)", a.v.Type())
+ }
+ if a.v.Kind() != reflect.Ptr {
+ return fmt.Errorf("bun: Array(non-pointer %s)", a.v.Type())
+ }
+ return a.scan(a.v, src)
+}
+
+func (a *ArrayValue) Value() interface{} {
+ if a.v.IsValid() {
+ return a.v.Interface()
+ }
+ return nil
+}
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go
new file mode 100644
index 000000000..1c927fca0
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go
@@ -0,0 +1,146 @@
+package pgdialect
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+)
+
+type arrayParser struct {
+ b []byte
+ i int
+
+ buf []byte
+ err error
+}
+
+func newArrayParser(b []byte) *arrayParser {
+ p := &arrayParser{
+ b: b,
+ i: 1,
+ }
+ if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' {
+ p.err = fmt.Errorf("bun: can't parse array: %q", b)
+ }
+ return p
+}
+
+func (p *arrayParser) NextElem() ([]byte, error) {
+ if p.err != nil {
+ return nil, p.err
+ }
+
+ c, err := p.readByte()
+ if err != nil {
+ return nil, err
+ }
+
+ switch c {
+ case '}':
+ return nil, io.EOF
+ case '"':
+ b, err := p.readSubstring()
+ if err != nil {
+ return nil, err
+ }
+
+ if p.peek() == ',' {
+ p.skipNext()
+ }
+
+ return b, nil
+ default:
+ b := p.readSimple()
+ if bytes.Equal(b, []byte("NULL")) {
+ b = nil
+ }
+
+ if p.peek() == ',' {
+ p.skipNext()
+ }
+
+ return b, nil
+ }
+}
+
+func (p *arrayParser) readSimple() []byte {
+ p.unreadByte()
+
+ if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 {
+ b := p.b[p.i : p.i+i]
+ p.i += i
+ return b
+ }
+
+ b := p.b[p.i : len(p.b)-1]
+ p.i = len(p.b) - 1
+ return b
+}
+
+func (p *arrayParser) readSubstring() ([]byte, error) {
+ c, err := p.readByte()
+ if err != nil {
+ return nil, err
+ }
+
+ p.buf = p.buf[:0]
+ for {
+ if c == '"' {
+ break
+ }
+
+ next, err := p.readByte()
+ if err != nil {
+ return nil, err
+ }
+
+ if c == '\\' {
+ switch next {
+ case '\\', '"':
+ p.buf = append(p.buf, next)
+
+ c, err = p.readByte()
+ if err != nil {
+ return nil, err
+ }
+ default:
+ p.buf = append(p.buf, '\\')
+ c = next
+ }
+ continue
+ }
+
+ p.buf = append(p.buf, c)
+ c = next
+ }
+
+ return p.buf, nil
+}
+
+func (p *arrayParser) valid() bool {
+ return p.i < len(p.b)
+}
+
+func (p *arrayParser) readByte() (byte, error) {
+ if p.valid() {
+ c := p.b[p.i]
+ p.i++
+ return c, nil
+ }
+ return 0, io.EOF
+}
+
+func (p *arrayParser) unreadByte() {
+ p.i--
+}
+
+func (p *arrayParser) peek() byte {
+ if p.valid() {
+ return p.b[p.i]
+ }
+ return 0
+}
+
+func (p *arrayParser) skipNext() {
+ p.i++
+}
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go
new file mode 100644
index 000000000..33d31f325
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go
@@ -0,0 +1,302 @@
+package pgdialect
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+ "strconv"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+func arrayScanner(typ reflect.Type) schema.ScannerFunc {
+ kind := typ.Kind()
+ if kind == reflect.Ptr {
+ typ = typ.Elem()
+ kind = typ.Kind()
+ }
+
+ switch kind {
+ case reflect.Slice, reflect.Array:
+ // ok:
+ default:
+ return nil
+ }
+
+ elemType := typ.Elem()
+
+ if kind == reflect.Slice {
+ switch elemType {
+ case stringType:
+ return scanStringSliceValue
+ case intType:
+ return scanIntSliceValue
+ case int64Type:
+ return scanInt64SliceValue
+ case float64Type:
+ return scanFloat64SliceValue
+ }
+ }
+
+ scanElem := schema.Scanner(elemType)
+ return func(dest reflect.Value, src interface{}) error {
+ dest = reflect.Indirect(dest)
+ if !dest.CanSet() {
+ return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
+ }
+
+ kind := dest.Kind()
+
+ if src == nil {
+ if kind != reflect.Slice || !dest.IsNil() {
+ dest.Set(reflect.Zero(dest.Type()))
+ }
+ return nil
+ }
+
+ if kind == reflect.Slice {
+ if dest.IsNil() {
+ dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
+ } else if dest.Len() > 0 {
+ dest.Set(dest.Slice(0, 0))
+ }
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ p := newArrayParser(b)
+ nextValue := internal.MakeSliceNextElemFunc(dest)
+ for {
+ elem, err := p.NextElem()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ return err
+ }
+
+ elemValue := nextValue()
+ if err := scanElem(elemValue, elem); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }
+}
+
+func scanStringSliceValue(dest reflect.Value, src interface{}) error {
+ dest = reflect.Indirect(dest)
+ if !dest.CanSet() {
+ return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
+ }
+
+ slice, err := decodeStringSlice(src)
+ if err != nil {
+ return err
+ }
+
+ dest.Set(reflect.ValueOf(slice))
+ return nil
+}
+
+func decodeStringSlice(src interface{}) ([]string, error) {
+ if src == nil {
+ return nil, nil
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return nil, err
+ }
+
+ slice := make([]string, 0)
+
+ p := newArrayParser(b)
+ for {
+ elem, err := p.NextElem()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ return nil, err
+ }
+ slice = append(slice, string(elem))
+ }
+
+ return slice, nil
+}
+
+func scanIntSliceValue(dest reflect.Value, src interface{}) error {
+ dest = reflect.Indirect(dest)
+ if !dest.CanSet() {
+ return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
+ }
+
+ slice, err := decodeIntSlice(src)
+ if err != nil {
+ return err
+ }
+
+ dest.Set(reflect.ValueOf(slice))
+ return nil
+}
+
+func decodeIntSlice(src interface{}) ([]int, error) {
+ if src == nil {
+ return nil, nil
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return nil, err
+ }
+
+ slice := make([]int, 0)
+
+ p := newArrayParser(b)
+ for {
+ elem, err := p.NextElem()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ return nil, err
+ }
+
+ if elem == nil {
+ slice = append(slice, 0)
+ continue
+ }
+
+ n, err := strconv.Atoi(bytesToString(elem))
+ if err != nil {
+ return nil, err
+ }
+
+ slice = append(slice, n)
+ }
+
+ return slice, nil
+}
+
+func scanInt64SliceValue(dest reflect.Value, src interface{}) error {
+ dest = reflect.Indirect(dest)
+ if !dest.CanSet() {
+ return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
+ }
+
+ slice, err := decodeInt64Slice(src)
+ if err != nil {
+ return err
+ }
+
+ dest.Set(reflect.ValueOf(slice))
+ return nil
+}
+
+func decodeInt64Slice(src interface{}) ([]int64, error) {
+ if src == nil {
+ return nil, nil
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return nil, err
+ }
+
+ slice := make([]int64, 0)
+
+ p := newArrayParser(b)
+ for {
+ elem, err := p.NextElem()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ return nil, err
+ }
+
+ if elem == nil {
+ slice = append(slice, 0)
+ continue
+ }
+
+ n, err := strconv.ParseInt(bytesToString(elem), 10, 64)
+ if err != nil {
+ return nil, err
+ }
+
+ slice = append(slice, n)
+ }
+
+ return slice, nil
+}
+
+func scanFloat64SliceValue(dest reflect.Value, src interface{}) error {
+ dest = reflect.Indirect(dest)
+ if !dest.CanSet() {
+ return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
+ }
+
+ slice, err := scanFloat64Slice(src)
+ if err != nil {
+ return err
+ }
+
+ dest.Set(reflect.ValueOf(slice))
+ return nil
+}
+
+func scanFloat64Slice(src interface{}) ([]float64, error) {
+ if src == -1 {
+ return nil, nil
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return nil, err
+ }
+
+ slice := make([]float64, 0)
+
+ p := newArrayParser(b)
+ for {
+ elem, err := p.NextElem()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ return nil, err
+ }
+
+ if elem == nil {
+ slice = append(slice, 0)
+ continue
+ }
+
+ n, err := strconv.ParseFloat(bytesToString(elem), 64)
+ if err != nil {
+ return nil, err
+ }
+
+ slice = append(slice, n)
+ }
+
+ return slice, nil
+}
+
+func toBytes(src interface{}) ([]byte, error) {
+ switch src := src.(type) {
+ case string:
+ return stringToBytes(src), nil
+ case []byte:
+ return src, nil
+ default:
+ return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go
new file mode 100644
index 000000000..fb210751b
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go
@@ -0,0 +1,150 @@
+package pgdialect
+
+import (
+ "database/sql"
+ "reflect"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/dialect/sqltype"
+ "github.com/uptrace/bun/schema"
+)
+
+type Dialect struct {
+ tables *schema.Tables
+ features feature.Feature
+
+ appenderMap sync.Map
+ scannerMap sync.Map
+}
+
+func New() *Dialect {
+ d := new(Dialect)
+ d.tables = schema.NewTables(d)
+ d.features = feature.Returning |
+ feature.DefaultPlaceholder |
+ feature.DoubleColonCast |
+ feature.InsertTableAlias |
+ feature.DeleteTableAlias |
+ feature.TableCascade |
+ feature.TableIdentity |
+ feature.TableTruncate
+ return d
+}
+
+func (d *Dialect) Init(*sql.DB) {}
+
+func (d *Dialect) Name() dialect.Name {
+ return dialect.PG
+}
+
+func (d *Dialect) Features() feature.Feature {
+ return d.features
+}
+
+func (d *Dialect) Tables() *schema.Tables {
+ return d.tables
+}
+
+func (d *Dialect) OnTable(table *schema.Table) {
+ for _, field := range table.FieldMap {
+ d.onField(field)
+ }
+}
+
+func (d *Dialect) onField(field *schema.Field) {
+ field.DiscoveredSQLType = fieldSQLType(field)
+
+ if field.AutoIncrement {
+ switch field.DiscoveredSQLType {
+ case sqltype.SmallInt:
+ field.CreateTableSQLType = pgTypeSmallSerial
+ case sqltype.Integer:
+ field.CreateTableSQLType = pgTypeSerial
+ case sqltype.BigInt:
+ field.CreateTableSQLType = pgTypeBigSerial
+ }
+ }
+
+ if field.Tag.HasOption("array") {
+ field.Append = arrayAppender(field.IndirectType)
+ field.Scan = arrayScanner(field.IndirectType)
+ }
+}
+
+func (d *Dialect) IdentQuote() byte {
+ return '"'
+}
+
+func (d *Dialect) Append(fmter schema.Formatter, b []byte, v interface{}) []byte {
+ switch v := v.(type) {
+ case nil:
+ return dialect.AppendNull(b)
+ case bool:
+ return dialect.AppendBool(b, v)
+ case int:
+ return strconv.AppendInt(b, int64(v), 10)
+ case int32:
+ return strconv.AppendInt(b, int64(v), 10)
+ case int64:
+ return strconv.AppendInt(b, v, 10)
+ case uint:
+ return strconv.AppendInt(b, int64(v), 10)
+ case uint32:
+ return strconv.AppendInt(b, int64(v), 10)
+ case uint64:
+ return strconv.AppendInt(b, int64(v), 10)
+ case float32:
+ return dialect.AppendFloat32(b, v)
+ case float64:
+ return dialect.AppendFloat64(b, v)
+ case string:
+ return dialect.AppendString(b, v)
+ case time.Time:
+ return dialect.AppendTime(b, v)
+ case []byte:
+ return dialect.AppendBytes(b, v)
+ case schema.QueryAppender:
+ return schema.AppendQueryAppender(fmter, b, v)
+ default:
+ vv := reflect.ValueOf(v)
+ if vv.Kind() == reflect.Ptr && vv.IsNil() {
+ return dialect.AppendNull(b)
+ }
+ appender := d.Appender(vv.Type())
+ return appender(fmter, b, vv)
+ }
+}
+
+func (d *Dialect) Appender(typ reflect.Type) schema.AppenderFunc {
+ if v, ok := d.appenderMap.Load(typ); ok {
+ return v.(schema.AppenderFunc)
+ }
+
+ fn := schema.Appender(typ, customAppender)
+
+ if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok {
+ return v.(schema.AppenderFunc)
+ }
+ return fn
+}
+
+func (d *Dialect) FieldAppender(field *schema.Field) schema.AppenderFunc {
+ return schema.FieldAppender(d, field)
+}
+
+func (d *Dialect) Scanner(typ reflect.Type) schema.ScannerFunc {
+ if v, ok := d.scannerMap.Load(typ); ok {
+ return v.(schema.ScannerFunc)
+ }
+
+ fn := scanner(typ)
+
+ if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok {
+ return v.(schema.ScannerFunc)
+ }
+ return fn
+}
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod
new file mode 100644
index 000000000..0cad1ce5b
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod
@@ -0,0 +1,7 @@
+module github.com/uptrace/bun/dialect/pgdialect
+
+go 1.16
+
+replace github.com/uptrace/bun => ../..
+
+require github.com/uptrace/bun v0.4.3
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum
new file mode 100644
index 000000000..4d0f1c1bb
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum
@@ -0,0 +1,22 @@
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
+github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
+github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
+github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc=
+github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
+github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
+github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
+golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio=
+golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go
new file mode 100644
index 000000000..dff30b9c5
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go
@@ -0,0 +1,11 @@
+// +build appengine
+
+package pgdialect
+
+func bytesToString(b []byte) string {
+ return string(b)
+}
+
+func stringToBytes(s string) []byte {
+ return []byte(s)
+}
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go
new file mode 100644
index 000000000..9e22282f5
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go
@@ -0,0 +1,28 @@
+package pgdialect
+
+import (
+ "fmt"
+ "reflect"
+
+ "github.com/uptrace/bun/schema"
+)
+
+func scanner(typ reflect.Type) schema.ScannerFunc {
+ if typ.Kind() == reflect.Interface {
+ return scanInterface
+ }
+ return schema.Scanner(typ)
+}
+
+func scanInterface(dest reflect.Value, src interface{}) error {
+ if dest.IsNil() {
+ dest.Set(reflect.ValueOf(src))
+ return nil
+ }
+
+ dest = dest.Elem()
+ if fn := scanner(dest.Type()); fn != nil {
+ return fn(dest, src)
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go
new file mode 100644
index 000000000..4c2d8075d
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go
@@ -0,0 +1,104 @@
+package pgdialect
+
+import (
+ "encoding/json"
+ "net"
+ "reflect"
+ "time"
+
+ "github.com/uptrace/bun/dialect/sqltype"
+ "github.com/uptrace/bun/schema"
+)
+
+const (
+ // Date / Time
+ pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone
+ pgTypeDate = "DATE" // Date
+ pgTypeTime = "TIME" // Time without a time zone
+ pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone
+ pgTypeInterval = "INTERVAL" // Time Interval
+
+ // Network Addresses
+ pgTypeInet = "INET" // IPv4 or IPv6 hosts and networks
+ pgTypeCidr = "CIDR" // IPv4 or IPv6 networks
+ pgTypeMacaddr = "MACADDR" // MAC addresses
+
+ // Serial Types
+ pgTypeSmallSerial = "SMALLSERIAL" // 2 byte autoincrementing integer
+ pgTypeSerial = "SERIAL" // 4 byte autoincrementing integer
+ pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer
+
+ // Character Types
+ pgTypeChar = "CHAR" // fixed length string (blank padded)
+ pgTypeText = "TEXT" // variable length string without limit
+
+ // JSON Types
+ pgTypeJSON = "JSON" // text representation of json data
+ pgTypeJSONB = "JSONB" // binary representation of json data
+
+ // Binary Data Types
+ pgTypeBytea = "BYTEA" // binary string
+)
+
+var (
+ timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
+ ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
+ ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
+ jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
+)
+
+func fieldSQLType(field *schema.Field) string {
+ if field.UserSQLType != "" {
+ return field.UserSQLType
+ }
+
+ if v, ok := field.Tag.Options["composite"]; ok {
+ return v
+ }
+
+ if _, ok := field.Tag.Options["hstore"]; ok {
+ return "hstore"
+ }
+
+ if _, ok := field.Tag.Options["array"]; ok {
+ switch field.IndirectType.Kind() {
+ case reflect.Slice, reflect.Array:
+ sqlType := sqlType(field.IndirectType.Elem())
+ return sqlType + "[]"
+ }
+ }
+
+ return sqlType(field.IndirectType)
+}
+
+func sqlType(typ reflect.Type) string {
+ switch typ {
+ case ipType:
+ return pgTypeInet
+ case ipNetType:
+ return pgTypeCidr
+ case jsonRawMessageType:
+ return pgTypeJSONB
+ }
+
+ sqlType := schema.DiscoverSQLType(typ)
+ switch sqlType {
+ case sqltype.Timestamp:
+ sqlType = pgTypeTimestampTz
+ }
+
+ switch typ.Kind() {
+ case reflect.Map, reflect.Struct:
+ if sqlType == sqltype.VarChar {
+ return pgTypeJSONB
+ }
+ return sqlType
+ case reflect.Array, reflect.Slice:
+ if typ.Elem().Kind() == reflect.Uint8 {
+ return pgTypeBytea
+ }
+ return pgTypeJSONB
+ }
+
+ return sqlType
+}
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go
new file mode 100644
index 000000000..2a02a20b1
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go
@@ -0,0 +1,18 @@
+// +build !appengine
+
+package pgdialect
+
+import "unsafe"
+
+func bytesToString(b []byte) string {
+ return *(*string)(unsafe.Pointer(&b))
+}
+
+func stringToBytes(s string) []byte {
+ return *(*[]byte)(unsafe.Pointer(
+ &struct {
+ string
+ Cap int
+ }{s, len(s)},
+ ))
+}
diff --git a/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go b/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go
new file mode 100644
index 000000000..84a51d26d
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go
@@ -0,0 +1,14 @@
+package sqltype
+
+const (
+ Boolean = "BOOLEAN"
+ SmallInt = "SMALLINT"
+ Integer = "INTEGER"
+ BigInt = "BIGINT"
+ Real = "REAL"
+ DoublePrecision = "DOUBLE PRECISION"
+ VarChar = "VARCHAR"
+ Timestamp = "TIMESTAMP"
+ JSON = "JSON"
+ JSONB = "JSONB"
+)
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/LICENSE b/vendor/github.com/uptrace/bun/driver/pgdriver/LICENSE
new file mode 100644
index 000000000..7ec81810c
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/LICENSE
@@ -0,0 +1,24 @@
+Copyright (c) 2021 Vladimir Mihailenco. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/README.md b/vendor/github.com/uptrace/bun/driver/pgdriver/README.md
new file mode 100644
index 000000000..8bb641a5c
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/README.md
@@ -0,0 +1,36 @@
+# pgdriver
+
+[](https://pkg.go.dev/github.com/uptrace/bun/driver/pgdriver)
+
+pgdriver is a database/sql driver for PostgreSQL based on [go-pg](https://github.com/go-pg/pg) code.
+
+You can install it with:
+
+```shell
+github.com/uptrace/bun/driver/pgdriver
+```
+
+And then create a `sql.DB` using it:
+
+```go
+import _ "github.com/uptrace/bun/driver/pgdriver"
+
+dsn := "postgres://postgres:@localhost:5432/test"
+db, err := sql.Open("pg", dsn)
+```
+
+Alternatively:
+
+```go
+dsn := "postgres://postgres:@localhost:5432/test"
+db := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
+```
+
+[Benchmark](https://github.com/go-bun/bun-benchmark):
+
+```
+BenchmarkInsert/pg-12 7254 148380 ns/op 900 B/op 13 allocs/op
+BenchmarkInsert/pgx-12 6494 166391 ns/op 2076 B/op 26 allocs/op
+BenchmarkSelect/pg-12 9100 132952 ns/op 1417 B/op 18 allocs/op
+BenchmarkSelect/pgx-12 8199 154920 ns/op 3679 B/op 60 allocs/op
+```
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/column.go b/vendor/github.com/uptrace/bun/driver/pgdriver/column.go
new file mode 100644
index 000000000..5c3626943
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/column.go
@@ -0,0 +1,192 @@
+package pgdriver
+
+import (
+ "encoding/hex"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+ "time"
+)
+
+const (
+ pgBool = 16
+
+ pgInt2 = 21
+ pgInt4 = 23
+ pgInt8 = 20
+
+ pgFloat4 = 700
+ pgFloat8 = 701
+
+ pgText = 25
+ pgVarchar = 1043
+ pgBytea = 17
+
+ pgDate = 1082
+ pgTimestamp = 1114
+ pgTimestamptz = 1184
+)
+
+func readColumnValue(rd *reader, dataType int32, dataLen int) (interface{}, error) {
+ if dataLen == -1 {
+ return nil, nil
+ }
+
+ switch dataType {
+ case pgBool:
+ return readBoolCol(rd, dataLen)
+ case pgInt2:
+ return readIntCol(rd, dataLen, 16)
+ case pgInt4:
+ return readIntCol(rd, dataLen, 32)
+ case pgInt8:
+ return readIntCol(rd, dataLen, 64)
+ case pgFloat4:
+ return readFloatCol(rd, dataLen, 32)
+ case pgFloat8:
+ return readFloatCol(rd, dataLen, 64)
+ case pgTimestamp:
+ return readTimeCol(rd, dataLen)
+ case pgTimestamptz:
+ return readTimeCol(rd, dataLen)
+ case pgDate:
+ return readTimeCol(rd, dataLen)
+ case pgText, pgVarchar:
+ return readStringCol(rd, dataLen)
+ case pgBytea:
+ return readBytesCol(rd, dataLen)
+ }
+
+ b := make([]byte, dataLen)
+ if _, err := io.ReadFull(rd, b); err != nil {
+ return nil, err
+ }
+ return b, nil
+}
+
+func readBoolCol(rd *reader, n int) (interface{}, error) {
+ tmp, err := rd.ReadTemp(n)
+ if err != nil {
+ return nil, err
+ }
+ return len(tmp) == 1 && (tmp[0] == 't' || tmp[0] == '1'), nil
+}
+
+func readIntCol(rd *reader, n int, bitSize int) (interface{}, error) {
+ if n <= 0 {
+ return 0, nil
+ }
+
+ tmp, err := rd.ReadTemp(n)
+ if err != nil {
+ return 0, err
+ }
+
+ return strconv.ParseInt(bytesToString(tmp), 10, bitSize)
+}
+
+func readFloatCol(rd *reader, n int, bitSize int) (interface{}, error) {
+ if n <= 0 {
+ return 0, nil
+ }
+
+ tmp, err := rd.ReadTemp(n)
+ if err != nil {
+ return 0, err
+ }
+
+ return strconv.ParseFloat(bytesToString(tmp), bitSize)
+}
+
+func readStringCol(rd *reader, n int) (interface{}, error) {
+ if n <= 0 {
+ return "", nil
+ }
+
+ b := make([]byte, n)
+
+ if _, err := io.ReadFull(rd, b); err != nil {
+ return nil, err
+ }
+
+ return bytesToString(b), nil
+}
+
+func readBytesCol(rd *reader, n int) (interface{}, error) {
+ if n <= 0 {
+ return []byte{}, nil
+ }
+
+ tmp, err := rd.ReadTemp(n)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(tmp) < 2 || tmp[0] != '\\' || tmp[1] != 'x' {
+ return nil, fmt.Errorf("pgdriver: can't parse bytea: %q", tmp)
+ }
+ tmp = tmp[2:] // Cut off "\x".
+
+ b := make([]byte, hex.DecodedLen(len(tmp)))
+ if _, err := hex.Decode(b, tmp); err != nil {
+ return nil, err
+ }
+ return b, nil
+}
+
+func readTimeCol(rd *reader, n int) (interface{}, error) {
+ if n <= 0 {
+ return time.Time{}, nil
+ }
+
+ tmp, err := rd.ReadTemp(n)
+ if err != nil {
+ return time.Time{}, err
+ }
+
+ tm, err := parseTime(bytesToString(tmp))
+ if err != nil {
+ return time.Time{}, err
+ }
+ return tm, nil
+}
+
+const (
+ dateFormat = "2006-01-02"
+ timeFormat = "15:04:05.999999999"
+ timestampFormat = "2006-01-02 15:04:05.999999999"
+ timestamptzFormat = "2006-01-02 15:04:05.999999999-07:00:00"
+ timestamptzFormat2 = "2006-01-02 15:04:05.999999999-07:00"
+ timestamptzFormat3 = "2006-01-02 15:04:05.999999999-07"
+)
+
+func parseTime(s string) (time.Time, error) {
+ switch l := len(s); {
+ case l < len("15:04:05"):
+ return time.Time{}, fmt.Errorf("pgdriver: can't parse time=%q", s)
+ case l <= len(timeFormat):
+ if s[2] == ':' {
+ return time.ParseInLocation(timeFormat, s, time.UTC)
+ }
+ return time.ParseInLocation(dateFormat, s, time.UTC)
+ default:
+ if s[10] == 'T' {
+ return time.Parse(time.RFC3339Nano, s)
+ }
+ if c := s[l-9]; c == '+' || c == '-' {
+ return time.Parse(timestamptzFormat, s)
+ }
+ if c := s[l-6]; c == '+' || c == '-' {
+ return time.Parse(timestamptzFormat2, s)
+ }
+ if c := s[l-3]; c == '+' || c == '-' {
+ if strings.HasSuffix(s, "+00") {
+ s = s[:len(s)-3]
+ return time.ParseInLocation(timestampFormat, s, time.UTC)
+ }
+ return time.Parse(timestamptzFormat3, s)
+ }
+ return time.ParseInLocation(timestampFormat, s, time.UTC)
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/config.go b/vendor/github.com/uptrace/bun/driver/pgdriver/config.go
new file mode 100644
index 000000000..8e8abfe59
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/config.go
@@ -0,0 +1,233 @@
+package pgdriver
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "net"
+ "net/url"
+ "os"
+ "strings"
+ "time"
+)
+
+type Config struct {
+ // Network type, either tcp or unix.
+ // Default is tcp.
+ Network string
+ // TCP host:port or Unix socket depending on Network.
+ Addr string
+ // Dial timeout for establishing new connections.
+ // Default is 5 seconds.
+ DialTimeout time.Duration
+ // Dialer creates new network connection and has priority over
+ // Network and Addr options.
+ Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
+
+ // TLS config for secure connections.
+ TLSConfig *tls.Config
+
+ User string
+ Password string
+ Database string
+ AppName string
+
+ // Timeout for socket reads. If reached, commands will fail
+ // with a timeout instead of blocking.
+ ReadTimeout time.Duration
+ // Timeout for socket writes. If reached, commands will fail
+ // with a timeout instead of blocking.
+ WriteTimeout time.Duration
+}
+
+func newDefaultConfig() *Config {
+ host := env("PGHOST", "localhost")
+ port := env("PGPORT", "5432")
+
+ cfg := &Config{
+ Network: "tcp",
+ Addr: net.JoinHostPort(host, port),
+ DialTimeout: 5 * time.Second,
+
+ User: env("PGUSER", "postgres"),
+ Database: env("PGDATABASE", "postgres"),
+
+ ReadTimeout: 10 * time.Second,
+ WriteTimeout: 5 * time.Second,
+ }
+
+ cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ netDialer := &net.Dialer{
+ Timeout: cfg.DialTimeout,
+ KeepAlive: 5 * time.Minute,
+ }
+ return netDialer.DialContext(ctx, network, addr)
+ }
+
+ return cfg
+}
+
+type DriverOption func(*Connector)
+
+func WithAddr(addr string) DriverOption {
+ if addr == "" {
+ panic("addr is empty")
+ }
+ return func(d *Connector) {
+ d.cfg.Addr = addr
+ }
+}
+
+func WithTLSConfig(cfg *tls.Config) DriverOption {
+ return func(d *Connector) {
+ d.cfg.TLSConfig = cfg
+ }
+}
+
+func WithUser(user string) DriverOption {
+ if user == "" {
+ panic("user is empty")
+ }
+ return func(d *Connector) {
+ d.cfg.User = user
+ }
+}
+
+func WithPassword(password string) DriverOption {
+ return func(d *Connector) {
+ d.cfg.Password = password
+ }
+}
+
+func WithDatabase(database string) DriverOption {
+ if database == "" {
+ panic("database is empty")
+ }
+ return func(d *Connector) {
+ d.cfg.Database = database
+ }
+}
+
+func WithApplicationName(appName string) DriverOption {
+ return func(d *Connector) {
+ d.cfg.AppName = appName
+ }
+}
+
+func WithTimeout(timeout time.Duration) DriverOption {
+ return func(d *Connector) {
+ d.cfg.DialTimeout = timeout
+ d.cfg.ReadTimeout = timeout
+ d.cfg.WriteTimeout = timeout
+ }
+}
+
+func WithDialTimeout(dialTimeout time.Duration) DriverOption {
+ return func(d *Connector) {
+ d.cfg.DialTimeout = dialTimeout
+ }
+}
+
+func WithReadTimeout(readTimeout time.Duration) DriverOption {
+ return func(d *Connector) {
+ d.cfg.ReadTimeout = readTimeout
+ }
+}
+
+func WithWriteTimeout(writeTimeout time.Duration) DriverOption {
+ return func(d *Connector) {
+ d.cfg.WriteTimeout = writeTimeout
+ }
+}
+
+func WithDSN(dsn string) DriverOption {
+ return func(d *Connector) {
+ opts, err := parseDSN(dsn)
+ if err != nil {
+ panic(err)
+ }
+ for _, opt := range opts {
+ opt(d)
+ }
+ }
+}
+
+func parseDSN(dsn string) ([]DriverOption, error) {
+ u, err := url.Parse(dsn)
+ if err != nil {
+ return nil, err
+ }
+
+ if u.Scheme != "postgres" && u.Scheme != "postgresql" {
+ return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme)
+ }
+
+ query, err := url.ParseQuery(u.RawQuery)
+ if err != nil {
+ return nil, err
+ }
+
+ var opts []DriverOption
+
+ if u.Host != "" {
+ addr := u.Host
+ if !strings.Contains(addr, ":") {
+ addr += ":5432"
+ }
+ opts = append(opts, WithAddr(addr))
+ }
+ if u.User != nil {
+ opts = append(opts, WithUser(u.User.Username()))
+ if password, ok := u.User.Password(); ok {
+ opts = append(opts, WithPassword(password))
+ }
+ }
+ if len(u.Path) > 1 {
+ opts = append(opts, WithDatabase(u.Path[1:]))
+ }
+
+ if appName := query.Get("application_name"); appName != "" {
+ opts = append(opts, WithApplicationName(appName))
+ }
+ delete(query, "application_name")
+
+ if sslMode := query.Get("sslmode"); sslMode != "" {
+ switch sslMode {
+ case "verify-ca", "verify-full":
+ opts = append(opts, WithTLSConfig(new(tls.Config)))
+ case "allow", "prefer", "require":
+ opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
+ case "disable":
+ // no TLS config
+ default:
+ return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode)
+ }
+ } else {
+ opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
+ }
+ delete(query, "sslmode")
+
+ for key := range query {
+ return nil, fmt.Errorf("pgdriver: unsupported option=%q", key)
+ }
+
+ return opts, nil
+}
+
+func env(key, defValue string) string {
+ if s := os.Getenv(key); s != "" {
+ return s
+ }
+ return defValue
+}
+
+// verify is a method to make sure if the config is legitimate
+// in the case it detects any errors, it returns with a non-nil error
+// it can be extended to check other parameters
+func (c *Config) verify() error {
+ if c.User == "" {
+ return errors.New("pgdriver: User option is empty (to configure, use WithUser).")
+ }
+ return nil
+}
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/driver.go b/vendor/github.com/uptrace/bun/driver/pgdriver/driver.go
new file mode 100644
index 000000000..d25c3adbc
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/driver.go
@@ -0,0 +1,606 @@
+package pgdriver
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "os"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+func init() {
+ sql.Register("pg", NewDriver())
+}
+
+type logging interface {
+ Printf(ctx context.Context, format string, v ...interface{})
+}
+
+type logger struct {
+ log *log.Logger
+}
+
+func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
+ _ = l.log.Output(2, fmt.Sprintf(format, v...))
+}
+
+var Logger logging = &logger{
+ log: log.New(os.Stderr, "pgdriver: ", log.LstdFlags|log.Lshortfile),
+}
+
+//------------------------------------------------------------------------------
+
+type Driver struct {
+ connector *Connector
+}
+
+var _ driver.DriverContext = (*Driver)(nil)
+
+func NewDriver() Driver {
+ return Driver{}
+}
+
+func (d Driver) OpenConnector(name string) (driver.Connector, error) {
+ opts, err := parseDSN(name)
+ if err != nil {
+ return nil, err
+ }
+ return NewConnector(opts...), nil
+}
+
+func (d Driver) Open(name string) (driver.Conn, error) {
+ connector, err := d.OpenConnector(name)
+ if err != nil {
+ return nil, err
+ }
+ return connector.Connect(context.TODO())
+}
+
+//------------------------------------------------------------------------------
+
+type DriverStats struct {
+ Queries uint64
+ Errors uint64
+}
+
+type Connector struct {
+ cfg *Config
+
+ stats DriverStats
+}
+
+func NewConnector(opts ...DriverOption) *Connector {
+ d := &Connector{cfg: newDefaultConfig()}
+ for _, opt := range opts {
+ opt(d)
+ }
+ return d
+}
+
+var _ driver.Connector = (*Connector)(nil)
+
+func (d *Connector) Connect(ctx context.Context) (driver.Conn, error) {
+ if err := d.cfg.verify(); err != nil {
+ return nil, err
+ }
+
+ return newConn(ctx, d)
+}
+
+func (d *Connector) Driver() driver.Driver {
+ return Driver{connector: d}
+}
+
+func (d *Connector) Config() *Config {
+ return d.cfg
+}
+
+func (d *Connector) Stats() DriverStats {
+ return DriverStats{
+ Queries: atomic.LoadUint64(&d.stats.Queries),
+ Errors: atomic.LoadUint64(&d.stats.Errors),
+ }
+}
+
+//------------------------------------------------------------------------------
+
+type Conn struct {
+ driver *Connector
+
+ netConn net.Conn
+ rd *reader
+
+ processID int32
+ secretKey int32
+
+ stmtCount int
+
+ closed int32
+}
+
+func newConn(ctx context.Context, driver *Connector) (*Conn, error) {
+ netConn, err := driver.cfg.Dialer(ctx, driver.cfg.Network, driver.cfg.Addr)
+ if err != nil {
+ return nil, err
+ }
+
+ cn := &Conn{
+ driver: driver,
+ netConn: netConn,
+ rd: newReader(netConn),
+ }
+
+ if cn.driver.cfg.TLSConfig != nil {
+ if err := enableSSL(ctx, cn, cn.driver.cfg.TLSConfig); err != nil {
+ return nil, err
+ }
+ }
+
+ if err := startup(ctx, cn); err != nil {
+ return nil, err
+ }
+
+ return cn, nil
+}
+
+func (cn *Conn) reader(ctx context.Context, timeout time.Duration) *reader {
+ cn.setReadDeadline(ctx, timeout)
+ return cn.rd
+}
+
+func (cn *Conn) withWriter(
+ ctx context.Context,
+ timeout time.Duration,
+ fn func(wr *bufio.Writer) error,
+) error {
+ wr := getBufioWriter()
+
+ cn.setWriteDeadline(ctx, timeout)
+ wr.Reset(cn.netConn)
+
+ err := fn(wr)
+ if err == nil {
+ err = wr.Flush()
+ }
+
+ putBufioWriter(wr)
+
+ return err
+}
+
+var _ driver.Conn = (*Conn)(nil)
+
+func (cn *Conn) Prepare(query string) (driver.Stmt, error) {
+ if cn.isClosed() {
+ return nil, driver.ErrBadConn
+ }
+
+ ctx := context.TODO()
+
+ name := fmt.Sprintf("pgdriver-%d", cn.stmtCount)
+ cn.stmtCount++
+
+ if err := writeParseDescribeSync(ctx, cn, name, query); err != nil {
+ return nil, err
+ }
+
+ rowDesc, err := readParseDescribeSync(ctx, cn)
+ if err != nil {
+ return nil, err
+ }
+
+ return newStmt(cn, name, rowDesc), nil
+}
+
+func (cn *Conn) Close() error {
+ if !atomic.CompareAndSwapInt32(&cn.closed, 0, 1) {
+ return nil
+ }
+ return cn.netConn.Close()
+}
+
+func (cn *Conn) isClosed() bool {
+ return atomic.LoadInt32(&cn.closed) == 1
+}
+
+func (cn *Conn) Begin() (driver.Tx, error) {
+ return cn.BeginTx(context.Background(), driver.TxOptions{})
+}
+
+var _ driver.ConnBeginTx = (*Conn)(nil)
+
+func (cn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
+ // No need to check if the conn is closed. ExecContext below handles that.
+
+ if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
+ return nil, errors.New("pgdriver: custom IsolationLevel is not supported")
+ }
+ if opts.ReadOnly {
+ return nil, errors.New("pgdriver: ReadOnly transactions are not supported")
+ }
+
+ if _, err := cn.ExecContext(ctx, "BEGIN", nil); err != nil {
+ return nil, err
+ }
+ return tx{cn: cn}, nil
+}
+
+var _ driver.ExecerContext = (*Conn)(nil)
+
+func (cn *Conn) ExecContext(
+ ctx context.Context, query string, args []driver.NamedValue,
+) (driver.Result, error) {
+ if cn.isClosed() {
+ return nil, driver.ErrBadConn
+ }
+ res, err := cn.exec(ctx, query, args)
+ if err != nil {
+ return nil, cn.checkBadConn(err)
+ }
+ return res, nil
+}
+
+func (cn *Conn) exec(
+ ctx context.Context, query string, args []driver.NamedValue,
+) (driver.Result, error) {
+ query, err := formatQuery(query, args)
+ if err != nil {
+ return nil, err
+ }
+ if err := writeQuery(ctx, cn, query); err != nil {
+ return nil, err
+ }
+ return readQuery(ctx, cn)
+}
+
+var _ driver.QueryerContext = (*Conn)(nil)
+
+func (cn *Conn) QueryContext(
+ ctx context.Context, query string, args []driver.NamedValue,
+) (driver.Rows, error) {
+ if cn.isClosed() {
+ return nil, driver.ErrBadConn
+ }
+ rows, err := cn.query(ctx, query, args)
+ if err != nil {
+ return nil, cn.checkBadConn(err)
+ }
+ return rows, nil
+}
+
+func (cn *Conn) query(
+ ctx context.Context, query string, args []driver.NamedValue,
+) (driver.Rows, error) {
+ query, err := formatQuery(query, args)
+ if err != nil {
+ return nil, err
+ }
+ if err := writeQuery(ctx, cn, query); err != nil {
+ return nil, err
+ }
+ return readQueryData(ctx, cn)
+}
+
+var _ driver.Pinger = (*Conn)(nil)
+
+func (cn *Conn) Ping(ctx context.Context) error {
+ _, err := cn.ExecContext(ctx, "SELECT 1", nil)
+ return err
+}
+
+func (cn *Conn) setReadDeadline(ctx context.Context, timeout time.Duration) {
+ if timeout == -1 {
+ timeout = cn.driver.cfg.ReadTimeout
+ }
+ _ = cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout))
+}
+
+func (cn *Conn) setWriteDeadline(ctx context.Context, timeout time.Duration) {
+ if timeout == -1 {
+ timeout = cn.driver.cfg.WriteTimeout
+ }
+ _ = cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout))
+}
+
+func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
+ deadline, ok := ctx.Deadline()
+ if !ok {
+ if timeout == 0 {
+ return time.Time{}
+ }
+ return time.Now().Add(timeout)
+ }
+
+ if timeout == 0 {
+ return deadline
+ }
+ if tm := time.Now().Add(timeout); tm.Before(deadline) {
+ return tm
+ }
+ return deadline
+}
+
+var _ driver.Validator = (*Conn)(nil)
+
+func (cn *Conn) IsValid() bool {
+ return !cn.isClosed()
+}
+
+func (cn *Conn) checkBadConn(err error) error {
+ if isBadConn(err, false) {
+ // Close and return driver.ErrBadConn next time the conn is used.
+ _ = cn.Close()
+ }
+ // Always return the original error.
+ return err
+}
+
+//------------------------------------------------------------------------------
+
+type rows struct {
+ cn *Conn
+ rowDesc *rowDescription
+ reusable bool
+ closed bool
+}
+
+var _ driver.Rows = (*rows)(nil)
+
+func newRows(cn *Conn, rowDesc *rowDescription, reusable bool) *rows {
+ return &rows{
+ cn: cn,
+ rowDesc: rowDesc,
+ reusable: reusable,
+ }
+}
+
+func (r *rows) Columns() []string {
+ if r.closed || r.rowDesc == nil {
+ return nil
+ }
+ return r.rowDesc.names
+}
+
+func (r *rows) Close() error {
+ if r.closed {
+ return nil
+ }
+ defer r.close()
+
+ for {
+ switch err := r.Next(nil); err {
+ case nil, io.EOF:
+ return nil
+ default: // unexpected error
+ _ = r.cn.Close()
+ return err
+ }
+ }
+}
+
+func (r *rows) close() {
+ r.closed = true
+
+ if r.rowDesc != nil {
+ if r.reusable {
+ rowDescPool.Put(r.rowDesc)
+ }
+ r.rowDesc = nil
+ }
+}
+
+func (r *rows) Next(dest []driver.Value) error {
+ if r.closed {
+ return io.EOF
+ }
+
+ eof, err := r.next(dest)
+ if err == io.EOF {
+ return io.ErrUnexpectedEOF
+ } else if err != nil {
+ return err
+ }
+ if eof {
+ return io.EOF
+ }
+ return nil
+}
+
+func (r *rows) next(dest []driver.Value) (eof bool, _ error) {
+ rd := r.cn.reader(context.TODO(), -1)
+ var firstErr error
+ for {
+ c, msgLen, err := readMessageType(rd)
+ if err != nil {
+ return false, err
+ }
+
+ switch c {
+ case dataRowMsg:
+ return false, r.readDataRow(rd, dest)
+ case commandCompleteMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return false, err
+ }
+ case readyForQueryMsg:
+ r.close()
+
+ if err := rd.Discard(msgLen); err != nil {
+ return false, err
+ }
+
+ if firstErr != nil {
+ return false, firstErr
+ }
+ return true, nil
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return false, err
+ }
+ if firstErr == nil {
+ firstErr = e
+ }
+ default:
+ return false, fmt.Errorf("pgdriver: Next: unexpected message %q", c)
+ }
+ }
+}
+
+func (r *rows) readDataRow(rd *reader, dest []driver.Value) error {
+ numCol, err := readInt16(rd)
+ if err != nil {
+ return err
+ }
+
+ if len(dest) != int(numCol) {
+ return fmt.Errorf("pgdriver: query returned %d columns, but Scan dest has %d items",
+ numCol, len(dest))
+ }
+
+ for colIdx := int16(0); colIdx < numCol; colIdx++ {
+ dataLen, err := readInt32(rd)
+ if err != nil {
+ return err
+ }
+
+ value, err := readColumnValue(rd, r.rowDesc.types[colIdx], int(dataLen))
+ if err != nil {
+ return err
+ }
+
+ if dest != nil {
+ dest[colIdx] = value
+ }
+ }
+
+ return nil
+}
+
+//------------------------------------------------------------------------------
+
+func parseResult(b []byte) (driver.RowsAffected, error) {
+ i := bytes.LastIndexByte(b, ' ')
+ if i == -1 {
+ return 0, nil
+ }
+
+ b = b[i+1 : len(b)-1]
+ affected, err := strconv.ParseUint(bytesToString(b), 10, 64)
+ if err != nil {
+ return 0, nil
+ }
+
+ return driver.RowsAffected(affected), nil
+}
+
+//------------------------------------------------------------------------------
+
+type tx struct {
+ cn *Conn
+}
+
+var _ driver.Tx = (*tx)(nil)
+
+func (tx tx) Commit() error {
+ _, err := tx.cn.ExecContext(context.Background(), "COMMIT", nil)
+ return err
+}
+
+func (tx tx) Rollback() error {
+ _, err := tx.cn.ExecContext(context.Background(), "ROLLBACK", nil)
+ return err
+}
+
+//------------------------------------------------------------------------------
+
+type stmt struct {
+ cn *Conn
+ name string
+ rowDesc *rowDescription
+}
+
+var (
+ _ driver.Stmt = (*stmt)(nil)
+ _ driver.StmtExecContext = (*stmt)(nil)
+ _ driver.StmtQueryContext = (*stmt)(nil)
+)
+
+func newStmt(cn *Conn, name string, rowDesc *rowDescription) *stmt {
+ return &stmt{
+ cn: cn,
+ name: name,
+ rowDesc: rowDesc,
+ }
+}
+
+func (stmt *stmt) Close() error {
+ if stmt.rowDesc != nil {
+ rowDescPool.Put(stmt.rowDesc)
+ stmt.rowDesc = nil
+ }
+
+ ctx := context.TODO()
+ if err := writeCloseStmt(ctx, stmt.cn, stmt.name); err != nil {
+ return err
+ }
+ if err := readCloseStmtComplete(ctx, stmt.cn); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (stmt *stmt) NumInput() int {
+ if stmt.rowDesc == nil {
+ return -1
+ }
+ return int(stmt.rowDesc.numInput)
+}
+
+func (stmt *stmt) Exec(args []driver.Value) (driver.Result, error) {
+ panic("not implemented")
+}
+
+func (stmt *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+ if err := writeBindExecute(ctx, stmt.cn, stmt.name, args); err != nil {
+ return nil, err
+ }
+ return readExtQuery(ctx, stmt.cn)
+}
+
+func (stmt *stmt) Query(args []driver.Value) (driver.Rows, error) {
+ panic("not implemented")
+}
+
+func (stmt *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+ if err := writeBindExecute(ctx, stmt.cn, stmt.name, args); err != nil {
+ return nil, err
+ }
+ return readExtQueryData(ctx, stmt.cn, stmt.rowDesc)
+}
+
+//------------------------------------------------------------------------------
+
+var bufioWriterPool = sync.Pool{
+ New: func() interface{} {
+ return bufio.NewWriter(nil)
+ },
+}
+
+func getBufioWriter() *bufio.Writer {
+ return bufioWriterPool.Get().(*bufio.Writer)
+}
+
+func putBufioWriter(wr *bufio.Writer) {
+ bufioWriterPool.Put(wr)
+}
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/error.go b/vendor/github.com/uptrace/bun/driver/pgdriver/error.go
new file mode 100644
index 000000000..5f1f9fec6
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/error.go
@@ -0,0 +1,66 @@
+package pgdriver
+
+import (
+ "fmt"
+ "net"
+)
+
+// Error represents an error returned by PostgreSQL server
+// using PostgreSQL ErrorResponse protocol.
+//
+// https://www.postgresql.org/docs/current/static/protocol-message-formats.html
+type Error struct {
+ m map[byte]string
+}
+
+// Field returns a string value associated with an error field.
+//
+// https://www.postgresql.org/docs/current/static/protocol-error-fields.html
+func (err Error) Field(k byte) string {
+ return err.m[k]
+}
+
+// IntegrityViolation reports whether an error is a part of
+// Integrity Constraint Violation class of errors.
+//
+// https://www.postgresql.org/docs/current/static/errcodes-appendix.html
+func (err Error) IntegrityViolation() bool {
+ switch err.Field('C') {
+ case "23000", "23001", "23502", "23503", "23505", "23514", "23P01":
+ return true
+ default:
+ return false
+ }
+}
+
+func (err Error) Error() string {
+ return fmt.Sprintf("%s #%s %s",
+ err.Field('S'), err.Field('C'), err.Field('M'))
+}
+
+func isBadConn(err error, allowTimeout bool) bool {
+ if err == nil {
+ return false
+ }
+
+ if err, ok := err.(Error); ok {
+ switch err.Field('V') {
+ case "FATAL", "PANIC":
+ return true
+ }
+ switch err.Field('C') {
+ case "25P02", // current transaction is aborted
+ "57014": // canceling statement due to user request
+ return true
+ }
+ return false
+ }
+
+ if allowTimeout {
+ if err, ok := err.(net.Error); ok && err.Timeout() {
+ return !err.Temporary()
+ }
+ }
+
+ return true
+}
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/format.go b/vendor/github.com/uptrace/bun/driver/pgdriver/format.go
new file mode 100644
index 000000000..c85967da5
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/format.go
@@ -0,0 +1,188 @@
+package pgdriver
+
+import (
+ "database/sql/driver"
+ "encoding/hex"
+ "fmt"
+ "math"
+ "strconv"
+ "time"
+ "unicode/utf8"
+)
+
+func formatQuery(query string, args []driver.NamedValue) (string, error) {
+ if len(args) == 0 {
+ return query, nil
+ }
+
+ dst := make([]byte, 0, 2*len(query))
+
+ p := newParser(query)
+ for p.Valid() {
+ switch c := p.Next(); c {
+ case '$':
+ if i, ok := p.Number(); ok {
+ if i > len(args) {
+ return "", fmt.Errorf("pgdriver: got %d args, wanted %d", len(args), i)
+ }
+
+ var err error
+ dst, err = appendArg(dst, args[i-1].Value)
+ if err != nil {
+ return "", err
+ }
+ } else {
+ dst = append(dst, '$')
+ }
+ case '\'':
+ if b, ok := p.QuotedString(); ok {
+ dst = append(dst, b...)
+ } else {
+ dst = append(dst, '\'')
+ }
+ default:
+ dst = append(dst, c)
+ }
+ }
+
+ return bytesToString(dst), nil
+}
+
+func appendArg(b []byte, v interface{}) ([]byte, error) {
+ switch v := v.(type) {
+ case nil:
+ return append(b, "NULL"...), nil
+ case int64:
+ return strconv.AppendInt(b, v, 10), nil
+ case float64:
+ switch {
+ case math.IsNaN(v):
+ return append(b, "'NaN'"...), nil
+ case math.IsInf(v, 1):
+ return append(b, "'Infinity'"...), nil
+ case math.IsInf(v, -1):
+ return append(b, "'-Infinity'"...), nil
+ default:
+ return strconv.AppendFloat(b, v, 'f', -1, 64), nil
+ }
+ case bool:
+ if v {
+ return append(b, "TRUE"...), nil
+ }
+ return append(b, "FALSE"...), nil
+ case []byte:
+ if v == nil {
+ return append(b, "NULL"...), nil
+ }
+
+ b = append(b, `'\x`...)
+
+ s := len(b)
+ b = append(b, make([]byte, hex.EncodedLen(len(v)))...)
+ hex.Encode(b[s:], v)
+
+ b = append(b, "'"...)
+
+ return b, nil
+ case string:
+ b = append(b, '\'')
+ for _, r := range v {
+ if r == '\000' {
+ continue
+ }
+
+ if r == '\'' {
+ b = append(b, '\'', '\'')
+ continue
+ }
+
+ if r < utf8.RuneSelf {
+ b = append(b, byte(r))
+ continue
+ }
+ l := len(b)
+ if cap(b)-l < utf8.UTFMax {
+ b = append(b, make([]byte, utf8.UTFMax)...)
+ }
+ n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
+ b = b[:l+n]
+ }
+ b = append(b, '\'')
+ return b, nil
+ case time.Time:
+ if v.IsZero() {
+ return append(b, "NULL"...), nil
+ }
+ return v.UTC().AppendFormat(b, "'2006-01-02 15:04:05.999999-07:00'"), nil
+ default:
+ return nil, fmt.Errorf("pgdriver: unexpected arg: %T", v)
+ }
+}
+
+type parser struct {
+ b []byte
+ i int
+}
+
+func newParser(s string) *parser {
+ return &parser{
+ b: stringToBytes(s),
+ }
+}
+
+func (p *parser) Valid() bool {
+ return p.i < len(p.b)
+}
+
+func (p *parser) Next() byte {
+ c := p.b[p.i]
+ p.i++
+ return c
+}
+
+func (p *parser) Number() (int, bool) {
+ start := p.i
+ end := len(p.b)
+
+ for i := p.i; i < len(p.b); i++ {
+ c := p.b[i]
+ if !isNum(c) {
+ end = i
+ break
+ }
+ }
+
+ p.i = end
+ b := p.b[start:end]
+
+ n, err := strconv.Atoi(bytesToString(b))
+ if err != nil {
+ return 0, false
+ }
+
+ return n, true
+}
+
+func (p *parser) QuotedString() ([]byte, bool) {
+ start := p.i - 1
+ end := len(p.b)
+
+ var c byte
+ for i := p.i; i < len(p.b); i++ {
+ next := p.b[i]
+ if c == '\'' && next != '\'' {
+ end = i
+ break
+ }
+ c = next
+ }
+
+ p.i = end
+ b := p.b[start:end]
+
+ return b, true
+}
+
+func isNum(c byte) bool {
+ return c >= '0' && c <= '9'
+}
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/go.mod b/vendor/github.com/uptrace/bun/driver/pgdriver/go.mod
new file mode 100644
index 000000000..0ebe475f9
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/go.mod
@@ -0,0 +1,11 @@
+module github.com/uptrace/bun/driver/pgdriver
+
+go 1.16
+
+replace github.com/uptrace/bun => ../..
+
+require (
+ github.com/stretchr/testify v1.7.0
+ github.com/uptrace/bun v0.4.3
+ mellium.im/sasl v0.2.1
+)
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/go.sum b/vendor/github.com/uptrace/bun/driver/pgdriver/go.sum
new file mode 100644
index 000000000..db9059b35
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/go.sum
@@ -0,0 +1,27 @@
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
+github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
+github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
+github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc=
+github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
+github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
+github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
+golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b h1:2b9XGzhjiYsYPnKXoEfL7klWZQIt8IfyRCz62gCqqlQ=
+golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
+golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio=
+golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w=
+mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ=
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/listener.go b/vendor/github.com/uptrace/bun/driver/pgdriver/listener.go
new file mode 100644
index 000000000..937c80fa7
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/listener.go
@@ -0,0 +1,392 @@
+package pgdriver
+
+import (
+ "context"
+ "errors"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/uptrace/bun"
+)
+
+const pingChannel = "bun:ping"
+
+var (
+ errListenerClosed = errors.New("bun: listener is closed")
+ errPingTimeout = errors.New("bun: ping timeout")
+)
+
+type Listener struct {
+ db *bun.DB
+ driver *Connector
+
+ channels []string
+
+ mu sync.Mutex
+ cn *Conn
+ closed bool
+ exit chan struct{}
+}
+
+func NewListener(db *bun.DB) *Listener {
+ return &Listener{
+ db: db,
+ driver: db.Driver().(Driver).connector,
+ exit: make(chan struct{}),
+ }
+}
+
+// Close closes the listener, releasing any open resources.
+func (ln *Listener) Close() error {
+ return ln.withLock(func() error {
+ if ln.closed {
+ return errListenerClosed
+ }
+
+ ln.closed = true
+ close(ln.exit)
+
+ return ln.closeConn(errListenerClosed)
+ })
+}
+
+func (ln *Listener) withLock(fn func() error) error {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+ return fn()
+}
+
+func (ln *Listener) conn(ctx context.Context) (*Conn, error) {
+ if ln.closed {
+ return nil, errListenerClosed
+ }
+ if ln.cn != nil {
+ return ln.cn, nil
+ }
+
+ atomic.AddUint64(&ln.driver.stats.Queries, 1)
+
+ cn, err := ln._conn(ctx)
+ if err != nil {
+ atomic.AddUint64(&ln.driver.stats.Errors, 1)
+ return nil, err
+ }
+
+ ln.cn = cn
+ return cn, nil
+}
+
+func (ln *Listener) _conn(ctx context.Context) (*Conn, error) {
+ driverConn, err := ln.driver.Connect(ctx)
+ if err != nil {
+ return nil, err
+ }
+ cn := driverConn.(*Conn)
+
+ if len(ln.channels) > 0 {
+ err := ln.listen(ctx, cn, ln.channels...)
+ if err != nil {
+ _ = cn.Close()
+ return nil, err
+ }
+ }
+
+ return cn, nil
+}
+
+func (ln *Listener) checkConn(ctx context.Context, cn *Conn, err error, allowTimeout bool) {
+ _ = ln.withLock(func() error {
+ if ln.closed || ln.cn != cn {
+ return nil
+ }
+ if isBadConn(err, allowTimeout) {
+ ln.reconnect(ctx, err)
+ }
+ return nil
+ })
+}
+
+func (ln *Listener) reconnect(ctx context.Context, reason error) {
+ if ln.cn != nil {
+ Logger.Printf(ctx, "bun: discarding bad listener connection: %s", reason)
+ _ = ln.closeConn(reason)
+ }
+ _, _ = ln.conn(ctx)
+}
+
+func (ln *Listener) closeConn(reason error) error {
+ if ln.cn == nil {
+ return nil
+ }
+ err := ln.cn.Close()
+ ln.cn = nil
+ return err
+}
+
+// Listen starts listening for notifications on channels.
+func (ln *Listener) Listen(ctx context.Context, channels ...string) error {
+ var cn *Conn
+
+ if err := ln.withLock(func() error {
+ ln.channels = appendIfNotExists(ln.channels, channels...)
+
+ var err error
+ cn, err = ln.conn(ctx)
+ return err
+ }); err != nil {
+ return err
+ }
+
+ if err := ln.listen(ctx, cn, channels...); err != nil {
+ ln.checkConn(ctx, cn, err, false)
+ return err
+ }
+ return nil
+}
+
+func (ln *Listener) listen(ctx context.Context, cn *Conn, channels ...string) error {
+ for _, channel := range channels {
+ if err := writeQuery(ctx, cn, "LISTEN "+strconv.Quote(channel)); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Unlisten stops listening for notifications on channels.
+func (ln *Listener) Unlisten(ctx context.Context, channels ...string) error {
+ var cn *Conn
+
+ if err := ln.withLock(func() error {
+ ln.channels = removeIfExists(ln.channels, channels...)
+
+ var err error
+ cn, err = ln.conn(ctx)
+ return err
+ }); err != nil {
+ return err
+ }
+
+ if err := ln.unlisten(ctx, cn, channels...); err != nil {
+ ln.checkConn(ctx, cn, err, false)
+ return err
+ }
+ return nil
+}
+
+func (ln *Listener) unlisten(ctx context.Context, cn *Conn, channels ...string) error {
+ for _, channel := range channels {
+ if err := writeQuery(ctx, cn, "UNLISTEN "+strconv.Quote(channel)); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Receive indefinitely waits for a notification. This is low-level API
+// and in most cases Channel should be used instead.
+func (ln *Listener) Receive(ctx context.Context) (channel string, payload string, err error) {
+ return ln.ReceiveTimeout(ctx, 0)
+}
+
+// ReceiveTimeout waits for a notification until timeout is reached.
+// This is low-level API and in most cases Channel should be used instead.
+func (ln *Listener) ReceiveTimeout(
+ ctx context.Context, timeout time.Duration,
+) (channel, payload string, err error) {
+ var cn *Conn
+
+ if err := ln.withLock(func() error {
+ var err error
+ cn, err = ln.conn(ctx)
+ return err
+ }); err != nil {
+ return "", "", err
+ }
+
+ rd := cn.reader(ctx, timeout)
+ channel, payload, err = readNotification(ctx, rd)
+ if err != nil {
+ ln.checkConn(ctx, cn, err, timeout > 0)
+ return "", "", err
+ }
+
+ return channel, payload, nil
+}
+
+// Channel returns a channel for concurrently receiving notifications.
+// It periodically sends Ping notification to test connection health.
+//
+// The channel is closed with Listener. Receive* APIs can not be used
+// after channel is created.
+func (ln *Listener) Channel(opts ...ChannelOption) <-chan Notification {
+ return newChannel(ln, opts).ch
+}
+
+//------------------------------------------------------------------------------
+
+// Notification received with LISTEN command.
+type Notification struct {
+ Channel string
+ Payload string
+}
+
+type ChannelOption func(c *channel)
+
+func WithChannelSize(size int) ChannelOption {
+ return func(c *channel) {
+ c.size = size
+ }
+}
+
+type channel struct {
+ ctx context.Context
+ ln *Listener
+
+ size int
+ pingTimeout time.Duration
+ chanSendTimeout time.Duration
+
+ ch chan Notification
+ pingCh chan struct{}
+}
+
+func newChannel(ln *Listener, opts []ChannelOption) *channel {
+ c := &channel{
+ ctx: context.TODO(),
+ ln: ln,
+
+ size: 100,
+ pingTimeout: 5 * time.Second,
+ chanSendTimeout: time.Minute,
+ }
+
+ for _, opt := range opts {
+ opt(c)
+ }
+
+ c.ch = make(chan Notification, c.size)
+ c.pingCh = make(chan struct{}, 1)
+ _ = c.ln.Listen(c.ctx, pingChannel)
+ go c.startReceive()
+ go c.startPing()
+
+ return c
+}
+
+func (c *channel) startReceive() {
+ timer := time.NewTimer(time.Minute)
+ timer.Stop()
+
+ var errCount int
+ for {
+ channel, payload, err := c.ln.Receive(c.ctx)
+ if err != nil {
+ if err == errListenerClosed {
+ close(c.ch)
+ return
+ }
+
+ if errCount > 0 {
+ time.Sleep(500 * time.Millisecond)
+ }
+ errCount++
+
+ continue
+ }
+
+ errCount = 0
+
+ // Any notification is as good as a ping.
+ select {
+ case c.pingCh <- struct{}{}:
+ default:
+ }
+
+ switch channel {
+ case pingChannel:
+ // ignore
+ default:
+ timer.Reset(c.chanSendTimeout)
+ select {
+ case c.ch <- Notification{channel, payload}:
+ if !timer.Stop() {
+ <-timer.C
+ }
+ case <-timer.C:
+ Logger.Printf(
+ c.ctx,
+ "pgdriver: %s channel is full for %s (notification is dropped)",
+ c,
+ c.chanSendTimeout,
+ )
+ }
+ }
+ }
+}
+
+func (c *channel) startPing() {
+ timer := time.NewTimer(time.Minute)
+ timer.Stop()
+
+ healthy := true
+ for {
+ timer.Reset(c.pingTimeout)
+ select {
+ case <-c.pingCh:
+ healthy = true
+ if !timer.Stop() {
+ <-timer.C
+ }
+ case <-timer.C:
+ pingErr := c.ping(c.ctx)
+ if healthy {
+ healthy = false
+ } else {
+ if pingErr == nil {
+ pingErr = errPingTimeout
+ }
+ _ = c.ln.withLock(func() error {
+ c.ln.reconnect(c.ctx, pingErr)
+ return nil
+ })
+ }
+ case <-c.ln.exit:
+ return
+ }
+ }
+}
+
+func (c *channel) ping(ctx context.Context) error {
+ _, err := c.ln.db.ExecContext(ctx, "NOTIFY "+strconv.Quote(pingChannel))
+ return err
+}
+
+func appendIfNotExists(ss []string, es ...string) []string {
+loop:
+ for _, e := range es {
+ for _, s := range ss {
+ if s == e {
+ continue loop
+ }
+ }
+ ss = append(ss, e)
+ }
+ return ss
+}
+
+func removeIfExists(ss []string, es ...string) []string {
+ for _, e := range es {
+ for i, s := range ss {
+ if s == e {
+ last := len(ss) - 1
+ ss[i] = ss[last]
+ ss = ss[:last]
+ break
+ }
+ }
+ }
+ return ss
+}
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/proto.go b/vendor/github.com/uptrace/bun/driver/pgdriver/proto.go
new file mode 100644
index 000000000..327fc29fc
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/proto.go
@@ -0,0 +1,1127 @@
+package pgdriver
+
+import (
+ "bufio"
+ "context"
+ "crypto/md5"
+ "crypto/tls"
+ "database/sql"
+ "database/sql/driver"
+ "encoding/binary"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "strconv"
+ "sync"
+ "time"
+ "unicode/utf8"
+
+ "mellium.im/sasl"
+)
+
+// https://www.postgresql.org/docs/current/protocol-message-formats.html
+//nolint:deadcode,varcheck,unused
+const (
+ commandCompleteMsg = 'C'
+ errorResponseMsg = 'E'
+ noticeResponseMsg = 'N'
+ parameterStatusMsg = 'S'
+ authenticationOKMsg = 'R'
+ backendKeyDataMsg = 'K'
+ noDataMsg = 'n'
+ passwordMessageMsg = 'p'
+ terminateMsg = 'X'
+
+ saslInitialResponseMsg = 'p'
+ authenticationSASLContinueMsg = 'R'
+ saslResponseMsg = 'p'
+ authenticationSASLFinalMsg = 'R'
+
+ authenticationOK = 0
+ authenticationCleartextPassword = 3
+ authenticationMD5Password = 5
+ authenticationSASL = 10
+
+ notificationResponseMsg = 'A'
+
+ describeMsg = 'D'
+ parameterDescriptionMsg = 't'
+
+ queryMsg = 'Q'
+ readyForQueryMsg = 'Z'
+ emptyQueryResponseMsg = 'I'
+ rowDescriptionMsg = 'T'
+ dataRowMsg = 'D'
+
+ parseMsg = 'P'
+ parseCompleteMsg = '1'
+
+ bindMsg = 'B'
+ bindCompleteMsg = '2'
+
+ executeMsg = 'E'
+
+ syncMsg = 'S'
+ flushMsg = 'H'
+
+ closeMsg = 'C'
+ closeCompleteMsg = '3'
+
+ copyInResponseMsg = 'G'
+ copyOutResponseMsg = 'H'
+ copyDataMsg = 'd'
+ copyDoneMsg = 'c'
+)
+
+var errEmptyQuery = errors.New("pgdriver: query is empty")
+
+type reader struct {
+ *bufio.Reader
+ buf []byte
+}
+
+func newReader(r io.Reader) *reader {
+ return &reader{
+ Reader: bufio.NewReader(r),
+ buf: make([]byte, 128),
+ }
+}
+
+func (r *reader) ReadTemp(n int) ([]byte, error) {
+ if n <= len(r.buf) {
+ b := r.buf[:n]
+ _, err := io.ReadFull(r.Reader, b)
+ return b, err
+ }
+
+ b := make([]byte, n)
+ _, err := io.ReadFull(r.Reader, b)
+ return b, err
+}
+
+func (r *reader) Discard(n int) error {
+ _, err := r.ReadTemp(n)
+ return err
+}
+
+func enableSSL(ctx context.Context, cn *Conn, tlsConf *tls.Config) error {
+ if err := writeSSLMsg(ctx, cn); err != nil {
+ return err
+ }
+
+ rd := cn.reader(ctx, -1)
+
+ c, err := rd.ReadByte()
+ if err != nil {
+ return err
+ }
+ if c != 'S' {
+ return errors.New("pgdriver: SSL is not enabled on the server")
+ }
+
+ cn.netConn = tls.Client(cn.netConn, tlsConf)
+ rd.Reset(cn.netConn)
+
+ return nil
+}
+
+func writeSSLMsg(ctx context.Context, cn *Conn) error {
+ wb := getWriteBuffer()
+ defer putWriteBuffer(wb)
+
+ wb.StartMessage(0)
+ wb.WriteInt32(80877103)
+ wb.FinishMessage()
+
+ return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error {
+ _, err := wr.Write(wb.Bytes)
+ return err
+ })
+}
+
+//------------------------------------------------------------------------------
+
+func startup(ctx context.Context, cn *Conn) error {
+ if err := writeStartup(ctx, cn); err != nil {
+ return err
+ }
+
+ rd := cn.reader(ctx, -1)
+
+ for {
+ c, msgLen, err := readMessageType(rd)
+ if err != nil {
+ return err
+ }
+
+ switch c {
+ case backendKeyDataMsg:
+ processID, err := readInt32(rd)
+ if err != nil {
+ return err
+ }
+ secretKey, err := readInt32(rd)
+ if err != nil {
+ return err
+ }
+ cn.processID = processID
+ cn.secretKey = secretKey
+ case authenticationOKMsg:
+ if err := auth(ctx, cn, rd); err != nil {
+ return err
+ }
+ case readyForQueryMsg:
+ return rd.Discard(msgLen)
+ case parameterStatusMsg, noticeResponseMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return err
+ }
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return err
+ }
+ return e
+ default:
+ return fmt.Errorf("pgdriver: unexpected startup message: %q", c)
+ }
+ }
+}
+
+func writeStartup(ctx context.Context, cn *Conn) error {
+ wb := getWriteBuffer()
+ defer putWriteBuffer(wb)
+
+ wb.StartMessage(0)
+ wb.WriteInt32(196608)
+ wb.WriteString("user")
+ wb.WriteString(cn.driver.cfg.User)
+ wb.WriteString("database")
+ wb.WriteString(cn.driver.cfg.Database)
+ if cn.driver.cfg.AppName != "" {
+ wb.WriteString("application_name")
+ wb.WriteString(cn.driver.cfg.AppName)
+ }
+ wb.WriteString("")
+ wb.FinishMessage()
+
+ return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error {
+ _, err := wr.Write(wb.Bytes)
+ return err
+ })
+}
+
+//------------------------------------------------------------------------------
+
+func auth(ctx context.Context, cn *Conn, rd *reader) error {
+ num, err := readInt32(rd)
+ if err != nil {
+ return err
+ }
+
+ switch num {
+ case authenticationOK:
+ return nil
+ case authenticationCleartextPassword:
+ return authCleartext(ctx, cn, rd)
+ case authenticationMD5Password:
+ return authMD5(ctx, cn, rd)
+ case authenticationSASL:
+ if err := authSASL(ctx, cn, rd); err != nil {
+ return fmt.Errorf("pgdriver: SASL: %w", err)
+ }
+ return nil
+ default:
+ return fmt.Errorf("pgdriver: unknown authentication message: %q", num)
+ }
+}
+
+func authCleartext(ctx context.Context, cn *Conn, rd *reader) error {
+ if err := writePassword(ctx, cn, cn.driver.cfg.Password); err != nil {
+ return err
+ }
+ return readAuthOK(cn, rd)
+}
+
+func readAuthOK(cn *Conn, rd *reader) error {
+ c, _, err := readMessageType(rd)
+ if err != nil {
+ return err
+ }
+
+ switch c {
+ case authenticationOKMsg:
+ num, err := readInt32(rd)
+ if err != nil {
+ return err
+ }
+ if num != 0 {
+ return fmt.Errorf("pgdriver: unexpected authentication code: %q", num)
+ }
+ return nil
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return err
+ }
+ return e
+ default:
+ return fmt.Errorf("pgdriver: unknown password message: %q", c)
+ }
+}
+
+//------------------------------------------------------------------------------
+
+func authMD5(ctx context.Context, cn *Conn, rd *reader) error {
+ b, err := rd.ReadTemp(4)
+ if err != nil {
+ return err
+ }
+
+ secret := "md5" + md5s(md5s(cn.driver.cfg.Password+cn.driver.cfg.User)+string(b))
+ if err := writePassword(ctx, cn, secret); err != nil {
+ return err
+ }
+
+ return readAuthOK(cn, rd)
+}
+
+func writePassword(ctx context.Context, cn *Conn, password string) error {
+ wb := getWriteBuffer()
+ defer putWriteBuffer(wb)
+
+ wb.StartMessage(passwordMessageMsg)
+ wb.WriteString(password)
+ wb.FinishMessage()
+
+ return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error {
+ _, err := wr.Write(wb.Bytes)
+ return err
+ })
+}
+
+func md5s(s string) string {
+ h := md5.Sum([]byte(s))
+ return hex.EncodeToString(h[:])
+}
+
+//------------------------------------------------------------------------------
+
+func authSASL(ctx context.Context, cn *Conn, rd *reader) error {
+ s, err := readString(rd)
+ if err != nil {
+ return err
+ }
+
+ var saslMech sasl.Mechanism
+
+ switch s {
+ case sasl.ScramSha256.Name:
+ saslMech = sasl.ScramSha256
+ case sasl.ScramSha256Plus.Name:
+ saslMech = sasl.ScramSha256Plus
+ default:
+ return fmt.Errorf("got %q, wanted %q", s, sasl.ScramSha256.Name)
+ }
+
+ c0, err := rd.ReadByte()
+ if err != nil {
+ return err
+ }
+ if c0 != 0 {
+ return fmt.Errorf("got %q, wanted %q", c0, 0)
+ }
+
+ creds := sasl.Credentials(func() (Username, Password, Identity []byte) {
+ return []byte(cn.driver.cfg.User), []byte(cn.driver.cfg.Password), nil
+ })
+ client := sasl.NewClient(saslMech, creds)
+
+ _, resp, err := client.Step(nil)
+ if err != nil {
+ return fmt.Errorf("client.Step 1 failed: %w", err)
+ }
+
+ if err := saslWriteInitialResponse(ctx, cn, saslMech, resp); err != nil {
+ return err
+ }
+
+ c, msgLen, err := readMessageType(rd)
+ if err != nil {
+ return err
+ }
+
+ switch c {
+ case authenticationSASLContinueMsg:
+ c11, err := readInt32(rd)
+ if err != nil {
+ return err
+ }
+ if c11 != 11 {
+ return fmt.Errorf("got %q, wanted %q", c, 11)
+ }
+
+ b, err := rd.ReadTemp(msgLen - 4)
+ if err != nil {
+ return err
+ }
+
+ _, resp, err = client.Step(b)
+ if err != nil {
+ return fmt.Errorf("client.Step 2 failed: %w", err)
+ }
+
+ if err := saslWriteResponse(ctx, cn, resp); err != nil {
+ return err
+ }
+
+ resp, err = saslReadAuthFinal(cn, rd)
+ if err != nil {
+ return err
+ }
+
+ if _, _, err := client.Step(resp); err != nil {
+ return fmt.Errorf("client.Step 3 failed: %w", err)
+ }
+
+ if client.State() != sasl.ValidServerResponse {
+ return fmt.Errorf("got state=%q, wanted %q", client.State(), sasl.ValidServerResponse)
+ }
+
+ return nil
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return err
+ }
+ return e
+ default:
+ return fmt.Errorf("got %q, wanted %q", c, authenticationSASLContinueMsg)
+ }
+}
+
+func saslWriteInitialResponse(
+ ctx context.Context, cn *Conn, saslMech sasl.Mechanism, resp []byte,
+) error {
+ wb := getWriteBuffer()
+ defer putWriteBuffer(wb)
+
+ wb.StartMessage(saslInitialResponseMsg)
+ wb.WriteString(saslMech.Name)
+ wb.WriteInt32(int32(len(resp)))
+ if _, err := wb.Write(resp); err != nil {
+ return err
+ }
+ wb.FinishMessage()
+
+ return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error {
+ _, err := wr.Write(wb.Bytes)
+ return err
+ })
+}
+
+func saslWriteResponse(ctx context.Context, cn *Conn, resp []byte) error {
+ wb := getWriteBuffer()
+ defer putWriteBuffer(wb)
+
+ wb.StartMessage(saslResponseMsg)
+ if _, err := wb.Write(resp); err != nil {
+ return err
+ }
+ wb.FinishMessage()
+
+ return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error {
+ _, err := wr.Write(wb.Bytes)
+ return err
+ })
+}
+
+func saslReadAuthFinal(cn *Conn, rd *reader) ([]byte, error) {
+ c, msgLen, err := readMessageType(rd)
+ if err != nil {
+ return nil, err
+ }
+
+ switch c {
+ case authenticationSASLFinalMsg:
+ c12, err := readInt32(rd)
+ if err != nil {
+ return nil, err
+ }
+ if c12 != 12 {
+ return nil, fmt.Errorf("got %q, wanted %q", c, 12)
+ }
+
+ resp := make([]byte, msgLen-4)
+ if _, err := io.ReadFull(rd, resp); err != nil {
+ return nil, err
+ }
+
+ if err := readAuthOK(cn, rd); err != nil {
+ return nil, err
+ }
+
+ return resp, nil
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return nil, err
+ }
+ return nil, e
+ default:
+ return nil, fmt.Errorf("got %q, wanted %q", c, authenticationSASLFinalMsg)
+ }
+}
+
+//------------------------------------------------------------------------------
+
+func writeQuery(ctx context.Context, cn *Conn, query string) error {
+ return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error {
+ if err := wr.WriteByte(queryMsg); err != nil {
+ return err
+ }
+
+ binary.BigEndian.PutUint32(cn.rd.buf, uint32(len(query)+5))
+ if _, err := wr.Write(cn.rd.buf[:4]); err != nil {
+ return err
+ }
+
+ if _, err := wr.WriteString(query); err != nil {
+ return err
+ }
+ if err := wr.WriteByte(0x0); err != nil {
+ return err
+ }
+
+ return nil
+ })
+}
+
+func readQuery(ctx context.Context, cn *Conn) (sql.Result, error) {
+ rd := cn.reader(ctx, -1)
+
+ var res driver.Result
+ var firstErr error
+ for {
+ c, msgLen, err := readMessageType(rd)
+ if err != nil {
+ return nil, err
+ }
+
+ switch c {
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return nil, err
+ }
+ if firstErr == nil {
+ firstErr = e
+ }
+ case emptyQueryResponseMsg:
+ if firstErr == nil {
+ firstErr = errEmptyQuery
+ }
+ case commandCompleteMsg:
+ tmp, err := rd.ReadTemp(msgLen)
+ if err != nil {
+ firstErr = err
+ break
+ }
+
+ r, err := parseResult(tmp)
+ if err != nil {
+ firstErr = err
+ } else {
+ res = r
+ }
+ case describeMsg,
+ rowDescriptionMsg,
+ noticeResponseMsg,
+ parameterStatusMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ case readyForQueryMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ return res, firstErr
+ default:
+ return nil, fmt.Errorf("pgdriver: Exec: unexpected message %q", c)
+ }
+ }
+}
+
+func readQueryData(ctx context.Context, cn *Conn) (*rows, error) {
+ rd := cn.reader(ctx, -1)
+ var firstErr error
+ for {
+ c, msgLen, err := readMessageType(rd)
+ if err != nil {
+ return nil, err
+ }
+
+ switch c {
+ case rowDescriptionMsg:
+ rowDesc, err := readRowDescription(rd)
+ if err != nil {
+ return nil, err
+ }
+ return newRows(cn, rowDesc, true), nil
+ case commandCompleteMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ case readyForQueryMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ if firstErr != nil {
+ return nil, firstErr
+ }
+ return &rows{closed: true}, nil
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return nil, err
+ }
+ if firstErr == nil {
+ firstErr = e
+ }
+ case emptyQueryResponseMsg:
+ if firstErr == nil {
+ firstErr = errEmptyQuery
+ }
+ case noticeResponseMsg, parameterStatusMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ default:
+ return nil, fmt.Errorf("pgdriver: newRows: unexpected message %q", c)
+ }
+ }
+}
+
+//------------------------------------------------------------------------------
+
+var rowDescPool sync.Pool
+
+type rowDescription struct {
+ buf []byte
+ names []string
+ types []int32
+ numInput int16
+}
+
+func newRowDescription(numCol int) *rowDescription {
+ if numCol < 16 {
+ numCol = 16
+ }
+ return &rowDescription{
+ buf: make([]byte, 0, 16*numCol),
+ names: make([]string, 0, numCol),
+ types: make([]int32, 0, numCol),
+ numInput: -1,
+ }
+}
+
+func (d *rowDescription) reset(numCol int) {
+ d.buf = make([]byte, 0, 16*numCol)
+ d.names = d.names[:0]
+ d.types = d.types[:0]
+ d.numInput = -1
+}
+
+func (d *rowDescription) addName(name []byte) {
+ if len(d.buf)+len(name) > cap(d.buf) {
+ d.buf = make([]byte, 0, cap(d.buf))
+ }
+
+ i := len(d.buf)
+ d.buf = append(d.buf, name...)
+ d.names = append(d.names, bytesToString(d.buf[i:]))
+}
+
+func (d *rowDescription) addType(dataType int32) {
+ d.types = append(d.types, dataType)
+}
+
+func readRowDescription(rd *reader) (*rowDescription, error) {
+ numCol, err := readInt16(rd)
+ if err != nil {
+ return nil, err
+ }
+
+ rowDesc, ok := rowDescPool.Get().(*rowDescription)
+ if !ok {
+ rowDesc = newRowDescription(int(numCol))
+ } else {
+ rowDesc.reset(int(numCol))
+ }
+
+ for i := 0; i < int(numCol); i++ {
+ name, err := rd.ReadSlice(0)
+ if err != nil {
+ return nil, err
+ }
+ rowDesc.addName(name[:len(name)-1])
+
+ if _, err := rd.ReadTemp(6); err != nil {
+ return nil, err
+ }
+
+ dataType, err := readInt32(rd)
+ if err != nil {
+ return nil, err
+ }
+ rowDesc.addType(dataType)
+
+ if _, err := rd.ReadTemp(8); err != nil {
+ return nil, err
+ }
+ }
+
+ return rowDesc, nil
+}
+
+//------------------------------------------------------------------------------
+
+func readNotification(ctx context.Context, rd *reader) (channel, payload string, err error) {
+ for {
+ c, msgLen, err := readMessageType(rd)
+ if err != nil {
+ return "", "", err
+ }
+
+ switch c {
+ case commandCompleteMsg, readyForQueryMsg, noticeResponseMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return "", "", err
+ }
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return "", "", err
+ }
+ return "", "", e
+ case notificationResponseMsg:
+ if err := rd.Discard(4); err != nil {
+ return "", "", err
+ }
+ channel, err = readString(rd)
+ if err != nil {
+ return "", "", err
+ }
+ payload, err = readString(rd)
+ if err != nil {
+ return "", "", err
+ }
+ return channel, payload, nil
+ default:
+ return "", "", fmt.Errorf("pgdriver: readNotification: unexpected message %q", c)
+ }
+ }
+}
+
+//------------------------------------------------------------------------------
+
+func writeParseDescribeSync(ctx context.Context, cn *Conn, name, query string) error {
+ wb := getWriteBuffer()
+ defer putWriteBuffer(wb)
+
+ wb.StartMessage(parseMsg)
+ wb.WriteString(name)
+ wb.WriteString(query)
+ wb.WriteInt16(0)
+ wb.FinishMessage()
+
+ wb.StartMessage(describeMsg)
+ wb.WriteByte('S')
+ wb.WriteString(name)
+ wb.FinishMessage()
+
+ wb.StartMessage(syncMsg)
+ wb.FinishMessage()
+
+ return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error {
+ _, err := wr.Write(wb.Bytes)
+ return err
+ })
+}
+
+func readParseDescribeSync(ctx context.Context, cn *Conn) (*rowDescription, error) {
+ rd := cn.reader(ctx, -1)
+ var numParam int16
+ var rowDesc *rowDescription
+ var firstErr error
+ for {
+ c, msgLen, err := readMessageType(rd)
+ if err != nil {
+ return nil, err
+ }
+
+ switch c {
+ case parseCompleteMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ case rowDescriptionMsg: // response to DESCRIBE message.
+ rowDesc, err = readRowDescription(rd)
+ if err != nil {
+ return nil, err
+ }
+ rowDesc.numInput = numParam
+ case parameterDescriptionMsg: // response to DESCRIBE message.
+ numParam, err = readInt16(rd)
+ if err != nil {
+ return nil, err
+ }
+
+ for i := 0; i < int(numParam); i++ {
+ if _, err := readInt32(rd); err != nil {
+ return nil, err
+ }
+ }
+ case noDataMsg: // response to DESCRIBE message.
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ case readyForQueryMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ if firstErr != nil {
+ return nil, firstErr
+ }
+ return rowDesc, err
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return nil, err
+ }
+ if firstErr == nil {
+ firstErr = e
+ }
+ case noticeResponseMsg, parameterStatusMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ default:
+ return nil, fmt.Errorf("pgdriver: readParseDescribeSync: unexpected message %q", c)
+ }
+ }
+}
+
+func writeBindExecute(ctx context.Context, cn *Conn, name string, args []driver.NamedValue) error {
+ wb := getWriteBuffer()
+ defer putWriteBuffer(wb)
+
+ wb.StartMessage(bindMsg)
+ wb.WriteString("")
+ wb.WriteString(name)
+ wb.WriteInt16(0)
+ wb.WriteInt16(int16(len(args)))
+ for i := range args {
+ wb.StartParam()
+ bytes, err := appendStmtArg(wb.Bytes, args[i].Value)
+ if err != nil {
+ return err
+ }
+ if bytes != nil {
+ wb.Bytes = bytes
+ wb.FinishParam()
+ } else {
+ wb.FinishNullParam()
+ }
+ }
+ wb.WriteInt16(0)
+ wb.FinishMessage()
+
+ wb.StartMessage(executeMsg)
+ wb.WriteString("")
+ wb.WriteInt32(0)
+ wb.FinishMessage()
+
+ wb.StartMessage(syncMsg)
+ wb.FinishMessage()
+
+ return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error {
+ _, err := wr.Write(wb.Bytes)
+ return err
+ })
+}
+
+func readExtQuery(ctx context.Context, cn *Conn) (driver.Result, error) {
+ rd := cn.reader(ctx, -1)
+ var res driver.Result
+ var firstErr error
+ for {
+ c, msgLen, err := readMessageType(rd)
+ if err != nil {
+ return nil, err
+ }
+
+ switch c {
+ case bindCompleteMsg, dataRowMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ case commandCompleteMsg: // response to EXECUTE message.
+ tmp, err := rd.ReadTemp(msgLen)
+ if err != nil {
+ return nil, err
+ }
+
+ r, err := parseResult(tmp)
+ if err != nil {
+ if firstErr == nil {
+ firstErr = err
+ }
+ } else {
+ res = r
+ }
+ case readyForQueryMsg: // Response to SYNC message.
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ if firstErr != nil {
+ return nil, firstErr
+ }
+ return res, nil
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return nil, err
+ }
+ if firstErr == nil {
+ firstErr = e
+ }
+ case emptyQueryResponseMsg:
+ if firstErr == nil {
+ firstErr = errEmptyQuery
+ }
+ case noticeResponseMsg, parameterStatusMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ default:
+ return nil, fmt.Errorf("pgdriver: readExtQuery: unexpected message %q", c)
+ }
+ }
+}
+
+func readExtQueryData(ctx context.Context, cn *Conn, rowDesc *rowDescription) (*rows, error) {
+ rd := cn.reader(ctx, -1)
+ var firstErr error
+ for {
+ c, msgLen, err := readMessageType(rd)
+ if err != nil {
+ return nil, err
+ }
+
+ switch c {
+ case bindCompleteMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ return newRows(cn, rowDesc, false), nil
+ case commandCompleteMsg: // response to EXECUTE message.
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ case readyForQueryMsg: // Response to SYNC message.
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ if firstErr != nil {
+ return nil, firstErr
+ }
+ return &rows{closed: true}, nil
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return nil, err
+ }
+ if firstErr == nil {
+ firstErr = e
+ }
+ case emptyQueryResponseMsg:
+ if firstErr == nil {
+ firstErr = errEmptyQuery
+ }
+ case noticeResponseMsg, parameterStatusMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return nil, err
+ }
+ default:
+ return nil, fmt.Errorf("pgdriver: readExtQueryData: unexpected message %q", c)
+ }
+ }
+}
+
+func writeCloseStmt(ctx context.Context, cn *Conn, name string) error {
+ wb := getWriteBuffer()
+ defer putWriteBuffer(wb)
+
+ wb.StartMessage(closeMsg)
+ wb.WriteByte('S') //nolint
+ wb.WriteString(name)
+ wb.FinishMessage()
+
+ wb.StartMessage(flushMsg)
+ wb.FinishMessage()
+
+ return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error {
+ _, err := wr.Write(wb.Bytes)
+ return err
+ })
+}
+
+func readCloseStmtComplete(ctx context.Context, cn *Conn) error {
+ rd := cn.reader(ctx, -1)
+ for {
+ c, msgLen, err := readMessageType(rd)
+ if err != nil {
+ return err
+ }
+
+ switch c {
+ case closeCompleteMsg:
+ return rd.Discard(msgLen)
+ case errorResponseMsg:
+ e, err := readError(rd)
+ if err != nil {
+ return err
+ }
+ return e
+ case noticeResponseMsg, parameterStatusMsg:
+ if err := rd.Discard(msgLen); err != nil {
+ return err
+ }
+ default:
+ return fmt.Errorf("pgdriver: readCloseCompleteMsg: unexpected message %q", c)
+ }
+ }
+}
+
+//------------------------------------------------------------------------------
+
+func readMessageType(rd *reader) (byte, int, error) {
+ c, err := rd.ReadByte()
+ if err != nil {
+ return 0, 0, err
+ }
+ l, err := readInt32(rd)
+ if err != nil {
+ return 0, 0, err
+ }
+ return c, int(l) - 4, nil
+}
+
+func readInt16(rd *reader) (int16, error) {
+ b, err := rd.ReadTemp(2)
+ if err != nil {
+ return 0, err
+ }
+ return int16(binary.BigEndian.Uint16(b)), nil
+}
+
+func readInt32(rd *reader) (int32, error) {
+ b, err := rd.ReadTemp(4)
+ if err != nil {
+ return 0, err
+ }
+ return int32(binary.BigEndian.Uint32(b)), nil
+}
+
+func readString(rd *reader) (string, error) {
+ b, err := rd.ReadSlice(0)
+ if err != nil {
+ return "", err
+ }
+ return string(b[:len(b)-1]), nil
+}
+
+func readError(rd *reader) (error, error) {
+ m := make(map[byte]string)
+ for {
+ c, err := rd.ReadByte()
+ if err != nil {
+ return nil, err
+ }
+ if c == 0 {
+ break
+ }
+ s, err := readString(rd)
+ if err != nil {
+ return nil, err
+ }
+ m[c] = s
+ }
+ return Error{m: m}, nil
+}
+
+//------------------------------------------------------------------------------
+
+func appendStmtArg(b []byte, v driver.Value) ([]byte, error) {
+ switch v := v.(type) {
+ case nil:
+ return nil, nil
+ case int64:
+ return strconv.AppendInt(b, v, 10), nil
+ case float64:
+ switch {
+ case math.IsNaN(v):
+ return append(b, "NaN"...), nil
+ case math.IsInf(v, 1):
+ return append(b, "Infinity"...), nil
+ case math.IsInf(v, -1):
+ return append(b, "-Infinity"...), nil
+ default:
+ return strconv.AppendFloat(b, v, 'f', -1, 64), nil
+ }
+ case bool:
+ if v {
+ return append(b, "TRUE"...), nil
+ }
+ return append(b, "FALSE"...), nil
+ case []byte:
+ if v == nil {
+ return nil, nil
+ }
+
+ b = append(b, `\x`...)
+
+ s := len(b)
+ b = append(b, make([]byte, hex.EncodedLen(len(v)))...)
+ hex.Encode(b[s:], v)
+
+ return b, nil
+ case string:
+ for _, r := range v {
+ if r == 0 {
+ continue
+ }
+ if r < utf8.RuneSelf {
+ b = append(b, byte(r))
+ continue
+ }
+ l := len(b)
+ if cap(b)-l < utf8.UTFMax {
+ b = append(b, make([]byte, utf8.UTFMax)...)
+ }
+ n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
+ b = b[:l+n]
+ }
+ return b, nil
+ case time.Time:
+ if v.IsZero() {
+ return nil, nil
+ }
+ return v.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00"), nil
+ default:
+ return nil, fmt.Errorf("pgdriver: unexpected arg: %T", v)
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/safe.go b/vendor/github.com/uptrace/bun/driver/pgdriver/safe.go
new file mode 100644
index 000000000..fab151a78
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/safe.go
@@ -0,0 +1,11 @@
+// +build appengine
+
+package internal
+
+func bytesToString(b []byte) string {
+ return string(b)
+}
+
+func stringToBytes(s string) []byte {
+ return []byte(s)
+}
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/unsafe.go b/vendor/github.com/uptrace/bun/driver/pgdriver/unsafe.go
new file mode 100644
index 000000000..6ba868105
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/unsafe.go
@@ -0,0 +1,19 @@
+// +build !appengine
+
+package pgdriver
+
+import "unsafe"
+
+func bytesToString(b []byte) string {
+ return *(*string)(unsafe.Pointer(&b))
+}
+
+//nolint:deadcode,unused
+func stringToBytes(s string) []byte {
+ return *(*[]byte)(unsafe.Pointer(
+ &struct {
+ string
+ Cap int
+ }{s, len(s)},
+ ))
+}
diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/write_buffer.go b/vendor/github.com/uptrace/bun/driver/pgdriver/write_buffer.go
new file mode 100644
index 000000000..cb683563d
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/driver/pgdriver/write_buffer.go
@@ -0,0 +1,112 @@
+package pgdriver
+
+import (
+ "encoding/binary"
+ "io"
+ "sync"
+)
+
+var wbPool = sync.Pool{
+ New: func() interface{} {
+ return newWriteBuffer()
+ },
+}
+
+func getWriteBuffer() *writeBuffer {
+ wb := wbPool.Get().(*writeBuffer)
+ return wb
+}
+
+func putWriteBuffer(wb *writeBuffer) {
+ wb.Reset()
+ wbPool.Put(wb)
+}
+
+type writeBuffer struct {
+ Bytes []byte
+
+ msgStart int
+ paramStart int
+}
+
+func newWriteBuffer() *writeBuffer {
+ return &writeBuffer{
+ Bytes: make([]byte, 0, 1024),
+ }
+}
+
+func (b *writeBuffer) Reset() {
+ b.Bytes = b.Bytes[:0]
+}
+
+func (b *writeBuffer) StartMessage(c byte) {
+ if c == 0 {
+ b.msgStart = len(b.Bytes)
+ b.Bytes = append(b.Bytes, 0, 0, 0, 0)
+ } else {
+ b.msgStart = len(b.Bytes) + 1
+ b.Bytes = append(b.Bytes, c, 0, 0, 0, 0)
+ }
+}
+
+func (b *writeBuffer) FinishMessage() {
+ binary.BigEndian.PutUint32(
+ b.Bytes[b.msgStart:], uint32(len(b.Bytes)-b.msgStart))
+}
+
+func (b *writeBuffer) Query() []byte {
+ return b.Bytes[b.msgStart+4 : len(b.Bytes)-1]
+}
+
+func (b *writeBuffer) StartParam() {
+ b.paramStart = len(b.Bytes)
+ b.Bytes = append(b.Bytes, 0, 0, 0, 0)
+}
+
+func (b *writeBuffer) FinishParam() {
+ binary.BigEndian.PutUint32(
+ b.Bytes[b.paramStart:], uint32(len(b.Bytes)-b.paramStart-4))
+}
+
+var nullParamLength = int32(-1)
+
+func (b *writeBuffer) FinishNullParam() {
+ binary.BigEndian.PutUint32(
+ b.Bytes[b.paramStart:], uint32(nullParamLength))
+}
+
+func (b *writeBuffer) Write(data []byte) (int, error) {
+ b.Bytes = append(b.Bytes, data...)
+ return len(data), nil
+}
+
+func (b *writeBuffer) WriteInt16(num int16) {
+ b.Bytes = append(b.Bytes, 0, 0)
+ binary.BigEndian.PutUint16(b.Bytes[len(b.Bytes)-2:], uint16(num))
+}
+
+func (b *writeBuffer) WriteInt32(num int32) {
+ b.Bytes = append(b.Bytes, 0, 0, 0, 0)
+ binary.BigEndian.PutUint32(b.Bytes[len(b.Bytes)-4:], uint32(num))
+}
+
+func (b *writeBuffer) WriteString(s string) {
+ b.Bytes = append(b.Bytes, s...)
+ b.Bytes = append(b.Bytes, 0)
+}
+
+func (b *writeBuffer) WriteBytes(data []byte) {
+ b.Bytes = append(b.Bytes, data...)
+ b.Bytes = append(b.Bytes, 0)
+}
+
+func (b *writeBuffer) WriteByte(c byte) error {
+ b.Bytes = append(b.Bytes, c)
+ return nil
+}
+
+func (b *writeBuffer) ReadFrom(r io.Reader) (int64, error) {
+ n, err := r.Read(b.Bytes[len(b.Bytes):cap(b.Bytes)])
+ b.Bytes = b.Bytes[:len(b.Bytes)+n]
+ return int64(n), err
+}
diff --git a/vendor/github.com/uptrace/bun/extra/bunjson/json.go b/vendor/github.com/uptrace/bun/extra/bunjson/json.go
new file mode 100644
index 000000000..eff9d3f0e
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/extra/bunjson/json.go
@@ -0,0 +1,26 @@
+package bunjson
+
+import (
+ "encoding/json"
+ "io"
+)
+
+var _ Provider = (*StdProvider)(nil)
+
+type StdProvider struct{}
+
+func (StdProvider) Marshal(v interface{}) ([]byte, error) {
+ return json.Marshal(v)
+}
+
+func (StdProvider) Unmarshal(data []byte, v interface{}) error {
+ return json.Unmarshal(data, v)
+}
+
+func (StdProvider) NewEncoder(w io.Writer) Encoder {
+ return json.NewEncoder(w)
+}
+
+func (StdProvider) NewDecoder(r io.Reader) Decoder {
+ return json.NewDecoder(r)
+}
diff --git a/vendor/github.com/uptrace/bun/extra/bunjson/provider.go b/vendor/github.com/uptrace/bun/extra/bunjson/provider.go
new file mode 100644
index 000000000..7f810e122
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/extra/bunjson/provider.go
@@ -0,0 +1,43 @@
+package bunjson
+
+import (
+ "io"
+)
+
+var provider Provider = StdProvider{}
+
+func SetProvider(p Provider) {
+ provider = p
+}
+
+type Provider interface {
+ Marshal(v interface{}) ([]byte, error)
+ Unmarshal(data []byte, v interface{}) error
+ NewEncoder(w io.Writer) Encoder
+ NewDecoder(r io.Reader) Decoder
+}
+
+type Decoder interface {
+ Decode(v interface{}) error
+ UseNumber()
+}
+
+type Encoder interface {
+ Encode(v interface{}) error
+}
+
+func Marshal(v interface{}) ([]byte, error) {
+ return provider.Marshal(v)
+}
+
+func Unmarshal(data []byte, v interface{}) error {
+ return provider.Unmarshal(data, v)
+}
+
+func NewEncoder(w io.Writer) Encoder {
+ return provider.NewEncoder(w)
+}
+
+func NewDecoder(r io.Reader) Decoder {
+ return provider.NewDecoder(r)
+}
diff --git a/vendor/github.com/uptrace/bun/go.mod b/vendor/github.com/uptrace/bun/go.mod
new file mode 100644
index 000000000..92def2a3d
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/go.mod
@@ -0,0 +1,12 @@
+module github.com/uptrace/bun
+
+go 1.16
+
+require (
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/jinzhu/inflection v1.0.0
+ github.com/stretchr/testify v1.7.0
+ github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc
+ github.com/vmihailenco/msgpack/v5 v5.3.4
+ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 // indirect
+)
diff --git a/vendor/github.com/uptrace/bun/go.sum b/vendor/github.com/uptrace/bun/go.sum
new file mode 100644
index 000000000..3bf0a4a3f
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/go.sum
@@ -0,0 +1,23 @@
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
+github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
+github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
+github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc=
+github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
+github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
+github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
+golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio=
+golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/vendor/github.com/uptrace/bun/hook.go b/vendor/github.com/uptrace/bun/hook.go
new file mode 100644
index 000000000..4cfa68fa6
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/hook.go
@@ -0,0 +1,98 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "reflect"
+ "sync/atomic"
+ "time"
+
+ "github.com/uptrace/bun/schema"
+)
+
+type QueryEvent struct {
+ DB *DB
+
+ QueryAppender schema.QueryAppender
+ Query string
+ QueryArgs []interface{}
+
+ StartTime time.Time
+ Result sql.Result
+ Err error
+
+ Stash map[interface{}]interface{}
+}
+
+type QueryHook interface {
+ BeforeQuery(context.Context, *QueryEvent) context.Context
+ AfterQuery(context.Context, *QueryEvent)
+}
+
+func (db *DB) beforeQuery(
+ ctx context.Context,
+ queryApp schema.QueryAppender,
+ query string,
+ queryArgs []interface{},
+) (context.Context, *QueryEvent) {
+ atomic.AddUint64(&db.stats.Queries, 1)
+
+ if len(db.queryHooks) == 0 {
+ return ctx, nil
+ }
+
+ event := &QueryEvent{
+ DB: db,
+
+ QueryAppender: queryApp,
+ Query: query,
+ QueryArgs: queryArgs,
+
+ StartTime: time.Now(),
+ }
+
+ for _, hook := range db.queryHooks {
+ ctx = hook.BeforeQuery(ctx, event)
+ }
+
+ return ctx, event
+}
+
+func (db *DB) afterQuery(
+ ctx context.Context,
+ event *QueryEvent,
+ res sql.Result,
+ err error,
+) {
+ switch err {
+ case nil, sql.ErrNoRows:
+ // nothing
+ default:
+ atomic.AddUint64(&db.stats.Errors, 1)
+ }
+
+ if event == nil {
+ return
+ }
+
+ event.Result = res
+ event.Err = err
+
+ db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1)
+}
+
+func (db *DB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) {
+ for ; hookIndex >= 0; hookIndex-- {
+ db.queryHooks[hookIndex].AfterQuery(ctx, event)
+ }
+}
+
+//------------------------------------------------------------------------------
+
+func callBeforeScanHook(ctx context.Context, v reflect.Value) error {
+ return v.Interface().(schema.BeforeScanHook).BeforeScan(ctx)
+}
+
+func callAfterScanHook(ctx context.Context, v reflect.Value) error {
+ return v.Interface().(schema.AfterScanHook).AfterScan(ctx)
+}
diff --git a/vendor/github.com/uptrace/bun/internal/flag.go b/vendor/github.com/uptrace/bun/internal/flag.go
new file mode 100644
index 000000000..b42f59df7
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/internal/flag.go
@@ -0,0 +1,16 @@
+package internal
+
+type Flag uint64
+
+func (flag Flag) Has(other Flag) bool {
+ return flag&other == other
+}
+
+func (flag Flag) Set(other Flag) Flag {
+ return flag | other
+}
+
+func (flag Flag) Remove(other Flag) Flag {
+ flag &= ^other
+ return flag
+}
diff --git a/vendor/github.com/uptrace/bun/internal/hex.go b/vendor/github.com/uptrace/bun/internal/hex.go
new file mode 100644
index 000000000..6fae2bb78
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/internal/hex.go
@@ -0,0 +1,43 @@
+package internal
+
+import (
+ fasthex "github.com/tmthrgd/go-hex"
+)
+
+type HexEncoder struct {
+ b []byte
+ written bool
+}
+
+func NewHexEncoder(b []byte) *HexEncoder {
+ return &HexEncoder{
+ b: b,
+ }
+}
+
+func (enc *HexEncoder) Bytes() []byte {
+ return enc.b
+}
+
+func (enc *HexEncoder) Write(b []byte) (int, error) {
+ if !enc.written {
+ enc.b = append(enc.b, '\'')
+ enc.b = append(enc.b, `\x`...)
+ enc.written = true
+ }
+
+ i := len(enc.b)
+ enc.b = append(enc.b, make([]byte, fasthex.EncodedLen(len(b)))...)
+ fasthex.Encode(enc.b[i:], b)
+
+ return len(b), nil
+}
+
+func (enc *HexEncoder) Close() error {
+ if enc.written {
+ enc.b = append(enc.b, '\'')
+ } else {
+ enc.b = append(enc.b, "NULL"...)
+ }
+ return nil
+}
diff --git a/vendor/github.com/uptrace/bun/internal/logger.go b/vendor/github.com/uptrace/bun/internal/logger.go
new file mode 100644
index 000000000..2e22a0893
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/internal/logger.go
@@ -0,0 +1,27 @@
+package internal
+
+import (
+ "fmt"
+ "log"
+ "os"
+)
+
+var Warn = log.New(os.Stderr, "WARN: bun: ", log.LstdFlags)
+
+var Deprecated = log.New(os.Stderr, "DEPRECATED: bun: ", log.LstdFlags)
+
+type Logging interface {
+ Printf(format string, v ...interface{})
+}
+
+type logger struct {
+ log *log.Logger
+}
+
+func (l *logger) Printf(format string, v ...interface{}) {
+ _ = l.log.Output(2, fmt.Sprintf(format, v...))
+}
+
+var Logger Logging = &logger{
+ log: log.New(os.Stderr, "bun: ", log.LstdFlags|log.Lshortfile),
+}
diff --git a/vendor/github.com/uptrace/bun/internal/map_key.go b/vendor/github.com/uptrace/bun/internal/map_key.go
new file mode 100644
index 000000000..bb5fcca8c
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/internal/map_key.go
@@ -0,0 +1,67 @@
+package internal
+
+import "reflect"
+
+var ifaceType = reflect.TypeOf((*interface{})(nil)).Elem()
+
+type MapKey struct {
+ iface interface{}
+}
+
+func NewMapKey(is []interface{}) MapKey {
+ return MapKey{
+ iface: newMapKey(is),
+ }
+}
+
+func newMapKey(is []interface{}) interface{} {
+ switch len(is) {
+ case 1:
+ ptr := new([1]interface{})
+ copy((*ptr)[:], is)
+ return *ptr
+ case 2:
+ ptr := new([2]interface{})
+ copy((*ptr)[:], is)
+ return *ptr
+ case 3:
+ ptr := new([3]interface{})
+ copy((*ptr)[:], is)
+ return *ptr
+ case 4:
+ ptr := new([4]interface{})
+ copy((*ptr)[:], is)
+ return *ptr
+ case 5:
+ ptr := new([5]interface{})
+ copy((*ptr)[:], is)
+ return *ptr
+ case 6:
+ ptr := new([6]interface{})
+ copy((*ptr)[:], is)
+ return *ptr
+ case 7:
+ ptr := new([7]interface{})
+ copy((*ptr)[:], is)
+ return *ptr
+ case 8:
+ ptr := new([8]interface{})
+ copy((*ptr)[:], is)
+ return *ptr
+ case 9:
+ ptr := new([9]interface{})
+ copy((*ptr)[:], is)
+ return *ptr
+ case 10:
+ ptr := new([10]interface{})
+ copy((*ptr)[:], is)
+ return *ptr
+ default:
+ }
+
+ at := reflect.New(reflect.ArrayOf(len(is), ifaceType)).Elem()
+ for i, v := range is {
+ *(at.Index(i).Addr().Interface().(*interface{})) = v
+ }
+ return at.Interface()
+}
diff --git a/vendor/github.com/uptrace/bun/internal/parser/parser.go b/vendor/github.com/uptrace/bun/internal/parser/parser.go
new file mode 100644
index 000000000..cdfc0be16
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/internal/parser/parser.go
@@ -0,0 +1,141 @@
+package parser
+
+import (
+ "bytes"
+ "strconv"
+
+ "github.com/uptrace/bun/internal"
+)
+
+type Parser struct {
+ b []byte
+ i int
+}
+
+func New(b []byte) *Parser {
+ return &Parser{
+ b: b,
+ }
+}
+
+func NewString(s string) *Parser {
+ return New(internal.Bytes(s))
+}
+
+func (p *Parser) Valid() bool {
+ return p.i < len(p.b)
+}
+
+func (p *Parser) Bytes() []byte {
+ return p.b[p.i:]
+}
+
+func (p *Parser) Read() byte {
+ if p.Valid() {
+ c := p.b[p.i]
+ p.Advance()
+ return c
+ }
+ return 0
+}
+
+func (p *Parser) Peek() byte {
+ if p.Valid() {
+ return p.b[p.i]
+ }
+ return 0
+}
+
+func (p *Parser) Advance() {
+ p.i++
+}
+
+func (p *Parser) Skip(skip byte) bool {
+ if p.Peek() == skip {
+ p.Advance()
+ return true
+ }
+ return false
+}
+
+func (p *Parser) SkipBytes(skip []byte) bool {
+ if len(skip) > len(p.b[p.i:]) {
+ return false
+ }
+ if !bytes.Equal(p.b[p.i:p.i+len(skip)], skip) {
+ return false
+ }
+ p.i += len(skip)
+ return true
+}
+
+func (p *Parser) ReadSep(sep byte) ([]byte, bool) {
+ ind := bytes.IndexByte(p.b[p.i:], sep)
+ if ind == -1 {
+ b := p.b[p.i:]
+ p.i = len(p.b)
+ return b, false
+ }
+
+ b := p.b[p.i : p.i+ind]
+ p.i += ind + 1
+ return b, true
+}
+
+func (p *Parser) ReadIdentifier() (string, bool) {
+ if p.i < len(p.b) && p.b[p.i] == '(' {
+ s := p.i + 1
+ if ind := bytes.IndexByte(p.b[s:], ')'); ind != -1 {
+ b := p.b[s : s+ind]
+ p.i = s + ind + 1
+ return internal.String(b), false
+ }
+ }
+
+ ind := len(p.b) - p.i
+ var alpha bool
+ for i, c := range p.b[p.i:] {
+ if isNum(c) {
+ continue
+ }
+ if isAlpha(c) || (i > 0 && alpha && c == '_') {
+ alpha = true
+ continue
+ }
+ ind = i
+ break
+ }
+ if ind == 0 {
+ return "", false
+ }
+ b := p.b[p.i : p.i+ind]
+ p.i += ind
+ return internal.String(b), !alpha
+}
+
+func (p *Parser) ReadNumber() int {
+ ind := len(p.b) - p.i
+ for i, c := range p.b[p.i:] {
+ if !isNum(c) {
+ ind = i
+ break
+ }
+ }
+ if ind == 0 {
+ return 0
+ }
+ n, err := strconv.Atoi(string(p.b[p.i : p.i+ind]))
+ if err != nil {
+ panic(err)
+ }
+ p.i += ind
+ return n
+}
+
+func isNum(c byte) bool {
+ return c >= '0' && c <= '9'
+}
+
+func isAlpha(c byte) bool {
+ return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
+}
diff --git a/vendor/github.com/uptrace/bun/internal/safe.go b/vendor/github.com/uptrace/bun/internal/safe.go
new file mode 100644
index 000000000..862ff0eb3
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/internal/safe.go
@@ -0,0 +1,11 @@
+// +build appengine
+
+package internal
+
+func String(b []byte) string {
+ return string(b)
+}
+
+func Bytes(s string) []byte {
+ return []byte(s)
+}
diff --git a/vendor/github.com/uptrace/bun/internal/tagparser/parser.go b/vendor/github.com/uptrace/bun/internal/tagparser/parser.go
new file mode 100644
index 000000000..8ef89248c
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/internal/tagparser/parser.go
@@ -0,0 +1,147 @@
+package tagparser
+
+import (
+ "strings"
+)
+
+type Tag struct {
+ Name string
+ Options map[string]string
+}
+
+func (t Tag) HasOption(name string) bool {
+ _, ok := t.Options[name]
+ return ok
+}
+
+func Parse(s string) Tag {
+ p := parser{
+ s: s,
+ }
+ p.parse()
+ return p.tag
+}
+
+type parser struct {
+ s string
+ i int
+
+ tag Tag
+ seenName bool // for empty names
+}
+
+func (p *parser) setName(name string) {
+ if p.seenName {
+ p.addOption(name, "")
+ } else {
+ p.seenName = true
+ p.tag.Name = name
+ }
+}
+
+func (p *parser) addOption(key, value string) {
+ p.seenName = true
+ if key == "" {
+ return
+ }
+ if p.tag.Options == nil {
+ p.tag.Options = make(map[string]string)
+ }
+ p.tag.Options[key] = value
+}
+
+func (p *parser) parse() {
+ for p.valid() {
+ p.parseKeyValue()
+ if p.peek() == ',' {
+ p.i++
+ }
+ }
+}
+
+func (p *parser) parseKeyValue() {
+ start := p.i
+
+ for p.valid() {
+ switch c := p.read(); c {
+ case ',':
+ key := p.s[start : p.i-1]
+ p.setName(key)
+ return
+ case ':':
+ key := p.s[start : p.i-1]
+ value := p.parseValue()
+ p.addOption(key, value)
+ return
+ case '"':
+ key := p.parseQuotedValue()
+ p.setName(key)
+ return
+ }
+ }
+
+ key := p.s[start:p.i]
+ p.setName(key)
+}
+
+func (p *parser) parseValue() string {
+ start := p.i
+
+ for p.valid() {
+ switch c := p.read(); c {
+ case '"':
+ return p.parseQuotedValue()
+ case ',':
+ return p.s[start : p.i-1]
+ }
+ }
+
+ if p.i == start {
+ return ""
+ }
+ return p.s[start:p.i]
+}
+
+func (p *parser) parseQuotedValue() string {
+ if i := strings.IndexByte(p.s[p.i:], '"'); i >= 0 && p.s[p.i+i-1] != '\\' {
+ s := p.s[p.i : p.i+i]
+ p.i += i + 1
+ return s
+ }
+
+ b := make([]byte, 0, 16)
+
+ for p.valid() {
+ switch c := p.read(); c {
+ case '\\':
+ b = append(b, p.read())
+ case '"':
+ return string(b)
+ default:
+ b = append(b, c)
+ }
+ }
+
+ return ""
+}
+
+func (p *parser) valid() bool {
+ return p.i < len(p.s)
+}
+
+func (p *parser) read() byte {
+ if !p.valid() {
+ return 0
+ }
+ c := p.s[p.i]
+ p.i++
+ return c
+}
+
+func (p *parser) peek() byte {
+ if !p.valid() {
+ return 0
+ }
+ c := p.s[p.i]
+ return c
+}
diff --git a/vendor/github.com/uptrace/bun/internal/time.go b/vendor/github.com/uptrace/bun/internal/time.go
new file mode 100644
index 000000000..e4e0804b0
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/internal/time.go
@@ -0,0 +1,41 @@
+package internal
+
+import (
+ "fmt"
+ "time"
+)
+
+const (
+ dateFormat = "2006-01-02"
+ timeFormat = "15:04:05.999999999"
+ timestampFormat = "2006-01-02 15:04:05.999999999"
+ timestamptzFormat = "2006-01-02 15:04:05.999999999-07:00:00"
+ timestamptzFormat2 = "2006-01-02 15:04:05.999999999-07:00"
+ timestamptzFormat3 = "2006-01-02 15:04:05.999999999-07"
+)
+
+func ParseTime(s string) (time.Time, error) {
+ switch l := len(s); {
+ case l < len("15:04:05"):
+ return time.Time{}, fmt.Errorf("bun: can't parse time=%q", s)
+ case l <= len(timeFormat):
+ if s[2] == ':' {
+ return time.ParseInLocation(timeFormat, s, time.UTC)
+ }
+ return time.ParseInLocation(dateFormat, s, time.UTC)
+ default:
+ if s[10] == 'T' {
+ return time.Parse(time.RFC3339Nano, s)
+ }
+ if c := s[l-9]; c == '+' || c == '-' {
+ return time.Parse(timestamptzFormat, s)
+ }
+ if c := s[l-6]; c == '+' || c == '-' {
+ return time.Parse(timestamptzFormat2, s)
+ }
+ if c := s[l-3]; c == '+' || c == '-' {
+ return time.Parse(timestamptzFormat3, s)
+ }
+ return time.ParseInLocation(timestampFormat, s, time.UTC)
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/internal/underscore.go b/vendor/github.com/uptrace/bun/internal/underscore.go
new file mode 100644
index 000000000..9de52fb7b
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/internal/underscore.go
@@ -0,0 +1,67 @@
+package internal
+
+func IsUpper(c byte) bool {
+ return c >= 'A' && c <= 'Z'
+}
+
+func IsLower(c byte) bool {
+ return c >= 'a' && c <= 'z'
+}
+
+func ToUpper(c byte) byte {
+ return c - 32
+}
+
+func ToLower(c byte) byte {
+ return c + 32
+}
+
+// Underscore converts "CamelCasedString" to "camel_cased_string".
+func Underscore(s string) string {
+ r := make([]byte, 0, len(s)+5)
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ if IsUpper(c) {
+ if i > 0 && i+1 < len(s) && (IsLower(s[i-1]) || IsLower(s[i+1])) {
+ r = append(r, '_', ToLower(c))
+ } else {
+ r = append(r, ToLower(c))
+ }
+ } else {
+ r = append(r, c)
+ }
+ }
+ return string(r)
+}
+
+func CamelCased(s string) string {
+ r := make([]byte, 0, len(s))
+ upperNext := true
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ if c == '_' {
+ upperNext = true
+ continue
+ }
+ if upperNext {
+ if IsLower(c) {
+ c = ToUpper(c)
+ }
+ upperNext = false
+ }
+ r = append(r, c)
+ }
+ return string(r)
+}
+
+func ToExported(s string) string {
+ if len(s) == 0 {
+ return s
+ }
+ if c := s[0]; IsLower(c) {
+ b := []byte(s)
+ b[0] = ToUpper(c)
+ return string(b)
+ }
+ return s
+}
diff --git a/vendor/github.com/uptrace/bun/internal/unsafe.go b/vendor/github.com/uptrace/bun/internal/unsafe.go
new file mode 100644
index 000000000..4bc79701f
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/internal/unsafe.go
@@ -0,0 +1,20 @@
+// +build !appengine
+
+package internal
+
+import "unsafe"
+
+// String converts byte slice to string.
+func String(b []byte) string {
+ return *(*string)(unsafe.Pointer(&b))
+}
+
+// Bytes converts string to byte slice.
+func Bytes(s string) []byte {
+ return *(*[]byte)(unsafe.Pointer(
+ &struct {
+ string
+ Cap int
+ }{s, len(s)},
+ ))
+}
diff --git a/vendor/github.com/uptrace/bun/internal/util.go b/vendor/github.com/uptrace/bun/internal/util.go
new file mode 100644
index 000000000..c831dc659
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/internal/util.go
@@ -0,0 +1,57 @@
+package internal
+
+import (
+ "reflect"
+)
+
+func MakeSliceNextElemFunc(v reflect.Value) func() reflect.Value {
+ if v.Kind() == reflect.Array {
+ var pos int
+ return func() reflect.Value {
+ v := v.Index(pos)
+ pos++
+ return v
+ }
+ }
+
+ elemType := v.Type().Elem()
+
+ if elemType.Kind() == reflect.Ptr {
+ elemType = elemType.Elem()
+ return func() reflect.Value {
+ if v.Len() < v.Cap() {
+ v.Set(v.Slice(0, v.Len()+1))
+ elem := v.Index(v.Len() - 1)
+ if elem.IsNil() {
+ elem.Set(reflect.New(elemType))
+ }
+ return elem.Elem()
+ }
+
+ elem := reflect.New(elemType)
+ v.Set(reflect.Append(v, elem))
+ return elem.Elem()
+ }
+ }
+
+ zero := reflect.Zero(elemType)
+ return func() reflect.Value {
+ if v.Len() < v.Cap() {
+ v.Set(v.Slice(0, v.Len()+1))
+ return v.Index(v.Len() - 1)
+ }
+
+ v.Set(reflect.Append(v, zero))
+ return v.Index(v.Len() - 1)
+ }
+}
+
+func Unwrap(err error) error {
+ u, ok := err.(interface {
+ Unwrap() error
+ })
+ if !ok {
+ return nil
+ }
+ return u.Unwrap()
+}
diff --git a/vendor/github.com/uptrace/bun/join.go b/vendor/github.com/uptrace/bun/join.go
new file mode 100644
index 000000000..4557f5bc0
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/join.go
@@ -0,0 +1,308 @@
+package bun
+
+import (
+ "context"
+ "reflect"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type join struct {
+ Parent *join
+ BaseModel tableModel
+ JoinModel tableModel
+ Relation *schema.Relation
+
+ ApplyQueryFunc func(*SelectQuery) *SelectQuery
+ columns []schema.QueryWithArgs
+}
+
+func (j *join) applyQuery(q *SelectQuery) {
+ if j.ApplyQueryFunc == nil {
+ return
+ }
+
+ var table *schema.Table
+ var columns []schema.QueryWithArgs
+
+ // Save state.
+ table, q.table = q.table, j.JoinModel.Table()
+ columns, q.columns = q.columns, nil
+
+ q = j.ApplyQueryFunc(q)
+
+ // Restore state.
+ q.table = table
+ j.columns, q.columns = q.columns, columns
+}
+
+func (j *join) Select(ctx context.Context, q *SelectQuery) error {
+ switch j.Relation.Type {
+ case schema.HasManyRelation:
+ return j.selectMany(ctx, q)
+ case schema.ManyToManyRelation:
+ return j.selectM2M(ctx, q)
+ }
+ panic("not reached")
+}
+
+func (j *join) selectMany(ctx context.Context, q *SelectQuery) error {
+ q = j.manyQuery(q)
+ if q == nil {
+ return nil
+ }
+ return q.Scan(ctx)
+}
+
+func (j *join) manyQuery(q *SelectQuery) *SelectQuery {
+ hasManyModel := newHasManyModel(j)
+ if hasManyModel == nil {
+ return nil
+ }
+
+ q = q.Model(hasManyModel)
+
+ var where []byte
+ if len(j.Relation.JoinFields) > 1 {
+ where = append(where, '(')
+ }
+ where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinFields)
+ if len(j.Relation.JoinFields) > 1 {
+ where = append(where, ')')
+ }
+ where = append(where, " IN ("...)
+ where = appendChildValues(
+ q.db.Formatter(),
+ where,
+ j.JoinModel.Root(),
+ j.JoinModel.ParentIndex(),
+ j.Relation.BaseFields,
+ )
+ where = append(where, ")"...)
+ q = q.Where(internal.String(where))
+
+ if j.Relation.PolymorphicField != nil {
+ q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue)
+ }
+
+ j.applyQuery(q)
+ q = q.Apply(j.hasManyColumns)
+
+ return q
+}
+
+func (j *join) hasManyColumns(q *SelectQuery) *SelectQuery {
+ if j.Relation.M2MTable != nil {
+ q = q.ColumnExpr(string(j.Relation.M2MTable.SQLAlias) + ".*")
+ }
+
+ b := make([]byte, 0, 32)
+
+ if len(j.columns) > 0 {
+ for i, col := range j.columns {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ var err error
+ b, err = col.AppendQuery(q.db.fmter, b)
+ if err != nil {
+ q.err = err
+ return q
+ }
+ }
+ } else {
+ joinTable := j.JoinModel.Table()
+ b = appendColumns(b, joinTable.SQLAlias, joinTable.Fields)
+ }
+
+ q = q.ColumnExpr(internal.String(b))
+
+ return q
+}
+
+func (j *join) selectM2M(ctx context.Context, q *SelectQuery) error {
+ q = j.m2mQuery(q)
+ if q == nil {
+ return nil
+ }
+ return q.Scan(ctx)
+}
+
+func (j *join) m2mQuery(q *SelectQuery) *SelectQuery {
+ fmter := q.db.fmter
+
+ m2mModel := newM2MModel(j)
+ if m2mModel == nil {
+ return nil
+ }
+ q = q.Model(m2mModel)
+
+ index := j.JoinModel.ParentIndex()
+ baseTable := j.BaseModel.Table()
+
+ //nolint
+ var join []byte
+ join = append(join, "JOIN "...)
+ join = fmter.AppendQuery(join, string(j.Relation.M2MTable.Name))
+ join = append(join, " AS "...)
+ join = append(join, j.Relation.M2MTable.SQLAlias...)
+ join = append(join, " ON ("...)
+ for i, col := range j.Relation.M2MBaseFields {
+ if i > 0 {
+ join = append(join, ", "...)
+ }
+ join = append(join, j.Relation.M2MTable.SQLAlias...)
+ join = append(join, '.')
+ join = append(join, col.SQLName...)
+ }
+ join = append(join, ") IN ("...)
+ join = appendChildValues(fmter, join, j.BaseModel.Root(), index, baseTable.PKs)
+ join = append(join, ")"...)
+ q = q.Join(internal.String(join))
+
+ joinTable := j.JoinModel.Table()
+ for i, m2mJoinField := range j.Relation.M2MJoinFields {
+ joinField := j.Relation.JoinFields[i]
+ q = q.Where("?.? = ?.?",
+ joinTable.SQLAlias, joinField.SQLName,
+ j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName)
+ }
+
+ j.applyQuery(q)
+ q = q.Apply(j.hasManyColumns)
+
+ return q
+}
+
+func (j *join) hasParent() bool {
+ if j.Parent != nil {
+ switch j.Parent.Relation.Type {
+ case schema.HasOneRelation, schema.BelongsToRelation:
+ return true
+ }
+ }
+ return false
+}
+
+func (j *join) appendAlias(fmter schema.Formatter, b []byte) []byte {
+ quote := fmter.IdentQuote()
+
+ b = append(b, quote)
+ b = appendAlias(b, j)
+ b = append(b, quote)
+ return b
+}
+
+func (j *join) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte {
+ quote := fmter.IdentQuote()
+
+ b = append(b, quote)
+ b = appendAlias(b, j)
+ b = append(b, "__"...)
+ b = append(b, column...)
+ b = append(b, quote)
+ return b
+}
+
+func (j *join) appendBaseAlias(fmter schema.Formatter, b []byte) []byte {
+ quote := fmter.IdentQuote()
+
+ if j.hasParent() {
+ b = append(b, quote)
+ b = appendAlias(b, j.Parent)
+ b = append(b, quote)
+ return b
+ }
+ return append(b, j.BaseModel.Table().SQLAlias...)
+}
+
+func (j *join) appendSoftDelete(b []byte, flags internal.Flag) []byte {
+ b = append(b, '.')
+ b = append(b, j.JoinModel.Table().SoftDeleteField.SQLName...)
+ if flags.Has(deletedFlag) {
+ b = append(b, " IS NOT NULL"...)
+ } else {
+ b = append(b, " IS NULL"...)
+ }
+ return b
+}
+
+func appendAlias(b []byte, j *join) []byte {
+ if j.hasParent() {
+ b = appendAlias(b, j.Parent)
+ b = append(b, "__"...)
+ }
+ b = append(b, j.Relation.Field.Name...)
+ return b
+}
+
+func (j *join) appendHasOneJoin(
+ fmter schema.Formatter, b []byte, q *SelectQuery,
+) (_ []byte, err error) {
+ isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag)
+
+ b = append(b, "LEFT JOIN "...)
+ b = fmter.AppendQuery(b, string(j.JoinModel.Table().SQLNameForSelects))
+ b = append(b, " AS "...)
+ b = j.appendAlias(fmter, b)
+
+ b = append(b, " ON "...)
+
+ b = append(b, '(')
+ for i, baseField := range j.Relation.BaseFields {
+ if i > 0 {
+ b = append(b, " AND "...)
+ }
+ b = j.appendAlias(fmter, b)
+ b = append(b, '.')
+ b = append(b, j.Relation.JoinFields[i].SQLName...)
+ b = append(b, " = "...)
+ b = j.appendBaseAlias(fmter, b)
+ b = append(b, '.')
+ b = append(b, baseField.SQLName...)
+ }
+ b = append(b, ')')
+
+ if isSoftDelete {
+ b = append(b, " AND "...)
+ b = j.appendAlias(fmter, b)
+ b = j.appendSoftDelete(b, q.flags)
+ }
+
+ return b, nil
+}
+
+func appendChildValues(
+ fmter schema.Formatter, b []byte, v reflect.Value, index []int, fields []*schema.Field,
+) []byte {
+ seen := make(map[string]struct{})
+ walk(v, index, func(v reflect.Value) {
+ start := len(b)
+
+ if len(fields) > 1 {
+ b = append(b, '(')
+ }
+ for i, f := range fields {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b = f.AppendValue(fmter, b, v)
+ }
+ if len(fields) > 1 {
+ b = append(b, ')')
+ }
+ b = append(b, ", "...)
+
+ if _, ok := seen[string(b[start:])]; ok {
+ b = b[:start]
+ } else {
+ seen[string(b[start:])] = struct{}{}
+ }
+ })
+ if len(seen) > 0 {
+ b = b[:len(b)-2] // trim ", "
+ }
+ return b
+}
diff --git a/vendor/github.com/uptrace/bun/model.go b/vendor/github.com/uptrace/bun/model.go
new file mode 100644
index 000000000..c9f0f3583
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/model.go
@@ -0,0 +1,207 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "reflect"
+ "time"
+
+ "github.com/uptrace/bun/schema"
+)
+
+var errNilModel = errors.New("bun: Model(nil)")
+
+var timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
+
+type Model interface {
+ ScanRows(ctx context.Context, rows *sql.Rows) (int, error)
+ Value() interface{}
+}
+
+type rowScanner interface {
+ ScanRow(ctx context.Context, rows *sql.Rows) error
+}
+
+type model interface {
+ Model
+}
+
+type tableModel interface {
+ model
+
+ schema.BeforeScanHook
+ schema.AfterScanHook
+ ScanColumn(column string, src interface{}) error
+
+ Table() *schema.Table
+ Relation() *schema.Relation
+
+ Join(string, func(*SelectQuery) *SelectQuery) *join
+ GetJoin(string) *join
+ GetJoins() []join
+ AddJoin(join) *join
+
+ Root() reflect.Value
+ ParentIndex() []int
+ Mount(reflect.Value)
+
+ updateSoftDeleteField() error
+}
+
+func newModel(db *DB, dest []interface{}) (model, error) {
+ if len(dest) == 1 {
+ return _newModel(db, dest[0], true)
+ }
+
+ values := make([]reflect.Value, len(dest))
+
+ for i, el := range dest {
+ v := reflect.ValueOf(el)
+ if v.Kind() != reflect.Ptr {
+ return nil, fmt.Errorf("bun: Scan(non-pointer %T)", dest)
+ }
+
+ v = v.Elem()
+ if v.Kind() != reflect.Slice {
+ return newScanModel(db, dest), nil
+ }
+
+ values[i] = v
+ }
+
+ return newSliceModel(db, dest, values), nil
+}
+
+func newSingleModel(db *DB, dest interface{}) (model, error) {
+ return _newModel(db, dest, false)
+}
+
+func _newModel(db *DB, dest interface{}, scan bool) (model, error) {
+ switch dest := dest.(type) {
+ case nil:
+ return nil, errNilModel
+ case Model:
+ return dest, nil
+ case sql.Scanner:
+ if !scan {
+ return nil, fmt.Errorf("bun: Model(unsupported %T)", dest)
+ }
+ return newScanModel(db, []interface{}{dest}), nil
+ }
+
+ v := reflect.ValueOf(dest)
+ if !v.IsValid() {
+ return nil, errNilModel
+ }
+ if v.Kind() != reflect.Ptr {
+ return nil, fmt.Errorf("bun: Model(non-pointer %T)", dest)
+ }
+
+ if v.IsNil() {
+ typ := v.Type().Elem()
+ if typ.Kind() == reflect.Struct {
+ return newStructTableModel(db, dest, db.Table(typ)), nil
+ }
+ return nil, fmt.Errorf("bun: Model(nil %T)", dest)
+ }
+
+ v = v.Elem()
+
+ switch v.Kind() {
+ case reflect.Map:
+ typ := v.Type()
+ if err := validMap(typ); err != nil {
+ return nil, err
+ }
+ mapPtr := v.Addr().Interface().(*map[string]interface{})
+ return newMapModel(db, mapPtr), nil
+ case reflect.Struct:
+ if v.Type() != timeType {
+ return newStructTableModelValue(db, dest, v), nil
+ }
+ case reflect.Slice:
+ switch elemType := sliceElemType(v); elemType.Kind() {
+ case reflect.Struct:
+ if elemType != timeType {
+ return newSliceTableModel(db, dest, v, elemType), nil
+ }
+ case reflect.Map:
+ if err := validMap(elemType); err != nil {
+ return nil, err
+ }
+ slicePtr := v.Addr().Interface().(*[]map[string]interface{})
+ return newMapSliceModel(db, slicePtr), nil
+ }
+ return newSliceModel(db, []interface{}{dest}, []reflect.Value{v}), nil
+ }
+
+ if scan {
+ return newScanModel(db, []interface{}{dest}), nil
+ }
+
+ return nil, fmt.Errorf("bun: Model(unsupported %T)", dest)
+}
+
+func newTableModelIndex(
+ db *DB,
+ table *schema.Table,
+ root reflect.Value,
+ index []int,
+ rel *schema.Relation,
+) (tableModel, error) {
+ typ := typeByIndex(table.Type, index)
+
+ if typ.Kind() == reflect.Struct {
+ return &structTableModel{
+ db: db,
+ table: table.Dialect().Tables().Get(typ),
+ rel: rel,
+
+ root: root,
+ index: index,
+ }, nil
+ }
+
+ if typ.Kind() == reflect.Slice {
+ structType := indirectType(typ.Elem())
+ if structType.Kind() == reflect.Struct {
+ m := sliceTableModel{
+ structTableModel: structTableModel{
+ db: db,
+ table: table.Dialect().Tables().Get(structType),
+ rel: rel,
+
+ root: root,
+ index: index,
+ },
+ }
+ m.init(typ)
+ return &m, nil
+ }
+ }
+
+ return nil, fmt.Errorf("bun: NewModel(%s)", typ)
+}
+
+func validMap(typ reflect.Type) error {
+ if typ.Key().Kind() != reflect.String || typ.Elem().Kind() != reflect.Interface {
+ return fmt.Errorf("bun: Model(unsupported %s) (expected *map[string]interface{})",
+ typ)
+ }
+ return nil
+}
+
+//------------------------------------------------------------------------------
+
+func isSingleRowModel(m model) bool {
+ switch m.(type) {
+ case *mapModel,
+ *structTableModel,
+ *scanModel:
+ return true
+ default:
+ return false
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/model_map.go b/vendor/github.com/uptrace/bun/model_map.go
new file mode 100644
index 000000000..81c1a4a3b
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/model_map.go
@@ -0,0 +1,183 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "reflect"
+ "sort"
+
+ "github.com/uptrace/bun/schema"
+)
+
+type mapModel struct {
+ db *DB
+
+ dest *map[string]interface{}
+ m map[string]interface{}
+
+ rows *sql.Rows
+ columns []string
+ _columnTypes []*sql.ColumnType
+ scanIndex int
+}
+
+var _ model = (*mapModel)(nil)
+
+func newMapModel(db *DB, dest *map[string]interface{}) *mapModel {
+ m := &mapModel{
+ db: db,
+ dest: dest,
+ }
+ if dest != nil {
+ m.m = *dest
+ }
+ return m
+}
+
+func (m *mapModel) Value() interface{} {
+ return m.dest
+}
+
+func (m *mapModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
+ if !rows.Next() {
+ return 0, rows.Err()
+ }
+
+ columns, err := rows.Columns()
+ if err != nil {
+ return 0, err
+ }
+
+ m.rows = rows
+ m.columns = columns
+ dest := makeDest(m, len(columns))
+
+ if m.m == nil {
+ m.m = make(map[string]interface{}, len(m.columns))
+ }
+
+ m.scanIndex = 0
+ if err := rows.Scan(dest...); err != nil {
+ return 0, err
+ }
+
+ *m.dest = m.m
+
+ return 1, nil
+}
+
+func (m *mapModel) Scan(src interface{}) error {
+ if _, ok := src.([]byte); !ok {
+ return m.scanRaw(src)
+ }
+
+ columnTypes, err := m.columnTypes()
+ if err != nil {
+ return err
+ }
+
+ scanType := columnTypes[m.scanIndex].ScanType()
+ switch scanType.Kind() {
+ case reflect.Interface:
+ return m.scanRaw(src)
+ case reflect.Slice:
+ if scanType.Elem().Kind() == reflect.Uint8 {
+ return m.scanRaw(src)
+ }
+ }
+
+ dest := reflect.New(scanType).Elem()
+ if err := schema.Scanner(scanType)(dest, src); err != nil {
+ return err
+ }
+
+ return m.scanRaw(dest.Interface())
+}
+
+func (m *mapModel) columnTypes() ([]*sql.ColumnType, error) {
+ if m._columnTypes == nil {
+ columnTypes, err := m.rows.ColumnTypes()
+ if err != nil {
+ return nil, err
+ }
+ m._columnTypes = columnTypes
+ }
+ return m._columnTypes, nil
+}
+
+func (m *mapModel) scanRaw(src interface{}) error {
+ columnName := m.columns[m.scanIndex]
+ m.scanIndex++
+ m.m[columnName] = src
+ return nil
+}
+
+func (m *mapModel) appendColumnsValues(fmter schema.Formatter, b []byte) []byte {
+ keys := make([]string, 0, len(m.m))
+
+ for k := range m.m {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+
+ b = append(b, " ("...)
+
+ for i, k := range keys {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b = fmter.AppendIdent(b, k)
+ }
+
+ b = append(b, ") VALUES ("...)
+
+ isTemplate := fmter.IsNop()
+ for i, k := range keys {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ if isTemplate {
+ b = append(b, '?')
+ } else {
+ b = fmter.Dialect().Append(fmter, b, m.m[k])
+ }
+ }
+
+ b = append(b, ")"...)
+
+ return b
+}
+
+func (m *mapModel) appendSet(fmter schema.Formatter, b []byte) []byte {
+ keys := make([]string, 0, len(m.m))
+
+ for k := range m.m {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+
+ isTemplate := fmter.IsNop()
+ for i, k := range keys {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ b = fmter.AppendIdent(b, k)
+ b = append(b, " = "...)
+ if isTemplate {
+ b = append(b, '?')
+ } else {
+ b = fmter.Dialect().Append(fmter, b, m.m[k])
+ }
+ }
+
+ return b
+}
+
+func makeDest(v interface{}, n int) []interface{} {
+ dest := make([]interface{}, n)
+ for i := range dest {
+ dest[i] = v
+ }
+ return dest
+}
diff --git a/vendor/github.com/uptrace/bun/model_map_slice.go b/vendor/github.com/uptrace/bun/model_map_slice.go
new file mode 100644
index 000000000..5c6f48e44
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/model_map_slice.go
@@ -0,0 +1,162 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "sort"
+
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/schema"
+)
+
+type mapSliceModel struct {
+ mapModel
+ dest *[]map[string]interface{}
+
+ keys []string
+}
+
+var _ model = (*mapSliceModel)(nil)
+
+func newMapSliceModel(db *DB, dest *[]map[string]interface{}) *mapSliceModel {
+ return &mapSliceModel{
+ mapModel: mapModel{
+ db: db,
+ },
+ dest: dest,
+ }
+}
+
+func (m *mapSliceModel) Value() interface{} {
+ return m.dest
+}
+
+func (m *mapSliceModel) SetCap(cap int) {
+ if cap > 100 {
+ cap = 100
+ }
+ if slice := *m.dest; len(slice) < cap {
+ *m.dest = make([]map[string]interface{}, 0, cap)
+ }
+}
+
+func (m *mapSliceModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
+ columns, err := rows.Columns()
+ if err != nil {
+ return 0, err
+ }
+
+ m.rows = rows
+ m.columns = columns
+ dest := makeDest(m, len(columns))
+
+ slice := *m.dest
+ if len(slice) > 0 {
+ slice = slice[:0]
+ }
+
+ var n int
+
+ for rows.Next() {
+ m.m = make(map[string]interface{}, len(m.columns))
+
+ m.scanIndex = 0
+ if err := rows.Scan(dest...); err != nil {
+ return 0, err
+ }
+
+ slice = append(slice, m.m)
+ n++
+ }
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+
+ *m.dest = slice
+ return n, nil
+}
+
+func (m *mapSliceModel) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if err := m.initKeys(); err != nil {
+ return nil, err
+ }
+
+ for i, k := range m.keys {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b = fmter.AppendIdent(b, k)
+ }
+
+ return b, nil
+}
+
+func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if err := m.initKeys(); err != nil {
+ return nil, err
+ }
+ slice := *m.dest
+
+ b = append(b, "VALUES "...)
+ if m.db.features.Has(feature.ValuesRow) {
+ b = append(b, "ROW("...)
+ } else {
+ b = append(b, '(')
+ }
+
+ if fmter.IsNop() {
+ for i := range m.keys {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b = append(b, '?')
+ }
+ return b, nil
+ }
+
+ for i, el := range slice {
+ if i > 0 {
+ b = append(b, "), "...)
+ if m.db.features.Has(feature.ValuesRow) {
+ b = append(b, "ROW("...)
+ } else {
+ b = append(b, '(')
+ }
+ }
+
+ for j, key := range m.keys {
+ if j > 0 {
+ b = append(b, ", "...)
+ }
+ b = fmter.Dialect().Append(fmter, b, el[key])
+ }
+ }
+
+ b = append(b, ')')
+
+ return b, nil
+}
+
+func (m *mapSliceModel) initKeys() error {
+ if m.keys != nil {
+ return nil
+ }
+
+ slice := *m.dest
+ if len(slice) == 0 {
+ return errors.New("bun: map slice is empty")
+ }
+
+ first := slice[0]
+ keys := make([]string, 0, len(first))
+
+ for k := range first {
+ keys = append(keys, k)
+ }
+
+ sort.Strings(keys)
+ m.keys = keys
+
+ return nil
+}
diff --git a/vendor/github.com/uptrace/bun/model_scan.go b/vendor/github.com/uptrace/bun/model_scan.go
new file mode 100644
index 000000000..6dd061fb2
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/model_scan.go
@@ -0,0 +1,54 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "reflect"
+)
+
+type scanModel struct {
+ db *DB
+
+ dest []interface{}
+ scanIndex int
+}
+
+var _ model = (*scanModel)(nil)
+
+func newScanModel(db *DB, dest []interface{}) *scanModel {
+ return &scanModel{
+ db: db,
+ dest: dest,
+ }
+}
+
+func (m *scanModel) Value() interface{} {
+ return m.dest
+}
+
+func (m *scanModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
+ if !rows.Next() {
+ return 0, rows.Err()
+ }
+
+ dest := makeDest(m, len(m.dest))
+
+ m.scanIndex = 0
+ if err := rows.Scan(dest...); err != nil {
+ return 0, err
+ }
+
+ return 1, nil
+}
+
+func (m *scanModel) ScanRow(ctx context.Context, rows *sql.Rows) error {
+ return rows.Scan(m.dest...)
+}
+
+func (m *scanModel) Scan(src interface{}) error {
+ dest := reflect.ValueOf(m.dest[m.scanIndex])
+ m.scanIndex++
+
+ scanner := m.db.dialect.Scanner(dest.Type())
+ return scanner(dest, src)
+}
diff --git a/vendor/github.com/uptrace/bun/model_slice.go b/vendor/github.com/uptrace/bun/model_slice.go
new file mode 100644
index 000000000..afe804382
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/model_slice.go
@@ -0,0 +1,82 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "reflect"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type sliceInfo struct {
+ nextElem func() reflect.Value
+ scan schema.ScannerFunc
+}
+
+type sliceModel struct {
+ dest []interface{}
+ values []reflect.Value
+ scanIndex int
+ info []sliceInfo
+}
+
+var _ model = (*sliceModel)(nil)
+
+func newSliceModel(db *DB, dest []interface{}, values []reflect.Value) *sliceModel {
+ return &sliceModel{
+ dest: dest,
+ values: values,
+ }
+}
+
+func (m *sliceModel) Value() interface{} {
+ return m.dest
+}
+
+func (m *sliceModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
+ columns, err := rows.Columns()
+ if err != nil {
+ return 0, err
+ }
+
+ m.info = make([]sliceInfo, len(m.values))
+ for i, v := range m.values {
+ if v.IsValid() && v.Len() > 0 {
+ v.Set(v.Slice(0, 0))
+ }
+
+ m.info[i] = sliceInfo{
+ nextElem: internal.MakeSliceNextElemFunc(v),
+ scan: schema.Scanner(v.Type().Elem()),
+ }
+ }
+
+ if len(columns) == 0 {
+ return 0, nil
+ }
+ dest := makeDest(m, len(columns))
+
+ var n int
+
+ for rows.Next() {
+ m.scanIndex = 0
+ if err := rows.Scan(dest...); err != nil {
+ return 0, err
+ }
+ n++
+ }
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+func (m *sliceModel) Scan(src interface{}) error {
+ info := m.info[m.scanIndex]
+ m.scanIndex++
+
+ dest := info.nextElem()
+ return info.scan(dest, src)
+}
diff --git a/vendor/github.com/uptrace/bun/model_table_has_many.go b/vendor/github.com/uptrace/bun/model_table_has_many.go
new file mode 100644
index 000000000..e64b7088d
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/model_table_has_many.go
@@ -0,0 +1,149 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type hasManyModel struct {
+ *sliceTableModel
+ baseTable *schema.Table
+ rel *schema.Relation
+
+ baseValues map[internal.MapKey][]reflect.Value
+ structKey []interface{}
+}
+
+var _ tableModel = (*hasManyModel)(nil)
+
+func newHasManyModel(j *join) *hasManyModel {
+ baseTable := j.BaseModel.Table()
+ joinModel := j.JoinModel.(*sliceTableModel)
+ baseValues := baseValues(joinModel, j.Relation.BaseFields)
+ if len(baseValues) == 0 {
+ return nil
+ }
+ m := hasManyModel{
+ sliceTableModel: joinModel,
+ baseTable: baseTable,
+ rel: j.Relation,
+
+ baseValues: baseValues,
+ }
+ if !m.sliceOfPtr {
+ m.strct = reflect.New(m.table.Type).Elem()
+ }
+ return &m
+}
+
+func (m *hasManyModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
+ columns, err := rows.Columns()
+ if err != nil {
+ return 0, err
+ }
+
+ m.columns = columns
+ dest := makeDest(m, len(columns))
+
+ var n int
+
+ for rows.Next() {
+ if m.sliceOfPtr {
+ m.strct = reflect.New(m.table.Type).Elem()
+ } else {
+ m.strct.Set(m.table.ZeroValue)
+ }
+ m.structInited = false
+
+ m.scanIndex = 0
+ m.structKey = m.structKey[:0]
+ if err := rows.Scan(dest...); err != nil {
+ return 0, err
+ }
+
+ if err := m.parkStruct(); err != nil {
+ return 0, err
+ }
+
+ n++
+ }
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+func (m *hasManyModel) Scan(src interface{}) error {
+ column := m.columns[m.scanIndex]
+ m.scanIndex++
+
+ field, err := m.table.Field(column)
+ if err != nil {
+ return err
+ }
+
+ if err := field.ScanValue(m.strct, src); err != nil {
+ return err
+ }
+
+ for _, f := range m.rel.JoinFields {
+ if f.Name == field.Name {
+ m.structKey = append(m.structKey, field.Value(m.strct).Interface())
+ break
+ }
+ }
+
+ return nil
+}
+
+func (m *hasManyModel) parkStruct() error {
+ baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)]
+ if !ok {
+ return fmt.Errorf(
+ "bun: has-many relation=%s does not have base %s with id=%q (check join conditions)",
+ m.rel.Field.GoName, m.baseTable, m.structKey)
+ }
+
+ for i, v := range baseValues {
+ if !m.sliceOfPtr {
+ v.Set(reflect.Append(v, m.strct))
+ continue
+ }
+
+ if i == 0 {
+ v.Set(reflect.Append(v, m.strct.Addr()))
+ continue
+ }
+
+ clone := reflect.New(m.strct.Type()).Elem()
+ clone.Set(m.strct)
+ v.Set(reflect.Append(v, clone.Addr()))
+ }
+
+ return nil
+}
+
+func baseValues(model tableModel, fields []*schema.Field) map[internal.MapKey][]reflect.Value {
+ fieldIndex := model.Relation().Field.Index
+ m := make(map[internal.MapKey][]reflect.Value)
+ key := make([]interface{}, 0, len(fields))
+ walk(model.Root(), model.ParentIndex(), func(v reflect.Value) {
+ key = modelKey(key[:0], v, fields)
+ mapKey := internal.NewMapKey(key)
+ m[mapKey] = append(m[mapKey], v.FieldByIndex(fieldIndex))
+ })
+ return m
+}
+
+func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) []interface{} {
+ for _, f := range fields {
+ key = append(key, f.Value(strct).Interface())
+ }
+ return key
+}
diff --git a/vendor/github.com/uptrace/bun/model_table_m2m.go b/vendor/github.com/uptrace/bun/model_table_m2m.go
new file mode 100644
index 000000000..4357e3a8e
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/model_table_m2m.go
@@ -0,0 +1,138 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type m2mModel struct {
+ *sliceTableModel
+ baseTable *schema.Table
+ rel *schema.Relation
+
+ baseValues map[internal.MapKey][]reflect.Value
+ structKey []interface{}
+}
+
+var _ tableModel = (*m2mModel)(nil)
+
+func newM2MModel(j *join) *m2mModel {
+ baseTable := j.BaseModel.Table()
+ joinModel := j.JoinModel.(*sliceTableModel)
+ baseValues := baseValues(joinModel, baseTable.PKs)
+ if len(baseValues) == 0 {
+ return nil
+ }
+ m := &m2mModel{
+ sliceTableModel: joinModel,
+ baseTable: baseTable,
+ rel: j.Relation,
+
+ baseValues: baseValues,
+ }
+ if !m.sliceOfPtr {
+ m.strct = reflect.New(m.table.Type).Elem()
+ }
+ return m
+}
+
+func (m *m2mModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
+ columns, err := rows.Columns()
+ if err != nil {
+ return 0, err
+ }
+
+ m.columns = columns
+ dest := makeDest(m, len(columns))
+
+ var n int
+
+ for rows.Next() {
+ if m.sliceOfPtr {
+ m.strct = reflect.New(m.table.Type).Elem()
+ } else {
+ m.strct.Set(m.table.ZeroValue)
+ }
+ m.structInited = false
+
+ m.scanIndex = 0
+ m.structKey = m.structKey[:0]
+ if err := rows.Scan(dest...); err != nil {
+ return 0, err
+ }
+
+ if err := m.parkStruct(); err != nil {
+ return 0, err
+ }
+
+ n++
+ }
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+func (m *m2mModel) Scan(src interface{}) error {
+ column := m.columns[m.scanIndex]
+ m.scanIndex++
+
+ field, ok := m.table.FieldMap[column]
+ if !ok {
+ return m.scanM2MColumn(column, src)
+ }
+
+ if err := field.ScanValue(m.strct, src); err != nil {
+ return err
+ }
+
+ for _, fk := range m.rel.M2MBaseFields {
+ if fk.Name == field.Name {
+ m.structKey = append(m.structKey, field.Value(m.strct).Interface())
+ break
+ }
+ }
+
+ return nil
+}
+
+func (m *m2mModel) scanM2MColumn(column string, src interface{}) error {
+ for _, field := range m.rel.M2MBaseFields {
+ if field.Name == column {
+ dest := reflect.New(field.IndirectType).Elem()
+ if err := field.Scan(dest, src); err != nil {
+ return err
+ }
+ m.structKey = append(m.structKey, dest.Interface())
+ break
+ }
+ }
+
+ _, err := m.scanColumn(column, src)
+ return err
+}
+
+func (m *m2mModel) parkStruct() error {
+ baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)]
+ if !ok {
+ return fmt.Errorf(
+ "bun: m2m relation=%s does not have base %s with key=%q (check join conditions)",
+ m.rel.Field.GoName, m.baseTable, m.structKey)
+ }
+
+ for _, v := range baseValues {
+ if m.sliceOfPtr {
+ v.Set(reflect.Append(v, m.strct.Addr()))
+ } else {
+ v.Set(reflect.Append(v, m.strct))
+ }
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/uptrace/bun/model_table_slice.go b/vendor/github.com/uptrace/bun/model_table_slice.go
new file mode 100644
index 000000000..67e7c71e7
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/model_table_slice.go
@@ -0,0 +1,113 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "reflect"
+
+ "github.com/uptrace/bun/schema"
+)
+
+type sliceTableModel struct {
+ structTableModel
+
+ slice reflect.Value
+ sliceLen int
+ sliceOfPtr bool
+ nextElem func() reflect.Value
+}
+
+var _ tableModel = (*sliceTableModel)(nil)
+
+func newSliceTableModel(
+ db *DB, dest interface{}, slice reflect.Value, elemType reflect.Type,
+) *sliceTableModel {
+ m := &sliceTableModel{
+ structTableModel: structTableModel{
+ db: db,
+ table: db.Table(elemType),
+ dest: dest,
+ root: slice,
+ },
+
+ slice: slice,
+ sliceLen: slice.Len(),
+ nextElem: makeSliceNextElemFunc(slice),
+ }
+ m.init(slice.Type())
+ return m
+}
+
+func (m *sliceTableModel) init(sliceType reflect.Type) {
+ switch sliceType.Elem().Kind() {
+ case reflect.Ptr, reflect.Interface:
+ m.sliceOfPtr = true
+ }
+}
+
+func (m *sliceTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join {
+ return m.join(m.slice, name, apply)
+}
+
+func (m *sliceTableModel) Bind(bind reflect.Value) {
+ m.slice = bind.Field(m.index[len(m.index)-1])
+}
+
+func (m *sliceTableModel) SetCap(cap int) {
+ if cap > 100 {
+ cap = 100
+ }
+ if m.slice.Cap() < cap {
+ m.slice.Set(reflect.MakeSlice(m.slice.Type(), 0, cap))
+ }
+}
+
+func (m *sliceTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
+ columns, err := rows.Columns()
+ if err != nil {
+ return 0, err
+ }
+
+ m.columns = columns
+ dest := makeDest(m, len(columns))
+
+ if m.slice.IsValid() && m.slice.Len() > 0 {
+ m.slice.Set(m.slice.Slice(0, 0))
+ }
+
+ var n int
+
+ for rows.Next() {
+ m.strct = m.nextElem()
+ m.structInited = false
+
+ if err := m.scanRow(ctx, rows, dest); err != nil {
+ return 0, err
+ }
+
+ n++
+ }
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+// Inherit these hooks from structTableModel.
+var (
+ _ schema.BeforeScanHook = (*sliceTableModel)(nil)
+ _ schema.AfterScanHook = (*sliceTableModel)(nil)
+)
+
+func (m *sliceTableModel) updateSoftDeleteField() error {
+ sliceLen := m.slice.Len()
+ for i := 0; i < sliceLen; i++ {
+ strct := indirect(m.slice.Index(i))
+ fv := m.table.SoftDeleteField.Value(strct)
+ if err := m.table.UpdateSoftDeleteField(fv); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/vendor/github.com/uptrace/bun/model_table_struct.go b/vendor/github.com/uptrace/bun/model_table_struct.go
new file mode 100644
index 000000000..3bb0c14dd
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/model_table_struct.go
@@ -0,0 +1,345 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/schema"
+)
+
+type structTableModel struct {
+ db *DB
+ table *schema.Table
+
+ rel *schema.Relation
+ joins []join
+
+ dest interface{}
+ root reflect.Value
+ index []int
+
+ strct reflect.Value
+ structInited bool
+ structInitErr error
+
+ columns []string
+ scanIndex int
+}
+
+var _ tableModel = (*structTableModel)(nil)
+
+func newStructTableModel(db *DB, dest interface{}, table *schema.Table) *structTableModel {
+ return &structTableModel{
+ db: db,
+ table: table,
+ dest: dest,
+ }
+}
+
+func newStructTableModelValue(db *DB, dest interface{}, v reflect.Value) *structTableModel {
+ return &structTableModel{
+ db: db,
+ table: db.Table(v.Type()),
+ dest: dest,
+ root: v,
+ strct: v,
+ }
+}
+
+func (m *structTableModel) Value() interface{} {
+ return m.dest
+}
+
+func (m *structTableModel) Table() *schema.Table {
+ return m.table
+}
+
+func (m *structTableModel) Relation() *schema.Relation {
+ return m.rel
+}
+
+func (m *structTableModel) Root() reflect.Value {
+ return m.root
+}
+
+func (m *structTableModel) Index() []int {
+ return m.index
+}
+
+func (m *structTableModel) ParentIndex() []int {
+ return m.index[:len(m.index)-len(m.rel.Field.Index)]
+}
+
+func (m *structTableModel) Mount(host reflect.Value) {
+ m.strct = host.FieldByIndex(m.rel.Field.Index)
+ m.structInited = false
+}
+
+func (m *structTableModel) initStruct() error {
+ if m.structInited {
+ return m.structInitErr
+ }
+ m.structInited = true
+
+ switch m.strct.Kind() {
+ case reflect.Invalid:
+ m.structInitErr = errNilModel
+ return m.structInitErr
+ case reflect.Interface:
+ m.strct = m.strct.Elem()
+ }
+
+ if m.strct.Kind() == reflect.Ptr {
+ if m.strct.IsNil() {
+ m.strct.Set(reflect.New(m.strct.Type().Elem()))
+ m.strct = m.strct.Elem()
+ } else {
+ m.strct = m.strct.Elem()
+ }
+ }
+
+ m.mountJoins()
+
+ return nil
+}
+
+func (m *structTableModel) mountJoins() {
+ for i := range m.joins {
+ j := &m.joins[i]
+ switch j.Relation.Type {
+ case schema.HasOneRelation, schema.BelongsToRelation:
+ j.JoinModel.Mount(m.strct)
+ }
+ }
+}
+
+var _ schema.BeforeScanHook = (*structTableModel)(nil)
+
+func (m *structTableModel) BeforeScan(ctx context.Context) error {
+ if !m.table.HasBeforeScanHook() {
+ return nil
+ }
+ return callBeforeScanHook(ctx, m.strct.Addr())
+}
+
+var _ schema.AfterScanHook = (*structTableModel)(nil)
+
+func (m *structTableModel) AfterScan(ctx context.Context) error {
+ if !m.table.HasAfterScanHook() || !m.structInited {
+ return nil
+ }
+
+ var firstErr error
+
+ if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil {
+ firstErr = err
+ }
+
+ for _, j := range m.joins {
+ switch j.Relation.Type {
+ case schema.HasOneRelation, schema.BelongsToRelation:
+ if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil {
+ firstErr = err
+ }
+ }
+ }
+
+ return firstErr
+}
+
+func (m *structTableModel) GetJoin(name string) *join {
+ for i := range m.joins {
+ j := &m.joins[i]
+ if j.Relation.Field.Name == name || j.Relation.Field.GoName == name {
+ return j
+ }
+ }
+ return nil
+}
+
+func (m *structTableModel) GetJoins() []join {
+ return m.joins
+}
+
+func (m *structTableModel) AddJoin(j join) *join {
+ m.joins = append(m.joins, j)
+ return &m.joins[len(m.joins)-1]
+}
+
+func (m *structTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join {
+ return m.join(m.strct, name, apply)
+}
+
+func (m *structTableModel) join(
+ bind reflect.Value, name string, apply func(*SelectQuery) *SelectQuery,
+) *join {
+ path := strings.Split(name, ".")
+ index := make([]int, 0, len(path))
+
+ currJoin := join{
+ BaseModel: m,
+ JoinModel: m,
+ }
+ var lastJoin *join
+
+ for _, name := range path {
+ relation, ok := currJoin.JoinModel.Table().Relations[name]
+ if !ok {
+ return nil
+ }
+
+ currJoin.Relation = relation
+ index = append(index, relation.Field.Index...)
+
+ if j := currJoin.JoinModel.GetJoin(name); j != nil {
+ currJoin.BaseModel = j.BaseModel
+ currJoin.JoinModel = j.JoinModel
+
+ lastJoin = j
+ } else {
+ model, err := newTableModelIndex(m.db, m.table, bind, index, relation)
+ if err != nil {
+ return nil
+ }
+
+ currJoin.Parent = lastJoin
+ currJoin.BaseModel = currJoin.JoinModel
+ currJoin.JoinModel = model
+
+ lastJoin = currJoin.BaseModel.AddJoin(currJoin)
+ }
+ }
+
+ // No joins with such name.
+ if lastJoin == nil {
+ return nil
+ }
+ if apply != nil {
+ lastJoin.ApplyQueryFunc = apply
+ }
+
+ return lastJoin
+}
+
+func (m *structTableModel) updateSoftDeleteField() error {
+ fv := m.table.SoftDeleteField.Value(m.strct)
+ return m.table.UpdateSoftDeleteField(fv)
+}
+
+func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
+ if !rows.Next() {
+ return 0, rows.Err()
+ }
+
+ if err := m.ScanRow(ctx, rows); err != nil {
+ return 0, err
+ }
+
+ // For inserts, SQLite3 can return a row like it was inserted sucessfully and then
+ // an actual error for the next row. See issues/100.
+ if m.db.dialect.Name() == dialect.SQLite {
+ _ = rows.Next()
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+ }
+
+ return 1, nil
+}
+
+func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error {
+ columns, err := rows.Columns()
+ if err != nil {
+ return err
+ }
+
+ m.columns = columns
+ dest := makeDest(m, len(columns))
+
+ return m.scanRow(ctx, rows, dest)
+}
+
+func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error {
+ if err := m.BeforeScan(ctx); err != nil {
+ return err
+ }
+
+ m.scanIndex = 0
+ if err := rows.Scan(dest...); err != nil {
+ return err
+ }
+
+ if err := m.AfterScan(ctx); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *structTableModel) Scan(src interface{}) error {
+ column := m.columns[m.scanIndex]
+ m.scanIndex++
+
+ return m.ScanColumn(unquote(column), src)
+}
+
+func (m *structTableModel) ScanColumn(column string, src interface{}) error {
+ if ok, err := m.scanColumn(column, src); ok {
+ return err
+ }
+ if column == "" || column[0] == '_' || m.db.flags.Has(discardUnknownColumns) {
+ return nil
+ }
+ return fmt.Errorf("bun: %s does not have column %q", m.table.TypeName, column)
+}
+
+func (m *structTableModel) scanColumn(column string, src interface{}) (bool, error) {
+ if src != nil {
+ if err := m.initStruct(); err != nil {
+ return true, err
+ }
+ }
+
+ if field, ok := m.table.FieldMap[column]; ok {
+ return true, field.ScanValue(m.strct, src)
+ }
+
+ if joinName, column := splitColumn(column); joinName != "" {
+ if join := m.GetJoin(joinName); join != nil {
+ return true, join.JoinModel.ScanColumn(column, src)
+ }
+ if m.table.ModelName == joinName {
+ return true, m.ScanColumn(column, src)
+ }
+ }
+
+ return false, nil
+}
+
+func (m *structTableModel) AppendNamedArg(
+ fmter schema.Formatter, b []byte, name string,
+) ([]byte, bool) {
+ return m.table.AppendNamedArg(fmter, b, name, m.strct)
+}
+
+// sqlite3 sometimes does not unquote columns.
+func unquote(s string) string {
+ if s == "" {
+ return s
+ }
+ if s[0] == '"' && s[len(s)-1] == '"' {
+ return s[1 : len(s)-1]
+ }
+ return s
+}
+
+func splitColumn(s string) (string, string) {
+ if i := strings.Index(s, "__"); i >= 0 {
+ return s[:i], s[i+2:]
+ }
+ return "", s
+}
diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go
new file mode 100644
index 000000000..1a7c32720
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_base.go
@@ -0,0 +1,874 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+const (
+ wherePKFlag internal.Flag = 1 << iota
+ forceDeleteFlag
+ deletedFlag
+ allWithDeletedFlag
+)
+
+type withQuery struct {
+ name string
+ query schema.QueryAppender
+}
+
+// IConn is a common interface for *sql.DB, *sql.Conn, and *sql.Tx.
+type IConn interface {
+ QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
+ ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
+ QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
+}
+
+var (
+ _ IConn = (*sql.DB)(nil)
+ _ IConn = (*sql.Conn)(nil)
+ _ IConn = (*sql.Tx)(nil)
+ _ IConn = (*DB)(nil)
+ _ IConn = (*Conn)(nil)
+ _ IConn = (*Tx)(nil)
+)
+
+// IDB is a common interface for *bun.DB, bun.Conn, and bun.Tx.
+type IDB interface {
+ IConn
+
+ NewValues(model interface{}) *ValuesQuery
+ NewSelect() *SelectQuery
+ NewInsert() *InsertQuery
+ NewUpdate() *UpdateQuery
+ NewDelete() *DeleteQuery
+ NewCreateTable() *CreateTableQuery
+ NewDropTable() *DropTableQuery
+ NewCreateIndex() *CreateIndexQuery
+ NewDropIndex() *DropIndexQuery
+ NewTruncateTable() *TruncateTableQuery
+ NewAddColumn() *AddColumnQuery
+ NewDropColumn() *DropColumnQuery
+}
+
+var (
+ _ IConn = (*DB)(nil)
+ _ IConn = (*Conn)(nil)
+ _ IConn = (*Tx)(nil)
+)
+
+type baseQuery struct {
+ db *DB
+ conn IConn
+
+ model model
+ err error
+
+ tableModel tableModel
+ table *schema.Table
+
+ with []withQuery
+ modelTable schema.QueryWithArgs
+ tables []schema.QueryWithArgs
+ columns []schema.QueryWithArgs
+
+ flags internal.Flag
+}
+
+func (q *baseQuery) DB() *DB {
+ return q.db
+}
+
+func (q *baseQuery) GetModel() Model {
+ return q.model
+}
+
+func (q *baseQuery) setConn(db IConn) {
+ // Unwrap Bun wrappers to not call query hooks twice.
+ switch db := db.(type) {
+ case *DB:
+ q.conn = db.DB
+ case Conn:
+ q.conn = db.Conn
+ case Tx:
+ q.conn = db.Tx
+ default:
+ q.conn = db
+ }
+}
+
+// TODO: rename to setModel
+func (q *baseQuery) setTableModel(modeli interface{}) {
+ model, err := newSingleModel(q.db, modeli)
+ if err != nil {
+ q.setErr(err)
+ return
+ }
+
+ q.model = model
+ if tm, ok := model.(tableModel); ok {
+ q.tableModel = tm
+ q.table = tm.Table()
+ }
+}
+
+func (q *baseQuery) setErr(err error) {
+ if q.err == nil {
+ q.err = err
+ }
+}
+
+func (q *baseQuery) getModel(dest []interface{}) (model, error) {
+ if len(dest) == 0 {
+ if q.model != nil {
+ return q.model, nil
+ }
+ return nil, errNilModel
+ }
+ return newModel(q.db, dest)
+}
+
+//------------------------------------------------------------------------------
+
+func (q *baseQuery) checkSoftDelete() error {
+ if q.table == nil {
+ return errors.New("bun: can't use soft deletes without a table")
+ }
+ if q.table.SoftDeleteField == nil {
+ return fmt.Errorf("%s does not have a soft delete field", q.table)
+ }
+ if q.tableModel == nil {
+ return errors.New("bun: can't use soft deletes without a table model")
+ }
+ return nil
+}
+
+// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models.
+func (q *baseQuery) whereDeleted() {
+ if err := q.checkSoftDelete(); err != nil {
+ q.setErr(err)
+ return
+ }
+ q.flags = q.flags.Set(deletedFlag)
+ q.flags = q.flags.Remove(allWithDeletedFlag)
+}
+
+// AllWithDeleted changes query to return all rows including soft deleted ones.
+func (q *baseQuery) whereAllWithDeleted() {
+ if err := q.checkSoftDelete(); err != nil {
+ q.setErr(err)
+ return
+ }
+ q.flags = q.flags.Set(allWithDeletedFlag)
+ q.flags = q.flags.Remove(deletedFlag)
+}
+
+func (q *baseQuery) isSoftDelete() bool {
+ if q.table != nil {
+ return q.table.SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag)
+ }
+ return false
+}
+
+//------------------------------------------------------------------------------
+
+func (q *baseQuery) addWith(name string, query schema.QueryAppender) {
+ q.with = append(q.with, withQuery{
+ name: name,
+ query: query,
+ })
+}
+
+func (q *baseQuery) appendWith(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if len(q.with) == 0 {
+ return b, nil
+ }
+
+ b = append(b, "WITH "...)
+ for i, with := range q.with {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ b = fmter.AppendIdent(b, with.name)
+ if q, ok := with.query.(schema.ColumnsAppender); ok {
+ b = append(b, " ("...)
+ b, err = q.AppendColumns(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ b = append(b, ")"...)
+ }
+
+ b = append(b, " AS ("...)
+
+ b, err = with.query.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, ')')
+ }
+ b = append(b, ' ')
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *baseQuery) addTable(table schema.QueryWithArgs) {
+ q.tables = append(q.tables, table)
+}
+
+func (q *baseQuery) addColumn(column schema.QueryWithArgs) {
+ q.columns = append(q.columns, column)
+}
+
+func (q *baseQuery) excludeColumn(columns []string) {
+ if q.columns == nil {
+ for _, f := range q.table.Fields {
+ q.columns = append(q.columns, schema.UnsafeIdent(f.Name))
+ }
+ }
+
+ if len(columns) == 1 && columns[0] == "*" {
+ q.columns = make([]schema.QueryWithArgs, 0)
+ return
+ }
+
+ for _, column := range columns {
+ if !q._excludeColumn(column) {
+ q.setErr(fmt.Errorf("bun: can't find column=%q", column))
+ return
+ }
+ }
+}
+
+func (q *baseQuery) _excludeColumn(column string) bool {
+ for i, col := range q.columns {
+ if col.Args == nil && col.Query == column {
+ q.columns = append(q.columns[:i], q.columns[i+1:]...)
+ return true
+ }
+ }
+ return false
+}
+
+//------------------------------------------------------------------------------
+
+func (q *baseQuery) modelHasTableName() bool {
+ return !q.modelTable.IsZero() || q.table != nil
+}
+
+func (q *baseQuery) hasTables() bool {
+ return q.modelHasTableName() || len(q.tables) > 0
+}
+
+func (q *baseQuery) appendTables(
+ fmter schema.Formatter, b []byte,
+) (_ []byte, err error) {
+ return q._appendTables(fmter, b, false)
+}
+
+func (q *baseQuery) appendTablesWithAlias(
+ fmter schema.Formatter, b []byte,
+) (_ []byte, err error) {
+ return q._appendTables(fmter, b, true)
+}
+
+func (q *baseQuery) _appendTables(
+ fmter schema.Formatter, b []byte, withAlias bool,
+) (_ []byte, err error) {
+ startLen := len(b)
+
+ if q.modelHasTableName() {
+ if !q.modelTable.IsZero() {
+ b, err = q.modelTable.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects))
+ if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects {
+ b = append(b, " AS "...)
+ b = append(b, q.table.SQLAlias...)
+ }
+ }
+ }
+
+ for _, table := range q.tables {
+ if len(b) > startLen {
+ b = append(b, ", "...)
+ }
+ b, err = table.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+}
+
+func (q *baseQuery) appendFirstTable(fmter schema.Formatter, b []byte) ([]byte, error) {
+ return q._appendFirstTable(fmter, b, false)
+}
+
+func (q *baseQuery) appendFirstTableWithAlias(
+ fmter schema.Formatter, b []byte,
+) ([]byte, error) {
+ return q._appendFirstTable(fmter, b, true)
+}
+
+func (q *baseQuery) _appendFirstTable(
+ fmter schema.Formatter, b []byte, withAlias bool,
+) ([]byte, error) {
+ if !q.modelTable.IsZero() {
+ return q.modelTable.AppendQuery(fmter, b)
+ }
+
+ if q.table != nil {
+ b = fmter.AppendQuery(b, string(q.table.SQLName))
+ if withAlias {
+ b = append(b, " AS "...)
+ b = append(b, q.table.SQLAlias...)
+ }
+ return b, nil
+ }
+
+ if len(q.tables) > 0 {
+ return q.tables[0].AppendQuery(fmter, b)
+ }
+
+ return nil, errors.New("bun: query does not have a table")
+}
+
+func (q *baseQuery) hasMultiTables() bool {
+ if q.modelHasTableName() {
+ return len(q.tables) >= 1
+ }
+ return len(q.tables) >= 2
+}
+
+func (q *baseQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ tables := q.tables
+ if !q.modelHasTableName() {
+ tables = tables[1:]
+ }
+ for i, table := range tables {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b, err = table.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ for i, f := range q.columns {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b, err = f.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return b, nil
+}
+
+func (q *baseQuery) getFields() ([]*schema.Field, error) {
+ table := q.tableModel.Table()
+
+ if len(q.columns) == 0 {
+ return table.Fields, nil
+ }
+
+ fields, err := q._getFields(false)
+ if err != nil {
+ return nil, err
+ }
+
+ return fields, nil
+}
+
+func (q *baseQuery) getDataFields() ([]*schema.Field, error) {
+ if len(q.columns) == 0 {
+ return q.table.DataFields, nil
+ }
+ return q._getFields(true)
+}
+
+func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) {
+ fields := make([]*schema.Field, 0, len(q.columns))
+ for _, col := range q.columns {
+ if col.Args != nil {
+ continue
+ }
+
+ field, err := q.table.Field(col.Query)
+ if err != nil {
+ return nil, err
+ }
+
+ if omitPK && field.IsPK {
+ continue
+ }
+
+ fields = append(fields, field)
+ }
+ return fields, nil
+}
+
+func (q *baseQuery) scan(
+ ctx context.Context,
+ queryApp schema.QueryAppender,
+ query string,
+ model model,
+ hasDest bool,
+) (res result, _ error) {
+ ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)
+
+ rows, err := q.conn.QueryContext(ctx, query)
+ if err != nil {
+ q.db.afterQuery(ctx, event, nil, err)
+ return res, err
+ }
+ defer rows.Close()
+
+ n, err := model.ScanRows(ctx, rows)
+ if err != nil {
+ q.db.afterQuery(ctx, event, nil, err)
+ return res, err
+ }
+
+ res.n = n
+ if n == 0 && hasDest && isSingleRowModel(model) {
+ err = sql.ErrNoRows
+ }
+
+ q.db.afterQuery(ctx, event, nil, err)
+
+ return res, err
+}
+
+func (q *baseQuery) exec(
+ ctx context.Context,
+ queryApp schema.QueryAppender,
+ query string,
+) (res result, _ error) {
+ ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)
+
+ r, err := q.conn.ExecContext(ctx, query)
+ if err != nil {
+ q.db.afterQuery(ctx, event, nil, err)
+ return res, err
+ }
+
+ res.r = r
+
+ q.db.afterQuery(ctx, event, nil, err)
+ return res, nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *baseQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) {
+ if q.table == nil {
+ return b, false
+ }
+
+ if m, ok := q.tableModel.(*structTableModel); ok {
+ if b, ok := m.AppendNamedArg(fmter, b, name); ok {
+ return b, ok
+ }
+ }
+
+ switch name {
+ case "TableName":
+ b = fmter.AppendQuery(b, string(q.table.SQLName))
+ return b, true
+ case "TableAlias":
+ b = fmter.AppendQuery(b, string(q.table.SQLAlias))
+ return b, true
+ case "PKs":
+ b = appendColumns(b, "", q.table.PKs)
+ return b, true
+ case "TablePKs":
+ b = appendColumns(b, q.table.SQLAlias, q.table.PKs)
+ return b, true
+ case "Columns":
+ b = appendColumns(b, "", q.table.Fields)
+ return b, true
+ case "TableColumns":
+ b = appendColumns(b, q.table.SQLAlias, q.table.Fields)
+ return b, true
+ }
+
+ return b, false
+}
+
+func appendColumns(b []byte, table schema.Safe, fields []*schema.Field) []byte {
+ for i, f := range fields {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ if len(table) > 0 {
+ b = append(b, table...)
+ b = append(b, '.')
+ }
+ b = append(b, f.SQLName...)
+ }
+ return b
+}
+
+func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) schema.Formatter {
+ if fmter.IsNop() {
+ return fmter
+ }
+ return fmter.WithArg(model)
+}
+
+//------------------------------------------------------------------------------
+
+type whereBaseQuery struct {
+ baseQuery
+
+ where []schema.QueryWithSep
+}
+
+func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) {
+ q.where = append(q.where, where)
+}
+
+func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) {
+ if len(where) == 0 {
+ return
+ }
+
+ where[0].Sep = ""
+
+ q.addWhere(schema.SafeQueryWithSep("", nil, sep+"("))
+ q.where = append(q.where, where...)
+ q.addWhere(schema.SafeQueryWithSep("", nil, ")"))
+}
+
+func (q *whereBaseQuery) mustAppendWhere(
+ fmter schema.Formatter, b []byte, withAlias bool,
+) ([]byte, error) {
+ if len(q.where) == 0 && !q.flags.Has(wherePKFlag) {
+ err := errors.New("bun: Update and Delete queries require at least one Where")
+ return nil, err
+ }
+ return q.appendWhere(fmter, b, withAlias)
+}
+
+func (q *whereBaseQuery) appendWhere(
+ fmter schema.Formatter, b []byte, withAlias bool,
+) (_ []byte, err error) {
+ if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) {
+ return b, nil
+ }
+
+ b = append(b, " WHERE "...)
+ startLen := len(b)
+
+ if len(q.where) > 0 {
+ b, err = appendWhere(fmter, b, q.where)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if q.isSoftDelete() {
+ if len(b) > startLen {
+ b = append(b, " AND "...)
+ }
+ if withAlias {
+ b = append(b, q.tableModel.Table().SQLAlias...)
+ b = append(b, '.')
+ }
+ b = append(b, q.tableModel.Table().SoftDeleteField.SQLName...)
+ if q.flags.Has(deletedFlag) {
+ b = append(b, " IS NOT NULL"...)
+ } else {
+ b = append(b, " IS NULL"...)
+ }
+ }
+
+ if q.flags.Has(wherePKFlag) {
+ if len(b) > startLen {
+ b = append(b, " AND "...)
+ }
+ b, err = q.appendWherePK(fmter, b, withAlias)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+}
+
+func appendWhere(
+ fmter schema.Formatter, b []byte, where []schema.QueryWithSep,
+) (_ []byte, err error) {
+ for i, where := range where {
+ if i > 0 || where.Sep == "(" {
+ b = append(b, where.Sep...)
+ }
+
+ if where.Query == "" && where.Args == nil {
+ continue
+ }
+
+ b = append(b, '(')
+ b, err = where.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ b = append(b, ')')
+ }
+ return b, nil
+}
+
+func (q *whereBaseQuery) appendWherePK(
+ fmter schema.Formatter, b []byte, withAlias bool,
+) (_ []byte, err error) {
+ if q.table == nil {
+ err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model)
+ return nil, err
+ }
+ if err := q.table.CheckPKs(); err != nil {
+ return nil, err
+ }
+
+ switch model := q.tableModel.(type) {
+ case *structTableModel:
+ return q.appendWherePKStruct(fmter, b, model, withAlias)
+ case *sliceTableModel:
+ return q.appendWherePKSlice(fmter, b, model, withAlias)
+ }
+
+ return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel)
+}
+
+func (q *whereBaseQuery) appendWherePKStruct(
+ fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool,
+) (_ []byte, err error) {
+ if !model.strct.IsValid() {
+ return nil, errNilModel
+ }
+
+ isTemplate := fmter.IsNop()
+ b = append(b, '(')
+ for i, f := range q.table.PKs {
+ if i > 0 {
+ b = append(b, " AND "...)
+ }
+ if withAlias {
+ b = append(b, q.table.SQLAlias...)
+ b = append(b, '.')
+ }
+ b = append(b, f.SQLName...)
+ b = append(b, " = "...)
+ if isTemplate {
+ b = append(b, '?')
+ } else {
+ b = f.AppendValue(fmter, b, model.strct)
+ }
+ }
+ b = append(b, ')')
+ return b, nil
+}
+
+func (q *whereBaseQuery) appendWherePKSlice(
+ fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool,
+) (_ []byte, err error) {
+ if len(q.table.PKs) > 1 {
+ b = append(b, '(')
+ }
+ if withAlias {
+ b = appendColumns(b, q.table.SQLAlias, q.table.PKs)
+ } else {
+ b = appendColumns(b, "", q.table.PKs)
+ }
+ if len(q.table.PKs) > 1 {
+ b = append(b, ')')
+ }
+
+ b = append(b, " IN ("...)
+
+ isTemplate := fmter.IsNop()
+ slice := model.slice
+ sliceLen := slice.Len()
+ for i := 0; i < sliceLen; i++ {
+ if i > 0 {
+ if isTemplate {
+ break
+ }
+ b = append(b, ", "...)
+ }
+
+ el := indirect(slice.Index(i))
+
+ if len(q.table.PKs) > 1 {
+ b = append(b, '(')
+ }
+ for i, f := range q.table.PKs {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ if isTemplate {
+ b = append(b, '?')
+ } else {
+ b = f.AppendValue(fmter, b, el)
+ }
+ }
+ if len(q.table.PKs) > 1 {
+ b = append(b, ')')
+ }
+ }
+
+ b = append(b, ')')
+
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+type returningQuery struct {
+ returning []schema.QueryWithArgs
+ returningFields []*schema.Field
+}
+
+func (q *returningQuery) addReturning(ret schema.QueryWithArgs) {
+ q.returning = append(q.returning, ret)
+}
+
+func (q *returningQuery) addReturningField(field *schema.Field) {
+ if len(q.returning) > 0 {
+ return
+ }
+ for _, f := range q.returningFields {
+ if f == field {
+ return
+ }
+ }
+ q.returningFields = append(q.returningFields, field)
+}
+
+func (q *returningQuery) hasReturning() bool {
+ if len(q.returning) == 1 {
+ switch q.returning[0].Query {
+ case "null", "NULL":
+ return false
+ }
+ }
+ return len(q.returning) > 0 || len(q.returningFields) > 0
+}
+
+func (q *returningQuery) appendReturning(
+ fmter schema.Formatter, b []byte,
+) (_ []byte, err error) {
+ if !q.hasReturning() {
+ return b, nil
+ }
+
+ b = append(b, " RETURNING "...)
+
+ for i, f := range q.returning {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b, err = f.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if len(q.returning) > 0 {
+ return b, nil
+ }
+
+ b = appendColumns(b, "", q.returningFields)
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+type columnValue struct {
+ column string
+ value schema.QueryWithArgs
+}
+
+type customValueQuery struct {
+ modelValues map[string]schema.QueryWithArgs
+ extraValues []columnValue
+}
+
+func (q *customValueQuery) addValue(
+ table *schema.Table, column string, value string, args []interface{},
+) {
+ if _, ok := table.FieldMap[column]; ok {
+ if q.modelValues == nil {
+ q.modelValues = make(map[string]schema.QueryWithArgs)
+ }
+ q.modelValues[column] = schema.SafeQuery(value, args)
+ } else {
+ q.extraValues = append(q.extraValues, columnValue{
+ column: column,
+ value: schema.SafeQuery(value, args),
+ })
+ }
+}
+
+//------------------------------------------------------------------------------
+
+type setQuery struct {
+ set []schema.QueryWithArgs
+}
+
+func (q *setQuery) addSet(set schema.QueryWithArgs) {
+ q.set = append(q.set, set)
+}
+
+func (q setQuery) appendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ for i, f := range q.set {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b, err = f.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+type cascadeQuery struct {
+ restrict bool
+}
+
+func (q cascadeQuery) appendCascade(fmter schema.Formatter, b []byte) []byte {
+ if !fmter.HasFeature(feature.TableCascade) {
+ return b
+ }
+ if q.restrict {
+ b = append(b, " RESTRICT"...)
+ } else {
+ b = append(b, " CASCADE"...)
+ }
+ return b
+}
diff --git a/vendor/github.com/uptrace/bun/query_column_add.go b/vendor/github.com/uptrace/bun/query_column_add.go
new file mode 100644
index 000000000..ce2f60bf0
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_column_add.go
@@ -0,0 +1,105 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type AddColumnQuery struct {
+ baseQuery
+}
+
+func NewAddColumnQuery(db *DB) *AddColumnQuery {
+ q := &AddColumnQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ }
+ return q
+}
+
+func (q *AddColumnQuery) Conn(db IConn) *AddColumnQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *AddColumnQuery) Model(model interface{}) *AddColumnQuery {
+ q.setTableModel(model)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *AddColumnQuery) Table(tables ...string) *AddColumnQuery {
+ for _, table := range tables {
+ q.addTable(schema.UnsafeIdent(table))
+ }
+ return q
+}
+
+func (q *AddColumnQuery) TableExpr(query string, args ...interface{}) *AddColumnQuery {
+ q.addTable(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *AddColumnQuery) ModelTableExpr(query string, args ...interface{}) *AddColumnQuery {
+ q.modelTable = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *AddColumnQuery) ColumnExpr(query string, args ...interface{}) *AddColumnQuery {
+ q.addColumn(schema.SafeQuery(query, args))
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *AddColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+ if len(q.columns) != 1 {
+ return nil, fmt.Errorf("bun: AddColumnQuery requires exactly one column")
+ }
+
+ b = append(b, "ALTER TABLE "...)
+
+ b, err = q.appendFirstTable(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, " ADD "...)
+
+ b, err = q.columns[0].AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *AddColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+
+ res, err := q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
diff --git a/vendor/github.com/uptrace/bun/query_column_drop.go b/vendor/github.com/uptrace/bun/query_column_drop.go
new file mode 100644
index 000000000..5684beeb3
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_column_drop.go
@@ -0,0 +1,112 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type DropColumnQuery struct {
+ baseQuery
+}
+
+func NewDropColumnQuery(db *DB) *DropColumnQuery {
+ q := &DropColumnQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ }
+ return q
+}
+
+func (q *DropColumnQuery) Conn(db IConn) *DropColumnQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *DropColumnQuery) Model(model interface{}) *DropColumnQuery {
+ q.setTableModel(model)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DropColumnQuery) Table(tables ...string) *DropColumnQuery {
+ for _, table := range tables {
+ q.addTable(schema.UnsafeIdent(table))
+ }
+ return q
+}
+
+func (q *DropColumnQuery) TableExpr(query string, args ...interface{}) *DropColumnQuery {
+ q.addTable(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *DropColumnQuery) ModelTableExpr(query string, args ...interface{}) *DropColumnQuery {
+ q.modelTable = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DropColumnQuery) Column(columns ...string) *DropColumnQuery {
+ for _, column := range columns {
+ q.addColumn(schema.UnsafeIdent(column))
+ }
+ return q
+}
+
+func (q *DropColumnQuery) ColumnExpr(query string, args ...interface{}) *DropColumnQuery {
+ q.addColumn(schema.SafeQuery(query, args))
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DropColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+ if len(q.columns) != 1 {
+ return nil, fmt.Errorf("bun: DropColumnQuery requires exactly one column")
+ }
+
+ b = append(b, "ALTER TABLE "...)
+
+ b, err = q.appendFirstTable(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, " DROP COLUMN "...)
+
+ b, err = q.columns[0].AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DropColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+
+ res, err := q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
diff --git a/vendor/github.com/uptrace/bun/query_delete.go b/vendor/github.com/uptrace/bun/query_delete.go
new file mode 100644
index 000000000..c0c5039c7
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_delete.go
@@ -0,0 +1,256 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type DeleteQuery struct {
+ whereBaseQuery
+ returningQuery
+}
+
+func NewDeleteQuery(db *DB) *DeleteQuery {
+ q := &DeleteQuery{
+ whereBaseQuery: whereBaseQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ },
+ }
+ return q
+}
+
+func (q *DeleteQuery) Conn(db IConn) *DeleteQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *DeleteQuery) Model(model interface{}) *DeleteQuery {
+ q.setTableModel(model)
+ return q
+}
+
+// Apply calls the fn passing the DeleteQuery as an argument.
+func (q *DeleteQuery) Apply(fn func(*DeleteQuery) *DeleteQuery) *DeleteQuery {
+ return fn(q)
+}
+
+func (q *DeleteQuery) With(name string, query schema.QueryAppender) *DeleteQuery {
+ q.addWith(name, query)
+ return q
+}
+
+func (q *DeleteQuery) Table(tables ...string) *DeleteQuery {
+ for _, table := range tables {
+ q.addTable(schema.UnsafeIdent(table))
+ }
+ return q
+}
+
+func (q *DeleteQuery) TableExpr(query string, args ...interface{}) *DeleteQuery {
+ q.addTable(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *DeleteQuery) ModelTableExpr(query string, args ...interface{}) *DeleteQuery {
+ q.modelTable = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DeleteQuery) WherePK() *DeleteQuery {
+ q.flags = q.flags.Set(wherePKFlag)
+ return q
+}
+
+func (q *DeleteQuery) Where(query string, args ...interface{}) *DeleteQuery {
+ q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
+ return q
+}
+
+func (q *DeleteQuery) WhereOr(query string, args ...interface{}) *DeleteQuery {
+ q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
+ return q
+}
+
+func (q *DeleteQuery) WhereGroup(sep string, fn func(*DeleteQuery) *DeleteQuery) *DeleteQuery {
+ saved := q.where
+ q.where = nil
+
+ q = fn(q)
+
+ where := q.where
+ q.where = saved
+
+ q.addWhereGroup(sep, where)
+
+ return q
+}
+
+func (q *DeleteQuery) WhereDeleted() *DeleteQuery {
+ q.whereDeleted()
+ return q
+}
+
+func (q *DeleteQuery) WhereAllWithDeleted() *DeleteQuery {
+ q.whereAllWithDeleted()
+ return q
+}
+
+func (q *DeleteQuery) ForceDelete() *DeleteQuery {
+ q.flags = q.flags.Set(forceDeleteFlag)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+// Returning adds a RETURNING clause to the query.
+//
+// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`.
+func (q *DeleteQuery) Returning(query string, args ...interface{}) *DeleteQuery {
+ q.addReturning(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *DeleteQuery) hasReturning() bool {
+ if !q.db.features.Has(feature.Returning) {
+ return false
+ }
+ return q.returningQuery.hasReturning()
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+ fmter = formatterWithModel(fmter, q)
+
+ if q.isSoftDelete() {
+ if err := q.tableModel.updateSoftDeleteField(); err != nil {
+ return nil, err
+ }
+
+ upd := UpdateQuery{
+ whereBaseQuery: q.whereBaseQuery,
+ returningQuery: q.returningQuery,
+ }
+ upd.Column(q.table.SoftDeleteField.Name)
+ return upd.AppendQuery(fmter, b)
+ }
+
+ q = q.WhereAllWithDeleted()
+ withAlias := q.db.features.Has(feature.DeleteTableAlias)
+
+ b, err = q.appendWith(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, "DELETE FROM "...)
+
+ if withAlias {
+ b, err = q.appendFirstTableWithAlias(fmter, b)
+ } else {
+ b, err = q.appendFirstTable(fmter, b)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ if q.hasMultiTables() {
+ b = append(b, " USING "...)
+ b, err = q.appendOtherTables(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ b, err = q.mustAppendWhere(fmter, b, withAlias)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(q.returning) > 0 {
+ b, err = q.appendReturning(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+}
+
+func (q *DeleteQuery) isSoftDelete() bool {
+ return q.tableModel != nil && q.table.SoftDeleteField != nil && !q.flags.Has(forceDeleteFlag)
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DeleteQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ if q.table != nil {
+ if err := q.beforeDeleteHook(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+
+ var res sql.Result
+
+ if hasDest := len(dest) > 0; hasDest || q.hasReturning() {
+ model, err := q.getModel(dest)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err = q.scan(ctx, q, query, model, hasDest)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ res, err = q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if q.table != nil {
+ if err := q.afterDeleteHook(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ return res, nil
+}
+
+func (q *DeleteQuery) beforeDeleteHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(BeforeDeleteHook); ok {
+ if err := hook.BeforeDelete(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (q *DeleteQuery) afterDeleteHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(AfterDeleteHook); ok {
+ if err := hook.AfterDelete(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/vendor/github.com/uptrace/bun/query_index_create.go b/vendor/github.com/uptrace/bun/query_index_create.go
new file mode 100644
index 000000000..de7eb7aa0
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_index_create.go
@@ -0,0 +1,242 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type CreateIndexQuery struct {
+ whereBaseQuery
+
+ unique bool
+ fulltext bool
+ spatial bool
+ concurrently bool
+ ifNotExists bool
+
+ index schema.QueryWithArgs
+ using schema.QueryWithArgs
+ include []schema.QueryWithArgs
+}
+
+func NewCreateIndexQuery(db *DB) *CreateIndexQuery {
+ q := &CreateIndexQuery{
+ whereBaseQuery: whereBaseQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ },
+ }
+ return q
+}
+
+func (q *CreateIndexQuery) Conn(db IConn) *CreateIndexQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *CreateIndexQuery) Model(model interface{}) *CreateIndexQuery {
+ q.setTableModel(model)
+ return q
+}
+
+func (q *CreateIndexQuery) Unique() *CreateIndexQuery {
+ q.unique = true
+ return q
+}
+
+func (q *CreateIndexQuery) Concurrently() *CreateIndexQuery {
+ q.concurrently = true
+ return q
+}
+
+func (q *CreateIndexQuery) IfNotExists() *CreateIndexQuery {
+ q.ifNotExists = true
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateIndexQuery) Index(query string) *CreateIndexQuery {
+ q.index = schema.UnsafeIdent(query)
+ return q
+}
+
+func (q *CreateIndexQuery) IndexExpr(query string, args ...interface{}) *CreateIndexQuery {
+ q.index = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateIndexQuery) Table(tables ...string) *CreateIndexQuery {
+ for _, table := range tables {
+ q.addTable(schema.UnsafeIdent(table))
+ }
+ return q
+}
+
+func (q *CreateIndexQuery) TableExpr(query string, args ...interface{}) *CreateIndexQuery {
+ q.addTable(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *CreateIndexQuery) ModelTableExpr(query string, args ...interface{}) *CreateIndexQuery {
+ q.modelTable = schema.SafeQuery(query, args)
+ return q
+}
+
+func (q *CreateIndexQuery) Using(query string, args ...interface{}) *CreateIndexQuery {
+ q.using = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateIndexQuery) Column(columns ...string) *CreateIndexQuery {
+ for _, column := range columns {
+ q.addColumn(schema.UnsafeIdent(column))
+ }
+ return q
+}
+
+func (q *CreateIndexQuery) ColumnExpr(query string, args ...interface{}) *CreateIndexQuery {
+ q.addColumn(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *CreateIndexQuery) ExcludeColumn(columns ...string) *CreateIndexQuery {
+ q.excludeColumn(columns)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateIndexQuery) Include(columns ...string) *CreateIndexQuery {
+ for _, column := range columns {
+ q.include = append(q.include, schema.UnsafeIdent(column))
+ }
+ return q
+}
+
+func (q *CreateIndexQuery) IncludeExpr(query string, args ...interface{}) *CreateIndexQuery {
+ q.include = append(q.include, schema.SafeQuery(query, args))
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateIndexQuery) Where(query string, args ...interface{}) *CreateIndexQuery {
+ q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
+ return q
+}
+
+func (q *CreateIndexQuery) WhereOr(query string, args ...interface{}) *CreateIndexQuery {
+ q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+
+ b = append(b, "CREATE "...)
+
+ if q.unique {
+ b = append(b, "UNIQUE "...)
+ }
+ if q.fulltext {
+ b = append(b, "FULLTEXT "...)
+ }
+ if q.spatial {
+ b = append(b, "SPATIAL "...)
+ }
+
+ b = append(b, "INDEX "...)
+
+ if q.concurrently {
+ b = append(b, "CONCURRENTLY "...)
+ }
+ if q.ifNotExists {
+ b = append(b, "IF NOT EXISTS "...)
+ }
+
+ b, err = q.index.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, " ON "...)
+ b, err = q.appendFirstTable(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ if !q.using.IsZero() {
+ b = append(b, " USING "...)
+ b, err = q.using.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ b = append(b, " ("...)
+ for i, col := range q.columns {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b, err = col.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ b = append(b, ')')
+
+ if len(q.include) > 0 {
+ b = append(b, " INCLUDE ("...)
+ for i, col := range q.include {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b, err = col.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ b = append(b, ')')
+ }
+
+ if len(q.where) > 0 {
+ b, err = appendWhere(fmter, b, q.where)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+
+ res, err := q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
diff --git a/vendor/github.com/uptrace/bun/query_index_drop.go b/vendor/github.com/uptrace/bun/query_index_drop.go
new file mode 100644
index 000000000..c922ff04f
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_index_drop.go
@@ -0,0 +1,105 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type DropIndexQuery struct {
+ baseQuery
+ cascadeQuery
+
+ concurrently bool
+ ifExists bool
+
+ index schema.QueryWithArgs
+}
+
+func NewDropIndexQuery(db *DB) *DropIndexQuery {
+ q := &DropIndexQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ }
+ return q
+}
+
+func (q *DropIndexQuery) Conn(db IConn) *DropIndexQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *DropIndexQuery) Model(model interface{}) *DropIndexQuery {
+ q.setTableModel(model)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DropIndexQuery) Concurrently() *DropIndexQuery {
+ q.concurrently = true
+ return q
+}
+
+func (q *DropIndexQuery) IfExists() *DropIndexQuery {
+ q.ifExists = true
+ return q
+}
+
+func (q *DropIndexQuery) Restrict() *DropIndexQuery {
+ q.restrict = true
+ return q
+}
+
+func (q *DropIndexQuery) Index(query string, args ...interface{}) *DropIndexQuery {
+ q.index = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DropIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+
+ b = append(b, "DROP INDEX "...)
+
+ if q.concurrently {
+ b = append(b, "CONCURRENTLY "...)
+ }
+ if q.ifExists {
+ b = append(b, "IF EXISTS "...)
+ }
+
+ b, err = q.index.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = q.appendCascade(fmter, b)
+
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DropIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+
+ res, err := q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
diff --git a/vendor/github.com/uptrace/bun/query_insert.go b/vendor/github.com/uptrace/bun/query_insert.go
new file mode 100644
index 000000000..efddee407
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_insert.go
@@ -0,0 +1,551 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type InsertQuery struct {
+ whereBaseQuery
+ returningQuery
+ customValueQuery
+
+ onConflict schema.QueryWithArgs
+ setQuery
+
+ ignore bool
+ replace bool
+}
+
+func NewInsertQuery(db *DB) *InsertQuery {
+ q := &InsertQuery{
+ whereBaseQuery: whereBaseQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ },
+ }
+ return q
+}
+
+func (q *InsertQuery) Conn(db IConn) *InsertQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *InsertQuery) Model(model interface{}) *InsertQuery {
+ q.setTableModel(model)
+ return q
+}
+
+// Apply calls the fn passing the SelectQuery as an argument.
+func (q *InsertQuery) Apply(fn func(*InsertQuery) *InsertQuery) *InsertQuery {
+ return fn(q)
+}
+
+func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery {
+ q.addWith(name, query)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *InsertQuery) Table(tables ...string) *InsertQuery {
+ for _, table := range tables {
+ q.addTable(schema.UnsafeIdent(table))
+ }
+ return q
+}
+
+func (q *InsertQuery) TableExpr(query string, args ...interface{}) *InsertQuery {
+ q.addTable(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *InsertQuery) ModelTableExpr(query string, args ...interface{}) *InsertQuery {
+ q.modelTable = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *InsertQuery) Column(columns ...string) *InsertQuery {
+ for _, column := range columns {
+ q.addColumn(schema.UnsafeIdent(column))
+ }
+ return q
+}
+
+func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery {
+ q.excludeColumn(columns)
+ return q
+}
+
+// Value overwrites model value for the column in INSERT and UPDATE queries.
+func (q *InsertQuery) Value(column string, value string, args ...interface{}) *InsertQuery {
+ if q.table == nil {
+ q.err = errNilModel
+ return q
+ }
+ q.addValue(q.table, column, value, args)
+ return q
+}
+
+func (q *InsertQuery) Where(query string, args ...interface{}) *InsertQuery {
+ q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
+ return q
+}
+
+func (q *InsertQuery) WhereOr(query string, args ...interface{}) *InsertQuery {
+ q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+// Returning adds a RETURNING clause to the query.
+//
+// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`.
+func (q *InsertQuery) Returning(query string, args ...interface{}) *InsertQuery {
+ q.addReturning(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *InsertQuery) hasReturning() bool {
+ if !q.db.features.Has(feature.Returning) {
+ return false
+ }
+ return q.returningQuery.hasReturning()
+}
+
+//------------------------------------------------------------------------------
+
+// Ignore generates an `INSERT IGNORE INTO` query (MySQL).
+func (q *InsertQuery) Ignore() *InsertQuery {
+ q.ignore = true
+ return q
+}
+
+// Replaces generates a `REPLACE INTO` query (MySQL).
+func (q *InsertQuery) Replace() *InsertQuery {
+ q.replace = true
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+ fmter = formatterWithModel(fmter, q)
+
+ b, err = q.appendWith(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ if q.replace {
+ b = append(b, "REPLACE "...)
+ } else {
+ b = append(b, "INSERT "...)
+ if q.ignore {
+ b = append(b, "IGNORE "...)
+ }
+ }
+ b = append(b, "INTO "...)
+
+ if q.db.features.Has(feature.InsertTableAlias) && !q.onConflict.IsZero() {
+ b, err = q.appendFirstTableWithAlias(fmter, b)
+ } else {
+ b, err = q.appendFirstTable(fmter, b)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ b, err = q.appendColumnsValues(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b, err = q.appendOn(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ if q.hasReturning() {
+ b, err = q.appendReturning(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+}
+
+func (q *InsertQuery) appendColumnsValues(
+ fmter schema.Formatter, b []byte,
+) (_ []byte, err error) {
+ if q.hasMultiTables() {
+ if q.columns != nil {
+ b = append(b, " ("...)
+ b, err = q.appendColumns(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ b = append(b, ")"...)
+ }
+
+ b = append(b, " SELECT * FROM "...)
+ b, err = q.appendOtherTables(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ return b, nil
+ }
+
+ if m, ok := q.model.(*mapModel); ok {
+ return m.appendColumnsValues(fmter, b), nil
+ }
+ if _, ok := q.model.(*mapSliceModel); ok {
+ return nil, fmt.Errorf("Insert(*[]map[string]interface{}) is not supported")
+ }
+
+ if q.model == nil {
+ return nil, errNilModel
+ }
+
+ fields, err := q.getFields()
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, " ("...)
+ b = q.appendFields(fmter, b, fields)
+ b = append(b, ") VALUES ("...)
+
+ switch model := q.tableModel.(type) {
+ case *structTableModel:
+ b, err = q.appendStructValues(fmter, b, fields, model.strct)
+ if err != nil {
+ return nil, err
+ }
+ case *sliceTableModel:
+ b, err = q.appendSliceValues(fmter, b, fields, model.slice)
+ if err != nil {
+ return nil, err
+ }
+ default:
+ return nil, fmt.Errorf("bun: Insert does not support %T", q.tableModel)
+ }
+
+ b = append(b, ')')
+
+ return b, nil
+}
+
+func (q *InsertQuery) appendStructValues(
+ fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value,
+) (_ []byte, err error) {
+ isTemplate := fmter.IsNop()
+ for i, f := range fields {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ app, ok := q.modelValues[f.Name]
+ if ok {
+ b, err = app.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ q.addReturningField(f)
+ continue
+ }
+
+ switch {
+ case isTemplate:
+ b = append(b, '?')
+ case f.NullZero && f.HasZeroValue(strct):
+ if q.db.features.Has(feature.DefaultPlaceholder) {
+ b = append(b, "DEFAULT"...)
+ } else if f.SQLDefault != "" {
+ b = append(b, f.SQLDefault...)
+ } else {
+ b = append(b, "NULL"...)
+ }
+ q.addReturningField(f)
+ default:
+ b = f.AppendValue(fmter, b, strct)
+ }
+ }
+
+ for i, v := range q.extraValues {
+ if i > 0 || len(fields) > 0 {
+ b = append(b, ", "...)
+ }
+
+ b, err = v.value.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+}
+
+func (q *InsertQuery) appendSliceValues(
+ fmter schema.Formatter, b []byte, fields []*schema.Field, slice reflect.Value,
+) (_ []byte, err error) {
+ if fmter.IsNop() {
+ return q.appendStructValues(fmter, b, fields, reflect.Value{})
+ }
+
+ sliceLen := slice.Len()
+ for i := 0; i < sliceLen; i++ {
+ if i > 0 {
+ b = append(b, "), ("...)
+ }
+ el := indirect(slice.Index(i))
+ b, err = q.appendStructValues(fmter, b, fields, el)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ for i, v := range q.extraValues {
+ if i > 0 || len(fields) > 0 {
+ b = append(b, ", "...)
+ }
+
+ b, err = v.value.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+}
+
+func (q *InsertQuery) getFields() ([]*schema.Field, error) {
+ if q.db.features.Has(feature.DefaultPlaceholder) || len(q.columns) > 0 {
+ return q.baseQuery.getFields()
+ }
+
+ var strct reflect.Value
+
+ switch model := q.tableModel.(type) {
+ case *structTableModel:
+ strct = model.strct
+ case *sliceTableModel:
+ if model.sliceLen == 0 {
+ return nil, fmt.Errorf("bun: Insert(empty %T)", model.slice.Type())
+ }
+ strct = indirect(model.slice.Index(0))
+ }
+
+ fields := make([]*schema.Field, 0, len(q.table.Fields))
+
+ for _, f := range q.table.Fields {
+ if f.NotNull && f.NullZero && f.SQLDefault == "" && f.HasZeroValue(strct) {
+ q.addReturningField(f)
+ continue
+ }
+ fields = append(fields, f)
+ }
+
+ return fields, nil
+}
+
+func (q *InsertQuery) appendFields(
+ fmter schema.Formatter, b []byte, fields []*schema.Field,
+) []byte {
+ b = appendColumns(b, "", fields)
+ for i, v := range q.extraValues {
+ if i > 0 || len(fields) > 0 {
+ b = append(b, ", "...)
+ }
+ b = fmter.AppendIdent(b, v.column)
+ }
+ return b
+}
+
+//------------------------------------------------------------------------------
+
+func (q *InsertQuery) On(s string, args ...interface{}) *InsertQuery {
+ q.onConflict = schema.SafeQuery(s, args)
+ return q
+}
+
+func (q *InsertQuery) Set(query string, args ...interface{}) *InsertQuery {
+ q.addSet(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.onConflict.IsZero() {
+ return b, nil
+ }
+
+ b = append(b, " ON "...)
+ b, err = q.onConflict.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(q.set) > 0 {
+ if fmter.HasFeature(feature.OnDuplicateKey) {
+ b = append(b, ' ')
+ } else {
+ b = append(b, " SET "...)
+ }
+
+ b, err = q.appendSet(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ } else if len(q.columns) > 0 {
+ fields, err := q.getDataFields()
+ if err != nil {
+ return nil, err
+ }
+
+ if len(fields) == 0 {
+ fields = q.tableModel.Table().DataFields
+ }
+
+ b = q.appendSetExcluded(b, fields)
+ }
+
+ b, err = q.appendWhere(fmter, b, true)
+ if err != nil {
+ return nil, err
+ }
+
+ return b, nil
+}
+
+func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte {
+ b = append(b, " SET "...)
+ for i, f := range fields {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b = append(b, f.SQLName...)
+ b = append(b, " = EXCLUDED."...)
+ b = append(b, f.SQLName...)
+ }
+ return b
+}
+
+//------------------------------------------------------------------------------
+
+func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ if q.table != nil {
+ if err := q.beforeInsertHook(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+ var res sql.Result
+
+ if hasDest := len(dest) > 0; hasDest || q.hasReturning() {
+ model, err := q.getModel(dest)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err = q.scan(ctx, q, query, model, hasDest)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ res, err = q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := q.tryLastInsertID(res, dest); err != nil {
+ return nil, err
+ }
+ }
+
+ if q.table != nil {
+ if err := q.afterInsertHook(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ return res, nil
+}
+
+func (q *InsertQuery) beforeInsertHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(BeforeInsertHook); ok {
+ if err := hook.BeforeInsert(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (q *InsertQuery) afterInsertHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(AfterInsertHook); ok {
+ if err := hook.AfterInsert(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error {
+ if q.db.features.Has(feature.Returning) || q.table == nil || len(q.table.PKs) != 1 {
+ return nil
+ }
+
+ id, err := res.LastInsertId()
+ if err != nil {
+ return err
+ }
+ if id == 0 {
+ return nil
+ }
+
+ model, err := q.getModel(dest)
+ if err != nil {
+ return err
+ }
+
+ pk := q.table.PKs[0]
+ switch model := model.(type) {
+ case *structTableModel:
+ if err := pk.ScanValue(model.strct, id); err != nil {
+ return err
+ }
+ case *sliceTableModel:
+ sliceLen := model.slice.Len()
+ for i := 0; i < sliceLen; i++ {
+ strct := indirect(model.slice.Index(i))
+ if err := pk.ScanValue(strct, id); err != nil {
+ return err
+ }
+ id++
+ }
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go
new file mode 100644
index 000000000..1f63686ad
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_select.go
@@ -0,0 +1,830 @@
+package bun
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+ "sync"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type union struct {
+ expr string
+ query *SelectQuery
+}
+
+type SelectQuery struct {
+ whereBaseQuery
+
+ distinctOn []schema.QueryWithArgs
+ joins []joinQuery
+ group []schema.QueryWithArgs
+ having []schema.QueryWithArgs
+ order []schema.QueryWithArgs
+ limit int32
+ offset int32
+ selFor schema.QueryWithArgs
+
+ union []union
+}
+
+func NewSelectQuery(db *DB) *SelectQuery {
+ return &SelectQuery{
+ whereBaseQuery: whereBaseQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ },
+ }
+}
+
+func (q *SelectQuery) Conn(db IConn) *SelectQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *SelectQuery) Model(model interface{}) *SelectQuery {
+ q.setTableModel(model)
+ return q
+}
+
+// Apply calls the fn passing the SelectQuery as an argument.
+func (q *SelectQuery) Apply(fn func(*SelectQuery) *SelectQuery) *SelectQuery {
+ return fn(q)
+}
+
+func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery {
+ q.addWith(name, query)
+ return q
+}
+
+func (q *SelectQuery) Distinct() *SelectQuery {
+ q.distinctOn = make([]schema.QueryWithArgs, 0)
+ return q
+}
+
+func (q *SelectQuery) DistinctOn(query string, args ...interface{}) *SelectQuery {
+ q.distinctOn = append(q.distinctOn, schema.SafeQuery(query, args))
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *SelectQuery) Table(tables ...string) *SelectQuery {
+ for _, table := range tables {
+ q.addTable(schema.UnsafeIdent(table))
+ }
+ return q
+}
+
+func (q *SelectQuery) TableExpr(query string, args ...interface{}) *SelectQuery {
+ q.addTable(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *SelectQuery) ModelTableExpr(query string, args ...interface{}) *SelectQuery {
+ q.modelTable = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *SelectQuery) Column(columns ...string) *SelectQuery {
+ for _, column := range columns {
+ q.addColumn(schema.UnsafeIdent(column))
+ }
+ return q
+}
+
+func (q *SelectQuery) ColumnExpr(query string, args ...interface{}) *SelectQuery {
+ q.addColumn(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *SelectQuery) ExcludeColumn(columns ...string) *SelectQuery {
+ q.excludeColumn(columns)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *SelectQuery) WherePK() *SelectQuery {
+ q.flags = q.flags.Set(wherePKFlag)
+ return q
+}
+
+func (q *SelectQuery) Where(query string, args ...interface{}) *SelectQuery {
+ q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
+ return q
+}
+
+func (q *SelectQuery) WhereOr(query string, args ...interface{}) *SelectQuery {
+ q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
+ return q
+}
+
+func (q *SelectQuery) WhereGroup(sep string, fn func(*SelectQuery) *SelectQuery) *SelectQuery {
+ saved := q.where
+ q.where = nil
+
+ q = fn(q)
+
+ where := q.where
+ q.where = saved
+
+ q.addWhereGroup(sep, where)
+
+ return q
+}
+
+func (q *SelectQuery) WhereDeleted() *SelectQuery {
+ q.whereDeleted()
+ return q
+}
+
+func (q *SelectQuery) WhereAllWithDeleted() *SelectQuery {
+ q.whereAllWithDeleted()
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *SelectQuery) Group(columns ...string) *SelectQuery {
+ for _, column := range columns {
+ q.group = append(q.group, schema.UnsafeIdent(column))
+ }
+ return q
+}
+
+func (q *SelectQuery) GroupExpr(group string, args ...interface{}) *SelectQuery {
+ q.group = append(q.group, schema.SafeQuery(group, args))
+ return q
+}
+
+func (q *SelectQuery) Having(having string, args ...interface{}) *SelectQuery {
+ q.having = append(q.having, schema.SafeQuery(having, args))
+ return q
+}
+
+func (q *SelectQuery) Order(orders ...string) *SelectQuery {
+ for _, order := range orders {
+ if order == "" {
+ continue
+ }
+
+ index := strings.IndexByte(order, ' ')
+ if index == -1 {
+ q.order = append(q.order, schema.UnsafeIdent(order))
+ continue
+ }
+
+ field := order[:index]
+ sort := order[index+1:]
+
+ switch strings.ToUpper(sort) {
+ case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST",
+ "ASC NULLS LAST", "DESC NULLS LAST":
+ q.order = append(q.order, schema.SafeQuery("? ?", []interface{}{
+ Ident(field),
+ Safe(sort),
+ }))
+ default:
+ q.order = append(q.order, schema.UnsafeIdent(order))
+ }
+ }
+ return q
+}
+
+func (q *SelectQuery) OrderExpr(query string, args ...interface{}) *SelectQuery {
+ q.order = append(q.order, schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *SelectQuery) Limit(n int) *SelectQuery {
+ q.limit = int32(n)
+ return q
+}
+
+func (q *SelectQuery) Offset(n int) *SelectQuery {
+ q.offset = int32(n)
+ return q
+}
+
+func (q *SelectQuery) For(s string, args ...interface{}) *SelectQuery {
+ q.selFor = schema.SafeQuery(s, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *SelectQuery) Union(other *SelectQuery) *SelectQuery {
+ return q.addUnion(" UNION ", other)
+}
+
+func (q *SelectQuery) UnionAll(other *SelectQuery) *SelectQuery {
+ return q.addUnion(" UNION ALL ", other)
+}
+
+func (q *SelectQuery) Intersect(other *SelectQuery) *SelectQuery {
+ return q.addUnion(" INTERSECT ", other)
+}
+
+func (q *SelectQuery) IntersectAll(other *SelectQuery) *SelectQuery {
+ return q.addUnion(" INTERSECT ALL ", other)
+}
+
+func (q *SelectQuery) Except(other *SelectQuery) *SelectQuery {
+ return q.addUnion(" EXCEPT ", other)
+}
+
+func (q *SelectQuery) ExceptAll(other *SelectQuery) *SelectQuery {
+ return q.addUnion(" EXCEPT ALL ", other)
+}
+
+func (q *SelectQuery) addUnion(expr string, other *SelectQuery) *SelectQuery {
+ q.union = append(q.union, union{
+ expr: expr,
+ query: other,
+ })
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *SelectQuery) Join(join string, args ...interface{}) *SelectQuery {
+ q.joins = append(q.joins, joinQuery{
+ join: schema.SafeQuery(join, args),
+ })
+ return q
+}
+
+func (q *SelectQuery) JoinOn(cond string, args ...interface{}) *SelectQuery {
+ return q.joinOn(cond, args, " AND ")
+}
+
+func (q *SelectQuery) JoinOnOr(cond string, args ...interface{}) *SelectQuery {
+ return q.joinOn(cond, args, " OR ")
+}
+
+func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *SelectQuery {
+ if len(q.joins) == 0 {
+ q.err = errors.New("bun: query has no joins")
+ return q
+ }
+ j := &q.joins[len(q.joins)-1]
+ j.on = append(j.on, schema.SafeQueryWithSep(cond, args, sep))
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+// Relation adds a relation to the query. Relation name can be:
+// - RelationName to select all columns,
+// - RelationName.column_name,
+// - RelationName._ to join relation without selecting relation columns.
+func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQuery) *SelectQuery {
+ if q.tableModel == nil {
+ q.setErr(errNilModel)
+ return q
+ }
+
+ var fn func(*SelectQuery) *SelectQuery
+
+ if len(apply) == 1 {
+ fn = apply[0]
+ } else if len(apply) > 1 {
+ panic("only one apply function is supported")
+ }
+
+ join := q.tableModel.Join(name, fn)
+ if join == nil {
+ q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name))
+ return q
+ }
+
+ return q
+}
+
+func (q *SelectQuery) forEachHasOneJoin(fn func(*join) error) error {
+ if q.tableModel == nil {
+ return nil
+ }
+ return q._forEachHasOneJoin(fn, q.tableModel.GetJoins())
+}
+
+func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) error {
+ for i := range joins {
+ j := &joins[i]
+ switch j.Relation.Type {
+ case schema.HasOneRelation, schema.BelongsToRelation:
+ if err := fn(j); err != nil {
+ return err
+ }
+ if err := q._forEachHasOneJoin(fn, j.JoinModel.GetJoins()); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func (q *SelectQuery) selectJoins(ctx context.Context, joins []join) error {
+ var err error
+ for i := range joins {
+ j := &joins[i]
+ switch j.Relation.Type {
+ case schema.HasOneRelation, schema.BelongsToRelation:
+ err = q.selectJoins(ctx, j.JoinModel.GetJoins())
+ default:
+ err = j.Select(ctx, q.db.NewSelect())
+ }
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *SelectQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ return q.appendQuery(fmter, b, false)
+}
+
+func (q *SelectQuery) appendQuery(
+ fmter schema.Formatter, b []byte, count bool,
+) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+ fmter = formatterWithModel(fmter, q)
+
+ cteCount := count && (len(q.group) > 0 || q.distinctOn != nil)
+ if cteCount {
+ b = append(b, "WITH _count_wrapper AS ("...)
+ }
+
+ if len(q.union) > 0 {
+ b = append(b, '(')
+ }
+
+ b, err = q.appendWith(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, "SELECT "...)
+
+ if len(q.distinctOn) > 0 {
+ b = append(b, "DISTINCT ON ("...)
+ for i, app := range q.distinctOn {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b, err = app.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ b = append(b, ") "...)
+ } else if q.distinctOn != nil {
+ b = append(b, "DISTINCT "...)
+ }
+
+ if count && !cteCount {
+ b = append(b, "count(*)"...)
+ } else {
+ b, err = q.appendColumns(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if q.hasTables() {
+ b, err = q.appendTables(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if err := q.forEachHasOneJoin(func(j *join) error {
+ b = append(b, ' ')
+ b, err = j.appendHasOneJoin(fmter, b, q)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+
+ for _, j := range q.joins {
+ b, err = j.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ b, err = q.appendWhere(fmter, b, true)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(q.group) > 0 {
+ b = append(b, " GROUP BY "...)
+ for i, f := range q.group {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b, err = f.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ if len(q.having) > 0 {
+ b = append(b, " HAVING "...)
+ for i, f := range q.having {
+ if i > 0 {
+ b = append(b, " AND "...)
+ }
+ b = append(b, '(')
+ b, err = f.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ b = append(b, ')')
+ }
+ }
+
+ if !count {
+ b, err = q.appendOrder(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ if q.limit != 0 {
+ b = append(b, " LIMIT "...)
+ b = strconv.AppendInt(b, int64(q.limit), 10)
+ }
+
+ if q.offset != 0 {
+ b = append(b, " OFFSET "...)
+ b = strconv.AppendInt(b, int64(q.offset), 10)
+ }
+
+ if !q.selFor.IsZero() {
+ b = append(b, " FOR "...)
+ b, err = q.selFor.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ if len(q.union) > 0 {
+ b = append(b, ')')
+
+ for _, u := range q.union {
+ b = append(b, u.expr...)
+ b = append(b, '(')
+ b, err = u.query.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ b = append(b, ')')
+ }
+ }
+
+ if cteCount {
+ b = append(b, ") SELECT count(*) FROM _count_wrapper"...)
+ }
+
+ return b, nil
+}
+
+func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ start := len(b)
+
+ switch {
+ case q.columns != nil:
+ for i, col := range q.columns {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ if col.Args == nil {
+ if field, ok := q.table.FieldMap[col.Query]; ok {
+ b = append(b, q.table.SQLAlias...)
+ b = append(b, '.')
+ b = append(b, field.SQLName...)
+ continue
+ }
+ }
+
+ b, err = col.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ case q.table != nil:
+ if len(q.table.Fields) > 10 && fmter.IsNop() {
+ b = append(b, q.table.SQLAlias...)
+ b = append(b, '.')
+ b = dialect.AppendString(b, fmt.Sprintf("%d columns", len(q.table.Fields)))
+ } else {
+ b = appendColumns(b, q.table.SQLAlias, q.table.Fields)
+ }
+ default:
+ b = append(b, '*')
+ }
+
+ if err := q.forEachHasOneJoin(func(j *join) error {
+ if len(b) != start {
+ b = append(b, ", "...)
+ start = len(b)
+ }
+
+ b, err = q.appendHasOneColumns(fmter, b, j)
+ if err != nil {
+ return err
+ }
+
+ return nil
+ }); err != nil {
+ return nil, err
+ }
+
+ b = bytes.TrimSuffix(b, []byte(", "))
+
+ return b, nil
+}
+
+func (q *SelectQuery) appendHasOneColumns(
+ fmter schema.Formatter, b []byte, join *join,
+) (_ []byte, err error) {
+ join.applyQuery(q)
+
+ if join.columns != nil {
+ for i, col := range join.columns {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ if col.Args == nil {
+ if field, ok := q.table.FieldMap[col.Query]; ok {
+ b = join.appendAlias(fmter, b)
+ b = append(b, '.')
+ b = append(b, field.SQLName...)
+ b = append(b, " AS "...)
+ b = join.appendAliasColumn(fmter, b, field.Name)
+ continue
+ }
+ }
+
+ b, err = col.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return b, nil
+ }
+
+ for i, field := range join.JoinModel.Table().Fields {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b = join.appendAlias(fmter, b)
+ b = append(b, '.')
+ b = append(b, field.SQLName...)
+ b = append(b, " AS "...)
+ b = join.appendAliasColumn(fmter, b, field.Name)
+ }
+ return b, nil
+}
+
+func (q *SelectQuery) appendTables(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ b = append(b, " FROM "...)
+ return q.appendTablesWithAlias(fmter, b)
+}
+
+func (q *SelectQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if len(q.order) > 0 {
+ b = append(b, " ORDER BY "...)
+
+ for i, f := range q.order {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b, err = f.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+ }
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) {
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+ return q.conn.QueryContext(ctx, query)
+}
+
+func (q *SelectQuery) Exec(ctx context.Context) (res sql.Result, err error) {
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+
+ res, err = q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error {
+ model, err := q.getModel(dest)
+ if err != nil {
+ return err
+ }
+
+ if q.limit > 1 {
+ if model, ok := model.(interface{ SetCap(int) }); ok {
+ model.SetCap(int(q.limit))
+ }
+ }
+
+ if q.table != nil {
+ if err := q.beforeSelectHook(ctx); err != nil {
+ return err
+ }
+ }
+
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return err
+ }
+
+ query := internal.String(queryBytes)
+
+ res, err := q.scan(ctx, q, query, model, true)
+ if err != nil {
+ return err
+ }
+
+ if res.n > 0 {
+ if tableModel, ok := model.(tableModel); ok {
+ if err := q.selectJoins(ctx, tableModel.GetJoins()); err != nil {
+ return err
+ }
+ }
+ }
+
+ if q.table != nil {
+ if err := q.afterSelectHook(ctx); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (q *SelectQuery) beforeSelectHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(BeforeSelectHook); ok {
+ if err := hook.BeforeSelect(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (q *SelectQuery) afterSelectHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(AfterSelectHook); ok {
+ if err := hook.AfterSelect(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (q *SelectQuery) Count(ctx context.Context) (int, error) {
+ qq := countQuery{q}
+
+ queryBytes, err := qq.appendQuery(q.db.fmter, nil, true)
+ if err != nil {
+ return 0, err
+ }
+
+ query := internal.String(queryBytes)
+ ctx, event := q.db.beforeQuery(ctx, qq, query, nil)
+
+ var num int
+ err = q.conn.QueryRowContext(ctx, query).Scan(&num)
+
+ q.db.afterQuery(ctx, event, nil, err)
+
+ return num, err
+}
+
+func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) {
+ var count int
+ var wg sync.WaitGroup
+ var mu sync.Mutex
+ var firstErr error
+
+ if q.limit >= 0 {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ if err := q.Scan(ctx, dest...); err != nil {
+ mu.Lock()
+ if firstErr == nil {
+ firstErr = err
+ }
+ mu.Unlock()
+ }
+ }()
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ var err error
+ count, err = q.Count(ctx)
+ if err != nil {
+ mu.Lock()
+ if firstErr == nil {
+ firstErr = err
+ }
+ mu.Unlock()
+ }
+ }()
+
+ wg.Wait()
+ return count, firstErr
+}
+
+//------------------------------------------------------------------------------
+
+type joinQuery struct {
+ join schema.QueryWithArgs
+ on []schema.QueryWithSep
+}
+
+func (j *joinQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ b = append(b, ' ')
+
+ b, err = j.join.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(j.on) > 0 {
+ b = append(b, " ON "...)
+ for i, on := range j.on {
+ if i > 0 {
+ b = append(b, on.Sep...)
+ }
+
+ b = append(b, '(')
+ b, err = on.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ b = append(b, ')')
+ }
+ }
+
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+type countQuery struct {
+ *SelectQuery
+}
+
+func (q countQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ return q.appendQuery(fmter, b, true)
+}
diff --git a/vendor/github.com/uptrace/bun/query_table_create.go b/vendor/github.com/uptrace/bun/query_table_create.go
new file mode 100644
index 000000000..0a4b3567c
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_table_create.go
@@ -0,0 +1,275 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "sort"
+ "strconv"
+
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/dialect/sqltype"
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type CreateTableQuery struct {
+ baseQuery
+
+ temp bool
+ ifNotExists bool
+ varchar int
+
+ fks []schema.QueryWithArgs
+ partitionBy schema.QueryWithArgs
+ tablespace schema.QueryWithArgs
+}
+
+func NewCreateTableQuery(db *DB) *CreateTableQuery {
+ q := &CreateTableQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ }
+ return q
+}
+
+func (q *CreateTableQuery) Conn(db IConn) *CreateTableQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery {
+ q.setTableModel(model)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery {
+ for _, table := range tables {
+ q.addTable(schema.UnsafeIdent(table))
+ }
+ return q
+}
+
+func (q *CreateTableQuery) TableExpr(query string, args ...interface{}) *CreateTableQuery {
+ q.addTable(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *CreateTableQuery) ModelTableExpr(query string, args ...interface{}) *CreateTableQuery {
+ q.modelTable = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateTableQuery) Temp() *CreateTableQuery {
+ q.temp = true
+ return q
+}
+
+func (q *CreateTableQuery) IfNotExists() *CreateTableQuery {
+ q.ifNotExists = true
+ return q
+}
+
+func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery {
+ q.varchar = n
+ return q
+}
+
+func (q *CreateTableQuery) ForeignKey(query string, args ...interface{}) *CreateTableQuery {
+ q.fks = append(q.fks, schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+ if q.table == nil {
+ return nil, errNilModel
+ }
+
+ b = append(b, "CREATE "...)
+ if q.temp {
+ b = append(b, "TEMP "...)
+ }
+ b = append(b, "TABLE "...)
+ if q.ifNotExists {
+ b = append(b, "IF NOT EXISTS "...)
+ }
+ b, err = q.appendFirstTable(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, " ("...)
+
+ for i, field := range q.table.Fields {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ b = append(b, field.SQLName...)
+ b = append(b, " "...)
+ b = q.appendSQLType(b, field)
+ if field.NotNull {
+ b = append(b, " NOT NULL"...)
+ }
+ if q.db.features.Has(feature.AutoIncrement) && field.AutoIncrement {
+ b = append(b, " AUTO_INCREMENT"...)
+ }
+ if field.SQLDefault != "" {
+ b = append(b, " DEFAULT "...)
+ b = append(b, field.SQLDefault...)
+ }
+ }
+
+ b = q.appendPKConstraint(b, q.table.PKs)
+ b = q.appendUniqueConstraints(fmter, b)
+ b, err = q.appenFKConstraints(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, ")"...)
+
+ if !q.partitionBy.IsZero() {
+ b = append(b, " PARTITION BY "...)
+ b, err = q.partitionBy.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if !q.tablespace.IsZero() {
+ b = append(b, " TABLESPACE "...)
+ b, err = q.tablespace.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+}
+
+func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte {
+ if field.CreateTableSQLType != field.DiscoveredSQLType {
+ return append(b, field.CreateTableSQLType...)
+ }
+
+ if q.varchar > 0 &&
+ field.CreateTableSQLType == sqltype.VarChar {
+ b = append(b, "varchar("...)
+ b = strconv.AppendInt(b, int64(q.varchar), 10)
+ b = append(b, ")"...)
+ return b
+ }
+
+ return append(b, field.CreateTableSQLType...)
+}
+
+func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte {
+ unique := q.table.Unique
+
+ keys := make([]string, 0, len(unique))
+ for key := range unique {
+ keys = append(keys, key)
+ }
+ sort.Strings(keys)
+
+ for _, key := range keys {
+ b = q.appendUniqueConstraint(fmter, b, key, unique[key])
+ }
+
+ return b
+}
+
+func (q *CreateTableQuery) appendUniqueConstraint(
+ fmter schema.Formatter, b []byte, name string, fields []*schema.Field,
+) []byte {
+ if name != "" {
+ b = append(b, ", CONSTRAINT "...)
+ b = fmter.AppendIdent(b, name)
+ } else {
+ b = append(b, ","...)
+ }
+ b = append(b, " UNIQUE ("...)
+ b = appendColumns(b, "", fields)
+ b = append(b, ")"...)
+
+ return b
+}
+
+func (q *CreateTableQuery) appenFKConstraints(
+ fmter schema.Formatter, b []byte,
+) (_ []byte, err error) {
+ for _, fk := range q.fks {
+ b = append(b, ", FOREIGN KEY "...)
+ b, err = fk.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return b, nil
+}
+
+func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []byte {
+ if len(pks) == 0 {
+ return b
+ }
+
+ b = append(b, ", PRIMARY KEY ("...)
+ b = appendColumns(b, "", pks)
+ b = append(b, ")"...)
+ return b
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ if err := q.beforeCreateTableHook(ctx); err != nil {
+ return nil, err
+ }
+
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+
+ res, err := q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+
+ if q.table != nil {
+ if err := q.afterCreateTableHook(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ return res, nil
+}
+
+func (q *CreateTableQuery) beforeCreateTableHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(BeforeCreateTableHook); ok {
+ if err := hook.BeforeCreateTable(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(AfterCreateTableHook); ok {
+ if err := hook.AfterCreateTable(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/vendor/github.com/uptrace/bun/query_table_drop.go b/vendor/github.com/uptrace/bun/query_table_drop.go
new file mode 100644
index 000000000..2c30171c1
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_table_drop.go
@@ -0,0 +1,137 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type DropTableQuery struct {
+ baseQuery
+ cascadeQuery
+
+ ifExists bool
+}
+
+func NewDropTableQuery(db *DB) *DropTableQuery {
+ q := &DropTableQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ }
+ return q
+}
+
+func (q *DropTableQuery) Conn(db IConn) *DropTableQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *DropTableQuery) Model(model interface{}) *DropTableQuery {
+ q.setTableModel(model)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DropTableQuery) Table(tables ...string) *DropTableQuery {
+ for _, table := range tables {
+ q.addTable(schema.UnsafeIdent(table))
+ }
+ return q
+}
+
+func (q *DropTableQuery) TableExpr(query string, args ...interface{}) *DropTableQuery {
+ q.addTable(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *DropTableQuery) ModelTableExpr(query string, args ...interface{}) *DropTableQuery {
+ q.modelTable = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DropTableQuery) IfExists() *DropTableQuery {
+ q.ifExists = true
+ return q
+}
+
+func (q *DropTableQuery) Restrict() *DropTableQuery {
+ q.restrict = true
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DropTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+
+ b = append(b, "DROP TABLE "...)
+ if q.ifExists {
+ b = append(b, "IF EXISTS "...)
+ }
+
+ b, err = q.appendTables(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = q.appendCascade(fmter, b)
+
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *DropTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ if q.table != nil {
+ if err := q.beforeDropTableHook(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+
+ res, err := q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+
+ if q.table != nil {
+ if err := q.afterDropTableHook(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ return res, nil
+}
+
+func (q *DropTableQuery) beforeDropTableHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(BeforeDropTableHook); ok {
+ if err := hook.BeforeDropTable(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (q *DropTableQuery) afterDropTableHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(AfterDropTableHook); ok {
+ if err := hook.AfterDropTable(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/vendor/github.com/uptrace/bun/query_table_truncate.go b/vendor/github.com/uptrace/bun/query_table_truncate.go
new file mode 100644
index 000000000..1e4bef7f6
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_table_truncate.go
@@ -0,0 +1,121 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type TruncateTableQuery struct {
+ baseQuery
+ cascadeQuery
+
+ continueIdentity bool
+}
+
+func NewTruncateTableQuery(db *DB) *TruncateTableQuery {
+ q := &TruncateTableQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ }
+ return q
+}
+
+func (q *TruncateTableQuery) Conn(db IConn) *TruncateTableQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *TruncateTableQuery) Model(model interface{}) *TruncateTableQuery {
+ q.setTableModel(model)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *TruncateTableQuery) Table(tables ...string) *TruncateTableQuery {
+ for _, table := range tables {
+ q.addTable(schema.UnsafeIdent(table))
+ }
+ return q
+}
+
+func (q *TruncateTableQuery) TableExpr(query string, args ...interface{}) *TruncateTableQuery {
+ q.addTable(schema.SafeQuery(query, args))
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *TruncateTableQuery) ContinueIdentity() *TruncateTableQuery {
+ q.continueIdentity = true
+ return q
+}
+
+func (q *TruncateTableQuery) Restrict() *TruncateTableQuery {
+ q.restrict = true
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *TruncateTableQuery) AppendQuery(
+ fmter schema.Formatter, b []byte,
+) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+
+ if !fmter.HasFeature(feature.TableTruncate) {
+ b = append(b, "DELETE FROM "...)
+
+ b, err = q.appendTables(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ return b, nil
+ }
+
+ b = append(b, "TRUNCATE TABLE "...)
+
+ b, err = q.appendTables(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ if q.db.features.Has(feature.TableIdentity) {
+ if q.continueIdentity {
+ b = append(b, " CONTINUE IDENTITY"...)
+ } else {
+ b = append(b, " RESTART IDENTITY"...)
+ }
+ }
+
+ b = q.appendCascade(fmter, b)
+
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *TruncateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+
+ res, err := q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
diff --git a/vendor/github.com/uptrace/bun/query_update.go b/vendor/github.com/uptrace/bun/query_update.go
new file mode 100644
index 000000000..ea74e1419
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_update.go
@@ -0,0 +1,432 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type UpdateQuery struct {
+ whereBaseQuery
+ returningQuery
+ customValueQuery
+ setQuery
+
+ omitZero bool
+}
+
+func NewUpdateQuery(db *DB) *UpdateQuery {
+ q := &UpdateQuery{
+ whereBaseQuery: whereBaseQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ },
+ }
+ return q
+}
+
+func (q *UpdateQuery) Conn(db IConn) *UpdateQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *UpdateQuery) Model(model interface{}) *UpdateQuery {
+ q.setTableModel(model)
+ return q
+}
+
+// Apply calls the fn passing the SelectQuery as an argument.
+func (q *UpdateQuery) Apply(fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery {
+ return fn(q)
+}
+
+func (q *UpdateQuery) With(name string, query schema.QueryAppender) *UpdateQuery {
+ q.addWith(name, query)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *UpdateQuery) Table(tables ...string) *UpdateQuery {
+ for _, table := range tables {
+ q.addTable(schema.UnsafeIdent(table))
+ }
+ return q
+}
+
+func (q *UpdateQuery) TableExpr(query string, args ...interface{}) *UpdateQuery {
+ q.addTable(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *UpdateQuery) ModelTableExpr(query string, args ...interface{}) *UpdateQuery {
+ q.modelTable = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *UpdateQuery) Column(columns ...string) *UpdateQuery {
+ for _, column := range columns {
+ q.addColumn(schema.UnsafeIdent(column))
+ }
+ return q
+}
+
+func (q *UpdateQuery) ExcludeColumn(columns ...string) *UpdateQuery {
+ q.excludeColumn(columns)
+ return q
+}
+
+func (q *UpdateQuery) Set(query string, args ...interface{}) *UpdateQuery {
+ q.addSet(schema.SafeQuery(query, args))
+ return q
+}
+
+// Value overwrites model value for the column in INSERT and UPDATE queries.
+func (q *UpdateQuery) Value(column string, value string, args ...interface{}) *UpdateQuery {
+ if q.table == nil {
+ q.err = errNilModel
+ return q
+ }
+ q.addValue(q.table, column, value, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *UpdateQuery) WherePK() *UpdateQuery {
+ q.flags = q.flags.Set(wherePKFlag)
+ return q
+}
+
+func (q *UpdateQuery) Where(query string, args ...interface{}) *UpdateQuery {
+ q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
+ return q
+}
+
+func (q *UpdateQuery) WhereOr(query string, args ...interface{}) *UpdateQuery {
+ q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
+ return q
+}
+
+func (q *UpdateQuery) WhereGroup(sep string, fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery {
+ saved := q.where
+ q.where = nil
+
+ q = fn(q)
+
+ where := q.where
+ q.where = saved
+
+ q.addWhereGroup(sep, where)
+
+ return q
+}
+
+func (q *UpdateQuery) WhereDeleted() *UpdateQuery {
+ q.whereDeleted()
+ return q
+}
+
+func (q *UpdateQuery) WhereAllWithDeleted() *UpdateQuery {
+ q.whereAllWithDeleted()
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+// Returning adds a RETURNING clause to the query.
+//
+// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`.
+func (q *UpdateQuery) Returning(query string, args ...interface{}) *UpdateQuery {
+ q.addReturning(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *UpdateQuery) hasReturning() bool {
+ if !q.db.features.Has(feature.Returning) {
+ return false
+ }
+ return q.returningQuery.hasReturning()
+}
+
+//------------------------------------------------------------------------------
+
+func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+ fmter = formatterWithModel(fmter, q)
+
+ withAlias := fmter.HasFeature(feature.UpdateMultiTable)
+
+ b, err = q.appendWith(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, "UPDATE "...)
+
+ if withAlias {
+ b, err = q.appendTablesWithAlias(fmter, b)
+ } else {
+ b, err = q.appendFirstTableWithAlias(fmter, b)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ b, err = q.mustAppendSet(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ if !fmter.HasFeature(feature.UpdateMultiTable) {
+ b, err = q.appendOtherTables(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ b, err = q.mustAppendWhere(fmter, b, withAlias)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(q.returning) > 0 {
+ b, err = q.appendReturning(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+}
+
+func (q *UpdateQuery) mustAppendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ b = append(b, " SET "...)
+
+ if len(q.set) > 0 {
+ return q.appendSet(fmter, b)
+ }
+
+ if m, ok := q.model.(*mapModel); ok {
+ return m.appendSet(fmter, b), nil
+ }
+
+ if q.tableModel == nil {
+ return nil, errNilModel
+ }
+
+ switch model := q.tableModel.(type) {
+ case *structTableModel:
+ b, err = q.appendSetStruct(fmter, b, model)
+ if err != nil {
+ return nil, err
+ }
+ case *sliceTableModel:
+ return nil, errors.New("bun: to bulk Update, use CTE and VALUES")
+ default:
+ return nil, fmt.Errorf("bun: Update does not support %T", q.tableModel)
+ }
+
+ return b, nil
+}
+
+func (q *UpdateQuery) appendSetStruct(
+ fmter schema.Formatter, b []byte, model *structTableModel,
+) ([]byte, error) {
+ fields, err := q.getDataFields()
+ if err != nil {
+ return nil, err
+ }
+
+ isTemplate := fmter.IsNop()
+ pos := len(b)
+ for _, f := range fields {
+ if q.omitZero && f.NullZero && f.HasZeroValue(model.strct) {
+ continue
+ }
+
+ if len(b) != pos {
+ b = append(b, ", "...)
+ pos = len(b)
+ }
+
+ b = append(b, f.SQLName...)
+ b = append(b, " = "...)
+
+ if isTemplate {
+ b = append(b, '?')
+ continue
+ }
+
+ app, ok := q.modelValues[f.Name]
+ if ok {
+ b, err = app.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ b = f.AppendValue(fmter, b, model.strct)
+ }
+ }
+
+ for i, v := range q.extraValues {
+ if i > 0 || len(fields) > 0 {
+ b = append(b, ", "...)
+ }
+
+ b = append(b, v.column...)
+ b = append(b, " = "...)
+
+ b, err = v.value.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+}
+
+func (q *UpdateQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if !q.hasMultiTables() {
+ return b, nil
+ }
+
+ b = append(b, " FROM "...)
+
+ b, err = q.whereBaseQuery.appendOtherTables(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ return b, nil
+}
+
+//------------------------------------------------------------------------------
+
+func (q *UpdateQuery) Bulk() *UpdateQuery {
+ model, ok := q.model.(*sliceTableModel)
+ if !ok {
+ q.setErr(fmt.Errorf("bun: Bulk requires a slice, got %T", q.model))
+ return q
+ }
+
+ return q.With("_data", q.db.NewValues(model)).
+ Model(model).
+ TableExpr("_data").
+ Set(q.updateSliceSet(model)).
+ Where(q.updateSliceWhere(model))
+}
+
+func (q *UpdateQuery) updateSliceSet(model *sliceTableModel) string {
+ var b []byte
+ for i, field := range model.table.DataFields {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ if q.db.fmter.HasFeature(feature.UpdateMultiTable) {
+ b = append(b, model.table.SQLAlias...)
+ b = append(b, '.')
+ }
+ b = append(b, field.SQLName...)
+ b = append(b, " = _data."...)
+ b = append(b, field.SQLName...)
+ }
+ return internal.String(b)
+}
+
+func (db *UpdateQuery) updateSliceWhere(model *sliceTableModel) string {
+ var b []byte
+ for i, pk := range model.table.PKs {
+ if i > 0 {
+ b = append(b, " AND "...)
+ }
+ b = append(b, model.table.SQLAlias...)
+ b = append(b, '.')
+ b = append(b, pk.SQLName...)
+ b = append(b, " = _data."...)
+ b = append(b, pk.SQLName...)
+ }
+ return internal.String(b)
+}
+
+//------------------------------------------------------------------------------
+
+func (q *UpdateQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ if q.table != nil {
+ if err := q.beforeUpdateHook(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+
+ var res sql.Result
+
+ if hasDest := len(dest) > 0; hasDest || q.hasReturning() {
+ model, err := q.getModel(dest)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err = q.scan(ctx, q, query, model, hasDest)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ res, err = q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if q.table != nil {
+ if err := q.afterUpdateHook(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ return res, nil
+}
+
+func (q *UpdateQuery) beforeUpdateHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(BeforeUpdateHook); ok {
+ if err := hook.BeforeUpdate(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (q *UpdateQuery) afterUpdateHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(AfterUpdateHook); ok {
+ if err := hook.AfterUpdate(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// FQN returns a fully qualified column name. For MySQL, it returns the column name with
+// the table alias. For other RDBMS, it returns just the column name.
+func (q *UpdateQuery) FQN(name string) Ident {
+ if q.db.fmter.HasFeature(feature.UpdateMultiTable) {
+ return Ident(q.table.Alias + "." + name)
+ }
+ return Ident(name)
+}
diff --git a/vendor/github.com/uptrace/bun/query_values.go b/vendor/github.com/uptrace/bun/query_values.go
new file mode 100644
index 000000000..323ac68ef
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_values.go
@@ -0,0 +1,198 @@
+package bun
+
+import (
+ "fmt"
+ "reflect"
+ "strconv"
+
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/schema"
+)
+
+type ValuesQuery struct {
+ baseQuery
+ customValueQuery
+
+ withOrder bool
+}
+
+var _ schema.NamedArgAppender = (*ValuesQuery)(nil)
+
+func NewValuesQuery(db *DB, model interface{}) *ValuesQuery {
+ q := &ValuesQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ }
+ q.setTableModel(model)
+ return q
+}
+
+func (q *ValuesQuery) Conn(db IConn) *ValuesQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *ValuesQuery) WithOrder() *ValuesQuery {
+ q.withOrder = true
+ return q
+}
+
+func (q *ValuesQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) {
+ switch name {
+ case "Columns":
+ bb, err := q.AppendColumns(fmter, b)
+ if err != nil {
+ q.setErr(err)
+ return b, true
+ }
+ return bb, true
+ }
+ return b, false
+}
+
+// AppendColumns appends the table columns. It is used by CTE.
+func (q *ValuesQuery) AppendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+ if q.model == nil {
+ return nil, errNilModel
+ }
+
+ if q.tableModel != nil {
+ fields, err := q.getFields()
+ if err != nil {
+ return nil, err
+ }
+
+ b = appendColumns(b, "", fields)
+
+ if q.withOrder {
+ b = append(b, ", _order"...)
+ }
+
+ return b, nil
+ }
+
+ switch model := q.model.(type) {
+ case *mapSliceModel:
+ return model.appendColumns(fmter, b)
+ }
+
+ return nil, fmt.Errorf("bun: Values does not support %T", q.model)
+}
+
+func (q *ValuesQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+ if q.model == nil {
+ return nil, errNilModel
+ }
+
+ fmter = formatterWithModel(fmter, q)
+
+ if q.tableModel != nil {
+ fields, err := q.getFields()
+ if err != nil {
+ return nil, err
+ }
+ return q.appendQuery(fmter, b, fields)
+ }
+
+ switch model := q.model.(type) {
+ case *mapSliceModel:
+ return model.appendValues(fmter, b)
+ }
+
+ return nil, fmt.Errorf("bun: Values does not support %T", q.model)
+}
+
+func (q *ValuesQuery) appendQuery(
+ fmter schema.Formatter,
+ b []byte,
+ fields []*schema.Field,
+) (_ []byte, err error) {
+ b = append(b, "VALUES "...)
+ if q.db.features.Has(feature.ValuesRow) {
+ b = append(b, "ROW("...)
+ } else {
+ b = append(b, '(')
+ }
+
+ switch model := q.tableModel.(type) {
+ case *structTableModel:
+ b, err = q.appendValues(fmter, b, fields, model.strct)
+ if err != nil {
+ return nil, err
+ }
+
+ if q.withOrder {
+ b = append(b, ", "...)
+ b = strconv.AppendInt(b, 0, 10)
+ }
+ case *sliceTableModel:
+ slice := model.slice
+ sliceLen := slice.Len()
+ for i := 0; i < sliceLen; i++ {
+ if i > 0 {
+ b = append(b, "), "...)
+ if q.db.features.Has(feature.ValuesRow) {
+ b = append(b, "ROW("...)
+ } else {
+ b = append(b, '(')
+ }
+ }
+
+ b, err = q.appendValues(fmter, b, fields, slice.Index(i))
+ if err != nil {
+ return nil, err
+ }
+
+ if q.withOrder {
+ b = append(b, ", "...)
+ b = strconv.AppendInt(b, int64(i), 10)
+ }
+ }
+ default:
+ return nil, fmt.Errorf("bun: Values does not support %T", q.model)
+ }
+
+ b = append(b, ')')
+
+ return b, nil
+}
+
+func (q *ValuesQuery) appendValues(
+ fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value,
+) (_ []byte, err error) {
+ isTemplate := fmter.IsNop()
+ for i, f := range fields {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ app, ok := q.modelValues[f.Name]
+ if ok {
+ b, err = app.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ continue
+ }
+
+ if isTemplate {
+ b = append(b, '?')
+ } else {
+ b = f.AppendValue(fmter, b, indirect(strct))
+ }
+
+ if fmter.HasFeature(feature.DoubleColonCast) {
+ b = append(b, "::"...)
+ b = append(b, f.UserSQLType...)
+ }
+ }
+ return b, nil
+}
diff --git a/vendor/github.com/uptrace/bun/schema/append.go b/vendor/github.com/uptrace/bun/schema/append.go
new file mode 100644
index 000000000..68f7071c8
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/append.go
@@ -0,0 +1,93 @@
+package schema
+
+import (
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/vmihailenco/msgpack/v5"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/dialect/sqltype"
+ "github.com/uptrace/bun/internal"
+)
+
+func FieldAppender(dialect Dialect, field *Field) AppenderFunc {
+ if field.Tag.HasOption("msgpack") {
+ return appendMsgpack
+ }
+
+ switch strings.ToUpper(field.UserSQLType) {
+ case sqltype.JSON, sqltype.JSONB:
+ return AppendJSONValue
+ }
+
+ return dialect.Appender(field.StructField.Type)
+}
+
+func Append(fmter Formatter, b []byte, v interface{}, custom CustomAppender) []byte {
+ switch v := v.(type) {
+ case nil:
+ return dialect.AppendNull(b)
+ case bool:
+ return dialect.AppendBool(b, v)
+ case int:
+ return strconv.AppendInt(b, int64(v), 10)
+ case int32:
+ return strconv.AppendInt(b, int64(v), 10)
+ case int64:
+ return strconv.AppendInt(b, v, 10)
+ case uint:
+ return strconv.AppendUint(b, uint64(v), 10)
+ case uint32:
+ return strconv.AppendUint(b, uint64(v), 10)
+ case uint64:
+ return strconv.AppendUint(b, v, 10)
+ case float32:
+ return dialect.AppendFloat32(b, v)
+ case float64:
+ return dialect.AppendFloat64(b, v)
+ case string:
+ return dialect.AppendString(b, v)
+ case time.Time:
+ return dialect.AppendTime(b, v)
+ case []byte:
+ return dialect.AppendBytes(b, v)
+ case QueryAppender:
+ return AppendQueryAppender(fmter, b, v)
+ default:
+ vv := reflect.ValueOf(v)
+ if vv.Kind() == reflect.Ptr && vv.IsNil() {
+ return dialect.AppendNull(b)
+ }
+ appender := Appender(vv.Type(), custom)
+ return appender(fmter, b, vv)
+ }
+}
+
+func appendMsgpack(fmter Formatter, b []byte, v reflect.Value) []byte {
+ hexEnc := internal.NewHexEncoder(b)
+
+ enc := msgpack.GetEncoder()
+ defer msgpack.PutEncoder(enc)
+
+ enc.Reset(hexEnc)
+ if err := enc.EncodeValue(v); err != nil {
+ return dialect.AppendError(b, err)
+ }
+
+ if err := hexEnc.Close(); err != nil {
+ return dialect.AppendError(b, err)
+ }
+
+ return hexEnc.Bytes()
+}
+
+func AppendQueryAppender(fmter Formatter, b []byte, app QueryAppender) []byte {
+ bb, err := app.AppendQuery(fmter, b)
+ if err != nil {
+ return dialect.AppendError(b, err)
+ }
+ return bb
+}
diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go
new file mode 100644
index 000000000..0c4677069
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/append_value.go
@@ -0,0 +1,237 @@
+package schema
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+ "fmt"
+ "net"
+ "reflect"
+ "strconv"
+ "time"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/extra/bunjson"
+ "github.com/uptrace/bun/internal"
+)
+
+var (
+ timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
+ ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
+ ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
+ jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
+
+ driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
+ queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
+)
+
+type (
+ AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte
+ CustomAppender func(typ reflect.Type) AppenderFunc
+)
+
+var appenders = []AppenderFunc{
+ reflect.Bool: AppendBoolValue,
+ reflect.Int: AppendIntValue,
+ reflect.Int8: AppendIntValue,
+ reflect.Int16: AppendIntValue,
+ reflect.Int32: AppendIntValue,
+ reflect.Int64: AppendIntValue,
+ reflect.Uint: AppendUintValue,
+ reflect.Uint8: AppendUintValue,
+ reflect.Uint16: AppendUintValue,
+ reflect.Uint32: AppendUintValue,
+ reflect.Uint64: AppendUintValue,
+ reflect.Uintptr: nil,
+ reflect.Float32: AppendFloat32Value,
+ reflect.Float64: AppendFloat64Value,
+ reflect.Complex64: nil,
+ reflect.Complex128: nil,
+ reflect.Array: AppendJSONValue,
+ reflect.Chan: nil,
+ reflect.Func: nil,
+ reflect.Interface: nil,
+ reflect.Map: AppendJSONValue,
+ reflect.Ptr: nil,
+ reflect.Slice: AppendJSONValue,
+ reflect.String: AppendStringValue,
+ reflect.Struct: AppendJSONValue,
+ reflect.UnsafePointer: nil,
+}
+
+func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc {
+ switch typ {
+ case timeType:
+ return appendTimeValue
+ case ipType:
+ return appendIPValue
+ case ipNetType:
+ return appendIPNetValue
+ case jsonRawMessageType:
+ return appendJSONRawMessageValue
+ }
+
+ if typ.Implements(queryAppenderType) {
+ return appendQueryAppenderValue
+ }
+ if typ.Implements(driverValuerType) {
+ return driverValueAppender(custom)
+ }
+
+ kind := typ.Kind()
+
+ if kind != reflect.Ptr {
+ ptr := reflect.PtrTo(typ)
+ if ptr.Implements(queryAppenderType) {
+ return addrAppender(appendQueryAppenderValue, custom)
+ }
+ if ptr.Implements(driverValuerType) {
+ return addrAppender(driverValueAppender(custom), custom)
+ }
+ }
+
+ switch kind {
+ case reflect.Interface:
+ return ifaceAppenderFunc(typ, custom)
+ case reflect.Ptr:
+ return ptrAppenderFunc(typ, custom)
+ case reflect.Slice:
+ if typ.Elem().Kind() == reflect.Uint8 {
+ return appendBytesValue
+ }
+ case reflect.Array:
+ if typ.Elem().Kind() == reflect.Uint8 {
+ return appendArrayBytesValue
+ }
+ }
+
+ if custom != nil {
+ if fn := custom(typ); fn != nil {
+ return fn
+ }
+ }
+ return appenders[typ.Kind()]
+}
+
+func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
+ return func(fmter Formatter, b []byte, v reflect.Value) []byte {
+ if v.IsNil() {
+ return dialect.AppendNull(b)
+ }
+ elem := v.Elem()
+ appender := Appender(elem.Type(), custom)
+ return appender(fmter, b, elem)
+ }
+}
+
+func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
+ appender := Appender(typ.Elem(), custom)
+ return func(fmter Formatter, b []byte, v reflect.Value) []byte {
+ if v.IsNil() {
+ return dialect.AppendNull(b)
+ }
+ return appender(fmter, b, v.Elem())
+ }
+}
+
+func AppendBoolValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return dialect.AppendBool(b, v.Bool())
+}
+
+func AppendIntValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return strconv.AppendInt(b, v.Int(), 10)
+}
+
+func AppendUintValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return strconv.AppendUint(b, v.Uint(), 10)
+}
+
+func AppendFloat32Value(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return dialect.AppendFloat32(b, float32(v.Float()))
+}
+
+func AppendFloat64Value(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return dialect.AppendFloat64(b, float64(v.Float()))
+}
+
+func appendBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return dialect.AppendBytes(b, v.Bytes())
+}
+
+func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ if v.CanAddr() {
+ return dialect.AppendBytes(b, v.Slice(0, v.Len()).Bytes())
+ }
+
+ tmp := make([]byte, v.Len())
+ reflect.Copy(reflect.ValueOf(tmp), v)
+ b = dialect.AppendBytes(b, tmp)
+ return b
+}
+
+func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return dialect.AppendString(b, v.String())
+}
+
+func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ bb, err := bunjson.Marshal(v.Interface())
+ if err != nil {
+ return dialect.AppendError(b, err)
+ }
+
+ if len(bb) > 0 && bb[len(bb)-1] == '\n' {
+ bb = bb[:len(bb)-1]
+ }
+
+ return dialect.AppendJSON(b, bb)
+}
+
+func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ tm := v.Interface().(time.Time)
+ return dialect.AppendTime(b, tm)
+}
+
+func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ ip := v.Interface().(net.IP)
+ return dialect.AppendString(b, ip.String())
+}
+
+func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ ipnet := v.Interface().(net.IPNet)
+ return dialect.AppendString(b, ipnet.String())
+}
+
+func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ bytes := v.Bytes()
+ if bytes == nil {
+ return dialect.AppendNull(b)
+ }
+ return dialect.AppendString(b, internal.String(bytes))
+}
+
+func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return AppendQueryAppender(fmter, b, v.Interface().(QueryAppender))
+}
+
+func driverValueAppender(custom CustomAppender) AppenderFunc {
+ return func(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return appendDriverValue(fmter, b, v.Interface().(driver.Valuer), custom)
+ }
+}
+
+func appendDriverValue(fmter Formatter, b []byte, v driver.Valuer, custom CustomAppender) []byte {
+ value, err := v.Value()
+ if err != nil {
+ return dialect.AppendError(b, err)
+ }
+ return Append(fmter, b, value, custom)
+}
+
+func addrAppender(fn AppenderFunc, custom CustomAppender) AppenderFunc {
+ return func(fmter Formatter, b []byte, v reflect.Value) []byte {
+ if !v.CanAddr() {
+ err := fmt.Errorf("bun: Append(nonaddressable %T)", v.Interface())
+ return dialect.AppendError(b, err)
+ }
+ return fn(fmter, b, v.Addr())
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/schema/dialect.go b/vendor/github.com/uptrace/bun/schema/dialect.go
new file mode 100644
index 000000000..c50de715a
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/dialect.go
@@ -0,0 +1,99 @@
+package schema
+
+import (
+ "database/sql"
+ "reflect"
+ "sync"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/dialect/feature"
+)
+
+type Dialect interface {
+ Init(db *sql.DB)
+
+ Name() dialect.Name
+ Features() feature.Feature
+
+ Tables() *Tables
+ OnTable(table *Table)
+
+ IdentQuote() byte
+ Append(fmter Formatter, b []byte, v interface{}) []byte
+ Appender(typ reflect.Type) AppenderFunc
+ FieldAppender(field *Field) AppenderFunc
+ Scanner(typ reflect.Type) ScannerFunc
+}
+
+//------------------------------------------------------------------------------
+
+type nopDialect struct {
+ tables *Tables
+ features feature.Feature
+
+ appenderMap sync.Map
+ scannerMap sync.Map
+}
+
+func newNopDialect() *nopDialect {
+ d := new(nopDialect)
+ d.tables = NewTables(d)
+ d.features = feature.Returning
+ return d
+}
+
+func (d *nopDialect) Init(*sql.DB) {}
+
+func (d *nopDialect) Name() dialect.Name {
+ return dialect.Invalid
+}
+
+func (d *nopDialect) Features() feature.Feature {
+ return d.features
+}
+
+func (d *nopDialect) Tables() *Tables {
+ return d.tables
+}
+
+func (d *nopDialect) OnField(field *Field) {}
+
+func (d *nopDialect) OnTable(table *Table) {}
+
+func (d *nopDialect) IdentQuote() byte {
+ return '"'
+}
+
+func (d *nopDialect) Append(fmter Formatter, b []byte, v interface{}) []byte {
+ return Append(fmter, b, v, nil)
+}
+
+func (d *nopDialect) Appender(typ reflect.Type) AppenderFunc {
+ if v, ok := d.appenderMap.Load(typ); ok {
+ return v.(AppenderFunc)
+ }
+
+ fn := Appender(typ, nil)
+
+ if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok {
+ return v.(AppenderFunc)
+ }
+ return fn
+}
+
+func (d *nopDialect) FieldAppender(field *Field) AppenderFunc {
+ return FieldAppender(d, field)
+}
+
+func (d *nopDialect) Scanner(typ reflect.Type) ScannerFunc {
+ if v, ok := d.scannerMap.Load(typ); ok {
+ return v.(ScannerFunc)
+ }
+
+ fn := Scanner(typ)
+
+ if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok {
+ return v.(ScannerFunc)
+ }
+ return fn
+}
diff --git a/vendor/github.com/uptrace/bun/schema/field.go b/vendor/github.com/uptrace/bun/schema/field.go
new file mode 100644
index 000000000..1e069b82f
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/field.go
@@ -0,0 +1,117 @@
+package schema
+
+import (
+ "fmt"
+ "reflect"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/internal/tagparser"
+)
+
+type Field struct {
+ StructField reflect.StructField
+
+ Tag tagparser.Tag
+ IndirectType reflect.Type
+ Index []int
+
+ Name string // SQL name, .e.g. id
+ SQLName Safe // escaped SQL name, e.g. "id"
+ GoName string // struct field name, e.g. Id
+
+ DiscoveredSQLType string
+ UserSQLType string
+ CreateTableSQLType string
+ SQLDefault string
+
+ OnDelete string
+ OnUpdate string
+
+ IsPK bool
+ NotNull bool
+ NullZero bool
+ AutoIncrement bool
+
+ Append AppenderFunc
+ Scan ScannerFunc
+ IsZero IsZeroerFunc
+}
+
+func (f *Field) String() string {
+ return f.Name
+}
+
+func (f *Field) Clone() *Field {
+ cp := *f
+ cp.Index = cp.Index[:len(f.Index):len(f.Index)]
+ return &cp
+}
+
+func (f *Field) Value(strct reflect.Value) reflect.Value {
+ return fieldByIndexAlloc(strct, f.Index)
+}
+
+func (f *Field) HasZeroValue(v reflect.Value) bool {
+ for _, idx := range f.Index {
+ if v.Kind() == reflect.Ptr {
+ if v.IsNil() {
+ return true
+ }
+ v = v.Elem()
+ }
+ v = v.Field(idx)
+ }
+ return f.IsZero(v)
+}
+
+func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []byte {
+ fv, ok := fieldByIndex(strct, f.Index)
+ if !ok {
+ return dialect.AppendNull(b)
+ }
+
+ if f.NullZero && f.IsZero(fv) {
+ return dialect.AppendNull(b)
+ }
+ if f.Append == nil {
+ panic(fmt.Errorf("bun: AppendValue(unsupported %s)", fv.Type()))
+ }
+ return f.Append(fmter, b, fv)
+}
+
+func (f *Field) ScanWithCheck(fv reflect.Value, src interface{}) error {
+ if f.Scan == nil {
+ return fmt.Errorf("bun: Scan(unsupported %s)", f.IndirectType)
+ }
+ return f.Scan(fv, src)
+}
+
+func (f *Field) ScanValue(strct reflect.Value, src interface{}) error {
+ if src == nil {
+ if fv, ok := fieldByIndex(strct, f.Index); ok {
+ return f.ScanWithCheck(fv, src)
+ }
+ return nil
+ }
+
+ fv := fieldByIndexAlloc(strct, f.Index)
+ return f.ScanWithCheck(fv, src)
+}
+
+func (f *Field) markAsPK() {
+ f.IsPK = true
+ f.NotNull = true
+ f.NullZero = true
+}
+
+func indexEqual(ind1, ind2 []int) bool {
+ if len(ind1) != len(ind2) {
+ return false
+ }
+ for i, ind := range ind1 {
+ if ind != ind2[i] {
+ return false
+ }
+ }
+ return true
+}
diff --git a/vendor/github.com/uptrace/bun/schema/formatter.go b/vendor/github.com/uptrace/bun/schema/formatter.go
new file mode 100644
index 000000000..7b26fbaca
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/formatter.go
@@ -0,0 +1,248 @@
+package schema
+
+import (
+ "reflect"
+ "strconv"
+ "strings"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/internal/parser"
+)
+
+var nopFormatter = Formatter{
+ dialect: newNopDialect(),
+}
+
+type Formatter struct {
+ dialect Dialect
+ args *namedArgList
+}
+
+func NewFormatter(dialect Dialect) Formatter {
+ return Formatter{
+ dialect: dialect,
+ }
+}
+
+func NewNopFormatter() Formatter {
+ return nopFormatter
+}
+
+func (f Formatter) IsNop() bool {
+ return f.dialect.Name() == dialect.Invalid
+}
+
+func (f Formatter) Dialect() Dialect {
+ return f.dialect
+}
+
+func (f Formatter) IdentQuote() byte {
+ return f.dialect.IdentQuote()
+}
+
+func (f Formatter) AppendIdent(b []byte, ident string) []byte {
+ return dialect.AppendIdent(b, ident, f.IdentQuote())
+}
+
+func (f Formatter) AppendValue(b []byte, v reflect.Value) []byte {
+ if v.Kind() == reflect.Ptr && v.IsNil() {
+ return dialect.AppendNull(b)
+ }
+ appender := f.dialect.Appender(v.Type())
+ return appender(f, b, v)
+}
+
+func (f Formatter) HasFeature(feature feature.Feature) bool {
+ return f.dialect.Features().Has(feature)
+}
+
+func (f Formatter) WithArg(arg NamedArgAppender) Formatter {
+ return Formatter{
+ dialect: f.dialect,
+ args: f.args.WithArg(arg),
+ }
+}
+
+func (f Formatter) WithNamedArg(name string, value interface{}) Formatter {
+ return Formatter{
+ dialect: f.dialect,
+ args: f.args.WithArg(&namedArg{name: name, value: value}),
+ }
+}
+
+func (f Formatter) FormatQuery(query string, args ...interface{}) string {
+ if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
+ return query
+ }
+ return internal.String(f.AppendQuery(nil, query, args...))
+}
+
+func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []byte {
+ if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
+ return append(dst, query...)
+ }
+ return f.append(dst, parser.NewString(query), args)
+}
+
+func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte {
+ var namedArgs NamedArgAppender
+ if len(args) == 1 {
+ var ok bool
+ namedArgs, ok = args[0].(NamedArgAppender)
+ if !ok {
+ namedArgs, _ = newStructArgs(f, args[0])
+ }
+ }
+
+ var argIndex int
+ for p.Valid() {
+ b, ok := p.ReadSep('?')
+ if !ok {
+ dst = append(dst, b...)
+ continue
+ }
+ if len(b) > 0 && b[len(b)-1] == '\\' {
+ dst = append(dst, b[:len(b)-1]...)
+ dst = append(dst, '?')
+ continue
+ }
+ dst = append(dst, b...)
+
+ name, numeric := p.ReadIdentifier()
+ if name != "" {
+ if numeric {
+ idx, err := strconv.Atoi(name)
+ if err != nil {
+ goto restore_arg
+ }
+
+ if idx >= len(args) {
+ goto restore_arg
+ }
+
+ dst = f.appendArg(dst, args[idx])
+ continue
+ }
+
+ if namedArgs != nil {
+ dst, ok = namedArgs.AppendNamedArg(f, dst, name)
+ if ok {
+ continue
+ }
+ }
+
+ dst, ok = f.args.AppendNamedArg(f, dst, name)
+ if ok {
+ continue
+ }
+
+ restore_arg:
+ dst = append(dst, '?')
+ dst = append(dst, name...)
+ continue
+ }
+
+ if argIndex >= len(args) {
+ dst = append(dst, '?')
+ continue
+ }
+
+ arg := args[argIndex]
+ argIndex++
+
+ dst = f.appendArg(dst, arg)
+ }
+
+ return dst
+}
+
+func (f Formatter) appendArg(b []byte, arg interface{}) []byte {
+ switch arg := arg.(type) {
+ case QueryAppender:
+ bb, err := arg.AppendQuery(f, b)
+ if err != nil {
+ return dialect.AppendError(b, err)
+ }
+ return bb
+ default:
+ return f.dialect.Append(f, b, arg)
+ }
+}
+
+//------------------------------------------------------------------------------
+
+type NamedArgAppender interface {
+ AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool)
+}
+
+//------------------------------------------------------------------------------
+
+type namedArgList struct {
+ arg NamedArgAppender
+ next *namedArgList
+}
+
+func (l *namedArgList) WithArg(arg NamedArgAppender) *namedArgList {
+ return &namedArgList{
+ arg: arg,
+ next: l,
+ }
+}
+
+func (l *namedArgList) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
+ for l != nil && l.arg != nil {
+ if b, ok := l.arg.AppendNamedArg(fmter, b, name); ok {
+ return b, true
+ }
+ l = l.next
+ }
+ return b, false
+}
+
+//------------------------------------------------------------------------------
+
+type namedArg struct {
+ name string
+ value interface{}
+}
+
+var _ NamedArgAppender = (*namedArg)(nil)
+
+func (a *namedArg) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
+ if a.name == name {
+ return fmter.appendArg(b, a.value), true
+ }
+ return b, false
+}
+
+//------------------------------------------------------------------------------
+
+var _ NamedArgAppender = (*structArgs)(nil)
+
+type structArgs struct {
+ table *Table
+ strct reflect.Value
+}
+
+func newStructArgs(fmter Formatter, strct interface{}) (*structArgs, bool) {
+ v := reflect.ValueOf(strct)
+ if !v.IsValid() {
+ return nil, false
+ }
+
+ v = reflect.Indirect(v)
+ if v.Kind() != reflect.Struct {
+ return nil, false
+ }
+
+ return &structArgs{
+ table: fmter.Dialect().Tables().Get(v.Type()),
+ strct: v,
+ }, true
+}
+
+func (m *structArgs) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
+ return m.table.AppendNamedArg(fmter, b, name, m.strct)
+}
diff --git a/vendor/github.com/uptrace/bun/schema/hook.go b/vendor/github.com/uptrace/bun/schema/hook.go
new file mode 100644
index 000000000..5391981d5
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/hook.go
@@ -0,0 +1,20 @@
+package schema
+
+import (
+ "context"
+ "reflect"
+)
+
+type BeforeScanHook interface {
+ BeforeScan(context.Context) error
+}
+
+var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem()
+
+//------------------------------------------------------------------------------
+
+type AfterScanHook interface {
+ AfterScan(context.Context) error
+}
+
+var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem()
diff --git a/vendor/github.com/uptrace/bun/schema/relation.go b/vendor/github.com/uptrace/bun/schema/relation.go
new file mode 100644
index 000000000..8d1baeb3f
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/relation.go
@@ -0,0 +1,32 @@
+package schema
+
+import (
+ "fmt"
+)
+
+const (
+ InvalidRelation = iota
+ HasOneRelation
+ BelongsToRelation
+ HasManyRelation
+ ManyToManyRelation
+)
+
+type Relation struct {
+ Type int
+ Field *Field
+ JoinTable *Table
+ BaseFields []*Field
+ JoinFields []*Field
+
+ PolymorphicField *Field
+ PolymorphicValue string
+
+ M2MTable *Table
+ M2MBaseFields []*Field
+ M2MJoinFields []*Field
+}
+
+func (r *Relation) String() string {
+ return fmt.Sprintf("relation=%s", r.Field.GoName)
+}
diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go
new file mode 100644
index 000000000..0e66a860f
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/scan.go
@@ -0,0 +1,392 @@
+package schema
+
+import (
+ "bytes"
+ "database/sql"
+ "fmt"
+ "net"
+ "reflect"
+ "strconv"
+ "time"
+
+ "github.com/vmihailenco/msgpack/v5"
+
+ "github.com/uptrace/bun/extra/bunjson"
+ "github.com/uptrace/bun/internal"
+)
+
+var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
+
+type ScannerFunc func(dest reflect.Value, src interface{}) error
+
+var scanners = []ScannerFunc{
+ reflect.Bool: scanBool,
+ reflect.Int: scanInt64,
+ reflect.Int8: scanInt64,
+ reflect.Int16: scanInt64,
+ reflect.Int32: scanInt64,
+ reflect.Int64: scanInt64,
+ reflect.Uint: scanUint64,
+ reflect.Uint8: scanUint64,
+ reflect.Uint16: scanUint64,
+ reflect.Uint32: scanUint64,
+ reflect.Uint64: scanUint64,
+ reflect.Uintptr: scanUint64,
+ reflect.Float32: scanFloat64,
+ reflect.Float64: scanFloat64,
+ reflect.Complex64: nil,
+ reflect.Complex128: nil,
+ reflect.Array: nil,
+ reflect.Chan: nil,
+ reflect.Func: nil,
+ reflect.Map: scanJSON,
+ reflect.Ptr: nil,
+ reflect.Slice: scanJSON,
+ reflect.String: scanString,
+ reflect.Struct: scanJSON,
+ reflect.UnsafePointer: nil,
+}
+
+func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
+ if field.Tag.HasOption("msgpack") {
+ return scanMsgpack
+ }
+ if field.Tag.HasOption("json_use_number") {
+ return scanJSONUseNumber
+ }
+ return dialect.Scanner(field.StructField.Type)
+}
+
+func Scanner(typ reflect.Type) ScannerFunc {
+ kind := typ.Kind()
+
+ if kind == reflect.Ptr {
+ if fn := Scanner(typ.Elem()); fn != nil {
+ return ptrScanner(fn)
+ }
+ }
+
+ if typ.Implements(scannerType) {
+ return scanScanner
+ }
+
+ if kind != reflect.Ptr {
+ ptr := reflect.PtrTo(typ)
+ if ptr.Implements(scannerType) {
+ return addrScanner(scanScanner)
+ }
+ }
+
+ switch typ {
+ case timeType:
+ return scanTime
+ case ipType:
+ return scanIP
+ case ipNetType:
+ return scanIPNet
+ case jsonRawMessageType:
+ return scanJSONRawMessage
+ }
+
+ return scanners[kind]
+}
+
+func scanBool(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetBool(false)
+ return nil
+ case bool:
+ dest.SetBool(src)
+ return nil
+ case int64:
+ dest.SetBool(src != 0)
+ return nil
+ case []byte:
+ if len(src) == 1 {
+ dest.SetBool(src[0] != '0')
+ return nil
+ }
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanInt64(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetInt(0)
+ return nil
+ case int64:
+ dest.SetInt(src)
+ return nil
+ case uint64:
+ dest.SetInt(int64(src))
+ return nil
+ case []byte:
+ n, err := strconv.ParseInt(internal.String(src), 10, 64)
+ if err != nil {
+ return err
+ }
+ dest.SetInt(n)
+ return nil
+ case string:
+ n, err := strconv.ParseInt(src, 10, 64)
+ if err != nil {
+ return err
+ }
+ dest.SetInt(n)
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanUint64(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetUint(0)
+ return nil
+ case uint64:
+ dest.SetUint(src)
+ return nil
+ case int64:
+ dest.SetUint(uint64(src))
+ return nil
+ case []byte:
+ n, err := strconv.ParseUint(internal.String(src), 10, 64)
+ if err != nil {
+ return err
+ }
+ dest.SetUint(n)
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanFloat64(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetFloat(0)
+ return nil
+ case float64:
+ dest.SetFloat(src)
+ return nil
+ case []byte:
+ f, err := strconv.ParseFloat(internal.String(src), 64)
+ if err != nil {
+ return err
+ }
+ dest.SetFloat(f)
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanString(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetString("")
+ return nil
+ case string:
+ dest.SetString(src)
+ return nil
+ case []byte:
+ dest.SetString(string(src))
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanTime(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = time.Time{}
+ return nil
+ case time.Time:
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = src
+ return nil
+ case string:
+ srcTime, err := internal.ParseTime(src)
+ if err != nil {
+ return err
+ }
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = srcTime
+ return nil
+ case []byte:
+ srcTime, err := internal.ParseTime(internal.String(src))
+ if err != nil {
+ return err
+ }
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = srcTime
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanScanner(dest reflect.Value, src interface{}) error {
+ return dest.Interface().(sql.Scanner).Scan(src)
+}
+
+func scanMsgpack(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ dec := msgpack.GetDecoder()
+ defer msgpack.PutDecoder(dec)
+
+ dec.Reset(bytes.NewReader(b))
+ return dec.DecodeValue(dest)
+}
+
+func scanJSON(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ return bunjson.Unmarshal(b, dest.Addr().Interface())
+}
+
+func scanJSONUseNumber(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ dec := bunjson.NewDecoder(bytes.NewReader(b))
+ dec.UseNumber()
+ return dec.Decode(dest.Addr().Interface())
+}
+
+func scanIP(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ ip := net.ParseIP(internal.String(b))
+ if ip == nil {
+ return fmt.Errorf("bun: invalid ip: %q", b)
+ }
+
+ ptr := dest.Addr().Interface().(*net.IP)
+ *ptr = ip
+
+ return nil
+}
+
+func scanIPNet(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ _, ipnet, err := net.ParseCIDR(internal.String(b))
+ if err != nil {
+ return err
+ }
+
+ ptr := dest.Addr().Interface().(*net.IPNet)
+ *ptr = *ipnet
+
+ return nil
+}
+
+func scanJSONRawMessage(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ dest.SetBytes(nil)
+ return nil
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ dest.SetBytes(b)
+ return nil
+}
+
+func addrScanner(fn ScannerFunc) ScannerFunc {
+ return func(dest reflect.Value, src interface{}) error {
+ if !dest.CanAddr() {
+ return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface())
+ }
+ return fn(dest.Addr(), src)
+ }
+}
+
+func toBytes(src interface{}) ([]byte, error) {
+ switch src := src.(type) {
+ case string:
+ return internal.Bytes(src), nil
+ case []byte:
+ return src, nil
+ default:
+ return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
+ }
+}
+
+func ptrScanner(fn ScannerFunc) ScannerFunc {
+ return func(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ if !dest.CanAddr() {
+ if dest.IsNil() {
+ return nil
+ }
+ return fn(dest.Elem(), src)
+ }
+
+ if !dest.IsNil() {
+ dest.Set(reflect.New(dest.Type().Elem()))
+ }
+ return nil
+ }
+
+ if dest.IsNil() {
+ dest.Set(reflect.New(dest.Type().Elem()))
+ }
+ return fn(dest.Elem(), src)
+ }
+}
+
+func scanNull(dest reflect.Value) error {
+ if nilable(dest.Kind()) && dest.IsNil() {
+ return nil
+ }
+ dest.Set(reflect.New(dest.Type()).Elem())
+ return nil
+}
+
+func nilable(kind reflect.Kind) bool {
+ switch kind {
+ case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
+ return true
+ }
+ return false
+}
diff --git a/vendor/github.com/uptrace/bun/schema/sqlfmt.go b/vendor/github.com/uptrace/bun/schema/sqlfmt.go
new file mode 100644
index 000000000..7b538cd0c
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/sqlfmt.go
@@ -0,0 +1,76 @@
+package schema
+
+type QueryAppender interface {
+ AppendQuery(fmter Formatter, b []byte) ([]byte, error)
+}
+
+type ColumnsAppender interface {
+ AppendColumns(fmter Formatter, b []byte) ([]byte, error)
+}
+
+//------------------------------------------------------------------------------
+
+// Safe represents a safe SQL query.
+type Safe string
+
+var _ QueryAppender = (*Safe)(nil)
+
+func (s Safe) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
+ return append(b, s...), nil
+}
+
+//------------------------------------------------------------------------------
+
+// Ident represents a SQL identifier, for example, table or column name.
+type Ident string
+
+var _ QueryAppender = (*Ident)(nil)
+
+func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
+ return fmter.AppendIdent(b, string(s)), nil
+}
+
+//------------------------------------------------------------------------------
+
+type QueryWithArgs struct {
+ Query string
+ Args []interface{}
+}
+
+var _ QueryAppender = QueryWithArgs{}
+
+func SafeQuery(query string, args []interface{}) QueryWithArgs {
+ if query != "" && args == nil {
+ args = make([]interface{}, 0)
+ }
+ return QueryWithArgs{Query: query, Args: args}
+}
+
+func UnsafeIdent(ident string) QueryWithArgs {
+ return QueryWithArgs{Query: ident}
+}
+
+func (q QueryWithArgs) IsZero() bool {
+ return q.Query == "" && q.Args == nil
+}
+
+func (q QueryWithArgs) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
+ if q.Args == nil {
+ return fmter.AppendIdent(b, q.Query), nil
+ }
+ return fmter.AppendQuery(b, q.Query, q.Args...), nil
+}
+
+//------------------------------------------------------------------------------
+
+type QueryWithSep struct {
+ QueryWithArgs
+ Sep string
+}
+
+func SafeQueryWithSep(query string, args []interface{}, sep string) QueryWithSep {
+ return QueryWithSep{
+ QueryWithArgs: SafeQuery(query, args),
+ Sep: sep,
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/schema/sqltype.go b/vendor/github.com/uptrace/bun/schema/sqltype.go
new file mode 100644
index 000000000..560f695c2
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/sqltype.go
@@ -0,0 +1,129 @@
+package schema
+
+import (
+ "bytes"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "reflect"
+ "time"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/dialect/sqltype"
+ "github.com/uptrace/bun/internal"
+)
+
+var (
+ bunNullTimeType = reflect.TypeOf((*NullTime)(nil)).Elem()
+ nullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem()
+ nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem()
+ nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem()
+ nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem()
+ nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem()
+)
+
+var sqlTypes = []string{
+ reflect.Bool: sqltype.Boolean,
+ reflect.Int: sqltype.BigInt,
+ reflect.Int8: sqltype.SmallInt,
+ reflect.Int16: sqltype.SmallInt,
+ reflect.Int32: sqltype.Integer,
+ reflect.Int64: sqltype.BigInt,
+ reflect.Uint: sqltype.BigInt,
+ reflect.Uint8: sqltype.SmallInt,
+ reflect.Uint16: sqltype.SmallInt,
+ reflect.Uint32: sqltype.Integer,
+ reflect.Uint64: sqltype.BigInt,
+ reflect.Uintptr: sqltype.BigInt,
+ reflect.Float32: sqltype.Real,
+ reflect.Float64: sqltype.DoublePrecision,
+ reflect.Complex64: "",
+ reflect.Complex128: "",
+ reflect.Array: "",
+ reflect.Chan: "",
+ reflect.Func: "",
+ reflect.Interface: "",
+ reflect.Map: sqltype.VarChar,
+ reflect.Ptr: "",
+ reflect.Slice: sqltype.VarChar,
+ reflect.String: sqltype.VarChar,
+ reflect.Struct: sqltype.VarChar,
+ reflect.UnsafePointer: "",
+}
+
+func DiscoverSQLType(typ reflect.Type) string {
+ switch typ {
+ case timeType, nullTimeType, bunNullTimeType:
+ return sqltype.Timestamp
+ case nullBoolType:
+ return sqltype.Boolean
+ case nullFloatType:
+ return sqltype.DoublePrecision
+ case nullIntType:
+ return sqltype.BigInt
+ case nullStringType:
+ return sqltype.VarChar
+ }
+ return sqlTypes[typ.Kind()]
+}
+
+//------------------------------------------------------------------------------
+
+var jsonNull = []byte("null")
+
+// NullTime is a time.Time wrapper that marshals zero time as JSON null and SQL NULL.
+type NullTime struct {
+ time.Time
+}
+
+var (
+ _ json.Marshaler = (*NullTime)(nil)
+ _ json.Unmarshaler = (*NullTime)(nil)
+ _ sql.Scanner = (*NullTime)(nil)
+ _ QueryAppender = (*NullTime)(nil)
+)
+
+func (tm NullTime) MarshalJSON() ([]byte, error) {
+ if tm.IsZero() {
+ return jsonNull, nil
+ }
+ return tm.Time.MarshalJSON()
+}
+
+func (tm *NullTime) UnmarshalJSON(b []byte) error {
+ if bytes.Equal(b, jsonNull) {
+ tm.Time = time.Time{}
+ return nil
+ }
+ return tm.Time.UnmarshalJSON(b)
+}
+
+func (tm NullTime) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
+ if tm.IsZero() {
+ return dialect.AppendNull(b), nil
+ }
+ return dialect.AppendTime(b, tm.Time), nil
+}
+
+func (tm *NullTime) Scan(src interface{}) error {
+ if src == nil {
+ tm.Time = time.Time{}
+ return nil
+ }
+
+ switch src := src.(type) {
+ case []byte:
+ newtm, err := internal.ParseTime(internal.String(src))
+ if err != nil {
+ return err
+ }
+
+ tm.Time = newtm
+ return nil
+ case time.Time:
+ tm.Time = src
+ return nil
+ default:
+ return fmt.Errorf("bun: can't scan %#v into NullTime", src)
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go
new file mode 100644
index 000000000..eca18b781
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/table.go
@@ -0,0 +1,948 @@
+package schema
+
+import (
+ "database/sql"
+ "fmt"
+ "reflect"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/jinzhu/inflection"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/internal/tagparser"
+)
+
+const (
+ beforeScanHookFlag internal.Flag = 1 << iota
+ afterScanHookFlag
+)
+
+var (
+ baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem()
+ tableNameInflector = inflection.Plural
+)
+
+type BaseModel struct{}
+
+// SetTableNameInflector overrides the default func that pluralizes
+// model name to get table name, e.g. my_article becomes my_articles.
+func SetTableNameInflector(fn func(string) string) {
+ tableNameInflector = fn
+}
+
+// Table represents a SQL table created from Go struct.
+type Table struct {
+ dialect Dialect
+
+ Type reflect.Type
+ ZeroValue reflect.Value // reflect.Struct
+ ZeroIface interface{} // struct pointer
+
+ TypeName string
+ ModelName string
+
+ Name string
+ SQLName Safe
+ SQLNameForSelects Safe
+ Alias string
+ SQLAlias Safe
+
+ Fields []*Field // PKs + DataFields
+ PKs []*Field
+ DataFields []*Field
+
+ fieldsMapMu sync.RWMutex
+ FieldMap map[string]*Field
+
+ Relations map[string]*Relation
+ Unique map[string][]*Field
+
+ SoftDeleteField *Field
+ UpdateSoftDeleteField func(fv reflect.Value) error
+
+ allFields []*Field // read only
+ skippedFields []*Field
+
+ flags internal.Flag
+}
+
+func newTable(dialect Dialect, typ reflect.Type) *Table {
+ t := new(Table)
+ t.dialect = dialect
+ t.Type = typ
+ t.ZeroValue = reflect.New(t.Type).Elem()
+ t.ZeroIface = reflect.New(t.Type).Interface()
+ t.TypeName = internal.ToExported(t.Type.Name())
+ t.ModelName = internal.Underscore(t.Type.Name())
+ tableName := tableNameInflector(t.ModelName)
+ t.setName(tableName)
+ t.Alias = t.ModelName
+ t.SQLAlias = t.quoteIdent(t.ModelName)
+
+ hooks := []struct {
+ typ reflect.Type
+ flag internal.Flag
+ }{
+ {beforeScanHookType, beforeScanHookFlag},
+ {afterScanHookType, afterScanHookFlag},
+ }
+
+ typ = reflect.PtrTo(t.Type)
+ for _, hook := range hooks {
+ if typ.Implements(hook.typ) {
+ t.flags = t.flags.Set(hook.flag)
+ }
+ }
+
+ return t
+}
+
+func (t *Table) init1() {
+ t.initFields()
+}
+
+func (t *Table) init2() {
+ t.initInlines()
+ t.initRelations()
+ t.skippedFields = nil
+}
+
+func (t *Table) setName(name string) {
+ t.Name = name
+ t.SQLName = t.quoteIdent(name)
+ t.SQLNameForSelects = t.quoteIdent(name)
+ if t.SQLAlias == "" {
+ t.Alias = name
+ t.SQLAlias = t.quoteIdent(name)
+ }
+}
+
+func (t *Table) String() string {
+ return "model=" + t.TypeName
+}
+
+func (t *Table) CheckPKs() error {
+ if len(t.PKs) == 0 {
+ return fmt.Errorf("bun: %s does not have primary keys", t)
+ }
+ return nil
+}
+
+func (t *Table) addField(field *Field) {
+ t.Fields = append(t.Fields, field)
+ if field.IsPK {
+ t.PKs = append(t.PKs, field)
+ } else {
+ t.DataFields = append(t.DataFields, field)
+ }
+ t.FieldMap[field.Name] = field
+}
+
+func (t *Table) removeField(field *Field) {
+ t.Fields = removeField(t.Fields, field)
+ if field.IsPK {
+ t.PKs = removeField(t.PKs, field)
+ } else {
+ t.DataFields = removeField(t.DataFields, field)
+ }
+ delete(t.FieldMap, field.Name)
+}
+
+func (t *Table) fieldWithLock(name string) *Field {
+ t.fieldsMapMu.RLock()
+ field := t.FieldMap[name]
+ t.fieldsMapMu.RUnlock()
+ return field
+}
+
+func (t *Table) HasField(name string) bool {
+ _, ok := t.FieldMap[name]
+ return ok
+}
+
+func (t *Table) Field(name string) (*Field, error) {
+ field, ok := t.FieldMap[name]
+ if !ok {
+ return nil, fmt.Errorf("bun: %s does not have column=%s", t, name)
+ }
+ return field, nil
+}
+
+func (t *Table) fieldByGoName(name string) *Field {
+ for _, f := range t.allFields {
+ if f.GoName == name {
+ return f
+ }
+ }
+ return nil
+}
+
+func (t *Table) initFields() {
+ t.Fields = make([]*Field, 0, t.Type.NumField())
+ t.FieldMap = make(map[string]*Field, t.Type.NumField())
+ t.addFields(t.Type, nil)
+
+ if len(t.PKs) > 0 {
+ return
+ }
+ for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} {
+ if field, ok := t.FieldMap[name]; ok {
+ field.markAsPK()
+ t.PKs = []*Field{field}
+ t.DataFields = removeField(t.DataFields, field)
+ break
+ }
+ }
+ if len(t.PKs) == 1 {
+ switch t.PKs[0].IndirectType.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+ reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ t.PKs[0].AutoIncrement = true
+ }
+ }
+}
+
+func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
+ for i := 0; i < typ.NumField(); i++ {
+ f := typ.Field(i)
+
+ // Make a copy so slice is not shared between fields.
+ index := make([]int, len(baseIndex))
+ copy(index, baseIndex)
+
+ if f.Anonymous {
+ if f.Tag.Get("bun") == "-" {
+ continue
+ }
+ if f.Name == "BaseModel" && f.Type == baseModelType {
+ if len(index) == 0 {
+ t.processBaseModelField(f)
+ }
+ continue
+ }
+
+ fieldType := indirectType(f.Type)
+ if fieldType.Kind() != reflect.Struct {
+ continue
+ }
+ t.addFields(fieldType, append(index, f.Index...))
+
+ tag := tagparser.Parse(f.Tag.Get("bun"))
+ if _, inherit := tag.Options["inherit"]; inherit {
+ embeddedTable := t.dialect.Tables().Ref(fieldType)
+ t.TypeName = embeddedTable.TypeName
+ t.SQLName = embeddedTable.SQLName
+ t.SQLNameForSelects = embeddedTable.SQLNameForSelects
+ t.Alias = embeddedTable.Alias
+ t.SQLAlias = embeddedTable.SQLAlias
+ t.ModelName = embeddedTable.ModelName
+ }
+
+ continue
+ }
+
+ field := t.newField(f, index)
+ if field != nil {
+ t.addField(field)
+ }
+ }
+}
+
+func (t *Table) processBaseModelField(f reflect.StructField) {
+ tag := tagparser.Parse(f.Tag.Get("bun"))
+
+ if isKnownTableOption(tag.Name) {
+ internal.Warn.Printf(
+ "%s.%s tag name %q is also an option name; is it a mistake?",
+ t.TypeName, f.Name, tag.Name,
+ )
+ }
+
+ for name := range tag.Options {
+ if !isKnownTableOption(name) {
+ internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
+ }
+ }
+
+ if tag.Name != "" {
+ t.setName(tag.Name)
+ }
+
+ if s, ok := tag.Options["select"]; ok {
+ t.SQLNameForSelects = t.quoteTableName(s)
+ }
+
+ if s, ok := tag.Options["alias"]; ok {
+ t.Alias = s
+ t.SQLAlias = t.quoteIdent(s)
+ }
+}
+
+//nolint
+func (t *Table) newField(f reflect.StructField, index []int) *Field {
+ tag := tagparser.Parse(f.Tag.Get("bun"))
+
+ if f.PkgPath != "" {
+ return nil
+ }
+
+ sqlName := internal.Underscore(f.Name)
+
+ if tag.Name != sqlName && isKnownFieldOption(tag.Name) {
+ internal.Warn.Printf(
+ "%s.%s tag name %q is also an option name; is it a mistake?",
+ t.TypeName, f.Name, tag.Name,
+ )
+ }
+
+ for name := range tag.Options {
+ if !isKnownFieldOption(name) {
+ internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
+ }
+ }
+
+ skip := tag.Name == "-"
+ if !skip && tag.Name != "" {
+ sqlName = tag.Name
+ }
+
+ index = append(index, f.Index...)
+ if field := t.fieldWithLock(sqlName); field != nil {
+ if indexEqual(field.Index, index) {
+ return field
+ }
+ t.removeField(field)
+ }
+
+ field := &Field{
+ StructField: f,
+
+ Tag: tag,
+ IndirectType: indirectType(f.Type),
+ Index: index,
+
+ Name: sqlName,
+ GoName: f.Name,
+ SQLName: t.quoteIdent(sqlName),
+ }
+
+ field.NotNull = tag.HasOption("notnull")
+ field.NullZero = tag.HasOption("nullzero")
+ field.AutoIncrement = tag.HasOption("autoincrement")
+ if tag.HasOption("pk") {
+ field.markAsPK()
+ }
+ if tag.HasOption("allowzero") {
+ if tag.HasOption("nullzero") {
+ internal.Warn.Printf(
+ "%s.%s: nullzero and allowzero options are mutually exclusive",
+ t.TypeName, f.Name,
+ )
+ }
+ field.NullZero = false
+ }
+
+ if v, ok := tag.Options["unique"]; ok {
+ // Split the value by comma, this will allow multiple names to be specified.
+ // We can use this to create multiple named unique constraints where a single column
+ // might be included in multiple constraints.
+ for _, uniqueName := range strings.Split(v, ",") {
+ if t.Unique == nil {
+ t.Unique = make(map[string][]*Field)
+ }
+ t.Unique[uniqueName] = append(t.Unique[uniqueName], field)
+ }
+ }
+ if s, ok := tag.Options["default"]; ok {
+ field.SQLDefault = s
+ }
+ if s, ok := field.Tag.Options["type"]; ok {
+ field.UserSQLType = s
+ }
+ field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType)
+ field.Append = t.dialect.FieldAppender(field)
+ field.Scan = FieldScanner(t.dialect, field)
+ field.IsZero = FieldZeroChecker(field)
+
+ if v, ok := tag.Options["alt"]; ok {
+ t.FieldMap[v] = field
+ }
+
+ t.allFields = append(t.allFields, field)
+ if skip {
+ t.skippedFields = append(t.skippedFields, field)
+ t.FieldMap[field.Name] = field
+ return nil
+ }
+
+ if _, ok := tag.Options["soft_delete"]; ok {
+ field.NullZero = true
+ t.SoftDeleteField = field
+ t.UpdateSoftDeleteField = softDeleteFieldUpdater(field)
+ }
+
+ return field
+}
+
+func (t *Table) initInlines() {
+ for _, f := range t.skippedFields {
+ if f.IndirectType.Kind() == reflect.Struct {
+ t.inlineFields(f, nil)
+ }
+ }
+}
+
+//---------------------------------------------------------------------------------------
+
+func (t *Table) initRelations() {
+ for i := 0; i < len(t.Fields); {
+ f := t.Fields[i]
+ if t.tryRelation(f) {
+ t.Fields = removeField(t.Fields, f)
+ t.DataFields = removeField(t.DataFields, f)
+ } else {
+ i++
+ }
+
+ if f.IndirectType.Kind() == reflect.Struct {
+ t.inlineFields(f, nil)
+ }
+ }
+}
+
+func (t *Table) tryRelation(field *Field) bool {
+ if rel, ok := field.Tag.Options["rel"]; ok {
+ t.initRelation(field, rel)
+ return true
+ }
+ if field.Tag.HasOption("m2m") {
+ t.addRelation(t.m2mRelation(field))
+ return true
+ }
+
+ if field.Tag.HasOption("join") {
+ internal.Warn.Printf(
+ `%s.%s option "join" requires a relation type`,
+ t.TypeName, field.GoName,
+ )
+ }
+
+ return false
+}
+
+func (t *Table) initRelation(field *Field, rel string) {
+ switch rel {
+ case "belongs-to":
+ t.addRelation(t.belongsToRelation(field))
+ case "has-one":
+ t.addRelation(t.hasOneRelation(field))
+ case "has-many":
+ t.addRelation(t.hasManyRelation(field))
+ default:
+ panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName))
+ }
+}
+
+func (t *Table) addRelation(rel *Relation) {
+ if t.Relations == nil {
+ t.Relations = make(map[string]*Relation)
+ }
+ _, ok := t.Relations[rel.Field.GoName]
+ if ok {
+ panic(fmt.Errorf("%s already has %s", t, rel))
+ }
+ t.Relations[rel.Field.GoName] = rel
+}
+
+func (t *Table) belongsToRelation(field *Field) *Relation {
+ joinTable := t.dialect.Tables().Ref(field.IndirectType)
+ if err := joinTable.CheckPKs(); err != nil {
+ panic(err)
+ }
+
+ rel := &Relation{
+ Type: HasOneRelation,
+ Field: field,
+ JoinTable: joinTable,
+ }
+
+ if join, ok := field.Tag.Options["join"]; ok {
+ baseColumns, joinColumns := parseRelationJoin(join)
+ for i, baseColumn := range baseColumns {
+ joinColumn := joinColumns[i]
+
+ if f := t.fieldWithLock(baseColumn); f != nil {
+ rel.BaseFields = append(rel.BaseFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s belongs-to %s: %s must have column %s",
+ t.TypeName, field.GoName, t.TypeName, baseColumn,
+ ))
+ }
+
+ if f := joinTable.fieldWithLock(joinColumn); f != nil {
+ rel.JoinFields = append(rel.JoinFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s belongs-to %s: %s must have column %s",
+ t.TypeName, field.GoName, t.TypeName, baseColumn,
+ ))
+ }
+ }
+ return rel
+ }
+
+ rel.JoinFields = joinTable.PKs
+ fkPrefix := internal.Underscore(field.GoName) + "_"
+ for _, joinPK := range joinTable.PKs {
+ fkName := fkPrefix + joinPK.Name
+ if fk := t.fieldWithLock(fkName); fk != nil {
+ rel.BaseFields = append(rel.BaseFields, fk)
+ continue
+ }
+
+ if fk := t.fieldWithLock(joinPK.Name); fk != nil {
+ rel.BaseFields = append(rel.BaseFields, fk)
+ continue
+ }
+
+ panic(fmt.Errorf(
+ "bun: %s belongs-to %s: %s must have column %s "+
+ "(to override, use join:base_column=join_column tag on %s field)",
+ t.TypeName, field.GoName, t.TypeName, fkName, field.GoName,
+ ))
+ }
+ return rel
+}
+
+func (t *Table) hasOneRelation(field *Field) *Relation {
+ if err := t.CheckPKs(); err != nil {
+ panic(err)
+ }
+
+ joinTable := t.dialect.Tables().Ref(field.IndirectType)
+ rel := &Relation{
+ Type: BelongsToRelation,
+ Field: field,
+ JoinTable: joinTable,
+ }
+
+ if join, ok := field.Tag.Options["join"]; ok {
+ baseColumns, joinColumns := parseRelationJoin(join)
+ for i, baseColumn := range baseColumns {
+ if f := t.fieldWithLock(baseColumn); f != nil {
+ rel.BaseFields = append(rel.BaseFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s has-one %s: %s must have column %s",
+ field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
+ ))
+ }
+
+ joinColumn := joinColumns[i]
+ if f := joinTable.fieldWithLock(joinColumn); f != nil {
+ rel.JoinFields = append(rel.JoinFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s has-one %s: %s must have column %s",
+ field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
+ ))
+ }
+ }
+ return rel
+ }
+
+ rel.BaseFields = t.PKs
+ fkPrefix := internal.Underscore(t.ModelName) + "_"
+ for _, pk := range t.PKs {
+ fkName := fkPrefix + pk.Name
+ if f := joinTable.fieldWithLock(fkName); f != nil {
+ rel.JoinFields = append(rel.JoinFields, f)
+ continue
+ }
+
+ if f := joinTable.fieldWithLock(pk.Name); f != nil {
+ rel.JoinFields = append(rel.JoinFields, f)
+ continue
+ }
+
+ panic(fmt.Errorf(
+ "bun: %s has-one %s: %s must have column %s "+
+ "(to override, use join:base_column=join_column tag on %s field)",
+ field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName,
+ ))
+ }
+ return rel
+}
+
+func (t *Table) hasManyRelation(field *Field) *Relation {
+ if err := t.CheckPKs(); err != nil {
+ panic(err)
+ }
+ if field.IndirectType.Kind() != reflect.Slice {
+ panic(fmt.Errorf(
+ "bun: %s.%s has-many relation requires slice, got %q",
+ t.TypeName, field.GoName, field.IndirectType.Kind(),
+ ))
+ }
+
+ joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
+ polymorphicValue, isPolymorphic := field.Tag.Options["polymorphic"]
+ rel := &Relation{
+ Type: HasManyRelation,
+ Field: field,
+ JoinTable: joinTable,
+ }
+ var polymorphicColumn string
+
+ if join, ok := field.Tag.Options["join"]; ok {
+ baseColumns, joinColumns := parseRelationJoin(join)
+ for i, baseColumn := range baseColumns {
+ joinColumn := joinColumns[i]
+
+ if isPolymorphic && baseColumn == "type" {
+ polymorphicColumn = joinColumn
+ continue
+ }
+
+ if f := t.fieldWithLock(baseColumn); f != nil {
+ rel.BaseFields = append(rel.BaseFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s has-one %s: %s must have column %s",
+ t.TypeName, field.GoName, t.TypeName, baseColumn,
+ ))
+ }
+
+ if f := joinTable.fieldWithLock(joinColumn); f != nil {
+ rel.JoinFields = append(rel.JoinFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s has-one %s: %s must have column %s",
+ t.TypeName, field.GoName, t.TypeName, baseColumn,
+ ))
+ }
+ }
+ } else {
+ rel.BaseFields = t.PKs
+ fkPrefix := internal.Underscore(t.ModelName) + "_"
+ if isPolymorphic {
+ polymorphicColumn = fkPrefix + "type"
+ }
+
+ for _, pk := range t.PKs {
+ joinColumn := fkPrefix + pk.Name
+ if fk := joinTable.fieldWithLock(joinColumn); fk != nil {
+ rel.JoinFields = append(rel.JoinFields, fk)
+ continue
+ }
+
+ if fk := joinTable.fieldWithLock(pk.Name); fk != nil {
+ rel.JoinFields = append(rel.JoinFields, fk)
+ continue
+ }
+
+ panic(fmt.Errorf(
+ "bun: %s has-many %s: %s must have column %s "+
+ "(to override, use join:base_column=join_column tag on the field %s)",
+ t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName,
+ ))
+ }
+ }
+
+ if isPolymorphic {
+ rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn)
+ if rel.PolymorphicField == nil {
+ panic(fmt.Errorf(
+ "bun: %s has-many %s: %s must have polymorphic column %s",
+ t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn,
+ ))
+ }
+
+ if polymorphicValue == "" {
+ polymorphicValue = t.ModelName
+ }
+ rel.PolymorphicValue = polymorphicValue
+ }
+
+ return rel
+}
+
+func (t *Table) m2mRelation(field *Field) *Relation {
+ if field.IndirectType.Kind() != reflect.Slice {
+ panic(fmt.Errorf(
+ "bun: %s.%s m2m relation requires slice, got %q",
+ t.TypeName, field.GoName, field.IndirectType.Kind(),
+ ))
+ }
+ joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
+
+ if err := t.CheckPKs(); err != nil {
+ panic(err)
+ }
+ if err := joinTable.CheckPKs(); err != nil {
+ panic(err)
+ }
+
+ m2mTableName, ok := field.Tag.Options["m2m"]
+ if !ok {
+ panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName))
+ }
+
+ m2mTable := t.dialect.Tables().ByName(m2mTableName)
+ if m2mTable == nil {
+ panic(fmt.Errorf(
+ "bun: can't find m2m %s table (use db.RegisterModel)",
+ m2mTableName,
+ ))
+ }
+
+ rel := &Relation{
+ Type: ManyToManyRelation,
+ Field: field,
+ JoinTable: joinTable,
+ M2MTable: m2mTable,
+ }
+ var leftColumn, rightColumn string
+
+ if join, ok := field.Tag.Options["join"]; ok {
+ left, right := parseRelationJoin(join)
+ leftColumn = left[0]
+ rightColumn = right[0]
+ } else {
+ leftColumn = t.TypeName
+ rightColumn = joinTable.TypeName
+ }
+
+ leftField := m2mTable.fieldByGoName(leftColumn)
+ if leftField == nil {
+ panic(fmt.Errorf(
+ "bun: %s many-to-many %s: %s must have field %s "+
+ "(to override, use tag join:LeftField=RightField on field %s.%s",
+ t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName,
+ ))
+ }
+
+ rightField := m2mTable.fieldByGoName(rightColumn)
+ if rightField == nil {
+ panic(fmt.Errorf(
+ "bun: %s many-to-many %s: %s must have field %s "+
+ "(to override, use tag join:LeftField=RightField on field %s.%s",
+ t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName,
+ ))
+ }
+
+ leftRel := m2mTable.belongsToRelation(leftField)
+ rel.BaseFields = leftRel.JoinFields
+ rel.M2MBaseFields = leftRel.BaseFields
+
+ rightRel := m2mTable.belongsToRelation(rightField)
+ rel.JoinFields = rightRel.JoinFields
+ rel.M2MJoinFields = rightRel.BaseFields
+
+ return rel
+}
+
+func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) {
+ if path == nil {
+ path = map[reflect.Type]struct{}{
+ t.Type: {},
+ }
+ }
+
+ if _, ok := path[field.IndirectType]; ok {
+ return
+ }
+ path[field.IndirectType] = struct{}{}
+
+ joinTable := t.dialect.Tables().Ref(field.IndirectType)
+ for _, f := range joinTable.allFields {
+ f = f.Clone()
+ f.GoName = field.GoName + "_" + f.GoName
+ f.Name = field.Name + "__" + f.Name
+ f.SQLName = t.quoteIdent(f.Name)
+ f.Index = appendNew(field.Index, f.Index...)
+
+ t.fieldsMapMu.Lock()
+ if _, ok := t.FieldMap[f.Name]; !ok {
+ t.FieldMap[f.Name] = f
+ }
+ t.fieldsMapMu.Unlock()
+
+ if f.IndirectType.Kind() != reflect.Struct {
+ continue
+ }
+
+ if _, ok := path[f.IndirectType]; !ok {
+ t.inlineFields(f, path)
+ }
+ }
+}
+
+//------------------------------------------------------------------------------
+
+func (t *Table) Dialect() Dialect { return t.dialect }
+
+//------------------------------------------------------------------------------
+
+func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) }
+func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) }
+
+//------------------------------------------------------------------------------
+
+func (t *Table) AppendNamedArg(
+ fmter Formatter, b []byte, name string, strct reflect.Value,
+) ([]byte, bool) {
+ if field, ok := t.FieldMap[name]; ok {
+ return fmter.appendArg(b, field.Value(strct).Interface()), true
+ }
+ return b, false
+}
+
+func (t *Table) quoteTableName(s string) Safe {
+ // Don't quote if table name contains placeholder (?) or parentheses.
+ if strings.IndexByte(s, '?') >= 0 ||
+ strings.IndexByte(s, '(') >= 0 ||
+ strings.IndexByte(s, ')') >= 0 {
+ return Safe(s)
+ }
+ return t.quoteIdent(s)
+}
+
+func (t *Table) quoteIdent(s string) Safe {
+ return Safe(NewFormatter(t.dialect).AppendIdent(nil, s))
+}
+
+func appendNew(dst []int, src ...int) []int {
+ cp := make([]int, len(dst)+len(src))
+ copy(cp, dst)
+ copy(cp[len(dst):], src)
+ return cp
+}
+
+func isKnownTableOption(name string) bool {
+ switch name {
+ case "alias", "select":
+ return true
+ }
+ return false
+}
+
+func isKnownFieldOption(name string) bool {
+ switch name {
+ case "alias",
+ "type",
+ "array",
+ "hstore",
+ "composite",
+ "json_use_number",
+ "msgpack",
+ "notnull",
+ "nullzero",
+ "allowzero",
+ "default",
+ "unique",
+ "soft_delete",
+
+ "pk",
+ "autoincrement",
+ "rel",
+ "join",
+ "m2m",
+ "polymorphic":
+ return true
+ }
+ return false
+}
+
+func removeField(fields []*Field, field *Field) []*Field {
+ for i, f := range fields {
+ if f == field {
+ return append(fields[:i], fields[i+1:]...)
+ }
+ }
+ return fields
+}
+
+func parseRelationJoin(join string) ([]string, []string) {
+ ss := strings.Split(join, ",")
+ baseColumns := make([]string, len(ss))
+ joinColumns := make([]string, len(ss))
+ for i, s := range ss {
+ ss := strings.Split(strings.TrimSpace(s), "=")
+ if len(ss) != 2 {
+ panic(fmt.Errorf("can't parse relation join: %q", join))
+ }
+ baseColumns[i] = ss[0]
+ joinColumns[i] = ss[1]
+ }
+ return baseColumns, joinColumns
+}
+
+//------------------------------------------------------------------------------
+
+func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
+ typ := field.StructField.Type
+
+ switch typ {
+ case timeType:
+ return func(fv reflect.Value) error {
+ ptr := fv.Addr().Interface().(*time.Time)
+ *ptr = time.Now()
+ return nil
+ }
+ case nullTimeType:
+ return func(fv reflect.Value) error {
+ ptr := fv.Addr().Interface().(*sql.NullTime)
+ *ptr = sql.NullTime{Time: time.Now()}
+ return nil
+ }
+ case nullIntType:
+ return func(fv reflect.Value) error {
+ ptr := fv.Addr().Interface().(*sql.NullInt64)
+ *ptr = sql.NullInt64{Int64: time.Now().UnixNano()}
+ return nil
+ }
+ }
+
+ switch field.IndirectType.Kind() {
+ case reflect.Int64:
+ return func(fv reflect.Value) error {
+ ptr := fv.Addr().Interface().(*int64)
+ *ptr = time.Now().UnixNano()
+ return nil
+ }
+ case reflect.Ptr:
+ typ = typ.Elem()
+ default:
+ return softDeleteFieldUpdaterFallback(field)
+ }
+
+ switch typ { //nolint:gocritic
+ case timeType:
+ return func(fv reflect.Value) error {
+ now := time.Now()
+ fv.Set(reflect.ValueOf(&now))
+ return nil
+ }
+ }
+
+ switch typ.Kind() { //nolint:gocritic
+ case reflect.Int64:
+ return func(fv reflect.Value) error {
+ utime := time.Now().UnixNano()
+ fv.Set(reflect.ValueOf(&utime))
+ return nil
+ }
+ }
+
+ return softDeleteFieldUpdaterFallback(field)
+}
+
+func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error {
+ return func(fv reflect.Value) error {
+ return field.ScanWithCheck(fv, time.Now())
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go
new file mode 100644
index 000000000..d82d08f59
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/tables.go
@@ -0,0 +1,148 @@
+package schema
+
+import (
+ "fmt"
+ "reflect"
+ "sync"
+)
+
+type tableInProgress struct {
+ table *Table
+
+ init1Once sync.Once
+ init2Once sync.Once
+}
+
+func newTableInProgress(table *Table) *tableInProgress {
+ return &tableInProgress{
+ table: table,
+ }
+}
+
+func (inp *tableInProgress) init1() bool {
+ var inited bool
+ inp.init1Once.Do(func() {
+ inp.table.init1()
+ inited = true
+ })
+ return inited
+}
+
+func (inp *tableInProgress) init2() bool {
+ var inited bool
+ inp.init2Once.Do(func() {
+ inp.table.init2()
+ inited = true
+ })
+ return inited
+}
+
+type Tables struct {
+ dialect Dialect
+ tables sync.Map
+
+ mu sync.RWMutex
+ inProgress map[reflect.Type]*tableInProgress
+}
+
+func NewTables(dialect Dialect) *Tables {
+ return &Tables{
+ dialect: dialect,
+ inProgress: make(map[reflect.Type]*tableInProgress),
+ }
+}
+
+func (t *Tables) Register(models ...interface{}) {
+ for _, model := range models {
+ _ = t.Get(reflect.TypeOf(model).Elem())
+ }
+}
+
+func (t *Tables) Get(typ reflect.Type) *Table {
+ return t.table(typ, false)
+}
+
+func (t *Tables) Ref(typ reflect.Type) *Table {
+ return t.table(typ, true)
+}
+
+func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
+ if typ.Kind() != reflect.Struct {
+ panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
+ }
+
+ if v, ok := t.tables.Load(typ); ok {
+ return v.(*Table)
+ }
+
+ t.mu.Lock()
+
+ if v, ok := t.tables.Load(typ); ok {
+ t.mu.Unlock()
+ return v.(*Table)
+ }
+
+ var table *Table
+
+ inProgress := t.inProgress[typ]
+ if inProgress == nil {
+ table = newTable(t.dialect, typ)
+ inProgress = newTableInProgress(table)
+ t.inProgress[typ] = inProgress
+ } else {
+ table = inProgress.table
+ }
+
+ t.mu.Unlock()
+
+ inProgress.init1()
+ if allowInProgress {
+ return table
+ }
+
+ if inProgress.init2() {
+ t.mu.Lock()
+ delete(t.inProgress, typ)
+ t.tables.Store(typ, table)
+ t.mu.Unlock()
+ }
+
+ t.dialect.OnTable(table)
+
+ for _, field := range table.FieldMap {
+ if field.UserSQLType == "" {
+ field.UserSQLType = field.DiscoveredSQLType
+ }
+ if field.CreateTableSQLType == "" {
+ field.CreateTableSQLType = field.UserSQLType
+ }
+ }
+
+ return table
+}
+
+func (t *Tables) ByModel(name string) *Table {
+ var found *Table
+ t.tables.Range(func(key, value interface{}) bool {
+ t := value.(*Table)
+ if t.TypeName == name {
+ found = t
+ return false
+ }
+ return true
+ })
+ return found
+}
+
+func (t *Tables) ByName(name string) *Table {
+ var found *Table
+ t.tables.Range(func(key, value interface{}) bool {
+ t := value.(*Table)
+ if t.Name == name {
+ found = t
+ return false
+ }
+ return true
+ })
+ return found
+}
diff --git a/vendor/github.com/uptrace/bun/schema/util.go b/vendor/github.com/uptrace/bun/schema/util.go
new file mode 100644
index 000000000..6d474e4cc
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/util.go
@@ -0,0 +1,53 @@
+package schema
+
+import "reflect"
+
+func indirectType(t reflect.Type) reflect.Type {
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+ return t
+}
+
+func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) {
+ if len(index) == 1 {
+ return v.Field(index[0]), true
+ }
+
+ for i, idx := range index {
+ if i > 0 {
+ if v.Kind() == reflect.Ptr {
+ if v.IsNil() {
+ return v, false
+ }
+ v = v.Elem()
+ }
+ }
+ v = v.Field(idx)
+ }
+ return v, true
+}
+
+func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value {
+ if len(index) == 1 {
+ return v.Field(index[0])
+ }
+
+ for i, idx := range index {
+ if i > 0 {
+ v = indirectNil(v)
+ }
+ v = v.Field(idx)
+ }
+ return v
+}
+
+func indirectNil(v reflect.Value) reflect.Value {
+ if v.Kind() == reflect.Ptr {
+ if v.IsNil() {
+ v.Set(reflect.New(v.Type().Elem()))
+ }
+ v = v.Elem()
+ }
+ return v
+}
diff --git a/vendor/github.com/uptrace/bun/schema/zerochecker.go b/vendor/github.com/uptrace/bun/schema/zerochecker.go
new file mode 100644
index 000000000..95efeee6b
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/zerochecker.go
@@ -0,0 +1,126 @@
+package schema
+
+import (
+ "database/sql/driver"
+ "reflect"
+)
+
+var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem()
+
+type isZeroer interface {
+ IsZero() bool
+}
+
+type IsZeroerFunc func(reflect.Value) bool
+
+func FieldZeroChecker(field *Field) IsZeroerFunc {
+ return zeroChecker(field.IndirectType)
+}
+
+func zeroChecker(typ reflect.Type) IsZeroerFunc {
+ if typ.Implements(isZeroerType) {
+ return isZeroInterface
+ }
+
+ kind := typ.Kind()
+
+ if kind != reflect.Ptr {
+ ptr := reflect.PtrTo(typ)
+ if ptr.Implements(isZeroerType) {
+ return addrChecker(isZeroInterface)
+ }
+ }
+
+ switch kind {
+ case reflect.Array:
+ if typ.Elem().Kind() == reflect.Uint8 {
+ return isZeroBytes
+ }
+ return isZeroLen
+ case reflect.String:
+ return isZeroLen
+ case reflect.Bool:
+ return isZeroBool
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return isZeroInt
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return isZeroUint
+ case reflect.Float32, reflect.Float64:
+ return isZeroFloat
+ case reflect.Interface, reflect.Ptr, reflect.Slice, reflect.Map:
+ return isNil
+ }
+
+ if typ.Implements(driverValuerType) {
+ return isZeroDriverValue
+ }
+
+ return notZero
+}
+
+func addrChecker(fn IsZeroerFunc) IsZeroerFunc {
+ return func(v reflect.Value) bool {
+ if !v.CanAddr() {
+ return false
+ }
+ return fn(v.Addr())
+ }
+}
+
+func isZeroInterface(v reflect.Value) bool {
+ if v.Kind() == reflect.Ptr && v.IsNil() {
+ return true
+ }
+ return v.Interface().(isZeroer).IsZero()
+}
+
+func isZeroDriverValue(v reflect.Value) bool {
+ if v.Kind() == reflect.Ptr {
+ return v.IsNil()
+ }
+
+ valuer := v.Interface().(driver.Valuer)
+ value, err := valuer.Value()
+ if err != nil {
+ return false
+ }
+ return value == nil
+}
+
+func isZeroLen(v reflect.Value) bool {
+ return v.Len() == 0
+}
+
+func isNil(v reflect.Value) bool {
+ return v.IsNil()
+}
+
+func isZeroBool(v reflect.Value) bool {
+ return !v.Bool()
+}
+
+func isZeroInt(v reflect.Value) bool {
+ return v.Int() == 0
+}
+
+func isZeroUint(v reflect.Value) bool {
+ return v.Uint() == 0
+}
+
+func isZeroFloat(v reflect.Value) bool {
+ return v.Float() == 0
+}
+
+func isZeroBytes(v reflect.Value) bool {
+ b := v.Slice(0, v.Len()).Bytes()
+ for _, c := range b {
+ if c != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+func notZero(v reflect.Value) bool {
+ return false
+}
diff --git a/vendor/github.com/uptrace/bun/util.go b/vendor/github.com/uptrace/bun/util.go
new file mode 100644
index 000000000..ce56be805
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/util.go
@@ -0,0 +1,114 @@
+package bun
+
+import "reflect"
+
+func indirect(v reflect.Value) reflect.Value {
+ switch v.Kind() {
+ case reflect.Interface:
+ return indirect(v.Elem())
+ case reflect.Ptr:
+ return v.Elem()
+ default:
+ return v
+ }
+}
+
+func walk(v reflect.Value, index []int, fn func(reflect.Value)) {
+ v = reflect.Indirect(v)
+ switch v.Kind() {
+ case reflect.Slice:
+ sliceLen := v.Len()
+ for i := 0; i < sliceLen; i++ {
+ visitField(v.Index(i), index, fn)
+ }
+ default:
+ visitField(v, index, fn)
+ }
+}
+
+func visitField(v reflect.Value, index []int, fn func(reflect.Value)) {
+ v = reflect.Indirect(v)
+ if len(index) > 0 {
+ v = v.Field(index[0])
+ if v.Kind() == reflect.Ptr && v.IsNil() {
+ return
+ }
+ walk(v, index[1:], fn)
+ } else {
+ fn(v)
+ }
+}
+
+func typeByIndex(t reflect.Type, index []int) reflect.Type {
+ for _, x := range index {
+ switch t.Kind() {
+ case reflect.Ptr:
+ t = t.Elem()
+ case reflect.Slice:
+ t = indirectType(t.Elem())
+ }
+ t = t.Field(x).Type
+ }
+ return indirectType(t)
+}
+
+func indirectType(t reflect.Type) reflect.Type {
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+ return t
+}
+
+func sliceElemType(v reflect.Value) reflect.Type {
+ elemType := v.Type().Elem()
+ if elemType.Kind() == reflect.Interface && v.Len() > 0 {
+ return indirect(v.Index(0).Elem()).Type()
+ }
+ return indirectType(elemType)
+}
+
+func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value {
+ if v.Kind() == reflect.Array {
+ var pos int
+ return func() reflect.Value {
+ v := v.Index(pos)
+ pos++
+ return v
+ }
+ }
+
+ sliceType := v.Type()
+ elemType := sliceType.Elem()
+
+ if elemType.Kind() == reflect.Ptr {
+ elemType = elemType.Elem()
+ return func() reflect.Value {
+ if v.Len() < v.Cap() {
+ v.Set(v.Slice(0, v.Len()+1))
+ elem := v.Index(v.Len() - 1)
+ if elem.IsNil() {
+ elem.Set(reflect.New(elemType))
+ }
+ return elem.Elem()
+ }
+
+ elem := reflect.New(elemType)
+ v.Set(reflect.Append(v, elem))
+ return elem.Elem()
+ }
+ }
+
+ zero := reflect.Zero(elemType)
+ return func() reflect.Value {
+ l := v.Len()
+ c := v.Cap()
+
+ if l < c {
+ v.Set(v.Slice(0, l+1))
+ return v.Index(l)
+ }
+
+ v.Set(reflect.Append(v, zero))
+ return v.Index(l)
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/version.go b/vendor/github.com/uptrace/bun/version.go
new file mode 100644
index 000000000..1baf9a39c
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/version.go
@@ -0,0 +1,6 @@
+package bun
+
+// Version is the current release version.
+func Version() string {
+ return "0.4.3"
+}
diff --git a/vendor/modules.txt b/vendor/modules.txt
index da1a09c9f..b8a8a26f7 100644
--- a/vendor/modules.txt
+++ b/vendor/modules.txt
@@ -247,9 +247,6 @@ github.com/go-fed/activity/streams/vocab
# github.com/go-fed/httpsig v1.1.0
## explicit
github.com/go-fed/httpsig
-# github.com/go-pg/pg/extra/pgdebug v0.2.0
-## explicit
-github.com/go-pg/pg/extra/pgdebug
# github.com/go-pg/pg/v10 v10.10.3
## explicit
github.com/go-pg/pg/v10
@@ -383,13 +380,29 @@ github.com/tdewolff/parse/v2/strconv
github.com/tmthrgd/go-hex
# github.com/ugorji/go/codec v1.2.6
github.com/ugorji/go/codec
+# github.com/uptrace/bun v0.4.3
+## explicit
+github.com/uptrace/bun
+github.com/uptrace/bun/dialect
+github.com/uptrace/bun/dialect/feature
+github.com/uptrace/bun/dialect/sqltype
+github.com/uptrace/bun/extra/bunjson
+github.com/uptrace/bun/internal
+github.com/uptrace/bun/internal/parser
+github.com/uptrace/bun/internal/tagparser
+github.com/uptrace/bun/schema
+# github.com/uptrace/bun/dialect/pgdialect v0.4.3
+## explicit
+github.com/uptrace/bun/dialect/pgdialect
+# github.com/uptrace/bun/driver/pgdriver v0.4.3
+## explicit
+github.com/uptrace/bun/driver/pgdriver
# github.com/urfave/cli/v2 v2.3.0
## explicit
github.com/urfave/cli/v2
# github.com/vmihailenco/bufpool v0.1.11
github.com/vmihailenco/bufpool
# github.com/vmihailenco/msgpack/v5 v5.3.4
-## explicit
github.com/vmihailenco/msgpack/v5
github.com/vmihailenco/msgpack/v5/msgpcode
# github.com/vmihailenco/tagparser v0.1.2
@@ -450,6 +463,7 @@ golang.org/x/text/transform
golang.org/x/text/unicode/bidi
golang.org/x/text/unicode/norm
# google.golang.org/appengine v1.6.7
+## explicit
google.golang.org/appengine/internal
google.golang.org/appengine/internal/base
google.golang.org/appengine/internal/datastore