runner: enable returning more info from runner processing

Currently we return only the text predicted from the LLM. This was nice in
that it was simple, but there may be other info we want to know from the
processing. This change adds the ability to return more information from the
runner than just the text predicted.
This commit is contained in:
Bruce MacDonald
2025-02-21 16:31:31 -08:00
parent 9f8a18ec05
commit d5eae8248d
6 changed files with 378 additions and 230 deletions

View File

@@ -2,6 +2,8 @@ package common
import (
"strings"
"github.com/ollama/ollama/llm"
)
func FindStop(sequence string, stops []string) (bool, string) {
@@ -29,68 +31,41 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case)
func TruncateStop(pieces []string, stop string) ([]string, bool) {
joined := strings.Join(pieces, "")
index := strings.Index(joined, stop)
if index == -1 {
return pieces, false
func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) {
var sequence string
for _, resp := range resps {
sequence += resp.Content
}
joined = joined[:index]
// Split truncated string back into pieces of original lengths
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
idx := strings.Index(sequence, stop)
if idx < 0 {
return resps, false
}
var result []string
tokenTruncated := false
start := 0
for _, length := range lengths {
if start >= len(joined) {
truncated := sequence[:idx]
if len(truncated) == 0 {
return nil, true
}
result := make([]llm.CompletionResponse, 0, len(resps))
// Track position in truncated sequence
pos := 0
truncationHappened := false
for _, resp := range resps {
if pos >= len(truncated) {
break
}
end := start + length
if end > len(joined) {
end = len(joined)
tokenTruncated = true
chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
if len(chunk) < len(resp.Content) {
truncationHappened = true
}
result = append(result, joined[start:end])
start = end
if len(chunk) > 0 {
result = append(result, llm.CompletionResponse{Content: chunk})
}
pos += len(resp.Content)
}
return result, tokenTruncated
}
func IncompleteUnicode(token string) bool {
incomplete := false
// check if there is incomplete UTF-8 character at the end
for i := 1; i < 5 && i <= len(token); i++ {
c := token[len(token)-i]
if (c & 0xc0) == 0x80 {
// continuation byte: 10xxxxxx
continue
}
if (c & 0xe0) == 0xc0 {
// 2-byte character: 110xxxxx ...
incomplete = i < 2
} else if (c & 0xf0) == 0xe0 {
// 3-byte character: 1110xxxx ...
incomplete = i < 3
} else if (c & 0xf8) == 0xf0 {
// 4-byte character: 11110xxx ...
incomplete = i < 4
}
// else 1-byte character or invalid byte
break
}
return incomplete
return result, truncationHappened
}

View File

@@ -1,51 +1,84 @@
package common
import (
"fmt"
"reflect"
"testing"
"github.com/ollama/ollama/llm"
)
func TestTruncateStop(t *testing.T) {
tests := []struct {
name string
pieces []string
pieces []llm.CompletionResponse
stop string
expected []string
expected []llm.CompletionResponse
expectedTrunc bool
}{
{
name: "Single word",
pieces: []string{"hello", "world"},
stop: "world",
expected: []string{"hello"},
name: "Single word",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: "world"},
},
stop: "world",
expected: []llm.CompletionResponse{
{Content: "Hello"},
},
expectedTrunc: false,
},
{
name: "Partial",
pieces: []string{"hello", "wor"},
stop: "or",
expected: []string{"hello", "w"},
name: "Partial",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " wor"},
},
stop: "or",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " w"},
},
expectedTrunc: true,
},
{
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: []string{"Hello", " there"},
name: "Suffix",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " there"},
{Content: "!"},
},
stop: "!",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " there"},
},
expectedTrunc: false,
},
{
name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"},
stop: "there!",
expected: []string{"Hello", " "},
name: "Suffix partial",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " the"},
{Content: "re!"},
},
stop: "there!",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " "},
},
expectedTrunc: true,
},
{
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
name: "Middle",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " wo"},
},
stop: "llo w",
expected: []llm.CompletionResponse{
{Content: "He"},
},
expectedTrunc: true,
},
}
@@ -54,76 +87,23 @@ func TestTruncateStop(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
}
})
}
}
func TestIncompleteUnicode(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Basic",
input: "hi",
expected: false,
},
{
name: "Two byte",
input: "hi" + string([]byte{0xc2, 0xa3}),
expected: false,
},
{
name: "Two byte - missing last",
input: "hi" + string([]byte{0xc2}),
expected: true,
},
{
name: "Three byte",
input: "hi" + string([]byte{0xe0, 0xA0, 0x80}),
expected: false,
},
{
name: "Three byte - missing last",
input: "hi" + string([]byte{0xe0, 0xA0}),
expected: true,
},
{
name: "Three byte - missing last 2",
input: "hi" + string([]byte{0xe0}),
expected: true,
},
{
name: "Four byte",
input: "hi" + string([]byte{0xf0, 0x92, 0x8a, 0xb7}),
expected: false,
},
{
name: "Four byte - missing last",
input: "hi" + string([]byte{0xf0, 0x92, 0x8a}),
expected: true,
},
{
name: "Four byte - missing last 2",
input: "hi" + string([]byte{0xf0, 0x92}),
expected: true,
},
{
name: "Four byte - missing last 3",
input: "hi" + string([]byte{0xf0}),
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IncompleteUnicode(tt.input)
if result != tt.expected {
t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected)
}
})
func formatContentDiff(result, expected []llm.CompletionResponse) string {
var s string
for i := 0; i < len(result) || i < len(expected); i++ {
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
} else if i < len(result) && i >= len(expected) {
s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
} else if i >= len(result) && i < len(expected) {
s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
}
}
return s
}