mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-30 20:12:26 -05:00 
			
		
		
		
	[bugfix] Close reader gracefully when streaming recache of remote media to fileserver api caller (#1281)
* close pipereader on failed data function * gently slurp the bytes * readability updates * go fmt * tidy up file server tests + add more cases * start moving io wrappers to separate iotools package. Remove use of buffering while piping recache stream Signed-off-by: kim <grufwub@gmail.com> * add license text Signed-off-by: kim <grufwub@gmail.com> Co-authored-by: kim <grufwub@gmail.com>
This commit is contained in:
		
					parent
					
						
							
								0871f5d181
							
						
					
				
			
			
				commit
				
					
						6ebdc306ed
					
				
			
		
					 8 changed files with 503 additions and 214 deletions
				
			
		
							
								
								
									
										109
									
								
								internal/api/client/fileserver/fileserver_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								internal/api/client/fileserver/fileserver_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,109 @@ | |||
| /* | ||||
|    GoToSocial | ||||
|    Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org | ||||
| 
 | ||||
|    This program is free software: you can redistribute it and/or modify | ||||
|    it under the terms of the GNU Affero General Public License as published by | ||||
|    the Free Software Foundation, either version 3 of the License, or | ||||
|    (at your option) any later version. | ||||
| 
 | ||||
|    This program is distributed in the hope that it will be useful, | ||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
|    GNU Affero General Public License for more details. | ||||
| 
 | ||||
|    You should have received a copy of the GNU Affero General Public License | ||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package fileserver_test | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/suite" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/concurrency" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/email" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/federation" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/media" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/messages" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/storage" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/typeutils" | ||||
| 	"github.com/superseriousbusiness/gotosocial/testrig" | ||||
| ) | ||||
| 
 | ||||
| type FileserverTestSuite struct { | ||||
| 	// standard suite interfaces | ||||
| 	suite.Suite | ||||
| 	db           db.DB | ||||
| 	storage      *storage.Driver | ||||
| 	federator    federation.Federator | ||||
| 	tc           typeutils.TypeConverter | ||||
| 	processor    processing.Processor | ||||
| 	mediaManager media.Manager | ||||
| 	oauthServer  oauth.Server | ||||
| 	emailSender  email.Sender | ||||
| 
 | ||||
| 	// standard suite models | ||||
| 	testTokens       map[string]*gtsmodel.Token | ||||
| 	testClients      map[string]*gtsmodel.Client | ||||
| 	testApplications map[string]*gtsmodel.Application | ||||
| 	testUsers        map[string]*gtsmodel.User | ||||
| 	testAccounts     map[string]*gtsmodel.Account | ||||
| 	testAttachments  map[string]*gtsmodel.MediaAttachment | ||||
| 
 | ||||
| 	// item being tested | ||||
| 	fileServer *fileserver.FileServer | ||||
| } | ||||
| 
 | ||||
| /* | ||||
| 	TEST INFRASTRUCTURE | ||||
| */ | ||||
| 
 | ||||
| func (suite *FileserverTestSuite) SetupSuite() { | ||||
| 	testrig.InitTestConfig() | ||||
| 	testrig.InitTestLog() | ||||
| 
 | ||||
| 	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) | ||||
| 	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) | ||||
| 
 | ||||
| 	suite.db = testrig.NewTestDB() | ||||
| 	suite.storage = testrig.NewInMemoryStorage() | ||||
| 	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) | ||||
| 	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) | ||||
| 
 | ||||
| 	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, testrig.NewTestMediaManager(suite.db, suite.storage), clientWorker, fedWorker) | ||||
| 	suite.tc = testrig.NewTestTypeConverter(suite.db) | ||||
| 	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) | ||||
| 	suite.oauthServer = testrig.NewTestOauthServer(suite.db) | ||||
| 
 | ||||
| 	suite.fileServer = fileserver.New(suite.processor).(*fileserver.FileServer) | ||||
| } | ||||
| 
 | ||||
| func (suite *FileserverTestSuite) SetupTest() { | ||||
| 	testrig.StandardDBSetup(suite.db, nil) | ||||
| 	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") | ||||
| 	suite.testTokens = testrig.NewTestTokens() | ||||
| 	suite.testClients = testrig.NewTestClients() | ||||
| 	suite.testApplications = testrig.NewTestApplications() | ||||
| 	suite.testUsers = testrig.NewTestUsers() | ||||
| 	suite.testAccounts = testrig.NewTestAccounts() | ||||
| 	suite.testAttachments = testrig.NewTestAttachments() | ||||
| } | ||||
| 
 | ||||
