Compare commits

...

10 Commits

Author SHA1 Message Date
jmorganca
4ef2b2852d server: serve original error for remote models 2025-09-20 16:46:32 -07:00
Devon Rifkin
3677842ff1 Merge pull request #12358 from ollama/drifkin/qwen3-coder-ampersands
parsers: fix `&`s in qwen3coder parameter values
2025-09-20 12:40:33 -07:00
Devon Rifkin
242df70a75 parsers: fix &s in qwen3coder parameter values
In <https://github.com/ollama/ollama/issues/12357> we that the model
will output tool calls such as

```
<function=shell>
<parameter=command>
pwd && ls -la
</parameter>
</function>
```

We parse this using the approach of transforming into valid xml and then
using an xml parser. While we do transform the function and parameter
names, we weren't escaping the parameter values (which in this example
are invalid since `pwd && ls -la` contains unescaped ampersands).

This has been fixed by first transforming the tags in the same way, and
then walking the transformed string and escaping the text in between the
tags. This also fixes a case where `<` in the middle of a parameter
value would cause an xml parse failure.

Fixes: #12357
2025-09-20 12:11:38 -07:00
Patrick Devine
dba39b2eee gemma: fix rope scaling for qat models (#12348)
* gemma: fix rope scaling for qat models

* gofumpt yourself
2025-09-19 15:04:40 -07:00
Michael Yang
9f3a37fd36 fix: model load for unsupported embedding models (#12311)
with #12181, there's now support for embeddings in ollama engine.
this is done by mutating the architecture and adding _embed when it
detects an embedding model. however this introduced a bug where if
an embedding model was run based on an existing ollama engine model
without an embedding implementation, e.g. llama4, it will pass the
initial arch support check but fail when actually loaded.

there's currently two entrypoints to creating a model. previously this
second entrypoint was necessary because calling model.New would also
load the model. since #11818, this is no longer th case so merge them
to reduce complexity
2025-09-18 16:11:08 -07:00
Michael Yang
7460259eb3 feat: qwen3 embed (#12301)
* cleanup

* use pooling.TypeNone

* pooling test

* qwen3 embed
2025-09-18 15:50:32 -07:00
Jeffrey Morgan
22ccdd74c2 server: add unauthorized error to remote chat handler (#12338) 2025-09-18 15:40:31 -07:00
Daniel Hiltgen
0c3d0e7533 build: avoid unbounded parallel builds (#12319)
With the addition of cuda v13, on a clean setup, the level of parallelism
was causing docker desktop to become overwhelmed and compilers
were crashing.  This limits to 8 parallel per build stage, with the ability
to override if you have many more cores available.
2025-09-18 14:57:01 -07:00
Patrick Devine
eb0a5d4459 auth: check the permissions on the private key to see if it's readable (#12336) 2025-09-18 14:34:34 -07:00
Michael Yang
ceac416ec2 fix(integration): check truncated length (#12337) 2025-09-18 14:00:21 -07:00
13 changed files with 420 additions and 157 deletions

View File

@@ -1,6 +1,7 @@
# vim: filetype=dockerfile # vim: filetype=dockerfile
ARG FLAVOR=${TARGETARCH} ARG FLAVOR=${TARGETARCH}
ARG PARALLEL=8
ARG ROCMVERSION=6.3.3 ARG ROCMVERSION=6.3.3
ARG JETPACK5VERSION=r35.4.1 ARG JETPACK5VERSION=r35.4.1
@@ -34,46 +35,51 @@ ENV LDFLAGS=-s
FROM base AS cpu FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \ cmake --preset 'CPU' \
&& cmake --build --parallel --preset 'CPU' \ && cmake --build --parallel ${PARALLEL} --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel 8 && cmake --install build --component CPU --strip --parallel ${PARALLEL}
FROM base AS cuda-11 FROM base AS cuda-11
ARG CUDA11VERSION=11.8 ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \ cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
&& cmake --build --parallel --preset 'CUDA 11' \ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel 8 && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS cuda-12 FROM base AS cuda-12
ARG CUDA12VERSION=12.8 ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\ cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
&& cmake --build --parallel --preset 'CUDA 12' \ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
&& cmake --install build --component CUDA --strip --parallel 8 && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS cuda-13 FROM base AS cuda-13
ARG CUDA13VERSION=13.0 ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \ cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
&& cmake --build --parallel --preset 'CUDA 13' \ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
&& cmake --install build --component CUDA --strip --parallel 8 && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS rocm-6 FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \ cmake --preset 'ROCm 6' \
&& cmake --build --parallel --preset 'ROCm 6' \ && cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
&& cmake --install build --component HIP --strip --parallel 8 && cmake --install build --component HIP --strip --parallel ${PARALLEL}
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5 FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
ARG CMAKEVERSION ARG CMAKEVERSION
@@ -81,10 +87,11 @@ RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 5' \ cmake --preset 'JetPack 5' \
&& cmake --build --parallel --preset 'JetPack 5' \ && cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
&& cmake --install build --component CUDA --strip --parallel 8 && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6 FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
ARG CMAKEVERSION ARG CMAKEVERSION
@@ -92,10 +99,11 @@ RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 6' \ cmake --preset 'JetPack 6' \
&& cmake --build --parallel --preset 'JetPack 6' \ && cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
&& cmake --install build --component CUDA --strip --parallel 8 && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS build FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama WORKDIR /go/src/github.com/ollama/ollama

View File

@@ -19,16 +19,28 @@ import (
const defaultPrivateKey = "id_ed25519" const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) { func keyPath() (string, error) {
fileExists := func(fp string) bool { fileIsReadable := func(fp string) bool {
info, err := os.Stat(fp) info, err := os.Stat(fp)
if err != nil { if err != nil {
return false return false
} }
return !info.IsDir()
// Check that it's a regular file, not a directory or other file type
if !info.Mode().IsRegular() {
return false
}
// Try to open it to check readability
file, err := os.Open(fp)
if err != nil {
return false
}
file.Close()
return true
} }
systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey) systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey)
if fileExists(systemPath) { if fileIsReadable(systemPath) {
return systemPath, nil return systemPath, nil
} }

View File

@@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@@ -44,9 +45,8 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
} }
res, err := embeddingTestHelper(ctx, client, t, req) res, err := embeddingTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatal(err)
} }
if len(res.Embedding) != 384 { if len(res.Embedding) != 384 {
@@ -74,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
} }
res, err := embedTestHelper(ctx, client, t, req) res, err := embedTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatal(err)
} }
if len(res.Embeddings) != 1 { if len(res.Embeddings) != 1 {
@@ -112,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
} }
res, err := embedTestHelper(ctx, client, t, req) res, err := embedTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatal(err)
} }
if len(res.Embeddings) != 2 { if len(res.Embeddings) != 2 {
@@ -156,93 +154,135 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
truncTrue, truncFalse := true, false truncTrue, truncFalse := true, false
type testReq struct { want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
Name string Model: "all-minilm",
Request api.EmbedRequest Input: "why",
})
if err != nil {
t.Fatal(err)
} }
reqs := []testReq{ cases := []struct {
name string
request api.EmbedRequest
check func(*api.EmbedResponse, error)
}{
{ {
Name: "Target Truncation", name: "target truncation",
Request: api.EmbedRequest{ request: api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why", Input: "why",
}, },
}, check: func(got *api.EmbedResponse, err error) {
{ if err != nil {
Name: "Default Truncate", t.Fatal(err)
Request: api.EmbedRequest{ }
Model: "all-minilm",
Input: "why is the sky blue?", if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
Options: map[string]any{"num_ctx": 1}, t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
}, },
}, },
{ {
Name: "Explicit Truncate", name: "default truncate",
Request: api.EmbedRequest{ request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 3},
},
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
},
},
{
name: "explicit truncate",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 3},
},
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
},
},
{
name: "truncate error",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 3},
},
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
{
name: "input after truncate error",
request: api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
Truncate: &truncTrue, Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1}, Options: map[string]any{"num_ctx": 1},
}, },
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
{
name: "input after truncate error",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 0},
},
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
}, },
} }
res := make(map[string]*api.EmbedResponse) for _, req := range cases {
t.Run(req.name, func(t *testing.T) {
for _, req := range reqs { req.check(embedTestHelper(ctx, client, t, req.request))
response, err := embedTestHelper(ctx, client, t, req.Request) })
if err != nil {
t.Fatalf("error: %v", err)
}
res[req.Name] = response
}
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatal("expected default request to truncate correctly")
}
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
t.Fatal("expected default request and truncate true request to be the same")
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 1},
})
if err == nil {
t.Fatal("expected error, got nil")
} }
} }
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
t.Helper()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err) t.Fatal(err)
} }
response, err := client.Embeddings(ctx, &req) return client.Embeddings(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
} }
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
t.Helper()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err) t.Fatal(err)
} }
response, err := client.Embed(ctx, &req) return client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
} }

View File

@@ -107,23 +107,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return nil, err return nil, err
} }
arch := b.Config().Architecture() m, err := modelForArch(b.Config())
if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(b.Config())
if err != nil { if err != nil {
return nil, err return nil, err
} }
base := Base{b: b, config: m.Config()} base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m) v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem())) v.Elem().Set(populateFields(base, v.Elem()))
return m, nil return m, nil
@@ -135,30 +124,38 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err return nil, err
} }
defer r.Close() defer r.Close()
meta, err := fsggml.Decode(r, -1) meta, err := fsggml.Decode(r, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return getTextProcessor(meta.KV())
}
func getTextProcessor(kv fsggml.KV) (TextProcessor, error) { m, err := modelForArch(meta.KV())
arch := kv.Architecture()
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(kv)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tp, ok := m.(TextProcessor) tp, ok := m.(TextProcessor)
if !ok { if !ok {
return nil, fmt.Errorf("%v is not a TextProcessor", m) return nil, ErrUnsupportedTokenizer
} }
return tp, nil return tp, nil
} }
func modelForArch(c fs.Config) (Model, error) {
arch := c.Architecture()
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, ErrUnsupportedModel
}
return f(c)
}
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type() t := v.Type()

View File

@@ -1,9 +1,9 @@
package model package model
import ( import (
"errors"
"reflect" "reflect"
"slices" "slices"
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@@ -12,7 +12,6 @@ import (
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
) )
func TestParseTags(t *testing.T) { func TestParseTags(t *testing.T) {
@@ -148,39 +147,58 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
} }
} }
func TestGetTextProcessor(t *testing.T) { func TestModelForArch(t *testing.T) {
tp, err := getTextProcessor(fsggml.KV{}) type fakeModel struct {
if err == nil { Model
t.Error("expected error")
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
} }
models["dummy"] = func(fs.Config) (Model, error) { type fakeEmbeddingModel struct {
return notTextProcessorModel{}, nil Model
} }
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
if err == nil { models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil }
t.Error("expected error") models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil }
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
t.Errorf("unexpected error: %v", err) cases := []struct {
} else if tp != nil { name string
t.Error("expected nil tp") config fs.Config
want any
err error
}{
{
name: "model",
config: fsggml.KV{
"general.architecture": "model",
},
want: fakeModel{},
},
{
name: "embedding",
config: fsggml.KV{
"general.architecture": "model",
"model.pooling_type": uint32(1),
},
want: fakeEmbeddingModel{},
},
{
name: "unsupported",
config: fsggml.KV{
"general.architecture": "unsupported",
},
err: ErrUnsupportedModel,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := modelForArch(tt.config)
if !errors.Is(err, tt.err) {
t.Fatal(err)
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff)
}
})
} }
} }
type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
panic("unimplemented")
}
func (notTextProcessorModel) Backend() ml.Backend {
panic("unimplemented")
}
func (notTextProcessorModel) Config() config {
panic("unimplemented")
}

View File

@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
} }
type MLP struct { type MLP struct {

View File

@@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: c.Float("rope.scaling.factor", 1.0), ropeScale: 1,
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
// (8 instead of 1)
// ropeScale: c.Float("rope.scaling.factor", 1.0),
}, },
} }
@@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
ropeBase = m.TextConfig.ropeGlobalBase ropeBase = m.TextConfig.ropeGlobalBase
} }
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
} }
type TextMLP struct { type TextMLP struct {

View File

@@ -0,0 +1,73 @@
package qwen3
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type embedModel struct {
model.Base
model.BytePairEncoding
*Model
poolingType pooling.Type
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates, err := m.forward(ctx, batch)
if err != nil {
return nil, err
}
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil
}
func newEmbed(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count"))
for i := range layers {
layers[i].MLP = &dense{}
}
m := embedModel{
BytePairEncoding: model.NewBytePairEncoding(
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
Model: &Model{
Layers: layers,
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
},
},
poolingType: pooling.Type(c.Uint("pooling_type")),
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}

View File

@@ -151,14 +151,25 @@ type Model struct {
*Options *Options
} }
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates, err := m.forward(ctx, batch)
if err != nil {
return nil, err
}
return m.Output.Forward(ctx, hiddenStates), nil
}
// Forward implements model.Model.
func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers { for i, layer := range m.Layers {
m.Cache.SetLayer(i) if m.Cache != nil {
m.Cache.SetLayer(i)
}
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
@@ -168,8 +179,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
} }
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil
return m.Output.Forward(ctx, hiddenStates), nil
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
@@ -227,4 +237,5 @@ func New(c fs.Config) (model.Model, error) {
func init() { func init() {
model.Register("qwen3", New) model.Register("qwen3", New)
model.Register("qwen3moe", New) model.Register("qwen3moe", New)
model.Register("qwen3_embed", newEmbed)
} }

View File

@@ -393,18 +393,55 @@ func parseValue(raw string, paramType api.PropertyType) any {
return raw return raw
} }
var qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`) var (
qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
qwenXMLTagRegex = regexp.MustCompile(`</?(?:function|parameter)(?:\s+name="[^"]*")?>`)
)
// transformToXML transforms a raw qwen tool call with xml-like tags into valid // transformToXML transforms a raw qwen tool call with xml-like tags into valid
// xml so that it can be parsed by any xml parser // xml so that it can be parsed by any xml parser
func transformToXML(raw string) string { func transformToXML(raw string) string {
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking // take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
// care to properly escape the string that becomes the attribute value // care to properly escape the string that becomes the attribute value
return qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string { transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
groups := qwenTagRegex.FindStringSubmatch(match) groups := qwenTagRegex.FindStringSubmatch(match)
tag := groups[1] tag := groups[1]
var escapedValue strings.Builder var escapedValue strings.Builder
xml.EscapeText(&escapedValue, []byte(groups[2])) xml.EscapeText(&escapedValue, []byte(groups[2]))
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String()) return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
}) })
// Walk the resulting string, escaping any character data that sits between the
// xml tags we just emitted
var out strings.Builder
lastIdx := 0
for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) {
if loc[0] > lastIdx {
escapeTextNode(&out, transformed[lastIdx:loc[0]])
}
out.WriteString(transformed[loc[0]:loc[1]])
lastIdx = loc[1]
}
if lastIdx < len(transformed) {
escapeTextNode(&out, transformed[lastIdx:])
}
return out.String()
}
// escapeTextNode escapes XML character data without altering other characters
// like newlines or tabs (which is why we don't use xml.EscapeText for this)
func escapeTextNode(sb *strings.Builder, s string) {
for _, r := range s {
switch r {
case '&':
sb.WriteString("&amp;")
case '<':
sb.WriteString("&lt;")
case '>':
sb.WriteString("&gt;")
default:
sb.WriteRune(r)
}
}
} }

