mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 16:54:13 +02:00
Compare commits
3 Commits
pdevine/sa
...
parth-test
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f69453457d | ||
|
|
152b922265 | ||
|
|
ba75143e71 |
309
server/aliases.go
Normal file
309
server/aliases.go
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
routerConfigFilename = "config.json"
|
||||||
|
routerConfigVersion = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
var errAliasCycle = errors.New("alias cycle detected")
|
||||||
|
|
||||||
|
type aliasEntry struct {
|
||||||
|
Alias string `json:"alias"`
|
||||||
|
Target string `json:"target"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type routerConfig struct {
|
||||||
|
Version int `json:"version"`
|
||||||
|
Aliases []aliasEntry `json:"aliases"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type aliasStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
path string
|
||||||
|
entries map[string]aliasEntry // normalized alias -> entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAliasStore(path string) (*aliasStore, error) {
|
||||||
|
store := &aliasStore{
|
||||||
|
path: path,
|
||||||
|
entries: make(map[string]aliasEntry),
|
||||||
|
}
|
||||||
|
if err := store.load(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *aliasStore) load() error {
|
||||||
|
data, err := os.ReadFile(s.path)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg routerConfig
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Version != 0 && cfg.Version != routerConfigVersion {
|
||||||
|
return fmt.Errorf("unsupported router config version %d", cfg.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range cfg.Aliases {
|
||||||
|
aliasName := model.ParseName(entry.Alias)
|
||||||
|
if !aliasName.IsValid() {
|
||||||
|
slog.Warn("invalid alias name in router config", "alias", entry.Alias)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
targetName := model.ParseName(entry.Target)
|
||||||
|
if !targetName.IsValid() {
|
||||||
|
slog.Warn("invalid alias target in router config", "target", entry.Target)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
canonicalAlias := displayAliasName(aliasName)
|
||||||
|
canonicalTarget := displayAliasName(targetName)
|
||||||
|
s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
|
||||||
|
Alias: canonicalAlias,
|
||||||
|
Target: canonicalTarget,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *aliasStore) saveLocked() error {
|
||||||
|
dir := filepath.Dir(s.path)
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]aliasEntry, 0, len(s.entries))
|
||||||
|
for _, entry := range s.entries {
|
||||||
|
entries = append(entries, entry)
|
||||||
|
}
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||||
|
})
|
||||||
|
|
||||||
|
cfg := routerConfig{
|
||||||
|
Version: routerConfigVersion,
|
||||||
|
Aliases: entries,
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.CreateTemp(dir, "router-*.json")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := json.NewEncoder(f)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
if err := enc.Encode(cfg); err != nil {
|
||||||
|
_ = f.Close()
|
||||||
|
_ = os.Remove(f.Name())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := f.Close(); err != nil {
|
||||||
|
_ = os.Remove(f.Name())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Chmod(f.Name(), 0o644); err != nil {
|
||||||
|
_ = os.Remove(f.Name())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.Rename(f.Name(), s.path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *aliasStore) ResolveName(name model.Name) (model.Name, bool, error) {
|
||||||
|
key := normalizeAliasKey(name)
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
entry, ok := s.entries[key]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return name, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a local model exists, do not allow alias shadowing.
|
||||||
|
exists, err := localModelExists(name)
|
||||||
|
if err != nil {
|
||||||
|
return name, false, err
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
return name, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
visited := map[string]struct{}{key: {}}
|
||||||
|
targetKey := normalizeAliasKeyString(entry.Target)
|
||||||
|
current := entry.Target
|
||||||
|
|
||||||
|
for {
|
||||||
|
targetName := model.ParseName(current)
|
||||||
|
if !targetName.IsValid() {
|
||||||
|
return name, false, fmt.Errorf("alias target %q is invalid", current)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, seen := visited[targetKey]; seen {
|
||||||
|
return name, false, errAliasCycle
|
||||||
|
}
|
||||||
|
visited[targetKey] = struct{}{}
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
next, ok := s.entries[targetKey]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return targetName, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
current = next.Target
|
||||||
|
targetKey = normalizeAliasKeyString(current)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *aliasStore) Set(alias, target model.Name) error {
|
||||||
|
aliasKey := normalizeAliasKey(alias)
|
||||||
|
targetKey := normalizeAliasKey(target)
|
||||||
|
|
||||||
|
if aliasKey == targetKey {
|
||||||
|
return fmt.Errorf("alias cannot point to itself")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
visited := map[string]struct{}{aliasKey: {}}
|
||||||
|
currentKey := targetKey
|
||||||
|
for {
|
||||||
|
if _, seen := visited[currentKey]; seen {
|
||||||
|
return errAliasCycle
|
||||||
|
}
|
||||||
|
visited[currentKey] = struct{}{}
|
||||||
|
|
||||||
|
next, ok := s.entries[currentKey]
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
currentKey = normalizeAliasKeyString(next.Target)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.entries[aliasKey] = aliasEntry{
|
||||||
|
Alias: displayAliasName(alias),
|
||||||
|
Target: displayAliasName(target),
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.saveLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *aliasStore) Delete(alias model.Name) (bool, error) {
|
||||||
|
aliasKey := normalizeAliasKey(alias)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := s.entries[aliasKey]; !ok {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(s.entries, aliasKey)
|
||||||
|
return true, s.saveLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *aliasStore) List() []aliasEntry {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
entries := make([]aliasEntry, 0, len(s.entries))
|
||||||
|
for _, entry := range s.entries {
|
||||||
|
entries = append(entries, entry)
|
||||||
|
}
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||||
|
})
|
||||||
|
return entries
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAliasKey(name model.Name) string {
|
||||||
|
return strings.ToLower(displayAliasName(name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAliasKeyString(value string) string {
|
||||||
|
n := model.ParseName(value)
|
||||||
|
if !n.IsValid() {
|
||||||
|
return strings.ToLower(strings.TrimSpace(value))
|
||||||
|
}
|
||||||
|
return normalizeAliasKey(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func displayAliasName(n model.Name) string {
|
||||||
|
display := n.DisplayShortest()
|
||||||
|
if strings.EqualFold(n.Tag, "latest") {
|
||||||
|
if idx := strings.LastIndex(display, ":"); idx != -1 {
|
||||||
|
return display[:idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return display
|
||||||
|
}
|
||||||
|
|
||||||
|
func localModelExists(name model.Name) (bool, error) {
|
||||||
|
manifests, err := manifest.Manifests(true)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
needle := name.String()
|
||||||
|
for existing := range manifests {
|
||||||
|
if strings.EqualFold(existing.String(), needle) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func routerConfigPath() string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return filepath.Join(".ollama", routerConfigFilename)
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".ollama", routerConfigFilename)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) aliasStore() (*aliasStore, error) {
|
||||||
|
s.aliasesOnce.Do(func() {
|
||||||
|
s.aliases, s.aliasesErr = newAliasStore(routerConfigPath())
|
||||||
|
})
|
||||||
|
|
||||||
|
return s.aliases, s.aliasesErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) resolveModelAliasName(name model.Name) (model.Name, bool, error) {
|
||||||
|
store, err := s.aliasStore()
|
||||||
|
if err != nil {
|
||||||
|
return name, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if store == nil {
|
||||||
|
return name, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return store.ResolveName(name)
|
||||||
|
}
|
||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@@ -81,6 +82,9 @@ type Server struct {
|
|||||||
addr net.Addr
|
addr net.Addr
|
||||||
sched *Scheduler
|
sched *Scheduler
|
||||||
defaultNumCtx int
|
defaultNumCtx int
|
||||||
|
aliasesOnce sync.Once
|
||||||
|
aliases *aliasStore
|
||||||
|
aliasesErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -191,9 +195,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resolvedName, _, err := s.resolveModelAliasName(name)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
name = resolvedName
|
||||||
|
|
||||||
// We cannot currently consolidate this into GetModel because all we'll
|
// We cannot currently consolidate this into GetModel because all we'll
|
||||||
// induce infinite recursion given the current code structure.
|
// induce infinite recursion given the current code structure.
|
||||||
name, err := getExistingName(name)
|
name, err = getExistingName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||||
return
|
return
|
||||||
@@ -1580,6 +1591,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||||
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||||
r.POST("/api/copy", s.CopyHandler)
|
r.POST("/api/copy", s.CopyHandler)
|
||||||
|
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
|
||||||
|
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||||
|
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
r.GET("/api/ps", s.PsHandler)
|
r.GET("/api/ps", s.PsHandler)
|
||||||
@@ -1950,13 +1964,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name, err := getExistingName(name)
|
resolvedName, _, err := s.resolveModelAliasName(name)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
name = resolvedName
|
||||||
|
|
||||||
|
name, err = getExistingName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := GetModel(req.Model)
|
m, err := GetModel(name.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
case os.IsNotExist(err):
|
case os.IsNotExist(err):
|
||||||
|
|||||||
138
server/routes_aliases.go
Normal file
138
server/routes_aliases.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type aliasListResponse struct {
|
||||||
|
Aliases []aliasEntry `json:"aliases"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type aliasDeleteRequest struct {
|
||||||
|
Alias string `json:"alias"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) ListAliasesHandler(c *gin.Context) {
|
||||||
|
store, err := s.aliasStore()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var aliases []aliasEntry
|
||||||
|
if store != nil {
|
||||||
|
aliases = store.List()
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) CreateAliasHandler(c *gin.Context) {
|
||||||
|
var req aliasEntry
|
||||||
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
} else if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Alias = strings.TrimSpace(req.Alias)
|
||||||
|
req.Target = strings.TrimSpace(req.Target)
|
||||||
|
if req.Alias == "" || req.Target == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
aliasName := model.ParseName(req.Alias)
|
||||||
|
if !aliasName.IsValid() {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
targetName := model.ParseName(req.Target)
|
||||||
|
if !targetName.IsValid() {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := localModelExists(aliasName)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
store, err := s.aliasStore()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := store.Set(aliasName, targetName); err != nil {
|
||||||
|
status := http.StatusInternalServerError
|
||||||
|
if errors.Is(err, errAliasCycle) {
|
||||||
|
status = http.StatusBadRequest
|
||||||
|
}
|
||||||
|
c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, aliasEntry{Alias: displayAliasName(aliasName), Target: displayAliasName(targetName)})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) DeleteAliasHandler(c *gin.Context) {
|
||||||
|
var req aliasDeleteRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
} else if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Alias = strings.TrimSpace(req.Alias)
|
||||||
|
if req.Alias == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
aliasName := model.ParseName(req.Alias)
|
||||||
|
if !aliasName.IsValid() {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
store, err := s.aliasStore()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted, err := store.Delete(aliasName)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !deleted {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||||
|
}
|
||||||
110
server/routes_aliases_test.go
Normal file
110
server/routes_aliases_test.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAliasShadowingRejected(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
t.Setenv("HOME", t.TempDir())
|
||||||
|
|
||||||
|
s := Server{}
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Model: "shadowed-model",
|
||||||
|
RemoteHost: "example.com",
|
||||||
|
From: "test",
|
||||||
|
Info: map[string]any{
|
||||||
|
"capabilities": []string{"completion"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAliasResolvesForChatRemote(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
t.Setenv("HOME", t.TempDir())
|
||||||
|
|
||||||
|
var remoteModel string
|
||||||
|
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req api.ChatRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
remoteModel = req.Model
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
resp := api.ChatResponse{
|
||||||
|
Model: req.Model,
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "load",
|
||||||
|
}
|
||||||
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer rs.Close()
|
||||||
|
|
||||||
|
p, err := url.Parse(rs.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
||||||
|
|
||||||
|
s := Server{}
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Model: "target-model",
|
||||||
|
RemoteHost: rs.URL,
|
||||||
|
From: "test",
|
||||||
|
Info: map[string]any{
|
||||||
|
"capabilities": []string{"completion"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "alias-model",
|
||||||
|
Messages: []api.Message{{Role: "user", Content: "hi"}},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.ChatResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Model != "alias-model" {
|
||||||
|
t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if remoteModel != "test" {
|
||||||
|
t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user