api: return structured error on unauthorized push

This commit implements a structured error response system for the Ollama API, replacing
ad-hoc error handling and string parsing with proper error types and codes through a new
ErrorResponse struct. Instead of relying on regex to parse error messages for SSH keys,
the API now passes this data in a structured format with standardized fields for error
messages, codes, and additional data. This structured approach makes the API more
maintainable and reliable while improving the developer experience by enabling
programmatic error handling, consistent error formats, and better error
documentation.
This commit is contained in:
Bruce MacDonald
2024-12-12 11:49:36 -08:00
parent ae9165d661
commit 9e190ac4d9
10 changed files with 255 additions and 104 deletions

View File

@@ -163,24 +163,29 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
scanBuf := make([]byte, 0, maxBufferSize)
scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() {
var errorResponse struct {
Error string `json:"error,omitempty"`
}
bts := scanner.Bytes()
var errorResponse ErrorResponse
if err := json.Unmarshal(bts, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
switch errorResponse.Code {
case ErrCodeUnknownKey:
return ErrUnknownOllamaKey{
Message: errorResponse.Message,
Key: errorResponse.Data["key"].(string),
}
}
if errorResponse.Message != "" {
return errors.New(errorResponse.Message)
}
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: errorResponse.Error,
ErrorMessage: errorResponse.Message,
}
}

View File

@@ -1,6 +1,12 @@
package api
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
@@ -43,3 +49,117 @@ func TestClientFromEnvironment(t *testing.T) {
})
}
}
func TestStream(t *testing.T) {
tests := []struct {
name string
serverResponse []string
statusCode int
expectedError error
}{
{
name: "unknown key error",
serverResponse: []string{
`{"error":"unauthorized access","code":"unknown_key","data":{"key":"test-key"}}`,
},
statusCode: http.StatusUnauthorized,
expectedError: &ErrUnknownOllamaKey{
Message: "unauthorized access",
Key: "test-key",
},
},
{
name: "general error message",
serverResponse: []string{
`{"error":"something went wrong"}`,
},
statusCode: http.StatusInternalServerError,
expectedError: fmt.Errorf("something went wrong"),
},
{
name: "malformed json response",
serverResponse: []string{
`{invalid-json`,
},
statusCode: http.StatusOK,
expectedError: fmt.Errorf("unmarshal: invalid character 'i' looking for beginning of object key string"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/x-ndjson")
w.WriteHeader(tt.statusCode)
for _, resp := range tt.serverResponse {
fmt.Fprintln(w, resp)
}
}))
defer server.Close()
baseURL, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
client := &Client{
http: server.Client(),
base: baseURL,
}
var responses [][]byte
err = client.stream(context.Background(), "POST", "/test", "test", func(bts []byte) error {
responses = append(responses, bts)
return nil
})
// Error checking
if tt.expectedError == nil {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
return
}
if err == nil {
t.Fatalf("expected error %v, got nil", tt.expectedError)
}
// Check for specific error types
var unknownKeyErr ErrUnknownOllamaKey
if errors.As(tt.expectedError, &unknownKeyErr) {
var gotErr ErrUnknownOllamaKey
if !errors.As(err, &gotErr) {
t.Fatalf("expected ErrUnknownOllamaKey, got %T", err)
}
if unknownKeyErr.Key != gotErr.Key {
t.Errorf("expected key %q, got %q", unknownKeyErr.Key, gotErr.Key)
}
if unknownKeyErr.Message != gotErr.Message {
t.Errorf("expected message %q, got %q", unknownKeyErr.Message, gotErr.Message)
}
return
}
var statusErr StatusError
if errors.As(tt.expectedError, &statusErr) {
var gotErr StatusError
if !errors.As(err, &gotErr) {
t.Fatalf("expected StatusError, got %T", err)
}
if statusErr.StatusCode != gotErr.StatusCode {
t.Errorf("expected status code %d, got %d", statusErr.StatusCode, gotErr.StatusCode)
}
if statusErr.ErrorMessage != gotErr.ErrorMessage {
t.Errorf("expected error message %q, got %q", statusErr.ErrorMessage, gotErr.ErrorMessage)
}
return
}
// For other errors, compare error strings
if err.Error() != tt.expectedError.Error() {
t.Errorf("expected error %q, got %q", tt.expectedError, err)
}
})
}
}

74
api/errors.go Normal file
View File

@@ -0,0 +1,74 @@
package api
import (
"fmt"
"slices"
"strings"
)
const InvalidModelNameErrMsg = "invalid model name"
// API error responses
// ErrorCode represents a standardized error code identifier
type ErrorCode string
const (
ErrCodeUnknownKey ErrorCode = "unknown_key"
ErrCodeGeneral ErrorCode = "general" // Generic fallback error code
)
// ErrorResponse implements a structured error interface
type ErrorResponse struct {
Message string `json:"error"` // Human-readable error message, uses 'error' field name for backwards compatibility
Code ErrorCode `json:"code"` // Machine-readable error code for programmatic handling, not response code
Data map[string]any `json:"data"` // Additional error specific data, if any
}
func (e ErrorResponse) Error() string {
return e.Message
}
type ErrUnknownOllamaKey struct {
Message string
Key string
}
func (e ErrUnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: unknown ollama key %q", strings.TrimSpace(e.Key))
}
func (e *ErrUnknownOllamaKey) FormatUserMessage(localKeys []string) string {
// The user should only be told to add the key if it is the same one that exists locally
if slices.Index(localKeys, e.Key) == -1 {
return e.Message
}
return fmt.Sprintf(`%s
Your ollama key is:
%s
Add your key at:
https://ollama.com/settings/keys`, e.Message, e.Key)
}
// StatusError is an error with an HTTP status code and message,
// it is parsed on the client-side and not returned from the API
type StatusError struct {
StatusCode int // e.g. 200
Status string // e.g. "200 OK"
ErrorMessage string `json:"error"`
}
func (e StatusError) Error() string {
switch {
case e.Status != "" && e.ErrorMessage != "":
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
case e.Status != "":
return e.Status
case e.ErrorMessage != "":
return e.ErrorMessage
default:
// this should not happen
return "something went wrong, please see the ollama server logs for details"
}
}

View File

@@ -12,27 +12,6 @@ import (
"time"
)
// StatusError is an error with an HTTP status code and message.
type StatusError struct {
StatusCode int
Status string
ErrorMessage string `json:"error"`
}
func (e StatusError) Error() string {
switch {
case e.Status != "" && e.ErrorMessage != "":
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
case e.Status != "":
return e.Status
case e.ErrorMessage != "":
return e.ErrorMessage
default:
// this should not happen
return "something went wrong, please see the ollama server logs for details"
}
}
// ImageData represents the raw binary data of an image file.
type ImageData []byte