| func (suite *FileserverTestSuite) TearDownSuite() { | ||||
| 	if err := suite.db.Stop(context.Background()); err != nil { | ||||
| 		log.Panicf("error closing db connection: %s", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (suite *FileserverTestSuite) TearDownTest() { | ||||
| 	testrig.StandardDBTeardown(suite.db) | ||||
| 	testrig.StandardStorageTeardown(suite.storage) | ||||
| } | ||||
|  | @ -19,7 +19,9 @@ | |||
| package fileserver | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 
 | ||||
|  | @ -120,5 +122,14 @@ func (m *FileServer) ServeFile(c *gin.Context) { | |||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	c.DataFromReader(http.StatusOK, content.ContentLength, format, content.Content, nil) | ||||
| 	// try to slurp the first few bytes to make sure we have something | ||||
| 	b := bytes.NewBuffer(make([]byte, 0, 64)) | ||||
| 	if _, err := io.CopyN(b, content.Content, 64); err != nil { | ||||
| 		err = fmt.Errorf("ServeFile: error reading from content: %w", err) | ||||
| 		api.ErrorHandler(c, gtserror.NewErrorNotFound(err, err.Error()), m.processor.InstanceGet) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// we're good, return the slurped bytes + the rest of the content | ||||
| 	c.DataFromReader(http.StatusOK, content.ContentLength, format, io.MultiReader(b, content.Content), nil) | ||||
| } | ||||
|  |  | |||
|  | @ -20,196 +20,251 @@ package fileserver_test | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/stretchr/testify/suite" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/concurrency" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/email" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/federation" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/media" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/messages" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/storage" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/typeutils" | ||||
| 	"github.com/superseriousbusiness/gotosocial/testrig" | ||||
| ) | ||||
| 
 | ||||
| type ServeFileTestSuite struct { | ||||
| 	// standard suite interfaces | ||||
| 	suite.Suite | ||||
| 	db           db.DB | ||||
| 	storage      *storage.Driver | ||||
| 	federator    federation.Federator | ||||
| 	tc           typeutils.TypeConverter | ||||
| 	processor    processing.Processor | ||||
| 	mediaManager media.Manager | ||||
| 	oauthServer  oauth.Server | ||||
| 	emailSender  email.Sender | ||||
| 
 | ||||
| 	// standard suite models | ||||
| 	testTokens       map[string]*gtsmodel.Token | ||||
| 	testClients      map[string]*gtsmodel.Client | ||||
| 	testApplications map[string]*gtsmodel.Application | ||||
| 	testUsers        map[string]*gtsmodel.User | ||||
| 	testAccounts     map[string]*gtsmodel.Account | ||||
| 	testAttachments  map[string]*gtsmodel.MediaAttachment | ||||
| 
 | ||||
| 	// item being tested | ||||
| 	fileServer *fileserver.FileServer | ||||
| 	FileserverTestSuite | ||||
| } | ||||
| 
 | ||||
| /* | ||||
| 	TEST INFRASTRUCTURE | ||||
| */ | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) SetupSuite() { | ||||
| 	// setup standard items | ||||
| 	testrig.InitTestConfig() | ||||
| 	testrig.InitTestLog() | ||||
| 
 | ||||
| 	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) | ||||
| 	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) | ||||
| 
 | ||||
| 	suite.db = testrig.NewTestDB() | ||||
| 	suite.storage = testrig.NewInMemoryStorage() | ||||
| 	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) | ||||
| 	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) | ||||
| 
 | ||||
| 	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, testrig.NewTestMediaManager(suite.db, suite.storage), clientWorker, fedWorker) | ||||
| 	suite.tc = testrig.NewTestTypeConverter(suite.db) | ||||
| 	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) | ||||
| 	suite.oauthServer = testrig.NewTestOauthServer(suite.db) | ||||
| 
 | ||||
