checkpoint

This commit is contained in:
ParthSareen
2025-01-23 09:46:14 -08:00
parent a7c8cc06da
commit 6ba557f25b
2 changed files with 198 additions and 56 deletions

View File

@@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"math"
"slices"
"github.com/ollama/ollama/model"
)
@@ -76,9 +75,10 @@ func (s JSONState) String() string {
}
type JSONSampler struct {
curNode *Node
proc model.TextProcessor
stack []*Node
curNode *Node
proc model.TextProcessor
stack []*Node
bracketCounter int
}
func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
@@ -88,23 +88,68 @@ func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
return nil, err
}
js := &JSONSampler{
curNode: startNode,
proc: proc,
curNode: startNode,
proc: proc,
stack: []*Node{},
bracketCounter: 0,
}
return js, nil
}
func isTokenSubset(subset, superset []int32) bool {
freq1 := make(map[int32]int)
freq2 := make(map[int32]int)
for _, v := range subset {
freq1[v]++
}
for _, v := range superset {
freq2[v]++
}
isSubset := true
for k, count1 := range freq1 {
count2 := freq2[k]
if count1 > count2 {
isSubset = false
break
}
}
return isSubset
}
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
// fmt.Printf("Updating state with token: %v\n", tokenSlice)
// fmt.Printf("Current state: %s\n", s.curNode.State)
// fmt.Println("tokenSlice", tokenSlice)
// todo: account for strings here
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
if err != nil {
return err
}
// only move to terminate state if stack is empty
if s.curNode.State == StateEnd {
fmt.Println("debug: node.State", s.curNode.State)
if len(s.stack) > 0 {
s.stack = s.stack[:len(s.stack)-1]
fmt.Println("popped and cur state", s.curNode.State)
return nil
}
return nil
}
for node, edge := range s.curNode.TransitionEdges {
for _, validToken := range edge {
if slices.Equal(tokenSlice, validToken) {
if isTokenSubset(tokenSlice, validToken) {
s.curNode = node
for _, token := range objectTokens {
if isTokenSubset(tokenSlice, token) {
fmt.Println("Appending to stack", s.curNode.State)
s.stack = append(s.stack, s.curNode)
}
}
// fmt.Printf("Transitioned to state: %s\n", node.State)
return nil
}
@@ -120,6 +165,11 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
}
}
fmt.Println("invalid token ", tokenSlice)
dec, err := s.proc.Decode(tokenSlice)
if err != nil {
return err
}
fmt.Println("decoded token ", dec)
return errors.New("invalid token")
}
@@ -164,6 +214,24 @@ func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
}
return logits, nil
case StateInString:
penalizeNewlineVariants := []string{"\n", " \"\n"}
penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
if err != nil {
return nil, err
}
penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
if err != nil {
return nil, err
}
validStates := getValidStates(s.curNode)
logits, err = s.maskLogits(logits, validStates)
if err != nil {
return nil, err
}
return logits, nil
default:
validStates := getValidStates(s.curNode)
logits, err = s.maskLogits(logits, validStates)
@@ -205,3 +273,17 @@ func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float
}
return logits, nil
}
func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
// fmt.Printf("Masking specific logits: %v\n", tokensToMask)
for i := range logits {
for _, token := range tokensToMask {
for _, chunked := range token {
if int(chunked) == i {
logits[i] = math.NaN()
}
}
}
}
return logits, nil
}