mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-12-28 17:46:15 -06:00
[security] transport.Controller{} and transport.Transport{} security and performance improvements (#564)
* cache transports in controller by privkey-generated pubkey, add retry logic to transport requests
Signed-off-by: kim <grufwub@gmail.com>
* update code comments, defer mutex unlocks
Signed-off-by: kim <grufwub@gmail.com>
* add count to 'performing request' log message
Signed-off-by: kim <grufwub@gmail.com>
* reduce repeated conversions of same url.URL object
Signed-off-by: kim <grufwub@gmail.com>
* move worker.Worker to concurrency subpackage, add WorkQueue type, limit transport http client use by WorkQueue
Signed-off-by: kim <grufwub@gmail.com>
* fix security advisories regarding max outgoing conns, max rsp body size
- implemented by a new httpclient.Client{} that wraps an underlying
client with a queue to limit connections, and limit reader wrapping
a response body with a configured maximum size
- update pub.HttpClient args passed around to be this new httpclient.Client{}
Signed-off-by: kim <grufwub@gmail.com>
* add httpclient tests, move ip validation to separate package + change mechanism
Signed-off-by: kim <grufwub@gmail.com>
* fix merge conflicts
Signed-off-by: kim <grufwub@gmail.com>
* use singular mutex in transport rather than separate signer mus
Signed-off-by: kim <grufwub@gmail.com>
* improved useragent string
Signed-off-by: kim <grufwub@gmail.com>
* add note regarding missing test
Signed-off-by: kim <grufwub@gmail.com>
* remove useragent field from transport (instead store in controller)
Signed-off-by: kim <grufwub@gmail.com>
* shutup linter
Signed-off-by: kim <grufwub@gmail.com>
* reset other signing headers on each loop iteration
Signed-off-by: kim <grufwub@gmail.com>
* respect request ctx during retry-backoff sleep period
Signed-off-by: kim <grufwub@gmail.com>
* use external pkg with docs explaining performance "hack"
Signed-off-by: kim <grufwub@gmail.com>
* use http package constants instead of string method literals
Signed-off-by: kim <grufwub@gmail.com>
* add license file headers
Signed-off-by: kim <grufwub@gmail.com>
* update code comment to match new func names
Signed-off-by: kim <grufwub@gmail.com>
* updates to user-agent string
Signed-off-by: kim <grufwub@gmail.com>
* update signed testrig models to fit with new transport logic (instead uses separate signer now)
Signed-off-by: kim <grufwub@gmail.com>
* fuck you linter
Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
parent
4ac508f037
commit
223025fc27
61 changed files with 1801 additions and 435 deletions
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/account"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
|
|
@ -20,7 +21,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -62,8 +62,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
|
|
@ -38,7 +39,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -80,8 +80,8 @@ func (suite *AdminStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
|
@ -40,7 +41,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -77,8 +77,8 @@ func (suite *ServeFileTestSuite) SetupSuite() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
|
|
@ -37,7 +38,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -77,8 +77,8 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
|
|
@ -47,7 +48,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -84,8 +84,8 @@ func (suite *MediaCreateTestSuite) SetupSuite() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
|
|
@ -45,7 +46,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -82,8 +82,8 @@ func (suite *MediaUpdateTestSuite) SetupSuite() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/status"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
|
@ -40,7 +41,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -90,8 +90,8 @@ func (suite *StatusStandardTestSuite) SetupTest() {
|
|||
testrig.StandardDBSetup(suite.db, nil)
|
||||
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
|
||||
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(suite.testHttpClient(), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
|
@ -30,7 +31,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -58,8 +58,8 @@ type UserStandardTestSuite struct {
|
|||
func (suite *UserStandardTestSuite) SetupTest() {
|
||||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
suite.testTokens = testrig.NewTestTokens()
|
||||
suite.testClients = testrig.NewTestClients()
|
||||
suite.testApplications = testrig.NewTestApplications()
|
||||
|
|
|
|||
|
|
@ -33,11 +33,11 @@ import (
|
|||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -85,8 +85,8 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
|
|||
suite.NoError(err)
|
||||
body := bytes.NewReader(bodyJson)
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
@ -188,8 +188,8 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
|
|||
suite.NoError(err)
|
||||
body := bytes.NewReader(bodyJson)
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
@ -281,8 +281,8 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
|
|||
suite.NoError(err)
|
||||
body := bytes.NewReader(bodyJson)
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
@ -403,8 +403,8 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
|
|||
suite.NoError(err)
|
||||
body := bytes.NewReader(bodyJson)
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ import (
|
|||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -46,8 +46,8 @@ func (suite *OutboxGetTestSuite) TestGetOutbox() {
|
|||
signedRequest := derefRequests["foss_satan_dereference_zork_outbox"]
|
||||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
@ -104,8 +104,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
|
|||
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"]
|
||||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
@ -162,8 +162,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
|
|||
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"]
|
||||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ import (
|
|||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -49,8 +49,8 @@ func (suite *RepliesGetTestSuite) TestGetReplies() {
|
|||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
targetStatus := suite.testStatuses["local_account_1_status_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
@ -113,8 +113,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
|
|||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
targetStatus := suite.testStatuses["local_account_1_status_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
@ -180,8 +180,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
|
|||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
targetStatus := suite.testStatuses["local_account_1_status_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@ import (
|
|||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -48,8 +48,8 @@ func (suite *StatusGetTestSuite) TestGetStatus() {
|
|||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
targetStatus := suite.testStatuses["local_account_1_status_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
@ -116,8 +116,8 @@ func (suite *StatusGetTestSuite) TestGetStatusLowercase() {
|
|||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
targetStatus := suite.testStatuses["local_account_1_status_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/security"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
|
@ -32,7 +33,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -78,8 +78,8 @@ func (suite *UserStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.tc = testrig.NewTestTypeConverter(suite.db)
|
||||
|
|
|
|||
|
|
@ -33,9 +33,9 @@ import (
|
|||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -49,8 +49,8 @@ func (suite *UserGetTestSuite) TestGetUser() {
|
|||
signedRequest := derefRequests["foss_satan_dereference_zork"]
|
||||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
@ -130,8 +130,8 @@ func (suite *UserGetTestSuite) TestGetUserPublicKeyDeleted() {
|
|||
derefRequests := testrig.NewTestDereferenceRequests(suite.testAccounts)
|
||||
signedRequest := derefRequests["foss_satan_dereference_zork_public_key"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/security"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
|
@ -37,7 +38,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -81,8 +81,8 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.tc = testrig.NewTestTypeConverter(suite.db)
|
||||
|
|
|
|||
|
|
@ -31,10 +31,10 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -71,8 +71,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUser() {
|
|||
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHost() {
|
||||
viper.Set(config.Keys.Host, "gts.example.org")
|
||||
viper.Set(config.Keys.AccountDomain, "example.org")
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
|
||||
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)
|
||||
|
||||
|
|
@ -107,8 +107,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo
|
|||
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAccountDomain() {
|
||||
viper.Set(config.Keys.Host, "gts.example.org")
|
||||
viper.Set(config.Keys.AccountDomain, "example.org")
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
|
||||
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
package worker
|
||||
package concurrency
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
|
@ -12,17 +12,17 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Worker represents a proccessor for MsgType objects, using a worker pool to allocate resources.
|
||||
type Worker[MsgType any] struct {
|
||||
// WorkerPool represents a proccessor for MsgType objects, using a worker pool to allocate resources.
|
||||
type WorkerPool[MsgType any] struct {
|
||||
workers runners.WorkerPool
|
||||
process func(context.Context, MsgType) error
|
||||
prefix string // contains type prefix for logging
|
||||
}
|
||||
|
||||
// New returns a new Worker[MsgType] with given number of workers and queue ratio,
|
||||
// New returns a new WorkerPool[MsgType] with given number of workers and queue ratio,
|
||||
// where the queue ratio is multiplied by no. workers to get queue size. If args < 1
|
||||
// then suitable defaults are determined from the runtime's GOMAXPROCS variable.
|
||||
func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
|
||||
func NewWorkerPool[MsgType any](workers int, queueRatio int) *WorkerPool[MsgType] {
|
||||
var zero MsgType
|
||||
|
||||
if workers < 1 {
|
||||
|
|
@ -38,7 +38,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
|
|||
msgType := reflect.TypeOf(zero).String()
|
||||
_, msgType = path.Split(msgType)
|
||||
|
||||
w := &Worker[MsgType]{
|
||||
w := &WorkerPool[MsgType]{
|
||||
workers: runners.NewWorkerPool(workers, workers*queueRatio),
|
||||
process: nil,
|
||||
prefix: fmt.Sprintf("worker.Worker[%s]", msgType),
|
||||
|
|
@ -55,7 +55,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
|
|||
}
|
||||
|
||||
// Start will attempt to start the underlying worker pool, or return error.
|
||||
func (w *Worker[MsgType]) Start() error {
|
||||
func (w *WorkerPool[MsgType]) Start() error {
|
||||
logrus.Infof("%s starting", w.prefix)
|
||||
|
||||
// Check processor was set
|
||||
|
|
@ -72,7 +72,7 @@ func (w *Worker[MsgType]) Start() error {
|
|||
}
|
||||
|
||||
// Stop will attempt to stop the underlying worker pool, or return error.
|
||||
func (w *Worker[MsgType]) Stop() error {
|
||||
func (w *WorkerPool[MsgType]) Stop() error {
|
||||
logrus.Infof("%s stopping", w.prefix)
|
||||
|
||||
// Attempt to stop pool
|
||||
|
|
@ -84,7 +84,7 @@ func (w *Worker[MsgType]) Stop() error {
|
|||
}
|
||||
|
||||
// SetProcessor will set the Worker's processor function, which is called for each queued message.
|
||||
func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
|
||||
func (w *WorkerPool[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
|
||||
if w.process != nil {
|
||||
logrus.Panicf("%s Worker.process is already set", w.prefix)
|
||||
}
|
||||
|
|
@ -92,7 +92,7 @@ func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error)
|
|||
}
|
||||
|
||||
// Queue will queue provided message to be processed with there's a free worker.
|
||||
func (w *Worker[MsgType]) Queue(msg MsgType) {
|
||||
func (w *WorkerPool[MsgType]) Queue(msg MsgType) {
|
||||
logrus.Tracef("%s queueing message (workers=%d queue=%d): %+v",
|
||||
w.prefix, w.workers.Workers(), w.workers.Queue(), msg,
|
||||
)
|
||||
|
|
@ -29,12 +29,12 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -150,7 +150,7 @@ func (suite *DereferencerStandardTestSuite) mockTransportController() transport.
|
|||
|
||||
return response, nil
|
||||
}
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
mockClient := testrig.NewMockHTTPClient(do)
|
||||
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,10 +28,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ func (suite *FederatingActorTestSuite) TestSendNoRemoteFollowers() {
|
|||
)
|
||||
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
// setup transport controller with a no-op client so we don't make external calls
|
||||
sentMessages := []*url.URL{}
|
||||
|
|
@ -112,7 +112,7 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() {
|
|||
)
|
||||
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
// setup transport controller with a no-op client so we don't make external calls
|
||||
sentMessages := []*url.URL{}
|
||||
|
|
|
|||
|
|
@ -24,10 +24,10 @@ import (
|
|||
"codeberg.org/gruf/go-mutexes"
|
||||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// DB wraps the pub.Database interface with a couple of custom functions for GoToSocial.
|
||||
|
|
@ -44,12 +44,12 @@ type DB interface {
|
|||
type federatingDB struct {
|
||||
locks mutexes.MutexMap
|
||||
db db.DB
|
||||
fedWorker *worker.Worker[messages.FromFederator]
|
||||
fedWorker *concurrency.WorkerPool[messages.FromFederator]
|
||||
typeConverter typeutils.TypeConverter
|
||||
}
|
||||
|
||||
// New returns a DB interface using the given database and config
|
||||
func New(db db.DB, fedWorker *worker.Worker[messages.FromFederator]) DB {
|
||||
func New(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) DB {
|
||||
fdb := federatingDB{
|
||||
locks: mutexes.NewMap(-1, -1), // use defaults
|
||||
db: db,
|
||||
|
|
|
|||
|
|
@ -23,12 +23,12 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ type FederatingDBTestSuite struct {
|
|||
suite.Suite
|
||||
db db.DB
|
||||
tc typeutils.TypeConverter
|
||||
fedWorker *worker.Worker[messages.FromFederator]
|
||||
fedWorker *concurrency.WorkerPool[messages.FromFederator]
|
||||
fromFederator chan messages.FromFederator
|
||||
federatingDB federatingdb.DB
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ func (suite *FederatingDBTestSuite) SetupSuite() {
|
|||
func (suite *FederatingDBTestSuite) SetupTest() {
|
||||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
suite.fedWorker = worker.New[messages.FromFederator](-1, -1)
|
||||
suite.fedWorker = concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
suite.fromFederator = make(chan messages.FromFederator, 10)
|
||||
suite.fedWorker.SetProcessor(func(ctx context.Context, msg messages.FromFederator) error {
|
||||
suite.fromFederator <- msg
|
||||
|
|
|
|||
|
|
@ -28,10 +28,10 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook() {
|
|||
// the activity we're gonna use
|
||||
activity := suite.testActivities["dm_for_zork"]
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
// setup transport controller with a no-op client so we don't make external calls
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {
|
||||
|
|
@ -78,7 +78,7 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostInbox() {
|
|||
sendingAccount := suite.testAccounts["remote_account_1"]
|
||||
inboxAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
// now setup module being tested, with the mock transport controller
|
||||
|
|
|
|||
199
internal/httpclient/client.go
Normal file
199
internal/httpclient/client.go
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net.
|
||||
var ErrReservedAddr = errors.New("dial within blocked / reserved IP range")
|
||||
|
||||
// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB).
|
||||
var ErrBodyTooLarge = errors.New("body size too large")
|
||||
|
||||
// dialer is the base net.Dialer used by all package-created http.Transports.
|
||||
var dialer = &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Resolver: &net.Resolver{Dial: nil},
|
||||
}
|
||||
|
||||
// Config provides configuration details for setting up a new
|
||||
// instance of httpclient.Client{}. Within are a subset of the
|
||||
// configuration values passed to initialized http.Transport{}
|
||||
// and http.Client{}, along with httpclient.Client{} specific.
|
||||
type Config struct {
|
||||
// MaxOpenConns limits the max number of concurrent open connections.
|
||||
MaxOpenConns int
|
||||
|
||||
// MaxIdleConns: see http.Transport{}.MaxIdleConns.
|
||||
MaxIdleConns int
|
||||
|
||||
// ReadBufferSize: see http.Transport{}.ReadBufferSize.
|
||||
ReadBufferSize int
|
||||
|
||||
// WriteBufferSize: see http.Transport{}.WriteBufferSize.
|
||||
WriteBufferSize int
|
||||
|
||||
// MaxBodySize determines the maximum fetchable body size.
|
||||
MaxBodySize int64
|
||||
|
||||
// Timeout: see http.Client{}.Timeout.
|
||||
Timeout time.Duration
|
||||
|
||||
// DisableCompression: see http.Transport{}.DisableCompression.
|
||||
DisableCompression bool
|
||||
|
||||
// AllowRanges allows outgoing communications to given IP nets.
|
||||
AllowRanges []netip.Prefix
|
||||
|
||||
// BlockRanges blocks outgoing communiciations to given IP nets.
|
||||
BlockRanges []netip.Prefix
|
||||
}
|
||||
|
||||
// Client wraps an underlying http.Client{} to provide the following:
|
||||
// - setting a maximum received request body size, returning error on
|
||||
// large content lengths, and using a limited reader in all other
|
||||
// cases to protect against forged / unknown content-lengths
|
||||
// - protection from server side request forgery (SSRF) by only dialing
|
||||
// out to known public IP prefixes, configurable with allows/blocks
|
||||
// - limit number of concurrent requests, else blocking until a slot
|
||||
// is available (context channels still respected)
|
||||
type Client struct {
|
||||
client http.Client
|
||||
queue chan struct{}
|
||||
bmax int64
|
||||
}
|
||||
|
||||
// New returns a new instance of Client initialized using configuration.
|
||||
func New(cfg Config) *Client {
|
||||
var c Client
|
||||
|
||||
// Copy global
|
||||
d := dialer
|
||||
|
||||
if cfg.MaxOpenConns <= 0 {
|
||||
// By default base this value on GOMAXPROCS.
|
||||
maxprocs := runtime.GOMAXPROCS(0)
|
||||
cfg.MaxOpenConns = maxprocs * 10
|
||||
}
|
||||
|
||||
if cfg.MaxIdleConns <= 0 {
|
||||
// By default base this value on MaxOpenConns
|
||||
cfg.MaxIdleConns = cfg.MaxOpenConns * 10
|
||||
}
|
||||
|
||||
if cfg.MaxBodySize <= 0 {
|
||||
// By default set this to a reasonable 40MB
|
||||
cfg.MaxBodySize = 40 * 1024 * 1024
|
||||
}
|
||||
|
||||
// Protect dialer with IP range sanitizer
|
||||
d.Control = (&sanitizer{
|
||||
allow: cfg.AllowRanges,
|
||||
block: cfg.BlockRanges,
|
||||
}).Sanitize
|
||||
|
||||
// Prepare client fields
|
||||
c.bmax = cfg.MaxBodySize
|
||||
c.queue = make(chan struct{}, cfg.MaxOpenConns)
|
||||
c.client.Timeout = cfg.Timeout
|
||||
|
||||
// Set underlying HTTP client roundtripper
|
||||
c.client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: d.DialContext,
|
||||
MaxIdleConns: cfg.MaxIdleConns,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ReadBufferSize: cfg.ReadBufferSize,
|
||||
WriteBufferSize: cfg.WriteBufferSize,
|
||||
DisableCompression: cfg.DisableCompression,
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
// Do will perform given request when an available slot in the queue is available,
|
||||
// and block until this time. For returned values, this follows the same semantics
|
||||
// as the standard http.Client{}.Do() implementation except that response body will
|
||||
// be wrapped by an io.LimitReader() to limit response body sizes.
|
||||
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
select {
|
||||
// Request context cancelled
|
||||
case <-req.Context().Done():
|
||||
return nil, req.Context().Err()
|
||||
|
||||
// Slot in queue acquired
|
||||
case c.queue <- struct{}{}:
|
||||
// NOTE:
|
||||
// Ideally here we would set the slot release to happen either
|
||||
// on error return, or via callback from the response body closer.
|
||||
// However when implementing this, there appear deadlocks between
|
||||
// the channel queue here and the media manager worker pool. So
|
||||
// currently we only place a limit on connections dialing out, but
|
||||
// there may still be more connections open than len(c.queue) given
|
||||
// that connections may not be closed until response body is closed.
|
||||
// The current implementation will reduce the viability of denial of
|
||||
// service attacks, but if there are future issues heed this advice :]
|
||||
defer func() { <-c.queue }()
|
||||
}
|
||||
|
||||
// Perform the HTTP request
|
||||
rsp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check response body not too large
|
||||
if rsp.ContentLength > c.bmax {
|
||||
return nil, ErrBodyTooLarge
|
||||
}
|
||||
|
||||
// Seperate the body implementers
|
||||
rbody := (io.Reader)(rsp.Body)
|
||||
cbody := (io.Closer)(rsp.Body)
|
||||
|
||||
var limit int64
|
||||
|
||||
if limit = rsp.ContentLength; limit < 0 {
|
||||
// If unknown, use max as reader limit
|
||||
limit = c.bmax
|
||||
}
|
||||
|
||||
// Don't trust them, limit body reads
|
||||
rbody = io.LimitReader(rbody, limit)
|
||||
|
||||
// Wrap body with limit
|
||||
rsp.Body = &struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}{rbody, cbody}
|
||||
|
||||
return rsp, nil
|
||||
}
|
||||
154
internal/httpclient/client_test.go
Normal file
154
internal/httpclient/client_test.go
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package httpclient_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
|
||||
)
|
||||
|
||||
var privateIPs = []string{
|
||||
"http://127.0.0.1:80",
|
||||
"http://0.0.0.0:80",
|
||||
"http://192.168.0.1:80",
|
||||
"http://192.168.1.0:80",
|
||||
"http://10.0.0.0:80",
|
||||
"http://172.16.0.0:80",
|
||||
"http://10.255.255.255:80",
|
||||
"http://172.31.255.255:80",
|
||||
"http://255.255.255.255:80",
|
||||
}
|
||||
|
||||
var bodies = []string{
|
||||
"hello world!",
|
||||
"{}",
|
||||
`{"key": "value", "some": "kinda bullshit"}`,
|
||||
"body with\r\nnewlines",
|
||||
}
|
||||
|
||||
// Note:
|
||||
// There is no test for the .MaxOpenConns implementation
|
||||
// in the httpclient.Client{}, due to the difficult to test
|
||||
// this. The block is only held for the actual dial out to
|
||||
// the connection, so the usual test of blocking and holding
|
||||
// open this queue slot to check we can't open another isn't
|
||||
// an easy test here.
|
||||
|
||||
func TestHTTPClientSmallBody(t *testing.T) {
|
||||
for _, body := range bodies {
|
||||
_TestHTTPClientWithBody(t, []byte(body), int(^uint16(0)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientExactBody(t *testing.T) {
|
||||
for _, body := range bodies {
|
||||
_TestHTTPClientWithBody(t, []byte(body), len(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientLargeBody(t *testing.T) {
|
||||
for _, body := range bodies {
|
||||
_TestHTTPClientWithBody(t, []byte(body), len(body)-1)
|
||||
}
|
||||
}
|
||||
|
||||
func _TestHTTPClientWithBody(t *testing.T, body []byte, max int) {
|
||||
var (
|
||||
handler http.HandlerFunc
|
||||
|
||||
expect []byte
|
||||
|
||||
expectErr error
|
||||
)
|
||||
|
||||
// If this is a larger body, reslice and
|
||||
// set error so we know what to expect
|
||||
expect = body
|
||||
if max < len(body) {
|
||||
expect = expect[:max]
|
||||
expectErr = httpclient.ErrBodyTooLarge
|
||||
}
|
||||
|
||||
// Create new HTTP client with maximum body size
|
||||
client := httpclient.New(httpclient.Config{
|
||||
MaxBodySize: int64(max),
|
||||
DisableCompression: true,
|
||||
AllowRanges: []netip.Prefix{
|
||||
// Loopback (used by server)
|
||||
netip.MustParsePrefix("127.0.0.1/8"),
|
||||
},
|
||||
})
|
||||
|
||||
// Set simple body-writing test handler
|
||||
handler = func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, _ = rw.Write(body)
|
||||
}
|
||||
|
||||
// Start the test server
|
||||
srv := httptest.NewServer(handler)
|
||||
defer srv.Close()
|
||||
|
||||
// Wrap body to provide reader iface
|
||||
rbody := bytes.NewReader(body)
|
||||
|
||||
// Create the test HTTP request
|
||||
req, _ := http.NewRequest("POST", srv.URL, rbody)
|
||||
|
||||
// Perform the test request
|
||||
rsp, err := client.Do(req)
|
||||
if !errors.Is(err, expectErr) {
|
||||
t.Fatalf("error performing client request: %v", err)
|
||||
} else if err != nil {
|
||||
return // expected error
|
||||
}
|
||||
defer rsp.Body.Close()
|
||||
|
||||
// Read response body into memory
|
||||
check, err := io.ReadAll(rsp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("error reading response body: %v", err)
|
||||
}
|
||||
|
||||
// Check actual response body matches expected
|
||||
if !bytes.Equal(expect, check) {
|
||||
t.Errorf("response body did not match expected: expect=%q actual=%q", string(expect), string(check))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientPrivateIP(t *testing.T) {
|
||||
client := httpclient.New(httpclient.Config{})
|
||||
|
||||
for _, addr := range privateIPs {
|
||||
// Prepare request to private IP
|
||||
req, _ := http.NewRequest("GET", addr, nil)
|
||||
|
||||
// Perform the HTTP request
|
||||
_, err := client.Do(req)
|
||||
if !errors.Is(err, httpclient.ErrReservedAddr) {
|
||||
t.Errorf("dialing private address did not return expected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
64
internal/httpclient/sanitizer.go
Normal file
64
internal/httpclient/sanitizer.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/netutil"
|
||||
)
|
||||
|
||||
type sanitizer struct {
|
||||
allow []netip.Prefix
|
||||
block []netip.Prefix
|
||||
}
|
||||
|
||||
// Sanitize implements the required net.Dialer.Control function signature.
|
||||
func (s *sanitizer) Sanitize(ntwrk, addr string, _ syscall.RawConn) error {
|
||||
// Parse IP+port from addr
|
||||
ipport, err := netip.ParseAddrPort(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Seperate the IP
|
||||
ip := ipport.Addr()
|
||||
|
||||
// Check if this is explicitly allowed
|
||||
for i := 0; i < len(s.allow); i++ {
|
||||
if s.allow[i].Contains(ip) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Now check if explicity blocked
|
||||
for i := 0; i < len(s.block); i++ {
|
||||
if s.block[i].Contains(ip) {
|
||||
return ErrReservedAddr
|
||||
}
|
||||
}
|
||||
|
||||
// Validate this is a safe IP
|
||||
if !netutil.ValidateIP(ip) {
|
||||
return ErrReservedAddr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -27,9 +27,9 @@ import (
|
|||
"github.com/robfig/cron/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// Manager provides an interface for managing media: parsing, storing, and retrieving media objects like photos, videos, and gifs.
|
||||
|
|
@ -79,8 +79,8 @@ type Manager interface {
|
|||
type manager struct {
|
||||
db db.DB
|
||||
storage *kv.KVStore
|
||||
emojiWorker *worker.Worker[*ProcessingEmoji]
|
||||
mediaWorker *worker.Worker[*ProcessingMedia]
|
||||
emojiWorker *concurrency.WorkerPool[*ProcessingEmoji]
|
||||
mediaWorker *concurrency.WorkerPool[*ProcessingMedia]
|
||||
stopCronJobs func() error
|
||||
}
|
||||
|
||||
|
|
@ -89,7 +89,7 @@ type manager struct {
|
|||
// A worker pool will also be initialized for the manager, to ensure that only
|
||||
// a limited number of media will be processed in parallel. The numbers of workers
|
||||
// is determined from the $GOMAXPROCS environment variable (usually no. CPU cores).
|
||||
// See internal/worker.New() documentation for further information.
|
||||
// See internal/concurrency.NewWorkerPool() documentation for further information.
|
||||
func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
|
||||
m := &manager{
|
||||
db: database,
|
||||
|
|
@ -97,7 +97,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
|
|||
}
|
||||
|
||||
// Prepare the media worker pool
|
||||
m.mediaWorker = worker.New[*ProcessingMedia](-1, 10)
|
||||
m.mediaWorker = concurrency.NewWorkerPool[*ProcessingMedia](-1, 10)
|
||||
m.mediaWorker.SetProcessor(func(ctx context.Context, media *ProcessingMedia) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
|
|
@ -109,7 +109,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
|
|||
})
|
||||
|
||||
// Prepare the emoji worker pool
|
||||
m.emojiWorker = worker.New[*ProcessingEmoji](-1, 10)
|
||||
m.emojiWorker = concurrency.NewWorkerPool[*ProcessingEmoji](-1, 10)
|
||||
m.emojiWorker.SetProcessor(func(ctx context.Context, emoji *ProcessingEmoji) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
|
|
|
|||
78
internal/netutil/validate.go
Normal file
78
internal/netutil/validate.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
var (
|
||||
// IPv6GlobalUnicast is the global IPv6 unicast IP prefix.
|
||||
IPv6GlobalUnicast = netip.MustParsePrefix("ff00::/8")
|
||||
|
||||
// IPvReserved contains IPv4 reserved IP prefixes.
|
||||
IPv4Reserved = [...]netip.Prefix{
|
||||
netip.MustParsePrefix("0.0.0.0/8"), // Current network
|
||||
netip.MustParsePrefix("10.0.0.0/8"), // Private
|
||||
netip.MustParsePrefix("100.64.0.0/10"), // RFC6598
|
||||
netip.MustParsePrefix("127.0.0.0/8"), // Loopback
|
||||
netip.MustParsePrefix("169.254.0.0/16"), // Link-local
|
||||
netip.MustParsePrefix("172.16.0.0/12"), // Private
|
||||
netip.MustParsePrefix("192.0.0.0/24"), // RFC6890
|
||||
netip.MustParsePrefix("192.0.2.0/24"), // Test, doc, examples
|
||||
netip.MustParsePrefix("192.88.99.0/24"), // IPv6 to IPv4 relay
|
||||
netip.MustParsePrefix("192.168.0.0/16"), // Private
|
||||
netip.MustParsePrefix("198.18.0.0/15"), // Benchmarking tests
|
||||
netip.MustParsePrefix("198.51.100.0/24"), // Test, doc, examples
|
||||
netip.MustParsePrefix("203.0.113.0/24"), // Test, doc, examples
|
||||
netip.MustParsePrefix("224.0.0.0/4"), // Multicast
|
||||
netip.MustParsePrefix("240.0.0.0/4"), // Reserved (includes broadcast / 255.255.255.255)
|
||||
}
|
||||
)
|
||||
|
||||
// ValidateAddr will parse a netip.AddrPort from string, and return the result of ValidateIP() on addr.
|
||||
func ValidateAddr(s string) bool {
|
||||
ipport, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return ValidateIP(ipport.Addr())
|
||||
}
|
||||
|
||||
// ValidateIP returns whether IP is an IPv4/6 address in non-reserved, public ranges.
|
||||
func ValidateIP(ip netip.Addr) bool {
|
||||
switch {
|
||||
// IPv4: check if IPv4 in reserved nets
|
||||
case ip.Is4():
|
||||
for _, reserved := range IPv4Reserved {
|
||||
if reserved.Contains(ip) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
// IPv6: check if in global unicast (public internet)
|
||||
case ip.Is6():
|
||||
return IPv6GlobalUnicast.Contains(ip)
|
||||
|
||||
// Assume malicious by default
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
@ -23,6 +23,7 @@ import (
|
|||
"mime/multipart"
|
||||
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
|
|
@ -33,7 +34,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/text"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/visibility"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/oauth2/v4"
|
||||
)
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ type Processor interface {
|
|||
type processor struct {
|
||||
tc typeutils.TypeConverter
|
||||
mediaManager media.Manager
|
||||
clientWorker *worker.Worker[messages.FromClientAPI]
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
oauthServer oauth.Server
|
||||
filter visibility.Filter
|
||||
formatter text.Formatter
|
||||
|
|
@ -94,7 +94,7 @@ type processor struct {
|
|||
}
|
||||
|
||||
// New returns a new account processor.
|
||||
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *worker.Worker[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
|
||||
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
|
||||
return &processor{
|
||||
tc: tc,
|
||||
mediaManager: mediaManager,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
|
@ -35,7 +36,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -81,8 +81,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error {
|
||||
suite.fromClientAPIChan <- msg
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -23,13 +23,13 @@ import (
|
|||
"mime/multipart"
|
||||
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// Processor wraps a bunch of functions for processing admin actions.
|
||||
|
|
@ -47,12 +47,12 @@ type Processor interface {
|
|||
type processor struct {
|
||||
tc typeutils.TypeConverter
|
||||
mediaManager media.Manager
|
||||
clientWorker *worker.Worker[messages.FromClientAPI]
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
db db.DB
|
||||
}
|
||||
|
||||
// New returns a new admin processor.
|
||||
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *worker.Worker[messages.FromClientAPI]) Processor {
|
||||
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor {
|
||||
return &processor{
|
||||
tc: tc,
|
||||
mediaManager: mediaManager,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import (
|
|||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
|
|
@ -33,7 +34,6 @@ import (
|
|||
mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ func (suite *MediaStandardTestSuite) mockTransportController() transport.Control
|
|||
|
||||
return response, nil
|
||||
}
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
mockClient := testrig.NewMockHTTPClient(do)
|
||||
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import (
|
|||
|
||||
"codeberg.org/gruf/go-store/kv"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
|
@ -44,7 +45,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/timeline"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/visibility"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// Processor should be passed to api modules (see internal/apimodule/...). It is used for
|
||||
|
|
@ -237,8 +237,8 @@ type Processor interface {
|
|||
|
||||
// processor just implements the Processor interface
|
||||
type processor struct {
|
||||
clientWorker *worker.Worker[messages.FromClientAPI]
|
||||
fedWorker *worker.Worker[messages.FromFederator]
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
fedWorker *concurrency.WorkerPool[messages.FromFederator]
|
||||
|
||||
federator federation.Federator
|
||||
tc typeutils.TypeConverter
|
||||
|
|
@ -271,8 +271,8 @@ func NewProcessor(
|
|||
storage *kv.KVStore,
|
||||
db db.DB,
|
||||
emailSender email.Sender,
|
||||
clientWorker *worker.Worker[messages.FromClientAPI],
|
||||
fedWorker *worker.Worker[messages.FromFederator],
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
|
||||
fedWorker *concurrency.WorkerPool[messages.FromFederator],
|
||||
) Processor {
|
||||
parseMentionFunc := GetParseMentionFunc(db, federator)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ import (
|
|||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
|
@ -40,7 +41,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/timeline"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -217,8 +217,8 @@ func (suite *ProcessingStandardTestSuite) SetupTest() {
|
|||
}, nil
|
||||
})
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
suite.transportController = testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
|
||||
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"context"
|
||||
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
|
@ -29,7 +30,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/text"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/visibility"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// Processor wraps a bunch of functions for processing statuses.
|
||||
|
|
@ -74,12 +74,12 @@ type processor struct {
|
|||
db db.DB
|
||||
filter visibility.Filter
|
||||
formatter text.Formatter
|
||||
clientWorker *worker.Worker[messages.FromClientAPI]
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
parseMention gtsmodel.ParseMentionFunc
|
||||
}
|
||||
|
||||
// New returns a new status processor.
|
||||
func New(db db.DB, tc typeutils.TypeConverter, clientWorker *worker.Worker[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
|
||||
func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
|
||||
return &processor{
|
||||
tc: tc,
|
||||
db: db,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ package status_test
|
|||
import (
|
||||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
|
@ -30,7 +31,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ type StatusStandardTestSuite struct {
|
|||
storage *kv.KVStore
|
||||
mediaManager media.Manager
|
||||
federator federation.Federator
|
||||
clientWorker *worker.Worker[messages.FromClientAPI]
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
|
||||
// standard suite models
|
||||
testTokens map[string]*gtsmodel.Token
|
||||
|
|
@ -75,11 +75,11 @@ func (suite *StatusStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.typeConverter = testrig.NewTestTypeConverter(suite.db)
|
||||
suite.clientWorker = worker.New[messages.FromClientAPI](-1, -1)
|
||||
suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
|
||||
|
|
|
|||
|
|
@ -20,13 +20,17 @@ package transport
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/go-fed/httpsig"
|
||||
"codeberg.org/gruf/go-byteutil"
|
||||
"codeberg.org/gruf/go-cache/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/activity/streams"
|
||||
|
|
@ -37,109 +41,85 @@ import (
|
|||
|
||||
// Controller generates transports for use in making federation requests to other servers.
|
||||
type Controller interface {
|
||||
NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error)
|
||||
// NewTransport returns an http signature transport with the given public key ID (URL location of pubkey), and the given private key.
|
||||
NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error)
|
||||
|
||||
// NewTransportForUsername searches for account with username, and returns result of .NewTransport().
|
||||
NewTransportForUsername(ctx context.Context, username string) (Transport, error)
|
||||
}
|
||||
|
||||
type controller struct {
|
||||
db db.DB
|
||||
clock pub.Clock
|
||||
client pub.HttpClient
|
||||
appAgent string
|
||||
|
||||
// dereferenceFollowersShortcut is a shortcut to dereference followers of an
|
||||
// account on this instance, without making any external api/http calls.
|
||||
//
|
||||
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
|
||||
dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
|
||||
|
||||
// dereferenceUserShortcut is a shortcut to dereference followers an account on
|
||||
// this instance, without making any external api/http calls.
|
||||
//
|
||||
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
|
||||
dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
|
||||
}
|
||||
|
||||
func dereferenceFollowersShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
|
||||
return func(ctx context.Context, iri *url.URL) ([]byte, error) {
|
||||
followers, err := federatingDB.Followers(ctx, iri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i, err := streams.Serialize(followers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(i)
|
||||
}
|
||||
}
|
||||
|
||||
func dereferenceUserShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
|
||||
return func(ctx context.Context, iri *url.URL) ([]byte, error) {
|
||||
user, err := federatingDB.Get(ctx, iri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i, err := streams.Serialize(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(i)
|
||||
}
|
||||
db db.DB
|
||||
fedDB federatingdb.DB
|
||||
clock pub.Clock
|
||||
client pub.HttpClient
|
||||
cache cache.Cache[string, *transport]
|
||||
userAgent string
|
||||
}
|
||||
|
||||
// NewController returns an implementation of the Controller interface for creating new transports
|
||||
func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
|
||||
applicationName := viper.GetString(config.Keys.ApplicationName)
|
||||
host := viper.GetString(config.Keys.Host)
|
||||
appAgent := fmt.Sprintf("%s %s", applicationName, host)
|
||||
|
||||
return &controller{
|
||||
db: db,
|
||||
clock: clock,
|
||||
client: client,
|
||||
appAgent: appAgent,
|
||||
dereferenceFollowersShortcut: dereferenceFollowersShortcut(federatingDB),
|
||||
dereferenceUserShortcut: dereferenceUserShortcut(federatingDB),
|
||||
// Determine build information
|
||||
build, _ := debug.ReadBuildInfo()
|
||||
|
||||
c := &controller{
|
||||
db: db,
|
||||
fedDB: federatingDB,
|
||||
clock: clock,
|
||||
client: client,
|
||||
cache: cache.New[string, *transport](),
|
||||
userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, build.Main.Version),
|
||||
}
|
||||
|
||||
// Transport cache has TTL=1hr freq=1m
|
||||
c.cache.SetTTL(time.Hour, false)
|
||||
if !c.cache.Start(time.Minute) {
|
||||
logrus.Panic("failed to start transport controller cache")
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// NewTransport returns a new http signature transport with the given public key id (a URL), and the given private key.
|
||||
func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) {
|
||||
prefs := []httpsig.Algorithm{httpsig.RSA_SHA256}
|
||||
digestAlgo := httpsig.DigestSha256
|
||||
getHeaders := []string{httpsig.RequestTarget, "host", "date"}
|
||||
postHeaders := []string{httpsig.RequestTarget, "host", "date", "digest"}
|
||||
func (c *controller) NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error) {
|
||||
// Generate public key string for cache key
|
||||
//
|
||||
// NOTE: it is safe to use the public key as the cache
|
||||
// key here as we are generating it ourselves from the
|
||||
// private key. If we were simply using a public key
|
||||
// provided as argument that would absolutely NOT be safe.
|
||||
pubStr := privkeyToPublicStr(privkey)
|
||||
|
||||
getSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, 120)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating get signer: %s", err)
|
||||
// First check for cached transport
|
||||
transp, ok := c.cache.Get(pubStr)
|
||||
if ok {
|
||||
return transp, nil
|
||||
}
|
||||
|
||||
postSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, 120)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating post signer: %s", err)
|
||||
// Create the transport
|
||||
transp = &transport{
|
||||
controller: c,
|
||||
pubKeyID: pubKeyID,
|
||||
privkey: privkey,
|
||||
}
|
||||
|
||||
sigTransport := pub.NewHttpSigTransport(c.client, c.appAgent, c.clock, getSigner, postSigner, pubKeyID, privkey)
|
||||
// Cache this transport under pubkey
|
||||
if !c.cache.Put(pubStr, transp) {
|
||||
var cached *transport
|
||||
|
||||
return &transport{
|
||||
client: c.client,
|
||||
appAgent: c.appAgent,
|
||||
gofedAgent: "(go-fed/activity v1.0.0)",
|
||||
clock: c.clock,
|
||||
pubKeyID: pubKeyID,
|
||||
privkey: privkey,
|
||||
sigTransport: sigTransport,
|
||||
getSigner: getSigner,
|
||||
getSignerMu: &sync.Mutex{},
|
||||
dereferenceFollowersShortcut: c.dereferenceFollowersShortcut,
|
||||
dereferenceUserShortcut: c.dereferenceUserShortcut,
|
||||
}, nil
|
||||
cached, ok = c.cache.Get(pubStr)
|
||||
if !ok {
|
||||
// Some ridiculous race cond.
|
||||
c.cache.Set(pubStr, transp)
|
||||
} else {
|
||||
// Use already cached
|
||||
transp = cached
|
||||
}
|
||||
}
|
||||
|
||||
return transp, nil
|
||||
}
|
||||
|
||||
func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) {
|
||||
|
|
@ -164,3 +144,45 @@ func (c *controller) NewTransportForUsername(ctx context.Context, username strin
|
|||
}
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// dereferenceLocalFollowers is a shortcut to dereference followers of an
|
||||
// account on this instance, without making any external api/http calls.
|
||||
//
|
||||
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
|
||||
func (c *controller) dereferenceLocalFollowers(ctx context.Context, iri *url.URL) ([]byte, error) {
|
||||
followers, err := c.fedDB.Followers(ctx, iri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i, err := streams.Serialize(followers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(i)
|
||||
}
|
||||
|
||||
// dereferenceLocalUser is a shortcut to dereference followers an account on
|
||||
// this instance, without making any external api/http calls.
|
||||
//
|
||||
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
|
||||
func (c *controller) dereferenceLocalUser(ctx context.Context, iri *url.URL) ([]byte, error) {
|
||||
user, err := c.fedDB.Get(ctx, iri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i, err := streams.Serialize(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(i)
|
||||
}
|
||||
|
||||
// privkeyToPublicStr will create a string representation of RSA public key from private.
|
||||
func privkeyToPublicStr(privkey *rsa.PrivateKey) string {
|
||||
b := x509.MarshalPKCS1PublicKey(&privkey.PublicKey)
|
||||
return byteutil.B2S(b)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,13 +19,14 @@
|
|||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
|
@ -72,6 +73,28 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
logrus.Debugf("Deliver: posting as %s to %s", t.pubKeyID, to.String())
|
||||
return t.sigTransport.Deliver(ctx, b, to)
|
||||
urlStr := to.String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", urlStr, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
|
||||
req.Header.Add("Accept-Charset", "utf-8")
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", to.Host)
|
||||
|
||||
resp, err := t.POST(req, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if code := resp.StatusCode; code != http.StatusOK &&
|
||||
code != http.StatusCreated && code != http.StatusAccepted {
|
||||
return fmt.Errorf("POST request to %s failed (%d): %s", urlStr, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,32 +20,55 @@ package transport
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/uris"
|
||||
)
|
||||
|
||||
func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) {
|
||||
l := logrus.WithField("func", "Dereference")
|
||||
|
||||
// if the request is to us, we can shortcut for certain URIs rather than going through
|
||||
// the normal request flow, thereby saving time and energy
|
||||
if iri.Host == viper.GetString(config.Keys.Host) {
|
||||
if uris.IsFollowersPath(iri) {
|
||||
// the request is for followers of one of our accounts, which we can shortcut
|
||||
return t.dereferenceFollowersShortcut(ctx, iri)
|
||||
return t.controller.dereferenceLocalFollowers(ctx, iri)
|
||||
}
|
||||
|
||||
if uris.IsUserPath(iri) {
|
||||
// the request is for one of our accounts, which we can shortcut
|
||||
return t.dereferenceUserShortcut(ctx, iri)
|
||||
return t.controller.dereferenceLocalUser(ctx, iri)
|
||||
}
|
||||
}
|
||||
|
||||
// the request is either for a remote host or for us but we don't have a shortcut, so continue as normal
|
||||
l.Debugf("performing GET to %s", iri.String())
|
||||
return t.sigTransport.Dereference(ctx, iri)
|
||||
// Build IRI just once
|
||||
iriStr := iri.String()
|
||||
|
||||
// Prepare new HTTP request to endpoint
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Accept", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
|
||||
req.Header.Add("Accept-Charset", "utf-8")
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", iri.Host)
|
||||
|
||||
// Perform the HTTP request
|
||||
rsp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rsp.Body.Close()
|
||||
|
||||
// Check for an expected status code
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
|
||||
}
|
||||
|
||||
return ioutil.ReadAll(rsp.Body)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,43 +80,38 @@ func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gts
|
|||
}
|
||||
|
||||
func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) {
|
||||
l := logrus.WithField("func", "dereferenceByAPIV1Instance")
|
||||
|
||||
cleanIRI := &url.URL{
|
||||
Scheme: iri.Scheme,
|
||||
Host: iri.Host,
|
||||
Path: "api/v1/instance",
|
||||
}
|
||||
|
||||
l.Debugf("performing GET to %s", cleanIRI.String())
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
|
||||
req.Header.Set("Host", cleanIRI.Host)
|
||||
t.getSignerMu.Lock()
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
|
||||
t.getSignerMu.Unlock()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
|
||||
}
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
// Build IRI just once
|
||||
iriStr := cleanIRI.String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", cleanIRI.Host)
|
||||
|
||||
resp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(b) == 0 {
|
||||
return nil, errors.New("response bytes was len 0")
|
||||
}
|
||||
|
||||
|
|
@ -237,44 +232,37 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm
|
|||
}
|
||||
|
||||
func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) {
|
||||
l := logrus.WithField("func", "callNodeInfoWellKnown")
|
||||
|
||||
cleanIRI := &url.URL{
|
||||
Scheme: iri.Scheme,
|
||||
Host: iri.Host,
|
||||
Path: ".well-known/nodeinfo",
|
||||
}
|
||||
|
||||
l.Debugf("performing GET to %s", cleanIRI.String())
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Build IRI just once
|
||||
iriStr := cleanIRI.String()
|
||||
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
|
||||
req.Header.Set("Host", cleanIRI.Host)
|
||||
t.getSignerMu.Lock()
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
|
||||
t.getSignerMu.Unlock()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := t.client.Do(req)
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", cleanIRI.Host)
|
||||
|
||||
resp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
|
||||
return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
} else if len(b) == 0 {
|
||||
return nil, errors.New("callNodeInfoWellKnown: response bytes was len 0")
|
||||
}
|
||||
|
||||
|
|
@ -302,38 +290,31 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur
|
|||
}
|
||||
|
||||
func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) {
|
||||
l := logrus.WithField("func", "callNodeInfo")
|
||||
// Build IRI just once
|
||||
iriStr := iri.String()
|
||||
|
||||
l.Debugf("performing GET to %s", iri.String())
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", iri.Host)
|
||||
t.getSignerMu.Lock()
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
|
||||
t.getSignerMu.Unlock()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := t.client.Do(req)
|
||||
|
||||
resp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
|
||||
return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
} else if len(b) == 0 {
|
||||
return nil, errors.New("callNodeInfo: response bytes was len 0")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,34 +24,31 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.ReadCloser, int, error) {
|
||||
l := logrus.WithField("func", "DereferenceMedia")
|
||||
l.Debugf("performing GET to %s", iri.String())
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
|
||||
// Build IRI just once
|
||||
iriStr := iri.String()
|
||||
|
||||
// Prepare HTTP request to this media's IRI
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", iri.Host)
|
||||
|
||||
// Perform the HTTP request
|
||||
rsp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
|
||||
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
|
||||
req.Header.Set("Host", iri.Host)
|
||||
t.getSignerMu.Lock()
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
|
||||
t.getSignerMu.Unlock()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
// Check for an expected status code
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
|
||||
}
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
|
||||
}
|
||||
return resp.Body, int(resp.ContentLength), nil
|
||||
|
||||
return rsp.Body, int(rsp.ContentLength), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,46 +23,36 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
|
||||
l := logrus.WithField("func", "Finger")
|
||||
urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain)
|
||||
l.Debugf("performing GET to %s", urlString)
|
||||
// Prepare URL string
|
||||
urlStr := "https://" +
|
||||
targetDomain +
|
||||
"/.well-known/webfinger?resource=acct:" +
|
||||
targetUsername + "@" + targetDomain
|
||||
|
||||
iri, err := url.Parse(urlString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Finger: error parsing url %s: %s", urlString, err)
|
||||
}
|
||||
|
||||
l.Debugf("performing GET to %s", iri.String())
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
|
||||
// Generate new GET request from URL string
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Accept", "application/jrd+json")
|
||||
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
|
||||
req.Header.Set("Host", iri.Host)
|
||||
t.getSignerMu.Lock()
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
|
||||
t.getSignerMu.Unlock()
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", req.URL.Host)
|
||||
|
||||
// Perform the HTTP request
|
||||
rsp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
defer rsp.Body.Close()
|
||||
|
||||
// Check for an expected status code
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GET request to %s failed (%d): %s", urlStr, rsp.StatusCode, rsp.Status)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
|
||||
}
|
||||
return ioutil.ReadAll(resp.Body)
|
||||
|
||||
return ioutil.ReadAll(rsp.Body)
|
||||
}
|
||||
|
|
|
|||
43
internal/transport/signing.go
Normal file
43
internal/transport/signing.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"github.com/go-fed/httpsig"
|
||||
)
|
||||
|
||||
var (
|
||||
// http signer preferences
|
||||
prefs = []httpsig.Algorithm{httpsig.RSA_SHA256}
|
||||
digestAlgo = httpsig.DigestSha256
|
||||
getHeaders = []string{httpsig.RequestTarget, "host", "date"}
|
||||
postHeaders = []string{httpsig.RequestTarget, "host", "date", "digest"}
|
||||
)
|
||||
|
||||
// NewGETSigner returns a new httpsig.Signer instance initialized with GTS GET preferences.
|
||||
func NewGETSigner(expiresIn int64) (httpsig.Signer, error) {
|
||||
sig, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, expiresIn)
|
||||
return sig, err
|
||||
}
|
||||
|
||||
// NewPOSTSigner returns a new httpsig.Signer instance initialized with GTS POST preferences.
|
||||
func NewPOSTSigner(expiresIn int64) (httpsig.Signer, error) {
|
||||
sig, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, expiresIn)
|
||||
return sig, err
|
||||
}
|
||||
|
|
@ -21,11 +21,18 @@ package transport
|
|||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
errorsv2 "codeberg.org/gruf/go-errors/v2"
|
||||
"github.com/go-fed/httpsig"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
|
@ -43,28 +50,148 @@ type Transport interface {
|
|||
DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error)
|
||||
// Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body.
|
||||
Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error)
|
||||
// SigTransport returns the underlying http signature transport wrapped by the GoToSocial transport.
|
||||
SigTransport() pub.Transport
|
||||
}
|
||||
|
||||
// transport implements the Transport interface
|
||||
type transport struct {
|
||||
client pub.HttpClient
|
||||
appAgent string
|
||||
gofedAgent string
|
||||
clock pub.Clock
|
||||
pubKeyID string
|
||||
privkey crypto.PrivateKey
|
||||
sigTransport *pub.HttpSigTransport
|
||||
getSigner httpsig.Signer
|
||||
getSignerMu *sync.Mutex
|
||||
controller *controller
|
||||
pubKeyID string
|
||||
privkey crypto.PrivateKey
|
||||
|
||||
// shortcuts for dereferencing things that exist on our instance without making an http call to ourself
|
||||
|
||||
dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
|
||||
dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
|
||||
signerExp time.Time
|
||||
getSigner httpsig.Signer
|
||||
postSigner httpsig.Signer
|
||||
signerMu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *transport) SigTransport() pub.Transport {
|
||||
return t.sigTransport
|
||||
// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
|
||||
func (t *transport) GET(r *http.Request, retryOn ...int) (*http.Response, error) {
|
||||
if r.Method != http.MethodGet {
|
||||
return nil, errors.New("must be GET request")
|
||||
}
|
||||
return t.do(r, func(r *http.Request) error {
|
||||
return t.signGET(r)
|
||||
}, retryOn...)
|
||||
}
|
||||
|
||||
// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
|
||||
func (t *transport) POST(r *http.Request, body []byte, retryOn ...int) (*http.Response, error) {
|
||||
if r.Method != http.MethodPost {
|
||||
return nil, errors.New("must be POST request")
|
||||
}
|
||||
return t.do(r, func(r *http.Request) error {
|
||||
return t.signPOST(r, body)
|
||||
}, retryOn...)
|
||||
}
|
||||
|
||||
func (t *transport) do(r *http.Request, signer func(*http.Request) error, retryOn ...int) (*http.Response, error) {
|
||||
const maxRetries = 5
|
||||
backoff := time.Second * 2
|
||||
|
||||
// Start a log entry for this request
|
||||
l := logrus.WithFields(logrus.Fields{
|
||||
"pubKeyID": t.pubKeyID,
|
||||
"method": r.Method,
|
||||
"url": r.URL.String(),
|
||||
})
|
||||
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
// Reset signing header fields
|
||||
now := t.controller.clock.Now().UTC()
|
||||
r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
r.Header.Del("Signature")
|
||||
r.Header.Del("Digest")
|
||||
|
||||
// Perform request signing
|
||||
if err := signer(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l.Infof("performing request")
|
||||
|
||||
// Attempt to perform request
|
||||
rsp, err := t.controller.client.Do(r)
|
||||
if err == nil { //nolint shutup linter
|
||||
// TooManyRequest means we need to slow
|
||||
// down and retry our request. Codes over
|
||||
// 500 generally indicate temp. outages.
|
||||
if code := rsp.StatusCode; code < 500 &&
|
||||
code != http.StatusTooManyRequests &&
|
||||
!containsInt(retryOn, rsp.StatusCode) {
|
||||
return rsp, nil
|
||||
}
|
||||
|
||||
// Generate error from status code for logging
|
||||
err = errors.New(`http response "` + rsp.Status + `"`)
|
||||
} else if errorsv2.Is(err, context.DeadlineExceeded, context.Canceled) {
|
||||
// Return early if context has cancelled
|
||||
return nil, err
|
||||
} else if strings.Contains(err.Error(), "stopped after 10 redirects") {
|
||||
// Don't bother if net/http returned after too many redirects
|
||||
return nil, err
|
||||
} else if errors.As(err, &x509.UnknownAuthorityError{}) {
|
||||
// Unknown authority errors we do NOT recover from
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l.Errorf("backing off for %s after http request error: %v", backoff.String(), err)
|
||||
|
||||
select {
|
||||
// Request ctx cancelled
|
||||
case <-r.Context().Done():
|
||||
return nil, r.Context().Err()
|
||||
|
||||
// Backoff for some time
|
||||
case <-time.After(backoff):
|
||||
backoff *= 2
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("transport reached max retries")
|
||||
}
|
||||
|
||||
// signGET will safely sign an HTTP GET request.
|
||||
func (t *transport) signGET(r *http.Request) (err error) {
|
||||
t.safesign(func() {
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// signPOST will safely sign an HTTP POST request for given body.
|
||||
func (t *transport) signPOST(r *http.Request, body []byte) (err error) {
|
||||
t.safesign(func() {
|
||||
err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// safesign will perform sign function within mutex protection,
|
||||
// and ensured that httpsig.Signers are up-to-date.
|
||||
func (t *transport) safesign(sign func()) {
|
||||
// Perform within mu safety
|
||||
t.signerMu.Lock()
|
||||
defer t.signerMu.Unlock()
|
||||
|
||||
if now := time.Now(); now.After(t.signerExp) {
|
||||
const expiry = 120
|
||||
|
||||
// Signers have expired and require renewal
|
||||
t.getSigner, _ = NewGETSigner(expiry)
|
||||
t.postSigner, _ = NewPOSTSigner(expiry)
|
||||
t.signerExp = now.Add(time.Second * expiry)
|
||||
}
|
||||
|
||||
// Perform signing
|
||||
sign()
|
||||
}
|
||||
|
||||
// containsInt checks if slice contains check.
|
||||
func containsInt(slice []int, check int) bool {
|
||||
for _, i := range slice {
|
||||
if i == check {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue