Compare commits

...

3 Commits

Author SHA1 Message Date
ParthSareen
f69453457d hi 2026-04-08 18:25:11 -07:00
Devon Rifkin
152b922265 checkpoint 2026-02-02 15:42:57 -08:00
Devon Rifkin
ba75143e71 WIP aliases 2026-02-02 14:59:58 -08:00
5 changed files with 582 additions and 3 deletions

309
server/aliases.go Normal file
View File

@@ -0,0 +1,309 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
const (
routerConfigFilename = "config.json"
routerConfigVersion = 1
)
var errAliasCycle = errors.New("alias cycle detected")
type aliasEntry struct {
Alias string `json:"alias"`
Target string `json:"target"`
}
type routerConfig struct {
Version int `json:"version"`
Aliases []aliasEntry `json:"aliases"`
}
type aliasStore struct {
mu sync.RWMutex
path string
entries map[string]aliasEntry // normalized alias -> entry
}
func newAliasStore(path string) (*aliasStore, error) {
store := &aliasStore{
path: path,
entries: make(map[string]aliasEntry),
}
if err := store.load(); err != nil {
return nil, err
}
return store, nil
}
func (s *aliasStore) load() error {
data, err := os.ReadFile(s.path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
var cfg routerConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return err
}
if cfg.Version != 0 && cfg.Version != routerConfigVersion {
return fmt.Errorf("unsupported router config version %d", cfg.Version)
}
for _, entry := range cfg.Aliases {
aliasName := model.ParseName(entry.Alias)
if !aliasName.IsValid() {
slog.Warn("invalid alias name in router config", "alias", entry.Alias)
continue
}
targetName := model.ParseName(entry.Target)
if !targetName.IsValid() {
slog.Warn("invalid alias target in router config", "target", entry.Target)
continue
}
canonicalAlias := displayAliasName(aliasName)
canonicalTarget := displayAliasName(targetName)
s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
Alias: canonicalAlias,
Target: canonicalTarget,
}
}
return nil
}
func (s *aliasStore) saveLocked() error {
dir := filepath.Dir(s.path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
entries := make([]aliasEntry, 0, len(s.entries))
for _, entry := range s.entries {
entries = append(entries, entry)
}
sort.Slice(entries, func(i, j int) bool {
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
})
cfg := routerConfig{
Version: routerConfigVersion,
Aliases: entries,
}
f, err := os.CreateTemp(dir, "router-*.json")
if err != nil {
return err
}
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(cfg); err != nil {
_ = f.Close()
_ = os.Remove(f.Name())
return err
}
if err := f.Close(); err != nil {
_ = os.Remove(f.Name())
return err
}
if err := os.Chmod(f.Name(), 0o644); err != nil {
_ = os.Remove(f.Name())
return err
}
return os.Rename(f.Name(), s.path)
}
func (s *aliasStore) ResolveName(name model.Name) (model.Name, bool, error) {
key := normalizeAliasKey(name)
s.mu.RLock()
entry, ok := s.entries[key]
s.mu.RUnlock()
if !ok {
return name, false, nil
}
// If a local model exists, do not allow alias shadowing.
exists, err := localModelExists(name)
if err != nil {
return name, false, err
}
if exists {
return name, false, nil
}
visited := map[string]struct{}{key: {}}
targetKey := normalizeAliasKeyString(entry.Target)
current := entry.Target
for {
targetName := model.ParseName(current)
if !targetName.IsValid() {
return name, false, fmt.Errorf("alias target %q is invalid", current)
}
if _, seen := visited[targetKey]; seen {
return name, false, errAliasCycle
}
visited[targetKey] = struct{}{}
s.mu.RLock()
next, ok := s.entries[targetKey]
s.mu.RUnlock()
if !ok {
return targetName, true, nil
}
current = next.Target
targetKey = normalizeAliasKeyString(current)
}
}
func (s *aliasStore) Set(alias, target model.Name) error {
aliasKey := normalizeAliasKey(alias)
targetKey := normalizeAliasKey(target)
if aliasKey == targetKey {
return fmt.Errorf("alias cannot point to itself")
}
s.mu.Lock()
defer s.mu.Unlock()
visited := map[string]struct{}{aliasKey: {}}
currentKey := targetKey
for {
if _, seen := visited[currentKey]; seen {
return errAliasCycle
}
visited[currentKey] = struct{}{}
next, ok := s.entries[currentKey]
if !ok {
break
}
currentKey = normalizeAliasKeyString(next.Target)
}
s.entries[aliasKey] = aliasEntry{
Alias: displayAliasName(alias),
Target: displayAliasName(target),
}
return s.saveLocked()
}
func (s *aliasStore) Delete(alias model.Name) (bool, error) {
aliasKey := normalizeAliasKey(alias)
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.entries[aliasKey]; !ok {
return false, nil
}
delete(s.entries, aliasKey)
return true, s.saveLocked()
}
func (s *aliasStore) List() []aliasEntry {
s.mu.RLock()
defer s.mu.RUnlock()
entries := make([]aliasEntry, 0, len(s.entries))
for _, entry := range s.entries {
entries = append(entries, entry)
}
sort.Slice(entries, func(i, j int) bool {
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
})
return entries
}
func normalizeAliasKey(name model.Name) string {
return strings.ToLower(displayAliasName(name))
}
func normalizeAliasKeyString(value string) string {
n := model.ParseName(value)
if !n.IsValid() {
return strings.ToLower(strings.TrimSpace(value))
}
return normalizeAliasKey(n)
}
func displayAliasName(n model.Name) string {
display := n.DisplayShortest()
if strings.EqualFold(n.Tag, "latest") {
if idx := strings.LastIndex(display, ":"); idx != -1 {
return display[:idx]
}
}
return display
}
func localModelExists(name model.Name) (bool, error) {
manifests, err := manifest.Manifests(true)
if err != nil {
return false, err
}
needle := name.String()
for existing := range manifests {
if strings.EqualFold(existing.String(), needle) {
return true, nil
}
}
return false, nil
}
func routerConfigPath() string {
home, err := os.UserHomeDir()
if err != nil {
return filepath.Join(".ollama", routerConfigFilename)
}
return filepath.Join(home, ".ollama", routerConfigFilename)
}
func (s *Server) aliasStore() (*aliasStore, error) {
s.aliasesOnce.Do(func() {
s.aliases, s.aliasesErr = newAliasStore(routerConfigPath())
})
return s.aliases, s.aliasesErr
}
func (s *Server) resolveModelAliasName(name model.Name) (model.Name, bool, error) {
store, err := s.aliasStore()
if err != nil {
return name, false, err
}
if store == nil {
return name, false, nil
}
return store.ResolveName(name)
}

View File

@@ -22,6 +22,7 @@ import (
"os/signal" "os/signal"
"slices" "slices"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time" "time"
@@ -81,6 +82,9 @@ type Server struct {
addr net.Addr addr net.Addr
sched *Scheduler sched *Scheduler
defaultNumCtx int defaultNumCtx int
aliasesOnce sync.Once
aliases *aliasStore
aliasesErr error
} }
func init() { func init() {
@@ -191,9 +195,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
resolvedName, _, err := s.resolveModelAliasName(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
name = resolvedName
// We cannot currently consolidate this into GetModel because all we'll // We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure. // induce infinite recursion given the current code structure.
name, err := getExistingName(name) name, err = getExistingName(name)
if err != nil { if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return return
@@ -1580,6 +1591,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/blobs/:digest", s.CreateBlobHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.POST("/api/copy", s.CopyHandler) r.POST("/api/copy", s.CopyHandler)
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
// Inference // Inference
r.GET("/api/ps", s.PsHandler) r.GET("/api/ps", s.PsHandler)
@@ -1950,13 +1964,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
name, err := getExistingName(name) resolvedName, _, err := s.resolveModelAliasName(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
name = resolvedName
name, err = getExistingName(name)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
m, err := GetModel(req.Model) m, err := GetModel(name.String())
if err != nil { if err != nil {
switch { switch {
case os.IsNotExist(err): case os.IsNotExist(err):

138
server/routes_aliases.go Normal file
View File

@@ -0,0 +1,138 @@
package server
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/types/model"
)
type aliasListResponse struct {
Aliases []aliasEntry `json:"aliases"`
}
type aliasDeleteRequest struct {
Alias string `json:"alias"`
}
func (s *Server) ListAliasesHandler(c *gin.Context) {
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var aliases []aliasEntry
if store != nil {
aliases = store.List()
}
c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
}
func (s *Server) CreateAliasHandler(c *gin.Context) {
var req aliasEntry
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Alias = strings.TrimSpace(req.Alias)
req.Target = strings.TrimSpace(req.Target)
if req.Alias == "" || req.Target == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
return
}
aliasName := model.ParseName(req.Alias)
if !aliasName.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
return
}
targetName := model.ParseName(req.Target)
if !targetName.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
return
}
if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
return
}
exists, err := localModelExists(aliasName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if exists {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
return
}
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := store.Set(aliasName, targetName); err != nil {
status := http.StatusInternalServerError
if errors.Is(err, errAliasCycle) {
status = http.StatusBadRequest
}
c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, aliasEntry{Alias: displayAliasName(aliasName), Target: displayAliasName(targetName)})
}
func (s *Server) DeleteAliasHandler(c *gin.Context) {
var req aliasDeleteRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Alias = strings.TrimSpace(req.Alias)
if req.Alias == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
return
}
aliasName := model.ParseName(req.Alias)
if !aliasName.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
return
}
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
deleted, err := store.Delete(aliasName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !deleted {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": true})
}

View File

@@ -0,0 +1,110 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
)
func TestAliasShadowingRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "shadowed-model",
RemoteHost: "example.com",
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", w.Code)
}
}
func TestAliasResolvesForChatRemote(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
var remoteModel string
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req api.ChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatal(err)
}
remoteModel = req.Model
w.Header().Set("Content-Type", "application/json")
resp := api.ChatResponse{
Model: req.Model,
Done: true,
DoneReason: "load",
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
}))
defer rs.Close()
p, err := url.Parse(rs.URL)
if err != nil {
t.Fatal(err)
}
t.Setenv("OLLAMA_REMOTES", p.Hostname())
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "target-model",
RemoteHost: rs.URL,
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "alias-model",
Messages: []api.Message{{Role: "user", Content: "hi"}},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Model != "alias-model" {
t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
}
if remoteModel != "test" {
t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
}
}

1
test.txt Normal file
View File

@@ -0,0 +1 @@
hi