mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 15:52:25 -05:00 
			
		
		
		
	[bugfix] Fix NegotiateAccept with multi accept
There's a bug in Gin's NegotiateFormat that doesn't handle the presence of multilpe accept headers. This lifts the code from the PR @tsmethurst sent a year ago to Gin into our codebase to fix the issue.
This commit is contained in:
		
					parent
					
						
							
								2478d83c84
							
						
					
				
			
			
				commit
				
					
						7050112af1
					
				
			
		
					 4 changed files with 122 additions and 3 deletions
				
			
		|  | @ -20,6 +20,7 @@ package util | |||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | @ -108,10 +109,63 @@ func NegotiateAccept(c *gin.Context, offers ...MIME) (string, error) { | |||
| 		return strings[0], nil | ||||
| 	} | ||||
| 
 | ||||
| 	format := c.NegotiateFormat(strings...) | ||||
| 	format := NegotiateFormat(c, strings...) | ||||
| 	if format == "" { | ||||
| 		return "", fmt.Errorf("no format can be offered for requested Accept header(s) %s; this endpoint offers %s", accepts, offers) | ||||
| 	} | ||||
| 
 | ||||
| 	return format, nil | ||||
| } | ||||
| 
 | ||||
| // This is the exact same thing as gin.Context.NegotiateFormat except it contains | ||||
| // tsmethurst's fix to make it work properly with multiple accept headers. | ||||
| // | ||||
| // https://github.com/gin-gonic/gin/pull/3156 | ||||
| func NegotiateFormat(c *gin.Context, offered ...string) string { | ||||
| 	if len(offered) == 0 { | ||||
| 		panic("you must provide at least one offer") | ||||
| 	} | ||||
| 
 | ||||
| 	if c.Accepted == nil { | ||||
| 		for _, a := range c.Request.Header.Values("Accept") { | ||||
| 			c.Accepted = append(c.Accepted, parseAccept(a)...) | ||||
| 		} | ||||
| 	} | ||||
| 	if len(c.Accepted) == 0 { | ||||
| 		return offered[0] | ||||
| 	} | ||||
| 	for _, accepted := range c.Accepted { | ||||
| 		for _, offer := range offered { | ||||
| 			// According to RFC 2616 and RFC 2396, non-ASCII characters are not allowed in headers, | ||||
| 			// therefore we can just iterate over the string without casting it into []rune | ||||
| 			i := 0 | ||||
| 			for ; i < len(accepted); i++ { | ||||
| 				if accepted[i] == '*' || offer[i] == '*' { | ||||
| 					return offer | ||||
| 				} | ||||
| 				if accepted[i] != offer[i] { | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
| 			if i == len(accepted) { | ||||
| 				return offer | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // https://github.com/gin-gonic/gin/blob/4787b8203b79012877ac98d7806422da3a678ba2/utils.go#L103 | ||||
| func parseAccept(acceptHeader string) []string { | ||||
| 	parts := strings.Split(acceptHeader, ",") | ||||
| 	out := make([]string, 0, len(parts)) | ||||
| 	for _, part := range parts { | ||||
| 		if i := strings.IndexByte(part, ';'); i > 0 { | ||||
| 			part = part[:i] | ||||
| 		} | ||||
| 		if part = strings.TrimSpace(part); part != "" { | ||||
| 			out = append(out, part) | ||||
| 		} | ||||
| 	} | ||||
| 	return out | ||||
| } | ||||
|  |  | |||
							
								
								
									
										65
									
								
								internal/api/util/negotiate_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								internal/api/util/negotiate_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,65 @@ | |||
| package util | ||||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
| 
 | ||||
| type testMIMES []MIME | ||||
| 
 | ||||
| func (tm testMIMES) String(t *testing.T) string { | ||||
| 	t.Helper() | ||||
| 
 | ||||
| 	res := tm.StringS(t) | ||||
| 	return strings.Join(res, ",") | ||||
| } | ||||
| 
 | ||||
| func (tm testMIMES) StringS(t *testing.T) []string { | ||||
| 	t.Helper() | ||||
| 
 | ||||
| 	res := make([]string, 0, len(tm)) | ||||
| 	for _, m := range tm { | ||||
| 		res = append(res, string(m)) | ||||
| 	} | ||||
| 	return res | ||||
| } | ||||
| 
 | ||||
| func TestNegotiateFormat(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		incoming []string | ||||
| 		offered  testMIMES | ||||
| 		format   string | ||||
| 	}{ | ||||
| 		{incoming: testMIMES{AppJSON}.StringS(t), offered: testMIMES{AppJRDJSON, AppJSON}, format: "application/json"}, | ||||
| 		{incoming: testMIMES{AppJRDJSON}.StringS(t), offered: testMIMES{AppJRDJSON, AppJSON}, format: "application/jrd+json"}, | ||||
| 		{incoming: testMIMES{AppJRDJSON, AppJSON}.StringS(t), offered: testMIMES{AppJRDJSON}, format: "application/jrd+json"}, | ||||
| 		{incoming: testMIMES{AppJRDJSON, AppJSON}.StringS(t), offered: testMIMES{AppJSON}, format: "application/json"}, | ||||
| 		{incoming: testMIMES{"text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8"}.StringS(t), offered: testMIMES{AppJSON, AppXML}, format: "application/xml"}, | ||||
| 		{incoming: testMIMES{"text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8"}.StringS(t), offered: testMIMES{TextHTML, AppXML}, format: "text/html"}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tt := range tests { | ||||
| 		name := "incoming:" + strings.Join(tt.incoming, ",") + " offered:" + tt.offered.String(t) | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			tt := tt | ||||
| 			t.Parallel() | ||||
| 
 | ||||
| 			c, _ := gin.CreateTestContext(httptest.NewRecorder()) | ||||
| 			c.Request = &http.Request{ | ||||
| 				Header: make(http.Header), | ||||
| 			} | ||||
| 			for _, header := range tt.incoming { | ||||
| 				c.Request.Header.Add("accept", header) | ||||
| 			} | ||||
| 
 | ||||
| 			format := NegotiateFormat(c, tt.offered.StringS(t)...) | ||||
| 			if tt.format != format { | ||||
| 				t.Fatalf("expected format: '%s', got format: '%s'", tt.format, format) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | @ -73,7 +73,7 @@ func (m *Module) profileGETHandler(c *gin.Context) { | |||
| 
 | ||||
| 	// if we're getting an AP request on this endpoint we | ||||
| 	// should render the account's AP representation instead | ||||
| 	accept := c.NegotiateFormat(string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON)) | ||||
| 	accept := apiutil.NegotiateFormat(c, string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON)) | ||||
| 	if accept == string(apiutil.AppActivityJSON) || accept == string(apiutil.AppActivityLDJSON) { | ||||
| 		m.returnAPProfile(ctx, c, username, accept) | ||||
| 		return | ||||
|  |  | |||
|  | @ -90,7 +90,7 @@ func (m *Module) threadGETHandler(c *gin.Context) { | |||
| 
 | ||||
| 	// if we're getting an AP request on this endpoint we | ||||
| 	// should render the status's AP representation instead | ||||
| 	accept := c.NegotiateFormat(string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON)) | ||||
| 	accept := apiutil.NegotiateFormat(c, string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON)) | ||||
| 	if accept == string(apiutil.AppActivityJSON) || accept == string(apiutil.AppActivityLDJSON) { | ||||
| 		m.returnAPStatus(ctx, c, username, statusID, accept) | ||||
| 		return | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue