[security] transport.Controller{} and transport.Transport{} security and performance improvements (#564)

* cache transports in controller by privkey-generated pubkey, add retry logic to transport requests

Signed-off-by: kim <grufwub@gmail.com>

* update code comments, defer mutex unlocks

Signed-off-by: kim <grufwub@gmail.com>

* add count to 'performing request' log message

Signed-off-by: kim <grufwub@gmail.com>

* reduce repeated conversions of same url.URL object

Signed-off-by: kim <grufwub@gmail.com>

* move worker.Worker to concurrency subpackage, add WorkQueue type, limit transport http client use by WorkQueue

Signed-off-by: kim <grufwub@gmail.com>

* fix security advisories regarding max outgoing conns, max rsp body size

- implemented by a new httpclient.Client{} that wraps an underlying
  client with a queue to limit connections, and limit reader wrapping
  a response body with a configured maximum size
- update pub.HttpClient args passed around to be this new httpclient.Client{}

Signed-off-by: kim <grufwub@gmail.com>

* add httpclient tests, move ip validation to separate package + change mechanism

Signed-off-by: kim <grufwub@gmail.com>

* fix merge conflicts

Signed-off-by: kim <grufwub@gmail.com>

* use singular mutex in transport rather than separate signer mus

Signed-off-by: kim <grufwub@gmail.com>

* improved useragent string

Signed-off-by: kim <grufwub@gmail.com>

* add note regarding missing test

Signed-off-by: kim <grufwub@gmail.com>

* remove useragent field from transport (instead store in controller)

Signed-off-by: kim <grufwub@gmail.com>

* shutup linter

Signed-off-by: kim <grufwub@gmail.com>

* reset other signing headers on each loop iteration

Signed-off-by: kim <grufwub@gmail.com>

* respect request ctx during retry-backoff sleep period

Signed-off-by: kim <grufwub@gmail.com>

* use external pkg with docs explaining performance "hack"

Signed-off-by: kim <grufwub@gmail.com>

* use http package constants instead of string method literals

Signed-off-by: kim <grufwub@gmail.com>

* add license file headers

Signed-off-by: kim <grufwub@gmail.com>

* update code comment to match new func names

Signed-off-by: kim <grufwub@gmail.com>

* updates to user-agent string

Signed-off-by: kim <grufwub@gmail.com>

* update signed testrig models to fit with new transport logic (instead uses separate signer now)

Signed-off-by: kim <grufwub@gmail.com>

* fuck you linter

Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
kim 2022-05-15 10:16:43 +01:00 committed by GitHub
commit 223025fc27
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
61 changed files with 1801 additions and 435 deletions

View file

@ -11,6 +11,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -20,7 +21,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -62,8 +62,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -29,6 +29,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -38,7 +39,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -80,8 +80,8 @@ func (suite *AdminStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -31,6 +31,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -77,8 +77,8 @@ func (suite *ServeFileTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -28,6 +28,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -37,7 +38,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -77,8 +77,8 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -37,6 +37,7 @@ import (
"github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -47,7 +48,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -84,8 +84,8 @@ func (suite *MediaCreateTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -35,6 +35,7 @@ import (
"github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -45,7 +46,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -82,8 +82,8 @@ func (suite *MediaUpdateTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -32,6 +32,7 @@ import (
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/api/client/status"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -90,8 +90,8 @@ func (suite *StatusStandardTestSuite) SetupTest() {
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(suite.testHttpClient(), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)

View file

@ -22,6 +22,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -30,7 +31,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -58,8 +58,8 @@ type UserStandardTestSuite struct {
func (suite *UserStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()

View file

@ -33,11 +33,11 @@ import (
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -85,8 +85,8 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -188,8 +188,8 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -281,8 +281,8 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -403,8 +403,8 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)

View file

@ -31,8 +31,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -46,8 +46,8 @@ func (suite *OutboxGetTestSuite) TestGetOutbox() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox"]
targetAccount := suite.testAccounts["local_account_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -104,8 +104,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"]
targetAccount := suite.testAccounts["local_account_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -162,8 +162,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"]
targetAccount := suite.testAccounts["local_account_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)

View file

@ -33,8 +33,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -49,8 +49,8 @@ func (suite *RepliesGetTestSuite) TestGetReplies() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -113,8 +113,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -180,8 +180,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)

View file

@ -32,8 +32,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -48,8 +48,8 @@ func (suite *StatusGetTestSuite) TestGetStatus() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -116,8 +116,8 @@ func (suite *StatusGetTestSuite) TestGetStatusLowercase() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)

View file

@ -23,6 +23,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -32,7 +33,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -78,8 +78,8 @@ func (suite *UserStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db)

View file

@ -33,9 +33,9 @@ import (
"github.com/superseriousbusiness/activity/streams/vocab"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -49,8 +49,8 @@ func (suite *UserGetTestSuite) TestGetUser() {
signedRequest := derefRequests["foss_satan_dereference_zork"]
targetAccount := suite.testAccounts["local_account_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -130,8 +130,8 @@ func (suite *UserGetTestSuite) TestGetUserPublicKeyDeleted() {
derefRequests := testrig.NewTestDereferenceRequests(suite.testAccounts)
signedRequest := derefRequests["foss_satan_dereference_zork_public_key"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)

View file

@ -28,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -37,7 +38,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -81,8 +81,8 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db)

View file

@ -31,10 +31,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -71,8 +71,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUser() {
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHost() {
viper.Set(config.Keys.Host, "gts.example.org")
viper.Set(config.Keys.AccountDomain, "example.org")
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)
@ -107,8 +107,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAccountDomain() {
viper.Set(config.Keys.Host, "gts.example.org")
viper.Set(config.Keys.AccountDomain, "example.org")
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)

View file

@ -1,4 +1,4 @@
package worker
package concurrency
import (
"context"
@ -12,17 +12,17 @@ import (
"github.com/sirupsen/logrus"
)
// Worker represents a proccessor for MsgType objects, using a worker pool to allocate resources.
type Worker[MsgType any] struct {
// WorkerPool represents a proccessor for MsgType objects, using a worker pool to allocate resources.
type WorkerPool[MsgType any] struct {
workers runners.WorkerPool
process func(context.Context, MsgType) error
prefix string // contains type prefix for logging
}
// New returns a new Worker[MsgType] with given number of workers and queue ratio,
// New returns a new WorkerPool[MsgType] with given number of workers and queue ratio,
// where the queue ratio is multiplied by no. workers to get queue size. If args < 1
// then suitable defaults are determined from the runtime's GOMAXPROCS variable.
func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
func NewWorkerPool[MsgType any](workers int, queueRatio int) *WorkerPool[MsgType] {
var zero MsgType
if workers < 1 {
@ -38,7 +38,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
msgType := reflect.TypeOf(zero).String()
_, msgType = path.Split(msgType)
w := &Worker[MsgType]{
w := &WorkerPool[MsgType]{
workers: runners.NewWorkerPool(workers, workers*queueRatio),
process: nil,
prefix: fmt.Sprintf("worker.Worker[%s]", msgType),
@ -55,7 +55,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
}
// Start will attempt to start the underlying worker pool, or return error.
func (w *Worker[MsgType]) Start() error {
func (w *WorkerPool[MsgType]) Start() error {
logrus.Infof("%s starting", w.prefix)
// Check processor was set
@ -72,7 +72,7 @@ func (w *Worker[MsgType]) Start() error {
}
// Stop will attempt to stop the underlying worker pool, or return error.
func (w *Worker[MsgType]) Stop() error {
func (w *WorkerPool[MsgType]) Stop() error {
logrus.Infof("%s stopping", w.prefix)
// Attempt to stop pool
@ -84,7 +84,7 @@ func (w *Worker[MsgType]) Stop() error {
}
// SetProcessor will set the Worker's processor function, which is called for each queued message.
func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
func (w *WorkerPool[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
if w.process != nil {
logrus.Panicf("%s Worker.process is already set", w.prefix)
}
@ -92,7 +92,7 @@ func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error)
}
// Queue will queue provided message to be processed with there's a free worker.
func (w *Worker[MsgType]) Queue(msg MsgType) {
func (w *WorkerPool[MsgType]) Queue(msg MsgType) {
logrus.Tracef("%s queueing message (workers=%d queue=%d): %+v",
w.prefix, w.workers.Workers(), w.workers.Queue(), msg,
)

View file

@ -29,12 +29,12 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -150,7 +150,7 @@ func (suite *DereferencerStandardTestSuite) mockTransportController() transport.
return response, nil
}
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
mockClient := testrig.NewMockHTTPClient(do)
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
}

View file

@ -28,10 +28,10 @@ import (
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -57,7 +57,7 @@ func (suite *FederatingActorTestSuite) TestSendNoRemoteFollowers() {
)
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
sentMessages := []*url.URL{}
@ -112,7 +112,7 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() {
)
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
sentMessages := []*url.URL{}

View file

@ -24,10 +24,10 @@ import (
"codeberg.org/gruf/go-mutexes"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// DB wraps the pub.Database interface with a couple of custom functions for GoToSocial.
@ -44,12 +44,12 @@ type DB interface {
type federatingDB struct {
locks mutexes.MutexMap
db db.DB
fedWorker *worker.Worker[messages.FromFederator]
fedWorker *concurrency.WorkerPool[messages.FromFederator]
typeConverter typeutils.TypeConverter
}
// New returns a DB interface using the given database and config
func New(db db.DB, fedWorker *worker.Worker[messages.FromFederator]) DB {
func New(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) DB {
fdb := federatingDB{
locks: mutexes.NewMap(-1, -1), // use defaults
db: db,

View file

@ -23,12 +23,12 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -36,7 +36,7 @@ type FederatingDBTestSuite struct {
suite.Suite
db db.DB
tc typeutils.TypeConverter
fedWorker *worker.Worker[messages.FromFederator]
fedWorker *concurrency.WorkerPool[messages.FromFederator]
fromFederator chan messages.FromFederator
federatingDB federatingdb.DB
@ -65,7 +65,7 @@ func (suite *FederatingDBTestSuite) SetupSuite() {
func (suite *FederatingDBTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
suite.fedWorker = worker.New[messages.FromFederator](-1, -1)
suite.fedWorker = concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.fromFederator = make(chan messages.FromFederator, 10)
suite.fedWorker.SetProcessor(func(ctx context.Context, msg messages.FromFederator) error {
suite.fromFederator <- msg

View file

@ -28,10 +28,10 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -44,7 +44,7 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook() {
// the activity we're gonna use
activity := suite.testActivities["dm_for_zork"]
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {
@ -78,7 +78,7 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostInbox() {
sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"]
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
// now setup module being tested, with the mock transport controller

View file

@ -0,0 +1,199 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package httpclient
import (
"errors"
"io"
"net"
"net/http"
"net/netip"
"runtime"
"time"
)
// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net.
var ErrReservedAddr = errors.New("dial within blocked / reserved IP range")
// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB).
var ErrBodyTooLarge = errors.New("body size too large")
// dialer is the base net.Dialer used by all package-created http.Transports.
var dialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: &net.Resolver{Dial: nil},
}
// Config provides configuration details for setting up a new
// instance of httpclient.Client{}. Within are a subset of the
// configuration values passed to initialized http.Transport{}
// and http.Client{}, along with httpclient.Client{} specific.
type Config struct {
// MaxOpenConns limits the max number of concurrent open connections.
MaxOpenConns int
// MaxIdleConns: see http.Transport{}.MaxIdleConns.
MaxIdleConns int
// ReadBufferSize: see http.Transport{}.ReadBufferSize.
ReadBufferSize int
// WriteBufferSize: see http.Transport{}.WriteBufferSize.
WriteBufferSize int
// MaxBodySize determines the maximum fetchable body size.
MaxBodySize int64
// Timeout: see http.Client{}.Timeout.
Timeout time.Duration
// DisableCompression: see http.Transport{}.DisableCompression.
DisableCompression bool
// AllowRanges allows outgoing communications to given IP nets.
AllowRanges []netip.Prefix
// BlockRanges blocks outgoing communiciations to given IP nets.
BlockRanges []netip.Prefix
}
// Client wraps an underlying http.Client{} to provide the following:
// - setting a maximum received request body size, returning error on
// large content lengths, and using a limited reader in all other
// cases to protect against forged / unknown content-lengths
// - protection from server side request forgery (SSRF) by only dialing
// out to known public IP prefixes, configurable with allows/blocks
// - limit number of concurrent requests, else blocking until a slot
// is available (context channels still respected)
type Client struct {
client http.Client
queue chan struct{}
bmax int64
}
// New returns a new instance of Client initialized using configuration.
func New(cfg Config) *Client {
var c Client
// Copy global
d := dialer
if cfg.MaxOpenConns <= 0 {
// By default base this value on GOMAXPROCS.
maxprocs := runtime.GOMAXPROCS(0)
cfg.MaxOpenConns = maxprocs * 10
}
if cfg.MaxIdleConns <= 0 {
// By default base this value on MaxOpenConns
cfg.MaxIdleConns = cfg.MaxOpenConns * 10
}
if cfg.MaxBodySize <= 0 {
// By default set this to a reasonable 40MB
cfg.MaxBodySize = 40 * 1024 * 1024
}
// Protect dialer with IP range sanitizer
d.Control = (&sanitizer{
allow: cfg.AllowRanges,
block: cfg.BlockRanges,
}).Sanitize
// Prepare client fields
c.bmax = cfg.MaxBodySize
c.queue = make(chan struct{}, cfg.MaxOpenConns)
c.client.Timeout = cfg.Timeout
// Set underlying HTTP client roundtripper
c.client.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
ForceAttemptHTTP2: true,
DialContext: d.DialContext,
MaxIdleConns: cfg.MaxIdleConns,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ReadBufferSize: cfg.ReadBufferSize,
WriteBufferSize: cfg.WriteBufferSize,
DisableCompression: cfg.DisableCompression,
}
return &c
}
// Do will perform given request when an available slot in the queue is available,
// and block until this time. For returned values, this follows the same semantics
// as the standard http.Client{}.Do() implementation except that response body will
// be wrapped by an io.LimitReader() to limit response body sizes.
func (c *Client) Do(req *http.Request) (*http.Response, error) {
select {
// Request context cancelled
case <-req.Context().Done():
return nil, req.Context().Err()
// Slot in queue acquired
case c.queue <- struct{}{}:
// NOTE:
// Ideally here we would set the slot release to happen either
// on error return, or via callback from the response body closer.
// However when implementing this, there appear deadlocks between
// the channel queue here and the media manager worker pool. So
// currently we only place a limit on connections dialing out, but
// there may still be more connections open than len(c.queue) given
// that connections may not be closed until response body is closed.
// The current implementation will reduce the viability of denial of
// service attacks, but if there are future issues heed this advice :]
defer func() { <-c.queue }()
}
// Perform the HTTP request
rsp, err := c.client.Do(req)
if err != nil {
return nil, err
}
// Check response body not too large
if rsp.ContentLength > c.bmax {
return nil, ErrBodyTooLarge
}
// Seperate the body implementers
rbody := (io.Reader)(rsp.Body)
cbody := (io.Closer)(rsp.Body)
var limit int64
if limit = rsp.ContentLength; limit < 0 {
// If unknown, use max as reader limit
limit = c.bmax
}
// Don't trust them, limit body reads
rbody = io.LimitReader(rbody, limit)
// Wrap body with limit
rsp.Body = &struct {
io.Reader
io.Closer
}{rbody, cbody}
return rsp, nil
}

View file

@ -0,0 +1,154 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package httpclient_test
import (
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
)
var privateIPs = []string{
"http://127.0.0.1:80",
"http://0.0.0.0:80",
"http://192.168.0.1:80",
"http://192.168.1.0:80",
"http://10.0.0.0:80",
"http://172.16.0.0:80",
"http://10.255.255.255:80",
"http://172.31.255.255:80",
"http://255.255.255.255:80",
}
var bodies = []string{
"hello world!",
"{}",
`{"key": "value", "some": "kinda bullshit"}`,
"body with\r\nnewlines",
}
// Note:
// There is no test for the .MaxOpenConns implementation
// in the httpclient.Client{}, due to the difficult to test
// this. The block is only held for the actual dial out to
// the connection, so the usual test of blocking and holding
// open this queue slot to check we can't open another isn't
// an easy test here.
func TestHTTPClientSmallBody(t *testing.T) {
for _, body := range bodies {
_TestHTTPClientWithBody(t, []byte(body), int(^uint16(0)))
}
}
func TestHTTPClientExactBody(t *testing.T) {
for _, body := range bodies {
_TestHTTPClientWithBody(t, []byte(body), len(body))
}
}
func TestHTTPClientLargeBody(t *testing.T) {
for _, body := range bodies {
_TestHTTPClientWithBody(t, []byte(body), len(body)-1)
}
}
func _TestHTTPClientWithBody(t *testing.T, body []byte, max int) {
var (
handler http.HandlerFunc
expect []byte
expectErr error
)
// If this is a larger body, reslice and
// set error so we know what to expect
expect = body
if max < len(body) {
expect = expect[:max]
expectErr = httpclient.ErrBodyTooLarge
}
// Create new HTTP client with maximum body size
client := httpclient.New(httpclient.Config{
MaxBodySize: int64(max),
DisableCompression: true,
AllowRanges: []netip.Prefix{
// Loopback (used by server)
netip.MustParsePrefix("127.0.0.1/8"),
},
})
// Set simple body-writing test handler
handler = func(rw http.ResponseWriter, r *http.Request) {
_, _ = rw.Write(body)
}
// Start the test server
srv := httptest.NewServer(handler)
defer srv.Close()
// Wrap body to provide reader iface
rbody := bytes.NewReader(body)
// Create the test HTTP request
req, _ := http.NewRequest("POST", srv.URL, rbody)
// Perform the test request
rsp, err := client.Do(req)
if !errors.Is(err, expectErr) {
t.Fatalf("error performing client request: %v", err)
} else if err != nil {
return // expected error
}
defer rsp.Body.Close()
// Read response body into memory
check, err := io.ReadAll(rsp.Body)
if err != nil {
t.Fatalf("error reading response body: %v", err)
}
// Check actual response body matches expected
if !bytes.Equal(expect, check) {
t.Errorf("response body did not match expected: expect=%q actual=%q", string(expect), string(check))
}
}
func TestHTTPClientPrivateIP(t *testing.T) {
client := httpclient.New(httpclient.Config{})
for _, addr := range privateIPs {
// Prepare request to private IP
req, _ := http.NewRequest("GET", addr, nil)
// Perform the HTTP request
_, err := client.Do(req)
if !errors.Is(err, httpclient.ErrReservedAddr) {
t.Errorf("dialing private address did not return expected error: %v", err)
}
}
}

View file

@ -0,0 +1,64 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package httpclient
import (
"net/netip"
"syscall"
"github.com/superseriousbusiness/gotosocial/internal/netutil"
)
type sanitizer struct {
allow []netip.Prefix
block []netip.Prefix
}
// Sanitize implements the required net.Dialer.Control function signature.
func (s *sanitizer) Sanitize(ntwrk, addr string, _ syscall.RawConn) error {
// Parse IP+port from addr
ipport, err := netip.ParseAddrPort(addr)
if err != nil {
return err
}
// Seperate the IP
ip := ipport.Addr()
// Check if this is explicitly allowed
for i := 0; i < len(s.allow); i++ {
if s.allow[i].Contains(ip) {
return nil
}
}
// Now check if explicity blocked
for i := 0; i < len(s.block); i++ {
if s.block[i].Contains(ip) {
return ErrReservedAddr
}
}
// Validate this is a safe IP
if !netutil.ValidateIP(ip) {
return ErrReservedAddr
}
return nil
}

View file

@ -27,9 +27,9 @@ import (
"github.com/robfig/cron/v3"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Manager provides an interface for managing media: parsing, storing, and retrieving media objects like photos, videos, and gifs.
@ -79,8 +79,8 @@ type Manager interface {
type manager struct {
db db.DB
storage *kv.KVStore
emojiWorker *worker.Worker[*ProcessingEmoji]
mediaWorker *worker.Worker[*ProcessingMedia]
emojiWorker *concurrency.WorkerPool[*ProcessingEmoji]
mediaWorker *concurrency.WorkerPool[*ProcessingMedia]
stopCronJobs func() error
}
@ -89,7 +89,7 @@ type manager struct {
// A worker pool will also be initialized for the manager, to ensure that only
// a limited number of media will be processed in parallel. The numbers of workers
// is determined from the $GOMAXPROCS environment variable (usually no. CPU cores).
// See internal/worker.New() documentation for further information.
// See internal/concurrency.NewWorkerPool() documentation for further information.
func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
m := &manager{
db: database,
@ -97,7 +97,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
}
// Prepare the media worker pool
m.mediaWorker = worker.New[*ProcessingMedia](-1, 10)
m.mediaWorker = concurrency.NewWorkerPool[*ProcessingMedia](-1, 10)
m.mediaWorker.SetProcessor(func(ctx context.Context, media *ProcessingMedia) error {
if err := ctx.Err(); err != nil {
return err
@ -109,7 +109,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
})
// Prepare the emoji worker pool
m.emojiWorker = worker.New[*ProcessingEmoji](-1, 10)
m.emojiWorker = concurrency.NewWorkerPool[*ProcessingEmoji](-1, 10)
m.emojiWorker.SetProcessor(func(ctx context.Context, emoji *ProcessingEmoji) error {
if err := ctx.Err(); err != nil {
return err

View file

@ -0,0 +1,78 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package netutil
import (
"net/netip"
)
var (
// IPv6GlobalUnicast is the global IPv6 unicast IP prefix.
IPv6GlobalUnicast = netip.MustParsePrefix("ff00::/8")
// IPvReserved contains IPv4 reserved IP prefixes.
IPv4Reserved = [...]netip.Prefix{
netip.MustParsePrefix("0.0.0.0/8"), // Current network
netip.MustParsePrefix("10.0.0.0/8"), // Private
netip.MustParsePrefix("100.64.0.0/10"), // RFC6598
netip.MustParsePrefix("127.0.0.0/8"), // Loopback
netip.MustParsePrefix("169.254.0.0/16"), // Link-local
netip.MustParsePrefix("172.16.0.0/12"), // Private
netip.MustParsePrefix("192.0.0.0/24"), // RFC6890
netip.MustParsePrefix("192.0.2.0/24"), // Test, doc, examples
netip.MustParsePrefix("192.88.99.0/24"), // IPv6 to IPv4 relay
netip.MustParsePrefix("192.168.0.0/16"), // Private
netip.MustParsePrefix("198.18.0.0/15"), // Benchmarking tests
netip.MustParsePrefix("198.51.100.0/24"), // Test, doc, examples
netip.MustParsePrefix("203.0.113.0/24"), // Test, doc, examples
netip.MustParsePrefix("224.0.0.0/4"), // Multicast
netip.MustParsePrefix("240.0.0.0/4"), // Reserved (includes broadcast / 255.255.255.255)
}
)
// ValidateAddr will parse a netip.AddrPort from string, and return the result of ValidateIP() on addr.
func ValidateAddr(s string) bool {
ipport, err := netip.ParseAddrPort(s)
if err != nil {
return false
}
return ValidateIP(ipport.Addr())
}
// ValidateIP returns whether IP is an IPv4/6 address in non-reserved, public ranges.
func ValidateIP(ip netip.Addr) bool {
switch {
// IPv4: check if IPv4 in reserved nets
case ip.Is4():
for _, reserved := range IPv4Reserved {
if reserved.Contains(ip) {
return false
}
}
return true
// IPv6: check if in global unicast (public internet)
case ip.Is6():
return IPv6GlobalUnicast.Contains(ip)
// Assume malicious by default
default:
return false
}
}

View file

@ -23,6 +23,7 @@ import (
"mime/multipart"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
@ -33,7 +34,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/oauth2/v4"
)
@ -84,7 +84,7 @@ type Processor interface {
type processor struct {
tc typeutils.TypeConverter
mediaManager media.Manager
clientWorker *worker.Worker[messages.FromClientAPI]
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
oauthServer oauth.Server
filter visibility.Filter
formatter text.Formatter
@ -94,7 +94,7 @@ type processor struct {
}
// New returns a new account processor.
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *worker.Worker[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
return &processor{
tc: tc,
mediaManager: mediaManager,

View file

@ -24,6 +24,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -35,7 +36,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -81,8 +81,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error {
suite.fromClientAPIChan <- msg
return nil

View file

@ -23,13 +23,13 @@ import (
"mime/multipart"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor wraps a bunch of functions for processing admin actions.
@ -47,12 +47,12 @@ type Processor interface {
type processor struct {
tc typeutils.TypeConverter
mediaManager media.Manager
clientWorker *worker.Worker[messages.FromClientAPI]
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
db db.DB
}
// New returns a new admin processor.
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *worker.Worker[messages.FromClientAPI]) Processor {
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor {
return &processor{
tc: tc,
mediaManager: mediaManager,

View file

@ -26,6 +26,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
@ -33,7 +34,6 @@ import (
mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -122,7 +122,7 @@ func (suite *MediaStandardTestSuite) mockTransportController() transport.Control
return response, nil
}
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
mockClient := testrig.NewMockHTTPClient(do)
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
}

View file

@ -25,6 +25,7 @@ import (
"codeberg.org/gruf/go-store/kv"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -44,7 +45,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor should be passed to api modules (see internal/apimodule/...). It is used for
@ -237,8 +237,8 @@ type Processor interface {
// processor just implements the Processor interface
type processor struct {
clientWorker *worker.Worker[messages.FromClientAPI]
fedWorker *worker.Worker[messages.FromFederator]
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
fedWorker *concurrency.WorkerPool[messages.FromFederator]
federator federation.Federator
tc typeutils.TypeConverter
@ -271,8 +271,8 @@ func NewProcessor(
storage *kv.KVStore,
db db.DB,
emailSender email.Sender,
clientWorker *worker.Worker[messages.FromClientAPI],
fedWorker *worker.Worker[messages.FromFederator],
clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
fedWorker *concurrency.WorkerPool[messages.FromFederator],
) Processor {
parseMentionFunc := GetParseMentionFunc(db, federator)

View file

@ -29,6 +29,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -217,8 +217,8 @@ func (suite *ProcessingStandardTestSuite) SetupTest() {
}, nil
})
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.transportController = testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)

View file

@ -22,6 +22,7 @@ import (
"context"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -29,7 +30,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor wraps a bunch of functions for processing statuses.
@ -74,12 +74,12 @@ type processor struct {
db db.DB
filter visibility.Filter
formatter text.Formatter
clientWorker *worker.Worker[messages.FromClientAPI]
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
parseMention gtsmodel.ParseMentionFunc
}
// New returns a new status processor.
func New(db db.DB, tc typeutils.TypeConverter, clientWorker *worker.Worker[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
return &processor{
tc: tc,
db: db,

View file

@ -21,6 +21,7 @@ package status_test
import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -30,7 +31,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -42,7 +42,7 @@ type StatusStandardTestSuite struct {
storage *kv.KVStore
mediaManager media.Manager
federator federation.Federator
clientWorker *worker.Worker[messages.FromClientAPI]
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
// standard suite models
testTokens map[string]*gtsmodel.Token
@ -75,11 +75,11 @@ func (suite *StatusStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.typeConverter = testrig.NewTestTypeConverter(suite.db)
suite.clientWorker = worker.New[messages.FromClientAPI](-1, -1)
suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
suite.storage = testrig.NewTestStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)

View file

@ -20,13 +20,17 @@ package transport
import (
"context"
"crypto"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"fmt"
"net/url"
"sync"
"runtime/debug"
"time"
"github.com/go-fed/httpsig"
"codeberg.org/gruf/go-byteutil"
"codeberg.org/gruf/go-cache/v2"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
@ -37,109 +41,85 @@ import (
// Controller generates transports for use in making federation requests to other servers.
type Controller interface {
NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error)
// NewTransport returns an http signature transport with the given public key ID (URL location of pubkey), and the given private key.
NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error)
// NewTransportForUsername searches for account with username, and returns result of .NewTransport().
NewTransportForUsername(ctx context.Context, username string) (Transport, error)
}
type controller struct {
db db.DB
clock pub.Clock
client pub.HttpClient
appAgent string
// dereferenceFollowersShortcut is a shortcut to dereference followers of an
// account on this instance, without making any external api/http calls.
//
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
// dereferenceUserShortcut is a shortcut to dereference followers an account on
// this instance, without making any external api/http calls.
//
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
}
func dereferenceFollowersShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
return func(ctx context.Context, iri *url.URL) ([]byte, error) {
followers, err := federatingDB.Followers(ctx, iri)
if err != nil {
return nil, err
}
i, err := streams.Serialize(followers)
if err != nil {
return nil, err
}
return json.Marshal(i)
}
}
func dereferenceUserShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
return func(ctx context.Context, iri *url.URL) ([]byte, error) {
user, err := federatingDB.Get(ctx, iri)
if err != nil {
return nil, err
}
i, err := streams.Serialize(user)
if err != nil {
return nil, err
}
return json.Marshal(i)
}
db db.DB
fedDB federatingdb.DB
clock pub.Clock
client pub.HttpClient
cache cache.Cache[string, *transport]
userAgent string
}
// NewController returns an implementation of the Controller interface for creating new transports
func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
applicationName := viper.GetString(config.Keys.ApplicationName)
host := viper.GetString(config.Keys.Host)
appAgent := fmt.Sprintf("%s %s", applicationName, host)
return &controller{
db: db,
clock: clock,
client: client,
appAgent: appAgent,
dereferenceFollowersShortcut: dereferenceFollowersShortcut(federatingDB),
dereferenceUserShortcut: dereferenceUserShortcut(federatingDB),
// Determine build information
build, _ := debug.ReadBuildInfo()
c := &controller{
db: db,
fedDB: federatingDB,
clock: clock,
client: client,
cache: cache.New[string, *transport](),
userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, build.Main.Version),
}
// Transport cache has TTL=1hr freq=1m
c.cache.SetTTL(time.Hour, false)
if !c.cache.Start(time.Minute) {
logrus.Panic("failed to start transport controller cache")
}
return c
}
// NewTransport returns a new http signature transport with the given public key id (a URL), and the given private key.
func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) {
prefs := []httpsig.Algorithm{httpsig.RSA_SHA256}
digestAlgo := httpsig.DigestSha256
getHeaders := []string{httpsig.RequestTarget, "host", "date"}
postHeaders := []string{httpsig.RequestTarget, "host", "date", "digest"}
func (c *controller) NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error) {
// Generate public key string for cache key
//
// NOTE: it is safe to use the public key as the cache
// key here as we are generating it ourselves from the
// private key. If we were simply using a public key
// provided as argument that would absolutely NOT be safe.
pubStr := privkeyToPublicStr(privkey)
getSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, 120)
if err != nil {
return nil, fmt.Errorf("error creating get signer: %s", err)
// First check for cached transport
transp, ok := c.cache.Get(pubStr)
if ok {
return transp, nil
}
postSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, 120)
if err != nil {
return nil, fmt.Errorf("error creating post signer: %s", err)
// Create the transport
transp = &transport{
controller: c,
pubKeyID: pubKeyID,
privkey: privkey,
}
sigTransport := pub.NewHttpSigTransport(c.client, c.appAgent, c.clock, getSigner, postSigner, pubKeyID, privkey)
// Cache this transport under pubkey
if !c.cache.Put(pubStr, transp) {
var cached *transport
return &transport{
client: c.client,
appAgent: c.appAgent,
gofedAgent: "(go-fed/activity v1.0.0)",
clock: c.clock,
pubKeyID: pubKeyID,
privkey: privkey,
sigTransport: sigTransport,
getSigner: getSigner,
getSignerMu: &sync.Mutex{},
dereferenceFollowersShortcut: c.dereferenceFollowersShortcut,
dereferenceUserShortcut: c.dereferenceUserShortcut,
}, nil
cached, ok = c.cache.Get(pubStr)
if !ok {
// Some ridiculous race cond.
c.cache.Set(pubStr, transp)
} else {
// Use already cached
transp = cached
}
}
return transp, nil
}
func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) {
@ -164,3 +144,45 @@ func (c *controller) NewTransportForUsername(ctx context.Context, username strin
}
return transport, nil
}
// dereferenceLocalFollowers is a shortcut to dereference followers of an
// account on this instance, without making any external api/http calls.
//
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
func (c *controller) dereferenceLocalFollowers(ctx context.Context, iri *url.URL) ([]byte, error) {
followers, err := c.fedDB.Followers(ctx, iri)
if err != nil {
return nil, err
}
i, err := streams.Serialize(followers)
if err != nil {
return nil, err
}
return json.Marshal(i)
}
// dereferenceLocalUser is a shortcut to dereference followers an account on
// this instance, without making any external api/http calls.
//
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
func (c *controller) dereferenceLocalUser(ctx context.Context, iri *url.URL) ([]byte, error) {
user, err := c.fedDB.Get(ctx, iri)
if err != nil {
return nil, err
}
i, err := streams.Serialize(user)
if err != nil {
return nil, err
}
return json.Marshal(i)
}
// privkeyToPublicStr will create a string representation of RSA public key from private.
func privkeyToPublicStr(privkey *rsa.PrivateKey) string {
b := x509.MarshalPKCS1PublicKey(&privkey.PublicKey)
return byteutil.B2S(b)
}

View file

@ -19,13 +19,14 @@
package transport
import (
"bytes"
"context"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
@ -72,6 +73,28 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
return nil
}
logrus.Debugf("Deliver: posting as %s to %s", t.pubKeyID, to.String())
return t.sigTransport.Deliver(ctx, b, to)
urlStr := to.String()
req, err := http.NewRequestWithContext(ctx, "POST", urlStr, bytes.NewReader(b))
if err != nil {
return err
}
req.Header.Add("Content-Type", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
req.Header.Add("Accept-Charset", "utf-8")
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", to.Host)
resp, err := t.POST(req, b)
if err != nil {
return err
}
defer resp.Body.Close()
if code := resp.StatusCode; code != http.StatusOK &&
code != http.StatusCreated && code != http.StatusAccepted {
return fmt.Errorf("POST request to %s failed (%d): %s", urlStr, resp.StatusCode, resp.Status)
}
return nil
}

View file

@ -20,32 +20,55 @@ package transport
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/uris"
)
func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) {
l := logrus.WithField("func", "Dereference")
// if the request is to us, we can shortcut for certain URIs rather than going through
// the normal request flow, thereby saving time and energy
if iri.Host == viper.GetString(config.Keys.Host) {
if uris.IsFollowersPath(iri) {
// the request is for followers of one of our accounts, which we can shortcut
return t.dereferenceFollowersShortcut(ctx, iri)
return t.controller.dereferenceLocalFollowers(ctx, iri)
}
if uris.IsUserPath(iri) {
// the request is for one of our accounts, which we can shortcut
return t.dereferenceUserShortcut(ctx, iri)
return t.controller.dereferenceLocalUser(ctx, iri)
}
}
// the request is either for a remote host or for us but we don't have a shortcut, so continue as normal
l.Debugf("performing GET to %s", iri.String())
return t.sigTransport.Dereference(ctx, iri)
// Build IRI just once
iriStr := iri.String()
// Prepare new HTTP request to endpoint
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
req.Header.Add("Accept", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
req.Header.Add("Accept-Charset", "utf-8")
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", iri.Host)
// Perform the HTTP request
rsp, err := t.GET(req)
if err != nil {
return nil, err
}
defer rsp.Body.Close()
// Check for an expected status code
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
}
return ioutil.ReadAll(rsp.Body)
}

View file

@ -80,43 +80,38 @@ func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gts
}
func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) {
l := logrus.WithField("func", "dereferenceByAPIV1Instance")
cleanIRI := &url.URL{
Scheme: iri.Scheme,
Host: iri.Host,
Path: "api/v1/instance",
}
l.Debugf("performing GET to %s", cleanIRI.String())
req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
if err != nil {
return nil, err
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
req.Header.Set("Host", cleanIRI.Host)
t.getSignerMu.Lock()
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
t.getSignerMu.Unlock()
if err != nil {
return nil, err
}
resp, err := t.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
}
b, err := ioutil.ReadAll(resp.Body)
// Build IRI just once
iriStr := cleanIRI.String()
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
if len(b) == 0 {
req.Header.Add("Accept", "application/json")
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", cleanIRI.Host)
resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
} else if len(b) == 0 {
return nil, errors.New("response bytes was len 0")
}
@ -237,44 +232,37 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm
}
func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) {
l := logrus.WithField("func", "callNodeInfoWellKnown")
cleanIRI := &url.URL{
Scheme: iri.Scheme,
Host: iri.Host,
Path: ".well-known/nodeinfo",
}
l.Debugf("performing GET to %s", cleanIRI.String())
req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
if err != nil {
return nil, err
}
// Build IRI just once
iriStr := cleanIRI.String()
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
req.Header.Set("Host", cleanIRI.Host)
t.getSignerMu.Lock()
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
t.getSignerMu.Unlock()
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
resp, err := t.client.Do(req)
req.Header.Add("Accept", "application/json")
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", cleanIRI.Host)
resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if len(b) == 0 {
} else if len(b) == 0 {
return nil, errors.New("callNodeInfoWellKnown: response bytes was len 0")
}
@ -302,38 +290,31 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur
}
func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) {
l := logrus.WithField("func", "callNodeInfo")
// Build IRI just once
iriStr := iri.String()
l.Debugf("performing GET to %s", iri.String())
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", iri.Host)
t.getSignerMu.Lock()
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
t.getSignerMu.Unlock()
if err != nil {
return nil, err
}
resp, err := t.client.Do(req)
resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if len(b) == 0 {
} else if len(b) == 0 {
return nil, errors.New("callNodeInfo: response bytes was len 0")
}

View file

@ -24,34 +24,31 @@ import (
"io"
"net/http"
"net/url"
"github.com/sirupsen/logrus"
)
func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.ReadCloser, int, error) {
l := logrus.WithField("func", "DereferenceMedia")
l.Debugf("performing GET to %s", iri.String())
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
// Build IRI just once
iriStr := iri.String()
// Prepare HTTP request to this media's IRI
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, 0, err
}
req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", iri.Host)
// Perform the HTTP request
rsp, err := t.GET(req)
if err != nil {
return nil, 0, err
}
req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
req.Header.Set("Host", iri.Host)
t.getSignerMu.Lock()
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
t.getSignerMu.Unlock()
if err != nil {
return nil, 0, err
// Check for an expected status code
if rsp.StatusCode != http.StatusOK {
return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
}
resp, err := t.client.Do(req)
if err != nil {
return nil, 0, err
}
if resp.StatusCode != http.StatusOK {
return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
}
return resp.Body, int(resp.ContentLength), nil
return rsp.Body, int(rsp.ContentLength), nil
}

View file

@ -23,46 +23,36 @@ import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"github.com/sirupsen/logrus"
)
func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
l := logrus.WithField("func", "Finger")
urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain)
l.Debugf("performing GET to %s", urlString)
// Prepare URL string
urlStr := "https://" +
targetDomain +
"/.well-known/webfinger?resource=acct:" +
targetUsername + "@" + targetDomain
iri, err := url.Parse(urlString)
if err != nil {
return nil, fmt.Errorf("Finger: error parsing url %s: %s", urlString, err)
}
l.Debugf("performing GET to %s", iri.String())
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
// Generate new GET request from URL string
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return nil, err
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Accept", "application/jrd+json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
req.Header.Set("Host", iri.Host)
t.getSignerMu.Lock()
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
t.getSignerMu.Unlock()
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", req.URL.Host)
// Perform the HTTP request
rsp, err := t.GET(req)
if err != nil {
return nil, err
}
resp, err := t.client.Do(req)
if err != nil {
return nil, err
defer rsp.Body.Close()
// Check for an expected status code
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET request to %s failed (%d): %s", urlStr, rsp.StatusCode, rsp.Status)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
}
return ioutil.ReadAll(resp.Body)
return ioutil.ReadAll(rsp.Body)
}

View file

@ -0,0 +1,43 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package transport
import (
"github.com/go-fed/httpsig"
)
var (
// http signer preferences
prefs = []httpsig.Algorithm{httpsig.RSA_SHA256}
digestAlgo = httpsig.DigestSha256
getHeaders = []string{httpsig.RequestTarget, "host", "date"}
postHeaders = []string{httpsig.RequestTarget, "host", "date", "digest"}
)
// NewGETSigner returns a new httpsig.Signer instance initialized with GTS GET preferences.
func NewGETSigner(expiresIn int64) (httpsig.Signer, error) {
sig, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, expiresIn)
return sig, err
}
// NewPOSTSigner returns a new httpsig.Signer instance initialized with GTS POST preferences.
func NewPOSTSigner(expiresIn int64) (httpsig.Signer, error) {
sig, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, expiresIn)
return sig, err
}

View file

@ -21,11 +21,18 @@ package transport
import (
"context"
"crypto"
"crypto/x509"
"errors"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
errorsv2 "codeberg.org/gruf/go-errors/v2"
"github.com/go-fed/httpsig"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@ -43,28 +50,148 @@ type Transport interface {
DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error)
// Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body.
Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error)
// SigTransport returns the underlying http signature transport wrapped by the GoToSocial transport.
SigTransport() pub.Transport
}
// transport implements the Transport interface
type transport struct {
client pub.HttpClient
appAgent string
gofedAgent string
clock pub.Clock
pubKeyID string
privkey crypto.PrivateKey
sigTransport *pub.HttpSigTransport
getSigner httpsig.Signer
getSignerMu *sync.Mutex
controller *controller
pubKeyID string
privkey crypto.PrivateKey
// shortcuts for dereferencing things that exist on our instance without making an http call to ourself
dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
signerExp time.Time
getSigner httpsig.Signer
postSigner httpsig.Signer
signerMu sync.Mutex
}
func (t *transport) SigTransport() pub.Transport {
return t.sigTransport
// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
func (t *transport) GET(r *http.Request, retryOn ...int) (*http.Response, error) {
if r.Method != http.MethodGet {
return nil, errors.New("must be GET request")
}
return t.do(r, func(r *http.Request) error {
return t.signGET(r)
}, retryOn...)
}
// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
func (t *transport) POST(r *http.Request, body []byte, retryOn ...int) (*http.Response, error) {
if r.Method != http.MethodPost {
return nil, errors.New("must be POST request")
}
return t.do(r, func(r *http.Request) error {
return t.signPOST(r, body)
}, retryOn...)
}
func (t *transport) do(r *http.Request, signer func(*http.Request) error, retryOn ...int) (*http.Response, error) {
const maxRetries = 5
backoff := time.Second * 2
// Start a log entry for this request
l := logrus.WithFields(logrus.Fields{
"pubKeyID": t.pubKeyID,
"method": r.Method,
"url": r.URL.String(),
})
for i := 0; i < maxRetries; i++ {
// Reset signing header fields
now := t.controller.clock.Now().UTC()
r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
r.Header.Del("Signature")
r.Header.Del("Digest")
// Perform request signing
if err := signer(r); err != nil {
return nil, err
}
l.Infof("performing request")
// Attempt to perform request
rsp, err := t.controller.client.Do(r)
if err == nil { //nolint shutup linter
// TooManyRequest means we need to slow
// down and retry our request. Codes over
// 500 generally indicate temp. outages.
if code := rsp.StatusCode; code < 500 &&
code != http.StatusTooManyRequests &&
!containsInt(retryOn, rsp.StatusCode) {
return rsp, nil
}
// Generate error from status code for logging
err = errors.New(`http response "` + rsp.Status + `"`)
} else if errorsv2.Is(err, context.DeadlineExceeded, context.Canceled) {
// Return early if context has cancelled
return nil, err
} else if strings.Contains(err.Error(), "stopped after 10 redirects") {
// Don't bother if net/http returned after too many redirects
return nil, err
} else if errors.As(err, &x509.UnknownAuthorityError{}) {
// Unknown authority errors we do NOT recover from
return nil, err
}
l.Errorf("backing off for %s after http request error: %v", backoff.String(), err)
select {
// Request ctx cancelled
case <-r.Context().Done():
return nil, r.Context().Err()
// Backoff for some time
case <-time.After(backoff):
backoff *= 2
}
}
return nil, errors.New("transport reached max retries")
}
// signGET will safely sign an HTTP GET request.
func (t *transport) signGET(r *http.Request) (err error) {
t.safesign(func() {
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil)
})
return
}
// signPOST will safely sign an HTTP POST request for given body.
func (t *transport) signPOST(r *http.Request, body []byte) (err error) {
t.safesign(func() {
err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body)
})
return
}
// safesign will perform sign function within mutex protection,
// and ensured that httpsig.Signers are up-to-date.
func (t *transport) safesign(sign func()) {
// Perform within mu safety
t.signerMu.Lock()
defer t.signerMu.Unlock()
if now := time.Now(); now.After(t.signerExp) {
const expiry = 120
// Signers have expired and require renewal
t.getSigner, _ = NewGETSigner(expiry)
t.postSigner, _ = NewPOSTSigner(expiry)
t.signerExp = now.Add(time.Second * expiry)
}
// Perform signing
sign()
}
// containsInt checks if slice contains check.
func containsInt(slice []int, check int) bool {
for _, i := range slice {
if i == check {
return true
}
}
return false
}