mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 12:54:12 +02:00
Compare commits
8 Commits
pdevine/bf
...
parth/samp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0de5bbd0fe | ||
|
|
42a14f7f63 | ||
|
|
f8c3dbe5b5 | ||
|
|
b078dd157c | ||
|
|
2ddacd7516 | ||
|
|
da0e345200 | ||
|
|
df94175a0f | ||
|
|
61a8825216 |
@@ -201,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
case "CohereForCausalLM":
|
case "CohereForCausalLM":
|
||||||
conv = &commandrModel{}
|
conv = &commandrModel{}
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported architecture")
|
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(bts, conv); err != nil {
|
if err := json.Unmarshal(bts, conv); err != nil {
|
||||||
|
|||||||
@@ -330,7 +330,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.Wait() != nil {
|
if err := g.Wait(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -211,8 +211,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|||||||
// final logit softcap
|
// final logit softcap
|
||||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
||||||
hiddenState = hiddenState.Tanh(ctx)
|
hiddenState = hiddenState.Tanh(ctx)
|
||||||
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
|
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
|
||||||
return hiddenState.Rows(ctx, outputs), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -561,14 +561,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := sample.NewSampler(
|
sampler := sample.NewSampler(req.Options, grammar)
|
||||||
req.Options.Temperature,
|
|
||||||
req.Options.TopK,
|
|
||||||
req.Options.TopP,
|
|
||||||
req.Options.MinP,
|
|
||||||
req.Options.Seed,
|
|
||||||
grammar,
|
|
||||||
)
|
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.Options.NumPredict,
|
numPredict: req.Options.NumPredict,
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,6 +28,10 @@ type Sampler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
|
if len(logits) == 0 {
|
||||||
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
|
}
|
||||||
|
|
||||||
tokens := make([]token, len(logits))
|
tokens := make([]token, len(logits))
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
@@ -94,13 +100,6 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
tokens = topP(tokens, s.topP)
|
tokens = topP(tokens, s.topP)
|
||||||
tokens = minP(tokens, s.minP)
|
tokens = minP(tokens, s.minP)
|
||||||
|
|
||||||
// TODO: this should fall back to greedy sampling
|
|
||||||
// or topP, topK values etc should be such that
|
|
||||||
// there are always tokens to sample from
|
|
||||||
if len(tokens) == 0 {
|
|
||||||
return token{}, errors.New("no tokens to sample from")
|
|
||||||
}
|
|
||||||
|
|
||||||
var r float32
|
var r float32
|
||||||
if s.rng != nil {
|
if s.rng != nil {
|
||||||
r = s.rng.Float32()
|
r = s.rng.Float32()
|
||||||
@@ -123,43 +122,71 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
return 1
|
return 1
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if math.IsNaN(float64(sum)) {
|
||||||
|
return token{}, errors.New("sample: logits sum to NaN, check model output")
|
||||||
|
}
|
||||||
return tokens[idx], nil
|
return tokens[idx], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
// SamplerParams contains the validated and normalized parameters for a sampler
|
||||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
type SamplerParams struct {
|
||||||
|
Temperature float32 `json:"temperature"`
|
||||||
|
TopK int `json:"top_k"`
|
||||||
|
TopP float32 `json:"top_p"`
|
||||||
|
MinP float32 `json:"min_p"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler to handle validation during JSON unmarshaling
|
||||||
|
func (p *SamplerParams) UnmarshalJSON(data []byte) error {
|
||||||
|
type rawParams SamplerParams
|
||||||
|
if err := json.Unmarshal(data, (*rawParams)(p)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate and normalize after unmarshaling
|
||||||
|
if p.Temperature < 0.0 {
|
||||||
|
p.Temperature = 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.TopP < 0.0 {
|
||||||
|
p.TopP = 0.0
|
||||||
|
}
|
||||||
|
if p.TopP >= 1.0 {
|
||||||
|
p.TopP = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.MinP < 0.0 {
|
||||||
|
p.MinP = 0.0
|
||||||
|
}
|
||||||
|
if p.MinP >= 1.0 {
|
||||||
|
p.MinP = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSampler creates a new sampler with the given options
|
||||||
|
func NewSampler(opts *api.Options, grammar *Grammar) Sampler {
|
||||||
|
var params SamplerParams
|
||||||
|
data, _ := json.Marshal(opts)
|
||||||
|
_ = json.Unmarshal(data, ¶ms)
|
||||||
|
|
||||||
var rng *rand.Rand
|
var rng *rand.Rand
|
||||||
if seed != -1 {
|
if params.Seed != -1 {
|
||||||
// PCG requires two parameters: sequence and stream
|
// PCG requires two parameters: sequence and stream
|
||||||
// Use original seed for sequence
|
// Use original seed for sequence
|
||||||
sequence := uint64(seed)
|
sequence := uint64(params.Seed)
|
||||||
// Use golden ratio hash to generate statistically independent seeds
|
// Use golden ratio hash to generate statistically independent seeds
|
||||||
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
||||||
}
|
}
|
||||||
if temperature < 0.0 {
|
|
||||||
temperature = 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
if topP < 0.0 {
|
|
||||||
topP = 0.0
|
|
||||||
}
|
|
||||||
if topP >= 1.0 {
|
|
||||||
topP = 1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
if minP < 0.0 {
|
|
||||||
minP = 0.0
|
|
||||||
}
|
|
||||||
if minP >= 1.0 {
|
|
||||||
minP = 1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
return Sampler{
|
return Sampler{
|
||||||
rng: rng,
|
rng: rng,
|
||||||
topK: topK,
|
topK: params.TopK,
|
||||||
topP: topP,
|
topP: params.TopP,
|
||||||
minP: minP,
|
minP: params.MinP,
|
||||||
temperature: temperature,
|
temperature: params.Temperature,
|
||||||
grammar: grammar,
|
grammar: grammar,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
|
sampler := NewSampler(createSamplerOptions(0.8, 0, 0, 0, 42), nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
sampler.Sample(logits)
|
sampler.Sample(logits)
|
||||||
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
|
|
||||||
for _, tc := range configs {
|
for _, tc := range configs {
|
||||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
|
sampler := NewSampler(createSamplerOptions(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed), nil)
|
||||||
sampler.Sample(logits)
|
sampler.Sample(logits)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
|
|
||||||
// Test with combined transforms separately - topK influences performance greatly
|
// Test with combined transforms separately - topK influences performance greatly
|
||||||
b.Run("TransformCombined", func(b *testing.B) {
|
b.Run("TransformCombined", func(b *testing.B) {
|
||||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
|
sampler := NewSampler(createSamplerOptions(0.8, 50, 0.9, 0.05, 42), nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
|
|||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := NewSampler(0, -1, 0, 0, -1, nil)
|
sampler := NewSampler(createSamplerOptions(0, -1, 0, 0, -1), nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
|
|||||||
@@ -1,13 +1,26 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func createSamplerOptions(temperature float32, topK int, topP float32, minP float32, seed int) *api.Options {
|
||||||
|
return &api.Options{
|
||||||
|
Temperature: temperature,
|
||||||
|
TopK: topK,
|
||||||
|
TopP: topP,
|
||||||
|
MinP: minP,
|
||||||
|
Seed: seed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWeighted(t *testing.T) {
|
func TestWeighted(t *testing.T) {
|
||||||
logits := []float32{-10, 3, -10, -10}
|
logits := []float32{-10, 3, -10, -10}
|
||||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
sampler := NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
|
||||||
got, err := sampler.Sample(logits)
|
got, err := sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -19,7 +32,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logits = []float32{-100, -10, 0, 10}
|
logits = []float32{-100, -10, 0, 10}
|
||||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
sampler = NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -29,12 +42,35 @@ func TestWeighted(t *testing.T) {
|
|||||||
if want != got {
|
if want != got {
|
||||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test very high p
|
||||||
|
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||||
|
// Use extremely small topP to filter out all tokens
|
||||||
|
sampler = NewSampler(createSamplerOptions(1.0, 0, 1e-10, 0, 0), nil)
|
||||||
|
got, err = sampler.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Should get the token with the highest logit
|
||||||
|
want = int32(0)
|
||||||
|
if want != got {
|
||||||
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||||
|
sampler = NewSampler(createSamplerOptions(1, 0, 0.95, 0.05, 0), nil)
|
||||||
|
got, err = sampler.Sample(logits)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error, got %d", got)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
func BenchmarkSample(b *testing.B) {
|
||||||
samplers := map[string]Sampler{
|
samplers := map[string]Sampler{
|
||||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
"Greedy": NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil), // Use NewSampler with temp=0 for greedy
|
||||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
"Weighted": NewSampler(createSamplerOptions(0.5, 10, 0.9, 0.2, -1), nil),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate random logits for benchmarking
|
// Generate random logits for benchmarking
|
||||||
|
|||||||
@@ -168,27 +168,53 @@ func TestTopP(t *testing.T) {
|
|||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = topK(tokens, 20)
|
tokens = topK(tokens, 20)
|
||||||
|
|
||||||
// Then apply topP
|
// Test with very high p value
|
||||||
tokens = topP(tokens, 0.95)
|
got := topP(tokens, 1.0)
|
||||||
|
|
||||||
// Should keep tokens until cumsum > 0.95
|
// Should keep all tokens since p is 1
|
||||||
if len(tokens) > 3 {
|
if len(got) != len(input) {
|
||||||
|
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with normal p value
|
||||||
|
got = topP(tokens, 0.95)
|
||||||
|
|
||||||
|
if len(got) > 3 {
|
||||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||||
t.Logf("got: %v", tokens)
|
t.Logf("got: %v", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test edge case - ensure at least one token remains
|
// Test edge case - ensure at least one token remains
|
||||||
input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
input = []float32{-1e6, -1e6, -1e7}
|
||||||
tokens = toTokens(input)
|
tokens = toTokens(input)
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = topP(tokens, 0.0) // Very small p
|
got = topP(tokens, 0.0)
|
||||||
if len(tokens) < 1 {
|
if len(got) < 1 {
|
||||||
t.Error("topP should keep at least one token")
|
t.Error("topP should keep at least one token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test with zero p value
|
||||||
|
got = topP(tokens, 0.0)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = toTokens(input)
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
got = topP(tokens, 1e-10)
|
||||||
|
if len(got) == 0 {
|
||||||
|
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMinP(t *testing.T) {
|
func TestMinP(t *testing.T) {
|
||||||
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
|
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
|
||||||
tokens := toTokens(input)
|
tokens := toTokens(input)
|
||||||
|
|
||||||
// First apply temperature and softmax
|
// First apply temperature and softmax
|
||||||
@@ -225,30 +251,48 @@ func TestMinP(t *testing.T) {
|
|||||||
t.Logf("got: %v", tokens)
|
t.Logf("got: %v", tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test with single token
|
||||||
|
tokens = toTokens(input[:1])
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
tokens = minP(tokens, 0.1)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(tokens) != 1 {
|
||||||
|
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
|
||||||
|
t.Logf("got: %v", tokens)
|
||||||
|
}
|
||||||
|
|
||||||
input = []float32{1e-10, 1e-10, 1e-10}
|
input = []float32{1e-10, 1e-10, 1e-10}
|
||||||
tokens = toTokens(input)
|
tokens = toTokens(input)
|
||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = minP(tokens, 1.0)
|
tokens = minP(tokens, 1.0)
|
||||||
if len(tokens) < 1 {
|
if len(tokens) < 1 {
|
||||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||||
}
|
got := minP(tokens, 1.0)
|
||||||
|
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSortLogits(t *testing.T) {
|
// Test with normal p value
|
||||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
got = minP(tokens, 0.2)
|
||||||
tokens := toTokens(input)
|
|
||||||
|
|
||||||
tokens = topK(tokens, 20)
|
// Should keep tokens with prob >= 0.2 * max_prob
|
||||||
|
if len(got) > 3 {
|
||||||
for i := 1; i < len(tokens); i++ {
|
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||||
if tokens[i].value > tokens[i-1].value {
|
t.Logf("got: %v", got)
|
||||||
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
|
||||||
i, tokens[i].value, tokens[i-1].value)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
// Test with zero p value
|
||||||
compareLogits(t, "sortLogits", want, tokens)
|
got = minP(tokens, 0.0)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(got) != len(tokens) {
|
||||||
|
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkTransforms(b *testing.B) {
|
func BenchmarkTransforms(b *testing.B) {
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
|
||||||
"github.com/ollama/ollama/server/internal/internal/names"
|
"github.com/ollama/ollama/server/internal/internal/names"
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
@@ -213,12 +212,6 @@ type Registry struct {
|
|||||||
// request. If zero, [DefaultChunkingThreshold] is used.
|
// request. If zero, [DefaultChunkingThreshold] is used.
|
||||||
ChunkingThreshold int64
|
ChunkingThreshold int64
|
||||||
|
|
||||||
// MaxChunkSize is the maximum size of a chunk to download. If zero,
|
|
||||||
// the default is [DefaultMaxChunkSize].
|
|
||||||
//
|
|
||||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
|
||||||
MaxChunkSize int64
|
|
||||||
|
|
||||||
// Mask, if set, is the name used to convert non-fully qualified names
|
// Mask, if set, is the name used to convert non-fully qualified names
|
||||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||||
Mask string
|
Mask string
|
||||||
@@ -447,6 +440,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(bmizerany): decide if this should be considered valid. Maybe
|
||||||
|
// server-side we special case '{}' to have some special meaning? Maybe
|
||||||
|
// "archiving" a tag (which is how we reason about it in the registry
|
||||||
|
// already, just with a different twist).
|
||||||
if len(m.Layers) == 0 {
|
if len(m.Layers) == 0 {
|
||||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
@@ -456,11 +454,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
exists := func(l *Layer) bool {
|
// TODO(bmizerany): work to remove the need to do this
|
||||||
info, err := c.Get(l.Digest)
|
|
||||||
return err == nil && info.Size == l.Size
|
|
||||||
}
|
|
||||||
|
|
||||||
layers := m.Layers
|
layers := m.Layers
|
||||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||||
layers = append(layers, m.Config)
|
layers = append(layers, m.Config)
|
||||||
@@ -469,19 +463,16 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
// Send initial layer trace events to allow clients to have an
|
// Send initial layer trace events to allow clients to have an
|
||||||
// understanding of work to be done before work starts.
|
// understanding of work to be done before work starts.
|
||||||
t := traceFromContext(ctx)
|
t := traceFromContext(ctx)
|
||||||
skip := make([]bool, len(layers))
|
for _, l := range layers {
|
||||||
for i, l := range layers {
|
|
||||||
t.update(l, 0, nil)
|
t.update(l, 0, nil)
|
||||||
if exists(l) {
|
|
||||||
skip[i] = true
|
|
||||||
t.update(l, l.Size, ErrCached)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
var g errgroup.Group
|
||||||
g.SetLimit(r.maxStreams())
|
g.SetLimit(r.maxStreams())
|
||||||
for i, l := range layers {
|
for _, l := range layers {
|
||||||
if skip[i] {
|
info, err := c.Get(l.Digest)
|
||||||
|
if err == nil && info.Size == l.Size {
|
||||||
|
t.update(l, l.Size, ErrCached)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -490,23 +481,26 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
t.update(l, 0, err)
|
t.update(l, 0, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// TODO(bmizerany): fix this unbounded use of defer
|
||||||
defer chunked.Close()
|
defer chunked.Close()
|
||||||
|
|
||||||
var progress atomic.Int64
|
var progress atomic.Int64
|
||||||
for cs, err := range r.chunksums(ctx, name, l) {
|
for cs, err := range r.chunksums(ctx, name, l) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Bad chunksums response, update tracing
|
||||||
|
// clients and then bail.
|
||||||
t.update(l, progress.Load(), err)
|
t.update(l, progress.Load(), err)
|
||||||
break
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Go(func() (err error) {
|
g.Go(func() (err error) {
|
||||||
defer func() { t.update(l, progress.Load(), err) }()
|
defer func() {
|
||||||
|
|
||||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
||||||
}
|
}
|
||||||
err := func() error {
|
t.update(l, progress.Load(), err)
|
||||||
|
}()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -518,35 +512,19 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
// Count bytes towards
|
// Count bytes towards progress, as they
|
||||||
// progress, as they arrive, so
|
// arrive, so that our bytes piggyback other
|
||||||
// that our bytes piggyback
|
// chunk updates on completion.
|
||||||
// other chunk updates on
|
|
||||||
// completion.
|
|
||||||
//
|
//
|
||||||
// This tactic is enough to
|
// This tactic is enough to show "smooth"
|
||||||
// show "smooth" progress given
|
// progress given the current CLI client. In
|
||||||
// the current CLI client. In
|
// the near future, the server should report
|
||||||
// the near future, the server
|
// download rate since it knows better than a
|
||||||
// should report download rate
|
// client that is measuring rate based on
|
||||||
// since it knows better than
|
// wall-clock time-since-last-update.
|
||||||
// a client that is measuring
|
|
||||||
// rate based on wall-clock
|
|
||||||
// time-since-last-update.
|
|
||||||
body := &trackingReader{r: res.Body, n: &progress}
|
body := &trackingReader{r: res.Body, n: &progress}
|
||||||
|
|
||||||
err = chunked.Put(cs.Chunk, cs.Digest, body)
|
return chunked.Put(cs.Chunk, cs.Digest, body)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
if !canRetry(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -554,13 +532,10 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// store the manifest blob
|
|
||||||
md := blob.DigestFromBytes(m.Data)
|
md := blob.DigestFromBytes(m.Data)
|
||||||
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// commit the manifest with a link
|
|
||||||
return c.Link(m.Name, md)
|
return c.Link(m.Name, md)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -782,12 +757,15 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
|||||||
}
|
}
|
||||||
blobURL := res.Header.Get("Content-Location")
|
blobURL := res.Header.Get("Content-Location")
|
||||||
|
|
||||||
|
var size int64
|
||||||
s := bufio.NewScanner(res.Body)
|
s := bufio.NewScanner(res.Body)
|
||||||
s.Split(bufio.ScanWords)
|
s.Split(bufio.ScanWords)
|
||||||
for {
|
for {
|
||||||
if !s.Scan() {
|
if !s.Scan() {
|
||||||
if s.Err() != nil {
|
if s.Err() != nil {
|
||||||
yield(chunksum{}, s.Err())
|
yield(chunksum{}, s.Err())
|
||||||
|
} else if size != l.Size {
|
||||||
|
yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -811,6 +789,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size += chunk.Size()
|
||||||
|
if size > l.Size {
|
||||||
|
yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
cs := chunksum{
|
cs := chunksum{
|
||||||
URL: blobURL,
|
URL: blobURL,
|
||||||
Chunk: chunk,
|
Chunk: chunk,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -70,7 +71,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
|||||||
// communication is attempted.
|
// communication is attempted.
|
||||||
//
|
//
|
||||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
c, err := blob.Open(t.TempDir())
|
c, err := blob.Open(t.TempDir())
|
||||||
@@ -88,7 +89,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|||||||
r := &Registry{
|
r := &Registry{
|
||||||
Cache: c,
|
Cache: c,
|
||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
Transport: recordRoundTripper(h),
|
Transport: recordRoundTripper(upstreamRegistry),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -767,3 +768,74 @@ func TestUnlink(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPullChunksums(t *testing.T) {
|
||||||
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
content := "hello"
|
||||||
|
var chunksums string
|
||||||
|
contentDigest := func() blob.Digest {
|
||||||
|
return blob.DigestFromBytes(content)
|
||||||
|
}
|
||||||
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/manifests/latest"):
|
||||||
|
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
|
||||||
|
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
|
||||||
|
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
|
||||||
|
w.Header().Set("Content-Location", loc)
|
||||||
|
io.WriteString(w, chunksums)
|
||||||
|
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
|
||||||
|
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected request: %v", r)
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
rc.MaxStreams = 1 // prevent concurrent chunk downloads
|
||||||
|
rc.ChunkingThreshold = 1 // for all blobs to be chunked
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var reads []int64
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Logf("Update: %v %d %v", l, n, err)
|
||||||
|
mu.Lock()
|
||||||
|
reads = append(reads, n)
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
|
||||||
|
blob.DigestFromBytes("hel"),
|
||||||
|
blob.DigestFromBytes("lo"),
|
||||||
|
)
|
||||||
|
err := rc.Pull(ctx, "test")
|
||||||
|
check(err)
|
||||||
|
if !slices.Equal(reads, []int64{0, 3, 5}) {
|
||||||
|
t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5})
|
||||||
|
}
|
||||||
|
|
||||||
|
mw, err := rc.Resolve(t.Context(), "test")
|
||||||
|
check(err)
|
||||||
|
mg, err := rc.ResolveLocal("test")
|
||||||
|
check(err)
|
||||||
|
if !reflect.DeepEqual(mw, mg) {
|
||||||
|
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||||
|
}
|
||||||
|
for i := range mg.Layers {
|
||||||
|
_, err = c.Get(mg.Layers[i].Digest)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// missing chunks
|
||||||
|
content = "llama"
|
||||||
|
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
|
||||||
|
err = rc.Pull(ctx, "missingchunks")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error because of missing chunks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
|||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||||
if t, err := template.Named(s); err != nil {
|
if t, err := template.Named(s); err != nil {
|
||||||
slog.Debug("template detection", "error", err)
|
slog.Debug("template detection", "error", err, "template", s)
|
||||||
} else {
|
} else {
|
||||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
13
template/gemma3-instruct.gotmpl
Normal file
13
template/gemma3-instruct.gotmpl
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||||
|
{{- if eq .Role "user" }}<start_of_turn>user
|
||||||
|
{{- if and (eq $i 1) $.System }}
|
||||||
|
{{ $.System }}
|
||||||
|
{{ end }}
|
||||||
|
{{ .Content }}<end_of_turn>
|
||||||
|
{{ else if eq .Role "assistant" }}<start_of_turn>model
|
||||||
|
{{ .Content }}<end_of_turn>
|
||||||
|
{{ end }}
|
||||||
|
{{- if $last }}<start_of_turn>model
|
||||||
|
{{ end }}
|
||||||
|
{{- end }}
|
||||||
6
template/gemma3-instruct.json
Normal file
6
template/gemma3-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<end_of_turn>"
|
||||||
|
],
|
||||||
|
"temperature": 0.1
|
||||||
|
}
|
||||||
@@ -87,6 +87,10 @@
|
|||||||
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"name": "gemma-instruct"
|
"name": "gemma-instruct"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
|
||||||
|
"name": "gemma3-instruct"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||||
"name": "llama3-instruct"
|
"name": "llama3-instruct"
|
||||||
|
|||||||
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
You are a helpful assistant.
|
||||||
|
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
I'm doing great. How can I help you today?<end_of_turn>
|
||||||
|
<start_of_turn>user
|
||||||
|
I'd like to show off how chat templating works!<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
I'm doing great. How can I help you today?<end_of_turn>
|
||||||
|
<start_of_turn>user
|
||||||
|
I'd like to show off how chat templating works!<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
Reference in New Issue
Block a user