mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-30 15:42:25 -05:00 
			
		
		
		
	[feature] domain block wildcarding (#1178)
* for domain block lookups, lookup along subdomain parts
Signed-off-by: kim <grufwub@gmail.com>
* only lookup up to a max of 5 domain parts to prevent DOS, limit inserted domains to max of 5 subdomains
Signed-off-by: kim <grufwub@gmail.com>
* add test for domain block wildcarding
Signed-off-by: kim <grufwub@gmail.com>
* check cached status first, increase cached domain time
Signed-off-by: kim <grufwub@gmail.com>
* fix domain wildcard part building logic
Signed-off-by: kim <grufwub@gmail.com>
* create separate domain.BlockCache{} type to hold all domain blocks in memory
Signed-off-by: kim <grufwub@gmail.com>
* remove unused variable
Signed-off-by: kim <grufwub@gmail.com>
* add docs and test to domain block cache, check for domain == host in domain block getter funcs
Signed-off-by: kim <grufwub@gmail.com>
* add license text
Signed-off-by: kim <grufwub@gmail.com>
* check order in which we check primary cache
Signed-off-by: kim <grufwub@gmail.com>
* add better documentation of how domain block checking is performed
Signed-off-by: kim <grufwub@gmail.com>
* change
Signed-off-by: kim <grufwub@gmail.com>
Signed-off-by: kim <grufwub@gmail.com>
	
	
This commit is contained in:
		
					parent
					
						
							
								8703933df4
							
						
					
				
			
			
				commit
				
					
						69dd5fed2c
					
				
			
		
					 5 changed files with 350 additions and 39 deletions
				
			
		
							
								
								
									
										170
									
								
								internal/cache/domain/domain.go
									
										
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										170
									
								
								internal/cache/domain/domain.go
									
										
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,170 @@ | ||||||
|  | /* | ||||||
|  |    GoToSocial | ||||||
|  |    Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org | ||||||
|  | 
 | ||||||
|  |    This program is free software: you can redistribute it and/or modify | ||||||
|  |    it under the terms of the GNU Affero General Public License as published by | ||||||
|  |    the Free Software Foundation, either version 3 of the License, or | ||||||
|  |    (at your option) any later version. | ||||||
|  | 
 | ||||||
|  |    This program is distributed in the hope that it will be useful, | ||||||
|  |    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  |    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  |    GNU Affero General Public License for more details. | ||||||
|  | 
 | ||||||
|  |    You should have received a copy of the GNU Affero General Public License | ||||||
|  |    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | */ | ||||||
|  | 
 | ||||||
|  | package domain | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"codeberg.org/gruf/go-cache/v3/ttl" | ||||||
|  | 	"github.com/miekg/dns" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // BlockCache provides a means of caching domain blocks in memory to reduce load | ||||||
|  | // on an underlying storage mechanism, e.g. a database. | ||||||
|  | // | ||||||
|  | // It consists of a TTL primary cache that stores calculated domain string to block results, | ||||||
|  | // that on cache miss is filled by calculating block status by iterating over a list of all of | ||||||
|  | // the domain blocks stored in memory. This reduces CPU usage required by not need needing to | ||||||
|  | // iterate through a possible 100-1000s long block list, while saving memory by having a primary | ||||||
|  | // cache of limited size that evicts stale entries. The raw list of all domain blocks should in | ||||||
|  | // most cases be negligible when it comes to memory usage. | ||||||
|  | // | ||||||
|  | // The in-memory block list is kept up-to-date by means of a passed loader function during every | ||||||
|  | // call to .IsBlocked(). In the case of a nil internal block list, the loader function is called to | ||||||
|  | // hydrate the cache with the latest list of domain blocks. The .Clear() function can be used to invalidate | ||||||
|  | // the cache, e.g. when a domain block is added / deleted from the database. It will drop the current | ||||||
|  | // list of domain blocks and clear all entries from the primary cache. | ||||||
|  | type BlockCache struct { | ||||||
|  | 	pcache *ttl.Cache[string, bool] // primary cache of domains -> block results | ||||||
|  | 	blocks []block                  // raw list of all domain blocks, nil => not loaded. | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New returns a new initialized BlockCache instance with given primary cache capacity and TTL. | ||||||
|  | func New(pcap int, pttl time.Duration) *BlockCache { | ||||||
|  | 	c := new(BlockCache) | ||||||
|  | 	c.pcache = new(ttl.Cache[string, bool]) | ||||||
|  | 	c.pcache.Init(0, pcap, pttl) | ||||||
|  | 	return c | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Start will start the cache background eviction routine with given sweep frequency. If already running or a freq <= 0 provided, this is a no-op. This will block until the eviction routine has started. | ||||||
|  | func (b *BlockCache) Start(pfreq time.Duration) bool { | ||||||
|  | 	return b.pcache.Start(pfreq) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Stop will stop cache background eviction routine. If not running this is a no-op. This will block until the eviction routine has stopped. | ||||||
|  | func (b *BlockCache) Stop() bool { | ||||||
|  | 	return b.pcache.Stop() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // IsBlocked checks whether domain is blocked. If the cache is not currently loaded, then the provided load function is used to hydrate it. | ||||||
|  | // NOTE: be VERY careful using any kind of locking mechanism within the load function, as this itself is ran within the cache mutex lock. | ||||||
|  | func (b *BlockCache) IsBlocked(domain string, load func() ([]string, error)) (bool, error) { | ||||||
|  | 	var blocked bool | ||||||
|  | 
 | ||||||
|  | 	// Acquire cache lock | ||||||
|  | 	b.pcache.Lock() | ||||||
|  | 	defer b.pcache.Unlock() | ||||||
|  | 
 | ||||||
|  | 	// Check primary cache for result | ||||||
|  | 	entry, ok := b.pcache.Cache.Get(domain) | ||||||
|  | 	if ok { | ||||||
|  | 		return entry.Value, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if b.blocks == nil { | ||||||
|  | 		// Cache is not hydrated | ||||||
|  | 		// | ||||||
|  | 		// Load domains from callback | ||||||
|  | 		domains, err := load() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return false, fmt.Errorf("error reloading cache: %w", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Drop all domain blocks and recreate | ||||||
|  | 		b.blocks = make([]block, len(domains)) | ||||||
|  | 
 | ||||||
|  | 		for i, domain := range domains { | ||||||
|  | 			// Store pre-split labels for each domain block | ||||||
|  | 			b.blocks[i].labels = dns.SplitDomainName(domain) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Split domain into it separate labels | ||||||
|  | 	labels := dns.SplitDomainName(domain) | ||||||
|  | 
 | ||||||
|  | 	// Compare this to our stored blocks | ||||||
|  | 	for _, block := range b.blocks { | ||||||
|  | 		if block.Blocks(labels) { | ||||||
|  | 			blocked = true | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Store block result in primary cache | ||||||
|  | 	b.pcache.Cache.Set(domain, &ttl.Entry[string, bool]{ | ||||||
|  | 		Key:    domain, | ||||||
|  | 		Value:  blocked, | ||||||
|  | 		Expiry: time.Now().Add(b.pcache.TTL), | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	return blocked, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Clear will drop the currently loaded domain list, and clear the primary cache. | ||||||
|  | // This will trigger a reload on next call to .IsBlocked(). | ||||||
|  | func (b *BlockCache) Clear() { | ||||||
|  | 	// Drop all blocks. | ||||||
|  | 	b.pcache.Lock() | ||||||
|  | 	b.blocks = nil | ||||||
|  | 	b.pcache.Unlock() | ||||||
|  | 
 | ||||||
|  | 	// Clear needs to be done _outside_ of | ||||||
|  | 	// lock, as also acquires a mutex lock. | ||||||
|  | 	b.pcache.Clear() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // block represents a domain block, and stores the | ||||||
|  | // deconstructed labels of a singular domain block. | ||||||
|  | // e.g. []string{"gts", "superseriousbusiness", "org"}. | ||||||
|  | type block struct { | ||||||
|  | 	labels []string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Blocks checks whether the separated domain labels of an | ||||||
|  | // incoming domain matches the stored (receiving struct) block. | ||||||
|  | func (b block) Blocks(labels []string) bool { | ||||||
|  | 	// Calculate length difference | ||||||
|  | 	d := len(labels) - len(b.labels) | ||||||
|  | 	if d < 0 { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Iterate backwards through domain block's | ||||||
|  | 	// labels, omparing against the incoming domain's. | ||||||
|  | 	// | ||||||
|  | 	// So for the following input: | ||||||
|  | 	// labels   = []string{"mail", "google", "com"} | ||||||
|  | 	// b.labels = []string{"google", "com"} | ||||||
|  | 	// | ||||||
|  | 	// These would be matched in reverse order along | ||||||
|  | 	// the entirety of the block object's labels: | ||||||
|  | 	// "com"    => match | ||||||
|  | 	// "google" => match | ||||||
|  | 	// | ||||||
|  | 	// And so would reach the end and return true. | ||||||
|  | 	for i := len(b.labels) - 1; i >= 0; i-- { | ||||||
|  | 		if b.labels[i] != labels[i+d] { | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return true | ||||||
|  | } | ||||||
							
								
								
									
										85
									
								
								internal/cache/domain/domain_test.go
									
										
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								internal/cache/domain/domain_test.go
									
										
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,85 @@ | ||||||
|  | /* | ||||||
|  |    GoToSocial | ||||||
|  |    Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org | ||||||
|  | 
 | ||||||
|  |    This program is free software: you can redistribute it and/or modify | ||||||
|  |    it under the terms of the GNU Affero General Public License as published by | ||||||
|  |    the Free Software Foundation, either version 3 of the License, or | ||||||
|  |    (at your option) any later version. | ||||||
|  | 
 | ||||||
|  |    This program is distributed in the hope that it will be useful, | ||||||
|  |    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  |    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  |    GNU Affero General Public License for more details. | ||||||
|  | 
 | ||||||
|  |    You should have received a copy of the GNU Affero General Public License | ||||||
|  |    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | */ | ||||||
|  | 
 | ||||||
|  | package domain_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/cache/domain" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestBlockCache(t *testing.T) { | ||||||
|  | 	c := domain.New(100, time.Second) | ||||||
|  | 
 | ||||||
|  | 	blocks := []string{ | ||||||
|  | 		"google.com", | ||||||
|  | 		"google.co.uk", | ||||||
|  | 		"pleroma.bad.host", | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	loader := func() ([]string, error) { | ||||||
|  | 		t.Log("load: returning blocked domains") | ||||||
|  | 		return blocks, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check a list of known blocked domains. | ||||||
|  | 	for _, domain := range []string{ | ||||||
|  | 		"google.com", | ||||||
|  | 		"mail.google.com", | ||||||
|  | 		"google.co.uk", | ||||||
|  | 		"mail.google.co.uk", | ||||||
|  | 		"pleroma.bad.host", | ||||||
|  | 		"dev.pleroma.bad.host", | ||||||
|  | 	} { | ||||||
|  | 		t.Logf("checking domain is blocked: %s", domain) | ||||||
|  | 		if b, _ := c.IsBlocked(domain, loader); !b { | ||||||
|  | 			t.Errorf("domain should be blocked: %s", domain) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check a list of known unblocked domains. | ||||||
|  | 	for _, domain := range []string{ | ||||||
|  | 		"askjeeves.com", | ||||||
|  | 		"ask-kim.co.uk", | ||||||
|  | 		"google.ie", | ||||||
|  | 		"mail.google.ie", | ||||||
|  | 		"gts.bad.host", | ||||||
|  | 		"mastodon.bad.host", | ||||||
|  | 	} { | ||||||
|  | 		t.Logf("checking domain isn't blocked: %s", domain) | ||||||
|  | 		if b, _ := c.IsBlocked(domain, loader); b { | ||||||
|  | 			t.Errorf("domain should not be blocked: %s", domain) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Clear the cache | ||||||
|  | 	c.Clear() | ||||||
|  | 
 | ||||||
|  | 	knownErr := errors.New("known error") | ||||||
|  | 
 | ||||||
|  | 	// Check that reload is actually performed and returns our error | ||||||
|  | 	if _, err := c.IsBlocked("", func() ([]string, error) { | ||||||
|  | 		t.Log("load: returning known error") | ||||||
|  | 		return nil, knownErr | ||||||
|  | 	}); !errors.Is(err, knownErr) { | ||||||
|  | 		t.Errorf("is blocked did not return expected error: %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										21
									
								
								internal/cache/gts.go
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								internal/cache/gts.go
									
										
									
									
										vendored
									
									
								
							|  | @ -20,6 +20,7 @@ package cache | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"codeberg.org/gruf/go-cache/v3/result" | 	"codeberg.org/gruf/go-cache/v3/result" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/cache/domain" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| ) | ) | ||||||
|  | @ -41,8 +42,8 @@ type GTSCaches interface { | ||||||
| 	// Block provides access to the gtsmodel Block (account) database cache. | 	// Block provides access to the gtsmodel Block (account) database cache. | ||||||
| 	Block() *result.Cache[*gtsmodel.Block] | 	Block() *result.Cache[*gtsmodel.Block] | ||||||
| 
 | 
 | ||||||
| 	// DomainBlock provides access to the gtsmodel DomainBlock database cache. | 	// DomainBlock provides access to the domain block database cache. | ||||||
| 	DomainBlock() *result.Cache[*gtsmodel.DomainBlock] | 	DomainBlock() *domain.BlockCache | ||||||
| 
 | 
 | ||||||
| 	// Emoji provides access to the gtsmodel Emoji database cache. | 	// Emoji provides access to the gtsmodel Emoji database cache. | ||||||
| 	Emoji() *result.Cache[*gtsmodel.Emoji] | 	Emoji() *result.Cache[*gtsmodel.Emoji] | ||||||
|  | @ -74,7 +75,7 @@ func NewGTS() GTSCaches { | ||||||
| type gtsCaches struct { | type gtsCaches struct { | ||||||
| 	account       *result.Cache[*gtsmodel.Account] | 	account       *result.Cache[*gtsmodel.Account] | ||||||
| 	block         *result.Cache[*gtsmodel.Block] | 	block         *result.Cache[*gtsmodel.Block] | ||||||
| 	domainBlock   *result.Cache[*gtsmodel.DomainBlock] | 	domainBlock   *domain.BlockCache | ||||||
| 	emoji         *result.Cache[*gtsmodel.Emoji] | 	emoji         *result.Cache[*gtsmodel.Emoji] | ||||||
| 	emojiCategory *result.Cache[*gtsmodel.EmojiCategory] | 	emojiCategory *result.Cache[*gtsmodel.EmojiCategory] | ||||||
| 	mention       *result.Cache[*gtsmodel.Mention] | 	mention       *result.Cache[*gtsmodel.Mention] | ||||||
|  | @ -151,7 +152,7 @@ func (c *gtsCaches) Block() *result.Cache[*gtsmodel.Block] { | ||||||
| 	return c.block | 	return c.block | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (c *gtsCaches) DomainBlock() *result.Cache[*gtsmodel.DomainBlock] { | func (c *gtsCaches) DomainBlock() *domain.BlockCache { | ||||||
| 	return c.domainBlock | 	return c.domainBlock | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -212,14 +213,10 @@ func (c *gtsCaches) initBlock() { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (c *gtsCaches) initDomainBlock() { | func (c *gtsCaches) initDomainBlock() { | ||||||
| 	c.domainBlock = result.NewSized([]result.Lookup{ | 	c.domainBlock = domain.New( | ||||||
| 		{Name: "Domain"}, | 		config.GetCacheGTSDomainBlockMaxSize(), | ||||||
| 	}, func(d1 *gtsmodel.DomainBlock) *gtsmodel.DomainBlock { | 		config.GetCacheGTSDomainBlockTTL(), | ||||||
| 		d2 := new(gtsmodel.DomainBlock) | 	) | ||||||
| 		*d2 = *d1 |  | ||||||
| 		return d2 |  | ||||||
| 	}, config.GetCacheGTSDomainBlockMaxSize()) |  | ||||||
| 	c.domainBlock.SetTTL(config.GetCacheGTSDomainBlockTTL(), true) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (c *gtsCaches) initEmoji() { | func (c *gtsCaches) initEmoji() { | ||||||
|  |  | ||||||
|  | @ -50,46 +50,52 @@ func normalizeDomain(domain string) (out string, err error) { | ||||||
| func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error { | func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error { | ||||||
| 	var err error | 	var err error | ||||||
| 
 | 
 | ||||||
|  | 	// Normalize the domain as punycode | ||||||
| 	block.Domain, err = normalizeDomain(block.Domain) | 	block.Domain, err = normalizeDomain(block.Domain) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return d.state.Caches.GTS.DomainBlock().Store(block, func() error { | 	// Attempt to store domain in DB | ||||||
| 		_, err := d.conn.NewInsert(). | 	if _, err := d.conn.NewInsert(). | ||||||
| 		Model(block). | 		Model(block). | ||||||
| 			Exec(ctx) | 		Exec(ctx); err != nil { | ||||||
| 		return d.conn.ProcessError(err) | 		return d.conn.ProcessError(err) | ||||||
| 	}) | 	} | ||||||
|  | 
 | ||||||
|  | 	// Clear the domain block cache (for later reload) | ||||||
|  | 	d.state.Caches.GTS.DomainBlock().Clear() | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { | func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { | ||||||
| 	var err error | 	var err error | ||||||
| 
 | 
 | ||||||
|  | 	// Normalize the domain as punycode | ||||||
| 	domain, err = normalizeDomain(domain) | 	domain, err = normalizeDomain(domain) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return d.state.Caches.GTS.DomainBlock().Load("Domain", func() (*gtsmodel.DomainBlock, error) { |  | ||||||
| 	// Check for easy case, domain referencing *us* | 	// Check for easy case, domain referencing *us* | ||||||
| 		if domain == "" || domain == config.GetAccountDomain() { | 	if domain == "" || domain == config.GetAccountDomain() || | ||||||
|  | 		domain == config.GetHost() { | ||||||
| 		return nil, db.ErrNoEntries | 		return nil, db.ErrNoEntries | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var block gtsmodel.DomainBlock | 	var block gtsmodel.DomainBlock | ||||||
| 
 | 
 | ||||||
|  | 	// Look for block matching domain in DB | ||||||
| 	q := d.conn. | 	q := d.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(&block). | 		Model(&block). | ||||||
| 			Where("? = ?", bun.Ident("domain_block.domain"), domain). | 		Where("? = ?", bun.Ident("domain_block.domain"), domain) | ||||||
| 			Limit(1) |  | ||||||
| 	if err := q.Scan(ctx); err != nil { | 	if err := q.Scan(ctx); err != nil { | ||||||
| 		return nil, d.conn.ProcessError(err) | 		return nil, d.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &block, nil | 	return &block, nil | ||||||
| 	}, domain) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { | func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { | ||||||
|  | @ -108,20 +114,41 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro | ||||||
| 		return d.conn.ProcessError(err) | 		return d.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Clear domain from cache | 	// Clear the domain block cache (for later reload) | ||||||
| 	d.state.Caches.GTS.DomainBlock().Invalidate(domain) | 	d.state.Caches.GTS.DomainBlock().Clear() | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { | func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { | ||||||
| 	block, err := d.GetDomainBlock(ctx, domain) | 	// Normalize the domain as punycode | ||||||
| 	if err == nil || err == db.ErrNoEntries { | 	domain, err := normalizeDomain(domain) | ||||||
| 		return (block != nil), nil | 	if err != nil { | ||||||
| 	} |  | ||||||
| 		return false, err | 		return false, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Check for easy case, domain referencing *us* | ||||||
|  | 	if domain == "" || domain == config.GetAccountDomain() || | ||||||
|  | 		domain == config.GetHost() { | ||||||
|  | 		return false, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check the cache for a domain block (hydrating the cache with callback if necessary) | ||||||
|  | 	return d.state.Caches.GTS.DomainBlock().IsBlocked(domain, func() ([]string, error) { | ||||||
|  | 		var domains []string | ||||||
|  | 
 | ||||||
|  | 		// Scan list of all blocked domains from DB | ||||||
|  | 		q := d.conn.NewSelect(). | ||||||
|  | 			Table("domain_blocks"). | ||||||
|  | 			Column("domain") | ||||||
|  | 		if err := q.Scan(ctx, &domains); err != nil { | ||||||
|  | 			return nil, d.conn.ProcessError(err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return domains, nil | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { | func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { | ||||||
| 	for _, domain := range domains { | 	for _, domain := range domains { | ||||||
| 		if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil { | 		if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil { | ||||||
|  |  | ||||||
|  | @ -56,6 +56,38 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() { | ||||||
| 	suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second) | 	suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (suite *DomainTestSuite) TestIsDomainBlockedWildcard() { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	domainBlock := >smodel.DomainBlock{ | ||||||
|  | 		ID:                 "01G204214Y9TNJEBX39C7G88SW", | ||||||
|  | 		Domain:             "bad.apples", | ||||||
|  | 		CreatedByAccountID: suite.testAccounts["admin_account"].ID, | ||||||
|  | 		CreatedByAccount:   suite.testAccounts["admin_account"], | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// no domain block exists for the given domain yet | ||||||
|  | 	blocked, err := suite.db.IsDomainBlocked(ctx, domainBlock.Domain) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.False(blocked) | ||||||
|  | 
 | ||||||
|  | 	err = suite.db.CreateDomainBlock(ctx, domainBlock) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 	// Start with the base block domain | ||||||
|  | 	domain := domainBlock.Domain | ||||||
|  | 
 | ||||||
|  | 	for _, part := range []string{"extra", "domain", "parts"} { | ||||||
|  | 		// Prepend the next domain part | ||||||
|  | 		domain = part + "." + domain | ||||||
|  | 
 | ||||||
|  | 		// Check that domain block is wildcarded for this subdomain | ||||||
|  | 		blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 		suite.True(blocked) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII() { | func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII() { | ||||||
| 	ctx := context.Background() | 	ctx := context.Background() | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue