checkpoint - cleanup still left, functionality setup

This commit is contained in:
ParthSareen
2025-05-08 18:48:44 -07:00
parent 6cb7494061
commit 779547fcde
3 changed files with 177 additions and 183 deletions

View File

@@ -53,7 +53,6 @@ func TestParseToolCalls(t *testing.T) {
output string
expectedToolCall []api.ToolCall
expectedTokens string
wantErr bool
}{
{
name: "mistral invalid json",
@@ -61,7 +60,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
wantErr: true,
},
{
name: "mistral multiple tool calls - no prefix",
@@ -69,7 +67,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "mistral tool calls with text in between - no prefix",
@@ -78,7 +75,6 @@ func TestParseToolCalls(t *testing.T) {
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
wantErr: false,
},
{
name: "mistral valid json - with prefix",
@@ -86,7 +82,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
// In this case we'd be ignoring the text in between and just returning the tool calls
@@ -96,7 +91,6 @@ func TestParseToolCalls(t *testing.T) {
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2, t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "mistral incomplete json",
@@ -104,7 +98,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
wantErr: true,
},
{
name: "mistral without tool token",
@@ -114,7 +107,6 @@ func TestParseToolCalls(t *testing.T) {
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
wantErr: true,
},
{
name: "mistral without tool token - tool first",
@@ -122,7 +114,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "command-r-plus with json block",
@@ -147,7 +138,6 @@ func TestParseToolCalls(t *testing.T) {
` + "```",
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "firefunction with functools",
@@ -155,7 +145,6 @@ func TestParseToolCalls(t *testing.T) {
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "llama3 with tool call tags",
@@ -165,7 +154,6 @@ func TestParseToolCalls(t *testing.T) {
</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
wantErr: false,
},
{
name: "xlam with tool_calls wrapper",
@@ -173,7 +161,6 @@ func TestParseToolCalls(t *testing.T) {
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "qwen2.5 with single tool call",
@@ -181,15 +168,34 @@ func TestParseToolCalls(t *testing.T) {
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
wantErr: false,
},
{
name: "qwen with invalid tool token",
name: "qwen with no tool prefix",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "qwen with no tool calls",
model: "qwen2.5-coder",
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
expectedToolCall: []api.ToolCall{},
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
},
{
name: "qwen with no tool prefix",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "some tokens after call",
},
{
name: "qwen with prefix",
model: "qwen2.5-coder",
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
// tests the leftover logic as well
@@ -198,7 +204,6 @@ func TestParseToolCalls(t *testing.T) {
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
wantErr: false,
},
{
name: "qwen3 with single tool call and thinking spaces",
@@ -206,31 +211,20 @@ func TestParseToolCalls(t *testing.T) {
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
wantErr: false,
},
// {
// name: "qwen3 testing",
// model: "qwen3",
// output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// expectedToolCall: []api.ToolCall{},
// expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// wantErr: true,
// },
// {
// name: "qwen3 testing 2",
// model: "qwen3",
// output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// expectedToolCall: []api.ToolCall{t1},
// expectedTokens: `<think></think>`,
// wantErr: true,
// },
{
name: "qwen with no tool calls",
model: "qwen2.5-coder",
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
name: "qwen3 testing",
model: "qwen3",
output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{},
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
wantErr: true,
expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
},
{
name: "qwen3 testing 2",
model: "qwen3",
output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: `<think></think>`,
},
{
name: "llama3.2 with tool call - no prefix",
@@ -238,7 +232,6 @@ func TestParseToolCalls(t *testing.T) {
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
wantErr: false,
},
{
name: "llama3.2 with incomplete tool call - no prefix",
@@ -246,7 +239,6 @@ func TestParseToolCalls(t *testing.T) {
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
wantErr: true,
},
{
name: "llama3.2 with tool call - in middle",
@@ -254,7 +246,6 @@ func TestParseToolCalls(t *testing.T) {
output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
wantErr: true,
},
{
name: "llama3.2 - fake tool prefix",
@@ -262,7 +253,6 @@ func TestParseToolCalls(t *testing.T) {
output: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
wantErr: true,
},
}
@@ -298,7 +288,6 @@ func TestParseToolCalls(t *testing.T) {
m := &Model{Template: tmpl}
tp := NewToolParser(m)
got := []api.ToolCall{}
success := false
var actualTokens strings.Builder
tokens := strings.Fields(tt.output)
@@ -306,40 +295,33 @@ func TestParseToolCalls(t *testing.T) {
add := true
s := " " + tok
// TODO(parthsareen): This logic is brittle as it mocks the logic in route, however can
if tp.state != Done {
toolCalls, leftover, ok := tp.ParseToolCalls(s)
if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) {
continue
}
if tp.state == ContainsPartialPrefix {
// actualTokens.Reset()
actualTokens.WriteString(leftover)
t.Log("leftover", leftover)
add = false
// continue
}
if ok && len(toolCalls) > 0 {
success = true
if !tp.Done {
toolCalls, leftover := tp.ParseToolCalls(s)
switch tp.ParserState {
case ToolCallFound:
got = append(got, toolCalls...)
add = false
// actualTokens.Reset()
case ToolCallSendTokens:
actualTokens.WriteString(s)
add = false
case ToolCallAccumulate:
add = false
case ToolCallSendPartial:
actualTokens.WriteString(" " + leftover)
add = false
}
}
// s = strings.TrimSpace(s)
if add {
actualTokens.WriteString(s)
}
}
if !tt.wantErr {
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
if !success && !tt.wantErr {
t.Errorf("expected success but got errors")
// Compare tool calls if we expect any
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
t.Errorf("tool calls mismatch (-got +want):\n%s", diff)
}
// Compare tokens if we expect any
stripped := strings.TrimSpace(actualTokens.String())
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)