This commit is contained in:
ParthSareen
2025-03-25 16:45:27 -07:00
parent 5ec6bb52a0
commit 4450f871db
2 changed files with 25 additions and 12 deletions

View File

@@ -108,10 +108,13 @@ type PythonSampler struct {
proc model.TextProcessor
decodedToks []string
curNode *Node
completed int
functions []PythonFunction
}
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
s.proc = proc
s.functions = functions
decodedToks := make([]string, len(proc.Vocab().Values))
for i := range proc.Vocab().Values {
token, err := proc.Decode([]int32{int32(i)})
@@ -194,7 +197,7 @@ func (s *PythonSampler) BuildGraph() error {
// String end
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
// Number
for _, r := range validNumberRunes {
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
@@ -237,6 +240,16 @@ func (s *PythonSampler) UpdateState(token int32) error {
if !ok {
return fmt.Errorf("invalid token: %q", mappedString)
}
if mappedString == "\"" {
if s.curNode.State == PStateInStringEnd {
s.completed++
}
if s.completed == len(s.functions) {
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
s.CreateMask(s.curNode)
}
}
s.curNode = nextNode
fmt.Println("curNode", s.curNode.State)
for r, node := range s.curNode.TransitionEdges {