Compare commits

...

15 Commits

Author SHA1 Message Date
Jeffrey Morgan
1b308e1d2a model: fix global layer rope scale values for gemma 3 (#13452) 2025-12-12 16:29:01 -08:00
Daniel Hiltgen
bd6c1d6b49 flash attn: add auto mode for llama engine (#13052)
* flash attn: add auto mode for llama engine

If the user does not specify fa in the environment, use auto-mode.

* review comments

* ensure kv cache quantized types have FA explicitly enabled

additional review comments
2025-12-12 13:27:19 -08:00
Jeffrey Morgan
3af5d3b738 model: force rope factor 1.0 for Gemma 3 (#13445) 2025-12-12 13:27:08 -08:00
Daniel Hiltgen
7730895158 Enable Ollama engine by default (#13443)
This changes the default behavior to use the Ollama engine for supported
models, while retaining the ability to disable the Ollama engine and
fall back to the Llama engine.  Models in the OllamaEngineRequired list
will always run on the Ollama engine.
2025-12-12 11:48:43 -08:00
Eva H
de9ecfd01c tidy up lint warnings on windows (#13430) 2025-12-12 11:43:35 -05:00
Eva H
95fdd8d619 fix: select and update models folder in settings (#13412) 2025-12-12 11:09:37 -05:00
Devon Rifkin
9f7822851c docs: add docs for v1/responses and rework openai compat section (#13416)
* docs: add docs for v1/responses and rework openai compat section

I reworked the examples to be separated by topic and to be fully
runnable (i.e., they now log output instead of just suggesting how a
call might be made).

We now use `<CodeGroup>`s so that each example has a dropdown on the
docs site for users to choose, which makes the examples a lot more
digestible (since you only see approx 1/3 of the code you used to).

I also added a new tool to extract code examples into files so that it's
easier to actually run them and check that they work.

## Example

```shell
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
```

Output:

```
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368

  - 01_basic.py
  - 01_basic.js
  - 01_basic.sh
  - 02_responses.py
  - 02_responses.js
  - 02_responses.sh
  - 03_vision.py
  - 03_vision.js
  - 03_vision.sh

Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368

To run examples:

  cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
  npm install   # for JS examples

then run individual files with `node file.js`, `python file.py`, `bash file.sh`
```

In the future we should consider actually running the examples in CI and
having some sort of acceptance test so we can automatically detect when
our examples break. So this is just a start in that direction.

* Update docs/api/openai-compatibility.mdx

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>

* Update docs/api/openai-compatibility.mdx

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>

---------

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2025-12-11 17:39:40 -08:00
Parth Sareen
9b2035d194 openai: add tool call appending to previous assistant message (#13434)
* openai: add tool call appending to previous asst message

* add tests for thinking appending
2025-12-11 17:30:12 -08:00
Alexander Gusak
93d45d7a04 docs: fix link to modelfile.mdx (#13220) 2025-12-11 16:14:45 -08:00
JJ
709f842457 Update README.md (#13373)
Correct Markdown syntax for Swollama GitHub and DocC documentation links
2025-12-11 16:08:57 -08:00
Jeffrey Morgan
2dfb74410d model: fix rotary embeddings for ministral 3 (#13432) 2025-12-11 16:02:05 -08:00
Devon Rifkin
1eb5e75972 openai: add v1/responses support (#13351)
Only supporting the stateless part of the API.

Doc updates to come once this is shipped.

Closes: #9659
2025-12-11 15:37:10 -08:00
nicole pardal
3475d915cb embeddings: modified batch size (#13429)
This PR detects embedding models and sets batch_size = context_size so the full input fits in a single batch.
Previously, if batch size was smaller than the input, tokens could be split across batches and cause a SIGTRAP crash.
This change ensures all tokens stay in one batch and prevents crashes.
Fixes: #12938 #13054

Co-authored-by: Jesse Gross <jesse@ollama.com>
2025-12-11 15:36:31 -08:00
Jeffrey Morgan
48e78e9be1 template: add yesterdayDate helper function (#13431) 2025-12-11 14:47:55 -08:00
Jeffrey Morgan
a838421ea3 model: conversion and hyperparameter fixes for ministral and devstral (#13424) 2025-12-11 13:04:00 -08:00
34 changed files with 3875 additions and 282 deletions

View File

@@ -555,7 +555,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama. - [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples) - [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
- [Ollama for Swift](https://github.com/mattt/ollama-swift) - [Ollama for Swift](https://github.com/mattt/ollama-swift)
- [Swollama for Swift]([https://github.com/marcusziade/Swollama](https://github.com/guitaripod/Swollama) with [DocC]( https://guitaripod.github.io/Swollama/documentation/swollama) - [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
- [GoLamify](https://github.com/prasad89/golamify) - [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell) - [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API) - [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)

View File

@@ -347,7 +347,7 @@ type CreateProgressFunc func(ProgressResponse) error
// Create creates a model from a [Modelfile]. fn is a progress function that // Create creates a model from a [Modelfile]. fn is a progress function that
// behaves similarly to other methods (see [Client.Pull]). // behaves similarly to other methods (see [Client.Pull]).
// //
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md // [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error { func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error { return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
var resp ProgressResponse var resp ProgressResponse

View File

@@ -191,13 +191,6 @@ func LaunchNewApp() {
C.launchApp(appName) C.launchApp(appName)
} }
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
p := C.CString(path)
defer C.free(unsafe.Pointer(p))
C.uiRequest(p)
}
func registerLaunchAgent(hasCompletedFirstRun bool) { func registerLaunchAgent(hasCompletedFirstRun bool) {
// Remove any stale Login Item registrations // Remove any stale Login Item registrations
C.unregisterSelfFromLoginItem() C.unregisterSelfFromLoginItem()

View File

@@ -263,11 +263,6 @@ func createLoginShortcut() error {
return nil return nil
} }
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
wintray.SendUIRequestMessage(path)
}
func LaunchNewApp() { func LaunchNewApp() {
} }

View File

@@ -169,37 +169,47 @@ DlgResult fileDlg(FileDlgParams* params) {
} }
NSArray* urls = [panel URLs]; NSArray* urls = [panel URLs];
if(self->params->allowMultiple && [urls count] >= 1) { if([urls count] == 0) {
return DLG_CANCEL;
}
if(self->params->allowMultiple) {
// For multiple files, we need to return all paths separated by null bytes // For multiple files, we need to return all paths separated by null bytes
char* bufPtr = self->params->buf; char* bufPtr = self->params->buf;
int remainingBuf = self->params->nbuf; int remainingBuf = self->params->nbuf;
// Calculate total required buffer size first // Calculate total required buffer size first
int totalSize = 0; int totalSize = 0;
for(NSURL* url in urls) { for(NSURL* url in urls) {
char tempBuf[PATH_MAX]; char tempBuf[PATH_MAX];
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) { if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
return DLG_URLFAIL; return DLG_URLFAIL;
} }
totalSize += strlen(tempBuf) + 1; // +1 for null terminator totalSize += strlen(tempBuf) + 1; // +1 for null terminator
} }
totalSize += 1; // Final null terminator totalSize += 1; // Final null terminator
if(totalSize > self->params->nbuf) { if(totalSize > self->params->nbuf) {
// Not enough buffer space // Not enough buffer space
return DLG_URLFAIL; return DLG_URLFAIL;
} }
// Now actually copy the paths (we know we have space) // Now actually copy the paths (we know we have space)
bufPtr = self->params->buf; bufPtr = self->params->buf;
for(NSURL* url in urls) { for(NSURL* url in urls) {
char tempBuf[PATH_MAX]; char tempBuf[PATH_MAX];
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]; [url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
int pathLen = strlen(tempBuf); int pathLen = strlen(tempBuf);
strcpy(bufPtr, tempBuf); strcpy(bufPtr, tempBuf);
bufPtr += pathLen + 1; bufPtr += pathLen + 1;
} }
*bufPtr = '\0'; // Final null terminator *bufPtr = '\0'; // Final null terminator
} else {
// Single file/directory selection - write path to buffer
NSURL* url = [urls firstObject];
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
return DLG_URLFAIL;
}
} }
return DLG_OK; return DLG_OK;

View File

@@ -15,7 +15,7 @@ const multiFileBufferSize = w32.MAX_PATH * 10
type WinDlgError int type WinDlgError int
func (e WinDlgError) Error() string { func (e WinDlgError) Error() string {
return fmt.Sprintf("CommDlgExtendedError: %#x", e) return fmt.Sprintf("CommDlgExtendedError: %#x", int(e))
} }
func err() error { func err() error {

View File

@@ -224,9 +224,7 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
if _, err := os.Stat(settings.Models); err == nil { if _, err := os.Stat(settings.Models); err == nil {
env["OLLAMA_MODELS"] = settings.Models env["OLLAMA_MODELS"] = settings.Models
} else { } else {
slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err) slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err)
settings.Models = ""
s.store.SetSettings(settings)
} }
} }
if settings.ContextLength > 0 { if settings.ContextLength > 0 {

View File

@@ -158,16 +158,16 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
case uint32(UI_REQUEST_MSG_ID): case uint32(UI_REQUEST_MSG_ID):
// Requests for the UI must always come from the main event thread // Requests for the UI must always come from the main event thread
l := int(wParam) l := int(wParam)
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) //nolint:govet,gosec
t.app.UIRun(path) t.app.UIRun(path)
case WM_COPYDATA: case WM_COPYDATA:
// Handle URL scheme requests from other instances // Handle URL scheme requests from other instances
if lParam != 0 { if lParam != 0 {
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) //nolint:govet,gosec
if cds.DwData == 1 { // Our identifier for URL scheme messages if cds.DwData == 1 { // Our identifier for URL scheme messages
// Convert the data back to string // Convert the data back to string
data := make([]byte, cds.CbData) data := make([]byte, cds.CbData)
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) //nolint:govet,gosec
urlScheme := string(data) urlScheme := string(data)
handleURLSchemeRequest(urlScheme) handleURLSchemeRequest(urlScheme)
lResult = 1 // Return non-zero to indicate success lResult = 1 // Return non-zero to indicate success

View File

@@ -182,6 +182,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &llama4Model{} conv = &llama4Model{}
case "Mistral3ForConditionalGeneration": case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{} conv = &mistral3Model{}
case "Ministral3ForCausalLM":
conv = &mistral3CausalModel{}
case "MixtralForCausalLM": case "MixtralForCausalLM":
conv = &mixtralModel{} conv = &mixtralModel{}
case "GemmaForCausalLM": case "GemmaForCausalLM":

View File

@@ -30,13 +30,15 @@ type mistral3Model struct {
HiddenAct string `json:"hidden_act"` HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
RopeParameters struct { RopeParameters struct {
BetaFast float32 `json:"beta_fast"` BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"` BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"` Factor float32 `json:"factor"`
ScalingBeta float32 `json:"llama_4_scaling_beta"` Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"` RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"` RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"` } `json:"rope_parameters"`
} `json:"text_config"` } `json:"text_config"`
VisionModel struct { VisionModel struct {
@@ -50,6 +52,9 @@ type mistral3Model struct {
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
HiddenAct string `json:"hidden_act"` HiddenAct string `json:"hidden_act"`
RopeTheta float32 `json:"rope_theta"` RopeTheta float32 `json:"rope_theta"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
} `json:"vision_config"` } `json:"vision_config"`
MultiModalProjectorBias bool `json:"multimodal_projector_bias"` MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"` ProjectorHiddenAct string `json:"projector_hidden_act"`
@@ -72,10 +77,22 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads) kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta) kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow
if p.TextModel.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale
}
if p.TextModel.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim
}
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 { if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta }
if p.TextModel.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta
} }
// Vision configuration // Vision configuration
@@ -88,7 +105,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value // kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta)
// Multimodal configuration // Multimodal configuration
kv["mistral3.image_token_index"] = p.ImageTokenIndex kv["mistral3.image_token_index"] = p.ImageTokenIndex

View File

@@ -0,0 +1,181 @@
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistral3CausalModel struct {
ModelParameters
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize
// Text configuration
kv["mistral3.block_count"] = p.NumHiddenLayers
kv["mistral3.context_length"] = p.MaxPositionEmbeddings
kv["mistral3.embedding_length"] = p.HiddenSize
kv["mistral3.feed_forward_length"] = p.IntermediateSize
kv["mistral3.attention.head_count"] = p.NumAttentionHeads
kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["mistral3.attention.key_length"] = p.HeadDim
kv["mistral3.attention.value_length"] = p.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
if p.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
}
if p.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
}
if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
if p.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
return kv
}
func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
t.SetRepacker(p.repack)
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistral3CausalModel) Replacements() []string {
return []string{
"model.norm", "output_norm",
"model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"embed_tokens", "token_embd",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
"lm_head", "output",
}
}
func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, ".attn_q.weight") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, ".attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -50,7 +50,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
Advanced parameters (optional): Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema - `format`: the format to return a response in. Format can be `json` or a JSON schema
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `system`: system message to (overrides what is defined in the `Modelfile`) - `system`: system message to (overrides what is defined in the `Modelfile`)
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`) - `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
@@ -507,7 +507,7 @@ The `message` object has the following fields:
Advanced parameters (optional): Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema. - `format`: the format to return a response in. Format can be `json` or a JSON schema.
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
@@ -1189,7 +1189,7 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
- `template`: (optional) the prompt template for the model - `template`: (optional) the prompt template for the model
- `license`: (optional) a string or list of strings containing the license or licenses for the model - `license`: (optional) a string or list of strings containing the license or licenses for the model
- `system`: (optional) a string containing the system prompt for the model - `system`: (optional) a string containing the system prompt for the model
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.md#valid-parameters-and-values) for a list of parameters) - `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.mdx#valid-parameters-and-values) for a list of parameters)
- `messages`: (optional) a list of message objects used to create a conversation - `messages`: (optional) a list of message objects used to create a conversation
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
- `quantize` (optional): quantize a non-quantized (e.g. float16) model - `quantize` (optional): quantize a non-quantized (e.g. float16) model
@@ -1698,7 +1698,7 @@ Generate embeddings from a model
Advanced parameters: Advanced parameters:
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true` - `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `dimensions`: number of dimensions for the embedding - `dimensions`: number of dimensions for the embedding
@@ -1817,7 +1817,7 @@ Generate embeddings from a model
Advanced parameters: Advanced parameters:
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Examples ### Examples

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,46 @@
# extract-examples
Extracts code examples from MDX files to a temp directory so you can run them.
## Usage
```shell
go run docs/tools/extract-examples/main.go <mdx-file>
```
## Example
```shell
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
```
Output:
```
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
- 01_basic.py
- 01_basic.js
- 01_basic.sh
- 02_responses.py
- 02_responses.js
- 02_responses.sh
- 03_vision.py
- 03_vision.js
- 03_vision.sh
Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
To run examples:
cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
npm install # for JS examples
then run individual files with `node file.js`, `python file.py`, `bash file.sh`
```
## How it works
- Parses MDX files looking for fenced code blocks with filenames (e.g., ` ```python basic.py `)
- Groups examples by their `<CodeGroup>` and prefixes filenames with `01_`, `02_`, etc.
- Writes all extracted files to a temp directory

View File

@@ -0,0 +1,137 @@
package main
import (
"bufio"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
)
func main() {
if len(os.Args) < 2 {
fmt.Fprintln(os.Stderr, "Usage: go run extract-examples.go <mdx-file>")
os.Exit(1)
}
mdxFile := os.Args[1]
f, err := os.Open(mdxFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
defer f.Close()
// Create temp directory
tempDir, err := os.MkdirTemp("", "mdx-examples-*")
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating temp dir: %v\n", err)
os.Exit(1)
}
fmt.Printf("Extracting code examples to: %s\n\n", tempDir)
// Patterns
codeBlockStart := regexp.MustCompile("^```([a-zA-Z0-9_-]+)\\s+([^\\s]+)$")
codeGroupStart := regexp.MustCompile("^<CodeGroup")
codeGroupEnd := regexp.MustCompile("^</CodeGroup>")
scanner := bufio.NewScanner(f)
inCodeBlock := false
inCodeGroup := false
var currentFile string
var content strings.Builder
count := 0
codeGroupNum := 0
for scanner.Scan() {
line := scanner.Text()
// Track CodeGroup boundaries
if codeGroupStart.MatchString(line) {
inCodeGroup = true
codeGroupNum++
continue
}
if codeGroupEnd.MatchString(line) {
inCodeGroup = false
continue
}
if inCodeBlock {
if line == "```" {
// End of code block - write file
if currentFile != "" {
outPath := filepath.Join(tempDir, currentFile)
if err := os.WriteFile(outPath, []byte(content.String()), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing %s: %v\n", currentFile, err)
} else {
fmt.Printf(" - %s\n", currentFile)
count++
}
}
inCodeBlock = false
currentFile = ""
content.Reset()
} else {
content.WriteString(line)
content.WriteString("\n")
}
} else {
if matches := codeBlockStart.FindStringSubmatch(line); matches != nil {
inCodeBlock = true
filename := matches[2]
// Prefix with CodeGroup number if inside a CodeGroup
if inCodeGroup {
currentFile = fmt.Sprintf("%02d_%s", codeGroupNum, filename)
} else {
currentFile = filename
}
content.Reset()
}
}
}
if err := scanner.Err(); err != nil {
fmt.Fprintf(os.Stderr, "Error reading file: %v\n", err)
os.Exit(1)
}
// Write package.json for JavaScript dependencies
packageJSON := `{
"name": "mdx-examples",
"type": "module",
"dependencies": {
"openai": "^4",
"ollama": "^0.5"
}
}
`
if err := os.WriteFile(filepath.Join(tempDir, "package.json"), []byte(packageJSON), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing package.json: %v\n", err)
}
// Write pyproject.toml for Python dependencies
pyprojectTOML := `[project]
name = "mdx-examples"
version = "0.0.0"
dependencies = [
"openai",
"ollama",
]
`
if err := os.WriteFile(filepath.Join(tempDir, "pyproject.toml"), []byte(pyprojectTOML), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing pyproject.toml: %v\n", err)
}
fmt.Printf("\n")
fmt.Printf("Extracted %d file(s) to %s\n", count, tempDir)
fmt.Printf("\n")
fmt.Printf("To run examples:\n")
fmt.Printf("\n")
fmt.Printf(" cd %s\n npm install # for JS examples\n", tempDir)
fmt.Printf("\n")
fmt.Printf("then run individual files with `node file.js`, `python file.py`, `bash file.sh`\n")
}

View File

@@ -199,7 +199,7 @@ var (
// MultiUserCache optimizes prompt caching for multi-user scenarios // MultiUserCache optimizes prompt caching for multi-user scenarios
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE") MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
// Enable the new Ollama engine // Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE") NewEngine = BoolWithDefault("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length // ContextLength sets the default context length
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096) ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
// Auth enables authentication between the Ollama client and server // Auth enables authentication between the Ollama client and server
@@ -291,7 +291,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(true), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"}, "OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
// Informational // Informational

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/util/bufioutil" "github.com/ollama/ollama/fs/util/bufioutil"
"github.com/ollama/ollama/ml"
) )
type GGML struct { type GGML struct {
@@ -550,7 +551,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
}, nil }, nil
} }
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) { func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
context *= uint64(numParallel) context *= uint64(numParallel)
embedding := f.KV().EmbeddingLength() embedding := f.KV().EmbeddingLength()
@@ -791,7 +792,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
} }
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6 partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
if useFlashAttention { if useFlashAttention == ml.FlashAttentionEnabled {
// rough estimate of graph size with flash attention on // rough estimate of graph size with flash attention on
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
} }
@@ -809,6 +810,14 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType) return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
} }
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
return false
}
return true
}
// SupportsFlashAttention checks if the model supports flash attention // SupportsFlashAttention checks if the model supports flash attention
func (f GGML) SupportsFlashAttention() bool { func (f GGML) SupportsFlashAttention() bool {
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())] _, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]

View File

@@ -487,6 +487,63 @@ func TestEmbedTruncation(t *testing.T) {
} }
} }
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
func TestEmbedLargeInput(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
defer mcancel()
// Test with progressively larger inputs
testCases := []struct {
name string
inputWords int
}{
{"medium_input_256_words", 256},
{"large_input_512_words", 512},
{"very_large_input_800_words", 800},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
words := make([]string, tc.inputWords)
for i := range words {
words[i] = "word"
}
input := strings.Join(words, " ")
req := api.EmbedRequest{
Model: model,
Input: input,
KeepAlive: &api.Duration{Duration: 30 * time.Second},
}
res, err := embedTestHelper(mctx, client, t, req)
if err != nil {
t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err)
}
if len(res.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) == 0 {
t.Fatal("expected non-empty embedding")
}
t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount)
})
}
})
}
}
// TestEmbedStatusCode tests that errors from the embedding endpoint // TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client. // properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler // This test specifically checks the error handling path in EmbedHandler

View File

@@ -118,18 +118,22 @@ type ContextParams struct {
c C.struct_llama_context_params c C.struct_llama_context_params
} }
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams { func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention ml.FlashAttentionType, kvCacheType string) ContextParams {
params := C.llama_context_default_params() params := C.llama_context_default_params()
params.n_ctx = C.uint(numCtx) params.n_ctx = C.uint(numCtx)
params.n_batch = C.uint(batchSize) params.n_batch = C.uint(batchSize * numSeqMax)
params.n_ubatch = C.uint(batchSize)
params.n_seq_max = C.uint(numSeqMax) params.n_seq_max = C.uint(numSeqMax)
params.n_threads = C.int(threads) params.n_threads = C.int(threads)
params.n_threads_batch = params.n_threads params.n_threads_batch = params.n_threads
params.embeddings = C.bool(true) params.embeddings = C.bool(true)
if flashAttention { switch flashAttention {
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED case ml.FlashAttentionEnabled:
} else { params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_ENABLED)
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED case ml.FlashAttentionDisabled:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_DISABLED)
case ml.FlashAttentionAuto:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_AUTO)
} }
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType)) params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType)) params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))

View File

@@ -143,7 +143,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
var llamaModel *llama.Model var llamaModel *llama.Model
var textProcessor model.TextProcessor var textProcessor model.TextProcessor
var err error var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() { if envconfig.NewEngine(true) || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 { if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath) textProcessor, err = model.NewTextProcessor(modelPath)
} else { } else {
@@ -188,6 +188,11 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
if len(projectors) > 0 && llamaModel != nil { if len(projectors) > 0 && llamaModel != nil {
loadRequest.ProjectorPath = projectors[0] loadRequest.ProjectorPath = projectors[0]
} }
// Determine if the user has forced FA on or off
faUserSet := false
if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) {
faUserSet = true
}
fa := envconfig.FlashAttention(f.FlashAttention()) fa := envconfig.FlashAttention(f.FlashAttention())
@@ -205,19 +210,51 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType()) kvct := strings.ToLower(envconfig.KvCacheType())
if fa { if textProcessor == nil {
slog.Info("enabling flash attention") flashAttention := ml.FlashAttentionAuto
loadRequest.FlashAttention = true if faUserSet {
if fa {
// Flash Attention also supports kv cache quantization flashAttention = ml.FlashAttentionEnabled
// Enable if the requested and kv cache type is supported by the model } else {
if f.SupportsKVCacheType(kvct) { flashAttention = ml.FlashAttentionDisabled
loadRequest.KvCacheType = kvct }
} else { }
slog.Warn("kv cache type not supported by model", "type", kvct)
if kvct != "" {
if f.KVCacheTypeIsQuantized(kvct) {
if flashAttention != ml.FlashAttentionEnabled {
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct)
loadRequest.KvCacheType = ""
} else if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
}
} else {
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
}
}
}
loadRequest.FlashAttention = flashAttention
} else {
// For Ollama engine, use our SupportsFlashAttention logic
if fa {
slog.Info("enabling flash attention")
loadRequest.FlashAttention = ml.FlashAttentionEnabled
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
} }
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
} }
gpuLibs := ml.LibraryPaths(gpus) gpuLibs := ml.LibraryPaths(gpus)
@@ -435,7 +472,7 @@ type LoadRequest struct {
LoraPath []string LoraPath []string
Parallel int Parallel int
BatchSize int BatchSize int
FlashAttention bool FlashAttention ml.FlashAttentionType
KvSize int KvSize int
KvCacheType string KvCacheType string
NumThreads int NumThreads int
@@ -474,6 +511,13 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers) s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
} }
// Check if embedding model and adjust batch size accordingly
_, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())]
if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx {
s.loadRequest.BatchSize = s.options.NumCtx
slog.Info("embedding model detected, setting batch size to context length", "batch_size", s.loadRequest.BatchSize)
}
kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize), kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize),
s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention) s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention)

View File

@@ -433,3 +433,111 @@ func ChatMiddleware() gin.HandlerFunc {
c.Next() c.Next()
} }
} }
type ResponsesWriter struct {
BaseWriter
converter *openai.ResponsesStreamConverter
model string
stream bool
responseID string
itemID string
}
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
}
func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
if err := json.Unmarshal(data, &chatResponse); err != nil {
return 0, err
}
if w.stream {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
}
}
return len(data), nil
}
// Non-streaming response
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
func (w *ResponsesWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ResponsesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ResponsesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
chatReq, err := openai.FromResponsesRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
// Check if client requested streaming (defaults to false)
streamRequested := req.Stream != nil && *req.Stream
// Pass streaming preference to the underlying chat request
chatReq.Stream = &streamRequested
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
responseID := fmt.Sprintf("resp_%d", rand.Intn(999999))
itemID := fmt.Sprintf("msg_%d", rand.Intn(999999))
w := &ResponsesWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
model: req.Model,
stream: streamRequested,
responseID: responseID,
itemID: itemID,
}
// Set headers based on streaming mode
if streamRequested {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
c.Next()
}
}

View File

@@ -74,7 +74,7 @@ type BackendParams struct {
GPULayers GPULayersList GPULayers GPULayersList
// FlashAttention indicates that we should use a fused flash attention kernel // FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool FlashAttention FlashAttentionType
} }
var backends = make(map[string]func(string, BackendParams) (Backend, error)) var backends = make(map[string]func(string, BackendParams) (Backend, error))

View File

@@ -109,7 +109,7 @@ type Backend struct {
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device // btDeviceMemory maps from a buffer type to the memory allocations associated with that device
btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory
flashAttention bool flashAttention ml.FlashAttentionType
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler // maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
maxGraphNodes int maxGraphNodes int
@@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
} }
func (b *Backend) CacheConfig() ml.CacheConfig { func (b *Backend) CacheConfig() ml.CacheConfig {
if b.flashAttention { if b.flashAttention == ml.FlashAttentionEnabled {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD} return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
} else { } else {
return ml.CacheConfig{CachePadding: 256, PermutedV: true} return ml.CacheConfig{CachePadding: 256, PermutedV: true}
@@ -1676,7 +1676,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
query := t.Permute(ctx, 0, 2, 1, 3) query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3)
if t.b.flashAttention { if t.b.flashAttention == ml.FlashAttentionEnabled {
value = value.Permute(ctx, 0, 2, 1, 3) value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)

View File

@@ -492,6 +492,32 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
return true return true
} }
type FlashAttentionType int32
const (
// Aligned with llama_flash_attn_type
FlashAttentionAuto FlashAttentionType = -1
FlashAttentionDisabled FlashAttentionType = 0
FlashAttentionEnabled FlashAttentionType = 1
)
func (f FlashAttentionType) LogValue() slog.Value {
return slog.AnyValue(f.String())
}
func (f FlashAttentionType) String() string {
switch f {
case FlashAttentionAuto:
return "Auto"
case FlashAttentionDisabled:
return "Disabled"
case FlashAttentionEnabled:
return "Enabled"
default:
return "unknown"
}
}
// Given the list of GPUs this instantiation is targeted for, // Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variables // figure out the visible devices environment variables
// Set mustFilter true to enable filtering of CUDA devices // Set mustFilter true to enable filtering of CUDA devices

View File

@@ -28,10 +28,10 @@ type TextConfig struct {
finalLogitSoftcap float32 finalLogitSoftcap float32
} }
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor { func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()} ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" { if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(scale))))
ropeOpts = append(ropeOpts, ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.ropeOriginalContext), rope.WithOriginalContextLength(o.ropeOriginalContext),
rope.WithExtrapolationFactor(o.ropeExtrapolation), rope.WithExtrapolationFactor(o.ropeExtrapolation),
@@ -41,7 +41,7 @@ func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positi
) )
} }
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...) return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./scale, ropeOpts...)
} }
type TextModel struct { type TextModel struct {
@@ -83,19 +83,22 @@ func newTextModel(c fs.Config) *TextModel {
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0), ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0), ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0), ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
ropeScale: c.Float("rope.scaling.factor", 1.0), ropeScale: c.Float("rope.scaling.factor", 8.0),
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0), finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
}, },
} }
// Google's Gemma 3 release with sliding window attention does // Google's Gemma 3 release with sliding window attention does
// not use final logit softcapping, and so force it to 0.0 // not use final logit softcapping, and so force it to 0.0
// The QAT weights for Gemma 3 also included an incorrect
// value for the rope scale, so we need to set it to 1.0 here.
// TODO (jmorganca): this should ideally be set to 0.0 in the // TODO (jmorganca): this should ideally be set to 0.0 in the
// model configuration instead of here, as future versions of // model configuration instead of here, as future versions of
// models may include both sliding window attention and final // models may include both sliding window attention and final
// logit softcapping. // logit softcapping.
if slices.Contains(m.TextConfig.slidingWindowPattern, true) { if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
m.TextConfig.finalLogitSoftcap = 0.0 m.TextConfig.finalLogitSoftcap = 0.0
m.TextConfig.ropeScale = 1.0
} }
if numBlocks == gemma27BLayerCount { if numBlocks == gemma27BLayerCount {
@@ -114,31 +117,31 @@ type TextSelfAttention struct {
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
} }
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 { func (opts *TextConfig) ropeValuesForLayer(layer int) (base float32, scale float32) {
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] { if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
return opts.ropeLocalBase return opts.ropeLocalBase, 1.0
} }
// Standard Gemma3: only every n-th layer is global, // Standard Gemma3: only every n-th layer is global,
// where n = gemmaGlobalCacheCount, otherwise use // where n = gemmaGlobalCacheCount, otherwise use
// the local rope base // the local rope base
if (layer+1)%gemmaGlobalCacheCount > 0 { if (layer+1)%gemmaGlobalCacheCount > 0 {
return opts.ropeLocalBase return opts.ropeLocalBase, 1.0
} }
// default to global rope base // default to global rope base
return opts.ropeBase return opts.ropeBase, opts.ropeScale
} }
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
ropeBase := opts.ropeBaseForLayer(layer) ropeBase, ropeScale := opts.ropeValuesForLayer(layer)
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps) q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase) q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase, ropeScale)
if opts.largeModelScaling { if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -149,7 +152,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps) k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase) k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase, ropeScale)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -162,7 +165,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil ropeBase, ropeScale := m.TextConfig.ropeValuesForLayer(layer)
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase, ropeScale), nil
} }
type TextMLP struct { type TextMLP struct {

View File

@@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
@@ -17,10 +18,30 @@ type TextOptions struct {
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
ropeOrigPosEmbeddings int ropeOrigPosEmbeddings int
ropeScalingBeta float32 ropeScalingBeta float32
ropeType string
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
ropeMscale float32
ropeMscaleAllDim float32
} }
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale) var ropeOpts []func(*rope.Options)
if o.ropeType == "yarn" {
if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 {
ropeOpts = append(ropeOpts, rope.WithAttentionFactor(1.0/float32(0.1*math.Log(float64(o.ropeScale))+1.0)))
}
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings),
rope.WithExtrapolationFactor(o.ropeExtrapolation),
rope.WithBetaFast(o.ropeBetaFast),
rope.WithBetaSlow(o.ropeBetaSlow),
)
}
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, ropeOpts...)
} }
type TextModel struct { type TextModel struct {
@@ -150,9 +171,15 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1), ropeScale: c.Float("rope.scaling.factor", 1.0),
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")), ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
ropeScalingBeta: c.Float("rope.scaling_beta"), ropeScalingBeta: c.Float("rope.scaling_beta", 0.1),
ropeBetaFast: c.Float("rope.scaling.beta_fast", 32.0),
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
ropeType: c.String("rope.scaling.type"),
ropeMscale: c.Float("rope.scaling.mscale"),
ropeMscaleAllDim: c.Float("rope.scaling.mscale_all_dim"),
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1),
}, },
} }
} }

View File

@@ -487,29 +487,9 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
} }
} }
types := []string{"jpeg", "jpg", "png", "webp"} img, err := decodeImageURL(url)
valid := false
// support blank mime type to match api/chat taking just unadorned base64
if strings.HasPrefix(url, "data:;base64,") {
url = strings.TrimPrefix(url, "data:;base64,")
valid = true
}
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, errors.New("invalid image input")
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil { if err != nil {
return nil, errors.New("invalid message format") return nil, err
} }
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}}) messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
@@ -648,6 +628,35 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
return "" return ""
} }
// decodeImageURL decodes a base64 data URI into raw image bytes.
func decodeImageURL(url string) (api.ImageData, error) {
types := []string{"jpeg", "jpg", "png", "webp"}
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
if strings.HasPrefix(url, "data:;base64,") {
url = strings.TrimPrefix(url, "data:;base64,")
} else {
valid := false
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, errors.New("invalid image input")
}
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil {
return nil, errors.New("invalid image input")
}
return img, nil
}
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall // FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) { func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
apiToolCalls := make([]api.ToolCall, len(toolCalls)) apiToolCalls := make([]api.ToolCall, len(toolCalls))

1015
openai/responses.go Normal file

File diff suppressed because it is too large Load Diff

1842
openai/responses_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -26,6 +26,7 @@ import (
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/runner/common"
) )
@@ -832,7 +833,7 @@ func (s *Server) loadModel(
ppath string, ppath string,
kvSize int, kvSize int,
kvCacheType string, kvCacheType string,
flashAttention bool, flashAttention ml.FlashAttentionType,
threads int, threads int,
multiUserCache bool, multiUserCache bool,
) { ) {
@@ -842,7 +843,7 @@ func (s *Server) loadModel(
panic(err) panic(err)
} }
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType) ctxParams := llama.NewContextParams(kvSize, s.batchSize, s.parallel, threads, flashAttention, kvCacheType)
s.lc, err = llama.NewContextWithModel(s.model, ctxParams) s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
if err != nil { if err != nil {
panic(err) panic(err)

View File

@@ -1203,16 +1203,22 @@ func (s *Server) allocModel(
return errors.New("loras are not yet implemented") return errors.New("loras are not yet implemented")
} }
if s.model.Config().Cache == nil {
if parallel > 1 {
parallel = 1
slog.Warn("model does not support caching, disabling parallel processing")
}
if s.batchSize < kvSize {
s.batchSize = kvSize
slog.Warn("model does not support caching, setting batch size to context length", "batch_size", kvSize)
}
}
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache) s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
if err != nil { if err != nil {
return err return err
} }
if !s.cache.enabled && parallel > 1 {
parallel = 1
slog.Warn("model does not support caching, disabling parallel processing")
}
s.parallel = parallel s.parallel = parallel
s.seqs = make([]*Sequence, s.parallel) s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) s.seqsSem = semaphore.NewWeighted(int64(s.parallel))