| 	// setup module being tested | ||||
| 	suite.fileServer = fileserver.New(suite.processor).(*fileserver.FileServer) | ||||
| } | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) TearDownSuite() { | ||||
| 	if err := suite.db.Stop(context.Background()); err != nil { | ||||
| 		log.Panicf("error closing db connection: %s", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) SetupTest() { | ||||
| 	testrig.StandardDBSetup(suite.db, nil) | ||||
| 	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") | ||||
| 	suite.testTokens = testrig.NewTestTokens() | ||||
| 	suite.testClients = testrig.NewTestClients() | ||||
| 	suite.testApplications = testrig.NewTestApplications() | ||||
| 	suite.testUsers = testrig.NewTestUsers() | ||||
| 	suite.testAccounts = testrig.NewTestAccounts() | ||||
| 	suite.testAttachments = testrig.NewTestAttachments() | ||||
| } | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) TearDownTest() { | ||||
| 	testrig.StandardDBTeardown(suite.db) | ||||
| 	testrig.StandardStorageTeardown(suite.storage) | ||||
| } | ||||
| 
 | ||||
| /* | ||||
| 	ACTUAL TESTS | ||||
| */ | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) TestServeOriginalFileSuccessful() { | ||||
| 	targetAttachment, ok := suite.testAttachments["admin_account_status_1_attachment_1"] | ||||
| 	suite.True(ok) | ||||
| 	suite.NotNil(targetAttachment) | ||||
| 
 | ||||
| // GetFile is just a convenience function to save repetition in this test suite. | ||||
| // It takes the required params to serve a file, calls the handler, and returns | ||||
| // the http status code, the response headers, and the parsed body bytes. | ||||
| func (suite *ServeFileTestSuite) GetFile( | ||||
| 	accountID string, | ||||
| 	mediaType media.Type, | ||||
| 	mediaSize media.Size, | ||||
| 	filename string, | ||||
| ) (code int, headers http.Header, body []byte) { | ||||
| 	recorder := httptest.NewRecorder() | ||||
| 	ctx, _ := testrig.CreateGinTestContext(recorder, nil) | ||||
| 	ctx.Request = httptest.NewRequest(http.MethodGet, targetAttachment.URL, nil) | ||||
| 	ctx.Request.Header.Set("accept", "*/*") | ||||
| 
 | ||||
| 	// normally the router would populate these params from the path values, | ||||
| 	// but because we're calling the ServeFile function directly, we need to set them manually. | ||||
| 	ctx.Params = gin.Params{ | ||||
| 		gin.Param{ | ||||
| 			Key:   fileserver.AccountIDKey, | ||||
| 			Value: targetAttachment.AccountID, | ||||
| 		}, | ||||
| 		gin.Param{ | ||||
| 			Key:   fileserver.MediaTypeKey, | ||||
| 			Value: string(media.TypeAttachment), | ||||
| 		}, | ||||
| 		gin.Param{ | ||||
| 			Key:   fileserver.MediaSizeKey, | ||||
| 			Value: string(media.SizeOriginal), | ||||
| 		}, | ||||
| 		gin.Param{ | ||||
| 			Key:   fileserver.FileNameKey, | ||||
| 			Value: fmt.Sprintf("%s.jpeg", targetAttachment.ID), | ||||
| 		}, | ||||
| 	ctx, _ := testrig.CreateGinTestContext(recorder, nil) | ||||
| 	ctx.Request = httptest.NewRequest(http.MethodGet, "http://localhost:8080/whatever", nil) | ||||
| 	ctx.Request.Header.Set("accept", "*/*") | ||||
| 	ctx.AddParam(fileserver.AccountIDKey, accountID) | ||||
| 	ctx.AddParam(fileserver.MediaTypeKey, string(mediaType)) | ||||
| 	ctx.AddParam(fileserver.MediaSizeKey, string(mediaSize)) | ||||
| 	ctx.AddParam(fileserver.FileNameKey, filename) | ||||
| 
 | ||||
| 	suite.fileServer.ServeFile(ctx) | ||||
| 	code = recorder.Code | ||||
| 	headers = recorder.Result().Header | ||||
| 
 | ||||
| 	var err error | ||||
| 	body, err = ioutil.ReadAll(recorder.Body) | ||||
| 	if err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	// call the function we're testing and check status code | ||||
| 	suite.fileServer.ServeFile(ctx) | ||||
| 	suite.EqualValues(http.StatusOK, recorder.Code) | ||||
| 	suite.EqualValues("image/jpeg", recorder.Header().Get("content-type")) | ||||
| 
 | ||||
| 	b, err := ioutil.ReadAll(recorder.Body) | ||||
| 	suite.NoError(err) | ||||
| 	suite.NotNil(b) | ||||
| 
 | ||||
| 	fileInStorage, err := suite.storage.Get(ctx, targetAttachment.File.Path) | ||||
| 	suite.NoError(err) | ||||
| 	suite.NotNil(fileInStorage) | ||||
| 	suite.Equal(b, fileInStorage) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) TestServeSmallFileSuccessful() { | ||||
| 	targetAttachment, ok := suite.testAttachments["admin_account_status_1_attachment_1"] | ||||
| 	suite.True(ok) | ||||
| 	suite.NotNil(targetAttachment) | ||||
| // UncacheAttachment is a convenience function that uncaches the targetAttachment by | ||||
| // removing its associated files from storage, and updating the database. | ||||
| func (suite *ServeFileTestSuite) UncacheAttachment(targetAttachment *gtsmodel.MediaAttachment) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	recorder := httptest.NewRecorder() | ||||
| 	ctx, _ := testrig.CreateGinTestContext(recorder, nil) | ||||
| 	ctx.Request = httptest.NewRequest(http.MethodGet, targetAttachment.Thumbnail.URL, nil) | ||||
| 	ctx.Request.Header.Set("accept", "*/*") | ||||
| 	cached := false | ||||
| 	targetAttachment.Cached = &cached | ||||
| 
 | ||||
| 	// normally the router would populate these params from the path values, | ||||
| 	// but because we're calling the ServeFile function directly, we need to set them manually. | ||||
| 	ctx.Params = gin.Params{ | ||||
| 		gin.Param{ | ||||
| 			Key:   fileserver.AccountIDKey, | ||||
| 			Value: targetAttachment.AccountID, | ||||
| 		}, | ||||
| 		gin.Param{ | ||||
| 			Key:   fileserver.MediaTypeKey, | ||||
| 			Value: string(media.TypeAttachment), | ||||
| 		}, | ||||
| 		gin.Param{ | ||||
| 			Key:   fileserver.MediaSizeKey, | ||||
| 			Value: string(media.SizeSmall), | ||||
| 		}, | ||||
| 		gin.Param{ | ||||
| 			Key:   fileserver.FileNameKey, | ||||
| 			Value: fmt.Sprintf("%s.jpeg", targetAttachment.ID), | ||||
| 		}, | ||||
| 	if err := suite.db.UpdateByID(ctx, targetAttachment, targetAttachment.ID, "cached"); err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 	if err := suite.storage.Delete(ctx, targetAttachment.File.Path); err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 	if err := suite.storage.Delete(ctx, targetAttachment.Thumbnail.Path); err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) TestServeOriginalLocalFileOK() { | ||||
| 	targetAttachment := >smodel.MediaAttachment{} | ||||
| 	*targetAttachment = *suite.testAttachments["admin_account_status_1_attachment_1"] | ||||
| 	fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.File.Path) | ||||
| 	if err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	// call the function we're testing and check status code | ||||
| 	suite.fileServer.ServeFile(ctx) | ||||
| 	suite.EqualValues(http.StatusOK, recorder.Code) | ||||
| 	suite.EqualValues("image/jpeg", recorder.Header().Get("content-type")) | ||||
| 	code, headers, body := suite.GetFile( | ||||
| 		targetAttachment.AccountID, | ||||
| 		media.TypeAttachment, | ||||
| 		media.SizeOriginal, | ||||
| 		targetAttachment.ID+".jpeg", | ||||
| 	) | ||||
| 
 | ||||
| 	b, err := ioutil.ReadAll(recorder.Body) | ||||
| 	suite.NoError(err) | ||||
| 	suite.NotNil(b) | ||||
| 	suite.Equal(http.StatusOK, code) | ||||
| 	suite.Equal("image/jpeg", headers.Get("content-type")) | ||||
| 	suite.Equal(fileInStorage, body) | ||||
| } | ||||
| 
 | ||||
| 	fileInStorage, err := suite.storage.Get(ctx, targetAttachment.Thumbnail.Path) | ||||
| 	suite.NoError(err) | ||||
| 	suite.NotNil(fileInStorage) | ||||
| 	suite.Equal(b, fileInStorage) | ||||
| func (suite *ServeFileTestSuite) TestServeSmallLocalFileOK() { | ||||
| 	targetAttachment := >smodel.MediaAttachment{} | ||||
| 	*targetAttachment = *suite.testAttachments["admin_account_status_1_attachment_1"] | ||||
| 	fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.Thumbnail.Path) | ||||
| 	if err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	code, headers, body := suite.GetFile( | ||||
| 		targetAttachment.AccountID, | ||||
| 		media.TypeAttachment, | ||||
| 		media.SizeSmall, | ||||
| 		targetAttachment.ID+".jpeg", | ||||
| 	) | ||||
| 
 | ||||
| 	suite.Equal(http.StatusOK, code) | ||||
| 	suite.Equal("image/jpeg", headers.Get("content-type")) | ||||
| 	suite.Equal(fileInStorage, body) | ||||
| } | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) TestServeOriginalRemoteFileOK() { | ||||
| 	targetAttachment := >smodel.MediaAttachment{} | ||||
| 	*targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] | ||||
| 	fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.File.Path) | ||||
| 	if err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	code, headers, body := suite.GetFile( | ||||
| 		targetAttachment.AccountID, | ||||
| 		media.TypeAttachment, | ||||
| 		media.SizeOriginal, | ||||
| 		targetAttachment.ID+".jpeg", | ||||
| 	) | ||||
| 
 | ||||
