mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 22:54:05 +02:00
server: reject unexpected auth hosts (#13738)
Added validation to ensure auth redirects stay on the same host as the original request. The fix is a single check in getAuthorizationToken comparing the realm URL's host against the request host. Added tests for the auth flow. Co-Authored-By: Gecko Security <188164982+geckosecurity@users.noreply.github.com> * gofmt --------- Co-authored-by: Gecko Security <188164982+geckosecurity@users.noreply.github.com>
This commit is contained in:
@@ -50,12 +50,17 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
|||||||
return redirectURL, nil
|
return redirectURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
|
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
|
||||||
redirectURL, err := challenge.URL()
|
redirectURL, err := challenge.URL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
|
||||||
|
if redirectURL.Host != originalHost {
|
||||||
|
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
|
||||||
|
}
|
||||||
|
|
||||||
sha256sum := sha256.Sum256(nil)
|
sha256sum := sha256.Sum256(nil)
|
||||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||||
|
|
||||||
|
|||||||
113
server/auth_test.go
Normal file
113
server/auth_test.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
realm string
|
||||||
|
originalHost string
|
||||||
|
wantMismatch bool
|
||||||
|
}{
|
||||||
|
{"https://example.com/token", "example.com", false},
|
||||||
|
{"https://example.com/token", "other.com", true},
|
||||||
|
{"https://example.com/token", "localhost:8000", true},
|
||||||
|
{"https://localhost:5000/token", "localhost:5000", false},
|
||||||
|
{"https://localhost:5000/token", "localhost:6000", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.originalHost, func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
|
||||||
|
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
|
||||||
|
|
||||||
|
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
|
||||||
|
if tt.wantMismatch && !isMismatch {
|
||||||
|
t.Errorf("expected domain mismatch error, got: %v", err)
|
||||||
|
}
|
||||||
|
if !tt.wantMismatch && isMismatch {
|
||||||
|
t.Errorf("unexpected domain mismatch error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRegistryChallenge(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
wantRealm, wantService, wantScope string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
|
||||||
|
"https://auth.example.com/token", "registry", "repo:foo:pull",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
|
||||||
|
"https://r.ollama.ai/v2/token", "ollama", "-",
|
||||||
|
},
|
||||||
|
{"", "", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := parseRegistryChallenge(tt.input)
|
||||||
|
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
|
||||||
|
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
|
||||||
|
tt.input, result.Realm, result.Service, result.Scope,
|
||||||
|
tt.wantRealm, tt.wantService, tt.wantScope)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryChallengeURL(t *testing.T) {
|
||||||
|
challenge := registryChallenge{
|
||||||
|
Realm: "https://auth.example.com/token",
|
||||||
|
Service: "registry",
|
||||||
|
Scope: "repo:foo:pull repo:bar:push",
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := challenge.URL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("URL() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Host != "auth.example.com" {
|
||||||
|
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
|
||||||
|
}
|
||||||
|
if u.Path != "/token" {
|
||||||
|
t.Errorf("path = %q, want %q", u.Path, "/token")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := u.Query()
|
||||||
|
if q.Get("service") != "registry" {
|
||||||
|
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
|
||||||
|
}
|
||||||
|
if scopes := q["scope"]; len(scopes) != 2 {
|
||||||
|
t.Errorf("scope count = %d, want 2", len(scopes))
|
||||||
|
}
|
||||||
|
if q.Get("ts") == "" {
|
||||||
|
t.Error("missing ts")
|
||||||
|
}
|
||||||
|
if q.Get("nonce") == "" {
|
||||||
|
t.Error("missing nonce")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nonces should differ between calls
|
||||||
|
u2, _ := challenge.URL()
|
||||||
|
if q.Get("nonce") == u2.Query().Get("nonce") {
|
||||||
|
t.Error("nonce should be unique per call")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryChallengeURLInvalid(t *testing.T) {
|
||||||
|
challenge := registryChallenge{Realm: "://invalid"}
|
||||||
|
if _, err := challenge.URL(); err == nil {
|
||||||
|
t.Error("expected error for invalid URL")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -775,7 +775,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
Realm: challenge.Realm,
|
Realm: challenge.Realm,
|
||||||
Service: challenge.Service,
|
Service: challenge.Service,
|
||||||
Scope: challenge.Scope,
|
Scope: challenge.Scope,
|
||||||
})
|
}, base.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||||
@@ -850,7 +850,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
Realm: challenge.Realm,
|
Realm: challenge.Realm,
|
||||||
Service: challenge.Service,
|
Service: challenge.Service,
|
||||||
Scope: challenge.Scope,
|
Scope: challenge.Scope,
|
||||||
})
|
}, base.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||||
@@ -916,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||||||
|
|
||||||
// Handle authentication error with one retry
|
// Handle authentication error with one retry
|
||||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||||
token, err := getAuthorizationToken(ctx, challenge)
|
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
|||||||
case resp.StatusCode == http.StatusUnauthorized:
|
case resp.StatusCode == http.StatusUnauthorized:
|
||||||
w.Rollback()
|
w.Rollback()
|
||||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||||
token, err := getAuthorizationToken(ctx, challenge)
|
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user