View File

@@ -1532,6 +1532,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler) r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler) r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
if rc != nil { if rc != nil {
// wrap old with new // wrap old with new
@@ -2393,3 +2394,4 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
} }
return msgs return msgs
} }

View File

@@ -127,6 +127,9 @@ var funcs = template.FuncMap{
// Default format is YYYY-MM-DD // Default format is YYYY-MM-DD
return time.Now().Format("2006-01-02") return time.Now().Format("2006-01-02")
}, },
"yesterdayDate": func(args ...string) string {
return time.Now().AddDate(0, 0, -1).Format("2006-01-02")
},
"toTypeScriptType": func(v any) string { "toTypeScriptType": func(v any) string {
if param, ok := v.(api.ToolProperty); ok { if param, ok := v.(api.ToolProperty); ok {
return param.ToTypeScriptType() return param.ToTypeScriptType()

View File

@@ -10,6 +10,7 @@ import (
"slices" "slices"
"strings" "strings"
"testing" "testing"
"time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@@ -451,6 +452,72 @@ func TestExecuteWithSuffix(t *testing.T) {
} }
} }
func TestDateFunctions(t *testing.T) {
t.Run("currentDate", func(t *testing.T) {
tmpl, err := Parse("{{- range .Messages }}{{ .Content }}{{ end }} Today is {{ currentDate }}")
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil {
t.Fatal(err)
}
expected := "Hello Today is " + time.Now().Format("2006-01-02")
if b.String() != expected {
t.Errorf("got %q, want %q", b.String(), expected)
}
})
t.Run("yesterdayDate", func(t *testing.T) {
tmpl, err := Parse("{{- range .Messages }}{{ .Content }}{{ end }} Yesterday was {{ yesterdayDate }}")
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil {
t.Fatal(err)
}
expected := "Hello Yesterday was " + time.Now().AddDate(0, 0, -1).Format("2006-01-02")
if b.String() != expected {
t.Errorf("got %q, want %q", b.String(), expected)
}
})
t.Run("yesterdayDate format", func(t *testing.T) {
tmpl, err := Parse("{{- range .Messages }}{{ end }}{{ yesterdayDate }}")
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil {
t.Fatal(err)
}
// Verify the format matches YYYY-MM-DD
result := b.String()
if len(result) != 10 {
t.Errorf("expected date length 10, got %d: %q", len(result), result)
}
// Parse and verify it's a valid date
parsed, err := time.Parse("2006-01-02", result)
if err != nil {
t.Errorf("failed to parse date %q: %v", result, err)
}
// Verify it's yesterday
yesterday := time.Now().AddDate(0, 0, -1)
if parsed.Year() != yesterday.Year() || parsed.Month() != yesterday.Month() || parsed.Day() != yesterday.Day() {
t.Errorf("expected yesterday's date, got %v", parsed)
}
})
}
func TestCollate(t *testing.T) { func TestCollate(t *testing.T) {
cases := []struct { cases := []struct {
name string name string