Compare commits

...

8 Commits

Author SHA1 Message Date
Patrick Devine
c10a40db99 parser: tidy up parameter/message parsing
This change addresses how parameters and messages are handled while parsing a Modelfile.
Currently a `MESSAGE` command creates a string that looks like "<role>: <message>" which
then has to be re-parsed, and `PARAMETER` ends up being the default data type which makes
it difficult to add other multi-part commands.

This change introduces a Message and a Parameter type which properly handle properties
such as the role and the name of the parameter.
2025-09-15 18:09:05 -07:00
Daniel Hiltgen
93c64ea1b1 doc: show how to clear the cgo cache (#12298) 2025-09-15 15:45:35 -07:00
Michael Yang
3f6642f6fc model: implement bert in ollama engine (#9080)
* fix truncate

* s/SentencePieceModel/SentencePiece/

* bert

* wordpiece

* refactor pooling

* more tokenizers

* normalize embeddings
2025-09-15 15:35:59 -07:00
Michael Yang
6f7117145f batch: use tensors for outputs (#12185)
this cleans up the model interface slightly without too much impact in
other areas
2025-09-15 14:33:06 -07:00
jmorganca
92b96d54ef Revert "runner: move harmony to runner (#12052)"
This reverts commit 1a558f98e2.
2025-09-12 20:40:14 -03:00
jmorganca
9d56e63dbf Revert "runner: simplify parser entrypoints in runner (#12233)"
This reverts commit 8d6fffaead.
2025-09-12 20:40:14 -03:00
tc-mb
053092185e Fix image cannot be seen with slice image on llama engine
Ollama's recent engine update, llama.cpp, caused all models requiring a slice schema to not display images. As a result, the value of numTokens isn't always the length of the sliced ​​image embed, but rather the end length of the schema. This causes the image embed to not be correctly included during all slice processing.
2025-09-12 16:25:12 -07:00
Daniel Hiltgen
44a6792873 tests: tighten up a few flaky tests (#12271)
Sometimes the context test results are pure emoji's
Thanksgiving has too much variability, so swap for a more straight forward prompt.
2025-09-12 13:59:34 -07:00
39 changed files with 987 additions and 618 deletions

View File

@@ -28,6 +28,7 @@ type bertModel struct {
LayerNormEPS float32 `json:"layer_norm_eps"` LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"` LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"`
normalizeEmbeddings bool
PoolingType uint32 PoolingType uint32
} }
@@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
var pooling string var pooling string
for _, m := range modules { for _, m := range modules {
if m.Type == "sentence_transformers.models.Pooling" { switch m.Type {
case "sentence_transformers.models.Pooling":
pooling = m.Path pooling = m.Path
break case "sentence_transformers.models.Normalize":
p.normalizeEmbeddings = true
} }
} }
@@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv["general.architecture"] = "bert" kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false kv["bert.attention.causal"] = false
kv["bert.pooling_type"] = p.PoolingType kv["bert.pooling_type"] = p.PoolingType
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer) kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)

View File

@@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
go run . serve go run . serve
``` ```
> [!NOTE]
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
## macOS (Apple Silicon) ## macOS (Apple Silicon)
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required. macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.

View File

@@ -3,29 +3,15 @@ package harmony
import ( import (
"fmt" "fmt"
"log/slog" "log/slog"
"slices"
"strings" "strings"
"unicode" "unicode"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/template"
) )
type harmonyParserState int type harmonyParserState int
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if template.Contains("<|start|>") && template.Contains("<|end|>") {
return true
}
}
return false
}
const ( const (
harmonyParserState_LookingForMessageStart harmonyParserState = iota harmonyParserState_LookingForMessageStart harmonyParserState = iota
harmonyParserState_ParsingHeader harmonyParserState_ParsingHeader
@@ -89,28 +75,18 @@ func (s *HarmonyParser) AddImplicitStart() {
s.acc.WriteString("<|start|>assistant") s.acc.WriteString("<|start|>assistant")
} }
func Prefill(lastMessage api.Message) string { func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
if lastMessage.Role != "assistant" { if lastMessage != nil && lastMessage.Role == "assistant" {
return "" // handle prefilling conditions
} if lastMessage.Content != "" {
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
switch { return
case strings.TrimSpace(lastMessage.Content) != "": } else if lastMessage.Thinking != "" {
return "<|start|>assistant<|channel|>final<|message|>" s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
case strings.TrimSpace(lastMessage.Thinking) != "": return
return "<|start|>assistant<|channel|>analysis<|message|>" }
default:
return ""
}
}
// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) {
if strings.TrimSpace(prefillString) != "" {
s.acc.WriteString(prefillString)
} else {
s.AddImplicitStart()
} }
s.AddImplicitStart()
} }
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent { func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
@@ -289,7 +265,6 @@ type HarmonyMessageHandler struct {
state harmonyMessageState state harmonyMessageState
HarmonyParser *HarmonyParser HarmonyParser *HarmonyParser
FunctionNameMap *FunctionNameMap FunctionNameMap *FunctionNameMap
ToolParser *HarmonyToolCallAccumulator
} }
// NewHarmonyMessageHandler creates a new message handler // NewHarmonyMessageHandler creates a new message handler
@@ -302,16 +277,12 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
HeaderEndTag: "<|message|>", HeaderEndTag: "<|message|>",
}, },
FunctionNameMap: NewFunctionNameMap(), FunctionNameMap: NewFunctionNameMap(),
ToolParser: &HarmonyToolCallAccumulator{
state: harmonyToolCallState_Normal,
currentToolName: nil,
},
} }
} }
// AddContent processes the content and returns the content, thinking, and tool content. // AddContent processes the content and returns the content, thinking, and tool content.
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser // content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) { func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
contentSb := strings.Builder{} contentSb := strings.Builder{}
thinkingSb := strings.Builder{} thinkingSb := strings.Builder{}
toolContentSb := strings.Builder{} toolContentSb := strings.Builder{}
@@ -328,14 +299,14 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
// event.Header.Recipient is the tool name, something like // event.Header.Recipient is the tool name, something like
// "browser.search" for a built-in, or "functions.calc" for a // "browser.search" for a built-in, or "functions.calc" for a
// custom one // custom one
h.ToolParser.SetToolName(event.Header.Recipient) toolParser.SetToolName(event.Header.Recipient)
} else { } else {
h.state = harmonyMessageState_Thinking h.state = harmonyMessageState_Thinking
} }
case "commentary": case "commentary":
if event.Header.Recipient != "" { if event.Header.Recipient != "" {
h.state = harmonyMessageState_ToolCalling h.state = harmonyMessageState_ToolCalling
h.ToolParser.SetToolName(event.Header.Recipient) toolParser.SetToolName(event.Header.Recipient)
} else { } else {
h.state = harmonyMessageState_Normal h.state = harmonyMessageState_Normal
} }
@@ -358,6 +329,13 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
return contentSb.String(), thinkingSb.String(), toolContentSb.String() return contentSb.String(), thinkingSb.String(), toolContentSb.String()
} }
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
return &HarmonyToolCallAccumulator{
state: harmonyToolCallState_Normal,
currentToolName: nil,
}
}
type harmonyToolCallState int type harmonyToolCallState int
const ( const (

View File

@@ -3,7 +3,6 @@ package harmony
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"testing" "testing"
) )
@@ -536,202 +535,3 @@ func TestFunctionConvertAndAdd(t *testing.T) {
}) })
} }
} }
func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("thinking_then_content_streams", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
type step struct {
in string
wantContent string
wantThinking string
}
steps := []step{
{in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."},
{in: "<|end|>", wantThinking: ""},
{in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"},
{in: "<|end|>", wantContent: ""},
}
for i, s := range steps {
content, thinking, tool := handler.AddContent(s.in)
if tool != "" {
tp.Add(tool)
}
if content != s.wantContent || thinking != s.wantThinking {
t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking)
}
}
})
t.Run("content_streams_as_it_arrives", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|start|>assistant<|message|>Hello",
", world",
"!<|end|>",
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
t.Fatalf("unexpected thinking %q", thinking)
}
if content != "" {
got = append(got, content)
}
}
want := []string{"Hello", ", world", "!"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
}
})
t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>analysis<|message|>Thinking...",
"<|end|>",
"<|start|>assistant<|message|>Answer",
"<|end|>",
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
got = append(got, thinking)
}
if content != "" {
got = append(got, content)
}
}
want := []string{"Thinking...", "Answer"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
}
})
t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|chan",
"nel|>analysis<|mess",
"age|>Deep ",
"thought",
"<|end|>",
"<|start|>assistant<|message|>Done",
"<|end|>",
}
var thinkingPieces []string
var contentPieces []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
thinkingPieces = append(thinkingPieces, thinking)
}
if content != "" {
contentPieces = append(contentPieces, content)
}
}
if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) {
t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want)
}
if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) {
t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want)
}
})
t.Run("simple_assistant_after_analysis", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>analysis<|message|>Think",
"<|end|>",
"<|start|>assistant<|message|>Answer",
"<|end|>",
}
var contentSb, thinkingSb strings.Builder
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
contentSb.WriteString(content)
thinkingSb.WriteString(thinking)
}
if contentSb.String() != "Answer" {
t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer")
}
if thinkingSb.String() != "Think" {
t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think")
}
})
t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if content != "" || thinking != "" {
continue
}
if tool != "" {
tp.Add(tool)
}
}
name, args := tp.Drain()
if name == nil || *name != "functions.calculate" {
t.Fatalf("unexpected tool name: %v", name)
}
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
t.Fatalf("unexpected tool args: got %s want %s", got, want)
}
})
t.Run("tool_call_across_chunks", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
"2\"}",
"<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if content != "" || thinking != "" {
continue
}
if tool != "" {
tp.Add(tool)
}
}
name, args := tp.Drain()
if name == nil || *name != "functions.calculate" {
t.Fatalf("unexpected tool name: %v", name)
}
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
t.Fatalf("unexpected tool args: got %s want %s", got, want)
}
})
}

View File

@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: smol, Model: smol,
Prompt: "Write me a story with a ton of emojis?", Prompt: "Write me a story in english with a lot of emojis",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
"temperature": 0, "temperature": 0,

View File

@@ -561,7 +561,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
}, { }, {
Model: smol, Model: smol,
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply", Prompt: "how do rainbows form? Be brief but factual in your reply",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
}, { }, {
@@ -579,9 +579,9 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
[][]string{ [][]string{
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"}, {"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"}, {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"}, {"water", "droplet", "refracted", "reflect", "color", "spectrum"},
{"fourth", "july", "declaration", "independence"}, {"fourth", "july", "declaration", "independence"},
{"nitrogen", "oxygen", "carbon", "dioxide"}, {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"},
} }
} }

View File

@@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
} }
nChunks := C.mtmd_input_chunks_size(ic) nChunks := C.mtmd_input_chunks_size(ic)
numEmbed := llamaContext.Model().NEmbd() numEmbed := llamaContext.Model().NEmbd()
lastChunkSize := 0 embed := make([][]float32, 0)
for i := range int(nChunks) { for i := range int(nChunks) {
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i)) chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk)) numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
lastChunkSize = numTokens slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
// Encode the chunk // Encode the chunk
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
return nil, errors.New("unable to encode mtmd image chunk") return nil, errors.New("unable to encode mtmd image chunk")
} }
}
// Get the embeddings // Get the embeddings for this chunk
embed := make([][]float32, lastChunkSize) chunkEmbed := make([][]float32, numTokens)
embd := C.mtmd_get_output_embd(c.c) chunkEmbd := C.mtmd_get_output_embd(c.c)
if nil == embd { if nil == chunkEmbd {
return nil, errors.New("failed to get image embedding") continue
} }
// Extend the embedding array for each token // Extend the embedding array for each token
s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize) s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
rows := make([]float32, len(s)) rows := make([]float32, len(s))
copy(rows, s) copy(rows, s)
for i := range lastChunkSize { for i := range numTokens {
embed[i] = rows[i*numEmbed : (i+1)*numEmbed] chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
}
embed = append(embed, chunkEmbed...)
} }
slog.Debug("image embeddings", "totalEmbeddings", len(embed))
return embed, nil return embed, nil
} }

View File

@@ -35,7 +35,6 @@ import (
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/parser"
) )
type filteredEnv []string type filteredEnv []string
@@ -1349,9 +1348,7 @@ type CompletionRequest struct {
Images []ImageData Images []ImageData
Options *api.Options Options *api.Options
Grammar string // set before sending the request to the subprocess Grammar string // set before sending the request to the subprocess
ParserType parser.TokenParserType
PrefillString string
} }
// DoneReason represents the reason why a completion response is done // DoneReason represents the reason why a completion response is done
@@ -1378,15 +1375,13 @@ func (d DoneReason) String() string {
} }
type CompletionResponse struct { type CompletionResponse struct {
Content string `json:"content"` Content string `json:"content"`
Thinking string `json:"thinking"` DoneReason DoneReason `json:"done_reason"`
ToolCalls []api.ToolCall `json:"tool_calls"` Done bool `json:"done"`
DoneReason DoneReason `json:"done_reason"` PromptEvalCount int `json:"prompt_eval_count"`
Done bool `json:"done"` PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
PromptEvalCount int `json:"prompt_eval_count"` EvalCount int `json:"eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"` EvalDuration time.Duration `json:"eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
} }
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@@ -1504,8 +1499,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("error unmarshalling llm prediction response: %v", err) return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
} }
switch { switch {
// TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future case strings.TrimSpace(c.Content) == lastToken:
case strings.TrimSpace(c.Content) == lastToken && c.Content != "":
tokenRepeat++ tokenRepeat++
default: default:
lastToken = strings.TrimSpace(c.Content) lastToken = strings.TrimSpace(c.Content)
@@ -1518,14 +1512,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return ctx.Err() return ctx.Err()
} }
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
})
}
if c.Done { if c.Done {
fn(c) fn(c)
return nil return nil
} }
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
fn(c)
}
} }
} }

View File

@@ -416,6 +416,7 @@ type Tensor interface {
AddID(ctx Context, t2, ids Tensor) Tensor AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor Softmax(ctx Context) Tensor
L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor Scale(ctx Context, s float64) Tensor

View File

@@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
} }
} }
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)) tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if w != nil { if w != nil {

36
ml/nn/pooling/pooling.go Normal file
View File

@@ -0,0 +1,36 @@
package pooling
import (
"github.com/ollama/ollama/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
TypeRank
TypeUnknown = 0xFFFFFFFE
TypeUnspecified = 0xFFFFFFFF
)
func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor {
switch poolingType {
case TypeNone:
return hiddenStates
case TypeMean:
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
case TypeLast:
panic("not implemented")
case TypeRank:
panic("not implemented")
default:
panic("not implemented")
}
}

View File

@@ -54,10 +54,9 @@ type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs. // Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor Inputs ml.Tensor
// Multimodal is a set of multimodal embeddings previously created by // Outputs are the set of indicies into Inputs for which output data should
// EncodeMultimodal, along with an index into Inputs. Unused for text-only // be returned.
// models or for batches without multimodal elements. Outputs ml.Tensor
Multimodal []MultimodalIndex
// Positions is the position for each Input, relative to its sequence. Equal // Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs. // in length to Inputs.
@@ -66,7 +65,8 @@ type Batch struct {
// Sequences is the sequence for each Input. Equal in length to Inputs. // Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int Sequences []int
// Outputs are the set of indicies into Inputs for which output data should // Multimodal is a set of multimodal embeddings previously created by
// be returned. // EncodeMultimodal, along with an index into Inputs. Unused for text-only
Outputs []int32 // models or for batches without multimodal elements.
Multimodal []MultimodalIndex
} }

View File

@@ -24,7 +24,11 @@ import (
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
var ErrNoVisionModel = errors.New("this model is missing data required for image input") var (
ErrNoVisionModel = errors.New("this model is missing data required for image input")
ErrUnsupportedModel = errors.New("model not supported")
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
)
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface { type Model interface {
@@ -242,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
vv = vv.Elem() vv = vv.Elem()
} }
vv = vv.Elem() vv = reflect.Indirect(vv)
if v.IsNil() { if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem() vv = reflect.New(v.Type().Elem()).Elem()
} }

181
model/models/bert/model.go Normal file
View File

@@ -0,0 +1,181 @@
package bert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
Layers []EncoderLayer `gguf:"blk"`
Options
}
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
}
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
if m.normalize {
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
}
return hiddenStates, nil
}
type EncoderLayer struct {
*Attention
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
*MLP
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
}
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
// Attention
residual := hiddenStates
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// MLP
residual = hiddenStates
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
query := a.Query.Forward(ctx, hiddenStates)
if a.QueryNorm != nil {
query = a.QueryNorm.Forward(ctx, query, opts.eps)
}
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
key := a.Key.Forward(ctx, hiddenStates)
if a.KeyNorm != nil {
key = a.KeyNorm.Forward(ctx, key, opts.eps)
}
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
value := a.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return a.Output.Forward(ctx, attention)
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
}
type Options struct {
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength int
poolingType pooling.Type
eps float32
normalize bool
}
func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func New(c fs.Config) (model.Model, error) {
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
//nolint:misspell
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
// support it for compatibility.
c.Uint("tokenizer.ggml.seperator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_epsilon"),
poolingType: pooling.Type(c.Uint("pooling_type")),
normalize: c.Bool("normalize_embeddings", true),
},
}, nil
}
func init() {
model.Register("bert", New)
model.Register("bert_embed", New)
}

View File

@@ -24,7 +24,7 @@ type Options struct {
type Model struct { type Model struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
TokenEmbedding *nn.Embedding `gguf:"token_embd"` TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"` Layers []Layer `gguf:"blk"`
@@ -40,7 +40,7 @@ const (
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var lastLayerOutputs ml.Tensor var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
lastLayerOutputs = outputs lastLayerOutputs = batch.Outputs
} }
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)

View File

@@ -1,49 +1,38 @@
package gemma3 package gemma3
import ( import (
"errors"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
type embedModel struct { type embedModel struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
*TextModel *TextModel
PoolingType uint32 poolingType pooling.Type
Dense [2]*nn.Linear `gguf:"dense"` Dense [2]*nn.Linear `gguf:"dense"`
} }
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
batch.Outputs = batch.Positions // return all positions
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
switch m.PoolingType {
case 0: // None
case 1: // Mean
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
default:
return nil, errors.New("unsupported pooling type")
}
for _, dense := range m.Dense { for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates) hiddenStates = dense.Forward(ctx, hiddenStates)
} }
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil return hiddenStates, nil
} }
func newEmbedModel(c fs.Config) (model.Model, error) { func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{ m := &embedModel{
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
@@ -61,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
}, },
), ),
TextModel: newTextModel(c), TextModel: newTextModel(c),
PoolingType: c.Uint("pooling_type", 0), poolingType: pooling.Type(c.Uint("pooling_type", 0)),
} }
m.Cache = kvcache.NewWrapperCache( m.Cache = kvcache.NewWrapperCache(

View File

@@ -16,7 +16,7 @@ import (
type Model struct { type Model struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
*VisionModel `gguf:"v"` *VisionModel `gguf:"v"`
*TextModel *TextModel
@@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),

View File

@@ -161,7 +161,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
@@ -194,7 +193,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
var lastLayerOutputs ml.Tensor var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
lastLayerOutputs = outputs lastLayerOutputs = batch.Outputs
} }
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig) hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)

View File

@@ -10,7 +10,7 @@ import (
type Model struct { type Model struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
*TextModel *TextModel
} }
@@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
TextModel: newTextModel(c), TextModel: newTextModel(c),
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),

View File

@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx) hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))) hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil return m.Output.Forward(ctx, hiddenStates), nil

View File

@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
} }
var outputs ml.Tensor var outputs ml.Tensor
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 { if i == len(m.TransformerBlocks)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = batch.Outputs
} }
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options) hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)

View File

@@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = batch.Outputs
} }
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)

View File

@@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
} }
func init() { func init() {

View File

@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
} }
func init() { func init() {

View File

@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
} }
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
// TODO: attention mask, cross attention mask // TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
} }
func init() { func init() {

View File

@@ -1,6 +1,7 @@
package models package models
import ( import (
_ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n" _ "github.com/ollama/ollama/model/models/gemma3n"

View File

@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = batch.Outputs
} }
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)

View File

@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
} }
func init() { func init() {

View File

@@ -165,7 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = batch.Outputs
} }
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)

View File

@@ -12,18 +12,18 @@ import (
const spmWhitespaceSep = "▁" const spmWhitespaceSep = "▁"
type SentencePieceModel struct { type SentencePiece struct {
maxTokenLen int maxTokenLen int
vocab *Vocabulary vocab *Vocabulary
} }
var _ TextProcessor = (*SentencePieceModel)(nil) var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePieceModel) Vocabulary() *Vocabulary { func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab return spm.vocab
} }
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{} counter := map[int]int{}
@@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen) "max token len", maxTokenLen)
return SentencePieceModel{ return SentencePiece{
maxTokenLen: maxTokenLen, maxTokenLen: maxTokenLen,
vocab: vocab, vocab: vocab,
} }
} }
func (spm SentencePieceModel) Is(id int32, special Special) bool { func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special) return spm.vocab.Is(id, special)
} }
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) { func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}} fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() { for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special) id := spm.vocab.Encode(special)
@@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} {
return item return item
} }
func (spm SentencePieceModel) Decode(ids []int32) (string, error) { func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder var sb strings.Builder
for _, id := range ids { for _, id := range ids {
data := spm.vocab.Decode(id) data := spm.vocab.Decode(id)

View File

@@ -12,7 +12,7 @@ import (
"github.com/ollama/ollama/convert/sentencepiece" "github.com/ollama/ollama/convert/sentencepiece"
) )
func loadSentencePieceVocab(t *testing.T) SentencePieceModel { func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper() t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model")) bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
@@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
} }
} }
return NewSentencePieceModel(&v) return NewSentencePiece(&v)
} }
func TestSentencePieceEncode(t *testing.T) { func TestSentencePieceEncode(t *testing.T) {
@@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
}) })
} }
func TestSentencePieceModelDecodeByteTokens(t *testing.T) { func TestSentencePieceDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{ vocab := &Vocabulary{
Values: []string{ Values: []string{
"normal", "normal",
@@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
Scores: []float32{0, 0, 0, 0, 0}, Scores: []float32{0, 0, 0, 0, 0},
} }
spm := NewSentencePieceModel(vocab) spm := NewSentencePiece(vocab)
tests := []struct { tests := []struct {
name string name string

167
model/wordpiece.go Normal file
View File

@@ -0,0 +1,167 @@
package model
import (
"fmt"
"iter"
"strings"
"unicode"
"github.com/ollama/ollama/logutil"
)
type WordPiece struct {
vocab *Vocabulary
}
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
// this differs from original word piece which uses "##" to indicate subwords.
const ggmlPrefix = "▁"
var wordPieceReplacer = strings.NewReplacer(
" .", ".",
" ?", "?",
" !", "!",
" ,", ",",
" ' ", "'",
" n't", "n't",
" 'm", "'m",
" do not", " don't",
" 's", "'s",
" 've", "'ve",
" 're", "'re",
)
// Decode implements TextProcessor.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
if id < 0 || int(id) >= len(wpm.vocab.Values) {
return "", fmt.Errorf("invalid token id: %d", id)
}
var separator string
piece := wpm.vocab.Values[id]
if i > 0 &&
(strings.HasPrefix(piece, ggmlPrefix) ||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
separator = " "
}
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
}
return sb.String(), nil
}
// words splits a string into words, treating CJK characters as separate words.
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
func (wpm WordPiece) words(s string) iter.Seq[string] {
return func(yield func(string) bool) {
runes := make([]rune, 0, len(s)*3)
for _, r := range s {
switch {
case r >= 0x4E00 && r <= 0x9FFF,
r >= 0x3400 && r <= 0x4DBF,
r >= 0x20000 && r <= 0x2A6DF,
r >= 0x2A700 && r <= 0x2B73F,
r >= 0x2B740 && r <= 0x2B81F,
r >= 0x2B820 && r <= 0x2CEAF,
r >= 0xF900 && r <= 0xFAFF,
r >= 0x2F800 && r <= 0x2FA1F:
runes = append(runes, ' ', r, ' ')
default:
runes = append(runes, r)
}
}
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
// split on but keep punctuation
var start int
for start < len(w) {
end := strings.IndexFunc(w[start:], unicode.IsPunct)
if end < 0 {
end = len(w) - start
} else if end == 0 {
end = 1
}
if !yield(w[start : start+end]) {
return
}
start += end
}
}
}
}
// Encode implements TextProcessor.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
// TODO: use [UNK] from config
unk := wpm.vocab.Encode("[UNK]")
for word := range wpm.words(s) {
var start int
var pieces []int32
for start < len(word) {
end := len(word)
var piece int32
for start < end {
subword := word[start:end]
if start == 0 {
subword = ggmlPrefix + subword
}
// TODO: some models might not want [ToLower]
piece = wpm.vocab.Encode(strings.ToLower(subword))
if piece >= 0 {
break
}
end--
}
if piece < 0 {
// Unknown token
pieces = pieces[:0]
break
}
pieces = append(pieces, piece)
start = end
}
if len(pieces) > 0 {
ids = append(ids, pieces...)
} else {
ids = append(ids, unk)
}
}
if addSpecial && len(ids) > 0 {
ids = wpm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
// Is implements TextProcessor.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary) WordPiece {
return WordPiece{
vocab: vocab,
}
}

51
model/wordpiece_test.go Normal file
View File

@@ -0,0 +1,51 @@
package model
import (
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWordPiece(t *testing.T) {
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
AddBOS: true,
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
})
ids, err := wpm.Encode("Hello world!", true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
t.Errorf("unexpected ids (-want +got):\n%s", diff)
}
words, err := wpm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}

View File

@@ -62,14 +62,15 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
for _, c := range f.Commands { for _, c := range f.Commands {
switch c.Name { switch c.Name {
case "model": case "model":
path, err := expandPath(c.Args, relativeDir) name := c.Args.(string)
path, err := expandPath(name, relativeDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
digestMap, err := fileDigestMap(path) digestMap, err := fileDigestMap(path)
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
req.From = c.Args req.From = name
continue continue
} else if err != nil { } else if err != nil {
return nil, err return nil, err
@@ -83,7 +84,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
} }
} }
case "adapter": case "adapter":
path, err := expandPath(c.Args, relativeDir) adapter := c.Args.(string)
path, err := expandPath(adapter, relativeDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -95,21 +97,25 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
req.Adapters = digestMap req.Adapters = digestMap
case "template": case "template":
req.Template = c.Args template := c.Args.(string)
req.Template = template
case "system": case "system":
req.System = c.Args system := c.Args.(string)
req.System = system
case "license": case "license":
licenses = append(licenses, c.Args) license := c.Args.(string)
licenses = append(licenses, license)
case "message": case "message":
role, msg, _ := strings.Cut(c.Args, ": ") msg := c.Args.(*Message)
messages = append(messages, api.Message{Role: role, Content: msg}) messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
default: case "parameter":
if slices.Contains(deprecatedParameters, c.Name) { if slices.Contains(deprecatedParameters, c.Name) {
fmt.Printf("warning: parameter %s is deprecated\n", c.Name) fmt.Printf("warning: parameter '%s' is deprecated\n", c.Name)
break break
} }
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}}) param := c.Args.(*Parameter)
ps, err := api.FormatParams(map[string][]string{param.Name: {param.Value}})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -123,6 +129,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
params[k] = v params[k] = v
} }
} }
default:
return nil, fmt.Errorf("warning: unknown command '%s'", c.Name)
} }
} }
@@ -312,7 +320,17 @@ func filesForModel(path string) ([]string, error) {
type Command struct { type Command struct {
Name string Name string
Args string Args any
}
type Parameter struct {
Name string
Value string
}
type Message struct {
Role string
Content string
} }
func (c Command) String() string { func (c Command) String() string {
@@ -321,12 +339,16 @@ func (c Command) String() string {
case "model": case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args) fmt.Fprintf(&sb, "FROM %s", c.Args)
case "license", "template", "system", "adapter": case "license", "template", "system", "adapter":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) data := c.Args.(string)
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(data))
case "message": case "message":
role, message, _ := strings.Cut(c.Args, ": ") data := c.Args.(*Message)
fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message)) fmt.Fprintf(&sb, "MESSAGE %s %s", data.Role, quote(data.Content))
case "parameter":
data := c.Args.(*Parameter)
fmt.Fprintf(&sb, "PARAMETER %s %s", data.Name, quote(data.Value))
default: default:
fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args)) fmt.Printf("unknown command '%s'\n", c.Name)
} }
return sb.String() return sb.String()
@@ -366,7 +388,6 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
var curr state var curr state
var currLine int = 1 var currLine int = 1
var b bytes.Buffer var b bytes.Buffer
var role string
var f Modelfile var f Modelfile
@@ -413,6 +434,7 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
case "parameter": case "parameter":
// transition to stateParameter which sets command name // transition to stateParameter which sets command name
next = stateParameter next = stateParameter
cmd.Name = s
case "message": case "message":
// transition to stateMessage which validates the message role // transition to stateMessage which validates the message role
next = stateMessage next = stateMessage
@@ -421,16 +443,37 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
cmd.Name = s cmd.Name = s
} }
case stateParameter: case stateParameter:
cmd.Name = b.String() s, ok := unquote(strings.TrimSpace(b.String()))
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
continue
}
cmd.Args = &Parameter{
Name: s,
}
case stateMessage: case stateMessage:
if !isValidMessageRole(b.String()) { s, ok := unquote(strings.TrimSpace(b.String()))
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
continue
}
if !isValidMessageRole(s) {
return nil, &ParserError{ return nil, &ParserError{
LineNumber: currLine, LineNumber: currLine,
Msg: errInvalidMessageRole.Error(), Msg: errInvalidMessageRole.Error(),
} }
} }
role = b.String() cmd.Args = &Message{
Role: s,
}
case stateComment, stateNil: case stateComment, stateNil:
// pass // pass
case stateValue: case stateValue:
@@ -443,12 +486,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
continue continue
} }
if role != "" { switch cmd.Name {
s = role + ": " + s case "parameter":
role = "" p := cmd.Args.(*Parameter)
p.Value = s
case "message":
m := cmd.Args.(*Message)
m.Content = s
default:
cmd.Args = s
} }
cmd.Args = s
f.Commands = append(f.Commands, cmd) f.Commands = append(f.Commands, cmd)
} }
@@ -473,11 +520,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
return nil, io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
} }
if role != "" { switch cmd.Name {
s = role + ": " + s case "parameter":
c := cmd.Args.(*Parameter)
c.Value = s
case "message":
c := cmd.Args.(*Message)
c.Content = s
default:
cmd.Args = s
} }
cmd.Args = s
f.Commands = append(f.Commands, cmd) f.Commands = append(f.Commands, cmd)
default: default:
return nil, io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF

View File

@@ -47,8 +47,8 @@ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
{Name: "model", Args: "model1"}, {Name: "model", Args: "model1"},
{Name: "adapter", Args: "adapter1"}, {Name: "adapter", Args: "adapter1"},
{Name: "license", Args: "MIT"}, {Name: "license", Args: "MIT"},
{Name: "param1", Args: "value1"}, {Name: "parameter", Args: &Parameter{"param1", "value1"}},
{Name: "param2", Args: "value2"}, {Name: "parameter", Args: &Parameter{"param2", "value2"}},
{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"}, {Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
} }
@@ -80,8 +80,8 @@ TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
{Name: "model", Args: " model 1"}, {Name: "model", Args: " model 1"},
{Name: "adapter", Args: "adapter3"}, {Name: "adapter", Args: "adapter3"},
{Name: "license", Args: "MIT "}, {Name: "license", Args: "MIT "},
{Name: "param1", Args: "value1"}, {Name: "parameter", Args: &Parameter{"param1", "value1"}},
{Name: "param2", Args: "value2"}, {Name: "parameter", Args: &Parameter{"param2", "value2"}},
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "}, {Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
} }
@@ -101,7 +101,7 @@ func TestParseFileFrom(t *testing.T) {
}, },
{ {
"FROM \"FOO BAR\"\nPARAMETER param1 value1", "FROM \"FOO BAR\"\nPARAMETER param1 value1",
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}}, []Command{{Name: "model", Args: "FOO BAR"}, {Name: "parameter", Args: &Parameter{"param1", "value1"}}},
nil, nil,
}, },
{ {
@@ -149,12 +149,12 @@ func TestParseFileFrom(t *testing.T) {
}, },
{ {
"PARAMETER param1 value1\nFROM foo", "PARAMETER param1 value1\nFROM foo",
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}}, []Command{{Name: "parameter", Args: &Parameter{"param1", "value1"}}, {Name: "model", Args: "foo"}},
nil, nil,
}, },
{ {
"PARAMETER what the \nFROM lemons make lemonade ", "PARAMETER what the \nFROM lemons make lemonade ",
[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}}, []Command{{Name: "parameter", Args: &Parameter{"what", "the"}}, {Name: "model", Args: "lemons make lemonade"}},
nil, nil,
}, },
} }
@@ -211,7 +211,7 @@ MESSAGE system You are a file parser. Always parse things.
`, `,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a file parser. Always parse things."}, {Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
}, },
nil, nil,
}, },
@@ -221,7 +221,7 @@ FROM foo
MESSAGE system You are a file parser. Always parse things.`, MESSAGE system You are a file parser. Always parse things.`,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a file parser. Always parse things."}, {Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
}, },
nil, nil,
}, },
@@ -234,9 +234,9 @@ MESSAGE assistant Hello, I want to parse all the things!
`, `,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a file parser. Always parse things."}, {Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
{Name: "message", Args: "user: Hey there!"}, {Name: "message", Args: &Message{"user", "Hey there!"}},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, {Name: "message", Args: &Message{"assistant", "Hello, I want to parse all the things!"}},
}, },
nil, nil,
}, },
@@ -244,12 +244,12 @@ MESSAGE assistant Hello, I want to parse all the things!
` `
FROM foo FROM foo
MESSAGE system """ MESSAGE system """
You are a multiline file parser. Always parse things. You are a multiline file "parser". Always parse things.
""" """
`, `,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"}, {Name: "message", Args: &Message{"system", "\nYou are a multiline file \"parser\". Always parse things.\n"}},
}, },
nil, nil,
}, },
@@ -514,7 +514,7 @@ func TestParseFileParameters(t *testing.T) {
assert.Equal(t, []Command{ assert.Equal(t, []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: v.name, Args: v.value}, {Name: "parameter", Args: &Parameter{v.name, v.value}},
}, modelfile.Commands) }, modelfile.Commands)
}) })
} }
@@ -617,8 +617,8 @@ SYSTEM You are a utf16 file.
expected := []Command{ expected := []Command{
{Name: "model", Args: "bob"}, {Name: "model", Args: "bob"},
{Name: "param1", Args: "1"}, {Name: "parameter", Args: &Parameter{"param1", "1"}},
{Name: "param2", Args: "4096"}, {Name: "parameter", Args: &Parameter{"param2", "4096"}},
{Name: "system", Args: "You are a utf16 file."}, {Name: "system", Args: "You are a utf16 file."},
} }

View File

@@ -1,126 +0,0 @@
package parser
import (
"encoding/json"
"errors"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
)
type TokenParserType int
const (
TokenParserTypeDefault TokenParserType = iota
TokenParserTypeHarmony
)
type TokenParser struct {
messageHandler MessageHandler
parserEngine ParserInternals
toolParser ToolParser
lastToken string
tokenRepeat int
repeatLimit int
}
const defaultTokenRepeatLimit = 30
type MessageHandler interface {
AddContent(token string) (content, thinking string, toolContent string)
}
type ParserInternals interface {
AddImplicitStartOrPrefill(prefillString string)
}
type ToolParser interface {
Add(token string)
Drain() (toolName *string, toolContent string)
}
// Default implementation for the TokenParser interface as a no-op passthrough
type defaultMessageHandler struct{}
func (defaultMessageHandler) AddContent(token string) (string, string, string) {
return token, "", ""
}
type defaultEngine struct{}
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
type defaultToolParser struct{}
func (defaultToolParser) Add(token string) {}
func (defaultToolParser) Drain() (*string, string) { return nil, "" }
func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser {
switch parserType {
case TokenParserTypeHarmony:
harmonyMessageHandler := harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString)
return TokenParser{
messageHandler: harmonyMessageHandler,
parserEngine: harmonyMessageHandler.HarmonyParser,
toolParser: harmonyMessageHandler.ToolParser,
repeatLimit: defaultTokenRepeatLimit,
}
default:
return TokenParser{
messageHandler: defaultMessageHandler{},
parserEngine: defaultEngine{},
toolParser: defaultToolParser{},
repeatLimit: 30,
}
}
}
func (p *TokenParser) AddContent(token string) (string, string, error) {
if p.repeatLimitReached(token) {
return "", "", errors.New("token repeat limit reached")
}
content, thinking, toolContent := p.messageHandler.AddContent(token)
p.toolParser.Add(toolContent)
return content, thinking, nil
}
// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached.
func (p *TokenParser) repeatLimitReached(token string) bool {
if p == nil {
return false
}
trimmed := strings.TrimSpace(token)
if trimmed == p.lastToken {
p.tokenRepeat++
} else {
p.tokenRepeat = 0
}
p.lastToken = trimmed
return p.tokenRepeat >= p.repeatLimit
}
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
func (p *TokenParser) Drain() []api.ToolCall {
toolName, toolContent := p.toolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
return nil
}
return []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
},
}
}
return nil
}

View File

@@ -34,7 +34,6 @@ import (
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample" "github.com/ollama/ollama/sample"
@@ -468,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet // Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input var batchInputs []*input.Input
var batchOutputs []int32
var batch input.Batch var batch input.Batch
resumeSeq := -1 resumeSeq := -1
@@ -550,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id) batch.Sequences = append(batch.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs) seq.iBatch = len(batchOutputs)
if i+1 == len(seq.inputs) { if i+1 == len(seq.inputs) || seq.embeddingOnly {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
} }
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs)) logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp) seq.pendingInputs = append(seq.pendingInputs, inp)
@@ -577,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute // Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs)) batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch) nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
if err != nil { if err != nil {
err = fmt.Errorf("failed to build graph: %w", err) err = fmt.Errorf("failed to build graph: %w", err)
@@ -704,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
} }
// sample a token // sample a token
vocabSize := len(outputs) / len(activeBatch.batch.Outputs) vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches) logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil { if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
@@ -781,8 +782,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
if req.Options == nil { if req.Options == nil {
opts := api.DefaultOptions() opts := api.DefaultOptions()
req.Options = &opts req.Options = &opts
@@ -873,18 +872,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
var thinking string
var err error
content, thinking, err = tokenParser.AddContent(content)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
close(seq.quit)
return
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content, Content: content,
Thinking: thinking,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit) close(seq.quit)
@@ -893,9 +882,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush() flusher.Flush()
} else { } else {
toolCalls := tokenParser.Drain()
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
ToolCalls: toolCalls,
Done: true, Done: true,
DoneReason: seq.doneReason, DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs, PromptEvalCount: seq.numPromptInputs,
@@ -1061,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Positions[i] = int32(i) batch.Positions[i] = int32(i)
} }
batch.Outputs = make([]int32, s.parallel)
for i := range batch.Outputs {
batch.Outputs[i] = int32(i)
}
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
cache := s.model.Config().Cache cache := s.model.Config().Cache
if cache != nil { if cache != nil {

View File

@@ -36,7 +36,6 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
@@ -47,6 +46,18 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
return true
}
}
return false
}
func experimentEnabled(name string) bool { func experimentEnabled(name string) bool {
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
} }
@@ -196,17 +207,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw useHarmony := shouldUseHarmony(m) && !req.Raw
var parserType parser.TokenParserType var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if useHarmony { if useHarmony {
parserType = parser.TokenParserTypeHarmony harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
} else { harmonyMessageHandler.HarmonyParser.AddImplicitStart()
parserType = parser.TokenParserTypeDefault harmonyToolParser = harmonyMessageHandler.CreateToolParser()
}
var functionNameMap *harmony.FunctionNameMap
if useHarmony {
functionNameMap = harmony.NewFunctionNameMap()
} }
// Validate Think value: string values currently only allowed for gptoss models // Validate Think value: string values currently only allowed for gptoss models
@@ -350,19 +357,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder var sb strings.Builder
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
ParserType: parserType,
}, func(cr llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{ res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Response: cr.Content, Response: cr.Content,
Done: cr.Done, Done: cr.Done,
Thinking: cr.Thinking,
ToolCalls: cr.ToolCalls,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount, PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
@@ -371,22 +375,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}, },
} }
if res.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if useHarmony { if useHarmony {
for i, tool := range res.ToolCalls { content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) res.Response = content
} res.Thinking = thinking
if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done { harmonyToolParser.Add(toolContent)
ch <- res } else if thinkingState != nil {
}
return
}
if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content) thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking res.Thinking = thinking
res.Response = content res.Response = content
@@ -397,6 +391,30 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
if cr.Done { if cr.Done {
if useHarmony {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw { if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String()) tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil { if err != nil {
@@ -470,7 +488,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
truncate := true truncate := true
if req.Truncate != nil && !*req.Truncate { if req.Truncate != nil && !*req.Truncate {
truncate = false truncate = false
} }
@@ -537,7 +554,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen--
}
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
ctxLen--
}
tokens = tokens[:ctxLen] tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens) s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1599,27 +1625,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
msgs = filterThinkTags(msgs, m) msgs = filterThinkTags(msgs, m)
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) var harmonyMessageHandler *harmony.HarmonyMessageHandler
var parserType parser.TokenParserType var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if useHarmony {
parserType = parser.TokenParserTypeHarmony useHarmony := shouldUseHarmony(m)
} else {
parserType = parser.TokenParserTypeDefault
}
processedTools := req.Tools processedTools := req.Tools
var functionNameMap *harmony.FunctionNameMap
var prefillString string
// TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner
if useHarmony { if useHarmony {
prefillString = harmony.Prefill(msgs[len(msgs)-1]) harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
functionNameMap = harmony.NewFunctionNameMap() var lastMessage *api.Message
if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1]
}
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
// make a copy of tools to pass to the chat prompt. Function names may be // make a copy of tools to pass to the chat prompt. Function names may be
// renamed to be valid Harmony function names. // renamed to be valid Harmony function names.
processedTools = make([]api.Tool, len(req.Tools)) processedTools = make([]api.Tool, len(req.Tools))
copy(processedTools, req.Tools) copy(processedTools, req.Tools)
for i, tool := range processedTools { for i, tool := range processedTools {
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name) processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
} }
} }
@@ -1672,17 +1698,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
ParserType: parserType,
PrefillString: prefillString,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
res := api.ChatResponse{ res := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls}, Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done, Done: r.Done,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
@@ -1698,13 +1722,31 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
if useHarmony { if useHarmony {
for i, tool := range res.Message.ToolCalls { content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) res.Message.Content = content
res.Message.Thinking = thinking
harmonyToolParser.Add(toolContent)
if r.Done {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
}
} }
// only send messages with meaningful content (empty messages confuse clients) // only send messages with meaningful content (empty messages confuse clients)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
ch <- res ch <- res
} }
return return
} }

View File

@@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -117,7 +118,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "content streams as it arrives", name: "content streams as it arrives",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Content: "Hello", Done: false}, input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
wantContent: "Hello", wantContent: "Hello",
}, },
{ {
@@ -125,7 +126,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
wantContent: ", world", wantContent: ", world",
}, },
{ {
input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop}, input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "!", wantContent: "!",
}, },
}, },
@@ -134,15 +135,20 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "thinking streams separately from content", name: "thinking streams separately from content",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Thinking: "Thinking...", Done: false}, input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
wantThinking: "Thinking...", wantThinking: "Thinking...",
}, },
{ {
input: llm.CompletionResponse{Content: "Answer", Done: false}, input: llm.CompletionResponse{Content: "<|end|>", Done: false},
wantContent: "Answer", // No output expected - just closes the analysis message and resets state to normal
}, },
{ {
input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop}, input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
wantContent: "Answer", // After message end, state is reset to normal
},
{
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
// No output expected - just closes the assistant message
}, },
}, },
}, },
@@ -150,16 +156,24 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "partial tags buffer until complete", name: "partial tags buffer until complete",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Thinking: "Deep ", Done: false}, input: llm.CompletionResponse{Content: "<|chan", Done: false},
// No output - partial tag
},
{
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
// No output - still building tags
},
{
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
wantThinking: "Deep ", wantThinking: "Deep ",
}, },
{ {
input: llm.CompletionResponse{Thinking: "thought", Done: false}, input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
wantThinking: "thought", wantThinking: "thought",
}, },
{ {
input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop}, input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Done", wantContent: "Done", // After message end, state is reset to normal
}, },
}, },
}, },
@@ -167,7 +181,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "simple assistant after analysis", name: "simple assistant after analysis",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop}, input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Answer", wantContent: "Answer",
wantThinking: "Think", wantThinking: "Think",
}, },
@@ -177,7 +191,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call parsed and returned correctly", name: "tool call parsed and returned correctly",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop}, input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "The weather is sunny", wantContent: "The weather is sunny",
wantToolCalls: []api.ToolCall{ wantToolCalls: []api.ToolCall{
{ {
@@ -196,10 +210,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call with streaming JSON across chunks", name: "tool call with streaming JSON across chunks",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Done: false}, input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
// No output yet - incomplete JSON
}, },
{ {
input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true}, input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
// Still no output - incomplete JSON
},
{
input: llm.CompletionResponse{Content: "2\"}", Done: true},
wantToolCalls: []api.ToolCall{ wantToolCalls: []api.ToolCall{
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
@@ -381,9 +400,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
mockResponses := []llm.CompletionResponse{ mockResponses := []llm.CompletionResponse{
{Content: "First ", Done: false}, {Content: "<|message|>First ", Done: false},
{Content: "chunk ", Done: false}, {Content: "chunk ", Done: false},
{Content: "here", Done: true, DoneReason: llm.DoneReasonStop}, {Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
} }
mock := mockRunner{ mock := mockRunner{
@@ -488,3 +507,189 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks) t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
} }
} }
func TestChatHarmonyParserStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
type expectedChunk struct {
afterResponse int // Which mock response this chunk should appear after
content string // Expected content in this chunk
thinking string // Expected thinking in this chunk
}
testCases := []struct {
name string
mockResponses []llm.CompletionResponse
expectedChunks []expectedChunk
wantContent string
wantThinking string
}{
{
name: "simple message without thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
{Content: "how can I help?", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 1, content: "Hello, "},
{afterResponse: 2, content: "how can I help?"},
},
wantContent: "Hello, how can I help?",
},
{
name: "message with analysis channel for thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|channel|>analysis<|message|>", Done: false},
{Content: "Let me think ", Done: false},
{Content: "about this problem...", Done: false},
{Content: "<|end|>", Done: false},
{Content: "<|start|>assistant<|message|>", Done: false},
{Content: "The answer ", Done: false},
{Content: "is 42", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 2, thinking: "Let me think "},
{afterResponse: 3, thinking: "about this problem..."},
{afterResponse: 6, content: "The answer "},
{afterResponse: 7, content: "is 42"},
},
wantContent: "The answer is 42",
wantThinking: "Let me think about this problem...",
},
{
name: "streaming with partial tags across boundaries",
mockResponses: []llm.CompletionResponse{
{Content: "<|chan", Done: false},
{Content: "nel|>analy", Done: false},
{Content: "sis<|mess", Done: false},
{Content: "age|>Think", Done: false},
{Content: "ing deeply...<|end|>", Done: false},
{Content: "<|start|>assi", Done: false},
{Content: "stant<|message|>Result ", Done: false},
{Content: "computed<|e", Done: false},
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 4, thinking: "Think"},
{afterResponse: 5, thinking: "ing deeply..."},
{afterResponse: 7, content: "Result "},
{afterResponse: 8, content: "computed"},
},
wantContent: "Result computed",
wantThinking: "Thinking deeply...",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Channel to synchronize mock responses with chunk verification
responsesSent := make(chan int, len(tc.mockResponses))
mock := mockRunner{
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
// Send mock responses one at a time, notifying when each is sent
for i, resp := range tc.mockResponses {
fn(resp)
responsesSent <- i + 1
}
close(responsesSent)
return nil
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a minimal model
_, digest := createHarmonyTestModel(t)
// Create model with passthrough template
stream := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "harmony-test",
Files: map[string]string{"file.gguf": digest},
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("failed to create model: %d", w.Code)
}
// Test chat endpoint with streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "harmony-test",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: &streamTrue,
Tools: getTestTools(),
})
if w.Code != http.StatusOK {
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
}
// Parse streaming response
var chunks []api.ChatResponse
var content, thinking strings.Builder
decoder := json.NewDecoder(w.Body)
for decoder.More() {
var chunk api.ChatResponse
if err := decoder.Decode(&chunk); err != nil {
t.Fatalf("failed to decode chunk: %v", err)
}
chunks = append(chunks, chunk)
// Accumulate content and thinking from each chunk
content.WriteString(chunk.Message.Content)
thinking.WriteString(chunk.Message.Thinking)
// Debug output
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
}
// Verify we got streaming chunks
if len(chunks) == 0 {
t.Fatal("expected streaming chunks, got none")
}
gotContent := content.String()
gotThinking := thinking.String()
if gotContent != tc.wantContent {
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
}
if gotThinking != tc.wantThinking {
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
}
// Verify last chunk has done=true
lastChunk := chunks[len(chunks)-1]
if !lastChunk.Done {
t.Error("expected last chunk to have done=true")
}
})
}
}