mirror of
https://github.com/ollama/ollama.git
synced 2026-04-21 16:25:42 +02:00
the updated interface supports variadic attention options which removes the need for individual `AttentionWith...` functions. it means more models can use the attention interface, e.g. models with custom masks, logit softcapping, etc. additionally, this interface should be less error prone since there are now reasonable defaults for all optional parameters
216 lines
6.9 KiB
Go
216 lines
6.9 KiB
Go
package deepseekocr
|
|
|
|
import (
|
|
"slices"
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
"github.com/ollama/ollama/ml/nn/attention"
|
|
)
|
|
|
|
type samModel struct {
|
|
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
|
PositionEmbedding ml.Tensor `gguf:"position_embd"`
|
|
|
|
Blocks []samBlock `gguf:"blk"`
|
|
|
|
Neck *samNeck `gguf:"neck"`
|
|
Net2 *nn.Conv2D `gguf:"net_2"`
|
|
Net3 *nn.Conv2D `gguf:"net_3"`
|
|
|
|
Options samOptions
|
|
}
|
|
|
|
func (m *samModel) absolutePositionEmbedding(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
|
source := m.PositionEmbedding.Dim(1)
|
|
target := hiddenStates.Dim(2)
|
|
if source != target {
|
|
positionEmbed := m.PositionEmbedding.Permute(ctx, 2, 0, 1, 3)
|
|
positionEmbed = positionEmbed.Interpolate(ctx, [4]int{target, target, hiddenStates.Dim(0), 1}, ml.SamplingModeBilinear)
|
|
return positionEmbed.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
}
|
|
|
|
return m.PositionEmbedding
|
|
}
|
|
|
|
func (m *samModel) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
|
hiddenStates := m.PatchEmbedding.Forward(ctx, t, 16, 16, 0, 0, 1, 1)
|
|
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
|
|
if m.PositionEmbedding != nil {
|
|
hiddenStates = hiddenStates.Add(ctx, m.absolutePositionEmbedding(ctx, hiddenStates))
|
|
}
|
|
|
|
for i, block := range m.Blocks {
|
|
var windowSize int
|
|
if !slices.Contains(m.Options.globalAttentionLayers, int32(i)) {
|
|
windowSize = 14
|
|
}
|
|
|
|
hiddenStates = block.Forward(ctx, hiddenStates, windowSize, m.Options)
|
|
}
|
|
|
|
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
|
hiddenStates = m.Neck.Forward(ctx, hiddenStates, m.Options)
|
|
hiddenStates = m.Net2.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
|
|
hiddenStates = m.Net3.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
|
|
return hiddenStates
|
|
}
|
|
|
|
type samOptions struct {
|
|
hiddenSize,
|
|
numHeads int
|
|
eps float32
|
|
globalAttentionLayers []int32
|
|
}
|
|
|
|
func (o samOptions) headDim() int {
|
|
return o.hiddenSize / o.numHeads
|
|
}
|
|
|
|
type samBlock struct {
|
|
Norm1 *nn.LayerNorm `gguf:"norm1"`
|
|
Attention *samAttention `gguf:"attn"`
|
|
Norm2 *nn.LayerNorm `gguf:"norm2"`
|
|
FeedForward *samMLP `gguf:"mlp"`
|
|
}
|
|
|
|
func (m *samBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, windowSize int, opts samOptions) ml.Tensor {
|
|
c, w, h := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
|
|
|
residual := hiddenStates
|
|
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
|
|
|
var pw, ph int
|
|
if windowSize > 0 {
|
|
pw = (windowSize - hiddenStates.Dim(1)%windowSize) % windowSize
|
|
ph = (windowSize - hiddenStates.Dim(2)%windowSize) % windowSize
|
|
if pw > 0 || ph > 0 {
|
|
hiddenStates = hiddenStates.Pad(ctx, 0, pw, ph, 0)
|
|
}
|
|
|
|
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, (w+pw)/windowSize, windowSize, -1)
|
|
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, c, windowSize, windowSize, -1)
|
|
}
|
|
|
|
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
|
|
|
|
if windowSize > 0 {
|
|
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, windowSize, (w+pw)/windowSize, -1)
|
|
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3)
|
|
hiddenStates = hiddenStates.Contiguous(ctx, c, w+pw, h+ph, -1)
|
|
hiddenStates = hiddenStates.Pad(ctx, 0, -pw, -ph, 0)
|
|
}
|
|
|
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
|
|
|
residual = hiddenStates
|
|
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
|
|
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
|
|
return hiddenStates.Add(ctx, residual)
|
|
}
|
|
|
|
type samAttention struct {
|
|
QKV *nn.Linear `gguf:"qkv"`
|
|
Output *nn.Linear `gguf:"proj"`
|
|
|
|
RelativePosition *struct {
|
|
Height ml.Tensor `gguf:"h"`
|
|
Width ml.Tensor `gguf:"w"`
|
|
} `gguf:",pre:rel_pos_"`
|
|
}
|
|
|
|
func relativeCoordinates(ctx ml.Context, qn, kn int) ml.Tensor {
|
|
s := make([]int32, qn*kn)
|
|
for i := range qn {
|
|
for j := range kn {
|
|
q := i * max(kn/qn, 1)
|
|
k := j * max(qn/kn, 1)
|
|
s[i*kn+j] = int32(q - k + (kn-1)*max(qn/kn, 1))
|
|
}
|
|
}
|
|
return ctx.Input().FromInts(s, qn*kn)
|
|
}
|
|
|
|
func relativePositions(ctx ml.Context, positions ml.Tensor, qn, kn int) ml.Tensor {
|
|
maxRelativeDistance := 2*max(qn, kn) - 1
|
|
if positions.Dim(1) != maxRelativeDistance {
|
|
// linear interpolation kernel not available so approx. with bilinear interpolation
|
|
positions = positions.Interpolate(ctx, [4]int{positions.Dim(0), maxRelativeDistance, 1, 1}, ml.SamplingModeBilinear)
|
|
}
|
|
|
|
rc := relativeCoordinates(ctx, qn, kn)
|
|
return positions.Rows(ctx, rc).Reshape(ctx, positions.Dim(0), kn, qn)
|
|
}
|
|
|
|
func (m *samAttention) decomposedRelativePositions(ctx ml.Context, query ml.Tensor, qn, kn []int) (ml.Tensor, ml.Tensor) {
|
|
qh, qw := qn[0], qn[1]
|
|
kh, kw := kn[0], kn[1]
|
|
|
|
rh := relativePositions(ctx, m.RelativePosition.Height, qh, kh)
|
|
rw := relativePositions(ctx, m.RelativePosition.Width, qw, kw)
|
|
|
|
query = query.Contiguous(ctx, query.Dim(0), qw, qh, -1)
|
|
rh = rh.Mulmat(ctx, query).Reshape(ctx, 1, kh, qh*qw, -1)
|
|
rw = rw.Mulmat(ctx, query.Permute(ctx, 0, 2, 1, 3)).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, kw, 1, qh*qw, -1)
|
|
return rh, rw
|
|
}
|
|
|
|
func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
|
|
w, h, b := hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
|
|
|
|
qkv := m.QKV.Forward(ctx, hiddenStates)
|
|
qkv = qkv.Reshape(ctx, opts.headDim(), -1, w*h, b)
|
|
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
|
|
query, key, value := chunks[0], chunks[1], chunks[2]
|
|
|
|
ctx.Forward(query, key, value)
|
|
|
|
rh, rw := m.decomposedRelativePositions(ctx, query.Permute(ctx, 0, 2, 1, 3), []int{h, w}, []int{h, w})
|
|
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
|
|
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
|
|
|
|
hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask))
|
|
hiddenStates = hiddenStates.Contiguous(ctx, -1, w, h, b)
|
|
return m.Output.Forward(ctx, hiddenStates)
|
|
}
|
|
|
|
type samMLP struct {
|
|
Lin1 *nn.Linear `gguf:"lin1"`
|
|
Lin2 *nn.Linear `gguf:"lin2"`
|
|
}
|
|
|
|
func (m *samMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
|
|
return m.Lin2.Forward(ctx, m.Lin1.Forward(ctx, hiddenStates).GELU(ctx))
|
|
}
|
|
|
|
type LayerNorm2D struct {
|
|
Weight ml.Tensor `gguf:"weight"`
|
|
Bias ml.Tensor `gguf:"bias"`
|
|
}
|
|
|
|
func (ln *LayerNorm2D) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
|
|
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
u := x.Mean(ctx)
|
|
d := x.Sub(ctx, u)
|
|
s := d.Sqr(ctx).Mean(ctx)
|
|
x = d.Div(ctx, s.Add(ctx, ctx.Input().FromFloats([]float32{eps}, 1)).Sqrt(ctx))
|
|
x = x.Mul(ctx, ln.Weight).Add(ctx, ln.Bias)
|
|
return x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
|
}
|
|
|
|
type samNeck struct {
|
|
C1 *nn.Conv2D `gguf:"0"`
|
|
LN1 *LayerNorm2D `gguf:"1"`
|
|
C2 *nn.Conv2D `gguf:"2"`
|
|
LN2 *LayerNorm2D `gguf:"3"`
|
|
}
|
|
|
|
func (m *samNeck) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
|
|
hiddenStates = m.C1.Forward(ctx, hiddenStates, 1, 1, 0, 0, 1, 1)
|
|
hiddenStates = m.LN1.Forward(ctx, hiddenStates, opts.eps)
|
|
hiddenStates = m.C2.Forward(ctx, hiddenStates, 1, 1, 1, 1, 1, 1)
|
|
hiddenStates = m.LN2.Forward(ctx, hiddenStates, opts.eps)
|
|
return hiddenStates
|
|
}
|