diff --git a/internal/subscriptions/subscriptions_test.go b/internal/subscriptions/subscriptions_test.go index d86d98691..2014835ae 100644 --- a/internal/subscriptions/subscriptions_test.go +++ b/internal/subscriptions/subscriptions_test.go @@ -24,6 +24,7 @@ import ( "time" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/subscriptions" @@ -814,6 +815,122 @@ func (suite *SubscriptionsTestSuite) TestAdoption() { suite.Equal(testSubscription.ID, existingBlock3.SubscriptionID) } +func (suite *SubscriptionsTestSuite) TestDomainAllowsAndBlocks() { + var ( + ctx = context.Background() + testStructs = testrig.SetupTestStructs(rMediaPath, rTemplatePath) + testAccount = suite.testAccounts["admin_account"] + subscriptions = subscriptions.New( + testStructs.State, + testStructs.TransportController, + testStructs.TypeConverter, + ) + + // Create a subscription for a CSV list of goodies. + testAllowSubscription = >smodel.DomainPermissionSubscription{ + ID: "01JGE681TQSBPAV59GZXPKE62H", + Priority: 255, + Title: "goodies!", + PermissionType: gtsmodel.DomainPermissionAllow, + AsDraft: util.Ptr(false), + AdoptOrphans: util.Ptr(false), + CreatedByAccountID: testAccount.ID, + CreatedByAccount: testAccount, + URI: "https://lists.example.org/goodies", + ContentType: gtsmodel.DomainPermSubContentTypePlain, + } + + testBlockSubscription = >smodel.DomainPermissionSubscription{ + ID: "01JPMVY19TKZND838Z7Y6S4EG8", + Priority: 255, + Title: "baddies!", + PermissionType: gtsmodel.DomainPermissionBlock, + AsDraft: util.Ptr(false), + AdoptOrphans: util.Ptr(false), + CreatedByAccountID: testAccount.ID, + CreatedByAccount: testAccount, + URI: "https://lists.example.org/baddies.csv", + ContentType: gtsmodel.DomainPermSubContentTypeCSV, + } + + ) + defer testrig.TearDownTestStructs(testStructs) + + // Store test subscriptions. + if err := testStructs.State.DB.PutDomainPermissionSubscription( + ctx, testAllowSubscription, + ); err != nil { + suite.FailNow(err.Error()) + } + if err := testStructs.State.DB.PutDomainPermissionSubscription( + ctx, testBlockSubscription, + ); err != nil { + suite.FailNow(err.Error()) + } + + // Put the instance in allowlist mode. + config.SetInstanceFederationMode("allowlist") + + // Fetch + process subscribed perms in order. + var order [2]gtsmodel.DomainPermissionType + if config.GetInstanceFederationMode() == config.InstanceFederationModeBlocklist { + order = [2]gtsmodel.DomainPermissionType{ + gtsmodel.DomainPermissionAllow, + gtsmodel.DomainPermissionBlock, + } + } else { + order = [2]gtsmodel.DomainPermissionType{ + gtsmodel.DomainPermissionBlock, + gtsmodel.DomainPermissionAllow, + } + } + for _, permType := range order { + subscriptions.ProcessDomainPermissionSubscriptions(ctx, permType) + } + + // We should now have allows for each + // domain on the subscribed allow list. + for _, domain := range []string{ + "people.we.like.com", + "goodeggs.org", + "allowthesefolks.church", + } { + var ( + perm gtsmodel.DomainPermission + err error + ) + if !testrig.WaitFor(func() bool { + perm, err = testStructs.State.DB.GetDomainAllow(ctx, domain) + return err == nil + }) { + suite.FailNowf("", "timed out waiting for domain %s", domain) + } + + suite.Equal(testAllowSubscription.ID, perm.GetSubscriptionID()) + } + + // And blocks for for each domain + // on the subscribed block list. + for _, domain := range []string{ + "bumfaces.net", + "peepee.poopoo", + "nothanks.com", + } { + var ( + perm gtsmodel.DomainPermission + err error + ) + if !testrig.WaitFor(func() bool { + perm, err = testStructs.State.DB.GetDomainBlock(ctx, domain) + return err == nil + }) { + suite.FailNowf("", "timed out waiting for domain %s", domain) + } + + suite.Equal(testBlockSubscription.ID, perm.GetSubscriptionID()) + } +} + func TestSubscriptionTestSuite(t *testing.T) { suite.Run(t, new(SubscriptionsTestSuite)) } diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go index a6b0dd801..bbcb3901d 100644 --- a/testrig/transportcontroller.go +++ b/testrig/transportcontroller.go @@ -640,6 +640,10 @@ nothanks.com` } ]` jsonRespETag = "\"don't modify me daddy\"" + allowsResp = `people.we.like.com +goodeggs.org +allowthesefolks.church` + allowsRespETag = "\"never change\"" ) switch req.URL.String() { @@ -720,6 +724,36 @@ nothanks.com` } responseContentLength = len(responseBytes) + case "https://lists.example.org/goodies.csv": + extraHeaders = map[string]string{ + "Last-Modified": lastModified, + "ETag": allowsRespETag, + } + if req.Header.Get("If-None-Match") == allowsRespETag { + // Cached. + responseCode = http.StatusNotModified + } else { + responseBytes = []byte(allowsResp) + responseContentType = textCSV + responseCode = http.StatusOK + } + responseContentLength = len(responseBytes) + + case "https://lists.example.org/goodies": + extraHeaders = map[string]string{ + "Last-Modified": lastModified, + "ETag": allowsRespETag, + } + if req.Header.Get("If-None-Match") == allowsRespETag { + // Cached. + responseCode = http.StatusNotModified + } else { + responseBytes = []byte(allowsResp) + responseContentType = textPlain + responseCode = http.StatusOK + } + responseContentLength = len(responseBytes) + default: responseCode = http.StatusNotFound responseBytes = []byte(`{"error":"not found"}`)