just for fun

This commit is contained in:
Roy Han
2024-07-25 17:34:11 -07:00
parent 918fd32884
commit 42009d2974
3 changed files with 256 additions and 140 deletions

View File

@@ -1,7 +1,6 @@
package cmd
import (
"encoding/json"
"errors"
"fmt"
"io"
@@ -14,11 +13,6 @@ import (
"strings"
"github.com/spf13/cobra"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/vg"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
@@ -453,140 +447,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/embed"):
line = strings.TrimPrefix(line, "/embed")
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
var strArray []string
fmt.Printf("line is %s\n", line)
err = json.Unmarshal([]byte(line), &strArray)
if err != nil {
fmt.Println("error: couldn't parse input")
return err
}
for i, s := range strArray {
fmt.Printf("strArray[%d] is %s\n", i, s)
}
req := &api.EmbedRequest{
Model: opts.Model,
Input: strArray,
}
resp, err := client.Embed(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get embeddings")
return err
}
embeddings := resp.Embeddings
r, c := len(embeddings), len(embeddings[0])
data := make([]float64, r*c)
for i := 0; i < r; i++ {
for j := 0; j < c; j++ {
data[i*c+j] = float64(embeddings[i][j])
}
}
X := mat.NewDense(r, c, data)
// Initialize PCA
var pca stat.PC
// Perform PCA
if !pca.PrincipalComponents(X, nil) {
return fmt.Errorf("PCA failed")
}
// Extract principal component vectors
var vectors mat.Dense
pca.VectorsTo(&vectors)
// // Extract variances of the principal components
// var variances []float64
// variances = pca.VarsTo(variances)
W := vectors.Slice(0, c, 0, 2).(*mat.Dense)
// Perform PCA reduction
var reducedData mat.Dense
reducedData.Mul(X, W)
// Print the projected 2D points
fmt.Println("Reduced embeddings to 2D:")
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
row := reducedData.RowView(i)
fmt.Printf("[%v, %v]\n", row.AtVec(0), row.AtVec(1))
}
points := make(plotter.XYs, reducedData.RawMatrix().Rows)
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
row := reducedData.RowView(i)
points[i].X = row.AtVec(0)
points[i].Y = row.AtVec(1)
}
// Create a new plot
p := plot.New()
// Set plot title and axis labels
p.Title.Text = "2D Data Plot"
p.X.Label.Text = "X"
p.Y.Label.Text = "Y"
// Create a scatter plot of the points
s, err := plotter.NewScatter(points)
if err != nil {
panic(err)
}
p.Add(s)
/// Create labels plotter and add it to the plot
labels := make([]string, reducedData.RawMatrix().Rows)
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
labels[i] = fmt.Sprintf("%d", i+1)
}
l, err := plotter.NewLabels(plotter.XYLabels{XYs: points, Labels: labels})
if err != nil {
panic(err)
}
p.Add(l)
// Make the grid square
p.X.Min = -1
p.X.Max = 1
p.Y.Min = -1
p.Y.Max = 1
// Set the aspect ratio to be 1:1
p.X.Tick.Marker = plot.ConstantTicks([]plot.Tick{
{Value: -1, Label: "-1"},
{Value: -0.5, Label: "-0.5"},
{Value: 0, Label: "0"},
{Value: 0.5, Label: "0.5"},
{Value: 1, Label: "1"},
})
p.Y.Tick.Marker = plot.ConstantTicks([]plot.Tick{
{Value: -1, Label: "-1"},
{Value: -0.5, Label: "-0.5"},
{Value: 0, Label: "0"},
{Value: 0.5, Label: "0.5"},
{Value: 1, Label: "1"},
})
// Save the plot to a PNG file
if err := p.Save(6*vg.Inch, 6*vg.Inch, "plot.png"); err != nil {
panic(err)
}
case strings.HasPrefix(line, "/"):
args := strings.Fields(line)
isFile := false