rename function, strip port from domain validation

This commit is contained in:
kim 2025-01-14 13:08:51 +00:00
commit 8db867b6df
10 changed files with 57 additions and 40 deletions

View file

@ -28,6 +28,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/instance" "github.com/superseriousbusiness/gotosocial/internal/api/client/instance"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -51,6 +52,7 @@ func (suite *InstancePatchTestSuite) instancePatch(fieldName string, fileName st
ctx := suite.newContext(recorder, http.MethodPatch, instance.InstanceInformationPathV1, requestBody.Bytes(), w.FormDataContentType(), true) ctx := suite.newContext(recorder, http.MethodPatch, instance.InstanceInformationPathV1, requestBody.Bytes(), w.FormDataContentType(), true)
suite.instanceModule.InstanceUpdatePATCHHandler(ctx) suite.instanceModule.InstanceUpdatePATCHHandler(ctx)
middleware.Logger(false)(ctx)
result := recorder.Result() result := recorder.Result()
defer result.Body.Close() defer result.Body.Close()

View file

@ -140,7 +140,7 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
var err error var err error
// Normalize the domain as punycode // Normalize the domain as punycode
domain, err = util.Punify_(domain) domain, err = util.Punify(domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -39,7 +39,7 @@ type domainDB struct {
func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) (err error) { func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) (err error) {
// Normalize the domain as punycode, note the extra // Normalize the domain as punycode, note the extra
// validation step for domain name write operations. // validation step for domain name write operations.
allow.Domain, err = util.PunifyValidate(allow.Domain) allow.Domain, err = util.PunifySafely(allow.Domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", allow.Domain, err) return gtserror.Newf("error punifying domain %s: %w", allow.Domain, err)
} }
@ -59,7 +59,7 @@ func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.Domain
func (d *domainDB) GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error) { func (d *domainDB) GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error) {
// Normalize domain as punycode for lookup. // Normalize domain as punycode for lookup.
domain, err := util.Punify_(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
@ -114,7 +114,7 @@ func (d *domainDB) GetDomainAllowByID(ctx context.Context, id string) (*gtsmodel
func (d *domainDB) UpdateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow, columns ...string) (err error) { func (d *domainDB) UpdateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow, columns ...string) (err error) {
// Normalize the domain as punycode, note the extra // Normalize the domain as punycode, note the extra
// validation step for domain name write operations. // validation step for domain name write operations.
allow.Domain, err = util.PunifyValidate(allow.Domain) allow.Domain, err = util.PunifySafely(allow.Domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", allow.Domain, err) return gtserror.Newf("error punifying domain %s: %w", allow.Domain, err)
} }
@ -143,7 +143,7 @@ func (d *domainDB) UpdateDomainAllow(ctx context.Context, allow *gtsmodel.Domain
func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error { func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error {
// Normalize domain as punycode for lookup. // Normalize domain as punycode for lookup.
domain, err := util.Punify_(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", domain, err) return gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
@ -167,7 +167,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
// Normalize the domain as punycode, note the extra // Normalize the domain as punycode, note the extra
// validation step for domain name write operations. // validation step for domain name write operations.
block.Domain, err = util.PunifyValidate(block.Domain) block.Domain, err = util.PunifySafely(block.Domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", block.Domain, err) return gtserror.Newf("error punifying domain %s: %w", block.Domain, err)
} }
@ -187,7 +187,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error) { func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error) {
// Normalize domain as punycode for lookup. // Normalize domain as punycode for lookup.
domain, err := util.Punify_(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
@ -244,7 +244,7 @@ func (d *domainDB) UpdateDomainBlock(ctx context.Context, block *gtsmodel.Domain
// Normalize the domain as punycode, note the extra // Normalize the domain as punycode, note the extra
// validation step for domain name write operations. // validation step for domain name write operations.
block.Domain, err = util.PunifyValidate(block.Domain) block.Domain, err = util.PunifySafely(block.Domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", block.Domain, err) return gtserror.Newf("error punifying domain %s: %w", block.Domain, err)
} }
@ -273,7 +273,7 @@ func (d *domainDB) UpdateDomainBlock(ctx context.Context, block *gtsmodel.Domain
func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error { func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error {
// Normalize domain as punycode for lookup. // Normalize domain as punycode for lookup.
domain, err := util.Punify_(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", domain, err) return gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
@ -294,7 +294,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error {
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, error) { func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, error) {
// Normalize domain as punycode for lookup. // Normalize domain as punycode for lookup.
domain, err := util.Punify_(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return false, gtserror.Newf("error punifying domain %s: %w", domain, err) return false, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }

View file

@ -169,7 +169,7 @@ func (d *domainDB) GetDomainPermissionDrafts(
var err error var err error
// Normalize domain as punycode for lookup. // Normalize domain as punycode for lookup.
domain, err = util.Punify_(domain) domain, err = util.Punify(domain)
if err != nil { if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
@ -240,7 +240,7 @@ func (d *domainDB) PutDomainPermissionDraft(
// Normalize the domain as punycode, note the extra // Normalize the domain as punycode, note the extra
// validation step for domain name write operations. // validation step for domain name write operations.
draft.Domain, err = util.PunifyValidate(draft.Domain) draft.Domain, err = util.PunifySafely(draft.Domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", draft.Domain, err) return gtserror.Newf("error punifying domain %s: %w", draft.Domain, err)
} }

View file

@ -41,7 +41,7 @@ func (d *domainDB) PutDomainPermissionExclude(
// Normalize the domain as punycode, note the extra // Normalize the domain as punycode, note the extra
// validation step for domain name write operations. // validation step for domain name write operations.
exclude.Domain, err = util.PunifyValidate(exclude.Domain) exclude.Domain, err = util.PunifySafely(exclude.Domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", exclude.Domain, err) return gtserror.Newf("error punifying domain %s: %w", exclude.Domain, err)
} }
@ -61,7 +61,7 @@ func (d *domainDB) PutDomainPermissionExclude(
func (d *domainDB) IsDomainPermissionExcluded(ctx context.Context, domain string) (bool, error) { func (d *domainDB) IsDomainPermissionExcluded(ctx context.Context, domain string) (bool, error) {
// Normalize domain as punycode for lookup. // Normalize domain as punycode for lookup.
domain, err := util.Punify_(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return false, gtserror.Newf("error punifying domain %s: %w", domain, err) return false, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
@ -180,7 +180,7 @@ func (d *domainDB) GetDomainPermissionExcludes(
var err error var err error
// Normalize domain as punycode for lookup. // Normalize domain as punycode for lookup.
domain, err = util.Punify_(domain) domain, err = util.Punify(domain)
if err != nil { if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }

View file

@ -161,7 +161,7 @@ func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.
var err error var err error
// Normalize the domain as punycode // Normalize the domain as punycode
domain, err = util.Punify_(domain) domain, err = util.Punify(domain)
if err != nil { if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
@ -268,7 +268,7 @@ func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instanc
// Normalize the domain as punycode, note the extra // Normalize the domain as punycode, note the extra
// validation step for domain name write operations. // validation step for domain name write operations.
instance.Domain, err = util.PunifyValidate(instance.Domain) instance.Domain, err = util.PunifySafely(instance.Domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err) return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
} }
@ -285,7 +285,7 @@ func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Inst
// Normalize the domain as punycode, note the extra // Normalize the domain as punycode, note the extra
// validation step for domain name write operations. // validation step for domain name write operations.
instance.Domain, err = util.PunifyValidate(instance.Domain) instance.Domain, err = util.PunifySafely(instance.Domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err) return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
} }
@ -356,7 +356,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
var err error var err error
// Normalize the domain as punycode // Normalize the domain as punycode
domain, err = util.Punify_(domain) domain, err = util.Punify(domain)
if err != nil { if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }

View file

@ -627,7 +627,7 @@ func permsFromCSV(
// Normalize + validate domain. // Normalize + validate domain.
domainRaw := record[*domainI] domainRaw := record[*domainI]
domain, err := util.PunifyValidate(domainRaw) domain, err := util.PunifySafely(domainRaw)
if err != nil { if err != nil {
l.Warnf("skipping invalid domain %s: %+v", domainRaw, err) l.Warnf("skipping invalid domain %s: %+v", domainRaw, err)
continue continue
@ -700,7 +700,7 @@ func permsFromJSON(
// Normalize + validate domain. // Normalize + validate domain.
domainRaw := apiPerm.Domain.Domain domainRaw := apiPerm.Domain.Domain
domain, err := util.PunifyValidate(domainRaw) domain, err := util.PunifySafely(domainRaw)
if err != nil { if err != nil {
l.Warnf("skipping invalid domain %s: %+v", domainRaw, err) l.Warnf("skipping invalid domain %s: %+v", domainRaw, err)
continue continue
@ -756,7 +756,7 @@ func permsFromPlain(
for _, domainRaw := range domains { for _, domainRaw := range domains {
// Normalize + validate domain as ASCII. // Normalize + validate domain as ASCII.
domain, err := util.PunifyValidate(domainRaw) domain, err := util.PunifySafely(domainRaw)
if err != nil { if err != nil {
l.Warnf("skipping invalid domain %s: %+v", domainRaw, err) l.Warnf("skipping invalid domain %s: %+v", domainRaw, err)
continue continue

View file

@ -76,7 +76,7 @@ func prepWebfingerReq(ctx context.Context, loc, domain, username string) (*http.
func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) { func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
// Remotes seem to prefer having their punycode // Remotes seem to prefer having their punycode
// domain used in webfinger requests, so let's oblige. // domain used in webfinger requests, so let's oblige.
punyDomain, err := util.Punify_(targetDomain) punyDomain, err := util.Punify(targetDomain)
if err != nil { if err != nil {
return nil, gtserror.Newf("error punifying %s: %w", targetDomain, err) return nil, gtserror.Newf("error punifying %s: %w", targetDomain, err)
} }

View file

@ -31,15 +31,32 @@ var (
verifyProfile = *idna.Lookup verifyProfile = *idna.Lookup
) )
// PunifyValidate validates the provided domain name, // PunifySafely validates the provided domain name,
// and converts unicode chars to ASCII, i.e. punified form. // and converts unicode chars to ASCII, i.e. punified form.
func PunifyValidate(domain string) (string, error) { func PunifySafely(domain string) (string, error) {
domain, err := verifyProfile.ToASCII(domain) if i := strings.LastIndexByte(domain, ':'); i >= 0 {
return strings.ToLower(domain), err
// If there is a port included in domain, we
// strip it as colon is invalid in a hostname.
domain, port := domain[:i], domain[i:]
domain, err := verifyProfile.ToASCII(domain)
if err != nil {
return "", err
}
// Then rebuild with port after.
domain = strings.ToLower(domain)
return domain + port, nil
} else { //nolint:revive
// Otherwise we just punify domain as-is.
domain, err := verifyProfile.ToASCII(domain)
return strings.ToLower(domain), err
}
} }
// Punify is a faster form of ValidatePunify() without validation. // Punify is a faster form of PunifySafely() without validation.
func Punify_(domain string) (string, error) { func Punify(domain string) (string, error) {
domain, err := punifyProfile.ToASCII(domain) domain, err := punifyProfile.ToASCII(domain)
return strings.ToLower(domain), err return strings.ToLower(domain), err
} }
@ -62,7 +79,7 @@ func URIMatches(expect *url.URL, uris ...*url.URL) (ok bool, err error) {
*punyURI = *expect *punyURI = *expect
// Set punified expected URL host. // Set punified expected URL host.
punyURI.Host, err = Punify_(expect.Host) punyURI.Host, err = Punify(expect.Host)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -76,7 +93,7 @@ func URIMatches(expect *url.URL, uris ...*url.URL) (ok bool, err error) {
// strings to check against. // strings to check against.
for _, uri := range uris { for _, uri := range uris {
*punyURI = *uri *punyURI = *uri
punyURI.Host, err = Punify_(uri.Host) punyURI.Host, err = Punify(uri.Host)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -91,12 +108,11 @@ func URIMatches(expect *url.URL, uris ...*url.URL) (ok bool, err error) {
return false, nil return false, nil
} }
// PunifyURIToStr returns a new copy of URI with the // PunifyURI returns a new copy of URI with the 'host'
// 'host' part converted to punycode with DomainToASCII. // part converted to punycode with PunifySafely().
// This can potentially be expensive doing extra domain // For simple comparisons prefer the faster URIMatches().
// verification for storage, for simple checks prefer URIMatches().
func PunifyURI(in *url.URL) (*url.URL, error) { func PunifyURI(in *url.URL) (*url.URL, error) {
punyHost, err := PunifyValidate(in.Host) punyHost, err := PunifySafely(in.Host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -107,11 +123,10 @@ func PunifyURI(in *url.URL) (*url.URL, error) {
} }
// PunifyURIToStr returns given URI serialized with the // PunifyURIToStr returns given URI serialized with the
// 'host' part converted to punycode with DomainToASCII. // 'host' part converted to punycode with PunifySafely().
// This can potentially be expensive doing extra domain // For simple comparisons prefer the faster URIMatches().
// verification for storage, for simple checks prefer URIMatches().
func PunifyURIToStr(in *url.URL) (string, error) { func PunifyURIToStr(in *url.URL) (string, error) {
punyHost, err := PunifyValidate(in.Host) punyHost, err := PunifySafely(in.Host)
if err != nil { if err != nil {
return "", err return "", err
} }