mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-11-02 13:42:25 -06:00
[bugfix] fix higher-level explicit domain rules causing issues with lower-level domain blocking (#2513)
* fix the sort direction of domain cache child nodes ...
* add more domain cache test cases
* add specific test for this bug to database domain test suite (thanks for writing this @tsmethurst!)
* remove unused field (this was a previous attempt at a fix)
* remove debugging println statements 😇
This commit is contained in:
parent
87bb596a02
commit
dfc7656579
3 changed files with 114 additions and 19 deletions
47
internal/cache/domain/domain.go
vendored
47
internal/cache/domain/domain.go
vendored
|
|
@ -19,10 +19,9 @@ package domain
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Cache provides a means of caching domains in memory to reduce
|
||||
|
|
@ -57,6 +56,24 @@ func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, err
|
|||
return false, fmt.Errorf("error reloading cache: %w", err)
|
||||
}
|
||||
|
||||
// Ensure the domains being inserted into the cache
|
||||
// are sorted by number of domain parts. i.e. those
|
||||
// with less parts are inserted last, else this can
|
||||
// allow domains to fall through the matching code!
|
||||
slices.SortFunc(domains, func(a, b string) int {
|
||||
const k = +1
|
||||
an := strings.Count(a, ".")
|
||||
bn := strings.Count(b, ".")
|
||||
switch {
|
||||
case an < bn:
|
||||
return +k
|
||||
case an > bn:
|
||||
return -k
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
|
||||
// Allocate new radix trie
|
||||
// node to store matches.
|
||||
ptr = new(root)
|
||||
|
|
@ -94,13 +111,13 @@ type root struct{ root node }
|
|||
|
||||
// Add will add the given domain to the radix trie.
|
||||
func (r *root) Add(domain string) {
|
||||
r.root.add(strings.Split(domain, "."))
|
||||
r.root.Add(strings.Split(domain, "."))
|
||||
}
|
||||
|
||||
// Match will return whether the given domain matches
|
||||
// an existing stored domain in this radix trie.
|
||||
func (r *root) Match(domain string) bool {
|
||||
return r.root.match(strings.Split(domain, "."))
|
||||
return r.root.Match(strings.Split(domain, "."))
|
||||
}
|
||||
|
||||
// Sort will sort the entire radix trie ensuring that
|
||||
|
|
@ -114,7 +131,7 @@ func (r *root) Sort() {
|
|||
// String returns a string representation of node (and its descendants).
|
||||
func (r *root) String() string {
|
||||
buf := new(strings.Builder)
|
||||
r.root.writestr(buf, "")
|
||||
r.root.WriteStr(buf, "")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
|
|
@ -123,7 +140,7 @@ type node struct {
|
|||
child []*node
|
||||
}
|
||||
|
||||
func (n *node) add(parts []string) {
|
||||
func (n *node) Add(parts []string) {
|
||||
if len(parts) == 0 {
|
||||
panic("invalid domain")
|
||||
}
|
||||
|
|
@ -165,7 +182,7 @@ func (n *node) add(parts []string) {
|
|||
}
|
||||
}
|
||||
|
||||
func (n *node) match(parts []string) bool {
|
||||
func (n *node) Match(parts []string) bool {
|
||||
for len(parts) > 0 {
|
||||
// Pop next domain part.
|
||||
i := len(parts) - 1
|
||||
|
|
@ -226,8 +243,16 @@ func (n *node) getChild(part string) *node {
|
|||
|
||||
func (n *node) sort() {
|
||||
// Sort this node's slice of child nodes.
|
||||
slices.SortFunc(n.child, func(i, j *node) bool {
|
||||
return i.part < j.part
|
||||
slices.SortFunc(n.child, func(i, j *node) int {
|
||||
const k = -1
|
||||
switch {
|
||||
case i.part < j.part:
|
||||
return +k
|
||||
case i.part > j.part:
|
||||
return -k
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
|
||||
// Sort each child node's children.
|
||||
|
|
@ -236,7 +261,7 @@ func (n *node) sort() {
|
|||
}
|
||||
}
|
||||
|
||||
func (n *node) writestr(buf *strings.Builder, prefix string) {
|
||||
func (n *node) WriteStr(buf *strings.Builder, prefix string) {
|
||||
if prefix != "" {
|
||||
// Suffix joining '.'
|
||||
prefix += "."
|
||||
|
|
@ -251,6 +276,6 @@ func (n *node) writestr(buf *strings.Builder, prefix string) {
|
|||
|
||||
// Iterate through node children.
|
||||
for _, child := range n.child {
|
||||
child.writestr(buf, prefix)
|
||||
child.WriteStr(buf, prefix)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue