mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-11-03 19:42:25 -06:00 
			
		
		
		
	Bumps [github.com/SherClockHolmes/webpush-go](https://github.com/SherClockHolmes/webpush-go) from 1.3.0 to 1.4.0. - [Release notes](https://github.com/SherClockHolmes/webpush-go/releases) - [Commits](https://github.com/SherClockHolmes/webpush-go/compare/v1.3.0...v1.4.0) --- updated-dependencies: - dependency-name: github.com/SherClockHolmes/webpush-go dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
		
			
				
	
	
		
			286 lines
		
	
	
	
		
			7.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			286 lines
		
	
	
	
		
			7.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package webpush
 | 
						||
 | 
						||
import (
 | 
						||
	"bytes"
 | 
						||
	"context"
 | 
						||
	"crypto/aes"
 | 
						||
	"crypto/cipher"
 | 
						||
	"crypto/elliptic"
 | 
						||
	"crypto/rand"
 | 
						||
	"crypto/sha256"
 | 
						||
	"encoding/base64"
 | 
						||
	"encoding/binary"
 | 
						||
	"errors"
 | 
						||
	"io"
 | 
						||
	"net/http"
 | 
						||
	"strconv"
 | 
						||
	"strings"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"golang.org/x/crypto/hkdf"
 | 
						||
)
 | 
						||
 | 
						||
const MaxRecordSize uint32 = 4096
 | 
						||
 | 
						||
var ErrMaxPadExceeded = errors.New("payload has exceeded the maximum length")
 | 
						||
 | 
						||
// saltFunc generates a salt of 16 bytes
 | 
						||
var saltFunc = func() ([]byte, error) {
 | 
						||
	salt := make([]byte, 16)
 | 
						||
	_, err := io.ReadFull(rand.Reader, salt)
 | 
						||
	if err != nil {
 | 
						||
		return salt, err
 | 
						||
	}
 | 
						||
 | 
						||
	return salt, nil
 | 
						||
}
 | 
						||
 | 
						||
// HTTPClient is an interface for sending the notification HTTP request / testing
 | 
						||
type HTTPClient interface {
 | 
						||
	Do(*http.Request) (*http.Response, error)
 | 
						||
}
 | 
						||
 | 
						||
// Options are config and extra params needed to send a notification
 | 
						||
type Options struct {
 | 
						||
	HTTPClient      HTTPClient // Will replace with *http.Client by default if not included
 | 
						||
	RecordSize      uint32     // Limit the record size
 | 
						||
	Subscriber      string     // Sub in VAPID JWT token
 | 
						||
	Topic           string     // Set the Topic header to collapse a pending messages (Optional)
 | 
						||
	TTL             int        // Set the TTL on the endpoint POST request
 | 
						||
	Urgency         Urgency    // Set the Urgency header to change a message priority (Optional)
 | 
						||
	VAPIDPublicKey  string     // VAPID public key, passed in VAPID Authorization header
 | 
						||
	VAPIDPrivateKey string     // VAPID private key, used to sign VAPID JWT token
 | 
						||
	VapidExpiration time.Time  // optional expiration for VAPID JWT token (defaults to now + 12 hours)
 | 
						||
}
 | 
						||
 | 
						||
// Keys are the base64 encoded values from PushSubscription.getKey()
 | 
						||
type Keys struct {
 | 
						||
	Auth   string `json:"auth"`
 | 
						||
	P256dh string `json:"p256dh"`
 | 
						||
}
 | 
						||
 | 
						||
// Subscription represents a PushSubscription object from the Push API
 | 
						||
type Subscription struct {
 | 
						||
	Endpoint string `json:"endpoint"`
 | 
						||
	Keys     Keys   `json:"keys"`
 | 
						||
}
 | 
						||
 | 
						||
// SendNotification calls SendNotificationWithContext with default context for backwards-compatibility
 | 
						||
func SendNotification(message []byte, s *Subscription, options *Options) (*http.Response, error) {
 | 
						||
	return SendNotificationWithContext(context.Background(), message, s, options)
 | 
						||
}
 | 
						||
 | 
						||
// SendNotificationWithContext sends a push notification to a subscription's endpoint
 | 
						||
// Message Encryption for Web Push, and VAPID protocols.
 | 
						||
// FOR MORE INFORMATION SEE RFC8291: https://datatracker.ietf.org/doc/rfc8291
 | 
						||
func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscription, options *Options) (*http.Response, error) {
 | 
						||
	// Authentication secret (auth_secret)
 | 
						||
	authSecret, err := decodeSubscriptionKey(s.Keys.Auth)
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	// dh (Diffie Hellman)
 | 
						||
	dh, err := decodeSubscriptionKey(s.Keys.P256dh)
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	// Generate 16 byte salt
 | 
						||
	salt, err := saltFunc()
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	// Create the ecdh_secret shared key pair
 | 
						||
	curve := elliptic.P256()
 | 
						||
 | 
						||
	// Application server key pairs (single use)
 | 
						||
	localPrivateKey, x, y, err := elliptic.GenerateKey(curve, rand.Reader)
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	localPublicKey := elliptic.Marshal(curve, x, y)
 | 
						||
 | 
						||
	// Combine application keys with receiver's EC public key
 | 
						||
	sharedX, sharedY := elliptic.Unmarshal(curve, dh)
 | 
						||
	if sharedX == nil {
 | 
						||
		return nil, errors.New("Unmarshal Error: Public key is not a valid point on the curve")
 | 
						||
	}
 | 
						||
 | 
						||
	// Derive ECDH shared secret
 | 
						||
	sx, sy := curve.ScalarMult(sharedX, sharedY, localPrivateKey)
 | 
						||
	if !curve.IsOnCurve(sx, sy) {
 | 
						||
		return nil, errors.New("Encryption error: ECDH shared secret isn't on curve")
 | 
						||
	}
 | 
						||
	mlen := curve.Params().BitSize / 8
 | 
						||
	sharedECDHSecret := make([]byte, mlen)
 | 
						||
	sx.FillBytes(sharedECDHSecret)
 | 
						||
 | 
						||
	hash := sha256.New
 | 
						||
 | 
						||
	// ikm
 | 
						||
	prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00"))
 | 
						||
	prkInfoBuf.Write(dh)
 | 
						||
	prkInfoBuf.Write(localPublicKey)
 | 
						||
 | 
						||
	prkHKDF := hkdf.New(hash, sharedECDHSecret, authSecret, prkInfoBuf.Bytes())
 | 
						||
	ikm, err := getHKDFKey(prkHKDF, 32)
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	// Derive Content Encryption Key
 | 
						||
	contentEncryptionKeyInfo := []byte("Content-Encoding: aes128gcm\x00")
 | 
						||
	contentHKDF := hkdf.New(hash, ikm, salt, contentEncryptionKeyInfo)
 | 
						||
	contentEncryptionKey, err := getHKDFKey(contentHKDF, 16)
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	// Derive the Nonce
 | 
						||
	nonceInfo := []byte("Content-Encoding: nonce\x00")
 | 
						||
	nonceHKDF := hkdf.New(hash, ikm, salt, nonceInfo)
 | 
						||
	nonce, err := getHKDFKey(nonceHKDF, 12)
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	// Cipher
 | 
						||
	c, err := aes.NewCipher(contentEncryptionKey)
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	gcm, err := cipher.NewGCM(c)
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	// Get the record size
 | 
						||
	recordSize := options.RecordSize
 | 
						||
	if recordSize == 0 {
 | 
						||
		recordSize = MaxRecordSize
 | 
						||
	}
 | 
						||
 | 
						||
	recordLength := int(recordSize) - 16
 | 
						||
 | 
						||
	// Encryption Content-Coding Header
 | 
						||
	recordBuf := bytes.NewBuffer(salt)
 | 
						||
 | 
						||
	rs := make([]byte, 4)
 | 
						||
	binary.BigEndian.PutUint32(rs, recordSize)
 | 
						||
 | 
						||
	recordBuf.Write(rs)
 | 
						||
	recordBuf.Write([]byte{byte(len(localPublicKey))})
 | 
						||
	recordBuf.Write(localPublicKey)
 | 
						||
 | 
						||
	// Data
 | 
						||
	dataBuf := bytes.NewBuffer(message)
 | 
						||
 | 
						||
	// Pad content to max record size - 16 - header
 | 
						||
	// Padding ending delimeter
 | 
						||
	dataBuf.Write([]byte("\x02"))
 | 
						||
	if err := pad(dataBuf, recordLength-recordBuf.Len()); err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	// Compose the ciphertext
 | 
						||
	ciphertext := gcm.Seal([]byte{}, nonce, dataBuf.Bytes(), nil)
 | 
						||
	recordBuf.Write(ciphertext)
 | 
						||
 | 
						||
	// POST request
 | 
						||
	req, err := http.NewRequest("POST", s.Endpoint, recordBuf)
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	if ctx != nil {
 | 
						||
		req = req.WithContext(ctx)
 | 
						||
	}
 | 
						||
 | 
						||
	req.Header.Set("Content-Encoding", "aes128gcm")
 | 
						||
	req.Header.Set("Content-Type", "application/octet-stream")
 | 
						||
	req.Header.Set("TTL", strconv.Itoa(options.TTL))
 | 
						||
 | 
						||
	// Сheck the optional headers
 | 
						||
	if len(options.Topic) > 0 {
 | 
						||
		req.Header.Set("Topic", options.Topic)
 | 
						||
	}
 | 
						||
 | 
						||
	if isValidUrgency(options.Urgency) {
 | 
						||
		req.Header.Set("Urgency", string(options.Urgency))
 | 
						||
	}
 | 
						||
 | 
						||
	expiration := options.VapidExpiration
 | 
						||
	if expiration.IsZero() {
 | 
						||
		expiration = time.Now().Add(time.Hour * 12)
 | 
						||
	}
 | 
						||
 | 
						||
	// Get VAPID Authorization header
 | 
						||
	vapidAuthHeader, err := getVAPIDAuthorizationHeader(
 | 
						||
		s.Endpoint,
 | 
						||
		options.Subscriber,
 | 
						||
		options.VAPIDPublicKey,
 | 
						||
		options.VAPIDPrivateKey,
 | 
						||
		expiration,
 | 
						||
	)
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	req.Header.Set("Authorization", vapidAuthHeader)
 | 
						||
 | 
						||
	// Send the request
 | 
						||
	var client HTTPClient
 | 
						||
	if options.HTTPClient != nil {
 | 
						||
		client = options.HTTPClient
 | 
						||
	} else {
 | 
						||
		client = &http.Client{}
 | 
						||
	}
 | 
						||
 | 
						||
	return client.Do(req)
 | 
						||
}
 | 
						||
 | 
						||
// decodeSubscriptionKey decodes a base64 subscription key.
 | 
						||
// if necessary, add "=" padding to the key for URL decode
 | 
						||
func decodeSubscriptionKey(key string) ([]byte, error) {
 | 
						||
	// "=" padding
 | 
						||
	buf := bytes.NewBufferString(key)
 | 
						||
	if rem := len(key) % 4; rem != 0 {
 | 
						||
		buf.WriteString(strings.Repeat("=", 4-rem))
 | 
						||
	}
 | 
						||
 | 
						||
	bytes, err := base64.StdEncoding.DecodeString(buf.String())
 | 
						||
	if err == nil {
 | 
						||
		return bytes, nil
 | 
						||
	}
 | 
						||
 | 
						||
	return base64.URLEncoding.DecodeString(buf.String())
 | 
						||
}
 | 
						||
 | 
						||
// Returns a key of length "length" given an hkdf function
 | 
						||
func getHKDFKey(hkdf io.Reader, length int) ([]byte, error) {
 | 
						||
	key := make([]byte, length)
 | 
						||
	n, err := io.ReadFull(hkdf, key)
 | 
						||
	if n != len(key) || err != nil {
 | 
						||
		return key, err
 | 
						||
	}
 | 
						||
 | 
						||
	return key, nil
 | 
						||
}
 | 
						||
 | 
						||
func pad(payload *bytes.Buffer, maxPadLen int) error {
 | 
						||
	payloadLen := payload.Len()
 | 
						||
	if payloadLen > maxPadLen {
 | 
						||
		return ErrMaxPadExceeded
 | 
						||
	}
 | 
						||
 | 
						||
	padLen := maxPadLen - payloadLen
 | 
						||
 | 
						||
	padding := make([]byte, padLen)
 | 
						||
	payload.Write(padding)
 | 
						||
 | 
						||
	return nil
 | 
						||
}
 |