[bugfix] fix higher-level explicit domain rules causing issues with lower-level domain blocking (#2513)

* fix the sort direction of domain cache child nodes ...

* add more domain cache test cases

* add specific test for this bug to database domain test suite (thanks for writing this @tsmethurst!)

* remove unused field (this was a previous attempt at a fix)

* remove debugging println statements 😇
This commit is contained in:
kim 2024-01-09 13:12:43 +00:00 committed by GitHub
commit dfc7656579
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 114 additions and 19 deletions

View file

@ -19,10 +19,9 @@ package domain
import (
"fmt"
"slices"
"strings"
"sync/atomic"
"golang.org/x/exp/slices"
)
// Cache provides a means of caching domains in memory to reduce
@ -57,6 +56,24 @@ func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, err
return false, fmt.Errorf("error reloading cache: %w", err)
}
// Ensure the domains being inserted into the cache
// are sorted by number of domain parts. i.e. those
// with less parts are inserted last, else this can
// allow domains to fall through the matching code!
slices.SortFunc(domains, func(a, b string) int {
const k = +1
an := strings.Count(a, ".")
bn := strings.Count(b, ".")
switch {
case an < bn:
return +k
case an > bn:
return -k
default:
return 0
}
})
// Allocate new radix trie
// node to store matches.
ptr = new(root)
@ -94,13 +111,13 @@ type root struct{ root node }
// Add will add the given domain to the radix trie.
func (r *root) Add(domain string) {
r.root.add(strings.Split(domain, "."))
r.root.Add(strings.Split(domain, "."))
}
// Match will return whether the given domain matches
// an existing stored domain in this radix trie.
func (r *root) Match(domain string) bool {
return r.root.match(strings.Split(domain, "."))
return r.root.Match(strings.Split(domain, "."))
}
// Sort will sort the entire radix trie ensuring that
@ -114,7 +131,7 @@ func (r *root) Sort() {
// String returns a string representation of node (and its descendants).
func (r *root) String() string {
buf := new(strings.Builder)
r.root.writestr(buf, "")
r.root.WriteStr(buf, "")
return buf.String()
}
@ -123,7 +140,7 @@ type node struct {
child []*node
}
func (n *node) add(parts []string) {
func (n *node) Add(parts []string) {
if len(parts) == 0 {
panic("invalid domain")
}
@ -165,7 +182,7 @@ func (n *node) add(parts []string) {
}
}
func (n *node) match(parts []string) bool {
func (n *node) Match(parts []string) bool {
for len(parts) > 0 {
// Pop next domain part.
i := len(parts) - 1
@ -226,8 +243,16 @@ func (n *node) getChild(part string) *node {
func (n *node) sort() {
// Sort this node's slice of child nodes.
slices.SortFunc(n.child, func(i, j *node) bool {
return i.part < j.part
slices.SortFunc(n.child, func(i, j *node) int {
const k = -1
switch {
case i.part < j.part:
return +k
case i.part > j.part:
return -k
default:
return 0
}
})
// Sort each child node's children.
@ -236,7 +261,7 @@ func (n *node) sort() {
}
}
func (n *node) writestr(buf *strings.Builder, prefix string) {
func (n *node) WriteStr(buf *strings.Builder, prefix string) {
if prefix != "" {
// Suffix joining '.'
prefix += "."
@ -251,6 +276,6 @@ func (n *node) writestr(buf *strings.Builder, prefix string) {
// Iterate through node children.
for _, child := range n.child {
child.writestr(buf, prefix)
child.WriteStr(buf, prefix)
}
}