mirror of
https://github.com/ollama/ollama.git
synced 2026-04-26 02:36:09 +02:00
checkpoint
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
@@ -76,9 +75,10 @@ func (s JSONState) String() string {
|
||||
}
|
||||
|
||||
type JSONSampler struct {
|
||||
curNode *Node
|
||||
proc model.TextProcessor
|
||||
stack []*Node
|
||||
curNode *Node
|
||||
proc model.TextProcessor
|
||||
stack []*Node
|
||||
bracketCounter int
|
||||
}
|
||||
|
||||
func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
|
||||
@@ -88,23 +88,68 @@ func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
|
||||
return nil, err
|
||||
}
|
||||
js := &JSONSampler{
|
||||
curNode: startNode,
|
||||
proc: proc,
|
||||
curNode: startNode,
|
||||
proc: proc,
|
||||
stack: []*Node{},
|
||||
bracketCounter: 0,
|
||||
}
|
||||
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func isTokenSubset(subset, superset []int32) bool {
|
||||
freq1 := make(map[int32]int)
|
||||
freq2 := make(map[int32]int)
|
||||
|
||||
for _, v := range subset {
|
||||
freq1[v]++
|
||||
}
|
||||
for _, v := range superset {
|
||||
freq2[v]++
|
||||
}
|
||||
isSubset := true
|
||||
for k, count1 := range freq1 {
|
||||
count2 := freq2[k]
|
||||
if count1 > count2 {
|
||||
isSubset = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return isSubset
|
||||
}
|
||||
|
||||
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
||||
// fmt.Printf("Updating state with token: %v\n", tokenSlice)
|
||||
// fmt.Printf("Current state: %s\n", s.curNode.State)
|
||||
|
||||
// fmt.Println("tokenSlice", tokenSlice)
|
||||
// todo: account for strings here
|
||||
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only move to terminate state if stack is empty
|
||||
if s.curNode.State == StateEnd {
|
||||
fmt.Println("debug: node.State", s.curNode.State)
|
||||
if len(s.stack) > 0 {
|
||||
s.stack = s.stack[:len(s.stack)-1]
|
||||
fmt.Println("popped and cur state", s.curNode.State)
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for node, edge := range s.curNode.TransitionEdges {
|
||||
for _, validToken := range edge {
|
||||
if slices.Equal(tokenSlice, validToken) {
|
||||
if isTokenSubset(tokenSlice, validToken) {
|
||||
s.curNode = node
|
||||
for _, token := range objectTokens {
|
||||
if isTokenSubset(tokenSlice, token) {
|
||||
fmt.Println("Appending to stack", s.curNode.State)
|
||||
s.stack = append(s.stack, s.curNode)
|
||||
}
|
||||
}
|
||||
// fmt.Printf("Transitioned to state: %s\n", node.State)
|
||||
return nil
|
||||
}
|
||||
@@ -120,6 +165,11 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
||||
}
|
||||
}
|
||||
fmt.Println("invalid token ", tokenSlice)
|
||||
dec, err := s.proc.Decode(tokenSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("decoded token ", dec)
|
||||
return errors.New("invalid token")
|
||||
}
|
||||
|
||||
@@ -164,6 +214,24 @@ func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateInString:
|
||||
penalizeNewlineVariants := []string{"\n", " \"\n"}
|
||||
penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
|
||||
logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
validStates := getValidStates(s.curNode)
|
||||
logits, err = s.maskLogits(logits, validStates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
default:
|
||||
validStates := getValidStates(s.curNode)
|
||||
logits, err = s.maskLogits(logits, validStates)
|
||||
@@ -205,3 +273,17 @@ func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
|
||||
// fmt.Printf("Masking specific logits: %v\n", tokensToMask)
|
||||
for i := range logits {
|
||||
for _, token := range tokensToMask {
|
||||
for _, chunked := range token {
|
||||
if int(chunked) == i {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user