Web Push: test notification policy

This commit is contained in:
Vyr Cossont 2025-01-31 13:18:59 -08:00
commit cfeae0330c
3 changed files with 104 additions and 28 deletions

View file

@ -67,8 +67,7 @@ func (r *realSender) Send(
relevantSubscriptions := slices.DeleteFunc( relevantSubscriptions := slices.DeleteFunc(
subscriptions, subscriptions,
func(subscription *gtsmodel.WebPushSubscription) bool { func(subscription *gtsmodel.WebPushSubscription) bool {
// Remove subscriptions that don't want this type of notification. return r.shouldSkipSubscription(ctx, notification, subscription)
return !subscription.NotificationFlags.Get(notification.NotificationType)
}, },
) )
if len(relevantSubscriptions) == 0 { if len(relevantSubscriptions) == 0 {
@ -117,6 +116,68 @@ func (r *realSender) Send(
return nil return nil
} }
// shouldSkipSubscription returns true if this subscription is not relevant to this notification.
func (r *realSender) shouldSkipSubscription(
ctx context.Context,
notification *gtsmodel.Notification,
subscription *gtsmodel.WebPushSubscription,
) bool {
// Remove subscriptions that don't want this type of notification.
if !subscription.NotificationFlags.Get(notification.NotificationType) {
return true
}
// Check against subscription's notification policy.
switch subscription.Policy {
case gtsmodel.WebPushNotificationPolicyAll:
// Allow notifications from any account.
return false
case gtsmodel.WebPushNotificationPolicyFollowed:
// Allow if the subscription account follows the notifying account.
isFollowing, err := r.state.DB.IsFollowing(ctx, subscription.AccountID, notification.OriginAccountID)
if err != nil {
log.Errorf(
ctx,
"error checking whether account %s follows account %s: %v",
subscription.AccountID,
notification.OriginAccountID,
err,
)
return true
}
return !isFollowing
case gtsmodel.WebPushNotificationPolicyFollower:
// Allow if the notifying account follows the subscription account.
isFollowing, err := r.state.DB.IsFollowing(ctx, notification.OriginAccountID, subscription.AccountID)
if err != nil {
log.Errorf(
ctx,
"error checking whether account %s follows account %s: %v",
notification.OriginAccountID,
subscription.AccountID,
err,
)
return true
}
return !isFollowing
case gtsmodel.WebPushNotificationPolicyNone:
// This subscription doesn't want any push notifications.
return true
default:
log.Errorf(
ctx,
"unknown Web Push notification policy for subscription with token ID %s: %d",
subscription.TokenID,
subscription.Policy,
)
return true
}
}
// sendToSubscription sends a notification to a single Web Push subscription. // sendToSubscription sends a notification to a single Web Push subscription.
func (r *realSender) sendToSubscription( func (r *realSender) sendToSubscription(
ctx context.Context, ctx context.Context,

View file

@ -23,7 +23,6 @@ import (
"net/http" "net/http"
"testing" "testing"
"time" "time"
// for go:linkname // for go:linkname
_ "unsafe" _ "unsafe"
@ -43,6 +42,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/subscriptions" "github.com/superseriousbusiness/gotosocial/internal/subscriptions"
"github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/webpush" "github.com/superseriousbusiness/gotosocial/internal/webpush"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -62,16 +62,7 @@ type RealSenderStandardTestSuite struct {
webPushSender webpush.Sender webPushSender webpush.Sender
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
testAttachments map[string]*gtsmodel.MediaAttachment
testStatuses map[string]*gtsmodel.Status
testTags map[string]*gtsmodel.Tag
testMentions map[string]*gtsmodel.Mention
testEmojis map[string]*gtsmodel.Emoji
testNotifications map[string]*gtsmodel.Notification testNotifications map[string]*gtsmodel.Notification
testWebPushSubscriptions map[string]*gtsmodel.WebPushSubscription testWebPushSubscriptions map[string]*gtsmodel.WebPushSubscription
@ -81,16 +72,7 @@ type RealSenderStandardTestSuite struct {
} }
func (suite *RealSenderStandardTestSuite) SetupSuite() { func (suite *RealSenderStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
suite.testEmojis = testrig.NewTestEmojis()
suite.testNotifications = testrig.NewTestNotifications() suite.testNotifications = testrig.NewTestNotifications()
suite.testWebPushSubscriptions = testrig.NewTestWebPushSubscriptions() suite.testWebPushSubscriptions = testrig.NewTestWebPushSubscriptions()
} }
@ -184,14 +166,16 @@ func (rc *notifyingReadCloser) Close() error {
// Simulate sending a push notification with the suite's fake web client. // Simulate sending a push notification with the suite's fake web client.
func (suite *RealSenderStandardTestSuite) simulatePushNotification( func (suite *RealSenderStandardTestSuite) simulatePushNotification(
notificationID string,
statusCode int, statusCode int,
expectSend bool,
expectDeletedSubscription bool, expectDeletedSubscription bool,
) error { ) error {
// Don't let the test run forever if the push notification was not sent for some reason. // Don't let the test run forever if the push notification was not sent for some reason.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() defer cancel()
notification, err := suite.state.DB.GetNotificationByID(ctx, suite.testNotifications["local_account_1_like"].ID) notification, err := suite.state.DB.GetNotificationByID(ctx, notificationID)
if !suite.NoError(err) { if !suite.NoError(err) {
suite.FailNow("Couldn't fetch notification to send") suite.FailNow("Couldn't fetch notification to send")
} }
@ -221,6 +205,14 @@ func (suite *RealSenderStandardTestSuite) simulatePushNotification(
case <-ctx.Done(): case <-ctx.Done():
contextExpired = true contextExpired = true
} }
// In some cases we expect the notification *not* to be sent.
if !expectSend {
suite.False(bodyClosed)
suite.True(contextExpired)
return nil
}
suite.True(bodyClosed) suite.True(bodyClosed)
suite.False(contextExpired) suite.False(contextExpired)
@ -240,25 +232,48 @@ func (suite *RealSenderStandardTestSuite) simulatePushNotification(
// Test a successful response to sending a push notification. // Test a successful response to sending a push notification.
func (suite *RealSenderStandardTestSuite) TestSendSuccess() { func (suite *RealSenderStandardTestSuite) TestSendSuccess() {
suite.NoError(suite.simulatePushNotification(http.StatusOK, false)) notificationID := suite.testNotifications["local_account_1_like"].ID
suite.NoError(suite.simulatePushNotification(notificationID, http.StatusOK, true, false))
} }
// Test a rate-limiting response to sending a push notification. // Test a rate-limiting response to sending a push notification.
// This should not delete the subscription. // This should not delete the subscription.
func (suite *RealSenderStandardTestSuite) TestRateLimited() { func (suite *RealSenderStandardTestSuite) TestRateLimited() {
suite.NoError(suite.simulatePushNotification(http.StatusTooManyRequests, false)) notificationID := suite.testNotifications["local_account_1_like"].ID
suite.NoError(suite.simulatePushNotification(notificationID, http.StatusTooManyRequests, true, false))
} }
// Test a non-special-cased client error response to sending a push notification. // Test a non-special-cased client error response to sending a push notification.
// This should delete the subscription. // This should delete the subscription.
func (suite *RealSenderStandardTestSuite) TestClientError() { func (suite *RealSenderStandardTestSuite) TestClientError() {
suite.NoError(suite.simulatePushNotification(http.StatusBadRequest, true)) notificationID := suite.testNotifications["local_account_1_like"].ID
suite.NoError(suite.simulatePushNotification(notificationID, http.StatusBadRequest, true, true))
} }
// Test a server error response to sending a push notification. // Test a server error response to sending a push notification.
// This should not delete the subscription. // This should not delete the subscription.
func (suite *RealSenderStandardTestSuite) TestServerError() { func (suite *RealSenderStandardTestSuite) TestServerError() {
suite.NoError(suite.simulatePushNotification(http.StatusInternalServerError, false)) notificationID := suite.testNotifications["local_account_1_like"].ID
suite.NoError(suite.simulatePushNotification(notificationID, http.StatusInternalServerError, true, false))
}
// Don't send a push notification if it doesn't match policy.
func (suite *RealSenderStandardTestSuite) TestSendPolicyMismatch() {
// Setup: create a new notification from an account that the subscribed account doesn't follow.
notification := &gtsmodel.Notification{
ID: "01JJZ2Y9Z8E1XKT90EHZ5KZBDW",
NotificationType: gtsmodel.NotificationFavourite,
TargetAccountID: suite.testAccounts["local_account_1"].ID,
OriginAccountID: suite.testAccounts["remote_account_1"].ID,
StatusID: "01F8MHAMCHF6Y650WCRSCP4WMY",
Read: util.Ptr(false),
}
if err := suite.db.PutNotification(context.Background(), notification); !suite.NoError(err) {
suite.FailNow(err.Error())
return
}
suite.NoError(suite.simulatePushNotification(notification.ID, 0, false, false))
} }
func TestRealSenderStandardTestSuite(t *testing.T) { func TestRealSenderStandardTestSuite(t *testing.T) {

View file

@ -3610,7 +3610,7 @@ func NewTestWebPushSubscriptions() map[string]*gtsmodel.WebPushSubscription {
gtsmodel.NotificationPendingReply, gtsmodel.NotificationPendingReply,
gtsmodel.NotificationPendingReblog, gtsmodel.NotificationPendingReblog,
}), }),
Policy: gtsmodel.WebPushNotificationPolicyAll, Policy: gtsmodel.WebPushNotificationPolicyFollowed,
}, },
} }
} }