mirror of
https://github.com/ollama/ollama.git
synced 2026-04-22 16:55:44 +02:00
just for fun
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -14,11 +13,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"gonum.org/v1/gonum/mat"
|
||||
"gonum.org/v1/gonum/stat"
|
||||
"gonum.org/v1/plot"
|
||||
"gonum.org/v1/plot/plotter"
|
||||
"gonum.org/v1/plot/vg"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -453,140 +447,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/embed"):
|
||||
line = strings.TrimPrefix(line, "/embed")
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
|
||||
var strArray []string
|
||||
fmt.Printf("line is %s\n", line)
|
||||
err = json.Unmarshal([]byte(line), &strArray)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't parse input")
|
||||
return err
|
||||
}
|
||||
|
||||
for i, s := range strArray {
|
||||
fmt.Printf("strArray[%d] is %s\n", i, s)
|
||||
}
|
||||
|
||||
req := &api.EmbedRequest{
|
||||
Model: opts.Model,
|
||||
Input: strArray,
|
||||
}
|
||||
|
||||
resp, err := client.Embed(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get embeddings")
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings := resp.Embeddings
|
||||
|
||||
r, c := len(embeddings), len(embeddings[0])
|
||||
data := make([]float64, r*c)
|
||||
for i := 0; i < r; i++ {
|
||||
for j := 0; j < c; j++ {
|
||||
data[i*c+j] = float64(embeddings[i][j])
|
||||
}
|
||||
}
|
||||
|
||||
X := mat.NewDense(r, c, data)
|
||||
|
||||
// Initialize PCA
|
||||
var pca stat.PC
|
||||
|
||||
// Perform PCA
|
||||
if !pca.PrincipalComponents(X, nil) {
|
||||
return fmt.Errorf("PCA failed")
|
||||
}
|
||||
|
||||
// Extract principal component vectors
|
||||
var vectors mat.Dense
|
||||
pca.VectorsTo(&vectors)
|
||||
|
||||
// // Extract variances of the principal components
|
||||
// var variances []float64
|
||||
// variances = pca.VarsTo(variances)
|
||||
|
||||
W := vectors.Slice(0, c, 0, 2).(*mat.Dense)
|
||||
|
||||
// Perform PCA reduction
|
||||
var reducedData mat.Dense
|
||||
reducedData.Mul(X, W)
|
||||
|
||||
// Print the projected 2D points
|
||||
fmt.Println("Reduced embeddings to 2D:")
|
||||
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
|
||||
row := reducedData.RowView(i)
|
||||
fmt.Printf("[%v, %v]\n", row.AtVec(0), row.AtVec(1))
|
||||
}
|
||||
|
||||
points := make(plotter.XYs, reducedData.RawMatrix().Rows)
|
||||
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
|
||||
row := reducedData.RowView(i)
|
||||
points[i].X = row.AtVec(0)
|
||||
points[i].Y = row.AtVec(1)
|
||||
}
|
||||
|
||||
// Create a new plot
|
||||
p := plot.New()
|
||||
|
||||
// Set plot title and axis labels
|
||||
p.Title.Text = "2D Data Plot"
|
||||
p.X.Label.Text = "X"
|
||||
p.Y.Label.Text = "Y"
|
||||
|
||||
// Create a scatter plot of the points
|
||||
s, err := plotter.NewScatter(points)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.Add(s)
|
||||
|
||||
/// Create labels plotter and add it to the plot
|
||||
|
||||
labels := make([]string, reducedData.RawMatrix().Rows)
|
||||
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
|
||||
labels[i] = fmt.Sprintf("%d", i+1)
|
||||
}
|
||||
|
||||
l, err := plotter.NewLabels(plotter.XYLabels{XYs: points, Labels: labels})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.Add(l)
|
||||
|
||||
// Make the grid square
|
||||
p.X.Min = -1
|
||||
p.X.Max = 1
|
||||
p.Y.Min = -1
|
||||
p.Y.Max = 1
|
||||
|
||||
// Set the aspect ratio to be 1:1
|
||||
p.X.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
||||
{Value: -1, Label: "-1"},
|
||||
{Value: -0.5, Label: "-0.5"},
|
||||
{Value: 0, Label: "0"},
|
||||
{Value: 0.5, Label: "0.5"},
|
||||
{Value: 1, Label: "1"},
|
||||
})
|
||||
p.Y.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
||||
{Value: -1, Label: "-1"},
|
||||
{Value: -0.5, Label: "-0.5"},
|
||||
{Value: 0, Label: "0"},
|
||||
{Value: 0.5, Label: "0.5"},
|
||||
{Value: 1, Label: "1"},
|
||||
})
|
||||
|
||||
// Save the plot to a PNG file
|
||||
if err := p.Save(6*vg.Inch, 6*vg.Inch, "plot.png"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
case strings.HasPrefix(line, "/"):
|
||||
args := strings.Fields(line)
|
||||
isFile := false
|
||||
|
||||
Reference in New Issue
Block a user