mirror of
https://github.com/ollama/ollama.git
synced 2026-04-20 07:54:25 +02:00
Use atomic.Int32 for Array.pinned and a sync.Mutex for the global arrays slice so MLX arrays can be safely created and managed from multiple goroutines. Convert Array value receivers to pointer receivers and struct fields from Array to *Array to avoid copying the atomic.
37 lines
722 B
Go
37 lines
722 B
Go
package mlx
|
|
|
|
type Linear struct {
|
|
Weight *Array `weight:"weight"`
|
|
Bias *Array `weight:"bias"`
|
|
}
|
|
|
|
// Forward computes the linear transformation: x @ Weight.T + Bias
|
|
func (m *Linear) Forward(x *Array) *Array {
|
|
w := m.Weight.Transpose(1, 0)
|
|
if m.Bias.Valid() {
|
|
return m.Bias.Addmm(x, w, 1.0, 1.0)
|
|
}
|
|
|
|
return x.Matmul(w)
|
|
}
|
|
|
|
func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
|
|
w := m.Weight.Transpose(0, 2, 1)
|
|
// TODO: bias
|
|
return x.GatherMM(w, lhs, rhs, sorted)
|
|
}
|
|
|
|
type Embedding struct {
|
|
Weight *Array `weight:"weight"`
|
|
}
|
|
|
|
func (e *Embedding) Forward(indices *Array) *Array {
|
|
return e.Weight.TakeAxis(indices, 0)
|
|
}
|
|
|
|
func (e *Embedding) AsLinear() Linear {
|
|
return Linear{
|
|
Weight: e.Weight,
|
|
}
|
|
}
|