| 	suite.Equal(http.StatusOK, code) | ||||
| 	suite.Equal("image/jpeg", headers.Get("content-type")) | ||||
| 	suite.Equal(fileInStorage, body) | ||||
| } | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) TestServeSmallRemoteFileOK() { | ||||
| 	targetAttachment := >smodel.MediaAttachment{} | ||||
| 	*targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] | ||||
| 	fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.Thumbnail.Path) | ||||
| 	if err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	code, headers, body := suite.GetFile( | ||||
| 		targetAttachment.AccountID, | ||||
| 		media.TypeAttachment, | ||||
| 		media.SizeSmall, | ||||
| 		targetAttachment.ID+".jpeg", | ||||
| 	) | ||||
| 
 | ||||
| 	suite.Equal(http.StatusOK, code) | ||||
| 	suite.Equal("image/jpeg", headers.Get("content-type")) | ||||
| 	suite.Equal(fileInStorage, body) | ||||
| } | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) TestServeOriginalRemoteFileRecache() { | ||||
| 	targetAttachment := >smodel.MediaAttachment{} | ||||
| 	*targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] | ||||
| 	fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.File.Path) | ||||
| 	if err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	// uncache the attachment so we'll have to refetch it from the 'remote' instance | ||||
| 	suite.UncacheAttachment(targetAttachment) | ||||
| 
 | ||||
| 	code, headers, body := suite.GetFile( | ||||
| 		targetAttachment.AccountID, | ||||
| 		media.TypeAttachment, | ||||
| 		media.SizeOriginal, | ||||
| 		targetAttachment.ID+".jpeg", | ||||
| 	) | ||||
| 
 | ||||
| 	suite.Equal(http.StatusOK, code) | ||||
| 	suite.Equal("image/jpeg", headers.Get("content-type")) | ||||
| 	suite.Equal(fileInStorage, body) | ||||
| } | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) TestServeSmallRemoteFileRecache() { | ||||
| 	targetAttachment := >smodel.MediaAttachment{} | ||||
| 	*targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] | ||||
| 	fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.Thumbnail.Path) | ||||
| 	if err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	// uncache the attachment so we'll have to refetch it from the 'remote' instance | ||||
| 	suite.UncacheAttachment(targetAttachment) | ||||
| 
 | ||||
| 	code, headers, body := suite.GetFile( | ||||
| 		targetAttachment.AccountID, | ||||
| 		media.TypeAttachment, | ||||
| 		media.SizeSmall, | ||||
| 		targetAttachment.ID+".jpeg", | ||||
| 	) | ||||
| 
 | ||||
| 	suite.Equal(http.StatusOK, code) | ||||
| 	suite.Equal("image/jpeg", headers.Get("content-type")) | ||||
| 	suite.Equal(fileInStorage, body) | ||||
| } | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) TestServeOriginalRemoteFileRecacheNotFound() { | ||||
| 	targetAttachment := >smodel.MediaAttachment{} | ||||
| 	*targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] | ||||
| 
 | ||||
| 	// uncache the attachment *and* set the remote URL to something that will return a 404 | ||||
| 	suite.UncacheAttachment(targetAttachment) | ||||
| 	targetAttachment.RemoteURL = "http://nothing.at.this.url/weeeeeeeee" | ||||
| 	if err := suite.db.UpdateByID(context.Background(), targetAttachment, targetAttachment.ID, "remote_url"); err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	code, _, _ := suite.GetFile( | ||||
| 		targetAttachment.AccountID, | ||||
| 		media.TypeAttachment, | ||||
| 		media.SizeOriginal, | ||||
| 		targetAttachment.ID+".jpeg", | ||||
| 	) | ||||
| 
 | ||||
| 	suite.Equal(http.StatusNotFound, code) | ||||
| } | ||||
| 
 | ||||
| func (suite *ServeFileTestSuite) TestServeSmallRemoteFileRecacheNotFound() { | ||||
| 	targetAttachment := >smodel.MediaAttachment{} | ||||
| 	*targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] | ||||
| 
 | ||||
| 	// uncache the attachment *and* set the remote URL to something that will return a 404 | ||||
| 	suite.UncacheAttachment(targetAttachment) | ||||
| 	targetAttachment.RemoteURL = "http://nothing.at.this.url/weeeeeeeee" | ||||
| 	if err := suite.db.UpdateByID(context.Background(), targetAttachment, targetAttachment.ID, "remote_url"); err != nil { | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	code, _, _ := suite.GetFile( | ||||
| 		targetAttachment.AccountID, | ||||
| 		media.TypeAttachment, | ||||
| 		media.SizeSmall, | ||||
| 		targetAttachment.ID+".jpeg", | ||||
| 	) | ||||
| 
 | ||||
| 	suite.Equal(http.StatusNotFound, code) | ||||
| } | ||||
| 
 | ||||
| // Callers trying to get some random-ass file that doesn't exist should just get a 404 | ||||
| func (suite *ServeFileTestSuite) TestServeFileNotFound() { | ||||
| 	code, _, _ := suite.GetFile( | ||||
| 		"01GMMY4G9B0QEG0PQK5Q5JGJWZ", | ||||
| 		media.TypeAttachment, | ||||
| 		media.SizeOriginal, | ||||
| 		"01GMMY68Y7E5DJ3CA3Y9SS8524.jpeg", | ||||
| 	) | ||||
| 
 | ||||
| 	suite.Equal(http.StatusNotFound, code) | ||||
| } | ||||
| 
 | ||||
