| 
									
										
										
										
											2023-03-12 16:00:57 +01:00
										 |  |  | // GoToSocial | 
					
						
							|  |  |  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | 
					
						
							|  |  |  | // SPDX-License-Identifier: AGPL-3.0-or-later | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // This program is free software: you can redistribute it and/or modify | 
					
						
							|  |  |  | // it under the terms of the GNU Affero General Public License as published by | 
					
						
							|  |  |  | // the Free Software Foundation, either version 3 of the License, or | 
					
						
							|  |  |  | // (at your option) any later version. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // This program is distributed in the hope that it will be useful, | 
					
						
							|  |  |  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | 
					
						
							|  |  |  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | 
					
						
							|  |  |  | // GNU Affero General Public License for more details. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // You should have received a copy of the GNU Affero General Public License | 
					
						
							|  |  |  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | package domain | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"fmt" | 
					
						
							| 
									
										
										
										
											2024-01-09 13:12:43 +00:00
										 |  |  | 	"slices" | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 	"strings" | 
					
						
							|  |  |  | 	"sync/atomic" | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-21 12:12:04 +02:00
										 |  |  | // Cache provides a means of caching domains in memory to reduce | 
					
						
							|  |  |  | // load on an underlying storage mechanism, e.g. a database. | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | // | 
					
						
							| 
									
										
										
										
											2023-09-21 12:12:04 +02:00
										 |  |  | // The in-memory domain list is kept up-to-date by means of a passed | 
					
						
							|  |  |  | // loader function during every call to .Matches(). In the case of | 
					
						
							|  |  |  | // a nil internal domain list, the loader function is called to hydrate | 
					
						
							|  |  |  | // the cache with the latest list of domains. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // The .Clear() function can be used to invalidate the cache, | 
					
						
							|  |  |  | // e.g. when an entry is added / deleted from the database. | 
					
						
							|  |  |  | type Cache struct { | 
					
						
							|  |  |  | 	// current domain cache radix trie. | 
					
						
							| 
									
										
										
										
											2023-12-18 14:18:25 +00:00
										 |  |  | 	rootptr atomic.Pointer[root] | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-21 12:12:04 +02:00
										 |  |  | // Matches checks whether domain matches an entry in the cache. | 
					
						
							|  |  |  | // If the cache is not currently loaded, then the provided load | 
					
						
							|  |  |  | // function is used to hydrate it. | 
					
						
							|  |  |  | func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, error) { | 
					
						
							| 
									
										
										
										
											2023-12-18 14:18:25 +00:00
										 |  |  | 	// Load the current | 
					
						
							|  |  |  | 	// root pointer value. | 
					
						
							|  |  |  | 	ptr := c.rootptr.Load() | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 	if ptr == nil { | 
					
						
							|  |  |  | 		// Cache is not hydrated. | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | 		// | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 		// Load domains from callback. | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | 		domains, err := load() | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return false, fmt.Errorf("error reloading cache: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 13:12:43 +00:00
										 |  |  | 		// Ensure the domains being inserted into the cache | 
					
						
							|  |  |  | 		// are sorted by number of domain parts. i.e. those | 
					
						
							|  |  |  | 		// with less parts are inserted last, else this can | 
					
						
							|  |  |  | 		// allow domains to fall through the matching code! | 
					
						
							|  |  |  | 		slices.SortFunc(domains, func(a, b string) int { | 
					
						
							|  |  |  | 			const k = +1 | 
					
						
							|  |  |  | 			an := strings.Count(a, ".") | 
					
						
							|  |  |  | 			bn := strings.Count(b, ".") | 
					
						
							|  |  |  | 			switch { | 
					
						
							|  |  |  | 			case an < bn: | 
					
						
							|  |  |  | 				return +k | 
					
						
							|  |  |  | 			case an > bn: | 
					
						
							|  |  |  | 				return -k | 
					
						
							|  |  |  | 			default: | 
					
						
							|  |  |  | 				return 0 | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 		// Allocate new radix trie | 
					
						
							|  |  |  | 		// node to store matches. | 
					
						
							| 
									
										
										
										
											2023-12-18 14:18:25 +00:00
										 |  |  | 		ptr = new(root) | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 		// Add each domain to the trie. | 
					
						
							|  |  |  | 		for _, domain := range domains { | 
					
						
							| 
									
										
										
										
											2023-12-18 14:18:25 +00:00
										 |  |  | 			ptr.Add(domain) | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 		// Sort the trie. | 
					
						
							| 
									
										
										
										
											2023-12-18 14:18:25 +00:00
										 |  |  | 		ptr.Sort() | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-18 14:18:25 +00:00
										 |  |  | 		// Store new node ptr. | 
					
						
							|  |  |  | 		c.rootptr.Store(ptr) | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-18 14:18:25 +00:00
										 |  |  | 	// Look for match in trie node. | 
					
						
							|  |  |  | 	return ptr.Match(domain), nil | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | // Clear will drop the currently loaded domain list, | 
					
						
							| 
									
										
										
										
											2023-09-21 12:12:04 +02:00
										 |  |  | // triggering a reload on next call to .Matches(). | 
					
						
							| 
									
										
										
										
											2023-12-18 14:18:25 +00:00
										 |  |  | func (c *Cache) Clear() { c.rootptr.Store(nil) } | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-21 12:12:04 +02:00
										 |  |  | // String returns a string representation of stored domains in cache. | 
					
						
							|  |  |  | func (c *Cache) String() string { | 
					
						
							| 
									
										
										
										
											2023-12-18 14:18:25 +00:00
										 |  |  | 	if ptr := c.rootptr.Load(); ptr != nil { | 
					
						
							|  |  |  | 		return ptr.String() | 
					
						
							| 
									
										
										
										
											2023-05-09 15:18:51 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	return "<empty>" | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-21 12:12:04 +02:00
										 |  |  | // root is the root node in the domain cache radix trie. this is the singular access point to the trie. | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | type root struct{ root node } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Add will add the given domain to the radix trie. | 
					
						
							|  |  |  | func (r *root) Add(domain string) { | 
					
						
							| 
									
										
										
										
											2024-01-09 13:12:43 +00:00
										 |  |  | 	r.root.Add(strings.Split(domain, ".")) | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Match will return whether the given domain matches | 
					
						
							| 
									
										
										
										
											2023-09-21 12:12:04 +02:00
										 |  |  | // an existing stored domain in this radix trie. | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | func (r *root) Match(domain string) bool { | 
					
						
							| 
									
										
										
										
											2024-01-09 13:12:43 +00:00
										 |  |  | 	return r.root.Match(strings.Split(domain, ".")) | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Sort will sort the entire radix trie ensuring that | 
					
						
							|  |  |  | // child nodes are stored in alphabetical order. This | 
					
						
							| 
									
										
										
										
											2023-09-21 12:12:04 +02:00
										 |  |  | // MUST be done to finalize the domain cache in order | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | // to speed up the binary search of node child parts. | 
					
						
							|  |  |  | func (r *root) Sort() { | 
					
						
							|  |  |  | 	r.root.sort() | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-09 15:18:51 +01:00
										 |  |  | // String returns a string representation of node (and its descendants). | 
					
						
							|  |  |  | func (r *root) String() string { | 
					
						
							|  |  |  | 	buf := new(strings.Builder) | 
					
						
							| 
									
										
										
										
											2024-01-09 13:12:43 +00:00
										 |  |  | 	r.root.WriteStr(buf, "") | 
					
						
							| 
									
										
										
										
											2023-05-09 15:18:51 +01:00
										 |  |  | 	return buf.String() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | type node struct { | 
					
						
							|  |  |  | 	part  string | 
					
						
							|  |  |  | 	child []*node | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 13:12:43 +00:00
										 |  |  | func (n *node) Add(parts []string) { | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 	if len(parts) == 0 { | 
					
						
							|  |  |  | 		panic("invalid domain") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for { | 
					
						
							|  |  |  | 		// Pop next domain part. | 
					
						
							|  |  |  | 		i := len(parts) - 1 | 
					
						
							|  |  |  | 		part := parts[i] | 
					
						
							|  |  |  | 		parts = parts[:i] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		var nn *node | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Look for existing child node | 
					
						
							|  |  |  | 		// that matches next domain part. | 
					
						
							|  |  |  | 		for _, child := range n.child { | 
					
						
							|  |  |  | 			if child.part == part { | 
					
						
							|  |  |  | 				nn = child | 
					
						
							|  |  |  | 				break | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if nn == nil { | 
					
						
							|  |  |  | 			// Alloc new child node. | 
					
						
							|  |  |  | 			nn = &node{part: part} | 
					
						
							|  |  |  | 			n.child = append(n.child, nn) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if len(parts) == 0 { | 
					
						
							|  |  |  | 			// Drop all children here as | 
					
						
							| 
									
										
										
										
											2023-09-21 12:12:04 +02:00
										 |  |  | 			// this is a higher-level domain | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 			// than that we previously had. | 
					
						
							|  |  |  | 			nn.child = nil | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Re-iter with | 
					
						
							|  |  |  | 		// child node. | 
					
						
							|  |  |  | 		n = nn | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 13:12:43 +00:00
										 |  |  | func (n *node) Match(parts []string) bool { | 
					
						
							| 
									
										
										
										
											2023-05-09 15:18:51 +01:00
										 |  |  | 	for len(parts) > 0 { | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 		// Pop next domain part. | 
					
						
							|  |  |  | 		i := len(parts) - 1 | 
					
						
							|  |  |  | 		part := parts[i] | 
					
						
							|  |  |  | 		parts = parts[:i] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Look for existing child | 
					
						
							|  |  |  | 		// that matches next part. | 
					
						
							|  |  |  | 		nn := n.getChild(part) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if nn == nil { | 
					
						
							|  |  |  | 			// No match :( | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | 			return false | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 		if len(nn.child) == 0 { | 
					
						
							|  |  |  | 			// It's a match! | 
					
						
							|  |  |  | 			return true | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Re-iter with | 
					
						
							|  |  |  | 		// child node. | 
					
						
							|  |  |  | 		n = nn | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-05-09 15:18:51 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// Ran out of parts | 
					
						
							|  |  |  | 	// without a match. | 
					
						
							|  |  |  | 	return false | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // getChild fetches child node with given domain part string | 
					
						
							|  |  |  | // using a binary search. THIS ASSUMES CHILDREN ARE SORTED. | 
					
						
							|  |  |  | func (n *node) getChild(part string) *node { | 
					
						
							|  |  |  | 	i, j := 0, len(n.child) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for i < j { | 
					
						
							|  |  |  | 		// avoid overflow when computing h | 
					
						
							|  |  |  | 		h := int(uint(i+j) >> 1) | 
					
						
							|  |  |  | 		// i ≤ h < j | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if n.child[h].part < part { | 
					
						
							|  |  |  | 			// preserves: | 
					
						
							|  |  |  | 			// n.child[i-1].part != part | 
					
						
							|  |  |  | 			i = h + 1 | 
					
						
							|  |  |  | 		} else { | 
					
						
							|  |  |  | 			// preserves: | 
					
						
							|  |  |  | 			// n.child[h].part == part | 
					
						
							|  |  |  | 			j = h | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if i >= len(n.child) || n.child[i].part != part { | 
					
						
							|  |  |  | 		return nil // no match | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 	return n.child[i] | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (n *node) sort() { | 
					
						
							|  |  |  | 	// Sort this node's slice of child nodes. | 
					
						
							| 
									
										
										
										
											2024-01-09 13:12:43 +00:00
										 |  |  | 	slices.SortFunc(n.child, func(i, j *node) int { | 
					
						
							|  |  |  | 		const k = -1 | 
					
						
							|  |  |  | 		switch { | 
					
						
							|  |  |  | 		case i.part < j.part: | 
					
						
							|  |  |  | 			return +k | 
					
						
							|  |  |  | 		case i.part > j.part: | 
					
						
							|  |  |  | 			return -k | 
					
						
							|  |  |  | 		default: | 
					
						
							|  |  |  | 			return 0 | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2023-05-01 11:36:46 +01:00
										 |  |  | 	}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Sort each child node's children. | 
					
						
							|  |  |  | 	for _, child := range n.child { | 
					
						
							|  |  |  | 		child.sort() | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2022-12-14 09:55:36 +00:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2023-05-09 15:18:51 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 13:12:43 +00:00
										 |  |  | func (n *node) WriteStr(buf *strings.Builder, prefix string) { | 
					
						
							| 
									
										
										
										
											2023-05-09 15:18:51 +01:00
										 |  |  | 	if prefix != "" { | 
					
						
							|  |  |  | 		// Suffix joining '.' | 
					
						
							|  |  |  | 		prefix += "." | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Append current part. | 
					
						
							|  |  |  | 	prefix += n.part | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Dump current prefix state. | 
					
						
							|  |  |  | 	buf.WriteString(prefix) | 
					
						
							|  |  |  | 	buf.WriteByte('\n') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Iterate through node children. | 
					
						
							|  |  |  | 	for _, child := range n.child { | 
					
						
							| 
									
										
										
										
											2024-01-09 13:12:43 +00:00
										 |  |  | 		child.WriteStr(buf, prefix) | 
					
						
							| 
									
										
										
										
											2023-05-09 15:18:51 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | } |