mirror of
https://github.com/ollama/ollama.git
synced 2026-04-20 07:54:25 +02:00
Compare commits
8 Commits
revert-122
...
pdevine/pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c10a40db99 | ||
|
|
93c64ea1b1 | ||
|
|
3f6642f6fc | ||
|
|
6f7117145f | ||
|
|
92b96d54ef | ||
|
|
9d56e63dbf | ||
|
|
053092185e | ||
|
|
44a6792873 |
@@ -28,6 +28,7 @@ type bertModel struct {
|
|||||||
LayerNormEPS float32 `json:"layer_norm_eps"`
|
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||||
NormEpsilon float32 `json:"norm_epsilon"`
|
NormEpsilon float32 `json:"norm_epsilon"`
|
||||||
|
normalizeEmbeddings bool
|
||||||
|
|
||||||
PoolingType uint32
|
PoolingType uint32
|
||||||
}
|
}
|
||||||
@@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
|||||||
|
|
||||||
var pooling string
|
var pooling string
|
||||||
for _, m := range modules {
|
for _, m := range modules {
|
||||||
if m.Type == "sentence_transformers.models.Pooling" {
|
switch m.Type {
|
||||||
|
case "sentence_transformers.models.Pooling":
|
||||||
pooling = m.Path
|
pooling = m.Path
|
||||||
break
|
case "sentence_transformers.models.Normalize":
|
||||||
|
p.normalizeEmbeddings = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["general.architecture"] = "bert"
|
kv["general.architecture"] = "bert"
|
||||||
kv["bert.attention.causal"] = false
|
kv["bert.attention.causal"] = false
|
||||||
kv["bert.pooling_type"] = p.PoolingType
|
kv["bert.pooling_type"] = p.PoolingType
|
||||||
|
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
|
||||||
|
|
||||||
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
|
|||||||
go run . serve
|
go run . serve
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
|
||||||
|
|
||||||
|
|
||||||
## macOS (Apple Silicon)
|
## macOS (Apple Silicon)
|
||||||
|
|
||||||
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
||||||
|
|||||||
@@ -3,29 +3,15 @@ package harmony
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/template"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type harmonyParserState int
|
type harmonyParserState int
|
||||||
|
|
||||||
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
|
|
||||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
|
|
||||||
// heuristic to check whether the template expects to be parsed via harmony:
|
|
||||||
// search for harmony tags that are nearly always used
|
|
||||||
if template.Contains("<|start|>") && template.Contains("<|end|>") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||||
harmonyParserState_ParsingHeader
|
harmonyParserState_ParsingHeader
|
||||||
@@ -89,28 +75,18 @@ func (s *HarmonyParser) AddImplicitStart() {
|
|||||||
s.acc.WriteString("<|start|>assistant")
|
s.acc.WriteString("<|start|>assistant")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Prefill(lastMessage api.Message) string {
|
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
||||||
if lastMessage.Role != "assistant" {
|
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||||
return ""
|
// handle prefilling conditions
|
||||||
}
|
if lastMessage.Content != "" {
|
||||||
|
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||||
switch {
|
return
|
||||||
case strings.TrimSpace(lastMessage.Content) != "":
|
} else if lastMessage.Thinking != "" {
|
||||||
return "<|start|>assistant<|channel|>final<|message|>"
|
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||||
case strings.TrimSpace(lastMessage.Thinking) != "":
|
return
|
||||||
return "<|start|>assistant<|channel|>analysis<|message|>"
|
}
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided
|
|
||||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) {
|
|
||||||
if strings.TrimSpace(prefillString) != "" {
|
|
||||||
s.acc.WriteString(prefillString)
|
|
||||||
} else {
|
|
||||||
s.AddImplicitStart()
|
|
||||||
}
|
}
|
||||||
|
s.AddImplicitStart()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
|
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
|
||||||
@@ -289,7 +265,6 @@ type HarmonyMessageHandler struct {
|
|||||||
state harmonyMessageState
|
state harmonyMessageState
|
||||||
HarmonyParser *HarmonyParser
|
HarmonyParser *HarmonyParser
|
||||||
FunctionNameMap *FunctionNameMap
|
FunctionNameMap *FunctionNameMap
|
||||||
ToolParser *HarmonyToolCallAccumulator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHarmonyMessageHandler creates a new message handler
|
// NewHarmonyMessageHandler creates a new message handler
|
||||||
@@ -302,16 +277,12 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
|||||||
HeaderEndTag: "<|message|>",
|
HeaderEndTag: "<|message|>",
|
||||||
},
|
},
|
||||||
FunctionNameMap: NewFunctionNameMap(),
|
FunctionNameMap: NewFunctionNameMap(),
|
||||||
ToolParser: &HarmonyToolCallAccumulator{
|
|
||||||
state: harmonyToolCallState_Normal,
|
|
||||||
currentToolName: nil,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddContent processes the content and returns the content, thinking, and tool content.
|
// AddContent processes the content and returns the content, thinking, and tool content.
|
||||||
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
|
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
|
||||||
func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) {
|
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
|
||||||
contentSb := strings.Builder{}
|
contentSb := strings.Builder{}
|
||||||
thinkingSb := strings.Builder{}
|
thinkingSb := strings.Builder{}
|
||||||
toolContentSb := strings.Builder{}
|
toolContentSb := strings.Builder{}
|
||||||
@@ -328,14 +299,14 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
|
|||||||
// event.Header.Recipient is the tool name, something like
|
// event.Header.Recipient is the tool name, something like
|
||||||
// "browser.search" for a built-in, or "functions.calc" for a
|
// "browser.search" for a built-in, or "functions.calc" for a
|
||||||
// custom one
|
// custom one
|
||||||
h.ToolParser.SetToolName(event.Header.Recipient)
|
toolParser.SetToolName(event.Header.Recipient)
|
||||||
} else {
|
} else {
|
||||||
h.state = harmonyMessageState_Thinking
|
h.state = harmonyMessageState_Thinking
|
||||||
}
|
}
|
||||||
case "commentary":
|
case "commentary":
|
||||||
if event.Header.Recipient != "" {
|
if event.Header.Recipient != "" {
|
||||||
h.state = harmonyMessageState_ToolCalling
|
h.state = harmonyMessageState_ToolCalling
|
||||||
h.ToolParser.SetToolName(event.Header.Recipient)
|
toolParser.SetToolName(event.Header.Recipient)
|
||||||
} else {
|
} else {
|
||||||
h.state = harmonyMessageState_Normal
|
h.state = harmonyMessageState_Normal
|
||||||
}
|
}
|
||||||
@@ -358,6 +329,13 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
|
|||||||
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
|
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
|
||||||
|
return &HarmonyToolCallAccumulator{
|
||||||
|
state: harmonyToolCallState_Normal,
|
||||||
|
currentToolName: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type harmonyToolCallState int
|
type harmonyToolCallState int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package harmony
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -536,202 +535,3 @@ func TestFunctionConvertAndAdd(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
|
|
||||||
t.Run("thinking_then_content_streams", func(t *testing.T) {
|
|
||||||
handler := NewHarmonyMessageHandler()
|
|
||||||
handler.HarmonyParser.AddImplicitStart()
|
|
||||||
tp := handler.ToolParser
|
|
||||||
type step struct {
|
|
||||||
in string
|
|
||||||
wantContent string
|
|
||||||
wantThinking string
|
|
||||||
}
|
|
||||||
steps := []step{
|
|
||||||
{in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."},
|
|
||||||
{in: "<|end|>", wantThinking: ""},
|
|
||||||
{in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"},
|
|
||||||
{in: "<|end|>", wantContent: ""},
|
|
||||||
}
|
|
||||||
for i, s := range steps {
|
|
||||||
content, thinking, tool := handler.AddContent(s.in)
|
|
||||||
if tool != "" {
|
|
||||||
tp.Add(tool)
|
|
||||||
}
|
|
||||||
if content != s.wantContent || thinking != s.wantThinking {
|
|
||||||
t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("content_streams_as_it_arrives", func(t *testing.T) {
|
|
||||||
handler := NewHarmonyMessageHandler()
|
|
||||||
handler.HarmonyParser.AddImplicitStart()
|
|
||||||
tp := handler.ToolParser
|
|
||||||
inputs := []string{
|
|
||||||
"<|start|>assistant<|message|>Hello",
|
|
||||||
", world",
|
|
||||||
"!<|end|>",
|
|
||||||
}
|
|
||||||
var got []string
|
|
||||||
for _, in := range inputs {
|
|
||||||
content, thinking, tool := handler.AddContent(in)
|
|
||||||
if tool != "" {
|
|
||||||
tp.Add(tool)
|
|
||||||
}
|
|
||||||
if thinking != "" {
|
|
||||||
t.Fatalf("unexpected thinking %q", thinking)
|
|
||||||
}
|
|
||||||
if content != "" {
|
|
||||||
got = append(got, content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
want := []string{"Hello", ", world", "!"}
|
|
||||||
if !reflect.DeepEqual(got, want) {
|
|
||||||
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
|
|
||||||
handler := NewHarmonyMessageHandler()
|
|
||||||
handler.HarmonyParser.AddImplicitStart()
|
|
||||||
tp := handler.ToolParser
|
|
||||||
inputs := []string{
|
|
||||||
"<|channel|>analysis<|message|>Thinking...",
|
|
||||||
"<|end|>",
|
|
||||||
"<|start|>assistant<|message|>Answer",
|
|
||||||
"<|end|>",
|
|
||||||
}
|
|
||||||
var got []string
|
|
||||||
for _, in := range inputs {
|
|
||||||
content, thinking, tool := handler.AddContent(in)
|
|
||||||
if tool != "" {
|
|
||||||
tp.Add(tool)
|
|
||||||
}
|
|
||||||
if thinking != "" {
|
|
||||||
got = append(got, thinking)
|
|
||||||
}
|
|
||||||
if content != "" {
|
|
||||||
got = append(got, content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
want := []string{"Thinking...", "Answer"}
|
|
||||||
if !reflect.DeepEqual(got, want) {
|
|
||||||
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
|
|
||||||
handler := NewHarmonyMessageHandler()
|
|
||||||
handler.HarmonyParser.AddImplicitStart()
|
|
||||||
tp := handler.ToolParser
|
|
||||||
inputs := []string{
|
|
||||||
"<|chan",
|
|
||||||
"nel|>analysis<|mess",
|
|
||||||
"age|>Deep ",
|
|
||||||
"thought",
|
|
||||||
"<|end|>",
|
|
||||||
"<|start|>assistant<|message|>Done",
|
|
||||||
"<|end|>",
|
|
||||||
}
|
|
||||||
var thinkingPieces []string
|
|
||||||
var contentPieces []string
|
|
||||||
for _, in := range inputs {
|
|
||||||
content, thinking, tool := handler.AddContent(in)
|
|
||||||
if tool != "" {
|
|
||||||
tp.Add(tool)
|
|
||||||
}
|
|
||||||
if thinking != "" {
|
|
||||||
thinkingPieces = append(thinkingPieces, thinking)
|
|
||||||
}
|
|
||||||
if content != "" {
|
|
||||||
contentPieces = append(contentPieces, content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) {
|
|
||||||
t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want)
|
|
||||||
}
|
|
||||||
if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) {
|
|
||||||
t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("simple_assistant_after_analysis", func(t *testing.T) {
|
|
||||||
handler := NewHarmonyMessageHandler()
|
|
||||||
handler.HarmonyParser.AddImplicitStart()
|
|
||||||
tp := handler.ToolParser
|
|
||||||
inputs := []string{
|
|
||||||
"<|channel|>analysis<|message|>Think",
|
|
||||||
"<|end|>",
|
|
||||||
"<|start|>assistant<|message|>Answer",
|
|
||||||
"<|end|>",
|
|
||||||
}
|
|
||||||
var contentSb, thinkingSb strings.Builder
|
|
||||||
for _, in := range inputs {
|
|
||||||
content, thinking, tool := handler.AddContent(in)
|
|
||||||
if tool != "" {
|
|
||||||
tp.Add(tool)
|
|
||||||
}
|
|
||||||
contentSb.WriteString(content)
|
|
||||||
thinkingSb.WriteString(thinking)
|
|
||||||
}
|
|
||||||
if contentSb.String() != "Answer" {
|
|
||||||
t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer")
|
|
||||||
}
|
|
||||||
if thinkingSb.String() != "Think" {
|
|
||||||
t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
|
|
||||||
handler := NewHarmonyMessageHandler()
|
|
||||||
handler.HarmonyParser.AddImplicitStart()
|
|
||||||
tp := handler.ToolParser
|
|
||||||
inputs := []string{
|
|
||||||
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
|
|
||||||
}
|
|
||||||
for _, in := range inputs {
|
|
||||||
content, thinking, tool := handler.AddContent(in)
|
|
||||||
if content != "" || thinking != "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if tool != "" {
|
|
||||||
tp.Add(tool)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
name, args := tp.Drain()
|
|
||||||
if name == nil || *name != "functions.calculate" {
|
|
||||||
t.Fatalf("unexpected tool name: %v", name)
|
|
||||||
}
|
|
||||||
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
|
|
||||||
t.Fatalf("unexpected tool args: got %s want %s", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("tool_call_across_chunks", func(t *testing.T) {
|
|
||||||
handler := NewHarmonyMessageHandler()
|
|
||||||
handler.HarmonyParser.AddImplicitStart()
|
|
||||||
tp := handler.ToolParser
|
|
||||||
inputs := []string{
|
|
||||||
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
|
|
||||||
"2\"}",
|
|
||||||
"<|end|>",
|
|
||||||
}
|
|
||||||
for _, in := range inputs {
|
|
||||||
content, thinking, tool := handler.AddContent(in)
|
|
||||||
if content != "" || thinking != "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if tool != "" {
|
|
||||||
tp.Add(tool)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
name, args := tp.Drain()
|
|
||||||
if name == nil || *name != "functions.calculate" {
|
|
||||||
t.Fatalf("unexpected tool name: %v", name)
|
|
||||||
}
|
|
||||||
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
|
|
||||||
t.Fatalf("unexpected tool args: got %s want %s", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
|
|||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: smol,
|
Model: smol,
|
||||||
Prompt: "Write me a story with a ton of emojis?",
|
Prompt: "Write me a story in english with a lot of emojis",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]any{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
|
|||||||
@@ -561,7 +561,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
}, {
|
}, {
|
||||||
Model: smol,
|
Model: smol,
|
||||||
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
|
Prompt: "how do rainbows form? Be brief but factual in your reply",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
}, {
|
}, {
|
||||||
@@ -579,9 +579,9 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
[][]string{
|
[][]string{
|
||||||
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
||||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
||||||
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
|
{"water", "droplet", "refracted", "reflect", "color", "spectrum"},
|
||||||
{"fourth", "july", "declaration", "independence"},
|
{"fourth", "july", "declaration", "independence"},
|
||||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
|
|||||||
}
|
}
|
||||||
nChunks := C.mtmd_input_chunks_size(ic)
|
nChunks := C.mtmd_input_chunks_size(ic)
|
||||||
numEmbed := llamaContext.Model().NEmbd()
|
numEmbed := llamaContext.Model().NEmbd()
|
||||||
lastChunkSize := 0
|
embed := make([][]float32, 0)
|
||||||
for i := range int(nChunks) {
|
for i := range int(nChunks) {
|
||||||
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
|
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
|
||||||
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
|
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
|
||||||
lastChunkSize = numTokens
|
slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
|
||||||
|
|
||||||
// Encode the chunk
|
// Encode the chunk
|
||||||
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
|
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
|
||||||
return nil, errors.New("unable to encode mtmd image chunk")
|
return nil, errors.New("unable to encode mtmd image chunk")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Get the embeddings
|
// Get the embeddings for this chunk
|
||||||
embed := make([][]float32, lastChunkSize)
|
chunkEmbed := make([][]float32, numTokens)
|
||||||
embd := C.mtmd_get_output_embd(c.c)
|
chunkEmbd := C.mtmd_get_output_embd(c.c)
|
||||||
if nil == embd {
|
if nil == chunkEmbd {
|
||||||
return nil, errors.New("failed to get image embedding")
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extend the embedding array for each token
|
// Extend the embedding array for each token
|
||||||
s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize)
|
s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
|
||||||
rows := make([]float32, len(s))
|
rows := make([]float32, len(s))
|
||||||
copy(rows, s)
|
copy(rows, s)
|
||||||
for i := range lastChunkSize {
|
for i := range numTokens {
|
||||||
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||||
|
}
|
||||||
|
embed = append(embed, chunkEmbed...)
|
||||||
}
|
}
|
||||||
|
slog.Debug("image embeddings", "totalEmbeddings", len(embed))
|
||||||
return embed, nil
|
return embed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ import (
|
|||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/parser"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type filteredEnv []string
|
type filteredEnv []string
|
||||||
@@ -1349,9 +1348,7 @@ type CompletionRequest struct {
|
|||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
Options *api.Options
|
||||||
|
|
||||||
Grammar string // set before sending the request to the subprocess
|
Grammar string // set before sending the request to the subprocess
|
||||||
ParserType parser.TokenParserType
|
|
||||||
PrefillString string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoneReason represents the reason why a completion response is done
|
// DoneReason represents the reason why a completion response is done
|
||||||
@@ -1378,15 +1375,13 @@ func (d DoneReason) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Thinking string `json:"thinking"`
|
DoneReason DoneReason `json:"done_reason"`
|
||||||
ToolCalls []api.ToolCall `json:"tool_calls"`
|
Done bool `json:"done"`
|
||||||
DoneReason DoneReason `json:"done_reason"`
|
PromptEvalCount int `json:"prompt_eval_count"`
|
||||||
Done bool `json:"done"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||||
PromptEvalCount int `json:"prompt_eval_count"`
|
EvalCount int `json:"eval_count"`
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
EvalCount int `json:"eval_count"`
|
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
@@ -1504,8 +1499,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
// TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future
|
case strings.TrimSpace(c.Content) == lastToken:
|
||||||
case strings.TrimSpace(c.Content) == lastToken && c.Content != "":
|
|
||||||
tokenRepeat++
|
tokenRepeat++
|
||||||
default:
|
default:
|
||||||
lastToken = strings.TrimSpace(c.Content)
|
lastToken = strings.TrimSpace(c.Content)
|
||||||
@@ -1518,14 +1512,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.Content != "" {
|
||||||
|
fn(CompletionResponse{
|
||||||
|
Content: c.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if c.Done {
|
if c.Done {
|
||||||
fn(c)
|
fn(c)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
|
|
||||||
fn(c)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -416,6 +416,7 @@ type Tensor interface {
|
|||||||
AddID(ctx Context, t2, ids Tensor) Tensor
|
AddID(ctx Context, t2, ids Tensor) Tensor
|
||||||
|
|
||||||
Softmax(ctx Context) Tensor
|
Softmax(ctx Context) Tensor
|
||||||
|
L2Norm(ctx Context, eps float32) Tensor
|
||||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||||
Scale(ctx Context, s float64) Tensor
|
Scale(ctx Context, s float64) Tensor
|
||||||
|
|||||||
@@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||||
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
||||||
if w != nil {
|
if w != nil {
|
||||||
|
|||||||
36
ml/nn/pooling/pooling.go
Normal file
36
ml/nn/pooling/pooling.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package pooling
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Type uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeNone Type = iota
|
||||||
|
TypeMean
|
||||||
|
TypeCLS
|
||||||
|
TypeLast
|
||||||
|
TypeRank
|
||||||
|
|
||||||
|
TypeUnknown = 0xFFFFFFFE
|
||||||
|
TypeUnspecified = 0xFFFFFFFF
|
||||||
|
)
|
||||||
|
|
||||||
|
func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor {
|
||||||
|
switch poolingType {
|
||||||
|
case TypeNone:
|
||||||
|
return hiddenStates
|
||||||
|
case TypeMean:
|
||||||
|
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||||
|
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
|
case TypeCLS:
|
||||||
|
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
||||||
|
case TypeLast:
|
||||||
|
panic("not implemented")
|
||||||
|
case TypeRank:
|
||||||
|
panic("not implemented")
|
||||||
|
default:
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -54,10 +54,9 @@ type Batch struct {
|
|||||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||||
Inputs ml.Tensor
|
Inputs ml.Tensor
|
||||||
|
|
||||||
// Multimodal is a set of multimodal embeddings previously created by
|
// Outputs are the set of indicies into Inputs for which output data should
|
||||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
// be returned.
|
||||||
// models or for batches without multimodal elements.
|
Outputs ml.Tensor
|
||||||
Multimodal []MultimodalIndex
|
|
||||||
|
|
||||||
// Positions is the position for each Input, relative to its sequence. Equal
|
// Positions is the position for each Input, relative to its sequence. Equal
|
||||||
// in length to Inputs.
|
// in length to Inputs.
|
||||||
@@ -66,7 +65,8 @@ type Batch struct {
|
|||||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||||
Sequences []int
|
Sequences []int
|
||||||
|
|
||||||
// Outputs are the set of indicies into Inputs for which output data should
|
// Multimodal is a set of multimodal embeddings previously created by
|
||||||
// be returned.
|
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||||
Outputs []int32
|
// models or for batches without multimodal elements.
|
||||||
|
Multimodal []MultimodalIndex
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,11 @@ import (
|
|||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
var (
|
||||||
|
ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||||
|
ErrUnsupportedModel = errors.New("model not supported")
|
||||||
|
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
|
||||||
|
)
|
||||||
|
|
||||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||||
type Model interface {
|
type Model interface {
|
||||||
@@ -242,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
|
|||||||
vv = vv.Elem()
|
vv = vv.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
vv = vv.Elem()
|
vv = reflect.Indirect(vv)
|
||||||
if v.IsNil() {
|
if v.IsNil() {
|
||||||
vv = reflect.New(v.Type().Elem()).Elem()
|
vv = reflect.New(v.Type().Elem()).Elem()
|
||||||
}
|
}
|
||||||
|
|||||||
181
model/models/bert/model.go
Normal file
181
model/models/bert/model.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package bert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.TextProcessor
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||||
|
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||||
|
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
|
||||||
|
|
||||||
|
Layers []EncoderLayer `gguf:"blk"`
|
||||||
|
|
||||||
|
Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implements model.Model.
|
||||||
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
|
||||||
|
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
|
||||||
|
for _, layer := range m.Layers {
|
||||||
|
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
|
||||||
|
if m.normalize {
|
||||||
|
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hiddenStates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncoderLayer struct {
|
||||||
|
*Attention
|
||||||
|
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
|
||||||
|
|
||||||
|
*MLP
|
||||||
|
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
// Attention
|
||||||
|
residual := hiddenStates
|
||||||
|
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
|
||||||
|
// MLP
|
||||||
|
residual = hiddenStates
|
||||||
|
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
|
||||||
|
return hiddenStates
|
||||||
|
}
|
||||||
|
|
||||||
|
type Attention struct {
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
|
||||||
|
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
|
||||||
|
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
batchSize := hiddenStates.Dim(1)
|
||||||
|
|
||||||
|
query := a.Query.Forward(ctx, hiddenStates)
|
||||||
|
if a.QueryNorm != nil {
|
||||||
|
query = a.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
|
}
|
||||||
|
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||||
|
|
||||||
|
key := a.Key.Forward(ctx, hiddenStates)
|
||||||
|
if a.KeyNorm != nil {
|
||||||
|
key = a.KeyNorm.Forward(ctx, key, opts.eps)
|
||||||
|
}
|
||||||
|
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||||
|
|
||||||
|
value := a.Value.Forward(ctx, hiddenStates)
|
||||||
|
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||||
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
return a.Output.Forward(ctx, attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP struct {
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
hiddenSize,
|
||||||
|
numHeads,
|
||||||
|
numKVHeads,
|
||||||
|
keyLength,
|
||||||
|
valueLength int
|
||||||
|
poolingType pooling.Type
|
||||||
|
eps float32
|
||||||
|
normalize bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Options) headDim() int {
|
||||||
|
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
var processor model.TextProcessor
|
||||||
|
switch c.String("tokenizer.ggml.model", "bert") {
|
||||||
|
case "bert":
|
||||||
|
processor = model.NewWordPiece(
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{
|
||||||
|
int32(cmp.Or(
|
||||||
|
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||||
|
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||||
|
EOS: []int32{
|
||||||
|
int32(cmp.Or(
|
||||||
|
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||||
|
//nolint:misspell
|
||||||
|
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
|
||||||
|
// support it for compatibility.
|
||||||
|
c.Uint("tokenizer.ggml.seperator_token_id"),
|
||||||
|
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
return nil, model.ErrUnsupportedTokenizer
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Model{
|
||||||
|
TextProcessor: processor,
|
||||||
|
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||||
|
Options: Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
eps: c.Float("attention.layer_norm_epsilon"),
|
||||||
|
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||||
|
normalize: c.Bool("normalize_embeddings", true),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("bert", New)
|
||||||
|
model.Register("bert_embed", New)
|
||||||
|
}
|
||||||
@@ -24,7 +24,7 @@ type Options struct {
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePieceModel
|
model.SentencePiece
|
||||||
|
|
||||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
Layers []Layer `gguf:"blk"`
|
Layers []Layer `gguf:"blk"`
|
||||||
@@ -40,7 +40,7 @@ const (
|
|||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePiece: model.NewSentencePiece(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||||
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
var lastLayerOutputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
lastLayerOutputs = outputs
|
lastLayerOutputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||||
|
|||||||
@@ -1,49 +1,38 @@
|
|||||||
package gemma3
|
package gemma3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
type embedModel struct {
|
type embedModel struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePieceModel
|
model.SentencePiece
|
||||||
|
|
||||||
*TextModel
|
*TextModel
|
||||||
PoolingType uint32
|
poolingType pooling.Type
|
||||||
|
|
||||||
Dense [2]*nn.Linear `gguf:"dense"`
|
Dense [2]*nn.Linear `gguf:"dense"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
batch.Outputs = batch.Positions // return all positions
|
|
||||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||||
|
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
|
||||||
switch m.PoolingType {
|
|
||||||
case 0: // None
|
|
||||||
case 1: // Mean
|
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
||||||
default:
|
|
||||||
return nil, errors.New("unsupported pooling type")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dense := range m.Dense {
|
for _, dense := range m.Dense {
|
||||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||||
return hiddenStates, nil
|
return hiddenStates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||||
m := &embedModel{
|
m := &embedModel{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePiece: model.NewSentencePiece(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
@@ -61,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
TextModel: newTextModel(c),
|
TextModel: newTextModel(c),
|
||||||
PoolingType: c.Uint("pooling_type", 0),
|
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Cache = kvcache.NewWrapperCache(
|
m.Cache = kvcache.NewWrapperCache(
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePieceModel
|
model.SentencePiece
|
||||||
|
|
||||||
*VisionModel `gguf:"v"`
|
*VisionModel `gguf:"v"`
|
||||||
*TextModel
|
*TextModel
|
||||||
@@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
|||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePiece: model.NewSentencePiece(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
|||||||
@@ -161,7 +161,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||||
@@ -194,7 +193,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
var lastLayerOutputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
lastLayerOutputs = outputs
|
lastLayerOutputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePieceModel
|
model.SentencePiece
|
||||||
|
|
||||||
*TextModel
|
*TextModel
|
||||||
}
|
}
|
||||||
@@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
TextModel: newTextModel(c),
|
TextModel: newTextModel(c),
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePiece: model.NewSentencePiece(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||||||
|
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
|
||||||
|
|
||||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
|
if i == len(m.TransformerBlocks)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
||||||
|
|||||||
@@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
// TODO: attention mask, cross attention mask
|
// TODO: attention mask, cross attention mask
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
_ "github.com/ollama/ollama/model/models/bert"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
||||||
|
|||||||
@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||||
|
|||||||
@@ -12,18 +12,18 @@ import (
|
|||||||
|
|
||||||
const spmWhitespaceSep = "▁"
|
const spmWhitespaceSep = "▁"
|
||||||
|
|
||||||
type SentencePieceModel struct {
|
type SentencePiece struct {
|
||||||
maxTokenLen int
|
maxTokenLen int
|
||||||
vocab *Vocabulary
|
vocab *Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
var _ TextProcessor = (*SentencePiece)(nil)
|
||||||
|
|
||||||
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
|
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||||
return spm.vocab
|
return spm.vocab
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||||
|
|
||||||
counter := map[int]int{}
|
counter := map[int]int{}
|
||||||
@@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
|||||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||||
"max token len", maxTokenLen)
|
"max token len", maxTokenLen)
|
||||||
|
|
||||||
return SentencePieceModel{
|
return SentencePiece{
|
||||||
maxTokenLen: maxTokenLen,
|
maxTokenLen: maxTokenLen,
|
||||||
vocab: vocab,
|
vocab: vocab,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||||
return spm.vocab.Is(id, special)
|
return spm.vocab.Is(id, special)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
fragments := []fragment{{value: s}}
|
fragments := []fragment{{value: s}}
|
||||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||||
id := spm.vocab.Encode(special)
|
id := spm.vocab.Encode(special)
|
||||||
@@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} {
|
|||||||
return item
|
return item
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
data := spm.vocab.Decode(id)
|
data := spm.vocab.Decode(id)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/ollama/ollama/convert/sentencepiece"
|
"github.com/ollama/ollama/convert/sentencepiece"
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||||
@@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewSentencePieceModel(&v)
|
return NewSentencePiece(&v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSentencePieceEncode(t *testing.T) {
|
func TestSentencePieceEncode(t *testing.T) {
|
||||||
@@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
func TestSentencePieceDecodeByteTokens(t *testing.T) {
|
||||||
vocab := &Vocabulary{
|
vocab := &Vocabulary{
|
||||||
Values: []string{
|
Values: []string{
|
||||||
"normal",
|
"normal",
|
||||||
@@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
|||||||
Scores: []float32{0, 0, 0, 0, 0},
|
Scores: []float32{0, 0, 0, 0, 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
spm := NewSentencePieceModel(vocab)
|
spm := NewSentencePiece(vocab)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
167
model/wordpiece.go
Normal file
167
model/wordpiece.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WordPiece struct {
|
||||||
|
vocab *Vocabulary
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
|
||||||
|
// this differs from original word piece which uses "##" to indicate subwords.
|
||||||
|
const ggmlPrefix = "▁"
|
||||||
|
|
||||||
|
var wordPieceReplacer = strings.NewReplacer(
|
||||||
|
" .", ".",
|
||||||
|
" ?", "?",
|
||||||
|
" !", "!",
|
||||||
|
" ,", ",",
|
||||||
|
" ' ", "'",
|
||||||
|
" n't", "n't",
|
||||||
|
" 'm", "'m",
|
||||||
|
" do not", " don't",
|
||||||
|
" 's", "'s",
|
||||||
|
" 've", "'ve",
|
||||||
|
" 're", "'re",
|
||||||
|
)
|
||||||
|
|
||||||
|
// Decode implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
for i, id := range ids {
|
||||||
|
if id < 0 || int(id) >= len(wpm.vocab.Values) {
|
||||||
|
return "", fmt.Errorf("invalid token id: %d", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
var separator string
|
||||||
|
piece := wpm.vocab.Values[id]
|
||||||
|
if i > 0 &&
|
||||||
|
(strings.HasPrefix(piece, ggmlPrefix) ||
|
||||||
|
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
|
||||||
|
separator = " "
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// words splits a string into words, treating CJK characters as separate words.
|
||||||
|
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
|
||||||
|
func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||||
|
return func(yield func(string) bool) {
|
||||||
|
runes := make([]rune, 0, len(s)*3)
|
||||||
|
for _, r := range s {
|
||||||
|
switch {
|
||||||
|
case r >= 0x4E00 && r <= 0x9FFF,
|
||||||
|
r >= 0x3400 && r <= 0x4DBF,
|
||||||
|
r >= 0x20000 && r <= 0x2A6DF,
|
||||||
|
r >= 0x2A700 && r <= 0x2B73F,
|
||||||
|
r >= 0x2B740 && r <= 0x2B81F,
|
||||||
|
r >= 0x2B820 && r <= 0x2CEAF,
|
||||||
|
r >= 0xF900 && r <= 0xFAFF,
|
||||||
|
r >= 0x2F800 && r <= 0x2FA1F:
|
||||||
|
runes = append(runes, ' ', r, ' ')
|
||||||
|
default:
|
||||||
|
runes = append(runes, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
|
||||||
|
// split on but keep punctuation
|
||||||
|
var start int
|
||||||
|
for start < len(w) {
|
||||||
|
end := strings.IndexFunc(w[start:], unicode.IsPunct)
|
||||||
|
if end < 0 {
|
||||||
|
end = len(w) - start
|
||||||
|
} else if end == 0 {
|
||||||
|
end = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if !yield(w[start : start+end]) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
start += end
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
|
var ids []int32
|
||||||
|
|
||||||
|
// TODO: use [UNK] from config
|
||||||
|
unk := wpm.vocab.Encode("[UNK]")
|
||||||
|
for word := range wpm.words(s) {
|
||||||
|
var start int
|
||||||
|
var pieces []int32
|
||||||
|
for start < len(word) {
|
||||||
|
end := len(word)
|
||||||
|
|
||||||
|
var piece int32
|
||||||
|
for start < end {
|
||||||
|
subword := word[start:end]
|
||||||
|
if start == 0 {
|
||||||
|
subword = ggmlPrefix + subword
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: some models might not want [ToLower]
|
||||||
|
piece = wpm.vocab.Encode(strings.ToLower(subword))
|
||||||
|
if piece >= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
end--
|
||||||
|
}
|
||||||
|
|
||||||
|
if piece < 0 {
|
||||||
|
// Unknown token
|
||||||
|
pieces = pieces[:0]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
pieces = append(pieces, piece)
|
||||||
|
start = end
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pieces) > 0 {
|
||||||
|
ids = append(ids, pieces...)
|
||||||
|
} else {
|
||||||
|
ids = append(ids, unk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if addSpecial && len(ids) > 0 {
|
||||||
|
ids = wpm.vocab.addSpecials(ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||||
|
return wpm.vocab.Is(id, special)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vocabulary implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||||
|
return wpm.vocab
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ TextProcessor = (*WordPiece)(nil)
|
||||||
|
|
||||||
|
func NewWordPiece(vocab *Vocabulary) WordPiece {
|
||||||
|
return WordPiece{
|
||||||
|
vocab: vocab,
|
||||||
|
}
|
||||||
|
}
|
||||||
51
model/wordpiece_test.go
Normal file
51
model/wordpiece_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWordPiece(t *testing.T) {
|
||||||
|
wpm := NewWordPiece(
|
||||||
|
&Vocabulary{
|
||||||
|
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: true,
|
||||||
|
BOS: []int32{1},
|
||||||
|
EOS: []int32{2},
|
||||||
|
})
|
||||||
|
|
||||||
|
ids, err := wpm.Encode("Hello world!", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||||
|
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
words, err := wpm.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWordPieceWords(t *testing.T) {
|
||||||
|
var wpm WordPiece
|
||||||
|
|
||||||
|
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||||
|
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||||
|
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
110
parser/parser.go
110
parser/parser.go
@@ -62,14 +62,15 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
for _, c := range f.Commands {
|
for _, c := range f.Commands {
|
||||||
switch c.Name {
|
switch c.Name {
|
||||||
case "model":
|
case "model":
|
||||||
path, err := expandPath(c.Args, relativeDir)
|
name := c.Args.(string)
|
||||||
|
path, err := expandPath(name, relativeDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
digestMap, err := fileDigestMap(path)
|
digestMap, err := fileDigestMap(path)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
req.From = c.Args
|
req.From = name
|
||||||
continue
|
continue
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -83,7 +84,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "adapter":
|
case "adapter":
|
||||||
path, err := expandPath(c.Args, relativeDir)
|
adapter := c.Args.(string)
|
||||||
|
path, err := expandPath(adapter, relativeDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -95,21 +97,25 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
|
|
||||||
req.Adapters = digestMap
|
req.Adapters = digestMap
|
||||||
case "template":
|
case "template":
|
||||||
req.Template = c.Args
|
template := c.Args.(string)
|
||||||
|
req.Template = template
|
||||||
case "system":
|
case "system":
|
||||||
req.System = c.Args
|
system := c.Args.(string)
|
||||||
|
req.System = system
|
||||||
case "license":
|
case "license":
|
||||||
licenses = append(licenses, c.Args)
|
license := c.Args.(string)
|
||||||
|
licenses = append(licenses, license)
|
||||||
case "message":
|
case "message":
|
||||||
role, msg, _ := strings.Cut(c.Args, ": ")
|
msg := c.Args.(*Message)
|
||||||
messages = append(messages, api.Message{Role: role, Content: msg})
|
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
||||||
default:
|
case "parameter":
|
||||||
if slices.Contains(deprecatedParameters, c.Name) {
|
if slices.Contains(deprecatedParameters, c.Name) {
|
||||||
fmt.Printf("warning: parameter %s is deprecated\n", c.Name)
|
fmt.Printf("warning: parameter '%s' is deprecated\n", c.Name)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
|
param := c.Args.(*Parameter)
|
||||||
|
ps, err := api.FormatParams(map[string][]string{param.Name: {param.Value}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -123,6 +129,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
params[k] = v
|
params[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("warning: unknown command '%s'", c.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,7 +320,17 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
|
|
||||||
type Command struct {
|
type Command struct {
|
||||||
Name string
|
Name string
|
||||||
Args string
|
Args any
|
||||||
|
}
|
||||||
|
|
||||||
|
type Parameter struct {
|
||||||
|
Name string
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string
|
||||||
|
Content string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Command) String() string {
|
func (c Command) String() string {
|
||||||
@@ -321,12 +339,16 @@ func (c Command) String() string {
|
|||||||
case "model":
|
case "model":
|
||||||
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
||||||
case "license", "template", "system", "adapter":
|
case "license", "template", "system", "adapter":
|
||||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
data := c.Args.(string)
|
||||||
|
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(data))
|
||||||
case "message":
|
case "message":
|
||||||
role, message, _ := strings.Cut(c.Args, ": ")
|
data := c.Args.(*Message)
|
||||||
fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message))
|
fmt.Fprintf(&sb, "MESSAGE %s %s", data.Role, quote(data.Content))
|
||||||
|
case "parameter":
|
||||||
|
data := c.Args.(*Parameter)
|
||||||
|
fmt.Fprintf(&sb, "PARAMETER %s %s", data.Name, quote(data.Value))
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args))
|
fmt.Printf("unknown command '%s'\n", c.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String()
|
return sb.String()
|
||||||
@@ -366,7 +388,6 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
|||||||
var curr state
|
var curr state
|
||||||
var currLine int = 1
|
var currLine int = 1
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
var role string
|
|
||||||
|
|
||||||
var f Modelfile
|
var f Modelfile
|
||||||
|
|
||||||
@@ -413,6 +434,7 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
|||||||
case "parameter":
|
case "parameter":
|
||||||
// transition to stateParameter which sets command name
|
// transition to stateParameter which sets command name
|
||||||
next = stateParameter
|
next = stateParameter
|
||||||
|
cmd.Name = s
|
||||||
case "message":
|
case "message":
|
||||||
// transition to stateMessage which validates the message role
|
// transition to stateMessage which validates the message role
|
||||||
next = stateMessage
|
next = stateMessage
|
||||||
@@ -421,16 +443,37 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
|||||||
cmd.Name = s
|
cmd.Name = s
|
||||||
}
|
}
|
||||||
case stateParameter:
|
case stateParameter:
|
||||||
cmd.Name = b.String()
|
s, ok := unquote(strings.TrimSpace(b.String()))
|
||||||
|
if !ok || isSpace(r) {
|
||||||
|
if _, err := b.WriteRune(r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cmd.Args = &Parameter{
|
||||||
|
Name: s,
|
||||||
|
}
|
||||||
case stateMessage:
|
case stateMessage:
|
||||||
if !isValidMessageRole(b.String()) {
|
s, ok := unquote(strings.TrimSpace(b.String()))
|
||||||
|
if !ok || isSpace(r) {
|
||||||
|
if _, err := b.WriteRune(r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidMessageRole(s) {
|
||||||
return nil, &ParserError{
|
return nil, &ParserError{
|
||||||
LineNumber: currLine,
|
LineNumber: currLine,
|
||||||
Msg: errInvalidMessageRole.Error(),
|
Msg: errInvalidMessageRole.Error(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
role = b.String()
|
cmd.Args = &Message{
|
||||||
|
Role: s,
|
||||||
|
}
|
||||||
case stateComment, stateNil:
|
case stateComment, stateNil:
|
||||||
// pass
|
// pass
|
||||||
case stateValue:
|
case stateValue:
|
||||||
@@ -443,12 +486,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if role != "" {
|
switch cmd.Name {
|
||||||
s = role + ": " + s
|
case "parameter":
|
||||||
role = ""
|
p := cmd.Args.(*Parameter)
|
||||||
|
p.Value = s
|
||||||
|
case "message":
|
||||||
|
m := cmd.Args.(*Message)
|
||||||
|
m.Content = s
|
||||||
|
default:
|
||||||
|
cmd.Args = s
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Args = s
|
|
||||||
f.Commands = append(f.Commands, cmd)
|
f.Commands = append(f.Commands, cmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -473,11 +520,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
|||||||
return nil, io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
|
||||||
if role != "" {
|
switch cmd.Name {
|
||||||
s = role + ": " + s
|
case "parameter":
|
||||||
|
c := cmd.Args.(*Parameter)
|
||||||
|
c.Value = s
|
||||||
|
case "message":
|
||||||
|
c := cmd.Args.(*Message)
|
||||||
|
c.Content = s
|
||||||
|
default:
|
||||||
|
cmd.Args = s
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Args = s
|
|
||||||
f.Commands = append(f.Commands, cmd)
|
f.Commands = append(f.Commands, cmd)
|
||||||
default:
|
default:
|
||||||
return nil, io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
|||||||
{Name: "model", Args: "model1"},
|
{Name: "model", Args: "model1"},
|
||||||
{Name: "adapter", Args: "adapter1"},
|
{Name: "adapter", Args: "adapter1"},
|
||||||
{Name: "license", Args: "MIT"},
|
{Name: "license", Args: "MIT"},
|
||||||
{Name: "param1", Args: "value1"},
|
{Name: "parameter", Args: &Parameter{"param1", "value1"}},
|
||||||
{Name: "param2", Args: "value2"},
|
{Name: "parameter", Args: &Parameter{"param2", "value2"}},
|
||||||
{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
|
{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,8 +80,8 @@ TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
|
|||||||
{Name: "model", Args: " model 1"},
|
{Name: "model", Args: " model 1"},
|
||||||
{Name: "adapter", Args: "adapter3"},
|
{Name: "adapter", Args: "adapter3"},
|
||||||
{Name: "license", Args: "MIT "},
|
{Name: "license", Args: "MIT "},
|
||||||
{Name: "param1", Args: "value1"},
|
{Name: "parameter", Args: &Parameter{"param1", "value1"}},
|
||||||
{Name: "param2", Args: "value2"},
|
{Name: "parameter", Args: &Parameter{"param2", "value2"}},
|
||||||
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
|
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,7 +101,7 @@ func TestParseFileFrom(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"FROM \"FOO BAR\"\nPARAMETER param1 value1",
|
"FROM \"FOO BAR\"\nPARAMETER param1 value1",
|
||||||
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
|
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "parameter", Args: &Parameter{"param1", "value1"}}},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -149,12 +149,12 @@ func TestParseFileFrom(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"PARAMETER param1 value1\nFROM foo",
|
"PARAMETER param1 value1\nFROM foo",
|
||||||
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
|
[]Command{{Name: "parameter", Args: &Parameter{"param1", "value1"}}, {Name: "model", Args: "foo"}},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"PARAMETER what the \nFROM lemons make lemonade ",
|
"PARAMETER what the \nFROM lemons make lemonade ",
|
||||||
[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
|
[]Command{{Name: "parameter", Args: &Parameter{"what", "the"}}, {Name: "model", Args: "lemons make lemonade"}},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -211,7 +211,7 @@ MESSAGE system You are a file parser. Always parse things.
|
|||||||
`,
|
`,
|
||||||
[]Command{
|
[]Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
@@ -221,7 +221,7 @@ FROM foo
|
|||||||
MESSAGE system You are a file parser. Always parse things.`,
|
MESSAGE system You are a file parser. Always parse things.`,
|
||||||
[]Command{
|
[]Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
@@ -234,9 +234,9 @@ MESSAGE assistant Hello, I want to parse all the things!
|
|||||||
`,
|
`,
|
||||||
[]Command{
|
[]Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
|
||||||
{Name: "message", Args: "user: Hey there!"},
|
{Name: "message", Args: &Message{"user", "Hey there!"}},
|
||||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
{Name: "message", Args: &Message{"assistant", "Hello, I want to parse all the things!"}},
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
@@ -244,12 +244,12 @@ MESSAGE assistant Hello, I want to parse all the things!
|
|||||||
`
|
`
|
||||||
FROM foo
|
FROM foo
|
||||||
MESSAGE system """
|
MESSAGE system """
|
||||||
You are a multiline file parser. Always parse things.
|
You are a multiline file "parser". Always parse things.
|
||||||
"""
|
"""
|
||||||
`,
|
`,
|
||||||
[]Command{
|
[]Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
{Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"},
|
{Name: "message", Args: &Message{"system", "\nYou are a multiline file \"parser\". Always parse things.\n"}},
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
@@ -514,7 +514,7 @@ func TestParseFileParameters(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, []Command{
|
assert.Equal(t, []Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
{Name: v.name, Args: v.value},
|
{Name: "parameter", Args: &Parameter{v.name, v.value}},
|
||||||
}, modelfile.Commands)
|
}, modelfile.Commands)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -617,8 +617,8 @@ SYSTEM You are a utf16 file.
|
|||||||
|
|
||||||
expected := []Command{
|
expected := []Command{
|
||||||
{Name: "model", Args: "bob"},
|
{Name: "model", Args: "bob"},
|
||||||
{Name: "param1", Args: "1"},
|
{Name: "parameter", Args: &Parameter{"param1", "1"}},
|
||||||
{Name: "param2", Args: "4096"},
|
{Name: "parameter", Args: &Parameter{"param2", "4096"}},
|
||||||
{Name: "system", Args: "You are a utf16 file."},
|
{Name: "system", Args: "You are a utf16 file."},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,126 +0,0 @@
|
|||||||
package parser
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/harmony"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TokenParserType int
|
|
||||||
|
|
||||||
const (
|
|
||||||
TokenParserTypeDefault TokenParserType = iota
|
|
||||||
TokenParserTypeHarmony
|
|
||||||
)
|
|
||||||
|
|
||||||
type TokenParser struct {
|
|
||||||
messageHandler MessageHandler
|
|
||||||
parserEngine ParserInternals
|
|
||||||
toolParser ToolParser
|
|
||||||
lastToken string
|
|
||||||
tokenRepeat int
|
|
||||||
repeatLimit int
|
|
||||||
}
|
|
||||||
|
|
||||||
const defaultTokenRepeatLimit = 30
|
|
||||||
|
|
||||||
type MessageHandler interface {
|
|
||||||
AddContent(token string) (content, thinking string, toolContent string)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ParserInternals interface {
|
|
||||||
AddImplicitStartOrPrefill(prefillString string)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParser interface {
|
|
||||||
Add(token string)
|
|
||||||
Drain() (toolName *string, toolContent string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default implementation for the TokenParser interface as a no-op passthrough
|
|
||||||
type defaultMessageHandler struct{}
|
|
||||||
|
|
||||||
func (defaultMessageHandler) AddContent(token string) (string, string, string) {
|
|
||||||
return token, "", ""
|
|
||||||
}
|
|
||||||
|
|
||||||
type defaultEngine struct{}
|
|
||||||
|
|
||||||
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
|
|
||||||
|
|
||||||
type defaultToolParser struct{}
|
|
||||||
|
|
||||||
func (defaultToolParser) Add(token string) {}
|
|
||||||
|
|
||||||
func (defaultToolParser) Drain() (*string, string) { return nil, "" }
|
|
||||||
|
|
||||||
func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser {
|
|
||||||
switch parserType {
|
|
||||||
case TokenParserTypeHarmony:
|
|
||||||
harmonyMessageHandler := harmony.NewHarmonyMessageHandler()
|
|
||||||
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString)
|
|
||||||
return TokenParser{
|
|
||||||
messageHandler: harmonyMessageHandler,
|
|
||||||
parserEngine: harmonyMessageHandler.HarmonyParser,
|
|
||||||
toolParser: harmonyMessageHandler.ToolParser,
|
|
||||||
repeatLimit: defaultTokenRepeatLimit,
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return TokenParser{
|
|
||||||
messageHandler: defaultMessageHandler{},
|
|
||||||
parserEngine: defaultEngine{},
|
|
||||||
toolParser: defaultToolParser{},
|
|
||||||
repeatLimit: 30,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *TokenParser) AddContent(token string) (string, string, error) {
|
|
||||||
if p.repeatLimitReached(token) {
|
|
||||||
return "", "", errors.New("token repeat limit reached")
|
|
||||||
}
|
|
||||||
content, thinking, toolContent := p.messageHandler.AddContent(token)
|
|
||||||
p.toolParser.Add(toolContent)
|
|
||||||
return content, thinking, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached.
|
|
||||||
func (p *TokenParser) repeatLimitReached(token string) bool {
|
|
||||||
if p == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
trimmed := strings.TrimSpace(token)
|
|
||||||
if trimmed == p.lastToken {
|
|
||||||
p.tokenRepeat++
|
|
||||||
} else {
|
|
||||||
p.tokenRepeat = 0
|
|
||||||
}
|
|
||||||
p.lastToken = trimmed
|
|
||||||
|
|
||||||
return p.tokenRepeat >= p.repeatLimit
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
|
|
||||||
func (p *TokenParser) Drain() []api.ToolCall {
|
|
||||||
toolName, toolContent := p.toolParser.Drain()
|
|
||||||
if toolName != nil {
|
|
||||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
|
||||||
var args api.ToolCallFunctionArguments
|
|
||||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: *toolName,
|
|
||||||
Arguments: args,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -34,7 +34,6 @@ import (
|
|||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
"github.com/ollama/ollama/parser"
|
|
||||||
"github.com/ollama/ollama/runner/common"
|
"github.com/ollama/ollama/runner/common"
|
||||||
"github.com/ollama/ollama/sample"
|
"github.com/ollama/ollama/sample"
|
||||||
|
|
||||||
@@ -468,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
|
|
||||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||||
var batchInputs []*input.Input
|
var batchInputs []*input.Input
|
||||||
|
var batchOutputs []int32
|
||||||
var batch input.Batch
|
var batch input.Batch
|
||||||
|
|
||||||
resumeSeq := -1
|
resumeSeq := -1
|
||||||
@@ -550,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||||
|
|
||||||
seq.iBatch = len(batch.Outputs)
|
seq.iBatch = len(batchOutputs)
|
||||||
if i+1 == len(seq.inputs) {
|
if i+1 == len(seq.inputs) || seq.embeddingOnly {
|
||||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
|
||||||
}
|
}
|
||||||
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||||
@@ -577,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
|
|
||||||
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
||||||
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
||||||
|
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
|
||||||
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("failed to build graph: %w", err)
|
err = fmt.Errorf("failed to build graph: %w", err)
|
||||||
@@ -704,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
|
vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
|
||||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||||
@@ -781,8 +782,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
|
|
||||||
|
|
||||||
if req.Options == nil {
|
if req.Options == nil {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
req.Options = &opts
|
req.Options = &opts
|
||||||
@@ -873,18 +872,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
var thinking string
|
|
||||||
var err error
|
|
||||||
content, thinking, err = tokenParser.AddContent(content)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
close(seq.quit)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
Content: content,
|
Content: content,
|
||||||
Thinking: thinking,
|
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
@@ -893,9 +882,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
} else {
|
} else {
|
||||||
toolCalls := tokenParser.Drain()
|
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
ToolCalls: toolCalls,
|
|
||||||
Done: true,
|
Done: true,
|
||||||
DoneReason: seq.doneReason,
|
DoneReason: seq.doneReason,
|
||||||
PromptEvalCount: seq.numPromptInputs,
|
PromptEvalCount: seq.numPromptInputs,
|
||||||
@@ -1061,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||||||
batch.Positions[i] = int32(i)
|
batch.Positions[i] = int32(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Outputs = make([]int32, s.parallel)
|
|
||||||
for i := range batch.Outputs {
|
|
||||||
batch.Outputs[i] = int32(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||||
|
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
|
||||||
|
|
||||||
cache := s.model.Config().Cache
|
cache := s.model.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
|
|||||||
154
server/routes.go
154
server/routes.go
@@ -36,7 +36,6 @@ import (
|
|||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/parser"
|
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
@@ -47,6 +46,18 @@ import (
|
|||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func shouldUseHarmony(model *Model) bool {
|
||||||
|
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||||
|
// heuristic to check whether the template expects to be parsed via harmony:
|
||||||
|
// search for harmony tags that are nearly always used
|
||||||
|
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func experimentEnabled(name string) bool {
|
func experimentEnabled(name string) bool {
|
||||||
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
||||||
}
|
}
|
||||||
@@ -196,17 +207,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
|
useHarmony := shouldUseHarmony(m) && !req.Raw
|
||||||
var parserType parser.TokenParserType
|
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||||
|
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
parserType = parser.TokenParserTypeHarmony
|
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||||
} else {
|
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
|
||||||
parserType = parser.TokenParserTypeDefault
|
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||||
}
|
|
||||||
var functionNameMap *harmony.FunctionNameMap
|
|
||||||
|
|
||||||
if useHarmony {
|
|
||||||
functionNameMap = harmony.NewFunctionNameMap()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate Think value: string values currently only allowed for gptoss models
|
// Validate Think value: string values currently only allowed for gptoss models
|
||||||
@@ -350,19 +357,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
ParserType: parserType,
|
|
||||||
}, func(cr llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
res := api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Response: cr.Content,
|
Response: cr.Content,
|
||||||
Done: cr.Done,
|
Done: cr.Done,
|
||||||
Thinking: cr.Thinking,
|
|
||||||
ToolCalls: cr.ToolCalls,
|
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: cr.PromptEvalCount,
|
PromptEvalCount: cr.PromptEvalCount,
|
||||||
PromptEvalDuration: cr.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
@@ -371,22 +375,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.Done {
|
|
||||||
res.DoneReason = cr.DoneReason.String()
|
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
||||||
}
|
|
||||||
|
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
for i, tool := range res.ToolCalls {
|
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
|
||||||
res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
|
res.Response = content
|
||||||
}
|
res.Thinking = thinking
|
||||||
if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done {
|
harmonyToolParser.Add(toolContent)
|
||||||
ch <- res
|
} else if thinkingState != nil {
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if thinkingState != nil {
|
|
||||||
thinking, content := thinkingState.AddContent(cr.Content)
|
thinking, content := thinkingState.AddContent(cr.Content)
|
||||||
res.Thinking = thinking
|
res.Thinking = thinking
|
||||||
res.Response = content
|
res.Response = content
|
||||||
@@ -397,6 +391,30 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cr.Done {
|
if cr.Done {
|
||||||
|
if useHarmony {
|
||||||
|
toolName, toolContent := harmonyToolParser.Drain()
|
||||||
|
if toolName != nil {
|
||||||
|
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||||
|
var args api.ToolCallFunctionArguments
|
||||||
|
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||||
|
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||||
|
ch <- gin.H{"error": errStr}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: *toolName,
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res.DoneReason = cr.DoneReason.String()
|
||||||
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
|
||||||
if !req.Raw {
|
if !req.Raw {
|
||||||
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -470,7 +488,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
truncate := true
|
truncate := true
|
||||||
|
|
||||||
if req.Truncate != nil && !*req.Truncate {
|
if req.Truncate != nil && !*req.Truncate {
|
||||||
truncate = false
|
truncate = false
|
||||||
}
|
}
|
||||||
@@ -537,7 +554,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||||
|
ctxLen--
|
||||||
|
}
|
||||||
|
|
||||||
|
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||||
|
ctxLen--
|
||||||
|
}
|
||||||
|
|
||||||
tokens = tokens[:ctxLen]
|
tokens = tokens[:ctxLen]
|
||||||
|
|
||||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -1599,27 +1625,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
msgs = filterThinkTags(msgs, m)
|
msgs = filterThinkTags(msgs, m)
|
||||||
|
|
||||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
|
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||||
var parserType parser.TokenParserType
|
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||||
if useHarmony {
|
|
||||||
parserType = parser.TokenParserTypeHarmony
|
useHarmony := shouldUseHarmony(m)
|
||||||
} else {
|
|
||||||
parserType = parser.TokenParserTypeDefault
|
|
||||||
}
|
|
||||||
|
|
||||||
processedTools := req.Tools
|
processedTools := req.Tools
|
||||||
var functionNameMap *harmony.FunctionNameMap
|
|
||||||
var prefillString string
|
|
||||||
// TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner
|
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
prefillString = harmony.Prefill(msgs[len(msgs)-1])
|
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||||
functionNameMap = harmony.NewFunctionNameMap()
|
var lastMessage *api.Message
|
||||||
|
if len(msgs) > 0 {
|
||||||
|
lastMessage = &msgs[len(msgs)-1]
|
||||||
|
}
|
||||||
|
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||||
|
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||||
|
|
||||||
// make a copy of tools to pass to the chat prompt. Function names may be
|
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||||
// renamed to be valid Harmony function names.
|
// renamed to be valid Harmony function names.
|
||||||
processedTools = make([]api.Tool, len(req.Tools))
|
processedTools = make([]api.Tool, len(req.Tools))
|
||||||
copy(processedTools, req.Tools)
|
copy(processedTools, req.Tools)
|
||||||
for i, tool := range processedTools {
|
for i, tool := range processedTools {
|
||||||
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
|
processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1672,17 +1698,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
ParserType: parserType,
|
|
||||||
PrefillString: prefillString,
|
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(r llm.CompletionResponse) {
|
||||||
res := api.ChatResponse{
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
|
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||||
Done: r.Done,
|
Done: r.Done,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: r.PromptEvalCount,
|
||||||
@@ -1698,13 +1722,31 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
for i, tool := range res.Message.ToolCalls {
|
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
|
||||||
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
|
res.Message.Content = content
|
||||||
|
res.Message.Thinking = thinking
|
||||||
|
harmonyToolParser.Add(toolContent)
|
||||||
|
|
||||||
|
if r.Done {
|
||||||
|
toolName, toolContent := harmonyToolParser.Drain()
|
||||||
|
if toolName != nil {
|
||||||
|
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||||
|
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
|
||||||
|
var args api.ToolCallFunctionArguments
|
||||||
|
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||||
|
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||||
|
ch <- gin.H{"error": errStr}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// only send messages with meaningful content (empty messages confuse clients)
|
// only send messages with meaningful content (empty messages confuse clients)
|
||||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
||||||
ch <- res
|
ch <- res
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -117,7 +118,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "content streams as it arrives",
|
name: "content streams as it arrives",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "Hello", Done: false},
|
input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
|
||||||
wantContent: "Hello",
|
wantContent: "Hello",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -125,7 +126,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
wantContent: ", world",
|
wantContent: ", world",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop},
|
input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
wantContent: "!",
|
wantContent: "!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -134,15 +135,20 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "thinking streams separately from content",
|
name: "thinking streams separately from content",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Thinking: "Thinking...", Done: false},
|
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
|
||||||
wantThinking: "Thinking...",
|
wantThinking: "Thinking...",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "Answer", Done: false},
|
input: llm.CompletionResponse{Content: "<|end|>", Done: false},
|
||||||
wantContent: "Answer",
|
// No output expected - just closes the analysis message and resets state to normal
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop},
|
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
|
||||||
|
wantContent: "Answer", // After message end, state is reset to normal
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
|
// No output expected - just closes the assistant message
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -150,16 +156,24 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "partial tags buffer until complete",
|
name: "partial tags buffer until complete",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Thinking: "Deep ", Done: false},
|
input: llm.CompletionResponse{Content: "<|chan", Done: false},
|
||||||
|
// No output - partial tag
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
|
||||||
|
// No output - still building tags
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
|
||||||
wantThinking: "Deep ",
|
wantThinking: "Deep ",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Thinking: "thought", Done: false},
|
input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
|
||||||
wantThinking: "thought",
|
wantThinking: "thought",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop},
|
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
wantContent: "Done",
|
wantContent: "Done", // After message end, state is reset to normal
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -167,7 +181,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "simple assistant after analysis",
|
name: "simple assistant after analysis",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop},
|
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
wantContent: "Answer",
|
wantContent: "Answer",
|
||||||
wantThinking: "Think",
|
wantThinking: "Think",
|
||||||
},
|
},
|
||||||
@@ -177,7 +191,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "tool call parsed and returned correctly",
|
name: "tool call parsed and returned correctly",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop},
|
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
wantContent: "The weather is sunny",
|
wantContent: "The weather is sunny",
|
||||||
wantToolCalls: []api.ToolCall{
|
wantToolCalls: []api.ToolCall{
|
||||||
{
|
{
|
||||||
@@ -196,10 +210,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "tool call with streaming JSON across chunks",
|
name: "tool call with streaming JSON across chunks",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Done: false},
|
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
|
||||||
|
// No output yet - incomplete JSON
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true},
|
input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
|
||||||
|
// Still no output - incomplete JSON
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: llm.CompletionResponse{Content: "2\"}", Done: true},
|
||||||
wantToolCalls: []api.ToolCall{
|
wantToolCalls: []api.ToolCall{
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
@@ -381,9 +400,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
|||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
mockResponses := []llm.CompletionResponse{
|
mockResponses := []llm.CompletionResponse{
|
||||||
{Content: "First ", Done: false},
|
{Content: "<|message|>First ", Done: false},
|
||||||
{Content: "chunk ", Done: false},
|
{Content: "chunk ", Done: false},
|
||||||
{Content: "here", Done: true, DoneReason: llm.DoneReasonStop},
|
{Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
}
|
}
|
||||||
|
|
||||||
mock := mockRunner{
|
mock := mockRunner{
|
||||||
@@ -488,3 +507,189 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
|||||||
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
|
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
type expectedChunk struct {
|
||||||
|
afterResponse int // Which mock response this chunk should appear after
|
||||||
|
content string // Expected content in this chunk
|
||||||
|
thinking string // Expected thinking in this chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
mockResponses []llm.CompletionResponse
|
||||||
|
expectedChunks []expectedChunk
|
||||||
|
wantContent string
|
||||||
|
wantThinking string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple message without thinking",
|
||||||
|
mockResponses: []llm.CompletionResponse{
|
||||||
|
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
|
||||||
|
{Content: "how can I help?", Done: false},
|
||||||
|
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
|
},
|
||||||
|
expectedChunks: []expectedChunk{
|
||||||
|
{afterResponse: 1, content: "Hello, "},
|
||||||
|
{afterResponse: 2, content: "how can I help?"},
|
||||||
|
},
|
||||||
|
wantContent: "Hello, how can I help?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "message with analysis channel for thinking",
|
||||||
|
mockResponses: []llm.CompletionResponse{
|
||||||
|
{Content: "<|channel|>analysis<|message|>", Done: false},
|
||||||
|
{Content: "Let me think ", Done: false},
|
||||||
|
{Content: "about this problem...", Done: false},
|
||||||
|
{Content: "<|end|>", Done: false},
|
||||||
|
{Content: "<|start|>assistant<|message|>", Done: false},
|
||||||
|
{Content: "The answer ", Done: false},
|
||||||
|
{Content: "is 42", Done: false},
|
||||||
|
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
|
},
|
||||||
|
expectedChunks: []expectedChunk{
|
||||||
|
{afterResponse: 2, thinking: "Let me think "},
|
||||||
|
{afterResponse: 3, thinking: "about this problem..."},
|
||||||
|
{afterResponse: 6, content: "The answer "},
|
||||||
|
{afterResponse: 7, content: "is 42"},
|
||||||
|
},
|
||||||
|
wantContent: "The answer is 42",
|
||||||
|
wantThinking: "Let me think about this problem...",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming with partial tags across boundaries",
|
||||||
|
mockResponses: []llm.CompletionResponse{
|
||||||
|
{Content: "<|chan", Done: false},
|
||||||
|
{Content: "nel|>analy", Done: false},
|
||||||
|
{Content: "sis<|mess", Done: false},
|
||||||
|
{Content: "age|>Think", Done: false},
|
||||||
|
{Content: "ing deeply...<|end|>", Done: false},
|
||||||
|
{Content: "<|start|>assi", Done: false},
|
||||||
|
{Content: "stant<|message|>Result ", Done: false},
|
||||||
|
{Content: "computed<|e", Done: false},
|
||||||
|
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
|
},
|
||||||
|
expectedChunks: []expectedChunk{
|
||||||
|
{afterResponse: 4, thinking: "Think"},
|
||||||
|
{afterResponse: 5, thinking: "ing deeply..."},
|
||||||
|
{afterResponse: 7, content: "Result "},
|
||||||
|
{afterResponse: 8, content: "computed"},
|
||||||
|
},
|
||||||
|
wantContent: "Result computed",
|
||||||
|
wantThinking: "Thinking deeply...",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Channel to synchronize mock responses with chunk verification
|
||||||
|
responsesSent := make(chan int, len(tc.mockResponses))
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
|
// Send mock responses one at a time, notifying when each is sent
|
||||||
|
for i, resp := range tc.mockResponses {
|
||||||
|
fn(resp)
|
||||||
|
responsesSent <- i + 1
|
||||||
|
}
|
||||||
|
close(responsesSent)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: discover.GetGPUInfo,
|
||||||
|
getCpuFn: discover.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(t.Context())
|
||||||
|
|
||||||
|
// Create a minimal model
|
||||||
|
_, digest := createHarmonyTestModel(t)
|
||||||
|
|
||||||
|
// Create model with passthrough template
|
||||||
|
stream := false
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Model: "harmony-test",
|
||||||
|
Files: map[string]string{"file.gguf": digest},
|
||||||
|
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("failed to create model: %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test chat endpoint with streaming
|
||||||
|
streamTrue := true
|
||||||
|
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "harmony-test",
|
||||||
|
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||||
|
Stream: &streamTrue,
|
||||||
|
Tools: getTestTools(),
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse streaming response
|
||||||
|
var chunks []api.ChatResponse
|
||||||
|
var content, thinking strings.Builder
|
||||||
|
|
||||||
|
decoder := json.NewDecoder(w.Body)
|
||||||
|
for decoder.More() {
|
||||||
|
var chunk api.ChatResponse
|
||||||
|
if err := decoder.Decode(&chunk); err != nil {
|
||||||
|
t.Fatalf("failed to decode chunk: %v", err)
|
||||||
|
}
|
||||||
|
chunks = append(chunks, chunk)
|
||||||
|
|
||||||
|
// Accumulate content and thinking from each chunk
|
||||||
|
content.WriteString(chunk.Message.Content)
|
||||||
|
thinking.WriteString(chunk.Message.Thinking)
|
||||||
|
|
||||||
|
// Debug output
|
||||||
|
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we got streaming chunks
|
||||||
|
if len(chunks) == 0 {
|
||||||
|
t.Fatal("expected streaming chunks, got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
gotContent := content.String()
|
||||||
|
gotThinking := thinking.String()
|
||||||
|
|
||||||
|
if gotContent != tc.wantContent {
|
||||||
|
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
|
||||||
|
}
|
||||||
|
if gotThinking != tc.wantThinking {
|
||||||
|
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify last chunk has done=true
|
||||||
|
lastChunk := chunks[len(chunks)-1]
|
||||||
|
if !lastChunk.Done {
|
||||||
|
t.Error("expected last chunk to have done=true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user