| func TestServeFileTestSuite(t *testing.T) { | ||||
|  |  | |||
							
								
								
									
										121
									
								
								internal/iotools/io.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								internal/iotools/io.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,121 @@ | |||
| /* | ||||
|    GoToSocial | ||||
|    Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org | ||||
| 
 | ||||
|    This program is free software: you can redistribute it and/or modify | ||||
|    it under the terms of the GNU Affero General Public License as published by | ||||
|    the Free Software Foundation, either version 3 of the License, or | ||||
|    (at your option) any later version. | ||||
| 
 | ||||
|    This program is distributed in the hope that it will be useful, | ||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
|    GNU Affero General Public License for more details. | ||||
| 
 | ||||
|    You should have received a copy of the GNU Affero General Public License | ||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package iotools | ||||
| 
 | ||||
| import ( | ||||
| 	"io" | ||||
| ) | ||||
| 
 | ||||
| // ReadFnCloser takes an io.Reader and wraps it to use the provided function to implement io.Closer. | ||||
| func ReadFnCloser(r io.Reader, close func() error) io.ReadCloser { | ||||
| 	return &readFnCloser{ | ||||
| 		Reader: r, | ||||
| 		close:  close, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type readFnCloser struct { | ||||
| 	io.Reader | ||||
| 	close func() error | ||||
| } | ||||
| 
 | ||||
| func (r *readFnCloser) Close() error { | ||||
| 	return r.close() | ||||
| } | ||||
| 
 | ||||
| // WriteFnCloser takes an io.Writer and wraps it to use the provided function to implement io.Closer. | ||||
| func WriteFnCloser(w io.Writer, close func() error) io.WriteCloser { | ||||
| 	return &writeFnCloser{ | ||||
| 		Writer: w, | ||||
| 		close:  close, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type writeFnCloser struct { | ||||
| 	io.Writer | ||||
| 	close func() error | ||||
| } | ||||
| 
 | ||||
| func (r *writeFnCloser) Close() error { | ||||
| 	return r.close() | ||||
| } | ||||
| 
 | ||||
| // SilentReader wraps an io.Reader to silence any | ||||
| // error output during reads. Instead they are stored | ||||
| // and accessible (not concurrency safe!) via .Error(). | ||||
| type SilentReader struct { | ||||
| 	io.Reader | ||||
| 	err error | ||||
| } | ||||
| 
 | ||||
| // SilenceReader wraps an io.Reader within SilentReader{}. | ||||
| func SilenceReader(r io.Reader) *SilentReader { | ||||
| 	return &SilentReader{Reader: r} | ||||
| } | ||||
| 
 | ||||
| func (r *SilentReader) Read(b []byte) (int, error) { | ||||
| 	n, err := r.Reader.Read(b) | ||||
| 	if err != nil { | ||||
| 		// Store error for now | ||||
| 		if r.err == nil { | ||||
| 			r.err = err | ||||
| 		} | ||||
| 
 | ||||
| 		// Pretend we're happy | ||||
| 		// to continue reading. | ||||
| 		n = len(b) | ||||
| 	} | ||||
| 	return n, nil | ||||
| } | ||||
| 
 | ||||
| func (r *SilentReader) Error() error { | ||||
| 	return r.err | ||||
| } | ||||
| 
 | ||||
| // SilentWriter wraps an io.Writer to silence any | ||||
| // error output during writes. Instead they are stored | ||||
| // and accessible (not concurrency safe!) via .Error(). | ||||
| type SilentWriter struct { | ||||
| 	io.Writer | ||||
| 	err error | ||||
| } | ||||
| 
 | ||||
| // SilenceWriter wraps an io.Writer within SilentWriter{}. | ||||
| func SilenceWriter(w io.Writer) *SilentWriter { | ||||
| 	return &SilentWriter{Writer: w} | ||||
| } | ||||
| 
 | ||||
| func (w *SilentWriter) Write(b []byte) (int, error) { | ||||
| 	n, err := w.Writer.Write(b) | ||||
| 	if err != nil { | ||||
| 		// Store error for now | ||||
| 		if w.err == nil { | ||||
| 			w.err = err | ||||
| 		} | ||||
| 
 | ||||
| 		// Pretend we're happy | ||||
| 		// to continue writing. | ||||
| 		n = len(b) | ||||
| 	} | ||||
| 	return n, nil | ||||
| } | ||||
| 
 | ||||
| func (w *SilentWriter) Error() error { | ||||
| 	return w.err | ||||
| } | ||||
|  | @ -19,7 +19,6 @@ | |||
| package media | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
|  | @ -29,7 +28,7 @@ import ( | |||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/iotools" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/media" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/uris" | ||||
|  | @ -135,7 +134,6 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount | |||
| 	} | ||||
| 
 | ||||
| 	var data media.DataFunc | ||||
| 	var postDataCallback media.PostDataCallbackFunc | ||||
| 
 | ||||
| 	if mediaSize == media.SizeSmall { | ||||
| 		// if it's the thumbnail that's requested then the user will have to wait a bit while we process the | ||||
|  | @ -155,7 +153,7 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount | |||
| 		// | ||||
| 		// this looks a bit like this: | ||||
| 		// | ||||
| 		//                http fetch                   buffered pipe | ||||
| 		//                http fetch                       pipe | ||||
| 		// remote server ------------> data function ----------------> api caller | ||||
| 		//                                   | | ||||
| 		//                                   | tee | ||||
|  | @ -163,54 +161,58 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount | |||
| 		//                                   ▼ | ||||
| 		//                            instance storage | ||||
| 
 | ||||
| 		// Buffer each end of the pipe, so that if the caller drops the connection during the flow, the tee | ||||
| 		// reader can continue without having to worry about tee-ing into a closed or blocked pipe. | ||||
| 		// This pipe will connect the caller to the in-process media retrieval... | ||||
| 		pipeReader, pipeWriter := io.Pipe() | ||||
| 		bufferedWriter := bufio.NewWriterSize(pipeWriter, int(attachmentContent.ContentLength)) | ||||
| 		bufferedReader := bufio.NewReaderSize(pipeReader, int(attachmentContent.ContentLength)) | ||||
| 
 | ||||
| 		// the caller will read from the buffered reader, so it doesn't matter if they drop out without reading everything | ||||
| 		attachmentContent.Content = io.NopCloser(bufferedReader) | ||||
| 		// Wrap the output pipe to silence any errors during the actual media | ||||
| 		// streaming process. We catch the error later but they must be silenced | ||||
| 		// during stream to prevent interruptions to storage of the actual media. | ||||
| 		silencedWriter := iotools.SilenceWriter(pipeWriter) | ||||
| 
 | ||||
| 		// Pass the reader side of the pipe to the caller to slurp from. | ||||
| 		attachmentContent.Content = pipeReader | ||||
| 
 | ||||
| 		// Create a data function which injects the writer end of the pipe | ||||
| 		// into the data retrieval process. If something goes wrong while | ||||
| 		// doing the data retrieval, we hang up the underlying pipeReader | ||||
| 		// to indicate to the caller that no data is available. It's up to | ||||
| 		// the caller of this processor function to handle that gracefully. | ||||
| 		data = func(innerCtx context.Context) (io.ReadCloser, int64, error) { | ||||
| 			t, err := p.transportController.NewTransportForUsername(innerCtx, requestingUsername) | ||||
| 			if err != nil { | ||||
| 				// propagate the transport error to read end of pipe. | ||||
| 				_ = pipeWriter.CloseWithError(fmt.Errorf("error getting transport for user: %w", err)) | ||||
| 				return nil, 0, err | ||||
| 			} | ||||
| 
 | ||||
| 			readCloser, fileSize, err := t.DereferenceMedia(transport.WithFastfail(innerCtx), remoteMediaIRI) | ||||
| 			if err != nil { | ||||
| 				// propagate the dereference error to read end of pipe. | ||||
| 				_ = pipeWriter.CloseWithError(fmt.Errorf("error dereferencing media: %w", err)) | ||||
| 				return nil, 0, err | ||||
| 			} | ||||
| 
 | ||||
| 			// Make a TeeReader so that everything read from the readCloser by the media manager will be written into the bufferedWriter. | ||||
| 			// We wrap this in a teeReadCloser which implements io.ReadCloser, so that whoever uses the teeReader can close the readCloser | ||||
| 			// when they're done with it. | ||||
| 			trc := teeReadCloser{ | ||||
| 				teeReader: io.TeeReader(readCloser, bufferedWriter), | ||||
| 				close:     readCloser.Close, | ||||
| 			} | ||||
| 			// Make a TeeReader so that everything read from the readCloser, | ||||
| 			// aka the remote instance, will also be written into the pipe. | ||||
| 			teeReader := io.TeeReader(readCloser, silencedWriter) | ||||
| 
 | ||||
| 			return trc, fileSize, nil | ||||
| 		} | ||||
| 
 | ||||
| 		// close the pipewriter after data has been piped into it, so the reader on the other side doesn't block; | ||||
| 		// we don't need to close the reader here because that's the caller's responsibility | ||||
| 		postDataCallback = func(innerCtx context.Context) error { | ||||
| 			// close the underlying pipe writer when we're done with it | ||||
| 			// Wrap teereader to implement original readcloser's close, | ||||
| 			// and also ensuring that we close the pipe from write end. | ||||
| 			return iotools.ReadFnCloser(teeReader, func() error { | ||||
| 				defer func() { | ||||
| 				if err := pipeWriter.Close(); err != nil { | ||||
| 					log.Errorf("getAttachmentContent: error closing pipeWriter: %s", err) | ||||
| 				} | ||||
| 					// We use the error (if any) encountered by the | ||||
| 					// silenced writer to close connection to make sure it | ||||
| 					// gets propagated to the attachment.Content reader. | ||||
| 					_ = pipeWriter.CloseWithError(silencedWriter.Error()) | ||||
| 				}() | ||||
| 
 | ||||
| 			// and flush the buffered writer into the buffer of the reader | ||||
| 			return bufferedWriter.Flush() | ||||
| 				return readCloser.Close() | ||||
| 			}), fileSize, nil | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// put the media recached in the queue | ||||
| 	processingMedia, err := p.mediaManager.RecacheMedia(ctx, data, postDataCallback, wantedMediaID) | ||||
| 	processingMedia, err := p.mediaManager.RecacheMedia(ctx, data, nil, wantedMediaID) | ||||
| 	if err != nil { | ||||
| 		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error recaching media: %s", err)) | ||||
| 	} | ||||
|  |  | |||
|  | @ -19,6 +19,7 @@ | |||
| package media_test | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"io" | ||||
| 	"path" | ||||
|  | @ -143,9 +144,13 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncachedInterrupted() { | |||
| 	suite.NotNil(content) | ||||
| 
 | ||||
| 	// only read the first kilobyte and then stop | ||||
| 	b := make([]byte, 1024) | ||||
| 	_, err = content.Content.Read(b) | ||||
| 	suite.NoError(err) | ||||
| 	b := make([]byte, 0, 1024) | ||||
| 	if !testrig.WaitFor(func() bool { | ||||
| 		read, err := io.CopyN(bytes.NewBuffer(b), content.Content, 1024) | ||||
| 		return err == nil && read == 1024 | ||||
| 	}) { | ||||
| 		suite.FailNow("timed out trying to read first 1024 bytes") | ||||
| 	} | ||||
| 
 | ||||
| 	// close the reader | ||||
| 	suite.NoError(content.Content.Close()) | ||||
|  |  | |||
|  | @ -20,7 +20,6 @@ package media | |||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | @ -62,16 +61,3 @@ func parseFocus(focus string) (focusx, focusy float32, err error) { | |||
| 	focusy = float32(fy) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| type teeReadCloser struct { | ||||
| 	teeReader io.Reader | ||||
| 	close     func() error | ||||
| } | ||||
| 
 | ||||
| func (t teeReadCloser) Read(p []byte) (n int, err error) { | ||||
| 	return t.teeReader.Read(p) | ||||
| } | ||||
| 
 | ||||
| func (t teeReadCloser) Close() error { | ||||
| 	return t.close() | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue