add thinking support to the api and cli (#10584)

- Both `/api/generate` and `/api/chat` now accept a `"think"`
  option that allows specifying whether thinking mode should be on or
  not
- Templates get passed this new option so, e.g., qwen3's template can
  put `/think` or `/no_think` in the system prompt depending on the
  value of the setting
- Models' thinking support is inferred by inspecting model templates.
  The prefix and suffix the parser uses to identify thinking support is
  also automatically inferred from templates
- Thinking control & parsing is opt-in via the API to prevent breaking
  existing API consumers. If the `"think"` option is not specified, the
  behavior is unchanged from previous versions of ollama
- Add parsing for thinking blocks in both streaming/non-streaming mode
  in both `/generate` and `/chat`
- Update the CLI to make use of these changes. Users can pass `--think`
  or `--think=false` to control thinking, or during an interactive
  session they can use the commands `/set think` or `/set nothink`
- A `--hidethinking` option has also been added to the CLI. This makes
  it easy to use thinking in scripting scenarios like
  `ollama run qwen3 --think --hidethinking "my question here"` where you
  just want to see the answer but still want the benefits of thinking
  models
This commit is contained in:
Devon Rifkin
2025-05-28 19:38:52 -07:00
committed by GitHub
parent aa25aff10d
commit 5f57b0ef42
17 changed files with 1195 additions and 49 deletions

View File

@@ -37,6 +37,7 @@ var (
errCapabilityInsert = errors.New("insert")
errCapabilityVision = errors.New("vision")
errCapabilityEmbedding = errors.New("embedding")
errCapabilityThinking = errors.New("thinking")
errInsecureProtocol = errors.New("insecure protocol http")
)
@@ -111,6 +112,12 @@ func (m *Model) Capabilities() []model.Capability {
capabilities = append(capabilities, model.CapabilityVision)
}
// Check for thinking capability
openingTag, closingTag := inferThinkingTags(m.Template.Template)
if openingTag != "" && closingTag != "" {
capabilities = append(capabilities, model.CapabilityThinking)
}
return capabilities
}
@@ -127,6 +134,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
model.CapabilityInsert: errCapabilityInsert,
model.CapabilityVision: errCapabilityVision,
model.CapabilityEmbedding: errCapabilityEmbedding,
model.CapabilityThinking: errCapabilityThinking,
}
for _, cap := range want {
@@ -141,11 +149,19 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
}
}
var err error
if len(errs) > 0 {
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
err = fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
}
return nil
if slices.Contains(errs, errCapabilityThinking) {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
// append a message to the existing error
return fmt.Errorf("%w. Pull the model again to get the latest version with full thinking support", err)
}
}
return err
}
func (m *Model) String() string {

View File

@@ -19,7 +19,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *bool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@@ -41,8 +41,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
}
thinkVal := false
if think != nil {
thinkVal = *think
}
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
return "", nil, err
}
@@ -96,7 +100,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
thinkVal := false
if think != nil {
thinkVal = *think
}
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
return "", nil, err
}

View File

@@ -208,7 +208,8 @@ func TestChatPrompt(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
think := false
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
if tt.error == nil && err != nil {
t.Fatal(err)
} else if tt.error != nil && err != tt.error {

View File

@@ -17,7 +17,6 @@ import (
"net/netip"
"os"
"os/signal"
"regexp"
"slices"
"strings"
"syscall"
@@ -186,6 +185,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Suffix != "" {
caps = append(caps, model.CapabilityInsert)
}
if req.Think != nil && *req.Think {
caps = append(caps, model.CapabilityThinking)
// TODO(drifkin): consider adding a warning if it's false and the model
// doesn't support thinking. It's not strictly required, but it can be a
// hint that the user is on an older qwen3/r1 model that doesn't have an
// updated template supporting thinking
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
@@ -254,6 +260,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
}
values.Think = req.Think != nil && *req.Think
values.IsThinkSet = req.Think != nil
var b bytes.Buffer
if req.Context != nil {
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
@@ -273,6 +282,15 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt = b.String()
}
var thinkingState *thinkingParser
openingTag, closingTag := inferThinkingTags(m.Template.Template)
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
thinkingState = &thinkingParser{
openingTag: openingTag,
closingTag: closingTag,
}
}
ch := make(chan any)
go func() {
// TODO (jmorganca): avoid building the response twice both here and below
@@ -297,6 +315,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
},
}
if thinkingState != nil {
thinking, content := thinkingState.addContent(cr.Content)
res.Thinking = thinking
res.Response = content
}
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
}
@@ -324,11 +348,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Stream != nil && !*req.Stream {
var r api.GenerateResponse
var sb strings.Builder
var sbThinking strings.Builder
var sbContent strings.Builder
for rr := range ch {
switch t := rr.(type) {
case api.GenerateResponse:
sb.WriteString(t.Response)
sbThinking.WriteString(t.Thinking)
sbContent.WriteString(t.Response)
r = t
case gin.H:
msg, ok := t["error"].(string)
@@ -344,7 +370,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
}
r.Response = sb.String()
r.Thinking = sbThinking.String()
r.Response = sbContent.String()
c.JSON(http.StatusOK, r)
return
}
@@ -1436,6 +1464,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools)
}
if req.Think != nil && *req.Think {
caps = append(caps, model.CapabilityThinking)
}
name := model.ParseName(req.Model)
if !name.IsValid() {
@@ -1476,13 +1507,22 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
msgs = filterThinkTags(msgs, m)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
if err != nil {
slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var thinkingState *thinkingParser
openingTag, closingTag := inferThinkingTags(m.Template.Template)
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
thinkingState = &thinkingParser{
openingTag: openingTag,
closingTag: closingTag,
}
}
var toolParser *tools.Parser
if len(req.Tools) > 0 {
toolParser, err = tools.NewParser(m.Template.Template)
@@ -1516,6 +1556,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
},
}
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.addContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Content = remainingContent
res.Message.Thinking = thinkingContent
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
@@ -1523,12 +1573,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if len(req.Tools) > 0 {
toolCalls, content := toolParser.Add(r.Content)
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content
} else if len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" {
// don't return
} else {
if r.Done {
ch <- res
@@ -1536,6 +1588,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
}
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
@@ -1544,12 +1597,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse
var sb strings.Builder
var toolCalls []api.ToolCall
var sbThinking strings.Builder
var sbContent strings.Builder
for rr := range ch {
switch t := rr.(type) {
case api.ChatResponse:
sb.WriteString(t.Message.Content)
sbThinking.WriteString(t.Message.Thinking)
sbContent.WriteString(t.Message.Content)
resp = t
if len(req.Tools) > 0 {
toolCalls = append(toolCalls, t.Message.ToolCalls...)
@@ -1568,7 +1623,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
resp.Message.Content = sb.String()
resp.Message.Content = sbContent.String()
resp.Message.Thinking = sbThinking.String()
if len(toolCalls) > 0 {
resp.Message.ToolCalls = toolCalls
}
@@ -1595,8 +1652,6 @@ func handleScheduleError(c *gin.Context, name string, err error) {
}
}
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
finalUserIndex := -1
@@ -1608,7 +1663,17 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
for i, msg := range msgs {
if msg.Role == "assistant" && i < finalUserIndex {
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
// TODO(drifkin): this is from before we added proper thinking support.
// However, even if thinking is not enabled (and therefore we shouldn't
// change the user output), we should probably perform this filtering
// for all thinking models (not just qwen3 & deepseek-r1) since it tends
// to save tokens and improve quality.
thinkingState := &thinkingParser{
openingTag: "<think>",
closingTag: "</think>",
}
_, content := thinkingState.addContent(msg.Content)
msgs[i].Content = content
}
}
}

View File

@@ -143,6 +143,25 @@ func TestGenerateChat(t *testing.T) {
}
})
t.Run("missing thinking capability", func(t *testing.T) {
think := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Think: &think,
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
if w.Code != http.StatusBadRequest {

300
server/thinking.go Normal file
View File

@@ -0,0 +1,300 @@
package server
import (
"strings"
"text/template"
"text/template/parse"
"unicode"
)
type thinkingState int
const (
// We're looking for the opening tag, but we haven't seen any non-whitespace
// characters yet
thinkingState_LookingForOpening thinkingState = iota
// We've seen the opening tag, but we haven't seen any non-whitespace
// characters yet (we want to eat any whitespace between the opening tag and
// the thinking content)
thinkingState_ThinkingStartedEatingWhitespace
// We've seen non-whitespace characters after the opening tag, but we haven't
// seen the closing tag yet
thinkingState_Thinking
// We've seen the closing tag, but we haven't seen any non-whitespace
// characters after the closing tag yet (we want to eat any whitespace between
// the closing tag and the content)
thinkingState_ThinkingDoneEatingWhitespace
// We've seen the closing tag and seen at least one non-whitespace character
// after it
thinkingState_ThinkingDone
)
func (s thinkingState) String() string {
switch s {
case thinkingState_LookingForOpening:
return "LookingForOpening"
case thinkingState_ThinkingStartedEatingWhitespace:
return "ThinkingStartedEatingWhitespace"
case thinkingState_Thinking:
return "Thinking"
case thinkingState_ThinkingDoneEatingWhitespace:
return "ThinkingDoneEatingWhitespace"
case thinkingState_ThinkingDone:
return "ThinkingDone"
default:
return "Unknown"
}
}
type thinkingParser struct {
state thinkingState
openingTag string
closingTag string
acc strings.Builder
}
// addContent returns the thinking content and the non-thinking content that
// should be immediately sent to the user. It will internally buffer if it needs
// to see more raw content to disambiguate
func (s *thinkingParser) addContent(content string) (string, string) {
s.acc.WriteString(content)
var thinkingSb, remainingSb strings.Builder
var thinking, remaining string
keepLooping := true
// we loop because we might pass through multiple parsing states in a single
// call to addContent, and we want to make sure callers don't have to wait for
// data that's already unambiguous
for keepLooping {
thinking, remaining, keepLooping = eat(s)
thinkingSb.WriteString(thinking)
remainingSb.WriteString(remaining)
}
return thinkingSb.String(), remainingSb.String()
}
// the additional bool return is true iff we should continue eating
func eat(s *thinkingParser) (string, string, bool) {
switch s.state {
case thinkingState_LookingForOpening:
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, s.openingTag) {
after := strings.Join(strings.Split(trimmed, s.openingTag)[1:], s.openingTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
// after might contain more than just thinking tokens, so we continue
// parsing instead of returning it as thinking tokens here
s.acc.Reset()
s.acc.WriteString(after)
if after == "" {
s.state = thinkingState_ThinkingStartedEatingWhitespace
} else {
s.state = thinkingState_Thinking
}
return "", "", true
} else if strings.HasPrefix(s.openingTag, trimmed) {
// partial opening seen, so let's keep accumulating
return "", "", false
} else if trimmed == "" {
// saw whitespace only, so let's keep accumulating
return "", "", false
} else {
// didn't see an opening tag, but we have content, so thinking was skipped
s.state = thinkingState_ThinkingDone
// note that we use the original content, not the trimmed one because we
// don't want to eat any whitespace in the real content if there were no
// thinking tags
return "", s.acc.String(), false
}
case thinkingState_ThinkingStartedEatingWhitespace:
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
s.acc.Reset()
if trimmed == "" {
return "", "", false
} else {
s.state = thinkingState_Thinking
s.acc.WriteString(trimmed)
return "", "", true
}
case thinkingState_Thinking:
acc := s.acc.String()
if strings.Contains(acc, s.closingTag) {
split := strings.Split(acc, s.closingTag)
thinking := split[0]
remaining := strings.Join(split[1:], s.closingTag)
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
s.acc.Reset()
if remaining == "" {
s.state = thinkingState_ThinkingDoneEatingWhitespace
} else {
s.state = thinkingState_ThinkingDone
}
return thinking, remaining, false
} else if overlapLen := overlap(acc, s.closingTag); overlapLen > 0 {
thinking := acc[:len(acc)-overlapLen]
remaining := acc[len(acc)-overlapLen:]
s.acc.Reset()
// keep track of the candidate closing tag. We have to buffer it until it
// becomes disambiguated
s.acc.WriteString(remaining)
return thinking, "", false
} else {
// purely just thinking tokens, so we can return them
s.acc.Reset()
return acc, "", false
}
case thinkingState_ThinkingDoneEatingWhitespace:
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
s.acc.Reset()
// if we see non-whitespace, we're done eating the leading whitespace of the content
if trimmed != "" {
s.state = thinkingState_ThinkingDone
}
return "", trimmed, false
case thinkingState_ThinkingDone:
acc := s.acc.String()
s.acc.Reset()
return "", acc, false
default:
panic("unknown state")
}
}
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) {
if n == nil {
return
}
shouldContinue := enterFn(n)
if !shouldContinue {
return
}
switch x := n.(type) {
case *parse.ListNode:
for _, c := range x.Nodes {
templateVisit(c, enterFn, exitFn)
}
case *parse.BranchNode:
if x.Pipe != nil {
templateVisit(x.Pipe, enterFn, exitFn)
}
if x.List != nil {
templateVisit(x.List, enterFn, exitFn)
}
if x.ElseList != nil {
templateVisit(x.ElseList, enterFn, exitFn)
}
case *parse.ActionNode:
templateVisit(x.Pipe, enterFn, exitFn)
case *parse.WithNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.RangeNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.IfNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.TemplateNode:
templateVisit(x.Pipe, enterFn, exitFn)
case *parse.PipeNode:
for _, c := range x.Cmds {
templateVisit(c, enterFn, exitFn)
}
case *parse.CommandNode:
for _, a := range x.Args {
templateVisit(a, enterFn, exitFn)
}
// text, field, number, etc. are leaves nothing to recurse into
}
if exitFn != nil {
exitFn(n)
}
}
// We use a heuristic to infer the tags that surround thinking traces:
// We look for a range node that iterates over "Messages" and then look for a
// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest
// ListNode and take the first and last TextNodes as the opening and closing
// tags.
func inferThinkingTags(t *template.Template) (string, string) {
ancestors := []parse.Node{}
openingTag := ""
closingTag := ""
enterFn := func(n parse.Node) bool {
ancestors = append(ancestors, n)
switch x := n.(type) {
case *parse.FieldNode:
if len(x.Ident) > 0 && x.Ident[0] == "Thinking" {
var mostRecentRange *parse.RangeNode
for i := len(ancestors) - 1; i >= 0; i-- {
if r, ok := ancestors[i].(*parse.RangeNode); ok {
mostRecentRange = r
break
}
}
if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") {
return true
}
// TODO(drifkin): to be more robust, check that it's in the action
// part, not the `if`'s pipeline part. We do match on the nearest list
// that starts and ends with text nodes, which makes this not strictly
// necessary for our heuristic
// go up to the nearest ancestor that is a *parse.ListNode
for i := len(ancestors) - 1; i >= 0; i-- {
if l, ok := ancestors[i].(*parse.ListNode); ok {
firstNode := l.Nodes[0]
if t, ok := firstNode.(*parse.TextNode); ok {
openingTag = strings.TrimSpace(t.String())
}
lastNode := l.Nodes[len(l.Nodes)-1]
if t, ok := lastNode.(*parse.TextNode); ok {
closingTag = strings.TrimSpace(t.String())
}
break
}
}
}
}
return true
}
exitFn := func(n parse.Node) {
ancestors = ancestors[:len(ancestors)-1]
}
templateVisit(t.Root, enterFn, exitFn)
return openingTag, closingTag
}
// checks to see if the given field name is present in the pipeline of the given range node
func rangeUsesField(rangeNode *parse.RangeNode, field string) bool {
found := false
enterFn := func(n parse.Node) bool {
switch x := n.(type) {
case *parse.FieldNode:
if x.Ident[0] == field {
found = true
}
}
return true
}
templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil)
return found
}

403
server/thinking_test.go Normal file
View File

@@ -0,0 +1,403 @@
package server
import (
"testing"
"text/template"
)
func TestExtractThinking(t *testing.T) {
tests := []struct {
in, wantContent, wantThink string
}{
{
in: "<think> internal </think> world",
wantThink: "internal ",
wantContent: "world",
},
{
in: "<think>a</think><think>b</think>c",
wantThink: "a",
wantContent: "<think>b</think>c",
},
{
in: "no think",
wantThink: "",
wantContent: "no think",
},
}
for i, tt := range tests {
parser := thinkingParser{
openingTag: "<think>",
closingTag: "</think>",
}
gotThinking, gotContent := parser.addContent(tt.in)
if gotContent != tt.wantContent || gotThinking != tt.wantThink {
t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent)
}
}
}
func TestThinkingStreaming(t *testing.T) {
type step struct {
input string
wantThinking string
wantContent string
wantStateAfter thinkingState
}
cases := []struct {
desc string
skip bool
steps []step
}{
{
desc: "content without a thinking tag",
steps: []step{
{
input: " abc",
wantThinking: "",
wantContent: " abc",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "content before a thinking tag nerfs the thinking tag",
steps: []step{
{
input: " abc <think>def</think> ghi",
wantThinking: "",
wantContent: " abc <think>def</think> ghi",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "building up a thinking tag partially",
steps: []step{
{
input: " <th",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_LookingForOpening,
},
{
input: "in",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_LookingForOpening,
},
{
input: "k>a",
wantThinking: "a",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
},
},
{
desc: "partial closing tag",
steps: []step{
{
input: "<think>abc</th",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
{
input: "ink>def",
wantThinking: "",
wantContent: "def",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "partial closing tag fakeout",
steps: []step{
{
input: "<think>abc</th",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
{
input: "ing>def",
wantThinking: "</thing>def",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
{
input: "ghi</thi",
wantThinking: "ghi",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
{
input: "nk>jkl",
wantThinking: "",
wantContent: "jkl",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "whitespace after thinking tag",
steps: []step{
{
input: " <think>abc</think>\n\ndef",
wantThinking: "abc",
wantContent: "def",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "whitespace after thinking tag (incremental)",
steps: []step{
{
input: " <think>abc</think>",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
},
{
input: "\n\ndef",
wantThinking: "",
wantContent: "def",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "whitespace after thinking tag with content and more whitespace",
steps: []step{
{
input: " <think>abc</think>\n\ndef ",
wantThinking: "abc",
wantContent: "def ",
wantStateAfter: thinkingState_ThinkingDone,
},
{
input: " ghi",
wantThinking: "",
wantContent: " ghi",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "token by token",
steps: []step{
{
input: "<think>",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
},
{
input: "\n",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
},
{
input: "</think>",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
},
{
input: "\n\n",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
},
{
input: "Hi",
wantThinking: "",
wantContent: "Hi",
wantStateAfter: thinkingState_ThinkingDone,
},
{
input: " there",
wantThinking: "",
wantContent: " there",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "leading thinking whitespace",
steps: []step{
{
input: " <think> \t ",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
},
{
input: " these are some ",
wantThinking: "these are some ",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
{
input: "thoughts </think> ",
wantThinking: "thoughts ",
wantContent: "",
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
},
{
input: " more content",
wantThinking: "",
wantContent: "more content",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
}
for _, c := range cases {
parser := thinkingParser{
openingTag: "<think>",
closingTag: "</think>",
}
if c.skip {
continue
}
for i, step := range c.steps {
thinking, content := parser.addContent(step.input)
if content != step.wantContent || thinking != step.wantThinking {
t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking)
}
if parser.state != step.wantStateAfter {
t.Errorf("case %q (step %d): got state %s, want %s", c.desc, i, parser.state, step.wantStateAfter)
}
}
}
}
func TestInferThinkingTags(t *testing.T) {
cases := []struct {
desc string
tmplString string
wantOpeningTag string
wantClosingTag string
}{
{
desc: "basic",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ end }}
`,
wantOpeningTag: "<think>",
wantClosingTag: "</think>",
},
{
desc: "doubly nested range",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- range $j, $_ := .NotMessages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ end }}
{{ end }}
`,
wantOpeningTag: "",
wantClosingTag: "",
},
{
desc: "whitespace is trimmed",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
Some text before {{ .Thinking }} Some text after
{{ end }}
{{ end }}
`,
wantOpeningTag: "Some text before",
wantClosingTag: "Some text after",
},
{
desc: "qwen3",
tmplString: `
{{- if or .System .Tools .Thinking }}<|im_start|>system
{{- if .System }}
{{ .System }}
{{- end }}
{{- if .Tools }}
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{{- range .Tools }}
{"type": "function", "function": {{ .Function }}}
{{- end }}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
{{- end }}
{{- if .Thinking }}
/think
{{- else }}
/no_think
{{- end }}<|im_end|>
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{- if eq .Role "user" }}<|im_start|>user
{{ .Content }}<|im_end|>
{{ else if eq .Role "assistant" }}<|im_start|>assistant
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{ end }}</tool_call>
{{- end }}{{ if not $last }}<|im_end|>
{{ end }}
{{- else if eq .Role "tool" }}<|im_start|>user
<tool_response>
{{ .Content }}
</tool_response><|im_end|>
{{ end }}
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
{{ end }}
{{- end }}
`,
wantOpeningTag: "<think>",
wantClosingTag: "</think>",
},
}
for _, c := range cases {
tmpl := template.Must(template.New("test").Parse(c.tmplString))
openingTag, closingTag := inferThinkingTags(tmpl)
if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag {
t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag)
}
}
}