mirror of
https://github.com/ollama/ollama.git
synced 2026-04-22 00:36:11 +02:00
Compare commits
10 Commits
v0.12.0-rc
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ef2b2852d | ||
|
|
3677842ff1 | ||
|
|
242df70a75 | ||
|
|
dba39b2eee | ||
|
|
9f3a37fd36 | ||
|
|
7460259eb3 | ||
|
|
22ccdd74c2 | ||
|
|
0c3d0e7533 | ||
|
|
eb0a5d4459 | ||
|
|
ceac416ec2 |
36
Dockerfile
36
Dockerfile
@@ -1,6 +1,7 @@
|
|||||||
# vim: filetype=dockerfile
|
# vim: filetype=dockerfile
|
||||||
|
|
||||||
ARG FLAVOR=${TARGETARCH}
|
ARG FLAVOR=${TARGETARCH}
|
||||||
|
ARG PARALLEL=8
|
||||||
|
|
||||||
ARG ROCMVERSION=6.3.3
|
ARG ROCMVERSION=6.3.3
|
||||||
ARG JETPACK5VERSION=r35.4.1
|
ARG JETPACK5VERSION=r35.4.1
|
||||||
@@ -34,46 +35,51 @@ ENV LDFLAGS=-s
|
|||||||
FROM base AS cpu
|
FROM base AS cpu
|
||||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CPU' \
|
cmake --preset 'CPU' \
|
||||||
&& cmake --build --parallel --preset 'CPU' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
||||||
&& cmake --install build --component CPU --strip --parallel 8
|
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
FROM base AS cuda-11
|
FROM base AS cuda-11
|
||||||
ARG CUDA11VERSION=11.8
|
ARG CUDA11VERSION=11.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
|
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
|
||||||
&& cmake --build --parallel --preset 'CUDA 11' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
FROM base AS cuda-12
|
FROM base AS cuda-12
|
||||||
ARG CUDA12VERSION=12.8
|
ARG CUDA12VERSION=12.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
|
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
|
||||||
&& cmake --build --parallel --preset 'CUDA 12' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
|
|
||||||
FROM base AS cuda-13
|
FROM base AS cuda-13
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
|
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
|
||||||
&& cmake --build --parallel --preset 'CUDA 13' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
|
|
||||||
FROM base AS rocm-6
|
FROM base AS rocm-6
|
||||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'ROCm 6' \
|
cmake --preset 'ROCm 6' \
|
||||||
&& cmake --build --parallel --preset 'ROCm 6' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
||||||
&& cmake --install build --component HIP --strip --parallel 8
|
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
@@ -81,10 +87,11 @@ RUN apt-get update && apt-get install -y curl ccache \
|
|||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 5' \
|
cmake --preset 'JetPack 5' \
|
||||||
&& cmake --build --parallel --preset 'JetPack 5' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
@@ -92,10 +99,11 @@ RUN apt-get update && apt-get install -y curl ccache \
|
|||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 6' \
|
cmake --preset 'JetPack 6' \
|
||||||
&& cmake --build --parallel --preset 'JetPack 6' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
|
|||||||
18
auth/auth.go
18
auth/auth.go
@@ -19,16 +19,28 @@ import (
|
|||||||
const defaultPrivateKey = "id_ed25519"
|
const defaultPrivateKey = "id_ed25519"
|
||||||
|
|
||||||
func keyPath() (string, error) {
|
func keyPath() (string, error) {
|
||||||
fileExists := func(fp string) bool {
|
fileIsReadable := func(fp string) bool {
|
||||||
info, err := os.Stat(fp)
|
info, err := os.Stat(fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return !info.IsDir()
|
|
||||||
|
// Check that it's a regular file, not a directory or other file type
|
||||||
|
if !info.Mode().IsRegular() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to open it to check readability
|
||||||
|
file, err := os.Open(fp)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
file.Close()
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey)
|
systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey)
|
||||||
if fileExists(systemPath) {
|
if fileIsReadable(systemPath) {
|
||||||
return systemPath, nil
|
return systemPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,9 +45,8 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res.Embedding) != 384 {
|
if len(res.Embedding) != 384 {
|
||||||
@@ -74,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, client, t, req)
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res.Embeddings) != 1 {
|
if len(res.Embeddings) != 1 {
|
||||||
@@ -112,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, client, t, req)
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res.Embeddings) != 2 {
|
if len(res.Embeddings) != 2 {
|
||||||
@@ -156,93 +154,135 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
|
|
||||||
truncTrue, truncFalse := true, false
|
truncTrue, truncFalse := true, false
|
||||||
|
|
||||||
type testReq struct {
|
want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||||
Name string
|
Model: "all-minilm",
|
||||||
Request api.EmbedRequest
|
Input: "why",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqs := []testReq{
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
request api.EmbedRequest
|
||||||
|
check func(*api.EmbedResponse, error)
|
||||||
|
}{
|
||||||
{
|
{
|
||||||
Name: "Target Truncation",
|
name: "target truncation",
|
||||||
Request: api.EmbedRequest{
|
request: api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why",
|
Input: "why",
|
||||||
},
|
},
|
||||||
|
check: func(got *api.EmbedResponse, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "Default Truncate",
|
name: "default truncate",
|
||||||
Request: api.EmbedRequest{
|
request: api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why is the sky blue?",
|
Input: "why is the sky blue?",
|
||||||
Options: map[string]any{"num_ctx": 1},
|
Options: map[string]any{"num_ctx": 3},
|
||||||
|
},
|
||||||
|
check: func(got *api.EmbedResponse, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "Explicit Truncate",
|
name: "explicit truncate",
|
||||||
Request: api.EmbedRequest{
|
request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 3},
|
||||||
|
},
|
||||||
|
check: func(got *api.EmbedResponse, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "truncate error",
|
||||||
|
request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 3},
|
||||||
|
},
|
||||||
|
check: func(res *api.EmbedResponse, err error) {
|
||||||
|
if err.Error() != "input exceeds maximum context length" {
|
||||||
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "input after truncate error",
|
||||||
|
request: api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why is the sky blue?",
|
Input: "why is the sky blue?",
|
||||||
Truncate: &truncTrue,
|
Truncate: &truncTrue,
|
||||||
Options: map[string]any{"num_ctx": 1},
|
Options: map[string]any{"num_ctx": 1},
|
||||||
},
|
},
|
||||||
|
check: func(res *api.EmbedResponse, err error) {
|
||||||
|
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||||
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "input after truncate error",
|
||||||
|
request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 0},
|
||||||
|
},
|
||||||
|
check: func(res *api.EmbedResponse, err error) {
|
||||||
|
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||||
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
res := make(map[string]*api.EmbedResponse)
|
for _, req := range cases {
|
||||||
|
t.Run(req.name, func(t *testing.T) {
|
||||||
for _, req := range reqs {
|
req.check(embedTestHelper(ctx, client, t, req.request))
|
||||||
response, err := embedTestHelper(ctx, client, t, req.Request)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error: %v", err)
|
|
||||||
}
|
|
||||||
res[req.Name] = response
|
|
||||||
}
|
|
||||||
|
|
||||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
|
||||||
t.Fatal("expected default request to truncate correctly")
|
|
||||||
}
|
|
||||||
|
|
||||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
|
||||||
t.Fatal("expected default request and truncate true request to be the same")
|
|
||||||
}
|
|
||||||
|
|
||||||
// check that truncate set to false returns an error if context length is exceeded
|
|
||||||
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
|
||||||
Model: "all-minilm",
|
|
||||||
Input: "why is the sky blue?",
|
|
||||||
Truncate: &truncFalse,
|
|
||||||
Options: map[string]any{"num_ctx": 1},
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error, got nil")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.Embeddings(ctx, &req)
|
return client.Embeddings(ctx, &req)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.Embed(ctx, &req)
|
return client.Embed(ctx, &req)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,23 +107,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
arch := b.Config().Architecture()
|
m, err := modelForArch(b.Config())
|
||||||
if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone {
|
|
||||||
arch = arch + "_embed"
|
|
||||||
}
|
|
||||||
|
|
||||||
f, ok := models[arch]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := f(b.Config())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
base := Base{b: b, config: m.Config()}
|
base := Base{b: b, config: m.Config()}
|
||||||
|
|
||||||
v := reflect.ValueOf(m)
|
v := reflect.ValueOf(m)
|
||||||
v.Elem().Set(populateFields(base, v.Elem()))
|
v.Elem().Set(populateFields(base, v.Elem()))
|
||||||
return m, nil
|
return m, nil
|
||||||
@@ -135,30 +124,38 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
meta, err := fsggml.Decode(r, -1)
|
meta, err := fsggml.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return getTextProcessor(meta.KV())
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
|
m, err := modelForArch(meta.KV())
|
||||||
arch := kv.Architecture()
|
|
||||||
f, ok := models[arch]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
|
||||||
}
|
|
||||||
m, err := f(kv)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tp, ok := m.(TextProcessor)
|
tp, ok := m.(TextProcessor)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("%v is not a TextProcessor", m)
|
return nil, ErrUnsupportedTokenizer
|
||||||
}
|
}
|
||||||
return tp, nil
|
return tp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func modelForArch(c fs.Config) (Model, error) {
|
||||||
|
arch := c.Architecture()
|
||||||
|
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
|
||||||
|
arch = arch + "_embed"
|
||||||
|
}
|
||||||
|
|
||||||
|
f, ok := models[arch]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrUnsupportedModel
|
||||||
|
}
|
||||||
|
|
||||||
|
return f(c)
|
||||||
|
}
|
||||||
|
|
||||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||||
t := v.Type()
|
t := v.Type()
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/backend/ggml"
|
"github.com/ollama/ollama/ml/backend/ggml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseTags(t *testing.T) {
|
func TestParseTags(t *testing.T) {
|
||||||
@@ -148,39 +147,58 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetTextProcessor(t *testing.T) {
|
func TestModelForArch(t *testing.T) {
|
||||||
tp, err := getTextProcessor(fsggml.KV{})
|
type fakeModel struct {
|
||||||
if err == nil {
|
Model
|
||||||
t.Error("expected error")
|
|
||||||
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
|
|
||||||
t.Errorf("unexpected error: %v", err)
|
|
||||||
} else if tp != nil {
|
|
||||||
t.Error("expected nil tp")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
models["dummy"] = func(fs.Config) (Model, error) {
|
type fakeEmbeddingModel struct {
|
||||||
return notTextProcessorModel{}, nil
|
Model
|
||||||
}
|
|
||||||
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
|
|
||||||
if err == nil {
|
|
||||||
t.Error("expected error")
|
|
||||||
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
|
|
||||||
t.Errorf("unexpected error: %v", err)
|
|
||||||
} else if tp != nil {
|
|
||||||
t.Error("expected nil tp")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type notTextProcessorModel struct{}
|
models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil }
|
||||||
|
models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil }
|
||||||
|
|
||||||
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
|
cases := []struct {
|
||||||
panic("unimplemented")
|
name string
|
||||||
|
config fs.Config
|
||||||
|
want any
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "model",
|
||||||
|
config: fsggml.KV{
|
||||||
|
"general.architecture": "model",
|
||||||
|
},
|
||||||
|
want: fakeModel{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "embedding",
|
||||||
|
config: fsggml.KV{
|
||||||
|
"general.architecture": "model",
|
||||||
|
"model.pooling_type": uint32(1),
|
||||||
|
},
|
||||||
|
want: fakeEmbeddingModel{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported",
|
||||||
|
config: fsggml.KV{
|
||||||
|
"general.architecture": "unsupported",
|
||||||
|
},
|
||||||
|
err: ErrUnsupportedModel,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func (notTextProcessorModel) Backend() ml.Backend {
|
for _, tt := range cases {
|
||||||
panic("unimplemented")
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := modelForArch(tt.config)
|
||||||
|
if !errors.Is(err, tt.err) {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (notTextProcessorModel) Config() config {
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
panic("unimplemented")
|
t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
|||||||
@@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel {
|
|||||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
ropeScale: 1,
|
||||||
|
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
||||||
|
// (8 instead of 1)
|
||||||
|
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
|||||||
ropeBase = m.TextConfig.ropeGlobalBase
|
ropeBase = m.TextConfig.ropeGlobalBase
|
||||||
}
|
}
|
||||||
|
|
||||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
|
|||||||
73
model/models/qwen3/embed.go
Normal file
73
model/models/qwen3/embed.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package qwen3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type embedModel struct {
|
||||||
|
model.Base
|
||||||
|
model.BytePairEncoding
|
||||||
|
|
||||||
|
*Model
|
||||||
|
poolingType pooling.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
hiddenStates, err := m.forward(ctx, batch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||||
|
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||||
|
return hiddenStates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEmbed(c fs.Config) (model.Model, error) {
|
||||||
|
layers := make([]Layer, c.Uint("block_count"))
|
||||||
|
for i := range layers {
|
||||||
|
layers[i].MLP = &dense{}
|
||||||
|
}
|
||||||
|
m := embedModel{
|
||||||
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
EOS: append(
|
||||||
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Model: &Model{
|
||||||
|
Layers: layers,
|
||||||
|
Options: &Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
keyLength: int(c.Uint("attention.key_length")),
|
||||||
|
valueLength: int(c.Uint("attention.value_length")),
|
||||||
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
|
numExperts: int(c.Uint("expert_count")),
|
||||||
|
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||||
|
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
@@ -151,14 +151,25 @@ type Model struct {
|
|||||||
*Options
|
*Options
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward implements model.Model.
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
hiddenStates, err := m.forward(ctx, batch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implements model.Model.
|
||||||
|
func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
|
|
||||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
|
if m.Cache != nil {
|
||||||
m.Cache.SetLayer(i)
|
m.Cache.SetLayer(i)
|
||||||
|
}
|
||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
@@ -168,8 +179,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil
|
||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
@@ -227,4 +237,5 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
func init() {
|
func init() {
|
||||||
model.Register("qwen3", New)
|
model.Register("qwen3", New)
|
||||||
model.Register("qwen3moe", New)
|
model.Register("qwen3moe", New)
|
||||||
|
model.Register("qwen3_embed", newEmbed)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -393,18 +393,55 @@ func parseValue(raw string, paramType api.PropertyType) any {
|
|||||||
return raw
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
var qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
|
var (
|
||||||
|
qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
|
||||||
|
qwenXMLTagRegex = regexp.MustCompile(`</?(?:function|parameter)(?:\s+name="[^"]*")?>`)
|
||||||
|
)
|
||||||
|
|
||||||
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
|
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
|
||||||
// xml so that it can be parsed by any xml parser
|
// xml so that it can be parsed by any xml parser
|
||||||
func transformToXML(raw string) string {
|
func transformToXML(raw string) string {
|
||||||
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
|
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
|
||||||
// care to properly escape the string that becomes the attribute value
|
// care to properly escape the string that becomes the attribute value
|
||||||
return qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
|
transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
|
||||||
groups := qwenTagRegex.FindStringSubmatch(match)
|
groups := qwenTagRegex.FindStringSubmatch(match)
|
||||||
tag := groups[1]
|
tag := groups[1]
|
||||||
var escapedValue strings.Builder
|
var escapedValue strings.Builder
|
||||||
xml.EscapeText(&escapedValue, []byte(groups[2]))
|
xml.EscapeText(&escapedValue, []byte(groups[2]))
|
||||||
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
|
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Walk the resulting string, escaping any character data that sits between the
|
||||||
|
// xml tags we just emitted
|
||||||
|
var out strings.Builder
|
||||||
|
lastIdx := 0
|
||||||
|
for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) {
|
||||||
|
if loc[0] > lastIdx {
|
||||||
|
escapeTextNode(&out, transformed[lastIdx:loc[0]])
|
||||||
|
}
|
||||||
|
out.WriteString(transformed[loc[0]:loc[1]])
|
||||||
|
lastIdx = loc[1]
|
||||||
|
}
|
||||||
|
if lastIdx < len(transformed) {
|
||||||
|
escapeTextNode(&out, transformed[lastIdx:])
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// escapeTextNode escapes XML character data without altering other characters
|
||||||
|
// like newlines or tabs (which is why we don't use xml.EscapeText for this)
|
||||||
|
func escapeTextNode(sb *strings.Builder, s string) {
|
||||||
|
for _, r := range s {
|
||||||
|
switch r {
|
||||||
|
case '&':
|
||||||
|
sb.WriteString("&")
|
||||||
|
case '<':
|
||||||
|
sb.WriteString("<")
|
||||||
|
case '>':
|
||||||
|
sb.WriteString(">")
|
||||||
|
default:
|
||||||
|
sb.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -312,6 +312,41 @@ true
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
// regression test for <https://github.com/ollama/ollama/issues/12357>
|
||||||
|
{
|
||||||
|
name: "ampersands in parameter values",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `<function=exec>
|
||||||
|
<parameter=command>
|
||||||
|
ls && echo "done"
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "exec",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"command": "ls && echo \"done\"",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "angle brackets in parameter values",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `<function=exec>
|
||||||
|
<parameter=command>
|
||||||
|
ls && echo "a > b and a < b"
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "exec",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"command": "ls && echo \"a > b and a < b\"",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, step := range steps {
|
for i, step := range steps {
|
||||||
@@ -796,6 +831,19 @@ San Francisco
|
|||||||
<parameter name=""unit with spaces"">
|
<parameter name=""unit with spaces"">
|
||||||
celsius
|
celsius
|
||||||
</parameter>
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "ampersands in parameter values",
|
||||||
|
raw: `<function=get_current_temperature>
|
||||||
|
<parameter=location>
|
||||||
|
San Francisco & San Jose
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
want: `<function name="get_current_temperature">
|
||||||
|
<parameter name="location">
|
||||||
|
San Francisco & San Jose
|
||||||
|
</parameter>
|
||||||
</function>`,
|
</function>`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
|
|||||||
--build-arg=OLLAMA_FAST_BUILD \
|
--build-arg=OLLAMA_FAST_BUILD \
|
||||||
--build-arg=CUSTOM_CPU_FLAGS \
|
--build-arg=CUSTOM_CPU_FLAGS \
|
||||||
--build-arg=GPU_RUNNER_CPU_FLAGS \
|
--build-arg=GPU_RUNNER_CPU_FLAGS \
|
||||||
|
--build-arg=PARALLEL \
|
||||||
--build-arg=AMDGPU_TARGETS"
|
--build-arg=AMDGPU_TARGETS"
|
||||||
|
|
||||||
echo "Building Ollama"
|
echo "Building Ollama"
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/auth"
|
|
||||||
"github.com/ollama/ollama/discover"
|
"github.com/ollama/ollama/discover"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
@@ -251,15 +250,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||||
err = client.Generate(c, &req, fn)
|
err = client.Generate(c, &req, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var sErr api.AuthorizationError
|
var authError api.AuthorizationError
|
||||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
if errors.As(err, &authError) {
|
||||||
pk, pkErr := auth.GetPublicKey()
|
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "public_key": authError.PublicKey})
|
||||||
if pkErr != nil {
|
|
||||||
slog.Error("couldn't get public key", "error", pkErr)
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"public_key": pk})
|
var apiError api.StatusError
|
||||||
|
if errors.As(err, &apiError) {
|
||||||
|
c.JSON(apiError.StatusCode, apiError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -634,7 +632,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||||
if len(tokens) > ctxLen {
|
if len(tokens) > ctxLen {
|
||||||
if !truncate {
|
if !truncate {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -646,6 +644,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
ctxLen--
|
ctxLen--
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
|
||||||
|
if ctxLen <= 0 {
|
||||||
|
// return error if the truncated input would be empty or just special tokens
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
tokens = tokens[:ctxLen]
|
tokens = tokens[:ctxLen]
|
||||||
|
|
||||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||||
@@ -1803,6 +1808,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||||
err = client.Chat(c, &req, fn)
|
err = client.Chat(c, &req, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
var authError api.AuthorizationError
|
||||||
|
if errors.As(err, &authError) {
|
||||||
|
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "public_key": authError.PublicKey})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var apiError api.StatusError
|
||||||
|
if errors.As(err, &apiError) {
|
||||||
|
c.JSON(apiError.StatusCode, apiError)
|
||||||
|
return
|
||||||
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user