View File

@@ -312,6 +312,41 @@ true
}, },
}, },
}, },
// regression test for <https://github.com/ollama/ollama/issues/12357>
{
name: "ampersands in parameter values",
tools: []api.Tool{},
rawToolCall: `<function=exec>
<parameter=command>
ls && echo "done"
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: map[string]any{
"command": "ls && echo \"done\"",
},
},
},
},
{
name: "angle brackets in parameter values",
tools: []api.Tool{},
rawToolCall: `<function=exec>
<parameter=command>
ls && echo "a > b and a < b"
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: map[string]any{
"command": "ls && echo \"a > b and a < b\"",
},
},
},
},
} }
for i, step := range steps { for i, step := range steps {
@@ -798,6 +833,19 @@ celsius
</parameter> </parameter>
</function>`, </function>`,
}, },
{
desc: "ampersands in parameter values",
raw: `<function=get_current_temperature>
<parameter=location>
San Francisco & San Jose
</parameter>
</function>`,
want: `<function name="get_current_temperature">
<parameter name="location">
San Francisco &amp; San Jose
</parameter>
</function>`,
},
} }
for _, tc := range cases { for _, tc := range cases {

View File

@@ -16,6 +16,7 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
--build-arg=OLLAMA_FAST_BUILD \ --build-arg=OLLAMA_FAST_BUILD \
--build-arg=CUSTOM_CPU_FLAGS \ --build-arg=CUSTOM_CPU_FLAGS \
--build-arg=GPU_RUNNER_CPU_FLAGS \ --build-arg=GPU_RUNNER_CPU_FLAGS \
--build-arg=PARALLEL \
--build-arg=AMDGPU_TARGETS" --build-arg=AMDGPU_TARGETS"
echo "Building Ollama" echo "Building Ollama"

View File

@@ -29,7 +29,6 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/discover" "github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
@@ -251,15 +250,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
client := api.NewClient(remoteURL, http.DefaultClient) client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Generate(c, &req, fn) err = client.Generate(c, &req, fn)
if err != nil { if err != nil {
var sErr api.AuthorizationError var authError api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { if errors.As(err, &authError) {
pk, pkErr := auth.GetPublicKey() c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "public_key": authError.PublicKey})
if pkErr != nil { return
slog.Error("couldn't get public key", "error", pkErr) }
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"}) var apiError api.StatusError
return if errors.As(err, &apiError) {
} c.JSON(apiError.StatusCode, apiError)
c.JSON(http.StatusUnauthorized, gin.H{"public_key": pk})
return return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -634,7 +632,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen { if len(tokens) > ctxLen {
if !truncate { if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
return return
} }
@@ -646,6 +644,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
ctxLen-- ctxLen--
} }
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
if ctxLen <= 0 {
// return error if the truncated input would be empty or just special tokens
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen] tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens) s, err = r.Detokenize(c.Request.Context(), tokens)
@@ -1803,6 +1808,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
client := api.NewClient(remoteURL, http.DefaultClient) client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Chat(c, &req, fn) err = client.Chat(c, &req, fn)
if err != nil { if err != nil {
var authError api.AuthorizationError
if errors.As(err, &authError) {
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "public_key": authError.PublicKey})
return
}
var apiError api.StatusError
if errors.As(err, &apiError) {
c.JSON(apiError.StatusCode, apiError)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }