mirror of
https://github.com/ollama/ollama.git
synced 2026-04-19 20:54:25 +02:00
Compare commits
3 Commits
parth/fix-
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc3ac5fee3 | ||
|
|
b5d0f72f16 | ||
|
|
148a1be0a3 |
@@ -26,6 +26,16 @@ irm https://claude.ai/install.ps1 | iex
|
|||||||
|
|
||||||
## Usage with Ollama
|
## Usage with Ollama
|
||||||
|
|
||||||
|
Configure Claude Code to use Ollama:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama config claude
|
||||||
|
```
|
||||||
|
|
||||||
|
This will prompt you to select a model and automatically configure Claude Code to use Ollama.
|
||||||
|
|
||||||
|
<Accordion title="Manual Configuration">
|
||||||
|
|
||||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||||
|
|
||||||
1. Set the environment variables:
|
1. Set the environment variables:
|
||||||
@@ -47,7 +57,9 @@ Or run with environment variables inline:
|
|||||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
|
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
</Accordion>
|
||||||
|
|
||||||
|
<Note>Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.</Note>
|
||||||
|
|
||||||
## Connecting to ollama.com
|
## Connecting to ollama.com
|
||||||
|
|
||||||
|
|||||||
@@ -2,22 +2,31 @@
|
|||||||
title: Codex
|
title: Codex
|
||||||
---
|
---
|
||||||
|
|
||||||
|
Codex is OpenAI's agentic coding tool for the command line.
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
Install the [Codex CLI](https://developers.openai.com/codex/cli/):
|
Install the [Codex CLI](https://developers.openai.com/codex/cli/):
|
||||||
|
|
||||||
```
|
```shell
|
||||||
npm install -g @openai/codex
|
npm install -g @openai/codex
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage with Ollama
|
## Usage with Ollama
|
||||||
|
|
||||||
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
|
Configure Codex to use Ollama:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama config codex
|
||||||
|
```
|
||||||
|
|
||||||
|
This will prompt you to select a model and automatically configure Codex to use Ollama.
|
||||||
|
|
||||||
|
<Accordion title="Manual Configuration">
|
||||||
|
|
||||||
To use `codex` with Ollama, use the `--oss` flag:
|
To use `codex` with Ollama, use the `--oss` flag:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
codex --oss
|
codex --oss
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -25,20 +34,22 @@ codex --oss
|
|||||||
|
|
||||||
By default, codex will use the local `gpt-oss:20b` model. However, you can specify a different model with the `-m` flag:
|
By default, codex will use the local `gpt-oss:20b` model. However, you can specify a different model with the `-m` flag:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
codex --oss -m gpt-oss:120b
|
codex --oss -m gpt-oss:120b
|
||||||
```
|
```
|
||||||
|
|
||||||
### Cloud Models
|
### Cloud Models
|
||||||
|
|
||||||
```
|
```shell
|
||||||
codex --oss -m gpt-oss:120b-cloud
|
codex --oss -m gpt-oss:120b-cloud
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</Accordion>
|
||||||
|
|
||||||
|
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
|
||||||
|
|
||||||
## Connecting to ollama.com
|
## Connecting to ollama.com
|
||||||
|
|
||||||
|
|
||||||
Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
|
Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
|
||||||
|
|
||||||
To use ollama.com directly, edit your `~/.codex/config.toml` file to point to ollama.com.
|
To use ollama.com directly, edit your `~/.codex/config.toml` file to point to ollama.com.
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
title: Droid
|
title: Droid
|
||||||
---
|
---
|
||||||
|
|
||||||
|
Droid is Factory's agentic coding tool for the command line.
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
@@ -11,63 +12,77 @@ Install the [Droid CLI](https://factory.ai/):
|
|||||||
curl -fsSL https://app.factory.ai/cli | sh
|
curl -fsSL https://app.factory.ai/cli | sh
|
||||||
```
|
```
|
||||||
|
|
||||||
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
|
|
||||||
|
|
||||||
## Usage with Ollama
|
## Usage with Ollama
|
||||||
|
|
||||||
Add a local configuration block to `~/.factory/config.json`:
|
Configure Droid to use Ollama:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama config droid
|
||||||
|
```
|
||||||
|
|
||||||
|
This will prompt you to select models and automatically configure Droid to use Ollama.
|
||||||
|
|
||||||
|
<Accordion title="Manual Configuration">
|
||||||
|
|
||||||
|
Add a local configuration block to `~/.factory/settings.json`:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"custom_models": [
|
"customModels": [
|
||||||
{
|
{
|
||||||
"model_display_name": "qwen3-coder [Ollama]",
|
|
||||||
"model": "qwen3-coder",
|
"model": "qwen3-coder",
|
||||||
"base_url": "http://localhost:11434/v1/",
|
"displayName": "qwen3-coder [Ollama]",
|
||||||
"api_key": "not-needed",
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"apiKey": "ollama",
|
||||||
"provider": "generic-chat-completion-api",
|
"provider": "generic-chat-completion-api",
|
||||||
"max_tokens": 32000
|
"maxOutputTokens": 32000
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Adjust `maxOutputTokens` based on your model's context length (the automated setup detects this automatically).
|
||||||
|
|
||||||
|
### Cloud Models
|
||||||
|
|
||||||
## Cloud Models
|
|
||||||
`qwen3-coder:480b-cloud` is the recommended model for use with Droid.
|
`qwen3-coder:480b-cloud` is the recommended model for use with Droid.
|
||||||
|
|
||||||
Add the cloud configuration block to `~/.factory/config.json`:
|
Add the cloud configuration block to `~/.factory/settings.json`:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"custom_models": [
|
"customModels": [
|
||||||
{
|
{
|
||||||
"model_display_name": "qwen3-coder [Ollama Cloud]",
|
|
||||||
"model": "qwen3-coder:480b-cloud",
|
"model": "qwen3-coder:480b-cloud",
|
||||||
"base_url": "http://localhost:11434/v1/",
|
"displayName": "qwen3-coder:480b-cloud [Ollama]",
|
||||||
"api_key": "not-needed",
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"apiKey": "ollama",
|
||||||
"provider": "generic-chat-completion-api",
|
"provider": "generic-chat-completion-api",
|
||||||
"max_tokens": 128000
|
"maxOutputTokens": 128000
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</Accordion>
|
||||||
|
|
||||||
|
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
|
||||||
|
|
||||||
## Connecting to ollama.com
|
## Connecting to ollama.com
|
||||||
|
|
||||||
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
|
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
|
||||||
2. Add the cloud configuration block to `~/.factory/config.json`:
|
2. Add the cloud configuration block to `~/.factory/settings.json`:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"custom_models": [
|
"customModels": [
|
||||||
{
|
{
|
||||||
"model_display_name": "qwen3-coder [Ollama Cloud]",
|
|
||||||
"model": "qwen3-coder:480b",
|
"model": "qwen3-coder:480b",
|
||||||
"base_url": "https://ollama.com/v1/",
|
"displayName": "qwen3-coder:480b [Ollama Cloud]",
|
||||||
"api_key": "OLLAMA_API_KEY",
|
"baseUrl": "https://ollama.com/v1",
|
||||||
|
"apiKey": "OLLAMA_API_KEY",
|
||||||
"provider": "generic-chat-completion-api",
|
"provider": "generic-chat-completion-api",
|
||||||
"max_tokens": 128000
|
"maxOutputTokens": 128000
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
63
docs/integrations/opencode.mdx
Normal file
63
docs/integrations/opencode.mdx
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
---
|
||||||
|
title: OpenCode
|
||||||
|
---
|
||||||
|
|
||||||
|
OpenCode is an agentic coding tool for the terminal.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
Install [OpenCode](https://opencode.ai):
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -fsSL https://opencode.ai/install | bash
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage with Ollama
|
||||||
|
|
||||||
|
Configure OpenCode to use Ollama:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama config opencode
|
||||||
|
```
|
||||||
|
|
||||||
|
This will prompt you to select models and automatically configure OpenCode to use Ollama.
|
||||||
|
|
||||||
|
<Accordion title="Manual Configuration">
|
||||||
|
|
||||||
|
Add the Ollama provider to `~/.config/opencode/opencode.json`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"$schema": "https://opencode.ai/config.json",
|
||||||
|
"provider": {
|
||||||
|
"ollama": {
|
||||||
|
"npm": "@ai-sdk/openai-compatible",
|
||||||
|
"name": "Ollama (local)",
|
||||||
|
"options": {
|
||||||
|
"baseURL": "http://localhost:11434/v1"
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"qwen3-coder": {
|
||||||
|
"name": "qwen3-coder [Ollama]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
</Accordion>
|
||||||
|
|
||||||
|
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
|
||||||
|
|
||||||
|
## Recommended Models
|
||||||
|
|
||||||
|
### Cloud models
|
||||||
|
- `qwen3-coder:480b` - Large coding model
|
||||||
|
- `glm-4.7:cloud` - High-performance cloud model
|
||||||
|
- `minimax-m2.1:cloud` - Fast cloud model
|
||||||
|
|
||||||
|
### Local models
|
||||||
|
- `qwen3-coder` - Excellent for coding tasks
|
||||||
|
- `gpt-oss:20b` - Strong general-purpose model
|
||||||
|
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package server
|
package manifest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
@@ -14,7 +14,7 @@ type Layer struct {
|
|||||||
Size int64 `json:"size"`
|
Size int64 `json:"size"`
|
||||||
From string `json:"from,omitempty"`
|
From string `json:"from,omitempty"`
|
||||||
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
|
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
|
||||||
status string
|
Status string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -22,7 +22,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||||
blobs, err := GetBlobsPath("")
|
blobs, err := BlobsPath("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Layer{}, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
@@ -45,7 +45,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
||||||
blob, err := GetBlobsPath(digest)
|
blob, err := BlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Layer{}, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
@@ -65,7 +65,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
|||||||
MediaType: mediatype,
|
MediaType: mediatype,
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Size: n,
|
Size: n,
|
||||||
status: fmt.Sprintf("%s %s", status, digest),
|
Status: fmt.Sprintf("%s %s", status, digest),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,7 +74,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
|||||||
return Layer{}, errors.New("creating new layer from layer with empty digest")
|
return Layer{}, errors.New("creating new layer from layer with empty digest")
|
||||||
}
|
}
|
||||||
|
|
||||||
blob, err := GetBlobsPath(digest)
|
blob, err := BlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Layer{}, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
@@ -89,7 +89,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
|||||||
Digest: digest,
|
Digest: digest,
|
||||||
Size: fi.Size(),
|
Size: fi.Size(),
|
||||||
From: from,
|
From: from,
|
||||||
status: fmt.Sprintf("using existing layer %s", digest),
|
Status: fmt.Sprintf("using existing layer %s", digest),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
|||||||
return nil, errors.New("opening layer with empty digest")
|
return nil, errors.New("opening layer with empty digest")
|
||||||
}
|
}
|
||||||
|
|
||||||
blob, err := GetBlobsPath(l.Digest)
|
blob, err := BlobsPath(l.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -126,7 +126,7 @@ func (l *Layer) Remove() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
blob, err := GetBlobsPath(l.Digest)
|
blob, err := BlobsPath(l.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
package server
|
package manifest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -33,12 +32,38 @@ func (m *Manifest) Size() (size int64) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manifest) Digest() string {
|
||||||
|
return m.digest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manifest) FileInfo() os.FileInfo {
|
||||||
|
return m.fi
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadConfigJSON reads and unmarshals a config layer as JSON.
|
||||||
|
func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
|
||||||
|
for _, layer := range m.Layers {
|
||||||
|
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
|
||||||
|
blobPath, err := BlobsPath(layer.Digest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(blobPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return json.Unmarshal(data, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("config %q not found in manifest", configPath)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manifest) Remove() error {
|
func (m *Manifest) Remove() error {
|
||||||
if err := os.Remove(m.filepath); err != nil {
|
if err := os.Remove(m.filepath); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
manifests, err := GetManifestPath()
|
manifests, err := Path()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -70,11 +95,11 @@ func (m *Manifest) RemoveLayers() error {
|
|||||||
if _, used := inUse[layer.Digest]; used {
|
if _, used := inUse[layer.Digest]; used {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
blob, err := GetBlobsPath(layer.Digest)
|
blob, err := BlobsPath(layer.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
if err := os.Remove(blob); os.IsNotExist(err) {
|
||||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -89,7 +114,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||||||
return nil, model.Unqualified(n)
|
return nil, model.Unqualified(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
manifests, err := GetManifestPath()
|
manifests, err := Path()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -121,7 +146,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||||
manifests, err := GetManifestPath()
|
manifests, err := Path()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -148,7 +173,7 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||||
manifests, err := GetManifestPath()
|
manifests, err := Path()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package server
|
package manifest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
95
manifest/paths.go
Normal file
95
manifest/paths.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package manifest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||||
|
|
||||||
|
func Path() (string, error) {
|
||||||
|
path := filepath.Join(envconfig.Models(), "manifests")
|
||||||
|
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||||
|
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathForName returns the path to the manifest file for a specific model name.
|
||||||
|
func PathForName(n model.Name) (string, error) {
|
||||||
|
if !n.IsValid() {
|
||||||
|
return "", os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
manifests, err := Path()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(manifests, n.Filepath()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BlobsPath(digest string) (string, error) {
|
||||||
|
// only accept actual sha256 digests
|
||||||
|
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||||
|
re := regexp.MustCompile(pattern)
|
||||||
|
|
||||||
|
if digest != "" && !re.MatchString(digest) {
|
||||||
|
return "", ErrInvalidDigestFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
digest = strings.ReplaceAll(digest, ":", "-")
|
||||||
|
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
||||||
|
dirPath := filepath.Dir(path)
|
||||||
|
if digest == "" {
|
||||||
|
dirPath = path
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||||
|
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PruneDirectory removes empty directories recursively.
|
||||||
|
func PruneDirectory(path string) error {
|
||||||
|
info, err := os.Lstat(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
||||||
|
entries, err := os.ReadDir(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err = os.ReadDir(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(entries) > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.Remove(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
ofs "github.com/ollama/ollama/fs"
|
ofs "github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
@@ -90,7 +91,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
ch <- resp
|
ch <- resp
|
||||||
}
|
}
|
||||||
|
|
||||||
oldManifest, _ := ParseNamedManifest(name)
|
oldManifest, _ := manifest.ParseNamedManifest(name)
|
||||||
|
|
||||||
var baseLayers []*layerGGML
|
var baseLayers []*layerGGML
|
||||||
var err error
|
var err error
|
||||||
@@ -123,9 +124,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
|
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
|
||||||
manifest, mErr := ParseNamedManifest(fromName)
|
mf, mErr := manifest.ParseNamedManifest(fromName)
|
||||||
if mErr == nil && manifest.Config.Digest != "" {
|
if mErr == nil && mf.Config.Digest != "" {
|
||||||
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
|
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
|
||||||
if pErr == nil {
|
if pErr == nil {
|
||||||
if cfgFile, fErr := os.Open(configPath); fErr == nil {
|
if cfgFile, fErr := os.Open(configPath); fErr == nil {
|
||||||
var baseConfig model.ConfigV2
|
var baseConfig model.ConfigV2
|
||||||
@@ -342,7 +343,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
|
|||||||
return "gguf"
|
return "gguf"
|
||||||
} else {
|
} else {
|
||||||
// try to see if we can find a gguf file even without the file extension
|
// try to see if we can find a gguf file even without the file extension
|
||||||
blobPath, err := GetBlobsPath(files[fn])
|
blobPath, err := manifest.BlobsPath(files[fn])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("error getting blobs path", "file", fn)
|
slog.Error("error getting blobs path", "file", fn)
|
||||||
return ""
|
return ""
|
||||||
@@ -394,7 +395,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
|||||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
blobPath, err := GetBlobsPath(digest)
|
blobPath, err := manifest.BlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -432,7 +433,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
layer, err := NewLayer(t, mediaType)
|
layer, err := manifest.NewLayer(t, mediaType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -465,7 +466,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||||
var layers []Layer
|
var layers []manifest.Layer
|
||||||
for _, layer := range baseLayers {
|
for _, layer := range baseLayers {
|
||||||
if layer.GGML != nil {
|
if layer.GGML != nil {
|
||||||
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
||||||
@@ -550,13 +551,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if layer.status != "" {
|
if layer.Status != "" {
|
||||||
fn(api.ProgressResponse{Status: layer.status})
|
fn(api.ProgressResponse{Status: layer.Status})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||||
if err := WriteManifest(name, *configLayer, layers); err != nil {
|
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -577,7 +578,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
blob, err := GetBlobsPath(layer.Digest)
|
blob, err := manifest.BlobsPath(layer.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -599,7 +600,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
|||||||
}
|
}
|
||||||
temp.Seek(0, io.SeekStart)
|
temp.Seek(0, io.SeekStart)
|
||||||
fn(api.ProgressResponse{Status: "verifying conversion"})
|
fn(api.ProgressResponse{Status: "verifying conversion"})
|
||||||
newLayer, err := NewLayer(temp, layer.MediaType)
|
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -619,7 +620,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
|||||||
var layers []*layerGGML
|
var layers []*layerGGML
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "parsing GGUF"})
|
fn(api.ProgressResponse{Status: "parsing GGUF"})
|
||||||
blobPath, err := GetBlobsPath(digest)
|
blobPath, err := manifest.BlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -654,7 +655,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
|||||||
mediatype = "application/vnd.ollama.image.projector"
|
mediatype = "application/vnd.ollama.image.projector"
|
||||||
}
|
}
|
||||||
|
|
||||||
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
|
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug("could not create new layer from layer", "error", err)
|
slog.Debug("could not create new layer from layer", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -665,8 +666,8 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
|||||||
return detectChatTemplate(layers)
|
return detectChatTemplate(layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeLayer(layers []Layer, mediatype string) []Layer {
|
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
|
||||||
return slices.DeleteFunc(layers, func(layer Layer) bool {
|
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
|
||||||
if layer.MediaType != mediatype {
|
if layer.MediaType != mediatype {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -680,7 +681,7 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
|
||||||
layers = removeLayer(layers, "application/vnd.ollama.image.template")
|
layers = removeLayer(layers, "application/vnd.ollama.image.template")
|
||||||
if _, err := template.Parse(t); err != nil {
|
if _, err := template.Parse(t); err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||||
@@ -690,7 +691,7 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
blob := strings.NewReader(t)
|
blob := strings.NewReader(t)
|
||||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.template")
|
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -699,11 +700,11 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
|||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setSystem(layers []Layer, s string) ([]Layer, error) {
|
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
|
||||||
layers = removeLayer(layers, "application/vnd.ollama.image.system")
|
layers = removeLayer(layers, "application/vnd.ollama.image.system")
|
||||||
if s != "" {
|
if s != "" {
|
||||||
blob := strings.NewReader(s)
|
blob := strings.NewReader(s)
|
||||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.system")
|
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -712,9 +713,9 @@ func setSystem(layers []Layer, s string) ([]Layer, error) {
|
|||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setLicense(layers []Layer, l string) ([]Layer, error) {
|
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
|
||||||
blob := strings.NewReader(l)
|
blob := strings.NewReader(l)
|
||||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.license")
|
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -722,7 +723,7 @@ func setLicense(layers []Layer, l string) ([]Layer, error) {
|
|||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
p = make(map[string]any)
|
p = make(map[string]any)
|
||||||
}
|
}
|
||||||
@@ -731,7 +732,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
digestPath, err := GetBlobsPath(layer.Digest)
|
digestPath, err := manifest.BlobsPath(layer.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -765,7 +766,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
|||||||
if err := json.NewEncoder(&b).Encode(p); err != nil {
|
if err := json.NewEncoder(&b).Encode(p); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -773,7 +774,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
|||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
|
||||||
// this leaves the old messages intact if no new messages were specified
|
// this leaves the old messages intact if no new messages were specified
|
||||||
// which may not be the correct behaviour
|
// which may not be the correct behaviour
|
||||||
if len(m) == 0 {
|
if len(m) == 0 {
|
||||||
@@ -786,7 +787,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
|||||||
if err := json.NewEncoder(&b).Encode(m); err != nil {
|
if err := json.NewEncoder(&b).Encode(m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
|
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -794,7 +795,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
|||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
|
||||||
digests := make([]string, len(layers))
|
digests := make([]string, len(layers))
|
||||||
for i, layer := range layers {
|
for i, layer := range layers {
|
||||||
digests[i] = layer.Digest
|
digests[i] = layer.Digest
|
||||||
@@ -805,7 +806,7 @@ func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
|||||||
if err := json.NewEncoder(&b).Encode(config); err != nil {
|
if err := json.NewEncoder(&b).Encode(config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConvertFromSafetensors(t *testing.T) {
|
func TestConvertFromSafetensors(t *testing.T) {
|
||||||
@@ -17,7 +18,7 @@ func TestConvertFromSafetensors(t *testing.T) {
|
|||||||
|
|
||||||
// Helper function to create a new layer and return its digest
|
// Helper function to create a new layer and return its digest
|
||||||
makeTemp := func(content string) string {
|
makeTemp := func(content string) string {
|
||||||
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create layer: %v", err)
|
t.Fatalf("Failed to create layer: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxRetries = 6
|
const maxRetries = 6
|
||||||
@@ -456,7 +458,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
|||||||
}
|
}
|
||||||
|
|
||||||
type downloadOpts struct {
|
type downloadOpts struct {
|
||||||
mp ModelPath
|
n model.Name
|
||||||
digest string
|
digest string
|
||||||
regOpts *registryOptions
|
regOpts *registryOptions
|
||||||
fn func(api.ProgressResponse)
|
fn func(api.ProgressResponse)
|
||||||
@@ -465,10 +467,10 @@ type downloadOpts struct {
|
|||||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||||
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||||
if opts.digest == "" {
|
if opts.digest == "" {
|
||||||
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
|
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
fp, err := GetBlobsPath(opts.digest)
|
fp, err := manifest.BlobsPath(opts.digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -492,8 +494,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
|
|||||||
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
||||||
download := data.(*blobDownload)
|
download := data.(*blobDownload)
|
||||||
if !ok {
|
if !ok {
|
||||||
requestURL := opts.mp.BaseURL()
|
requestURL := opts.n.BaseURL()
|
||||||
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
|
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
|
||||||
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
||||||
blobDownloadManager.Delete(opts.digest)
|
blobDownloadManager.Delete(opts.digest)
|
||||||
return false, err
|
return false, err
|
||||||
|
|||||||
205
server/images.go
205
server/images.go
@@ -4,7 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -24,6 +23,7 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/fs/gguf"
|
"github.com/ollama/ollama/fs/gguf"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
"github.com/ollama/ollama/model/parsers"
|
"github.com/ollama/ollama/model/parsers"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
@@ -274,44 +274,22 @@ func (m *Model) String() string {
|
|||||||
return modelfile.String()
|
return modelfile.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
|
||||||
fp, err := mp.GetManifestPath()
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := os.Open(fp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
sha256sum := sha256.New()
|
|
||||||
|
|
||||||
var manifest Manifest
|
|
||||||
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetModel(name string) (*Model, error) {
|
func GetModel(name string) (*Model, error) {
|
||||||
mp := ParseModelPath(name)
|
n := model.ParseName(name)
|
||||||
manifest, digest, err := GetManifest(mp)
|
mf, err := manifest.ParseNamedManifest(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
model := &Model{
|
m := &Model{
|
||||||
Name: mp.GetFullTagname(),
|
Name: n.String(),
|
||||||
ShortName: mp.GetShortTagname(),
|
ShortName: n.DisplayShortest(),
|
||||||
Digest: digest,
|
Digest: mf.Digest(),
|
||||||
Template: template.DefaultTemplate,
|
Template: template.DefaultTemplate,
|
||||||
}
|
}
|
||||||
|
|
||||||
if manifest.Config.Digest != "" {
|
if mf.Config.Digest != "" {
|
||||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
filename, err := manifest.BlobsPath(mf.Config.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -322,29 +300,29 @@ func GetModel(name string) (*Model, error) {
|
|||||||
}
|
}
|
||||||
defer configFile.Close()
|
defer configFile.Close()
|
||||||
|
|
||||||
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range manifest.Layers {
|
for _, layer := range mf.Layers {
|
||||||
filename, err := GetBlobsPath(layer.Digest)
|
filename, err := manifest.BlobsPath(layer.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch layer.MediaType {
|
switch layer.MediaType {
|
||||||
case "application/vnd.ollama.image.model":
|
case "application/vnd.ollama.image.model":
|
||||||
model.ModelPath = filename
|
m.ModelPath = filename
|
||||||
model.ParentModel = layer.From
|
m.ParentModel = layer.From
|
||||||
case "application/vnd.ollama.image.embed":
|
case "application/vnd.ollama.image.embed":
|
||||||
// Deprecated in versions > 0.1.2
|
// Deprecated in versions > 0.1.2
|
||||||
// TODO: remove this warning in a future version
|
// TODO: remove this warning in a future version
|
||||||
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||||
case "application/vnd.ollama.image.adapter":
|
case "application/vnd.ollama.image.adapter":
|
||||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
m.AdapterPaths = append(m.AdapterPaths, filename)
|
||||||
case "application/vnd.ollama.image.projector":
|
case "application/vnd.ollama.image.projector":
|
||||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
m.ProjectorPaths = append(m.ProjectorPaths, filename)
|
||||||
case "application/vnd.ollama.image.prompt",
|
case "application/vnd.ollama.image.prompt",
|
||||||
"application/vnd.ollama.image.template":
|
"application/vnd.ollama.image.template":
|
||||||
bts, err := os.ReadFile(filename)
|
bts, err := os.ReadFile(filename)
|
||||||
@@ -352,7 +330,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
model.Template, err = template.Parse(string(bts))
|
m.Template, err = template.Parse(string(bts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -362,7 +340,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
model.System = string(bts)
|
m.System = string(bts)
|
||||||
case "application/vnd.ollama.image.params":
|
case "application/vnd.ollama.image.params":
|
||||||
params, err := os.Open(filename)
|
params, err := os.Open(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -371,7 +349,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
defer params.Close()
|
defer params.Close()
|
||||||
|
|
||||||
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
||||||
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case "application/vnd.ollama.image.messages":
|
case "application/vnd.ollama.image.messages":
|
||||||
@@ -381,7 +359,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
}
|
}
|
||||||
defer msgs.Close()
|
defer msgs.Close()
|
||||||
|
|
||||||
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case "application/vnd.ollama.image.license":
|
case "application/vnd.ollama.image.license":
|
||||||
@@ -389,11 +367,11 @@ func GetModel(name string) (*Model, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
model.License = append(model.License, string(bts))
|
m.License = append(m.License, string(bts))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return model, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyModel(src, dst model.Name) error {
|
func CopyModel(src, dst model.Name) error {
|
||||||
@@ -408,7 +386,7 @@ func CopyModel(src, dst model.Name) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
manifests, err := GetManifestPath()
|
manifests, err := manifest.Path()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -437,7 +415,7 @@ func CopyModel(src, dst model.Name) error {
|
|||||||
|
|
||||||
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||||
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
|
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
|
||||||
manifests, err := Manifests(true)
|
manifests, err := manifest.Manifests(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -452,7 +430,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
|||||||
|
|
||||||
// only delete the files which are still in the deleteMap
|
// only delete the files which are still in the deleteMap
|
||||||
for k := range deleteMap {
|
for k := range deleteMap {
|
||||||
fp, err := GetBlobsPath(k)
|
fp, err := manifest.BlobsPath(k)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
|
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
|
||||||
continue
|
continue
|
||||||
@@ -468,7 +446,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
|||||||
|
|
||||||
func PruneLayers() error {
|
func PruneLayers() error {
|
||||||
deleteMap := make(map[string]struct{})
|
deleteMap := make(map[string]struct{})
|
||||||
p, err := GetBlobsPath("")
|
p, err := manifest.BlobsPath("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -483,9 +461,9 @@ func PruneLayers() error {
|
|||||||
name := blob.Name()
|
name := blob.Name()
|
||||||
name = strings.ReplaceAll(name, "-", ":")
|
name = strings.ReplaceAll(name, "-", ":")
|
||||||
|
|
||||||
_, err := GetBlobsPath(name)
|
_, err := manifest.BlobsPath(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrInvalidDigestFormat) {
|
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
|
||||||
// remove invalid blobs (e.g. partial downloads)
|
// remove invalid blobs (e.g. partial downloads)
|
||||||
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
|
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
|
||||||
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
|
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
|
||||||
@@ -510,63 +488,30 @@ func PruneLayers() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func PruneDirectory(path string) error {
|
|
||||||
info, err := os.Lstat(path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
|
||||||
entries, err := os.ReadDir(path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, entry := range entries {
|
|
||||||
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
entries, err = os.ReadDir(path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(entries) > 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return os.Remove(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
n := model.ParseName(name)
|
||||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||||
|
|
||||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||||
return errInsecureProtocol
|
return errInsecureProtocol
|
||||||
}
|
}
|
||||||
|
|
||||||
manifest, _, err := GetManifest(mp)
|
mf, err := manifest.ParseNamedManifest(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
|
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var layers []Layer
|
var layers []manifest.Layer
|
||||||
layers = append(layers, manifest.Layers...)
|
layers = append(layers, mf.Layers...)
|
||||||
if manifest.Config.Digest != "" {
|
if mf.Config.Digest != "" {
|
||||||
layers = append(layers, manifest.Config)
|
layers = append(layers, mf.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use fast transfer for models with tensor layers (many small blobs)
|
// Use fast transfer for models with tensor layers (many small blobs)
|
||||||
if hasTensorLayers(layers) {
|
if hasTensorLayers(layers) {
|
||||||
// Read raw manifest JSON to preserve tensor metadata fields
|
// Read raw manifest JSON to preserve tensor metadata fields
|
||||||
manifestPath, err := mp.GetManifestPath()
|
manifestPath, err := manifest.PathForName(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -574,7 +519,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
|
if err := pushWithTransfer(ctx, n, layers, manifestJSON, regOpts, fn); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fn(api.ProgressResponse{Status: "success"})
|
fn(api.ProgressResponse{Status: "success"})
|
||||||
@@ -582,17 +527,17 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
if err := uploadBlob(ctx, n, layer, regOpts, fn); err != nil {
|
||||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "pushing manifest"})
|
fn(api.ProgressResponse{Status: "pushing manifest"})
|
||||||
requestURL := mp.BaseURL()
|
requestURL := n.BaseURL()
|
||||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||||
|
|
||||||
manifestJSON, err := json.Marshal(manifest)
|
manifestJSON, err := json.Marshal(mf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -611,44 +556,44 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
}
|
}
|
||||||
|
|
||||||
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
n := model.ParseName(name)
|
||||||
|
|
||||||
// build deleteMap to prune unused layers
|
// build deleteMap to prune unused layers
|
||||||
deleteMap := make(map[string]struct{})
|
deleteMap := make(map[string]struct{})
|
||||||
manifest, _, err := GetManifest(mp)
|
existingMf, err := manifest.ParseNamedManifest(n)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
// noop
|
// noop
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
|
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
|
||||||
} else {
|
} else {
|
||||||
for _, l := range manifest.Layers {
|
for _, l := range existingMf.Layers {
|
||||||
deleteMap[l.Digest] = struct{}{}
|
deleteMap[l.Digest] = struct{}{}
|
||||||
}
|
}
|
||||||
if manifest.Config.Digest != "" {
|
if existingMf.Config.Digest != "" {
|
||||||
deleteMap[manifest.Config.Digest] = struct{}{}
|
deleteMap[existingMf.Config.Digest] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||||
return errInsecureProtocol
|
return errInsecureProtocol
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||||
|
|
||||||
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
mf, err := pullModelManifest(ctx, n, regOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("pull model manifest: %s", err)
|
return fmt.Errorf("pull model manifest: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var layers []Layer
|
var layers []manifest.Layer
|
||||||
layers = append(layers, manifest.Layers...)
|
layers = append(layers, mf.Layers...)
|
||||||
if manifest.Config.Digest != "" {
|
if mf.Config.Digest != "" {
|
||||||
layers = append(layers, manifest.Config)
|
layers = append(layers, mf.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use fast transfer for models with tensor layers (many small blobs)
|
// Use fast transfer for models with tensor layers (many small blobs)
|
||||||
if hasTensorLayers(layers) {
|
if hasTensorLayers(layers) {
|
||||||
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
|
if err := pullWithTransfer(ctx, n, layers, mf, regOpts, fn); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fn(api.ProgressResponse{Status: "success"})
|
fn(api.ProgressResponse{Status: "success"})
|
||||||
@@ -658,7 +603,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
skipVerify := make(map[string]bool)
|
skipVerify := make(map[string]bool)
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||||
mp: mp,
|
n: n,
|
||||||
digest: layer.Digest,
|
digest: layer.Digest,
|
||||||
regOpts: regOpts,
|
regOpts: regOpts,
|
||||||
fn: fn,
|
fn: fn,
|
||||||
@@ -677,7 +622,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
}
|
}
|
||||||
if err := verifyBlob(layer.Digest); err != nil {
|
if err := verifyBlob(layer.Digest); err != nil {
|
||||||
if errors.Is(err, errDigestMismatch) {
|
if errors.Is(err, errDigestMismatch) {
|
||||||
fp, err := GetBlobsPath(layer.Digest)
|
fp, err := manifest.BlobsPath(layer.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -692,16 +637,16 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
delete(deleteMap, layer.Digest)
|
delete(deleteMap, layer.Digest)
|
||||||
}
|
}
|
||||||
delete(deleteMap, manifest.Config.Digest)
|
delete(deleteMap, mf.Config.Digest)
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||||
|
|
||||||
manifestJSON, err := json.Marshal(manifest)
|
manifestJSON, err := json.Marshal(mf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fp, err := mp.GetManifestPath()
|
fp, err := manifest.PathForName(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -728,9 +673,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
}
|
}
|
||||||
|
|
||||||
// hasTensorLayers checks if any layer has tensor media type.
|
// hasTensorLayers checks if any layer has tensor media type.
|
||||||
func hasTensorLayers(layers []Layer) bool {
|
func hasTensorLayers(layers []manifest.Layer) bool {
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if layer.MediaType == MediaTypeImageTensor {
|
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -738,7 +683,7 @@ func hasTensorLayers(layers []Layer) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
|
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
|
||||||
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, mf *manifest.Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
blobs := make([]transfer.Blob, len(layers))
|
blobs := make([]transfer.Blob, len(layers))
|
||||||
for i, layer := range layers {
|
for i, layer := range layers {
|
||||||
blobs[i] = transfer.Blob{
|
blobs[i] = transfer.Blob{
|
||||||
@@ -747,12 +692,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
destDir, err := GetBlobsPath("")
|
destDir, err := manifest.BlobsPath("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
base := mp.BaseURL()
|
base := n.BaseURL()
|
||||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||||
base.Scheme = "http"
|
base.Scheme = "http"
|
||||||
}
|
}
|
||||||
@@ -784,7 +729,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
Blobs: blobs,
|
Blobs: blobs,
|
||||||
BaseURL: baseURL,
|
BaseURL: baseURL,
|
||||||
DestDir: destDir,
|
DestDir: destDir,
|
||||||
Repository: mp.GetNamespaceRepository(),
|
Repository: n.DisplayNamespaceModel(),
|
||||||
Progress: progress,
|
Progress: progress,
|
||||||
Token: regOpts.Token,
|
Token: regOpts.Token,
|
||||||
GetToken: getToken,
|
GetToken: getToken,
|
||||||
@@ -795,12 +740,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
|
|
||||||
// Write manifest
|
// Write manifest
|
||||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||||
manifestJSON, err := json.Marshal(manifest)
|
manifestJSON, err := json.Marshal(mf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fp, err := mp.GetManifestPath()
|
fp, err := manifest.PathForName(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -812,7 +757,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
}
|
}
|
||||||
|
|
||||||
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
|
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
|
||||||
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
blobs := make([]transfer.Blob, len(layers))
|
blobs := make([]transfer.Blob, len(layers))
|
||||||
for i, layer := range layers {
|
for i, layer := range layers {
|
||||||
blobs[i] = transfer.Blob{
|
blobs[i] = transfer.Blob{
|
||||||
@@ -822,12 +767,12 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
srcDir, err := GetBlobsPath("")
|
srcDir, err := manifest.BlobsPath("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
base := mp.BaseURL()
|
base := n.BaseURL()
|
||||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||||
base.Scheme = "http"
|
base.Scheme = "http"
|
||||||
}
|
}
|
||||||
@@ -864,13 +809,13 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
GetToken: getToken,
|
GetToken: getToken,
|
||||||
Logger: slog.Default(),
|
Logger: slog.Default(),
|
||||||
Manifest: manifestJSON,
|
Manifest: manifestJSON,
|
||||||
ManifestRef: mp.Tag,
|
ManifestRef: n.Tag,
|
||||||
Repository: mp.GetNamespaceRepository(),
|
Repository: n.DisplayNamespaceModel(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
|
||||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
requestURL := n.BaseURL().JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||||
@@ -880,7 +825,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
var m Manifest
|
var m manifest.Manifest
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1042,7 +987,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
|
|||||||
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
|
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
|
||||||
|
|
||||||
func verifyBlob(digest string) error {
|
func verifyBlob(digest string) error {
|
||||||
fp, err := GetBlobsPath(digest)
|
fp, err := manifest.BlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -20,19 +21,19 @@ import (
|
|||||||
var intermediateBlobs map[string]string = make(map[string]string)
|
var intermediateBlobs map[string]string = make(map[string]string)
|
||||||
|
|
||||||
type layerGGML struct {
|
type layerGGML struct {
|
||||||
Layer
|
manifest.Layer
|
||||||
*ggml.GGML
|
*ggml.GGML
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||||
m, err := ParseNamedManifest(name)
|
m, err := manifest.ParseNamedManifest(name)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, os.ErrNotExist):
|
case errors.Is(err, os.ErrNotExist):
|
||||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err = ParseNamedManifest(name)
|
m, err = manifest.ParseNamedManifest(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -41,7 +42,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range m.Layers {
|
for _, layer := range m.Layers {
|
||||||
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -50,7 +51,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
|||||||
case "application/vnd.ollama.image.model",
|
case "application/vnd.ollama.image.model",
|
||||||
"application/vnd.ollama.image.projector",
|
"application/vnd.ollama.image.projector",
|
||||||
"application/vnd.ollama.image.adapter":
|
"application/vnd.ollama.image.adapter":
|
||||||
blobpath, err := GetBlobsPath(layer.Digest)
|
blobpath, err := manifest.BlobsPath(layer.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -81,12 +82,12 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
|||||||
if t, err := template.Named(s); err != nil {
|
if t, err := template.Named(s); err != nil {
|
||||||
slog.Debug("template detection", "error", err, "template", s)
|
slog.Debug("template detection", "error", err, "template", s)
|
||||||
} else {
|
} else {
|
||||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||||
layers = append(layers, &layerGGML{layer, nil})
|
layers = append(layers, &layerGGML{layer, nil})
|
||||||
|
|
||||||
if t.Parameters != nil {
|
if t.Parameters != nil {
|
||||||
@@ -95,7 +96,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,146 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
"github.com/ollama/ollama/types/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ModelPath struct {
|
|
||||||
ProtocolScheme string
|
|
||||||
Registry string
|
|
||||||
Namespace string
|
|
||||||
Repository string
|
|
||||||
Tag string
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
DefaultRegistry = "registry.ollama.ai"
|
|
||||||
DefaultNamespace = "library"
|
|
||||||
DefaultTag = "latest"
|
|
||||||
DefaultProtocolScheme = "https"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
|
||||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
|
||||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
|
||||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
|
||||||
ErrModelPathInvalid = errors.New("invalid model path")
|
|
||||||
)
|
|
||||||
|
|
||||||
func ParseModelPath(name string) ModelPath {
|
|
||||||
mp := ModelPath{
|
|
||||||
ProtocolScheme: DefaultProtocolScheme,
|
|
||||||
Registry: DefaultRegistry,
|
|
||||||
Namespace: DefaultNamespace,
|
|
||||||
Repository: "",
|
|
||||||
Tag: DefaultTag,
|
|
||||||
}
|
|
||||||
|
|
||||||
before, after, found := strings.Cut(name, "://")
|
|
||||||
if found {
|
|
||||||
mp.ProtocolScheme = before
|
|
||||||
name = after
|
|
||||||
}
|
|
||||||
|
|
||||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
|
||||||
parts := strings.Split(name, "/")
|
|
||||||
switch len(parts) {
|
|
||||||
case 3:
|
|
||||||
mp.Registry = parts[0]
|
|
||||||
mp.Namespace = parts[1]
|
|
||||||
mp.Repository = parts[2]
|
|
||||||
case 2:
|
|
||||||
mp.Namespace = parts[0]
|
|
||||||
mp.Repository = parts[1]
|
|
||||||
case 1:
|
|
||||||
mp.Repository = parts[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
|
|
||||||
mp.Repository = repo
|
|
||||||
mp.Tag = tag
|
|
||||||
}
|
|
||||||
|
|
||||||
return mp
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mp ModelPath) GetNamespaceRepository() string {
|
|
||||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mp ModelPath) GetFullTagname() string {
|
|
||||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mp ModelPath) GetShortTagname() string {
|
|
||||||
if mp.Registry == DefaultRegistry {
|
|
||||||
if mp.Namespace == DefaultNamespace {
|
|
||||||
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
|
||||||
func (mp ModelPath) GetManifestPath() (string, error) {
|
|
||||||
name := model.Name{
|
|
||||||
Host: mp.Registry,
|
|
||||||
Namespace: mp.Namespace,
|
|
||||||
Model: mp.Repository,
|
|
||||||
Tag: mp.Tag,
|
|
||||||
}
|
|
||||||
if !name.IsValid() {
|
|
||||||
return "", fs.ErrNotExist
|
|
||||||
}
|
|
||||||
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mp ModelPath) BaseURL() *url.URL {
|
|
||||||
return &url.URL{
|
|
||||||
Scheme: mp.ProtocolScheme,
|
|
||||||
Host: mp.Registry,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetManifestPath() (string, error) {
|
|
||||||
path := filepath.Join(envconfig.Models(), "manifests")
|
|
||||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
|
||||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return path, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetBlobsPath(digest string) (string, error) {
|
|
||||||
// only accept actual sha256 digests
|
|
||||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
|
||||||
re := regexp.MustCompile(pattern)
|
|
||||||
|
|
||||||
if digest != "" && !re.MatchString(digest) {
|
|
||||||
return "", ErrInvalidDigestFormat
|
|
||||||
}
|
|
||||||
|
|
||||||
digest = strings.ReplaceAll(digest, ":", "-")
|
|
||||||
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
|
||||||
dirPath := filepath.Dir(path)
|
|
||||||
if digest == "" {
|
|
||||||
dirPath = path
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
|
||||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return path, nil
|
|
||||||
}
|
|
||||||
@@ -1,153 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetBlobsPath(t *testing.T) {
|
|
||||||
// GetBlobsPath expects an actual directory to exist
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
digest string
|
|
||||||
expected string
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"empty digest",
|
|
||||||
"",
|
|
||||||
filepath.Join(tempDir, "blobs"),
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"valid with colon",
|
|
||||||
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
|
||||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"valid with dash",
|
|
||||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
|
||||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"digest too short",
|
|
||||||
"sha256-45640291",
|
|
||||||
"",
|
|
||||||
ErrInvalidDigestFormat,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"digest too long",
|
|
||||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
|
|
||||||
"",
|
|
||||||
ErrInvalidDigestFormat,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"digest invalid chars",
|
|
||||||
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
|
|
||||||
"",
|
|
||||||
ErrInvalidDigestFormat,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
|
||||||
|
|
||||||
got, err := GetBlobsPath(tc.digest)
|
|
||||||
|
|
||||||
require.ErrorIs(t, tc.err, err, tc.name)
|
|
||||||
assert.Equal(t, tc.expected, got, tc.name)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseModelPath(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
arg string
|
|
||||||
want ModelPath
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"full path https",
|
|
||||||
"https://example.com/ns/repo:tag",
|
|
||||||
ModelPath{
|
|
||||||
ProtocolScheme: "https",
|
|
||||||
Registry: "example.com",
|
|
||||||
Namespace: "ns",
|
|
||||||
Repository: "repo",
|
|
||||||
Tag: "tag",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"full path http",
|
|
||||||
"http://example.com/ns/repo:tag",
|
|
||||||
ModelPath{
|
|
||||||
ProtocolScheme: "http",
|
|
||||||
Registry: "example.com",
|
|
||||||
Namespace: "ns",
|
|
||||||
Repository: "repo",
|
|
||||||
Tag: "tag",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"no protocol",
|
|
||||||
"example.com/ns/repo:tag",
|
|
||||||
ModelPath{
|
|
||||||
ProtocolScheme: "https",
|
|
||||||
Registry: "example.com",
|
|
||||||
Namespace: "ns",
|
|
||||||
Repository: "repo",
|
|
||||||
Tag: "tag",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"no registry",
|
|
||||||
"ns/repo:tag",
|
|
||||||
ModelPath{
|
|
||||||
ProtocolScheme: "https",
|
|
||||||
Registry: DefaultRegistry,
|
|
||||||
Namespace: "ns",
|
|
||||||
Repository: "repo",
|
|
||||||
Tag: "tag",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"no namespace",
|
|
||||||
"repo:tag",
|
|
||||||
ModelPath{
|
|
||||||
ProtocolScheme: "https",
|
|
||||||
Registry: DefaultRegistry,
|
|
||||||
Namespace: DefaultNamespace,
|
|
||||||
Repository: "repo",
|
|
||||||
Tag: "tag",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"no tag",
|
|
||||||
"repo",
|
|
||||||
ModelPath{
|
|
||||||
ProtocolScheme: "https",
|
|
||||||
Registry: DefaultRegistry,
|
|
||||||
Namespace: DefaultNamespace,
|
|
||||||
Repository: "repo",
|
|
||||||
Tag: DefaultTag,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
got := ParseModelPath(tc.arg)
|
|
||||||
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("got: %q want: %q", got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -39,6 +39,7 @@ import (
|
|||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
"github.com/ollama/ollama/middleware"
|
"github.com/ollama/ollama/middleware"
|
||||||
"github.com/ollama/ollama/model/parsers"
|
"github.com/ollama/ollama/model/parsers"
|
||||||
"github.com/ollama/ollama/model/renderers"
|
"github.com/ollama/ollama/model/renderers"
|
||||||
@@ -974,7 +975,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
|||||||
// is.
|
// is.
|
||||||
func getExistingName(n model.Name) (model.Name, error) {
|
func getExistingName(n model.Name) (model.Name, error) {
|
||||||
var zero model.Name
|
var zero model.Name
|
||||||
existing, err := Manifests(true)
|
existing, err := manifest.Manifests(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return zero, err
|
return zero, err
|
||||||
}
|
}
|
||||||
@@ -1018,7 +1019,7 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := ParseNamedManifest(n)
|
m, err := manifest.ParseNamedManifest(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
case os.IsNotExist(err):
|
case os.IsNotExist(err):
|
||||||
@@ -1080,7 +1081,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
|||||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
name := model.ParseName(req.Model)
|
name := model.ParseName(req.Model)
|
||||||
if !name.IsValid() {
|
if !name.IsValid() {
|
||||||
return nil, ErrModelPathInvalid
|
return nil, model.Unqualified(name)
|
||||||
}
|
}
|
||||||
name, err := getExistingName(name)
|
name, err := getExistingName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1112,7 +1113,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
|
|
||||||
// For safetensors LLM models (experimental), populate details from config.json
|
// For safetensors LLM models (experimental), populate details from config.json
|
||||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||||
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
||||||
modelDetails.Family = arch
|
modelDetails.Family = arch
|
||||||
}
|
}
|
||||||
@@ -1121,7 +1122,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Get torch_dtype directly from config.json for quantization level
|
// Get torch_dtype directly from config.json for quantization level
|
||||||
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
|
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
|
||||||
modelDetails.QuantizationLevel = dtype
|
modelDetails.QuantizationLevel = dtype
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1135,7 +1136,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||||
}
|
}
|
||||||
|
|
||||||
manifest, err := ParseNamedManifest(name)
|
mf, err := manifest.ParseNamedManifest(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1147,7 +1148,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
Details: modelDetails,
|
Details: modelDetails,
|
||||||
Messages: msgs,
|
Messages: msgs,
|
||||||
Capabilities: m.Capabilities(),
|
Capabilities: m.Capabilities(),
|
||||||
ModifiedAt: manifest.fi.ModTime(),
|
ModifiedAt: mf.FileInfo().ModTime(),
|
||||||
Requires: m.Config.Requires,
|
Requires: m.Config.Requires,
|
||||||
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
|
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
|
||||||
// default we return an empty map.
|
// default we return an empty map.
|
||||||
@@ -1214,7 +1215,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||||
// Populate tensor info if verbose
|
// Populate tensor info if verbose
|
||||||
if req.Verbose {
|
if req.Verbose {
|
||||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||||
resp.Tensors = tensors
|
resp.Tensors = tensors
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1223,12 +1224,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
|
|
||||||
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
||||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||||
resp.ModelInfo = info
|
resp.ModelInfo = info
|
||||||
}
|
}
|
||||||
// Populate tensor info if verbose
|
// Populate tensor info if verbose
|
||||||
if req.Verbose {
|
if req.Verbose {
|
||||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||||
resp.Tensors = tensors
|
resp.Tensors = tensors
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1285,7 +1286,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ListHandler(c *gin.Context) {
|
func (s *Server) ListHandler(c *gin.Context) {
|
||||||
ms, err := Manifests(true)
|
ms, err := manifest.Manifests(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -1316,8 +1317,8 @@ func (s *Server) ListHandler(c *gin.Context) {
|
|||||||
RemoteModel: cf.RemoteModel,
|
RemoteModel: cf.RemoteModel,
|
||||||
RemoteHost: cf.RemoteHost,
|
RemoteHost: cf.RemoteHost,
|
||||||
Size: m.Size(),
|
Size: m.Size(),
|
||||||
Digest: m.digest,
|
Digest: m.Digest(),
|
||||||
ModifiedAt: m.fi.ModTime(),
|
ModifiedAt: m.FileInfo().ModTime(),
|
||||||
Details: api.ModelDetails{
|
Details: api.ModelDetails{
|
||||||
Format: cf.ModelFormat,
|
Format: cf.ModelFormat,
|
||||||
Family: cf.ModelFamily,
|
Family: cf.ModelFamily,
|
||||||
@@ -1376,7 +1377,7 @@ func (s *Server) CopyHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) HeadBlobHandler(c *gin.Context) {
|
func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||||
path, err := GetBlobsPath(c.Param("digest"))
|
path, err := manifest.BlobsPath(c.Param("digest"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -1392,7 +1393,7 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
|
|||||||
|
|
||||||
func (s *Server) CreateBlobHandler(c *gin.Context) {
|
func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||||
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
|
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
|
||||||
p, err := GetBlobsPath(ib)
|
p, err := manifest.BlobsPath(ib)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -1410,7 +1411,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
path, err := GetBlobsPath(c.Param("digest"))
|
path, err := manifest.BlobsPath(c.Param("digest"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -1428,7 +1429,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
layer, err := NewLayer(c.Request.Body, "")
|
layer, err := manifest.NewLayer(c.Request.Body, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -1628,7 +1629,7 @@ func Serve(ln net.Listener) error {
|
|||||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||||
slog.Info("server config", "env", envconfig.Values())
|
slog.Info("server config", "env", envconfig.Values())
|
||||||
|
|
||||||
blobsDir, err := GetBlobsPath("")
|
blobsDir, err := manifest.BlobsPath("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1637,7 +1638,7 @@ func Serve(ln net.Listener) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !envconfig.NoPrune() {
|
if !envconfig.NoPrune() {
|
||||||
if _, err := Manifests(false); err != nil {
|
if _, err := manifest.Manifests(false); err != nil {
|
||||||
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
|
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
|
||||||
} else {
|
} else {
|
||||||
// clean up unused layers and manifests
|
// clean up unused layers and manifests
|
||||||
@@ -1645,12 +1646,12 @@ func Serve(ln net.Listener) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
manifestsPath, err := GetManifestPath()
|
manifestsPath, err := manifest.Path()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := PruneDirectory(manifestsPath); err != nil {
|
if err := manifest.PruneDirectory(manifestsPath); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/ollama/ollama/convert"
|
"github.com/ollama/ollama/convert"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,15 +224,15 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
manifest, err := ParseNamedManifest(model.ParseName("child"))
|
mf, err := manifest.ParseNamedManifest(model.ParseName("child"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("parse manifest: %v", err)
|
t.Fatalf("parse manifest: %v", err)
|
||||||
}
|
}
|
||||||
if manifest.Config.Digest == "" {
|
if mf.Config.Digest == "" {
|
||||||
t.Fatalf("unexpected empty config digest for child manifest")
|
t.Fatalf("unexpected empty config digest for child manifest")
|
||||||
}
|
}
|
||||||
|
|
||||||
configPath, err := GetBlobsPath(manifest.Config.Digest)
|
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("config blob path: %v", err)
|
t.Fatalf("config blob path: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -93,13 +94,13 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a manifest with duplicate layers
|
// create a manifest with duplicate layers
|
||||||
if err := WriteManifest(n, config, []Layer{config}); err != nil {
|
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,12 +21,14 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
var blobUploadManager sync.Map
|
var blobUploadManager sync.Map
|
||||||
|
|
||||||
type blobUpload struct {
|
type blobUpload struct {
|
||||||
Layer
|
manifest.Layer
|
||||||
|
|
||||||
Total int64
|
Total int64
|
||||||
Completed atomic.Int64
|
Completed atomic.Int64
|
||||||
@@ -51,7 +53,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||||
p, err := GetBlobsPath(b.Digest)
|
p, err := manifest.BlobsPath(b.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -59,7 +61,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
|
|||||||
if b.From != "" {
|
if b.From != "" {
|
||||||
values := requestURL.Query()
|
values := requestURL.Query()
|
||||||
values.Add("mount", b.Digest)
|
values.Add("mount", b.Digest)
|
||||||
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
|
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
|
||||||
requestURL.RawQuery = values.Encode()
|
requestURL.RawQuery = values.Encode()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,7 +130,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
|||||||
defer blobUploadManager.Delete(b.Digest)
|
defer blobUploadManager.Delete(b.Digest)
|
||||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||||
|
|
||||||
p, err := GetBlobsPath(b.Digest)
|
p, err := manifest.BlobsPath(b.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.err = err
|
b.err = err
|
||||||
return
|
return
|
||||||
@@ -364,9 +366,9 @@ func (p *progressWriter) Rollback() {
|
|||||||
p.written = 0
|
p.written = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
requestURL := mp.BaseURL()
|
requestURL := n.BaseURL()
|
||||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
|
||||||
|
|
||||||
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||||
switch {
|
switch {
|
||||||
@@ -388,8 +390,8 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOp
|
|||||||
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
|
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
|
||||||
upload := data.(*blobUpload)
|
upload := data.(*blobUpload)
|
||||||
if !ok {
|
if !ok {
|
||||||
requestURL := mp.BaseURL()
|
requestURL := n.BaseURL()
|
||||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
|
||||||
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
|
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
|
||||||
blobUploadManager.Delete(layer.Digest)
|
blobUploadManager.Delete(layer.Digest)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -35,22 +36,25 @@ func Unqualified(n Name) error {
|
|||||||
const MissingPart = "!MISSING!"
|
const MissingPart = "!MISSING!"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultHost = "registry.ollama.ai"
|
defaultHost = "registry.ollama.ai"
|
||||||
defaultNamespace = "library"
|
defaultNamespace = "library"
|
||||||
defaultTag = "latest"
|
defaultTag = "latest"
|
||||||
|
defaultProtocolScheme = "https"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultName returns a name with the default values for the host, namespace,
|
// DefaultName returns a name with the default values for the host, namespace,
|
||||||
// and tag parts. The model and digest parts are empty.
|
// tag, and protocol scheme parts. The model and digest parts are empty.
|
||||||
//
|
//
|
||||||
// - The default host is ("registry.ollama.ai")
|
// - The default host is ("registry.ollama.ai")
|
||||||
// - The default namespace is ("library")
|
// - The default namespace is ("library")
|
||||||
// - The default tag is ("latest")
|
// - The default tag is ("latest")
|
||||||
|
// - The default protocol scheme is ("https")
|
||||||
func DefaultName() Name {
|
func DefaultName() Name {
|
||||||
return Name{
|
return Name{
|
||||||
Host: defaultHost,
|
Host: defaultHost,
|
||||||
Namespace: defaultNamespace,
|
Namespace: defaultNamespace,
|
||||||
Tag: defaultTag,
|
Tag: defaultTag,
|
||||||
|
ProtocolScheme: defaultProtocolScheme,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,10 +91,11 @@ func (k partKind) String() string {
|
|||||||
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
|
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
|
||||||
// is valid.
|
// is valid.
|
||||||
type Name struct {
|
type Name struct {
|
||||||
Host string
|
Host string
|
||||||
Namespace string
|
Namespace string
|
||||||
Model string
|
Model string
|
||||||
Tag string
|
Tag string
|
||||||
|
ProtocolScheme string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseName parses and assembles a Name from a name string. The
|
// ParseName parses and assembles a Name from a name string. The
|
||||||
@@ -160,7 +165,9 @@ func ParseNameBare(s string) Name {
|
|||||||
}
|
}
|
||||||
|
|
||||||
scheme, host, ok := strings.Cut(s, "://")
|
scheme, host, ok := strings.Cut(s, "://")
|
||||||
if !ok {
|
if ok {
|
||||||
|
n.ProtocolScheme = scheme
|
||||||
|
} else {
|
||||||
host = scheme
|
host = scheme
|
||||||
}
|
}
|
||||||
n.Host = host
|
n.Host = host
|
||||||
@@ -189,12 +196,13 @@ func ParseNameFromFilepath(s string) (n Name) {
|
|||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge merges the host, namespace, and tag parts of the two names,
|
// Merge merges the host, namespace, tag, and protocol scheme parts of the two names,
|
||||||
// preferring the non-empty parts of a.
|
// preferring the non-empty parts of a.
|
||||||
func Merge(a, b Name) Name {
|
func Merge(a, b Name) Name {
|
||||||
a.Host = cmp.Or(a.Host, b.Host)
|
a.Host = cmp.Or(a.Host, b.Host)
|
||||||
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
|
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
|
||||||
a.Tag = cmp.Or(a.Tag, b.Tag)
|
a.Tag = cmp.Or(a.Tag, b.Tag)
|
||||||
|
a.ProtocolScheme = cmp.Or(a.ProtocolScheme, b.ProtocolScheme)
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -305,6 +313,23 @@ func (n Name) EqualFold(o Name) bool {
|
|||||||
strings.EqualFold(n.Tag, o.Tag)
|
strings.EqualFold(n.Tag, o.Tag)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BaseURL returns the base URL for the registry.
|
||||||
|
func (n Name) BaseURL() *url.URL {
|
||||||
|
return &url.URL{
|
||||||
|
Scheme: n.ProtocolScheme,
|
||||||
|
Host: n.Host,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayNamespaceModel returns the namespace and model joined by "/".
|
||||||
|
func (n Name) DisplayNamespaceModel() string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(n.Namespace)
|
||||||
|
b.WriteByte('/')
|
||||||
|
b.WriteString(n.Model)
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
func isValidLen(kind partKind, s string) bool {
|
func isValidLen(kind partKind, s string) bool {
|
||||||
switch kind {
|
switch kind {
|
||||||
case kindHost:
|
case kindHost:
|
||||||
|
|||||||
@@ -32,10 +32,11 @@ func TestParseNameParts(t *testing.T) {
|
|||||||
{
|
{
|
||||||
in: "scheme://host:port/namespace/model:tag",
|
in: "scheme://host:port/namespace/model:tag",
|
||||||
want: Name{
|
want: Name{
|
||||||
Host: "host:port",
|
Host: "host:port",
|
||||||
Namespace: "namespace",
|
Namespace: "namespace",
|
||||||
Model: "model",
|
Model: "model",
|
||||||
Tag: "tag",
|
Tag: "tag",
|
||||||
|
ProtocolScheme: "scheme",
|
||||||
},
|
},
|
||||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
|
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/server"
|
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/x/create"
|
"github.com/ollama/ollama/x/create"
|
||||||
)
|
)
|
||||||
@@ -103,7 +103,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
|||||||
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
||||||
func newLayerCreator() create.LayerCreator {
|
func newLayerCreator() create.LayerCreator {
|
||||||
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
||||||
layer, err := server.NewLayer(r, mediaType)
|
layer, err := manifest.NewLayer(r, mediaType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return create.LayerInfo{}, err
|
return create.LayerInfo{}, err
|
||||||
}
|
}
|
||||||
@@ -141,13 +141,13 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create layer for quantized weight
|
// Create layer for quantized weight
|
||||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create layer for scales
|
// Create layer for scales
|
||||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -169,7 +169,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
|||||||
|
|
||||||
// Add qbiases layer if present (affine mode)
|
// Add qbiases layer if present (affine mode)
|
||||||
if qbiasData != nil {
|
if qbiasData != nil {
|
||||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -186,7 +186,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
|||||||
|
|
||||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||||
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
||||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
layer, err := manifest.NewLayer(r, manifest.MediaTypeImageTensor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -221,15 +221,15 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create config layer blob
|
// Create config layer blob
|
||||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
configLayer, err := manifest.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create config layer: %w", err)
|
return fmt.Errorf("failed to create config layer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert LayerInfo to server.Layer
|
// Convert LayerInfo to manifest.Layer
|
||||||
serverLayers := make([]server.Layer, 0, len(layers))
|
manifestLayers := make([]manifest.Layer, 0, len(layers))
|
||||||
for _, l := range layers {
|
for _, l := range layers {
|
||||||
serverLayers = append(serverLayers, server.Layer{
|
manifestLayers = append(manifestLayers, manifest.Layer{
|
||||||
MediaType: l.MediaType,
|
MediaType: l.MediaType,
|
||||||
Digest: l.Digest,
|
Digest: l.Digest,
|
||||||
Size: l.Size,
|
Size: l.Size,
|
||||||
@@ -243,19 +243,19 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
serverLayers = append(serverLayers, modelfileLayers...)
|
manifestLayers = append(manifestLayers, modelfileLayers...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return server.WriteManifest(name, configLayer, serverLayers)
|
return manifest.WriteManifest(name, configLayer, manifestLayers)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||||
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||||
var layers []server.Layer
|
var layers []manifest.Layer
|
||||||
|
|
||||||
if mf.Template != "" {
|
if mf.Template != "" {
|
||||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
||||||
}
|
}
|
||||||
@@ -263,7 +263,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if mf.System != "" {
|
if mf.System != "" {
|
||||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
||||||
}
|
}
|
||||||
@@ -271,7 +271,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if mf.License != "" {
|
if mf.License != "" {
|
||||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
6
x/imagegen/cache/step.go
vendored
6
x/imagegen/cache/step.go
vendored
@@ -9,7 +9,7 @@ import "github.com/ollama/ollama/x/imagegen/mlx"
|
|||||||
// shallow layers change little between consecutive steps, so we can
|
// shallow layers change little between consecutive steps, so we can
|
||||||
// cache their outputs and skip recomputation on non-refresh steps.
|
// cache their outputs and skip recomputation on non-refresh steps.
|
||||||
//
|
//
|
||||||
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
|
// Supports both single-stream and dual-stream architectures:
|
||||||
// - Single-stream: use Get/Set for the single output per layer
|
// - Single-stream: use Get/Set for the single output per layer
|
||||||
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
|
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
|
||||||
//
|
//
|
||||||
@@ -87,7 +87,7 @@ func (c *StepCache) Set(layer int, arr *mlx.Array) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
|
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
|
||||||
// Used for dual-stream architectures like Qwen-Image.
|
// Used for dual-stream architectures.
|
||||||
func (c *StepCache) Get2(layer int) *mlx.Array {
|
func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||||
if layer < len(c.layers2) {
|
if layer < len(c.layers2) {
|
||||||
return c.layers2[layer]
|
return c.layers2[layer]
|
||||||
@@ -96,7 +96,7 @@ func (c *StepCache) Get2(layer int) *mlx.Array {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set2 stores a layer output (stream 2), freeing any previous value.
|
// Set2 stores a layer output (stream 2), freeing any previous value.
|
||||||
// Used for dual-stream architectures like Qwen-Image.
|
// Used for dual-stream architectures.
|
||||||
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
|
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
|
||||||
if layer < len(c.layers2) {
|
if layer < len(c.layers2) {
|
||||||
if c.layers2[layer] != nil {
|
if c.layers2[layer] != nil {
|
||||||
|
|||||||
@@ -21,8 +21,6 @@ import (
|
|||||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||||
)
|
)
|
||||||
@@ -61,14 +59,11 @@ func main() {
|
|||||||
listTensors := flag.Bool("list", false, "List tensors only")
|
listTensors := flag.Bool("list", false, "List tensors only")
|
||||||
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
||||||
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
||||||
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
|
|
||||||
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
||||||
|
|
||||||
// Legacy mode flags
|
// Legacy mode flags
|
||||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||||
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
||||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
|
||||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
|
||||||
var inputImages stringSlice
|
var inputImages stringSlice
|
||||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||||
@@ -166,60 +161,6 @@ func main() {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
err = saveImageArray(img, *out)
|
err = saveImageArray(img, *out)
|
||||||
}
|
}
|
||||||
case *qwenImage:
|
|
||||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
|
||||||
if loadErr != nil {
|
|
||||||
log.Fatal(loadErr)
|
|
||||||
}
|
|
||||||
var img *mlx.Array
|
|
||||||
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
|
|
||||||
Prompt: *prompt,
|
|
||||||
NegativePrompt: *negativePrompt,
|
|
||||||
CFGScale: float32(*cfgScale),
|
|
||||||
Width: int32(*width),
|
|
||||||
Height: int32(*height),
|
|
||||||
Steps: *steps,
|
|
||||||
Seed: *seed,
|
|
||||||
LayerCache: *layerCache,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
err = saveImageArray(img, *out)
|
|
||||||
}
|
|
||||||
case *qwenImageEdit:
|
|
||||||
if len(inputImages) == 0 {
|
|
||||||
log.Fatal("qwen-image-edit requires at least one -input-image")
|
|
||||||
}
|
|
||||||
|
|
||||||
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
|
|
||||||
if loadErr != nil {
|
|
||||||
log.Fatal(loadErr)
|
|
||||||
}
|
|
||||||
// For image editing, use 0 for dimensions to auto-detect from input image
|
|
||||||
// unless explicitly overridden from defaults
|
|
||||||
editWidth := int32(0)
|
|
||||||
editHeight := int32(0)
|
|
||||||
if *width != 1024 {
|
|
||||||
editWidth = int32(*width)
|
|
||||||
}
|
|
||||||
if *height != 1024 {
|
|
||||||
editHeight = int32(*height)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := &qwen_image_edit.GenerateConfig{
|
|
||||||
Prompt: *prompt,
|
|
||||||
NegativePrompt: *negativePrompt,
|
|
||||||
CFGScale: float32(*cfgScale),
|
|
||||||
Width: editWidth,
|
|
||||||
Height: editHeight,
|
|
||||||
Steps: *steps,
|
|
||||||
Seed: *seed,
|
|
||||||
}
|
|
||||||
|
|
||||||
var img *mlx.Array
|
|
||||||
img, err = m.EditFromConfig(inputImages, cfg)
|
|
||||||
if err == nil {
|
|
||||||
err = saveImageArray(img, *out)
|
|
||||||
}
|
|
||||||
case *listTensors:
|
case *listTensors:
|
||||||
err = listModelTensors(*modelPath)
|
err = listModelTensors(*modelPath)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -1,87 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestMain initializes MLX before running tests.
|
|
||||||
// If MLX libraries are not available, tests are skipped.
|
|
||||||
func TestMain(m *testing.M) {
|
|
||||||
// Change to repo root so ./build/lib/ollama/ path works
|
|
||||||
_, thisFile, _, _ := runtime.Caller(0)
|
|
||||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
|
||||||
if err := os.Chdir(repoRoot); err != nil {
|
|
||||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := mlx.InitMLX(); err != nil {
|
|
||||||
fmt.Printf("Skipping qwen_image tests: %v\n", err)
|
|
||||||
os.Exit(0)
|
|
||||||
}
|
|
||||||
os.Exit(m.Run())
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPipelineOutput runs the full pipeline (integration test).
|
|
||||||
// Skips if model weights not found. Requires ~50GB VRAM.
|
|
||||||
func TestPipelineOutput(t *testing.T) {
|
|
||||||
modelPath := "../../../weights/Qwen-Image-2512"
|
|
||||||
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
|
|
||||||
t.Skip("Skipping: model weights not found at " + modelPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load model
|
|
||||||
pm, err := LoadPersistent(modelPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Skipf("Skipping: failed to load model: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run 2-step pipeline (minimum for stable scheduler)
|
|
||||||
cfg := &GenerateConfig{
|
|
||||||
Prompt: "a cat",
|
|
||||||
Width: 256,
|
|
||||||
Height: 256,
|
|
||||||
Steps: 2,
|
|
||||||
Seed: 42,
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := pm.GenerateFromConfig(cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Pipeline failed: %v", err)
|
|
||||||
}
|
|
||||||
mlx.Eval(output)
|
|
||||||
|
|
||||||
// Verify output shape [1, C, H, W]
|
|
||||||
shape := output.Shape()
|
|
||||||
if len(shape) != 4 {
|
|
||||||
t.Errorf("Expected 4D output, got %v", shape)
|
|
||||||
}
|
|
||||||
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
|
|
||||||
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify values in expected range [0, 1]
|
|
||||||
data := output.Data()
|
|
||||||
minVal, maxVal := float32(1.0), float32(0.0)
|
|
||||||
for _, v := range data {
|
|
||||||
if v < minVal {
|
|
||||||
minVal = v
|
|
||||||
}
|
|
||||||
if v > maxVal {
|
|
||||||
maxVal = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
|
|
||||||
|
|
||||||
if minVal < -0.1 || maxVal > 1.1 {
|
|
||||||
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,367 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
// Package qwen_image implements the Qwen-Image diffusion transformer model.
|
|
||||||
package qwen_image
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"path/filepath"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/cache"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GenerateConfig holds all options for image generation.
|
|
||||||
type GenerateConfig struct {
|
|
||||||
Prompt string
|
|
||||||
NegativePrompt string // Empty = no CFG
|
|
||||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
|
||||||
Width int32 // Image width (default: 1024)
|
|
||||||
Height int32 // Image height (default: 1024)
|
|
||||||
Steps int // Denoising steps (default: 30)
|
|
||||||
Seed int64 // Random seed
|
|
||||||
Progress func(step, totalSteps int) // Optional progress callback
|
|
||||||
|
|
||||||
// Layer caching (DeepCache/Learning-to-Cache speedup)
|
|
||||||
LayerCache bool // Enable layer caching (default: false)
|
|
||||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
|
||||||
CacheLayers int // Number of shallow layers to cache (default: 25)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model represents a Qwen-Image diffusion model.
|
|
||||||
type Model struct {
|
|
||||||
ModelPath string
|
|
||||||
Tokenizer *tokenizer.Tokenizer
|
|
||||||
TextEncoder *Qwen25VL
|
|
||||||
Transformer *Transformer
|
|
||||||
VAEDecoder *VAEDecoder
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load loads the Qwen-Image model from a directory.
|
|
||||||
func (m *Model) Load(modelPath string) error {
|
|
||||||
fmt.Println("Loading Qwen-Image model...")
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
if mlx.GPUIsAvailable() {
|
|
||||||
mlx.SetDefaultDeviceGPU()
|
|
||||||
mlx.EnableCompile()
|
|
||||||
}
|
|
||||||
|
|
||||||
m.ModelPath = modelPath
|
|
||||||
|
|
||||||
// Load tokenizer
|
|
||||||
fmt.Print(" Loading tokenizer... ")
|
|
||||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
|
||||||
tok, err := tokenizer.Load(tokenizerPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("tokenizer: %w", err)
|
|
||||||
}
|
|
||||||
m.Tokenizer = tok
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
|
|
||||||
m.TextEncoder = &Qwen25VL{}
|
|
||||||
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
|
|
||||||
return fmt.Errorf("text encoder: %w", err)
|
|
||||||
}
|
|
||||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
|
||||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
|
||||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
|
||||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
|
||||||
|
|
||||||
// Load transformer
|
|
||||||
m.Transformer = &Transformer{}
|
|
||||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
|
||||||
return fmt.Errorf("transformer: %w", err)
|
|
||||||
}
|
|
||||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
|
||||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
|
||||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
|
||||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
|
||||||
|
|
||||||
// Load VAE decoder
|
|
||||||
m.VAEDecoder = &VAEDecoder{}
|
|
||||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
|
||||||
return fmt.Errorf("VAE decoder: %w", err)
|
|
||||||
}
|
|
||||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
|
||||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
|
||||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
|
||||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
|
||||||
|
|
||||||
mem := mlx.MetalGetActiveMemory()
|
|
||||||
peak := mlx.MetalGetPeakMemory()
|
|
||||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
|
||||||
time.Since(start).Seconds(),
|
|
||||||
float64(mem)/(1024*1024*1024),
|
|
||||||
float64(peak)/(1024*1024*1024))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate creates an image from a prompt.
|
|
||||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
|
||||||
return m.GenerateFromConfig(&GenerateConfig{
|
|
||||||
Prompt: prompt,
|
|
||||||
Width: width,
|
|
||||||
Height: height,
|
|
||||||
Steps: steps,
|
|
||||||
Seed: seed,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateWithProgress creates an image with progress callback.
|
|
||||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
|
||||||
return m.GenerateFromConfig(&GenerateConfig{
|
|
||||||
Prompt: prompt,
|
|
||||||
Width: width,
|
|
||||||
Height: height,
|
|
||||||
Steps: steps,
|
|
||||||
Seed: seed,
|
|
||||||
Progress: progress,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
|
||||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
|
||||||
return m.GenerateFromConfig(&GenerateConfig{
|
|
||||||
Prompt: prompt,
|
|
||||||
NegativePrompt: negativePrompt,
|
|
||||||
CFGScale: cfgScale,
|
|
||||||
Width: width,
|
|
||||||
Height: height,
|
|
||||||
Steps: steps,
|
|
||||||
Seed: seed,
|
|
||||||
Progress: progress,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateFromConfig generates an image using the unified config struct.
|
|
||||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
|
||||||
start := time.Now()
|
|
||||||
result, err := m.generate(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if cfg.NegativePrompt != "" {
|
|
||||||
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
|
||||||
} else {
|
|
||||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateImage implements model.ImageModel interface.
|
|
||||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
|
||||||
return m.Generate(prompt, width, height, steps, seed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// generate is the internal denoising pipeline.
|
|
||||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
|
||||||
// Apply defaults
|
|
||||||
if cfg.Width <= 0 {
|
|
||||||
cfg.Width = 1024
|
|
||||||
}
|
|
||||||
if cfg.Height <= 0 {
|
|
||||||
cfg.Height = 1024
|
|
||||||
}
|
|
||||||
if cfg.Steps <= 0 {
|
|
||||||
cfg.Steps = 50
|
|
||||||
}
|
|
||||||
if cfg.CFGScale <= 0 {
|
|
||||||
cfg.CFGScale = 4.0
|
|
||||||
}
|
|
||||||
if cfg.CacheInterval <= 0 {
|
|
||||||
cfg.CacheInterval = 3
|
|
||||||
}
|
|
||||||
if cfg.CacheLayers <= 0 {
|
|
||||||
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
|
|
||||||
}
|
|
||||||
|
|
||||||
useCFG := cfg.NegativePrompt != ""
|
|
||||||
tcfg := m.Transformer.Config
|
|
||||||
latentH := cfg.Height / 8
|
|
||||||
latentW := cfg.Width / 8
|
|
||||||
pH := latentH / tcfg.PatchSize
|
|
||||||
pW := latentW / tcfg.PatchSize
|
|
||||||
imgSeqLen := pH * pW
|
|
||||||
|
|
||||||
// Text encoding
|
|
||||||
var posEmb, negEmb *mlx.Array
|
|
||||||
{
|
|
||||||
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
|
||||||
if useCFG {
|
|
||||||
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
|
|
||||||
mlx.Keep(posEmb, negEmb)
|
|
||||||
mlx.Eval(posEmb, negEmb)
|
|
||||||
} else {
|
|
||||||
mlx.Keep(posEmb)
|
|
||||||
mlx.Eval(posEmb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pad sequences to same length for CFG
|
|
||||||
txtLen := posEmb.Shape()[1]
|
|
||||||
if useCFG {
|
|
||||||
negLen := negEmb.Shape()[1]
|
|
||||||
if negLen > txtLen {
|
|
||||||
txtLen = negLen
|
|
||||||
}
|
|
||||||
if posEmb.Shape()[1] < txtLen {
|
|
||||||
posEmb = padSequence(posEmb, txtLen)
|
|
||||||
}
|
|
||||||
if negEmb.Shape()[1] < txtLen {
|
|
||||||
negEmb = padSequence(negEmb, txtLen)
|
|
||||||
}
|
|
||||||
mlx.Keep(posEmb, negEmb)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
|
||||||
var batchedEmb *mlx.Array
|
|
||||||
if useCFG {
|
|
||||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
|
||||||
mlx.Keep(batchedEmb)
|
|
||||||
mlx.Eval(batchedEmb)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scheduler
|
|
||||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
|
||||||
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
|
||||||
|
|
||||||
// Init latents [B, C, T, H, W]
|
|
||||||
var latents *mlx.Array
|
|
||||||
{
|
|
||||||
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
|
|
||||||
mlx.Eval(latents)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoPE cache
|
|
||||||
var ropeCache *RoPECache
|
|
||||||
{
|
|
||||||
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
|
|
||||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
|
||||||
mlx.Eval(ropeCache.ImgFreqs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Layer cache for DeepCache/Learning-to-Cache speedup
|
|
||||||
var stepCache *cache.StepCache
|
|
||||||
if cfg.LayerCache {
|
|
||||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
|
||||||
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Denoising loop
|
|
||||||
for i := 0; i < cfg.Steps; i++ {
|
|
||||||
stepStart := time.Now()
|
|
||||||
if cfg.Progress != nil {
|
|
||||||
cfg.Progress(i+1, cfg.Steps)
|
|
||||||
}
|
|
||||||
|
|
||||||
t := scheduler.Timesteps[i]
|
|
||||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
|
||||||
|
|
||||||
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
|
|
||||||
latents2D := mlx.Squeeze(latents, 2)
|
|
||||||
patches := PackLatents(latents2D, tcfg.PatchSize)
|
|
||||||
|
|
||||||
var output *mlx.Array
|
|
||||||
if useCFG {
|
|
||||||
// CFG Batching: single forward pass with batch=2
|
|
||||||
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
|
||||||
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
|
||||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
|
||||||
|
|
||||||
// Single batched forward pass
|
|
||||||
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
|
||||||
|
|
||||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
|
||||||
L := batchedOutput.Shape()[1]
|
|
||||||
D := batchedOutput.Shape()[2]
|
|
||||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
|
||||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
|
||||||
|
|
||||||
diff := mlx.Sub(posOutput, negOutput)
|
|
||||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
|
||||||
combPred := mlx.Add(negOutput, scaledDiff)
|
|
||||||
|
|
||||||
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
|
|
||||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
|
|
||||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
|
||||||
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
|
||||||
} else if stepCache != nil {
|
|
||||||
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
|
|
||||||
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
|
|
||||||
} else {
|
|
||||||
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
|
||||||
}
|
|
||||||
|
|
||||||
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
|
|
||||||
oldLatents := latents
|
|
||||||
latents = scheduler.Step(noisePred, latents, i)
|
|
||||||
|
|
||||||
// Keep cached arrays alive across cleanup
|
|
||||||
if stepCache != nil {
|
|
||||||
mlx.Keep(stepCache.Arrays()...)
|
|
||||||
}
|
|
||||||
mlx.Eval(latents)
|
|
||||||
oldLatents.Free()
|
|
||||||
|
|
||||||
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
|
|
||||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
|
||||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Free denoising temporaries before VAE decode
|
|
||||||
posEmb.Free()
|
|
||||||
if negEmb != nil {
|
|
||||||
negEmb.Free()
|
|
||||||
}
|
|
||||||
if batchedEmb != nil {
|
|
||||||
batchedEmb.Free()
|
|
||||||
}
|
|
||||||
ropeCache.ImgFreqs.Free()
|
|
||||||
ropeCache.TxtFreqs.Free()
|
|
||||||
if stepCache != nil {
|
|
||||||
stepCache.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// VAE decode (Decode manages its own pools for staged memory)
|
|
||||||
decoded := m.VAEDecoder.Decode(latents)
|
|
||||||
latents.Free()
|
|
||||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
|
||||||
{
|
|
||||||
decoded = mlx.Squeeze(decoded, 2)
|
|
||||||
decoded = mlx.AddScalar(decoded, 1.0)
|
|
||||||
decoded = mlx.DivScalar(decoded, 2.0)
|
|
||||||
mlx.Eval(decoded)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
|
||||||
|
|
||||||
return decoded, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// padSequence pads a sequence tensor to the target length with zeros
|
|
||||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
currentLen := shape[1]
|
|
||||||
if currentLen >= targetLen {
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
padLen := targetLen - currentLen
|
|
||||||
// Pad on sequence dimension (axis 1)
|
|
||||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadPersistent is an alias for backward compatibility.
|
|
||||||
// Use m := &Model{}; m.Load(path) instead.
|
|
||||||
func LoadPersistent(modelPath string) (*Model, error) {
|
|
||||||
m := &Model{}
|
|
||||||
if err := m.Load(modelPath); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
@@ -1,218 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
|
|
||||||
type SchedulerConfig struct {
|
|
||||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
|
||||||
BaseShift float32 `json:"base_shift"` // 0.5
|
|
||||||
MaxShift float32 `json:"max_shift"` // 0.9
|
|
||||||
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
|
||||||
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
|
|
||||||
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
|
|
||||||
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
|
|
||||||
func DefaultSchedulerConfig() *SchedulerConfig {
|
|
||||||
return &SchedulerConfig{
|
|
||||||
NumTrainTimesteps: 1000,
|
|
||||||
BaseShift: 0.5,
|
|
||||||
MaxShift: 0.9, // Matches scheduler_config.json
|
|
||||||
BaseImageSeqLen: 256,
|
|
||||||
MaxImageSeqLen: 8192,
|
|
||||||
ShiftTerminal: 0.02,
|
|
||||||
UseDynamicShift: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
|
|
||||||
type FlowMatchScheduler struct {
|
|
||||||
Config *SchedulerConfig
|
|
||||||
Timesteps []float32
|
|
||||||
Sigmas []float32
|
|
||||||
NumSteps int
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFlowMatchScheduler creates a new scheduler
|
|
||||||
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
|
|
||||||
return &FlowMatchScheduler{
|
|
||||||
Config: cfg,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CalculateShift computes the dynamic shift based on image sequence length
|
|
||||||
// This matches Python's calculate_shift function
|
|
||||||
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
|
|
||||||
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
|
|
||||||
b := baseShift - m*float32(baseSeqLen)
|
|
||||||
mu := float32(imageSeqLen)*m + b
|
|
||||||
return mu
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
|
||||||
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
|
|
||||||
// 1. Create sigmas from sigma_max to sigma_min (linspace)
|
|
||||||
// 2. Apply time_shift with mu (if dynamic shifting)
|
|
||||||
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
|
|
||||||
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
|
|
||||||
s.NumSteps = numSteps
|
|
||||||
|
|
||||||
// Calculate mu for dynamic shifting
|
|
||||||
var mu float32
|
|
||||||
if s.Config.UseDynamicShift {
|
|
||||||
mu = CalculateShift(
|
|
||||||
imageSeqLen,
|
|
||||||
s.Config.BaseImageSeqLen,
|
|
||||||
s.Config.MaxImageSeqLen,
|
|
||||||
s.Config.BaseShift,
|
|
||||||
s.Config.MaxShift,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 1: Create sigmas from 1.0 to 1/num_steps
|
|
||||||
// Python (pipeline_qwenimage.py:639):
|
|
||||||
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
|
||||||
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
|
|
||||||
sigmas := make([]float32, numSteps)
|
|
||||||
sigmaMax := float32(1.0)
|
|
||||||
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
|
|
||||||
if numSteps == 1 {
|
|
||||||
sigmas[0] = sigmaMax
|
|
||||||
} else {
|
|
||||||
for i := 0; i < numSteps; i++ {
|
|
||||||
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 2: Apply time shift if using dynamic shifting
|
|
||||||
if s.Config.UseDynamicShift && mu != 0 {
|
|
||||||
for i := range sigmas {
|
|
||||||
sigmas[i] = s.timeShift(mu, sigmas[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 3: Apply stretch_shift_to_terminal
|
|
||||||
if s.Config.ShiftTerminal > 0 {
|
|
||||||
sigmas = s.stretchShiftToTerminal(sigmas)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 4: Append terminal sigma (0) and store
|
|
||||||
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
|
|
||||||
// before passing to transformer. We skip both steps and just use sigmas directly.
|
|
||||||
s.Sigmas = make([]float32, numSteps+1)
|
|
||||||
s.Timesteps = make([]float32, numSteps+1)
|
|
||||||
for i := 0; i < numSteps; i++ {
|
|
||||||
s.Sigmas[i] = sigmas[i]
|
|
||||||
s.Timesteps[i] = sigmas[i]
|
|
||||||
}
|
|
||||||
s.Sigmas[numSteps] = 0.0
|
|
||||||
s.Timesteps[numSteps] = 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
// stretchShiftToTerminal stretches and shifts the timestep schedule
|
|
||||||
// so the final value equals shift_terminal (matches Python behavior)
|
|
||||||
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
|
|
||||||
if len(sigmas) == 0 {
|
|
||||||
return sigmas
|
|
||||||
}
|
|
||||||
|
|
||||||
// one_minus_z = 1 - t
|
|
||||||
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
|
||||||
// stretched_t = 1 - (one_minus_z / scale_factor)
|
|
||||||
lastSigma := sigmas[len(sigmas)-1]
|
|
||||||
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
|
|
||||||
|
|
||||||
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
|
|
||||||
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
|
|
||||||
if scaleFactor < 1e-6 {
|
|
||||||
return sigmas
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]float32, len(sigmas))
|
|
||||||
for i, t := range sigmas {
|
|
||||||
oneMinusZ := 1.0 - t
|
|
||||||
result[i] = 1.0 - (oneMinusZ / scaleFactor)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// timeShift applies the dynamic time shift (exponential)
|
|
||||||
// exp(mu) / (exp(mu) + (1/t - 1))
|
|
||||||
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
|
|
||||||
if t <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
expMu := float32(math.Exp(float64(mu)))
|
|
||||||
return expMu / (expMu + (1.0/t - 1.0))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step performs one denoising step
|
|
||||||
// modelOutput: predicted velocity from the transformer
|
|
||||||
// sample: current noisy sample
|
|
||||||
// timestepIdx: current timestep index
|
|
||||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
|
||||||
// Get current and next sigma
|
|
||||||
sigma := s.Sigmas[timestepIdx]
|
|
||||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
|
||||||
|
|
||||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
|
||||||
dt := sigmaNext - sigma
|
|
||||||
|
|
||||||
// Upcast to float32 to avoid precision issues (matches Python diffusers)
|
|
||||||
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
|
|
||||||
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
|
|
||||||
|
|
||||||
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
|
|
||||||
result := mlx.Add(sampleF32, scaledOutput)
|
|
||||||
|
|
||||||
// Cast back to original dtype
|
|
||||||
return mlx.ToBFloat16(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTimestep returns the timestep value at the given index
|
|
||||||
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
|
|
||||||
if idx < len(s.Timesteps) {
|
|
||||||
return s.Timesteps[idx]
|
|
||||||
}
|
|
||||||
return 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
|
|
||||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
|
||||||
return mlx.RandomNormal(shape, uint64(seed))
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
|
|
||||||
// This matches how Python diffusers generates noise - directly in packed space.
|
|
||||||
// Generating in unpacked format and then packing produces different spatial
|
|
||||||
// correlation structure, which affects model output quality.
|
|
||||||
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
|
|
||||||
shape := []int32{batchSize, seqLen, channels}
|
|
||||||
return mlx.RandomNormal(shape, uint64(seed))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLatentShape returns the latent shape for a given image size
|
|
||||||
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
|
|
||||||
func GetLatentShape(batchSize, height, width int32) []int32 {
|
|
||||||
latentH := height / 8
|
|
||||||
latentW := width / 8
|
|
||||||
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPatchedLatentShape returns the patchified latent shape
|
|
||||||
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
|
|
||||||
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
|
|
||||||
latentH := height / 8
|
|
||||||
latentW := width / 8
|
|
||||||
pH := latentH / patchSize
|
|
||||||
pW := latentW / patchSize
|
|
||||||
inChannels := int32(64) // 16 * patch_size^2
|
|
||||||
return []int32{batchSize, pH * pW, inChannels}
|
|
||||||
}
|
|
||||||
@@ -1,135 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
|
|
||||||
// Golden values generated via:
|
|
||||||
//
|
|
||||||
// python3 -c "
|
|
||||||
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
|
||||||
// import numpy as np
|
|
||||||
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
|
|
||||||
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
|
|
||||||
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
|
|
||||||
// sigmas = np.linspace(1.0, 1.0/30, 30)
|
|
||||||
// s.set_timesteps(sigmas=sigmas, mu=mu)
|
|
||||||
// print(s.sigmas.numpy())"
|
|
||||||
func TestSchedulerSetTimesteps(t *testing.T) {
|
|
||||||
cfg := DefaultSchedulerConfig()
|
|
||||||
scheduler := NewFlowMatchScheduler(cfg)
|
|
||||||
scheduler.SetTimesteps(30, 4096)
|
|
||||||
|
|
||||||
// Golden values from Python diffusers (first 3, last 3 before terminal)
|
|
||||||
wantFirst := []float32{1.000000, 0.982251, 0.963889}
|
|
||||||
wantLast := []float32{0.142924, 0.083384, 0.020000}
|
|
||||||
|
|
||||||
// Check first 3
|
|
||||||
for i, want := range wantFirst {
|
|
||||||
got := scheduler.Sigmas[i]
|
|
||||||
if abs32(got-want) > 1e-4 {
|
|
||||||
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check last 3 (indices 27, 28, 29)
|
|
||||||
for i, want := range wantLast {
|
|
||||||
idx := 27 + i
|
|
||||||
got := scheduler.Sigmas[idx]
|
|
||||||
if abs32(got-want) > 1e-4 {
|
|
||||||
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check terminal is 0
|
|
||||||
if scheduler.Sigmas[30] != 0.0 {
|
|
||||||
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check length
|
|
||||||
if len(scheduler.Sigmas) != 31 {
|
|
||||||
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSchedulerProperties tests mathematical invariants of the scheduler.
|
|
||||||
func TestSchedulerProperties(t *testing.T) {
|
|
||||||
cfg := DefaultSchedulerConfig()
|
|
||||||
scheduler := NewFlowMatchScheduler(cfg)
|
|
||||||
scheduler.SetTimesteps(30, 4096)
|
|
||||||
|
|
||||||
// Property: sigmas monotonically decreasing
|
|
||||||
for i := 1; i < len(scheduler.Sigmas); i++ {
|
|
||||||
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
|
|
||||||
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
|
|
||||||
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Property: first sigma should be ~1.0 (with time shift)
|
|
||||||
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
|
|
||||||
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Property: terminal sigma should be exactly 0
|
|
||||||
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
|
|
||||||
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Property: last non-terminal sigma should be shift_terminal (0.02)
|
|
||||||
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
|
|
||||||
if abs32(lastNonTerminal-0.02) > 1e-5 {
|
|
||||||
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Property: length = steps + 1
|
|
||||||
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
|
|
||||||
t.Errorf("sigmas length should be steps+1: got %d, want %d",
|
|
||||||
len(scheduler.Sigmas), scheduler.NumSteps+1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestCalculateShift verifies the mu calculation against Python reference.
|
|
||||||
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
|
||||||
func TestCalculateShift(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
imgSeqLen int32
|
|
||||||
want float32
|
|
||||||
}{
|
|
||||||
{256, 0.5}, // base case
|
|
||||||
{8192, 0.9}, // max case
|
|
||||||
{4096, 0.6935}, // middle case (rounded)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range cases {
|
|
||||||
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
|
|
||||||
if abs32(got-c.want) > 0.001 {
|
|
||||||
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSchedulerStep verifies the Euler step formula.
|
|
||||||
func TestSchedulerStep(t *testing.T) {
|
|
||||||
cfg := DefaultSchedulerConfig()
|
|
||||||
scheduler := NewFlowMatchScheduler(cfg)
|
|
||||||
scheduler.SetTimesteps(30, 4096)
|
|
||||||
|
|
||||||
// Verify dt calculation for first step
|
|
||||||
sigma0 := scheduler.Sigmas[0]
|
|
||||||
sigma1 := scheduler.Sigmas[1]
|
|
||||||
expectedDt := sigma1 - sigma0
|
|
||||||
|
|
||||||
// dt should be negative (sigmas decrease)
|
|
||||||
if expectedDt >= 0 {
|
|
||||||
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func abs32(x float32) float32 {
|
|
||||||
return float32(math.Abs(float64(x)))
|
|
||||||
}
|
|
||||||
@@ -1,174 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"math"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"slices"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TinyTextEncoderConfig holds config for the tiny test text encoder
|
|
||||||
type TinyTextEncoderConfig struct {
|
|
||||||
HiddenSize int32 `json:"hidden_size"`
|
|
||||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
||||||
IntermediateSize int32 `json:"intermediate_size"`
|
|
||||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
||||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
|
||||||
VocabSize int32 `json:"vocab_size"`
|
|
||||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
|
||||||
HeadDim int32 `json:"head_dim"`
|
|
||||||
MRoPESection []int32 `json:"mrope_section"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadTinyTextEncoder loads the tiny text encoder from testdata
|
|
||||||
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
|
|
||||||
|
|
||||||
// Load config
|
|
||||||
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
|
|
||||||
}
|
|
||||||
|
|
||||||
var tinyCfg TinyTextEncoderConfig
|
|
||||||
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
|
|
||||||
t.Fatalf("Failed to parse config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create encoder config (using Qwen25VLConfig)
|
|
||||||
cfg := &Qwen25VLConfig{
|
|
||||||
HiddenSize: tinyCfg.HiddenSize,
|
|
||||||
NumHiddenLayers: tinyCfg.NumHiddenLayers,
|
|
||||||
IntermediateSize: tinyCfg.IntermediateSize,
|
|
||||||
NumAttentionHeads: tinyCfg.NumAttentionHeads,
|
|
||||||
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
|
|
||||||
VocabSize: tinyCfg.VocabSize,
|
|
||||||
RMSNormEps: tinyCfg.RMSNormEps,
|
|
||||||
RopeTheta: tinyCfg.RopeTheta,
|
|
||||||
HeadDim: tinyCfg.HeadDim,
|
|
||||||
MRoPESection: tinyCfg.MRoPESection,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load weights
|
|
||||||
weights, err := safetensors.LoadModelWeights(testdataDir)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load weights: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
|
||||||
t.Fatalf("Failed to bulk load weights: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build encoder
|
|
||||||
embedding, err := weights.Get("model.embed_tokens.weight")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get embedding: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
|
|
||||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
|
||||||
block, err := newVLTextBlock(weights, int(i), cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load block %d: %v", i, err)
|
|
||||||
}
|
|
||||||
blocks[i] = block
|
|
||||||
}
|
|
||||||
|
|
||||||
finalNorm, err := weights.Get("model.norm.weight")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get final norm: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
encoder := &Qwen25VL{
|
|
||||||
Config: cfg,
|
|
||||||
Embedding: embedding,
|
|
||||||
Blocks: blocks,
|
|
||||||
FinalNorm: finalNorm,
|
|
||||||
HasVision: false, // Text-only mode
|
|
||||||
}
|
|
||||||
|
|
||||||
return encoder, &tinyCfg
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
|
|
||||||
func TestTextEncoderForward(t *testing.T) {
|
|
||||||
encoder, cfg := loadTinyTextEncoder(t)
|
|
||||||
|
|
||||||
// Create test tokens (within vocab range)
|
|
||||||
tokens := []int32{1, 2, 3, 4, 5}
|
|
||||||
|
|
||||||
// Forward pass using EncodeTextOnly
|
|
||||||
out := encoder.EncodeTextOnly(tokens)
|
|
||||||
mlx.Eval(out)
|
|
||||||
|
|
||||||
// Verify output shape: [batch, seq_len, hidden_size]
|
|
||||||
wantShape := []int32{1, 5, cfg.HiddenSize}
|
|
||||||
if !slices.Equal(out.Shape(), wantShape) {
|
|
||||||
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify output is finite (not NaN or Inf)
|
|
||||||
data := out.Data()
|
|
||||||
for i, v := range data {
|
|
||||||
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
|
||||||
t.Errorf("output[%d] is not finite: %v", i, v)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestTextEncoderBatch tests batch processing.
|
|
||||||
func TestTextEncoderBatch(t *testing.T) {
|
|
||||||
encoder, cfg := loadTinyTextEncoder(t)
|
|
||||||
|
|
||||||
// For batch test, we'll use EncodeTextOnly with a single sequence
|
|
||||||
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
|
|
||||||
tokens := []int32{1, 2, 3}
|
|
||||||
|
|
||||||
out := encoder.EncodeTextOnly(tokens)
|
|
||||||
mlx.Eval(out)
|
|
||||||
|
|
||||||
wantShape := []int32{1, 3, cfg.HiddenSize}
|
|
||||||
if !slices.Equal(out.Shape(), wantShape) {
|
|
||||||
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
|
|
||||||
func TestMRoPEComputation(t *testing.T) {
|
|
||||||
encoder, cfg := loadTinyTextEncoder(t)
|
|
||||||
|
|
||||||
cossin := encoder.computeTextRoPE(10, 1)
|
|
||||||
mlx.Eval(cossin[0], cossin[1])
|
|
||||||
|
|
||||||
// Verify shapes: [3, B, L, head_dim]
|
|
||||||
wantShape := []int32{3, 1, 10, cfg.HeadDim}
|
|
||||||
if !slices.Equal(cossin[0].Shape(), wantShape) {
|
|
||||||
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
|
|
||||||
}
|
|
||||||
if !slices.Equal(cossin[1].Shape(), wantShape) {
|
|
||||||
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify cos/sin values are in valid range [-1, 1]
|
|
||||||
cosData := cossin[0].Data()
|
|
||||||
sinData := cossin[1].Data()
|
|
||||||
for i := 0; i < min(100, len(cosData)); i++ {
|
|
||||||
if cosData[i] < -1.01 || cosData[i] > 1.01 {
|
|
||||||
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
|
|
||||||
}
|
|
||||||
if sinData[i] < -1.01 || sinData[i] > 1.01 {
|
|
||||||
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,868 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/cache"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TransformerConfig holds Qwen-Image transformer configuration
|
|
||||||
type TransformerConfig struct {
|
|
||||||
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
|
|
||||||
NHeads int32 `json:"num_attention_heads"` // 24
|
|
||||||
HeadDim int32 `json:"attention_head_dim"` // 128
|
|
||||||
NLayers int32 `json:"num_layers"` // 60
|
|
||||||
InChannels int32 `json:"in_channels"` // 64
|
|
||||||
OutChannels int32 `json:"out_channels"` // 16
|
|
||||||
PatchSize int32 `json:"patch_size"` // 2
|
|
||||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
|
|
||||||
NormEps float32 `json:"norm_eps"` // 1e-6
|
|
||||||
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
|
|
||||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultTransformerConfig returns config for Qwen-Image transformer
|
|
||||||
func defaultTransformerConfig() *TransformerConfig {
|
|
||||||
return &TransformerConfig{
|
|
||||||
HiddenDim: 3072, // 24 * 128
|
|
||||||
NHeads: 24,
|
|
||||||
HeadDim: 128,
|
|
||||||
NLayers: 60,
|
|
||||||
InChannels: 64,
|
|
||||||
OutChannels: 16,
|
|
||||||
PatchSize: 2,
|
|
||||||
JointAttentionDim: 3584,
|
|
||||||
NormEps: 1e-6,
|
|
||||||
AxesDimsRope: []int32{16, 56, 56},
|
|
||||||
GuidanceEmbeds: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TimestepEmbedder creates timestep embeddings
|
|
||||||
type TimestepEmbedder struct {
|
|
||||||
Linear1Weight *mlx.Array // [256, hidden_dim]
|
|
||||||
Linear1Bias *mlx.Array
|
|
||||||
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
|
|
||||||
Linear2Bias *mlx.Array
|
|
||||||
}
|
|
||||||
|
|
||||||
// newTimestepEmbedder creates a timestep embedder from weights
|
|
||||||
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
|
|
||||||
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &TimestepEmbedder{
|
|
||||||
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
|
|
||||||
Linear1Bias: linear1Bias,
|
|
||||||
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
|
|
||||||
Linear2Bias: linear2Bias,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward computes timestep embeddings
|
|
||||||
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
|
|
||||||
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
|
||||||
half := int32(128) // embedding_dim / 2
|
|
||||||
|
|
||||||
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
|
|
||||||
freqs := make([]float32, half)
|
|
||||||
for i := int32(0); i < half; i++ {
|
|
||||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
|
||||||
}
|
|
||||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
|
||||||
|
|
||||||
tExpanded := mlx.ExpandDims(t, 1)
|
|
||||||
args := mlx.Mul(tExpanded, freqsArr)
|
|
||||||
args = mlx.MulScalar(args, 1000.0) // scale
|
|
||||||
|
|
||||||
// [cos, sin] (flip_sin_to_cos=True)
|
|
||||||
sinArgs := mlx.Sin(args)
|
|
||||||
cosArgs := mlx.Cos(args)
|
|
||||||
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
|
|
||||||
|
|
||||||
// MLP: linear1 -> silu -> linear2
|
|
||||||
h := mlx.Linear(embedding, te.Linear1Weight)
|
|
||||||
h = mlx.Add(h, te.Linear1Bias)
|
|
||||||
h = mlx.SiLU(h)
|
|
||||||
h = mlx.Linear(h, te.Linear2Weight)
|
|
||||||
h = mlx.Add(h, te.Linear2Bias)
|
|
||||||
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// JointAttention implements dual-stream joint attention
|
|
||||||
type JointAttention struct {
|
|
||||||
// Image projections
|
|
||||||
ToQ *mlx.Array
|
|
||||||
ToQB *mlx.Array
|
|
||||||
ToK *mlx.Array
|
|
||||||
ToKB *mlx.Array
|
|
||||||
ToV *mlx.Array
|
|
||||||
ToVB *mlx.Array
|
|
||||||
ToOut *mlx.Array
|
|
||||||
ToOutB *mlx.Array
|
|
||||||
NormQ *mlx.Array
|
|
||||||
NormK *mlx.Array
|
|
||||||
|
|
||||||
// Text (added) projections
|
|
||||||
AddQProj *mlx.Array
|
|
||||||
AddQProjB *mlx.Array
|
|
||||||
AddKProj *mlx.Array
|
|
||||||
AddKProjB *mlx.Array
|
|
||||||
AddVProj *mlx.Array
|
|
||||||
AddVProjB *mlx.Array
|
|
||||||
ToAddOut *mlx.Array
|
|
||||||
ToAddOutB *mlx.Array
|
|
||||||
NormAddQ *mlx.Array
|
|
||||||
NormAddK *mlx.Array
|
|
||||||
|
|
||||||
NHeads int32
|
|
||||||
HeadDim int32
|
|
||||||
Scale float32
|
|
||||||
}
|
|
||||||
|
|
||||||
// newJointAttention creates a joint attention layer
|
|
||||||
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
|
|
||||||
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
|
|
||||||
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
|
|
||||||
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
|
|
||||||
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
|
|
||||||
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
|
|
||||||
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
|
|
||||||
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
|
|
||||||
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
|
|
||||||
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
|
|
||||||
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
|
|
||||||
|
|
||||||
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
|
|
||||||
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
|
|
||||||
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
|
|
||||||
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
|
|
||||||
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
|
|
||||||
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
|
|
||||||
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
|
|
||||||
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
|
|
||||||
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
|
|
||||||
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
|
|
||||||
|
|
||||||
return &JointAttention{
|
|
||||||
ToQ: mlx.Transpose(toQ, 1, 0),
|
|
||||||
ToQB: toQB,
|
|
||||||
ToK: mlx.Transpose(toK, 1, 0),
|
|
||||||
ToKB: toKB,
|
|
||||||
ToV: mlx.Transpose(toV, 1, 0),
|
|
||||||
ToVB: toVB,
|
|
||||||
ToOut: mlx.Transpose(toOut, 1, 0),
|
|
||||||
ToOutB: toOutB,
|
|
||||||
NormQ: normQ,
|
|
||||||
NormK: normK,
|
|
||||||
AddQProj: mlx.Transpose(addQProj, 1, 0),
|
|
||||||
AddQProjB: addQProjB,
|
|
||||||
AddKProj: mlx.Transpose(addKProj, 1, 0),
|
|
||||||
AddKProjB: addKProjB,
|
|
||||||
AddVProj: mlx.Transpose(addVProj, 1, 0),
|
|
||||||
AddVProjB: addVProjB,
|
|
||||||
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
|
|
||||||
ToAddOutB: toAddOutB,
|
|
||||||
NormAddQ: normAddQ,
|
|
||||||
NormAddK: normAddK,
|
|
||||||
NHeads: cfg.NHeads,
|
|
||||||
HeadDim: cfg.HeadDim,
|
|
||||||
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward computes joint attention
|
|
||||||
// img: [B, L_img, D], txt: [B, L_txt, D]
|
|
||||||
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
|
|
||||||
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
||||||
imgShape := img.Shape()
|
|
||||||
B := imgShape[0]
|
|
||||||
Limg := imgShape[1]
|
|
||||||
D := imgShape[2]
|
|
||||||
|
|
||||||
txtShape := txt.Shape()
|
|
||||||
Ltxt := txtShape[1]
|
|
||||||
|
|
||||||
// === Image Q/K/V ===
|
|
||||||
imgFlat := mlx.Reshape(img, B*Limg, D)
|
|
||||||
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
|
|
||||||
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
|
|
||||||
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
|
|
||||||
|
|
||||||
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
|
|
||||||
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
|
|
||||||
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
|
|
||||||
|
|
||||||
// QK norm (RMSNorm per head)
|
|
||||||
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
|
|
||||||
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
|
|
||||||
|
|
||||||
// Apply RoPE
|
|
||||||
if imgFreqs != nil {
|
|
||||||
qImg = applyRoPE(qImg, imgFreqs)
|
|
||||||
kImg = applyRoPE(kImg, imgFreqs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// === Text Q/K/V ===
|
|
||||||
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
|
|
||||||
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
|
|
||||||
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
|
|
||||||
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
|
|
||||||
|
|
||||||
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
|
||||||
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
|
||||||
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
|
||||||
|
|
||||||
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
|
|
||||||
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
|
|
||||||
|
|
||||||
if txtFreqs != nil {
|
|
||||||
qTxt = applyRoPE(qTxt, txtFreqs)
|
|
||||||
kTxt = applyRoPE(kTxt, txtFreqs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Concatenate for joint attention: [txt, img] order
|
|
||||||
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
|
|
||||||
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
|
|
||||||
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
|
|
||||||
|
|
||||||
// Transpose to [B, nheads, L, head_dim]
|
|
||||||
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
|
|
||||||
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
|
|
||||||
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
|
|
||||||
|
|
||||||
// SDPA
|
|
||||||
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
|
|
||||||
|
|
||||||
// Transpose back and split
|
|
||||||
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
|
|
||||||
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
|
|
||||||
|
|
||||||
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
|
|
||||||
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
|
|
||||||
|
|
||||||
// Output projections
|
|
||||||
outImg = mlx.Reshape(outImg, B*Limg, D)
|
|
||||||
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
|
|
||||||
outImg = mlx.Reshape(outImg, B, Limg, D)
|
|
||||||
|
|
||||||
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
|
|
||||||
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
|
|
||||||
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
|
|
||||||
|
|
||||||
return outImg, outTxt
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyRoPE applies rotary embeddings using complex multiplication
|
|
||||||
// x: [B, L, nheads, head_dim]
|
|
||||||
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
|
|
||||||
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
L := shape[1]
|
|
||||||
nheads := shape[2]
|
|
||||||
headDim := shape[3]
|
|
||||||
halfDim := headDim / 2
|
|
||||||
|
|
||||||
// Reshape x to pairs: [B, L, nheads, half, 2]
|
|
||||||
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
|
|
||||||
|
|
||||||
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
|
|
||||||
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
|
|
||||||
|
|
||||||
// Extract real/imag parts
|
|
||||||
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
|
||||||
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
|
||||||
xReal = mlx.Squeeze(xReal, 4)
|
|
||||||
xImag = mlx.Squeeze(xImag, 4)
|
|
||||||
|
|
||||||
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
|
||||||
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
|
||||||
freqReal = mlx.Squeeze(freqReal, 4)
|
|
||||||
freqImag = mlx.Squeeze(freqImag, 4)
|
|
||||||
|
|
||||||
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
|
|
||||||
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
|
|
||||||
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
|
|
||||||
|
|
||||||
// Interleave back
|
|
||||||
outReal = mlx.ExpandDims(outReal, 4)
|
|
||||||
outImag = mlx.ExpandDims(outImag, 4)
|
|
||||||
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
|
|
||||||
|
|
||||||
return mlx.Reshape(out, B, L, nheads, headDim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MLP implements GELU MLP (not GEGLU)
|
|
||||||
type MLP struct {
|
|
||||||
ProjWeight *mlx.Array
|
|
||||||
ProjBias *mlx.Array
|
|
||||||
OutWeight *mlx.Array
|
|
||||||
OutBias *mlx.Array
|
|
||||||
}
|
|
||||||
|
|
||||||
// newMLP creates a GELU MLP
|
|
||||||
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
|
|
||||||
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
|
|
||||||
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
|
|
||||||
outWeight, _ := weights.Get(prefix + ".net.2.weight")
|
|
||||||
outBias, _ := weights.Get(prefix + ".net.2.bias")
|
|
||||||
|
|
||||||
return &MLP{
|
|
||||||
ProjWeight: mlx.Transpose(projWeight, 1, 0),
|
|
||||||
ProjBias: projBias,
|
|
||||||
OutWeight: mlx.Transpose(outWeight, 1, 0),
|
|
||||||
OutBias: outBias,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies GELU MLP
|
|
||||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
L := shape[1]
|
|
||||||
D := shape[2]
|
|
||||||
|
|
||||||
xFlat := mlx.Reshape(x, B*L, D)
|
|
||||||
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
|
|
||||||
h = geluApprox(h)
|
|
||||||
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
|
|
||||||
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
|
|
||||||
}
|
|
||||||
|
|
||||||
// geluApprox implements approximate GELU
|
|
||||||
func geluApprox(x *mlx.Array) *mlx.Array {
|
|
||||||
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
|
|
||||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
|
||||||
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
|
|
||||||
inner = mlx.MulScalar(inner, sqrt2OverPi)
|
|
||||||
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TransformerBlock is a single dual-stream transformer block
|
|
||||||
type TransformerBlock struct {
|
|
||||||
Attention *JointAttention
|
|
||||||
ImgMLP *MLP
|
|
||||||
TxtMLP *MLP
|
|
||||||
|
|
||||||
ImgModWeight *mlx.Array
|
|
||||||
ImgModBias *mlx.Array
|
|
||||||
TxtModWeight *mlx.Array
|
|
||||||
TxtModBias *mlx.Array
|
|
||||||
|
|
||||||
HiddenDim int32
|
|
||||||
NormEps float32
|
|
||||||
}
|
|
||||||
|
|
||||||
// newTransformerBlock creates a transformer block
|
|
||||||
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
|
|
||||||
attn, err := newJointAttention(weights, prefix, cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
|
|
||||||
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
|
|
||||||
|
|
||||||
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
|
|
||||||
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
|
|
||||||
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
|
|
||||||
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
|
|
||||||
|
|
||||||
return &TransformerBlock{
|
|
||||||
Attention: attn,
|
|
||||||
ImgMLP: imgMLP,
|
|
||||||
TxtMLP: txtMLP,
|
|
||||||
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
|
|
||||||
ImgModBias: imgModBias,
|
|
||||||
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
|
|
||||||
TxtModBias: txtModBias,
|
|
||||||
HiddenDim: cfg.HiddenDim,
|
|
||||||
NormEps: cfg.NormEps,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies the transformer block
|
|
||||||
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
||||||
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
|
|
||||||
siluT := mlx.SiLU(temb)
|
|
||||||
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
|
|
||||||
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
|
|
||||||
|
|
||||||
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
|
|
||||||
imgModParts := splitMod6(imgMod, tb.HiddenDim)
|
|
||||||
txtModParts := splitMod6(txtMod, tb.HiddenDim)
|
|
||||||
|
|
||||||
// Pre-attention: norm + modulate
|
|
||||||
imgNorm := layerNormNoAffine(img, tb.NormEps)
|
|
||||||
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
|
|
||||||
|
|
||||||
txtNorm := layerNormNoAffine(txt, tb.NormEps)
|
|
||||||
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
|
|
||||||
|
|
||||||
// Joint attention
|
|
||||||
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
|
|
||||||
|
|
||||||
// Residual with gate
|
|
||||||
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
|
|
||||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
|
|
||||||
|
|
||||||
// Pre-MLP: norm + modulate
|
|
||||||
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
|
|
||||||
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
|
|
||||||
|
|
||||||
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
|
|
||||||
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
|
|
||||||
|
|
||||||
// MLP
|
|
||||||
mlpImg := tb.ImgMLP.Forward(imgNorm2)
|
|
||||||
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
|
|
||||||
|
|
||||||
// Residual with gate
|
|
||||||
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
|
|
||||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
|
|
||||||
|
|
||||||
return img, txt
|
|
||||||
}
|
|
||||||
|
|
||||||
// splitMod6 splits modulation into 6 parts each [B, 1, D]
|
|
||||||
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
|
|
||||||
shape := mod.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
parts := make([]*mlx.Array, 6)
|
|
||||||
for i := int32(0); i < 6; i++ {
|
|
||||||
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
|
|
||||||
parts[i] = mlx.ExpandDims(part, 1)
|
|
||||||
}
|
|
||||||
return parts
|
|
||||||
}
|
|
||||||
|
|
||||||
// layerNormNoAffine applies layer norm without learnable parameters
|
|
||||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
|
||||||
ndim := x.Ndim()
|
|
||||||
lastAxis := ndim - 1
|
|
||||||
mean := mlx.Mean(x, lastAxis, true)
|
|
||||||
xCentered := mlx.Sub(x, mean)
|
|
||||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
|
||||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transformer is the full Qwen-Image transformer model
|
|
||||||
type Transformer struct {
|
|
||||||
Config *TransformerConfig
|
|
||||||
|
|
||||||
ImgIn *mlx.Array
|
|
||||||
ImgInBias *mlx.Array
|
|
||||||
TxtIn *mlx.Array
|
|
||||||
TxtInBias *mlx.Array
|
|
||||||
TxtNorm *mlx.Array
|
|
||||||
|
|
||||||
TEmbed *TimestepEmbedder
|
|
||||||
Layers []*TransformerBlock
|
|
||||||
|
|
||||||
NormOutWeight *mlx.Array
|
|
||||||
NormOutBias *mlx.Array
|
|
||||||
ProjOut *mlx.Array
|
|
||||||
ProjOutBias *mlx.Array
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load loads the transformer from a directory
|
|
||||||
func (m *Transformer) Load(path string) error {
|
|
||||||
fmt.Println("Loading Qwen-Image transformer...")
|
|
||||||
|
|
||||||
cfg := defaultTransformerConfig()
|
|
||||||
m.Config = cfg
|
|
||||||
|
|
||||||
weights, err := safetensors.LoadModelWeights(path)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("weights: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bulk load all weights as bf16
|
|
||||||
fmt.Print(" Loading weights as bf16... ")
|
|
||||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
|
||||||
return fmt.Errorf("load weights: %w", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
|
||||||
|
|
||||||
fmt.Print(" Loading input projections... ")
|
|
||||||
imgIn, _ := weights.Get("img_in.weight")
|
|
||||||
imgInBias, _ := weights.Get("img_in.bias")
|
|
||||||
txtIn, _ := weights.Get("txt_in.weight")
|
|
||||||
txtInBias, _ := weights.Get("txt_in.bias")
|
|
||||||
txtNorm, _ := weights.Get("txt_norm.weight")
|
|
||||||
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
|
|
||||||
m.ImgInBias = imgInBias
|
|
||||||
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
|
|
||||||
m.TxtInBias = txtInBias
|
|
||||||
m.TxtNorm = txtNorm
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
fmt.Print(" Loading timestep embedder... ")
|
|
||||||
m.TEmbed, err = newTimestepEmbedder(weights)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("timestep embedder: %w", err)
|
|
||||||
}
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
|
||||||
for i := int32(0); i < cfg.NLayers; i++ {
|
|
||||||
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
|
|
||||||
prefix := fmt.Sprintf("transformer_blocks.%d", i)
|
|
||||||
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("layer %d: %w", i, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
|
|
||||||
|
|
||||||
fmt.Print(" Loading output layers... ")
|
|
||||||
normOutWeight, _ := weights.Get("norm_out.linear.weight")
|
|
||||||
normOutBias, _ := weights.Get("norm_out.linear.bias")
|
|
||||||
projOut, _ := weights.Get("proj_out.weight")
|
|
||||||
projOutBias, _ := weights.Get("proj_out.bias")
|
|
||||||
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
|
|
||||||
m.NormOutBias = normOutBias
|
|
||||||
m.ProjOut = mlx.Transpose(projOut, 1, 0)
|
|
||||||
m.ProjOutBias = projOutBias
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
weights.ReleaseAll()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadFromPath is a convenience function to load transformer from path
|
|
||||||
func LoadTransformerFromPath(path string) (*Transformer, error) {
|
|
||||||
m := &Transformer{}
|
|
||||||
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward runs the transformer
|
|
||||||
// img: [B, L_img, in_channels] patchified latents
|
|
||||||
// txt: [B, L_txt, joint_attention_dim] text embeddings
|
|
||||||
// t: [B] timesteps (0-1)
|
|
||||||
// imgFreqs, txtFreqs: RoPE frequencies
|
|
||||||
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
|
|
||||||
imgShape := img.Shape()
|
|
||||||
B := imgShape[0]
|
|
||||||
Limg := imgShape[1]
|
|
||||||
|
|
||||||
txtShape := txt.Shape()
|
|
||||||
Ltxt := txtShape[1]
|
|
||||||
|
|
||||||
// Timestep embedding
|
|
||||||
temb := tr.TEmbed.Forward(t)
|
|
||||||
|
|
||||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
|
||||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
|
||||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
|
||||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
|
||||||
|
|
||||||
// Project text: RMSNorm then linear
|
|
||||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
|
||||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
|
||||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
|
||||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
|
||||||
|
|
||||||
for _, layer := range tr.Layers {
|
|
||||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final norm with modulation (AdaLayerNormContinuous)
|
|
||||||
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
|
|
||||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
|
||||||
modShape := finalMod.Shape()
|
|
||||||
halfDim := modShape[1] / 2
|
|
||||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
|
||||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
|
||||||
|
|
||||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
|
||||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
|
||||||
|
|
||||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
|
||||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
|
||||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
|
||||||
|
|
||||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
|
||||||
return mlx.Reshape(out, B, Limg, outChannels)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForwardWithCache runs the transformer with layer caching for speedup.
|
|
||||||
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
|
|
||||||
// shallow layers change little between denoising steps, so we cache their
|
|
||||||
// outputs and reuse them on non-refresh steps.
|
|
||||||
//
|
|
||||||
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
|
|
||||||
// step: current denoising step (0-indexed)
|
|
||||||
// cacheInterval: refresh cache every N steps (e.g., 3)
|
|
||||||
// cacheLayers: number of shallow layers to cache (e.g., 15)
|
|
||||||
func (tr *Transformer) ForwardWithCache(
|
|
||||||
img, txt, t *mlx.Array,
|
|
||||||
imgFreqs, txtFreqs *mlx.Array,
|
|
||||||
stepCache *cache.StepCache,
|
|
||||||
step, cacheInterval, cacheLayers int,
|
|
||||||
) *mlx.Array {
|
|
||||||
imgShape := img.Shape()
|
|
||||||
B := imgShape[0]
|
|
||||||
Limg := imgShape[1]
|
|
||||||
|
|
||||||
txtShape := txt.Shape()
|
|
||||||
Ltxt := txtShape[1]
|
|
||||||
|
|
||||||
// Timestep embedding
|
|
||||||
temb := tr.TEmbed.Forward(t)
|
|
||||||
|
|
||||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
|
||||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
|
||||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
|
||||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
|
||||||
|
|
||||||
// Project text: RMSNorm then linear
|
|
||||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
|
||||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
|
||||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
|
||||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
|
||||||
|
|
||||||
// Check if we should refresh the cache
|
|
||||||
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
|
|
||||||
|
|
||||||
for i, layer := range tr.Layers {
|
|
||||||
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
|
|
||||||
// Use cached outputs for shallow layers
|
|
||||||
imgH = stepCache.Get(i)
|
|
||||||
txtH = stepCache.Get2(i)
|
|
||||||
} else {
|
|
||||||
// Compute layer
|
|
||||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
|
||||||
// Cache shallow layers on refresh steps
|
|
||||||
if i < cacheLayers && refreshCache {
|
|
||||||
stepCache.Set(i, imgH)
|
|
||||||
stepCache.Set2(i, txtH)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final norm with modulation (AdaLayerNormContinuous)
|
|
||||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
|
||||||
modShape := finalMod.Shape()
|
|
||||||
halfDim := modShape[1] / 2
|
|
||||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
|
||||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
|
||||||
|
|
||||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
|
||||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
|
||||||
|
|
||||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
|
||||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
|
||||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
|
||||||
|
|
||||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
|
||||||
return mlx.Reshape(out, B, Limg, outChannels)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoPECache holds precomputed RoPE frequencies
|
|
||||||
type RoPECache struct {
|
|
||||||
ImgFreqs *mlx.Array // [L_img, head_dim]
|
|
||||||
TxtFreqs *mlx.Array // [L_txt, head_dim]
|
|
||||||
}
|
|
||||||
|
|
||||||
// PrepareRoPE computes RoPE for image and text sequences
|
|
||||||
// This matches Python's QwenEmbedRope with scale_rope=True
|
|
||||||
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
|
|
||||||
theta := float64(10000)
|
|
||||||
maxIdx := int32(4096)
|
|
||||||
|
|
||||||
// Compute base frequencies for each axis dimension
|
|
||||||
freqsT := ComputeAxisFreqs(axesDims[0], theta)
|
|
||||||
freqsH := ComputeAxisFreqs(axesDims[1], theta)
|
|
||||||
freqsW := ComputeAxisFreqs(axesDims[2], theta)
|
|
||||||
|
|
||||||
// Build frequency lookup tables
|
|
||||||
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
|
|
||||||
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
|
|
||||||
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
|
|
||||||
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
|
|
||||||
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
|
|
||||||
|
|
||||||
// Image frequencies with scale_rope=True
|
|
||||||
imgLen := imgH * imgW
|
|
||||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
|
||||||
imgFreqsData := make([]float32, imgLen*headDim)
|
|
||||||
|
|
||||||
hHalf := imgH / 2
|
|
||||||
wHalf := imgW / 2
|
|
||||||
|
|
||||||
idx := int32(0)
|
|
||||||
for y := int32(0); y < imgH; y++ {
|
|
||||||
for x := int32(0); x < imgW; x++ {
|
|
||||||
// Frame = 0
|
|
||||||
for i := 0; i < len(freqsT)*2; i++ {
|
|
||||||
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
|
|
||||||
}
|
|
||||||
idx += int32(len(freqsT) * 2)
|
|
||||||
|
|
||||||
// Height: scale_rope pattern
|
|
||||||
hNegCount := imgH - hHalf
|
|
||||||
if y < hNegCount {
|
|
||||||
negTableIdx := maxIdx - hNegCount + y
|
|
||||||
for i := 0; i < len(freqsH)*2; i++ {
|
|
||||||
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
posIdx := y - hNegCount
|
|
||||||
for i := 0; i < len(freqsH)*2; i++ {
|
|
||||||
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
idx += int32(len(freqsH) * 2)
|
|
||||||
|
|
||||||
// Width: scale_rope pattern
|
|
||||||
wNegCount := imgW - wHalf
|
|
||||||
if x < wNegCount {
|
|
||||||
negTableIdx := maxIdx - wNegCount + x
|
|
||||||
for i := 0; i < len(freqsW)*2; i++ {
|
|
||||||
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
posIdx := x - wNegCount
|
|
||||||
for i := 0; i < len(freqsW)*2; i++ {
|
|
||||||
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
idx += int32(len(freqsW) * 2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
|
|
||||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
|
||||||
|
|
||||||
// Text frequencies
|
|
||||||
maxVidIdx := max(hHalf, wHalf)
|
|
||||||
txtFreqsData := make([]float32, txtLen*headDim)
|
|
||||||
|
|
||||||
idx = 0
|
|
||||||
for t := int32(0); t < txtLen; t++ {
|
|
||||||
pos := maxVidIdx + t
|
|
||||||
for i := 0; i < len(freqsT)*2; i++ {
|
|
||||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
|
||||||
}
|
|
||||||
idx += int32(len(freqsT) * 2)
|
|
||||||
for i := 0; i < len(freqsH)*2; i++ {
|
|
||||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
|
||||||
}
|
|
||||||
idx += int32(len(freqsH) * 2)
|
|
||||||
for i := 0; i < len(freqsW)*2; i++ {
|
|
||||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
|
||||||
}
|
|
||||||
idx += int32(len(freqsW) * 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
|
||||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
|
||||||
|
|
||||||
return &RoPECache{
|
|
||||||
ImgFreqs: imgFreqs,
|
|
||||||
TxtFreqs: txtFreqs,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
|
|
||||||
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
|
|
||||||
halfDim := dim / 2
|
|
||||||
freqs := make([]float64, halfDim)
|
|
||||||
for i := int32(0); i < halfDim; i++ {
|
|
||||||
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
|
|
||||||
}
|
|
||||||
return freqs
|
|
||||||
}
|
|
||||||
|
|
||||||
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
|
|
||||||
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
|
|
||||||
table := make([][]float32, maxIdx)
|
|
||||||
for idx := int32(0); idx < maxIdx; idx++ {
|
|
||||||
var pos float64
|
|
||||||
if negative {
|
|
||||||
pos = float64(-maxIdx + int32(idx))
|
|
||||||
} else {
|
|
||||||
pos = float64(idx)
|
|
||||||
}
|
|
||||||
|
|
||||||
row := make([]float32, len(baseFreqs)*2)
|
|
||||||
for i, f := range baseFreqs {
|
|
||||||
angle := pos * f
|
|
||||||
row[i*2] = float32(math.Cos(angle))
|
|
||||||
row[i*2+1] = float32(math.Sin(angle))
|
|
||||||
}
|
|
||||||
table[idx] = row
|
|
||||||
}
|
|
||||||
return table
|
|
||||||
}
|
|
||||||
|
|
||||||
func max(a, b int32) int32 {
|
|
||||||
if a > b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
|
|
||||||
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
|
||||||
shape := latents.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
C := shape[1]
|
|
||||||
H := shape[2]
|
|
||||||
W := shape[3]
|
|
||||||
|
|
||||||
pH := H / patchSize
|
|
||||||
pW := W / patchSize
|
|
||||||
|
|
||||||
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
|
|
||||||
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
|
||||||
// -> [B, pH, pW, C, 2, 2]
|
|
||||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
|
||||||
// -> [B, pH*pW, C*4]
|
|
||||||
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
|
|
||||||
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
|
|
||||||
shape := patches.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
channels := shape[2] / (patchSize * patchSize)
|
|
||||||
|
|
||||||
pH := H / patchSize
|
|
||||||
pW := W / patchSize
|
|
||||||
|
|
||||||
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
|
|
||||||
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
|
||||||
// -> [B, C, pH, 2, pW, 2]
|
|
||||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
|
||||||
// -> [B, C, H, W]
|
|
||||||
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
|
||||||
// Add temporal dimension for VAE: [B, C, 1, H, W]
|
|
||||||
return mlx.ExpandDims(x, 2)
|
|
||||||
}
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestTransformerConfig tests configuration invariants.
|
|
||||||
func TestTransformerConfig(t *testing.T) {
|
|
||||||
cfg := defaultTransformerConfig()
|
|
||||||
|
|
||||||
// Property: hidden_dim = n_heads * head_dim
|
|
||||||
if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
|
|
||||||
t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
|
|
||||||
cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Property: axes_dims_rope sums to head_dim
|
|
||||||
var ropeSum int32
|
|
||||||
for _, d := range cfg.AxesDimsRope {
|
|
||||||
ropeSum += d
|
|
||||||
}
|
|
||||||
if ropeSum != cfg.HeadDim {
|
|
||||||
t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Property: in_channels = out_channels * patch_size^2
|
|
||||||
expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
|
|
||||||
if cfg.InChannels != expectedIn {
|
|
||||||
t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestTransformerRoPE tests RoPE frequency computation produces valid values.
|
|
||||||
func TestTransformerRoPE(t *testing.T) {
|
|
||||||
cfg := defaultTransformerConfig()
|
|
||||||
|
|
||||||
// Test with small image dimensions
|
|
||||||
imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
|
|
||||||
txtLen := int32(5)
|
|
||||||
|
|
||||||
ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
|
|
||||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
|
||||||
|
|
||||||
// Verify shapes: [seq_len, head_dim]
|
|
||||||
imgSeqLen := imgH * imgW
|
|
||||||
if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
|
|
||||||
t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
|
|
||||||
}
|
|
||||||
if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
|
|
||||||
t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ropeCache.TxtFreqs.Shape()[0] != txtLen {
|
|
||||||
t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify values are finite
|
|
||||||
imgData := ropeCache.ImgFreqs.Data()
|
|
||||||
for i := 0; i < min(100, len(imgData)); i++ {
|
|
||||||
if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
|
|
||||||
t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestTransformerForward tests full forward pass (integration test).
|
|
||||||
// Skips if model weights are not available.
|
|
||||||
func TestTransformerForward(t *testing.T) {
|
|
||||||
weightsPath := "../../../weights/Qwen-Image-2512/transformer"
|
|
||||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
|
||||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
transformer := &Transformer{}
|
|
||||||
if err := transformer.Load(weightsPath); err != nil {
|
|
||||||
t.Fatalf("Failed to load transformer: %v", err)
|
|
||||||
}
|
|
||||||
mlx.Keep(mlx.Collect(transformer)...)
|
|
||||||
cfg := transformer.Config
|
|
||||||
|
|
||||||
// Small test inputs
|
|
||||||
batchSize := int32(1)
|
|
||||||
imgH, imgW := int32(4), int32(4)
|
|
||||||
imgSeqLen := imgH * imgW
|
|
||||||
txtSeqLen := int32(5)
|
|
||||||
|
|
||||||
hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
|
|
||||||
encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
|
|
||||||
timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
|
|
||||||
|
|
||||||
ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
|
|
||||||
|
|
||||||
// Forward pass
|
|
||||||
out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
|
||||||
mlx.Eval(out)
|
|
||||||
|
|
||||||
// Verify output shape: [batch, img_seq_len, in_channels]
|
|
||||||
wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
|
|
||||||
gotShape := out.Shape()
|
|
||||||
if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
|
|
||||||
t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify output is finite
|
|
||||||
outData := out.Data()
|
|
||||||
for i := 0; i < min(100, len(outData)); i++ {
|
|
||||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
|
||||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,854 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// VAEConfig holds Qwen-Image VAE configuration
|
|
||||||
type VAEConfig struct {
|
|
||||||
ZDim int32 `json:"z_dim"` // 16
|
|
||||||
BaseDim int32 `json:"base_dim"` // 96
|
|
||||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
|
||||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
|
||||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
|
||||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
|
||||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
|
||||||
func defaultVAEConfig() *VAEConfig {
|
|
||||||
return &VAEConfig{
|
|
||||||
ZDim: 16,
|
|
||||||
BaseDim: 96,
|
|
||||||
DimMult: []int32{1, 2, 4, 4},
|
|
||||||
NumResBlocks: 2,
|
|
||||||
LatentsMean: []float32{
|
|
||||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
|
||||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
|
||||||
0.4134, -0.0715, 0.5517, -0.3632,
|
|
||||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
|
||||||
},
|
|
||||||
LatentsStd: []float32{
|
|
||||||
2.8184, 1.4541, 2.3275, 2.6558,
|
|
||||||
1.2196, 1.7708, 2.6052, 2.0743,
|
|
||||||
3.2687, 2.1526, 2.8652, 1.5579,
|
|
||||||
1.6382, 1.1253, 2.8251, 1.916,
|
|
||||||
},
|
|
||||||
TemperalDownsample: []bool{false, true, true},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
|
||||||
type CausalConv3d struct {
|
|
||||||
Weight *mlx.Array
|
|
||||||
Bias *mlx.Array
|
|
||||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
|
||||||
KernelT int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// newCausalConv3d creates a 3D causal conv
|
|
||||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
|
||||||
weight, err := weights.Get(prefix + ".weight")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
|
||||||
}
|
|
||||||
bias, _ := weights.Get(prefix + ".bias")
|
|
||||||
|
|
||||||
kernelT := weight.Shape()[2]
|
|
||||||
outC := weight.Shape()[0]
|
|
||||||
|
|
||||||
var biasReshaped *mlx.Array
|
|
||||||
if bias != nil {
|
|
||||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &CausalConv3d{
|
|
||||||
Weight: weight,
|
|
||||||
Bias: bias,
|
|
||||||
BiasReshaped: biasReshaped,
|
|
||||||
KernelT: kernelT,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies causal 3D convolution
|
|
||||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
|
||||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
|
|
||||||
kernelT := shape[2]
|
|
||||||
kernelH := shape[3]
|
|
||||||
kernelW := shape[4]
|
|
||||||
|
|
||||||
// Causal temporal padding, same spatial padding
|
|
||||||
// Input is channels-last: [B, T, H, W, C]
|
|
||||||
padT := kernelT - 1
|
|
||||||
padH := kernelH / 2
|
|
||||||
padW := kernelW / 2
|
|
||||||
|
|
||||||
// Stage 1: Pad
|
|
||||||
{
|
|
||||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
|
||||||
mlx.Eval(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 2: Conv + bias
|
|
||||||
var out *mlx.Array
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
|
||||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
|
||||||
if c.Bias != nil {
|
|
||||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
|
||||||
out = mlx.Add(out, bias)
|
|
||||||
}
|
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// RMSNorm3D applies RMS normalization over channels
|
|
||||||
// Works with channels-last [B, T, H, W, C] format
|
|
||||||
type RMSNorm3D struct {
|
|
||||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
|
||||||
}
|
|
||||||
|
|
||||||
// newRMSNorm3D creates an RMS norm
|
|
||||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
|
||||||
gamma, err := weights.Get(prefix + ".gamma")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
|
|
||||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
|
||||||
return &RMSNorm3D{Gamma: gamma}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
|
||||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
// RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
|
|
||||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
|
||||||
return mlx.Mul(normalized, n.Gamma)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResBlock is a residual block with RMS norm and causal convs
|
|
||||||
type ResBlock struct {
|
|
||||||
Norm1 *RMSNorm3D
|
|
||||||
Conv1 *CausalConv3d
|
|
||||||
Norm2 *RMSNorm3D
|
|
||||||
Conv2 *CausalConv3d
|
|
||||||
Shortcut *CausalConv3d
|
|
||||||
}
|
|
||||||
|
|
||||||
// newResBlock creates a residual block
|
|
||||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
|
||||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var shortcut *CausalConv3d
|
|
||||||
if inDim != outDim {
|
|
||||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ResBlock{
|
|
||||||
Norm1: norm1,
|
|
||||||
Conv1: conv1,
|
|
||||||
Norm2: norm2,
|
|
||||||
Conv2: conv2,
|
|
||||||
Shortcut: shortcut,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies the residual block
|
|
||||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
// Use h as working variable, keep x intact for residual (caller will free x)
|
|
||||||
// Conv handles its own pools, so we just need pools for non-conv operations
|
|
||||||
var h *mlx.Array
|
|
||||||
|
|
||||||
// Keep x so it survives Eval() cleanup - needed for residual connection
|
|
||||||
mlx.Keep(x)
|
|
||||||
|
|
||||||
// Stage 1: norm1 + silu
|
|
||||||
{
|
|
||||||
h = r.Norm1.Forward(x)
|
|
||||||
h = silu3D(h)
|
|
||||||
mlx.Eval(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 2: conv1 (handles its own pools)
|
|
||||||
{
|
|
||||||
prev := h
|
|
||||||
h = r.Conv1.Forward(h)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 3: norm2 + silu
|
|
||||||
{
|
|
||||||
prev := h
|
|
||||||
h = r.Norm2.Forward(h)
|
|
||||||
h = silu3D(h)
|
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 4: conv2 (handles its own pools)
|
|
||||||
{
|
|
||||||
prev := h
|
|
||||||
h = r.Conv2.Forward(h)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Residual connection (shortcut handles its own pools if present)
|
|
||||||
if r.Shortcut != nil {
|
|
||||||
shortcut := r.Shortcut.Forward(x)
|
|
||||||
h = mlx.Add(h, shortcut)
|
|
||||||
mlx.Eval(h)
|
|
||||||
} else {
|
|
||||||
h = mlx.Add(h, x)
|
|
||||||
mlx.Eval(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// AttentionBlock is a 2D attention block
|
|
||||||
type AttentionBlock struct {
|
|
||||||
Norm *RMSNorm3D
|
|
||||||
ToQKV *mlx.Array
|
|
||||||
ToQKVBias *mlx.Array
|
|
||||||
Proj *mlx.Array
|
|
||||||
ProjBias *mlx.Array
|
|
||||||
Dim int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// newAttentionBlock creates an attention block
|
|
||||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
|
||||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
|
||||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
|
||||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
|
||||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
|
||||||
|
|
||||||
return &AttentionBlock{
|
|
||||||
Norm: norm,
|
|
||||||
ToQKV: toQKV,
|
|
||||||
ToQKVBias: toQKVBias,
|
|
||||||
Proj: proj,
|
|
||||||
ProjBias: projBias,
|
|
||||||
Dim: dim,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies 2D attention
|
|
||||||
// Input: [B, T, H, W, C] (channels-last)
|
|
||||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
T := shape[1]
|
|
||||||
H := shape[2]
|
|
||||||
W := shape[3]
|
|
||||||
C := shape[4]
|
|
||||||
|
|
||||||
identity := x
|
|
||||||
|
|
||||||
// Flatten to [B*T, 1, H, W, C] for norm
|
|
||||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
|
||||||
x = a.Norm.Forward(x)
|
|
||||||
x = mlx.Reshape(x, B*T, H, W, C)
|
|
||||||
|
|
||||||
// Flatten spatial to [B*T, H*W, C]
|
|
||||||
x = mlx.Reshape(x, B*T, H*W, C)
|
|
||||||
|
|
||||||
// Linear to get Q, K, V: [B*T, H*W, 3*C]
|
|
||||||
// Weight is [outC, inC] or [outC, inC, 1, 1]
|
|
||||||
wShape := a.ToQKV.Shape()
|
|
||||||
var w *mlx.Array
|
|
||||||
if len(wShape) == 4 {
|
|
||||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
|
||||||
} else {
|
|
||||||
w = a.ToQKV
|
|
||||||
}
|
|
||||||
w = mlx.Transpose(w, 1, 0) // [inC, outC]
|
|
||||||
|
|
||||||
qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
|
|
||||||
if a.ToQKVBias != nil {
|
|
||||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
|
||||||
}
|
|
||||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
|
||||||
|
|
||||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
|
||||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
|
||||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
|
||||||
|
|
||||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
|
||||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
|
||||||
|
|
||||||
// out: [B*T, 1, H*W, C]
|
|
||||||
out = mlx.Reshape(out, B*T, H*W, C)
|
|
||||||
|
|
||||||
// Project back
|
|
||||||
pShape := a.Proj.Shape()
|
|
||||||
var p *mlx.Array
|
|
||||||
if len(pShape) == 4 {
|
|
||||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
|
||||||
} else {
|
|
||||||
p = a.Proj
|
|
||||||
}
|
|
||||||
p = mlx.Transpose(p, 1, 0) // [inC, outC]
|
|
||||||
out = mlx.Linear(out, p) // [B*T, H*W, C]
|
|
||||||
if a.ProjBias != nil {
|
|
||||||
out = mlx.Add(out, a.ProjBias)
|
|
||||||
}
|
|
||||||
|
|
||||||
out = mlx.Reshape(out, B, T, H, W, C)
|
|
||||||
return mlx.Add(out, identity)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpBlock handles upsampling in decoder
|
|
||||||
type UpBlock struct {
|
|
||||||
ResBlocks []*ResBlock
|
|
||||||
Upsampler *Upsample
|
|
||||||
}
|
|
||||||
|
|
||||||
// newUpBlock creates an up block
|
|
||||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
|
||||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
|
||||||
|
|
||||||
currentDim := inDim
|
|
||||||
for i := int32(0); i <= numBlocks; i++ {
|
|
||||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
|
||||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resBlocks[i] = block
|
|
||||||
currentDim = outDim
|
|
||||||
}
|
|
||||||
|
|
||||||
var upsampler *Upsample
|
|
||||||
if upsampleMode != "" {
|
|
||||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UpBlock{
|
|
||||||
ResBlocks: resBlocks,
|
|
||||||
Upsampler: upsampler,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies up block with staged memory management
|
|
||||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
// ResBlocks handle their own pools
|
|
||||||
for _, block := range u.ResBlocks {
|
|
||||||
prev := x
|
|
||||||
x = block.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upsampler handles its own pools
|
|
||||||
if u.Upsampler != nil {
|
|
||||||
prev := x
|
|
||||||
x = u.Upsampler.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upsample handles spatial upsampling
|
|
||||||
type Upsample struct {
|
|
||||||
Conv *mlx.Array
|
|
||||||
Bias *mlx.Array
|
|
||||||
Mode string
|
|
||||||
}
|
|
||||||
|
|
||||||
// newUpsample creates an upsampler
|
|
||||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
|
||||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
|
||||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
|
||||||
return &Upsample{
|
|
||||||
Conv: conv,
|
|
||||||
Bias: bias,
|
|
||||||
Mode: mode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
|
||||||
// Uses staged pools to reduce peak memory during 2x upsampling
|
|
||||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
T := shape[1]
|
|
||||||
H := shape[2]
|
|
||||||
W := shape[3]
|
|
||||||
C := shape[4]
|
|
||||||
outC := u.Conv.Shape()[0]
|
|
||||||
|
|
||||||
// Stage 1: 2x nearest neighbor upsample
|
|
||||||
{
|
|
||||||
x = mlx.Reshape(x, B*T, H, W, C)
|
|
||||||
x = upsample2xChannelsLast(x)
|
|
||||||
mlx.Eval(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 2: Conv + bias
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
|
||||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
|
||||||
if u.Bias != nil {
|
|
||||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
|
||||||
x = mlx.Add(x, bias)
|
|
||||||
}
|
|
||||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// MidBlock is the middle block of decoder
|
|
||||||
type MidBlock struct {
|
|
||||||
ResBlock1 *ResBlock
|
|
||||||
Attention *AttentionBlock
|
|
||||||
ResBlock2 *ResBlock
|
|
||||||
}
|
|
||||||
|
|
||||||
// newMidBlock creates a mid block
|
|
||||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
|
||||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &MidBlock{
|
|
||||||
ResBlock1: res1,
|
|
||||||
Attention: attn,
|
|
||||||
ResBlock2: res2,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies mid block
|
|
||||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
// Each component handles its own pools; we just free inputs
|
|
||||||
prev := x
|
|
||||||
x = m.ResBlock1.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
|
|
||||||
prev = x
|
|
||||||
x = m.Attention.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
|
|
||||||
prev = x
|
|
||||||
x = m.ResBlock2.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// VAEDecoder is the full VAE decoder
|
|
||||||
type VAEDecoder struct {
|
|
||||||
Config *VAEConfig
|
|
||||||
|
|
||||||
PostQuantConv *CausalConv3d
|
|
||||||
ConvIn *CausalConv3d
|
|
||||||
MidBlock *MidBlock
|
|
||||||
UpBlocks []*UpBlock
|
|
||||||
NormOut *RMSNorm3D
|
|
||||||
ConvOut *CausalConv3d
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load loads the VAE decoder from a directory
|
|
||||||
func (m *VAEDecoder) Load(path string) error {
|
|
||||||
fmt.Println("Loading Qwen-Image VAE decoder...")
|
|
||||||
|
|
||||||
cfg := defaultVAEConfig()
|
|
||||||
m.Config = cfg
|
|
||||||
|
|
||||||
weights, err := safetensors.LoadModelWeights(path)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("weights: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bulk load all weights as bf16
|
|
||||||
fmt.Print(" Loading weights as bf16... ")
|
|
||||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
|
||||||
return fmt.Errorf("failed to load weights: %w", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
|
||||||
|
|
||||||
fmt.Print(" Loading post_quant_conv... ")
|
|
||||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.PostQuantConv = postQuantConv
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
fmt.Print(" Loading conv_in... ")
|
|
||||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.ConvIn = convIn
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
// Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
|
|
||||||
fmt.Print(" Loading mid_block... ")
|
|
||||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
|
||||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.MidBlock = midBlock
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
// Up blocks (reversed dim_mult)
|
|
||||||
fmt.Print(" Loading up_blocks... ")
|
|
||||||
numUpBlocks := len(cfg.DimMult)
|
|
||||||
m.UpBlocks = make([]*UpBlock, numUpBlocks)
|
|
||||||
|
|
||||||
dimsMult := make([]int32, numUpBlocks+1)
|
|
||||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
|
||||||
for i := 0; i < numUpBlocks; i++ {
|
|
||||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
|
||||||
}
|
|
||||||
|
|
||||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
|
||||||
for i := range cfg.TemperalDownsample {
|
|
||||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < numUpBlocks; i++ {
|
|
||||||
inDim := cfg.BaseDim * dimsMult[i]
|
|
||||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
|
||||||
|
|
||||||
if i > 0 {
|
|
||||||
inDim = inDim / 2
|
|
||||||
}
|
|
||||||
|
|
||||||
upsampleMode := ""
|
|
||||||
if i < numUpBlocks-1 {
|
|
||||||
if temporalUpsample[i] {
|
|
||||||
upsampleMode = "upsample3d"
|
|
||||||
} else {
|
|
||||||
upsampleMode = "upsample2d"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
|
||||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.UpBlocks[i] = upBlock
|
|
||||||
}
|
|
||||||
fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
|
|
||||||
|
|
||||||
fmt.Print(" Loading output layers... ")
|
|
||||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.NormOut = normOut
|
|
||||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.ConvOut = convOut
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
weights.ReleaseAll()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadVAEDecoderFromPath is a convenience function to load VAE from path
|
|
||||||
func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
|
|
||||||
m := &VAEDecoder{}
|
|
||||||
if err := m.Load(filepath.Join(path, "vae")); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode converts latents to image
|
|
||||||
// z: [B, C, T, H, W] normalized latents
|
|
||||||
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
|
||||||
func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
|
||||||
var x *mlx.Array
|
|
||||||
|
|
||||||
// Stage 1a: Denormalize and transpose
|
|
||||||
{
|
|
||||||
z = vae.Denormalize(z)
|
|
||||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
|
||||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
|
||||||
mlx.Eval(z)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 1b: PostQuantConv (handles its own pools)
|
|
||||||
x = vae.PostQuantConv.Forward(z)
|
|
||||||
z.Free()
|
|
||||||
|
|
||||||
// Stage 1c: ConvIn (handles its own pools)
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
x = vae.ConvIn.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 2: Mid block (handles its own pools)
|
|
||||||
x = vae.MidBlock.Forward(x)
|
|
||||||
|
|
||||||
// Stage 3: Up blocks (each handles its own pools)
|
|
||||||
for _, upBlock := range vae.UpBlocks {
|
|
||||||
x = upBlock.Forward(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 4a: NormOut + silu
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
x = vae.NormOut.Forward(x)
|
|
||||||
x = silu3D(x)
|
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 4b: ConvOut (handles its own pools)
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
x = vae.ConvOut.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 4c: Post-processing
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
// Clamp to [-1, 1]
|
|
||||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
|
||||||
// Convert back from channels-last to channels-first
|
|
||||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// Denormalize reverses the normalization applied during encoding
|
|
||||||
func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
|
|
||||||
shape := z.Shape()
|
|
||||||
C := shape[1]
|
|
||||||
|
|
||||||
mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
|
||||||
std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
|
||||||
|
|
||||||
mean = mlx.ToBFloat16(mean)
|
|
||||||
std = mlx.ToBFloat16(std)
|
|
||||||
|
|
||||||
return mlx.Add(mlx.Mul(z, std), mean)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper functions
|
|
||||||
|
|
||||||
func silu3D(x *mlx.Array) *mlx.Array {
|
|
||||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
|
||||||
}
|
|
||||||
|
|
||||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
|
||||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
|
||||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
// Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
|
|
||||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
|
||||||
}
|
|
||||||
|
|
||||||
func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
|
||||||
if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
|
|
||||||
}
|
|
||||||
|
|
||||||
func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
H := shape[2]
|
|
||||||
W := shape[3]
|
|
||||||
|
|
||||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
|
||||||
x = mlx.Reshape(x, B*H*W, shape[1])
|
|
||||||
|
|
||||||
wShape := weight.Shape()
|
|
||||||
var w *mlx.Array
|
|
||||||
if len(wShape) == 4 {
|
|
||||||
w = mlx.Reshape(weight, wShape[0], wShape[1])
|
|
||||||
} else {
|
|
||||||
w = weight
|
|
||||||
}
|
|
||||||
w = mlx.Transpose(w, 1, 0)
|
|
||||||
|
|
||||||
out := mlx.Linear(x, w)
|
|
||||||
outC := w.Dim(1)
|
|
||||||
out = mlx.Reshape(out, B, H, W, outC)
|
|
||||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
|
|
||||||
x = pad2D(x, 1, 1, 1, 1)
|
|
||||||
return conv2D(x, weight, 1, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
|
|
||||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
|
||||||
w = mlx.Transpose(w, 0, 2, 3, 1)
|
|
||||||
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
H := shape[1]
|
|
||||||
W := shape[2]
|
|
||||||
|
|
||||||
wShape := w.Shape()
|
|
||||||
Cout := wShape[0]
|
|
||||||
kH := wShape[1]
|
|
||||||
kW := wShape[2]
|
|
||||||
|
|
||||||
outH := (H-kH)/strideH + 1
|
|
||||||
outW := (W-kW)/strideW + 1
|
|
||||||
|
|
||||||
patches := extractPatches2D(x, kH, kW, strideH, strideW)
|
|
||||||
wFlat := mlx.Reshape(w, Cout, -1)
|
|
||||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
|
||||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
|
||||||
out = mlx.Reshape(out, B, outH, outW, Cout)
|
|
||||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
H := shape[1]
|
|
||||||
W := shape[2]
|
|
||||||
C := shape[3]
|
|
||||||
|
|
||||||
outH := (H-kH)/strideH + 1
|
|
||||||
outW := (W-kW)/strideW + 1
|
|
||||||
|
|
||||||
patches := make([]*mlx.Array, outH*outW)
|
|
||||||
idx := 0
|
|
||||||
for i := int32(0); i < outH; i++ {
|
|
||||||
for j := int32(0); j < outW; j++ {
|
|
||||||
startH := i * strideH
|
|
||||||
startW := j * strideW
|
|
||||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
|
||||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
|
||||||
patches[idx] = patch
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range patches {
|
|
||||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
|
||||||
}
|
|
||||||
stacked := mlx.Concatenate(patches, 1)
|
|
||||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
|
||||||
}
|
|
||||||
|
|
||||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
H := shape[2]
|
|
||||||
W := shape[3]
|
|
||||||
|
|
||||||
rowIdxData := make([]int32, H*2)
|
|
||||||
for i := int32(0); i < H; i++ {
|
|
||||||
rowIdxData[i*2] = i
|
|
||||||
rowIdxData[i*2+1] = i
|
|
||||||
}
|
|
||||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
|
||||||
|
|
||||||
colIdxData := make([]int32, W*2)
|
|
||||||
for i := int32(0); i < W; i++ {
|
|
||||||
colIdxData[i*2] = i
|
|
||||||
colIdxData[i*2+1] = i
|
|
||||||
}
|
|
||||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
|
||||||
|
|
||||||
x = mlx.Take(x, rowIdx, 2)
|
|
||||||
x = mlx.Take(x, colIdx, 3)
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
|
||||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
H := shape[1]
|
|
||||||
W := shape[2]
|
|
||||||
|
|
||||||
// Create repeat indices for rows
|
|
||||||
rowIdxData := make([]int32, H*2)
|
|
||||||
for i := int32(0); i < H; i++ {
|
|
||||||
rowIdxData[i*2] = i
|
|
||||||
rowIdxData[i*2+1] = i
|
|
||||||
}
|
|
||||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
|
||||||
|
|
||||||
// Create repeat indices for columns
|
|
||||||
colIdxData := make([]int32, W*2)
|
|
||||||
for i := int32(0); i < W; i++ {
|
|
||||||
colIdxData[i*2] = i
|
|
||||||
colIdxData[i*2+1] = i
|
|
||||||
}
|
|
||||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
|
||||||
|
|
||||||
// Take along H (axis 1) then W (axis 2)
|
|
||||||
x = mlx.Take(x, rowIdx, 1)
|
|
||||||
x = mlx.Take(x, colIdx, 2)
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
|
||||||
// weight: [outC, kH, kW, inC] (MLX channels-last format)
|
|
||||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
|
||||||
// Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
|
|
||||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
|
||||||
// Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
|
|
||||||
// stride=1, padding=0 (we already padded manually)
|
|
||||||
return mlx.Conv2d(x, weight, 1, 0)
|
|
||||||
}
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestVAEConfig tests configuration invariants.
|
|
||||||
func TestVAEConfig(t *testing.T) {
|
|
||||||
cfg := defaultVAEConfig()
|
|
||||||
|
|
||||||
// Property: latents_mean and latents_std have z_dim elements
|
|
||||||
if int32(len(cfg.LatentsMean)) != cfg.ZDim {
|
|
||||||
t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim)
|
|
||||||
}
|
|
||||||
if int32(len(cfg.LatentsStd)) != cfg.ZDim {
|
|
||||||
t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Property: dim_mult defines 4 stages
|
|
||||||
if len(cfg.DimMult) != 4 {
|
|
||||||
t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Property: temperal_downsample has 3 elements (for 3 transitions)
|
|
||||||
if len(cfg.TemperalDownsample) != 3 {
|
|
||||||
t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestVAELatentsNormalization tests the latent denormalization values.
|
|
||||||
func TestVAELatentsNormalization(t *testing.T) {
|
|
||||||
cfg := defaultVAEConfig()
|
|
||||||
|
|
||||||
// Verify latents_std values are all positive
|
|
||||||
for i, std := range cfg.LatentsStd {
|
|
||||||
if std <= 0 {
|
|
||||||
t.Errorf("latents_std[%d] should be positive: %v", i, std)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify values are in reasonable range (from actual model)
|
|
||||||
for i, mean := range cfg.LatentsMean {
|
|
||||||
if math.Abs(float64(mean)) > 5 {
|
|
||||||
t.Errorf("latents_mean[%d] seems too large: %v", i, mean)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i, std := range cfg.LatentsStd {
|
|
||||||
if std > 10 {
|
|
||||||
t.Errorf("latents_std[%d] seems too large: %v", i, std)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestVAEDecoderForward tests full forward pass (integration test).
|
|
||||||
// Skips if model weights are not available.
|
|
||||||
func TestVAEDecoderForward(t *testing.T) {
|
|
||||||
weightsPath := "../../../weights/Qwen-Image-2512/vae"
|
|
||||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
|
||||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
vae := &VAEDecoder{}
|
|
||||||
if err := vae.Load(weightsPath); err != nil {
|
|
||||||
t.Fatalf("Failed to load VAE decoder: %v", err)
|
|
||||||
}
|
|
||||||
mlx.Keep(mlx.Collect(vae)...)
|
|
||||||
|
|
||||||
// Small test input: [B, C, T, H, W]
|
|
||||||
// After 4 upsampling stages (2x each), H/W multiply by 16
|
|
||||||
batchSize := int32(1)
|
|
||||||
channels := int32(16)
|
|
||||||
frames := int32(1)
|
|
||||||
latentH := int32(4)
|
|
||||||
latentW := int32(4)
|
|
||||||
|
|
||||||
latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0)
|
|
||||||
|
|
||||||
// Decode
|
|
||||||
out := vae.Decode(latents)
|
|
||||||
mlx.Eval(out)
|
|
||||||
|
|
||||||
// Verify output shape: [B, 3, T, H*16, W*16]
|
|
||||||
outShape := out.Shape()
|
|
||||||
if outShape[0] != batchSize {
|
|
||||||
t.Errorf("batch size: got %d, want %d", outShape[0], batchSize)
|
|
||||||
}
|
|
||||||
if outShape[1] != 3 {
|
|
||||||
t.Errorf("channels: got %d, want 3", outShape[1])
|
|
||||||
}
|
|
||||||
if outShape[2] != frames {
|
|
||||||
t.Errorf("frames: got %d, want %d", outShape[2], frames)
|
|
||||||
}
|
|
||||||
expectedH := latentH * 16 // 4 stages of 2x upsampling
|
|
||||||
expectedW := latentW * 16
|
|
||||||
if outShape[3] != expectedH || outShape[4] != expectedW {
|
|
||||||
t.Errorf("spatial dims: got [%d, %d], want [%d, %d]",
|
|
||||||
outShape[3], outShape[4], expectedH, expectedW)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify output is in valid range (should be clamped to [0, 1] by decode)
|
|
||||||
outData := out.Data()
|
|
||||||
for i := 0; i < min(100, len(outData)); i++ {
|
|
||||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
|
||||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,682 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image_edit
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
|
||||||
type CausalConv3d struct {
|
|
||||||
Weight *mlx.Array
|
|
||||||
Bias *mlx.Array
|
|
||||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
|
||||||
KernelT int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// newCausalConv3d creates a 3D causal conv
|
|
||||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
|
||||||
weight, err := weights.Get(prefix + ".weight")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
|
||||||
}
|
|
||||||
bias, _ := weights.Get(prefix + ".bias")
|
|
||||||
|
|
||||||
kernelT := weight.Shape()[2]
|
|
||||||
outC := weight.Shape()[0]
|
|
||||||
|
|
||||||
var biasReshaped *mlx.Array
|
|
||||||
if bias != nil {
|
|
||||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &CausalConv3d{
|
|
||||||
Weight: weight,
|
|
||||||
Bias: bias,
|
|
||||||
BiasReshaped: biasReshaped,
|
|
||||||
KernelT: kernelT,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies causal 3D convolution (or 2D if weight is 4D)
|
|
||||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
|
||||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
shape := c.Weight.Shape()
|
|
||||||
|
|
||||||
// Handle both 5D (3D conv) and 4D (2D conv) weights
|
|
||||||
if len(shape) == 4 {
|
|
||||||
// 2D conv: [O, I, kH, kW] - need to apply per-frame
|
|
||||||
return c.forward2D(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3D conv: [O, I, kT, kH, kW]
|
|
||||||
kernelT := shape[2]
|
|
||||||
kernelH := shape[3]
|
|
||||||
kernelW := shape[4]
|
|
||||||
|
|
||||||
// Causal temporal padding, same spatial padding
|
|
||||||
padT := kernelT - 1
|
|
||||||
padH := kernelH / 2
|
|
||||||
padW := kernelW / 2
|
|
||||||
|
|
||||||
// Stage 1: Pad
|
|
||||||
{
|
|
||||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
|
||||||
mlx.Eval(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 2: Conv + bias
|
|
||||||
var out *mlx.Array
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
|
||||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
|
||||||
if c.Bias != nil {
|
|
||||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
|
||||||
out = mlx.Add(out, bias)
|
|
||||||
}
|
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
|
|
||||||
func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
|
|
||||||
xShape := x.Shape()
|
|
||||||
B := xShape[0]
|
|
||||||
T := xShape[1]
|
|
||||||
H := xShape[2]
|
|
||||||
W := xShape[3]
|
|
||||||
C := xShape[4]
|
|
||||||
|
|
||||||
wShape := c.Weight.Shape() // [O, I, kH, kW]
|
|
||||||
kernelH := wShape[2]
|
|
||||||
kernelW := wShape[3]
|
|
||||||
outC := wShape[0]
|
|
||||||
|
|
||||||
padH := kernelH / 2
|
|
||||||
padW := kernelW / 2
|
|
||||||
|
|
||||||
// Reshape to [B*T, H, W, C] for 2D conv
|
|
||||||
x = mlx.Reshape(x, B*T, H, W, C)
|
|
||||||
|
|
||||||
// Pad spatially
|
|
||||||
x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
|
|
||||||
|
|
||||||
// Apply 2D conv
|
|
||||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
|
||||||
x = mlx.Conv2d(x, weight, 1, 0)
|
|
||||||
|
|
||||||
if c.Bias != nil {
|
|
||||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
|
|
||||||
x = mlx.Add(x, bias)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get output spatial dims
|
|
||||||
outH := H
|
|
||||||
outW := W
|
|
||||||
|
|
||||||
// Reshape back to [B, T, H, W, C]
|
|
||||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
|
||||||
mlx.Eval(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// RMSNorm3D applies RMS normalization over channels
|
|
||||||
type RMSNorm3D struct {
|
|
||||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
|
||||||
}
|
|
||||||
|
|
||||||
// newRMSNorm3D creates an RMS norm
|
|
||||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
|
||||||
gamma, err := weights.Get(prefix + ".gamma")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
|
||||||
return &RMSNorm3D{Gamma: gamma}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
|
||||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
|
||||||
return mlx.Mul(normalized, n.Gamma)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResBlock is a residual block with RMS norm and causal convs
|
|
||||||
type ResBlock struct {
|
|
||||||
Norm1 *RMSNorm3D
|
|
||||||
Conv1 *CausalConv3d
|
|
||||||
Norm2 *RMSNorm3D
|
|
||||||
Conv2 *CausalConv3d
|
|
||||||
Shortcut *CausalConv3d
|
|
||||||
}
|
|
||||||
|
|
||||||
// newResBlock creates a residual block
|
|
||||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
|
||||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var shortcut *CausalConv3d
|
|
||||||
if inDim != outDim {
|
|
||||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ResBlock{
|
|
||||||
Norm1: norm1,
|
|
||||||
Conv1: conv1,
|
|
||||||
Norm2: norm2,
|
|
||||||
Conv2: conv2,
|
|
||||||
Shortcut: shortcut,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies the residual block
|
|
||||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
var h *mlx.Array
|
|
||||||
|
|
||||||
mlx.Keep(x)
|
|
||||||
|
|
||||||
// Stage 1: norm1 + silu
|
|
||||||
{
|
|
||||||
h = r.Norm1.Forward(x)
|
|
||||||
h = silu3D(h)
|
|
||||||
mlx.Eval(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 2: conv1
|
|
||||||
{
|
|
||||||
prev := h
|
|
||||||
h = r.Conv1.Forward(h)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 3: norm2 + silu
|
|
||||||
{
|
|
||||||
prev := h
|
|
||||||
h = r.Norm2.Forward(h)
|
|
||||||
h = silu3D(h)
|
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 4: conv2
|
|
||||||
{
|
|
||||||
prev := h
|
|
||||||
h = r.Conv2.Forward(h)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Residual connection
|
|
||||||
if r.Shortcut != nil {
|
|
||||||
shortcut := r.Shortcut.Forward(x)
|
|
||||||
h = mlx.Add(h, shortcut)
|
|
||||||
mlx.Eval(h)
|
|
||||||
} else {
|
|
||||||
h = mlx.Add(h, x)
|
|
||||||
mlx.Eval(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// AttentionBlock is a 2D attention block
|
|
||||||
type AttentionBlock struct {
|
|
||||||
Norm *RMSNorm3D
|
|
||||||
ToQKV *mlx.Array
|
|
||||||
ToQKVBias *mlx.Array
|
|
||||||
Proj *mlx.Array
|
|
||||||
ProjBias *mlx.Array
|
|
||||||
Dim int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// newAttentionBlock creates an attention block
|
|
||||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
|
||||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
|
||||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
|
||||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
|
||||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
|
||||||
|
|
||||||
return &AttentionBlock{
|
|
||||||
Norm: norm,
|
|
||||||
ToQKV: toQKV,
|
|
||||||
ToQKVBias: toQKVBias,
|
|
||||||
Proj: proj,
|
|
||||||
ProjBias: projBias,
|
|
||||||
Dim: dim,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies 2D attention
|
|
||||||
// Input: [B, T, H, W, C] (channels-last)
|
|
||||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
T := shape[1]
|
|
||||||
H := shape[2]
|
|
||||||
W := shape[3]
|
|
||||||
C := shape[4]
|
|
||||||
|
|
||||||
identity := x
|
|
||||||
|
|
||||||
// Flatten to [B*T, 1, H, W, C] for norm
|
|
||||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
|
||||||
x = a.Norm.Forward(x)
|
|
||||||
x = mlx.Reshape(x, B*T, H, W, C)
|
|
||||||
|
|
||||||
// Flatten spatial to [B*T, H*W, C]
|
|
||||||
x = mlx.Reshape(x, B*T, H*W, C)
|
|
||||||
|
|
||||||
// Linear to get Q, K, V
|
|
||||||
wShape := a.ToQKV.Shape()
|
|
||||||
var w *mlx.Array
|
|
||||||
if len(wShape) == 4 {
|
|
||||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
|
||||||
} else {
|
|
||||||
w = a.ToQKV
|
|
||||||
}
|
|
||||||
w = mlx.Transpose(w, 1, 0)
|
|
||||||
|
|
||||||
qkv := mlx.Linear(x, w)
|
|
||||||
if a.ToQKVBias != nil {
|
|
||||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
|
||||||
}
|
|
||||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
|
||||||
|
|
||||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
|
||||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
|
||||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
|
||||||
|
|
||||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
|
||||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
|
||||||
|
|
||||||
out = mlx.Reshape(out, B*T, H*W, C)
|
|
||||||
|
|
||||||
// Project back
|
|
||||||
pShape := a.Proj.Shape()
|
|
||||||
var p *mlx.Array
|
|
||||||
if len(pShape) == 4 {
|
|
||||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
|
||||||
} else {
|
|
||||||
p = a.Proj
|
|
||||||
}
|
|
||||||
p = mlx.Transpose(p, 1, 0)
|
|
||||||
out = mlx.Linear(out, p)
|
|
||||||
if a.ProjBias != nil {
|
|
||||||
out = mlx.Add(out, a.ProjBias)
|
|
||||||
}
|
|
||||||
|
|
||||||
out = mlx.Reshape(out, B, T, H, W, C)
|
|
||||||
return mlx.Add(out, identity)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpBlock handles upsampling in decoder
|
|
||||||
type UpBlock struct {
|
|
||||||
ResBlocks []*ResBlock
|
|
||||||
Upsampler *Upsample
|
|
||||||
}
|
|
||||||
|
|
||||||
// newUpBlock creates an up block
|
|
||||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
|
||||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
|
||||||
|
|
||||||
currentDim := inDim
|
|
||||||
for i := int32(0); i <= numBlocks; i++ {
|
|
||||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
|
||||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resBlocks[i] = block
|
|
||||||
currentDim = outDim
|
|
||||||
}
|
|
||||||
|
|
||||||
var upsampler *Upsample
|
|
||||||
if upsampleMode != "" {
|
|
||||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UpBlock{
|
|
||||||
ResBlocks: resBlocks,
|
|
||||||
Upsampler: upsampler,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies up block
|
|
||||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
for _, block := range u.ResBlocks {
|
|
||||||
prev := x
|
|
||||||
x = block.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.Upsampler != nil {
|
|
||||||
prev := x
|
|
||||||
x = u.Upsampler.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upsample handles spatial upsampling
|
|
||||||
type Upsample struct {
|
|
||||||
Conv *mlx.Array
|
|
||||||
Bias *mlx.Array
|
|
||||||
Mode string
|
|
||||||
}
|
|
||||||
|
|
||||||
// newUpsample creates an upsampler
|
|
||||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
|
||||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
|
||||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
|
||||||
return &Upsample{
|
|
||||||
Conv: conv,
|
|
||||||
Bias: bias,
|
|
||||||
Mode: mode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
|
||||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
T := shape[1]
|
|
||||||
H := shape[2]
|
|
||||||
W := shape[3]
|
|
||||||
C := shape[4]
|
|
||||||
outC := u.Conv.Shape()[0]
|
|
||||||
|
|
||||||
// Stage 1: 2x nearest neighbor upsample
|
|
||||||
{
|
|
||||||
x = mlx.Reshape(x, B*T, H, W, C)
|
|
||||||
x = upsample2xChannelsLast(x)
|
|
||||||
mlx.Eval(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stage 2: Conv + bias
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
|
||||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
|
||||||
if u.Bias != nil {
|
|
||||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
|
||||||
x = mlx.Add(x, bias)
|
|
||||||
}
|
|
||||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// MidBlock is the middle block
|
|
||||||
type MidBlock struct {
|
|
||||||
ResBlock1 *ResBlock
|
|
||||||
Attention *AttentionBlock
|
|
||||||
ResBlock2 *ResBlock
|
|
||||||
}
|
|
||||||
|
|
||||||
// newMidBlock creates a mid block
|
|
||||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
|
||||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &MidBlock{
|
|
||||||
ResBlock1: res1,
|
|
||||||
Attention: attn,
|
|
||||||
ResBlock2: res2,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies mid block
|
|
||||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
prev := x
|
|
||||||
x = m.ResBlock1.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
|
|
||||||
prev = x
|
|
||||||
x = m.Attention.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
|
|
||||||
prev = x
|
|
||||||
x = m.ResBlock2.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper functions
|
|
||||||
|
|
||||||
func silu3D(x *mlx.Array) *mlx.Array {
|
|
||||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
|
||||||
}
|
|
||||||
|
|
||||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
|
||||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
|
||||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
|
||||||
}
|
|
||||||
|
|
||||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
|
||||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
H := shape[1]
|
|
||||||
W := shape[2]
|
|
||||||
|
|
||||||
rowIdxData := make([]int32, H*2)
|
|
||||||
for i := int32(0); i < H; i++ {
|
|
||||||
rowIdxData[i*2] = i
|
|
||||||
rowIdxData[i*2+1] = i
|
|
||||||
}
|
|
||||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
|
||||||
|
|
||||||
colIdxData := make([]int32, W*2)
|
|
||||||
for i := int32(0); i < W; i++ {
|
|
||||||
colIdxData[i*2] = i
|
|
||||||
colIdxData[i*2+1] = i
|
|
||||||
}
|
|
||||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
|
||||||
|
|
||||||
x = mlx.Take(x, rowIdx, 1)
|
|
||||||
x = mlx.Take(x, colIdx, 2)
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
|
||||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
|
||||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
|
||||||
return mlx.Conv2d(x, weight, 1, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// conv2DStrided applies conv with stride > 1 using manual patch extraction
|
|
||||||
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
|
|
||||||
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
H := shape[1]
|
|
||||||
W := shape[2]
|
|
||||||
|
|
||||||
wShape := weight.Shape()
|
|
||||||
Cout := wShape[0]
|
|
||||||
kH := wShape[1]
|
|
||||||
kW := wShape[2]
|
|
||||||
|
|
||||||
outH := (H - kH) / stride + 1
|
|
||||||
outW := (W - kW) / stride + 1
|
|
||||||
|
|
||||||
patches := extractPatches2DStrided(x, kH, kW, stride)
|
|
||||||
wFlat := mlx.Reshape(weight, Cout, -1)
|
|
||||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
|
||||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
|
||||||
return mlx.Reshape(out, B, outH, outW, Cout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// conv3DStrided applies 3D conv with strides using manual patch extraction
|
|
||||||
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
|
|
||||||
// strideT, strideH, strideW are the strides for each dimension
|
|
||||||
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
|
|
||||||
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
T := shape[1]
|
|
||||||
H := shape[2]
|
|
||||||
W := shape[3]
|
|
||||||
C := shape[4]
|
|
||||||
|
|
||||||
wShape := weight.Shape()
|
|
||||||
Cout := wShape[0]
|
|
||||||
// I := wShape[1]
|
|
||||||
kT := wShape[2]
|
|
||||||
kH := wShape[3]
|
|
||||||
kW := wShape[4]
|
|
||||||
|
|
||||||
// For temporal: if T < kT, we need to repeat frames temporally
|
|
||||||
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
|
|
||||||
// Python Qwen2.5-VL duplicates the frame, not zero-pads
|
|
||||||
if T < kT {
|
|
||||||
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
|
|
||||||
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
|
|
||||||
T = kT
|
|
||||||
}
|
|
||||||
|
|
||||||
outT := (T - kT) / strideT + 1
|
|
||||||
outH := (H - kH) / strideH + 1
|
|
||||||
outW := (W - kW) / strideW + 1
|
|
||||||
|
|
||||||
// Extract 3D patches in [C, T, H, W] order to match Python
|
|
||||||
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
|
|
||||||
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
|
|
||||||
|
|
||||||
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
|
|
||||||
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
|
|
||||||
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
|
|
||||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
|
||||||
return mlx.Reshape(out, B, outT, outH, outW, Cout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractPatches3DStrided extracts 3D patches with given strides
|
|
||||||
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
|
|
||||||
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
T := shape[1]
|
|
||||||
H := shape[2]
|
|
||||||
W := shape[3]
|
|
||||||
C := shape[4]
|
|
||||||
|
|
||||||
outT := (T - kT) / strideT + 1
|
|
||||||
outH := (H - kH) / strideH + 1
|
|
||||||
outW := (W - kW) / strideW + 1
|
|
||||||
|
|
||||||
numPatches := outT * outH * outW
|
|
||||||
patches := make([]*mlx.Array, numPatches)
|
|
||||||
idx := 0
|
|
||||||
for t := int32(0); t < outT; t++ {
|
|
||||||
for i := int32(0); i < outH; i++ {
|
|
||||||
for j := int32(0); j < outW; j++ {
|
|
||||||
startT := t * strideT
|
|
||||||
startH := i * strideH
|
|
||||||
startW := j * strideW
|
|
||||||
// Extract patch: [B, kT, kH, kW, C]
|
|
||||||
patch := mlx.Slice(x,
|
|
||||||
[]int32{0, startT, startH, startW, 0},
|
|
||||||
[]int32{B, startT + kT, startH + kH, startW + kW, C})
|
|
||||||
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
|
|
||||||
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
|
|
||||||
// Flatten to [B, C*T*H*W]
|
|
||||||
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
|
|
||||||
patches[idx] = patch
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range patches {
|
|
||||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
|
||||||
}
|
|
||||||
stacked := mlx.Concatenate(patches, 1)
|
|
||||||
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractPatches2DStrided extracts patches with given stride
|
|
||||||
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
H := shape[1]
|
|
||||||
W := shape[2]
|
|
||||||
C := shape[3]
|
|
||||||
|
|
||||||
outH := (H - kH) / stride + 1
|
|
||||||
outW := (W - kW) / stride + 1
|
|
||||||
|
|
||||||
patches := make([]*mlx.Array, outH*outW)
|
|
||||||
idx := 0
|
|
||||||
for i := int32(0); i < outH; i++ {
|
|
||||||
for j := int32(0); j < outW; j++ {
|
|
||||||
startH := i * stride
|
|
||||||
startW := j * stride
|
|
||||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
|
||||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
|
||||||
patches[idx] = patch
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range patches {
|
|
||||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
|
||||||
}
|
|
||||||
stacked := mlx.Concatenate(patches, 1)
|
|
||||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
|
||||||
}
|
|
||||||
|
|
||||||
// layerNormNoAffine applies layer norm without learnable parameters
|
|
||||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
|
||||||
ndim := x.Ndim()
|
|
||||||
lastAxis := ndim - 1
|
|
||||||
mean := mlx.Mean(x, lastAxis, true)
|
|
||||||
xCentered := mlx.Sub(x, mean)
|
|
||||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
|
||||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
|
||||||
}
|
|
||||||
@@ -1,475 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image_edit
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"image"
|
|
||||||
"image/color"
|
|
||||||
_ "image/jpeg"
|
|
||||||
_ "image/png"
|
|
||||||
"math"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"golang.org/x/image/draw"
|
|
||||||
_ "golang.org/x/image/webp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// loadImageFile loads an image from disk
|
|
||||||
func loadImageFile(path string) (image.Image, error) {
|
|
||||||
f, err := os.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("open image: %w", err)
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
img, _, err := image.Decode(f)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("decode image: %w", err)
|
|
||||||
}
|
|
||||||
return img, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range
|
|
||||||
func imageToFloat32Pixels(img image.Image, width, height int) []float32 {
|
|
||||||
pixels := make([]float32, width*height*3)
|
|
||||||
idx := 0
|
|
||||||
for y := 0; y < height; y++ {
|
|
||||||
for x := 0; x < width; x++ {
|
|
||||||
r, g, b, _ := img.At(x, y).RGBA()
|
|
||||||
pixels[idx] = float32(r) / 65535.0
|
|
||||||
pixels[idx+1] = float32(g) / 65535.0
|
|
||||||
pixels[idx+2] = float32(b) / 65535.0
|
|
||||||
idx += 3
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return pixels
|
|
||||||
}
|
|
||||||
|
|
||||||
// normalizeImageNet applies ImageNet normalization to an image tensor
|
|
||||||
func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array {
|
|
||||||
mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3})
|
|
||||||
std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3})
|
|
||||||
return mlx.Div(mlx.Sub(arr, mean), std)
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16
|
|
||||||
func prepareImageTensor(arr *mlx.Array) *mlx.Array {
|
|
||||||
// Transpose to [C, H, W] and make contiguous
|
|
||||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
|
||||||
// Add batch dimension [1, C, H, W]
|
|
||||||
arr = mlx.ExpandDims(arr, 0)
|
|
||||||
// Convert to bf16
|
|
||||||
arr = mlx.ToBFloat16(arr)
|
|
||||||
mlx.Eval(arr)
|
|
||||||
return arr
|
|
||||||
}
|
|
||||||
|
|
||||||
// clampFloat clamps a value to [0, 255] and returns uint8
|
|
||||||
func clampFloat(v, weightSum float64) uint8 {
|
|
||||||
v /= weightSum
|
|
||||||
if v < 0 {
|
|
||||||
v = 0
|
|
||||||
}
|
|
||||||
if v > 255 {
|
|
||||||
v = 255
|
|
||||||
}
|
|
||||||
return uint8(math.Round(v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ImageDims holds dimensions for a preprocessed image
|
|
||||||
type ImageDims struct {
|
|
||||||
// Original image dimensions
|
|
||||||
OrigW, OrigH int32
|
|
||||||
// Condition image dimensions (for vision encoder)
|
|
||||||
CondW, CondH int32
|
|
||||||
// VAE image dimensions
|
|
||||||
VaeW, VaeH int32
|
|
||||||
// Latent dimensions (VAE dims / vae_scale_factor)
|
|
||||||
LatentW, LatentH int32
|
|
||||||
// Patch dimensions (latent dims / patch_size)
|
|
||||||
PatchW, PatchH int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessorConfig holds image processor configuration
|
|
||||||
type ProcessorConfig struct {
|
|
||||||
// Condition image size (target pixel area for vision encoder input)
|
|
||||||
// Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456
|
|
||||||
// Pipeline resizes image to this area before passing to encode_prompt
|
|
||||||
ConditionImageSize int32
|
|
||||||
|
|
||||||
// VAE image size (target pixel area)
|
|
||||||
// Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576
|
|
||||||
VAEImageSize int32
|
|
||||||
|
|
||||||
// Image normalization (ImageNet stats for vision encoder)
|
|
||||||
ImageMean []float32
|
|
||||||
ImageStd []float32
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultProcessorConfig returns default processor config
|
|
||||||
func defaultProcessorConfig() *ProcessorConfig {
|
|
||||||
return &ProcessorConfig{
|
|
||||||
ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE
|
|
||||||
VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE
|
|
||||||
ImageMean: []float32{0.48145466, 0.4578275, 0.40821073},
|
|
||||||
ImageStd: []float32{0.26862954, 0.26130258, 0.27577711},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Processor handles image preprocessing for Qwen-Image-Edit
|
|
||||||
type Processor struct {
|
|
||||||
Config *ProcessorConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load loads the processor config
|
|
||||||
func (p *Processor) Load(path string) error {
|
|
||||||
p.Config = defaultProcessorConfig()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadAndPreprocess loads an image and preprocesses it for both paths
|
|
||||||
// Returns: condImage (for vision encoder), vaeImage (for VAE encoding)
|
|
||||||
func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) {
|
|
||||||
img, err := loadImageFile(imagePath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
bounds := img.Bounds()
|
|
||||||
origW := bounds.Dx()
|
|
||||||
origH := bounds.Dy()
|
|
||||||
ratio := float64(origW) / float64(origH)
|
|
||||||
|
|
||||||
// Calculate dimensions for condition image (vision encoder)
|
|
||||||
// Python pipeline does TWO resizes:
|
|
||||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
|
||||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
|
||||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
|
||||||
finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
|
||||||
|
|
||||||
// Calculate dimensions for VAE image (1024x1024 area)
|
|
||||||
// Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32)
|
|
||||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
|
||||||
|
|
||||||
// Preprocess for condition (vision encoder) - two-step resize
|
|
||||||
condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH)
|
|
||||||
|
|
||||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
|
||||||
vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH)
|
|
||||||
|
|
||||||
return condImage, vaeImage, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// preprocessImageLanczos does single-step Lanczos resize for vision encoder
|
|
||||||
// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default)
|
|
||||||
// Used by edit_plus pipeline for multi-image input
|
|
||||||
// Returns: [B, C, H, W] normalized tensor
|
|
||||||
func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array {
|
|
||||||
resized := resizeImageLanczos(img, int(width), int(height))
|
|
||||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
|
||||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
|
||||||
arr = p.normalizeImageNet(arr)
|
|
||||||
return prepareImageTensor(arr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline
|
|
||||||
// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize)
|
|
||||||
// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize)
|
|
||||||
// Returns: [B, C, H, W] normalized tensor
|
|
||||||
func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array {
|
|
||||||
intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH))
|
|
||||||
resized := resizeImageBicubic(intermediate, int(finalW), int(finalH))
|
|
||||||
pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH))
|
|
||||||
arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3})
|
|
||||||
arr = p.normalizeImageNet(arr)
|
|
||||||
return prepareImageTensor(arr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// preprocessImage converts image to tensor for vision encoder
|
|
||||||
// Returns: [B, C, H, W] normalized tensor
|
|
||||||
func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array {
|
|
||||||
resized := resizeImageBicubic(img, int(width), int(height))
|
|
||||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
|
||||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
|
||||||
if normalize {
|
|
||||||
arr = p.normalizeImageNet(arr)
|
|
||||||
}
|
|
||||||
return prepareImageTensor(arr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// preprocessImageForVAE converts image to tensor for VAE encoding
|
|
||||||
// Returns: [B, C, T, H, W] tensor in [-1, 1] range
|
|
||||||
func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array {
|
|
||||||
resized := resizeImageLanczos(img, int(width), int(height))
|
|
||||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
|
||||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
|
||||||
|
|
||||||
// Scale to [-1, 1]: arr * 2 - 1
|
|
||||||
arr = mlx.MulScalar(arr, 2.0)
|
|
||||||
arr = mlx.AddScalar(arr, -1.0)
|
|
||||||
|
|
||||||
// Transpose to [C, H, W] and make contiguous
|
|
||||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
|
||||||
|
|
||||||
// Add batch and temporal dimensions [1, C, 1, H, W]
|
|
||||||
arr = mlx.ExpandDims(arr, 0) // [1, C, H, W]
|
|
||||||
arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W]
|
|
||||||
|
|
||||||
arr = mlx.ToBFloat16(arr)
|
|
||||||
mlx.Eval(arr)
|
|
||||||
return arr
|
|
||||||
}
|
|
||||||
|
|
||||||
// smartResize implements Python Qwen2VL processor's smart_resize logic
|
|
||||||
// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints
|
|
||||||
func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) {
|
|
||||||
// Round to factor
|
|
||||||
hBar := int32(math.Round(float64(height)/float64(factor))) * factor
|
|
||||||
wBar := int32(math.Round(float64(width)/float64(factor))) * factor
|
|
||||||
|
|
||||||
// Ensure minimum factor size
|
|
||||||
if hBar < factor {
|
|
||||||
hBar = factor
|
|
||||||
}
|
|
||||||
if wBar < factor {
|
|
||||||
wBar = factor
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check pixel constraints
|
|
||||||
total := hBar * wBar
|
|
||||||
if total > maxPixels {
|
|
||||||
// Scale down
|
|
||||||
beta := math.Sqrt(float64(maxPixels) / float64(total))
|
|
||||||
hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor
|
|
||||||
wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor
|
|
||||||
} else if total < minPixels {
|
|
||||||
// Scale up
|
|
||||||
beta := math.Sqrt(float64(minPixels) / float64(total))
|
|
||||||
hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
|
||||||
wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
|
||||||
}
|
|
||||||
|
|
||||||
return hBar, wBar
|
|
||||||
}
|
|
||||||
|
|
||||||
// calculateDimensions calculates width and height for a target area while maintaining ratio
|
|
||||||
// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge)
|
|
||||||
func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) {
|
|
||||||
width := math.Sqrt(float64(targetArea) * ratio)
|
|
||||||
height := width / ratio
|
|
||||||
|
|
||||||
m := float64(multiple)
|
|
||||||
width = math.Round(width/m) * m
|
|
||||||
height = math.Round(height/m) * m
|
|
||||||
|
|
||||||
// Ensure minimum dimensions
|
|
||||||
if width < m {
|
|
||||||
width = m
|
|
||||||
}
|
|
||||||
if height < m {
|
|
||||||
height = m
|
|
||||||
}
|
|
||||||
|
|
||||||
return int32(width), int32(height)
|
|
||||||
}
|
|
||||||
|
|
||||||
// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS)
|
|
||||||
func resizeImageLanczos(img image.Image, width, height int) image.Image {
|
|
||||||
bounds := img.Bounds()
|
|
||||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
|
||||||
|
|
||||||
// Lanczos3 kernel (a=3) to match PIL.LANCZOS
|
|
||||||
lanczos3 := &draw.Kernel{
|
|
||||||
Support: 3.0,
|
|
||||||
At: func(t float64) float64 {
|
|
||||||
if t == 0 {
|
|
||||||
return 1.0
|
|
||||||
}
|
|
||||||
if t < 0 {
|
|
||||||
t = -t
|
|
||||||
}
|
|
||||||
if t >= 3.0 {
|
|
||||||
return 0.0
|
|
||||||
}
|
|
||||||
// sinc(t) * sinc(t/3)
|
|
||||||
piT := math.Pi * t
|
|
||||||
return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3))
|
|
||||||
},
|
|
||||||
}
|
|
||||||
lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil)
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC)
|
|
||||||
// Uses separable interpolation with PIL's coordinate mapping for exact match
|
|
||||||
func resizeImageBicubic(img image.Image, width, height int) image.Image {
|
|
||||||
bounds := img.Bounds()
|
|
||||||
srcW := bounds.Dx()
|
|
||||||
srcH := bounds.Dy()
|
|
||||||
|
|
||||||
// Convert to RGBA if needed
|
|
||||||
var src *image.RGBA
|
|
||||||
if rgba, ok := img.(*image.RGBA); ok {
|
|
||||||
src = rgba
|
|
||||||
} else {
|
|
||||||
src = image.NewRGBA(bounds)
|
|
||||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
|
||||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
|
||||||
src.Set(x, y, img.At(x, y))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keys cubic with a=-0.5 (PIL BICUBIC)
|
|
||||||
cubic := func(x float64) float64 {
|
|
||||||
if x < 0 {
|
|
||||||
x = -x
|
|
||||||
}
|
|
||||||
if x < 1 {
|
|
||||||
return 1.5*x*x*x - 2.5*x*x + 1
|
|
||||||
}
|
|
||||||
if x < 2 {
|
|
||||||
return -0.5*x*x*x + 2.5*x*x - 4*x + 2
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Horizontal pass: srcW -> width, keep srcH rows
|
|
||||||
temp := image.NewRGBA(image.Rect(0, 0, width, srcH))
|
|
||||||
for y := 0; y < srcH; y++ {
|
|
||||||
for dstX := 0; dstX < width; dstX++ {
|
|
||||||
// PIL coordinate mapping: center-to-center
|
|
||||||
srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5
|
|
||||||
baseX := int(math.Floor(srcXf))
|
|
||||||
|
|
||||||
var sumR, sumG, sumB, sumA, weightSum float64
|
|
||||||
for i := -1; i <= 2; i++ {
|
|
||||||
sx := baseX + i
|
|
||||||
if sx < 0 {
|
|
||||||
sx = 0
|
|
||||||
}
|
|
||||||
if sx >= srcW {
|
|
||||||
sx = srcW - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
w := cubic(math.Abs(srcXf - float64(baseX+i)))
|
|
||||||
c := src.RGBAAt(sx, y)
|
|
||||||
sumR += float64(c.R) * w
|
|
||||||
sumG += float64(c.G) * w
|
|
||||||
sumB += float64(c.B) * w
|
|
||||||
sumA += float64(c.A) * w
|
|
||||||
weightSum += w
|
|
||||||
}
|
|
||||||
|
|
||||||
temp.SetRGBA(dstX, y, color.RGBA{
|
|
||||||
clampFloat(sumR, weightSum),
|
|
||||||
clampFloat(sumG, weightSum),
|
|
||||||
clampFloat(sumB, weightSum),
|
|
||||||
clampFloat(sumA, weightSum),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Vertical pass: srcH -> height
|
|
||||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
|
||||||
for x := 0; x < width; x++ {
|
|
||||||
for dstY := 0; dstY < height; dstY++ {
|
|
||||||
srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5
|
|
||||||
baseY := int(math.Floor(srcYf))
|
|
||||||
|
|
||||||
var sumR, sumG, sumB, sumA, weightSum float64
|
|
||||||
for j := -1; j <= 2; j++ {
|
|
||||||
sy := baseY + j
|
|
||||||
if sy < 0 {
|
|
||||||
sy = 0
|
|
||||||
}
|
|
||||||
if sy >= srcH {
|
|
||||||
sy = srcH - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
w := cubic(math.Abs(srcYf - float64(baseY+j)))
|
|
||||||
c := temp.RGBAAt(x, sy)
|
|
||||||
sumR += float64(c.R) * w
|
|
||||||
sumG += float64(c.G) * w
|
|
||||||
sumB += float64(c.B) * w
|
|
||||||
sumA += float64(c.A) * w
|
|
||||||
weightSum += w
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.SetRGBA(x, dstY, color.RGBA{
|
|
||||||
clampFloat(sumR, weightSum),
|
|
||||||
clampFloat(sumG, weightSum),
|
|
||||||
clampFloat(sumB, weightSum),
|
|
||||||
clampFloat(sumA, weightSum),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadAndPreprocessMultiple loads multiple images and preprocesses them
|
|
||||||
// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions)
|
|
||||||
func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) {
|
|
||||||
const vaeScaleFactor int32 = 8
|
|
||||||
const patchSize int32 = 2
|
|
||||||
|
|
||||||
condImages := make([]*mlx.Array, len(imagePaths))
|
|
||||||
vaeImages := make([]*mlx.Array, len(imagePaths))
|
|
||||||
dims := make([]ImageDims, len(imagePaths))
|
|
||||||
|
|
||||||
for i, imagePath := range imagePaths {
|
|
||||||
img, err := loadImageFile(imagePath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, fmt.Errorf("image %d: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bounds := img.Bounds()
|
|
||||||
origW := int32(bounds.Dx())
|
|
||||||
origH := int32(bounds.Dy())
|
|
||||||
ratio := float64(origW) / float64(origH)
|
|
||||||
|
|
||||||
// Calculate dimensions for condition image (vision encoder)
|
|
||||||
// Python pipeline does TWO resizes:
|
|
||||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
|
||||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
|
||||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
|
||||||
condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
|
||||||
|
|
||||||
// Calculate dimensions for VAE image (1024x1024 area)
|
|
||||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
|
||||||
|
|
||||||
// Calculate derived dimensions
|
|
||||||
latentW := vaeW / vaeScaleFactor
|
|
||||||
latentH := vaeH / vaeScaleFactor
|
|
||||||
patchW := latentW / patchSize
|
|
||||||
patchH := latentH / patchSize
|
|
||||||
|
|
||||||
dims[i] = ImageDims{
|
|
||||||
OrigW: origW,
|
|
||||||
OrigH: origH,
|
|
||||||
CondW: condW,
|
|
||||||
CondH: condH,
|
|
||||||
VaeW: vaeW,
|
|
||||||
VaeH: vaeH,
|
|
||||||
LatentW: latentW,
|
|
||||||
LatentH: latentH,
|
|
||||||
PatchW: patchW,
|
|
||||||
PatchH: patchH,
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n",
|
|
||||||
i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH)
|
|
||||||
|
|
||||||
// Preprocess for condition (vision encoder) - two-step resize to match Python pipeline
|
|
||||||
condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH)
|
|
||||||
|
|
||||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
|
||||||
vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH)
|
|
||||||
}
|
|
||||||
|
|
||||||
return condImages, vaeImages, dims, nil
|
|
||||||
}
|
|
||||||
@@ -1,625 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing.
|
|
||||||
// It reuses components from qwen_image where possible.
|
|
||||||
package qwen_image_edit
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"path/filepath"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GenerateConfig holds all options for image editing.
|
|
||||||
type GenerateConfig struct {
|
|
||||||
Prompt string
|
|
||||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
|
||||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
|
||||||
Width int32 // Output width (default: from input image)
|
|
||||||
Height int32 // Output height (default: from input image)
|
|
||||||
Steps int // Denoising steps (default: 50)
|
|
||||||
Seed int64 // Random seed
|
|
||||||
Progress func(step, totalSteps int) // Optional progress callback
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model represents a Qwen-Image-Edit diffusion model.
|
|
||||||
type Model struct {
|
|
||||||
ModelPath string
|
|
||||||
Tokenizer *tokenizer.Tokenizer
|
|
||||||
Processor *Processor // Image processor for vision encoder
|
|
||||||
TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image)
|
|
||||||
Transformer *qwen_image.Transformer // Reuse qwen_image transformer
|
|
||||||
VAE *VAE // Combined encoder + decoder
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load loads the Qwen-Image-Edit model from a directory.
|
|
||||||
func (m *Model) Load(modelPath string) error {
|
|
||||||
fmt.Println("Loading Qwen-Image-Edit model...")
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
if mlx.GPUIsAvailable() {
|
|
||||||
mlx.SetDefaultDeviceGPU()
|
|
||||||
mlx.EnableCompile()
|
|
||||||
}
|
|
||||||
|
|
||||||
m.ModelPath = modelPath
|
|
||||||
|
|
||||||
// Load tokenizer from processor directory
|
|
||||||
fmt.Print(" Loading tokenizer... ")
|
|
||||||
processorPath := filepath.Join(modelPath, "processor")
|
|
||||||
tok, err := tokenizer.Load(processorPath)
|
|
||||||
if err != nil {
|
|
||||||
// Fallback to tokenizer directory
|
|
||||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
|
||||||
tok, err = tokenizer.Load(tokenizerPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("tokenizer: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.Tokenizer = tok
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
// Load processor (image preprocessing config)
|
|
||||||
fmt.Print(" Loading processor... ")
|
|
||||||
m.Processor = &Processor{}
|
|
||||||
if err := m.Processor.Load(processorPath); err != nil {
|
|
||||||
return fmt.Errorf("processor: %w", err)
|
|
||||||
}
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
// Load vision-language text encoder (Qwen2.5-VL from qwen_image package)
|
|
||||||
m.TextEncoder = &qwen_image.Qwen25VL{}
|
|
||||||
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
|
|
||||||
return fmt.Errorf("text encoder: %w", err)
|
|
||||||
}
|
|
||||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
|
||||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
|
||||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
|
||||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
|
||||||
|
|
||||||
// Load transformer (reuse qwen_image)
|
|
||||||
m.Transformer = &qwen_image.Transformer{}
|
|
||||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
|
||||||
return fmt.Errorf("transformer: %w", err)
|
|
||||||
}
|
|
||||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
|
||||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
|
||||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
|
||||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
|
||||||
|
|
||||||
// Load VAE (encoder + decoder)
|
|
||||||
m.VAE = &VAE{}
|
|
||||||
if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil {
|
|
||||||
return fmt.Errorf("VAE: %w", err)
|
|
||||||
}
|
|
||||||
mlx.Eval(mlx.Collect(m.VAE)...)
|
|
||||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
|
||||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
|
||||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
|
||||||
|
|
||||||
mem := mlx.MetalGetActiveMemory()
|
|
||||||
peak := mlx.MetalGetPeakMemory()
|
|
||||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
|
||||||
time.Since(start).Seconds(),
|
|
||||||
float64(mem)/(1024*1024*1024),
|
|
||||||
float64(peak)/(1024*1024*1024))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Edit edits an image based on a text prompt.
|
|
||||||
// inputImagePath: path to input image
|
|
||||||
// prompt: text description of desired edit
|
|
||||||
func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
|
||||||
return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{
|
|
||||||
Prompt: prompt,
|
|
||||||
Width: width,
|
|
||||||
Height: height,
|
|
||||||
Steps: steps,
|
|
||||||
Seed: seed,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// EditFromConfig edits images using the unified config struct.
|
|
||||||
// Accepts one or more input images.
|
|
||||||
func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
|
||||||
if len(inputImagePaths) == 0 {
|
|
||||||
return nil, fmt.Errorf("no input images provided")
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
result, err := m.edit(inputImagePaths, cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.NegativePrompt != "" {
|
|
||||||
fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n",
|
|
||||||
len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
|
||||||
} else {
|
|
||||||
fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n",
|
|
||||||
len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// EditImage implements model.ImageEditModel interface.
|
|
||||||
func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
|
||||||
return m.Edit(inputImagePath, prompt, width, height, steps, seed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// EditMultiImage edits using multiple source images.
|
|
||||||
// This matches diffusers' QwenImageEditPlusPipeline behavior.
|
|
||||||
func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
|
||||||
return m.EditFromConfig(inputImagePaths, cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// edit is the internal editing pipeline that handles one or more images.
|
|
||||||
func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
|
||||||
// Apply defaults
|
|
||||||
if cfg.Steps <= 0 {
|
|
||||||
cfg.Steps = 50
|
|
||||||
}
|
|
||||||
if cfg.CFGScale <= 0 {
|
|
||||||
cfg.CFGScale = 4.0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load and preprocess all input images
|
|
||||||
fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths))
|
|
||||||
condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("preprocess images: %w", err)
|
|
||||||
}
|
|
||||||
for _, img := range condImages {
|
|
||||||
mlx.Keep(img)
|
|
||||||
}
|
|
||||||
for _, img := range vaeImages {
|
|
||||||
mlx.Keep(img)
|
|
||||||
}
|
|
||||||
mlx.Eval(append(condImages, vaeImages...)...)
|
|
||||||
|
|
||||||
useCFG := cfg.NegativePrompt != ""
|
|
||||||
tcfg := m.Transformer.Config
|
|
||||||
vaeScaleFactor := int32(8)
|
|
||||||
|
|
||||||
// Output dimensions - if not specified, use first input image dimensions
|
|
||||||
if cfg.Width <= 0 {
|
|
||||||
cfg.Width = inputDims[0].VaeW
|
|
||||||
}
|
|
||||||
if cfg.Height <= 0 {
|
|
||||||
cfg.Height = inputDims[0].VaeH
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output (noise) latent dimensions
|
|
||||||
outLatentH := cfg.Height / vaeScaleFactor
|
|
||||||
outLatentW := cfg.Width / vaeScaleFactor
|
|
||||||
outPH := outLatentH / tcfg.PatchSize
|
|
||||||
outPW := outLatentW / tcfg.PatchSize
|
|
||||||
noiseSeqLen := outPH * outPW
|
|
||||||
imgSeqLen := noiseSeqLen
|
|
||||||
|
|
||||||
// Encode prompt with all images for conditioning
|
|
||||||
posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("encoding prompt: %w", err)
|
|
||||||
}
|
|
||||||
mlx.Keep(posEmb)
|
|
||||||
mlx.Eval(posEmb)
|
|
||||||
|
|
||||||
var negEmb *mlx.Array
|
|
||||||
if useCFG {
|
|
||||||
negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("encoding negative prompt: %w", err)
|
|
||||||
}
|
|
||||||
mlx.Keep(negEmb)
|
|
||||||
mlx.Eval(negEmb)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pad sequences to same length for CFG
|
|
||||||
txtLen := posEmb.Shape()[1]
|
|
||||||
if useCFG {
|
|
||||||
negLen := negEmb.Shape()[1]
|
|
||||||
if negLen > txtLen {
|
|
||||||
txtLen = negLen
|
|
||||||
}
|
|
||||||
if posEmb.Shape()[1] < txtLen {
|
|
||||||
posEmb = padSequence(posEmb, txtLen)
|
|
||||||
}
|
|
||||||
if negEmb.Shape()[1] < txtLen {
|
|
||||||
negEmb = padSequence(negEmb, txtLen)
|
|
||||||
}
|
|
||||||
mlx.Keep(posEmb, negEmb)
|
|
||||||
mlx.Eval(posEmb, negEmb)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
|
||||||
var batchedEmb *mlx.Array
|
|
||||||
if useCFG {
|
|
||||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
|
||||||
mlx.Keep(batchedEmb)
|
|
||||||
mlx.Eval(batchedEmb)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode all input images to latents and concatenate
|
|
||||||
fmt.Println("Encoding images to latents...")
|
|
||||||
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
|
|
||||||
for i, vaeImage := range vaeImages {
|
|
||||||
imageLatents := m.VAE.Encode(vaeImage)
|
|
||||||
imageLatents = m.VAE.Normalize(imageLatents)
|
|
||||||
imageLatents2D := mlx.Squeeze(imageLatents, 2)
|
|
||||||
packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize)
|
|
||||||
mlx.Keep(packed)
|
|
||||||
mlx.Eval(packed)
|
|
||||||
allImageLatentsPacked[i] = packed
|
|
||||||
}
|
|
||||||
|
|
||||||
imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1)
|
|
||||||
mlx.Keep(imageLatentsPacked)
|
|
||||||
mlx.Eval(imageLatentsPacked)
|
|
||||||
|
|
||||||
// Scheduler
|
|
||||||
scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig())
|
|
||||||
scheduler.SetTimesteps(cfg.Steps, noiseSeqLen)
|
|
||||||
|
|
||||||
// Init noise latents in packed format
|
|
||||||
packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize
|
|
||||||
packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed)
|
|
||||||
latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize)
|
|
||||||
mlx.Eval(latents)
|
|
||||||
|
|
||||||
// RoPE cache
|
|
||||||
ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope)
|
|
||||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
|
||||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
|
||||||
|
|
||||||
// Denoising loop
|
|
||||||
fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps)
|
|
||||||
for i := 0; i < cfg.Steps; i++ {
|
|
||||||
stepStart := time.Now()
|
|
||||||
if cfg.Progress != nil {
|
|
||||||
cfg.Progress(i+1, cfg.Steps)
|
|
||||||
}
|
|
||||||
|
|
||||||
t := scheduler.Timesteps[i]
|
|
||||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
|
||||||
mlx.Eval(timestep)
|
|
||||||
|
|
||||||
latents2D := mlx.Squeeze(latents, 2)
|
|
||||||
patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize)
|
|
||||||
latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1)
|
|
||||||
|
|
||||||
var output *mlx.Array
|
|
||||||
if useCFG {
|
|
||||||
// CFG Batching: single forward pass with batch=2
|
|
||||||
// Tile inputs: [1, L, D] -> [2, L, D]
|
|
||||||
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
|
|
||||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
|
||||||
|
|
||||||
// Single batched forward pass
|
|
||||||
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
|
||||||
|
|
||||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
|
||||||
D := batchedOutput.Shape()[2]
|
|
||||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
|
|
||||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
|
|
||||||
|
|
||||||
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
|
|
||||||
} else {
|
|
||||||
output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
|
||||||
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]})
|
|
||||||
}
|
|
||||||
|
|
||||||
noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize)
|
|
||||||
oldLatents := latents
|
|
||||||
latents = scheduler.Step(noisePred, latents, i)
|
|
||||||
mlx.Eval(latents)
|
|
||||||
oldLatents.Free()
|
|
||||||
|
|
||||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Free denoising temporaries
|
|
||||||
posEmb.Free()
|
|
||||||
if negEmb != nil {
|
|
||||||
negEmb.Free()
|
|
||||||
}
|
|
||||||
if batchedEmb != nil {
|
|
||||||
batchedEmb.Free()
|
|
||||||
}
|
|
||||||
ropeCache.ImgFreqs.Free()
|
|
||||||
ropeCache.TxtFreqs.Free()
|
|
||||||
imageLatentsPacked.Free()
|
|
||||||
|
|
||||||
// Decode latents
|
|
||||||
decoded := m.decodeAndPostprocess(latents)
|
|
||||||
latents.Free()
|
|
||||||
|
|
||||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
|
||||||
return decoded, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling.
|
|
||||||
// This prevents CFG from inflating magnitude too much.
|
|
||||||
func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array {
|
|
||||||
// Upcast to float32 for precision
|
|
||||||
posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32)
|
|
||||||
negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32)
|
|
||||||
|
|
||||||
// CFG: pred = neg + scale * (pos - neg)
|
|
||||||
diff := mlx.Sub(posF32, negF32)
|
|
||||||
scaledDiff := mlx.MulScalar(diff, scale)
|
|
||||||
combPred := mlx.Add(negF32, scaledDiff)
|
|
||||||
|
|
||||||
// Norm rescaling: rescale combined prediction to match conditional norm
|
|
||||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true))
|
|
||||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
|
||||||
output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
|
||||||
|
|
||||||
mlx.Eval(output)
|
|
||||||
return mlx.ToBFloat16(output)
|
|
||||||
}
|
|
||||||
|
|
||||||
// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1].
|
|
||||||
func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array {
|
|
||||||
latents = m.VAE.Denormalize(latents)
|
|
||||||
decoded := m.VAE.Decode(latents)
|
|
||||||
|
|
||||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
|
||||||
decoded = mlx.Squeeze(decoded, 2)
|
|
||||||
decoded = mlx.AddScalar(decoded, 1.0)
|
|
||||||
decoded = mlx.DivScalar(decoded, 2.0)
|
|
||||||
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
|
|
||||||
mlx.Eval(decoded)
|
|
||||||
return decoded
|
|
||||||
}
|
|
||||||
|
|
||||||
// padSequence pads a sequence tensor to the target length with zeros
|
|
||||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
currentLen := shape[1]
|
|
||||||
if currentLen >= targetLen {
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
padLen := targetLen - currentLen
|
|
||||||
// Pad on sequence dimension (axis 1)
|
|
||||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadPersistent is an alias for backward compatibility.
|
|
||||||
func LoadPersistent(modelPath string) (*Model, error) {
|
|
||||||
m := &Model{}
|
|
||||||
if err := m.Load(modelPath); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PrepareRoPEMultiImage computes RoPE with interpolation for image editing.
|
|
||||||
// Handles single or multiple input images with different resolutions.
|
|
||||||
//
|
|
||||||
// Parameters:
|
|
||||||
// - outPH, outPW: output patch dimensions (noise latent resolution)
|
|
||||||
// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...]
|
|
||||||
// - txtLen: text sequence length
|
|
||||||
// - axesDims: RoPE axis dimensions [16, 56, 56]
|
|
||||||
//
|
|
||||||
// Returns RoPE cache where:
|
|
||||||
// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions
|
|
||||||
// - First outPH*outPW positions are for noise latents (standard RoPE at output res)
|
|
||||||
// - Following positions are for each input image (interpolated from output res)
|
|
||||||
func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache {
|
|
||||||
theta := float64(10000)
|
|
||||||
maxIdx := int32(4096)
|
|
||||||
|
|
||||||
// Compute base frequencies for each axis dimension
|
|
||||||
freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta)
|
|
||||||
freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta)
|
|
||||||
freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta)
|
|
||||||
|
|
||||||
// Build frequency lookup tables
|
|
||||||
posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
|
||||||
posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false)
|
|
||||||
posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false)
|
|
||||||
negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image
|
|
||||||
negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true)
|
|
||||||
negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true)
|
|
||||||
|
|
||||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
|
||||||
|
|
||||||
// Helper to compute RoPE for a single position at output resolution with scale_rope
|
|
||||||
computePosFreqs := func(framePos, y, x int32) []float32 {
|
|
||||||
row := make([]float32, headDim)
|
|
||||||
idx := 0
|
|
||||||
|
|
||||||
// Frame position
|
|
||||||
for i := 0; i < len(freqsT)*2; i++ {
|
|
||||||
row[idx+i] = posFreqsT[framePos][i]
|
|
||||||
}
|
|
||||||
idx += len(freqsT) * 2
|
|
||||||
|
|
||||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
|
||||||
outHHalf := outPH / 2
|
|
||||||
hNegCount := outPH - outHHalf
|
|
||||||
if y < hNegCount {
|
|
||||||
negTableIdx := maxIdx - hNegCount + y
|
|
||||||
for i := 0; i < len(freqsH)*2; i++ {
|
|
||||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
posIdx := y - hNegCount
|
|
||||||
for i := 0; i < len(freqsH)*2; i++ {
|
|
||||||
row[idx+i] = posFreqsH[posIdx][i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
idx += len(freqsH) * 2
|
|
||||||
|
|
||||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
|
||||||
outWHalf := outPW / 2
|
|
||||||
wNegCount := outPW - outWHalf
|
|
||||||
if x < wNegCount {
|
|
||||||
negTableIdx := maxIdx - wNegCount + x
|
|
||||||
for i := 0; i < len(freqsW)*2; i++ {
|
|
||||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
posIdx := x - wNegCount
|
|
||||||
for i := 0; i < len(freqsW)*2; i++ {
|
|
||||||
row[idx+i] = posFreqsW[posIdx][i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return row
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to compute RoPE for frame -1 (used for last condition image)
|
|
||||||
// This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:]
|
|
||||||
computeNegFrameFreqs := func(y, x int32) []float32 {
|
|
||||||
row := make([]float32, headDim)
|
|
||||||
idx := 0
|
|
||||||
|
|
||||||
// Frame -1: use last row of negative frame frequencies
|
|
||||||
negFrameIdx := maxIdx - 1
|
|
||||||
for i := 0; i < len(freqsT)*2; i++ {
|
|
||||||
row[idx+i] = negFreqsT[negFrameIdx][i]
|
|
||||||
}
|
|
||||||
idx += len(freqsT) * 2
|
|
||||||
|
|
||||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
|
||||||
outHHalf := outPH / 2
|
|
||||||
hNegCount := outPH - outHHalf
|
|
||||||
if y < hNegCount {
|
|
||||||
negTableIdx := maxIdx - hNegCount + y
|
|
||||||
for i := 0; i < len(freqsH)*2; i++ {
|
|
||||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
posIdx := y - hNegCount
|
|
||||||
for i := 0; i < len(freqsH)*2; i++ {
|
|
||||||
row[idx+i] = posFreqsH[posIdx][i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
idx += len(freqsH) * 2
|
|
||||||
|
|
||||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
|
||||||
outWHalf := outPW / 2
|
|
||||||
wNegCount := outPW - outWHalf
|
|
||||||
if x < wNegCount {
|
|
||||||
negTableIdx := maxIdx - wNegCount + x
|
|
||||||
for i := 0; i < len(freqsW)*2; i++ {
|
|
||||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
posIdx := x - wNegCount
|
|
||||||
for i := 0; i < len(freqsW)*2; i++ {
|
|
||||||
row[idx+i] = posFreqsW[posIdx][i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return row
|
|
||||||
}
|
|
||||||
|
|
||||||
// Total image sequence length: noise + all input images
|
|
||||||
noiseSeqLen := outPH * outPW
|
|
||||||
totalImgLen := noiseSeqLen
|
|
||||||
for _, dims := range inputDims {
|
|
||||||
totalImgLen += dims.PatchH * dims.PatchW
|
|
||||||
}
|
|
||||||
|
|
||||||
imgFreqsData := make([]float32, totalImgLen*headDim)
|
|
||||||
idx := int32(0)
|
|
||||||
|
|
||||||
// Segment 0: Noise latents - standard RoPE at output resolution (frame 0)
|
|
||||||
for y := int32(0); y < outPH; y++ {
|
|
||||||
for x := int32(0); x < outPW; x++ {
|
|
||||||
row := computePosFreqs(0, y, x)
|
|
||||||
copy(imgFreqsData[idx:], row)
|
|
||||||
idx += headDim
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Segments 1..N: Edit image latents - INTERPOLATED RoPE
|
|
||||||
// For single image: use frame 1 (matches original PrepareRoPEInterpolated)
|
|
||||||
// For multiple images: Python uses frame -1 for the LAST condition image
|
|
||||||
// (_compute_condition_freqs), positive indices for others.
|
|
||||||
numImages := len(inputDims)
|
|
||||||
lastImgIdx := numImages - 1
|
|
||||||
for imgIdx, dims := range inputDims {
|
|
||||||
inPH := dims.PatchH
|
|
||||||
inPW := dims.PatchW
|
|
||||||
|
|
||||||
// Determine frame index for this image
|
|
||||||
// Single image case: use frame 1 (like original PrepareRoPEInterpolated)
|
|
||||||
// Multi-image case: last image uses frame -1, others use frame 1, 2, etc.
|
|
||||||
useNegFrame := numImages > 1 && imgIdx == lastImgIdx
|
|
||||||
|
|
||||||
// Map each input position to an output position using linear interpolation
|
|
||||||
for y := int32(0); y < inPH; y++ {
|
|
||||||
for x := int32(0); x < inPW; x++ {
|
|
||||||
// Interpolate: map input (y, x) to output grid position
|
|
||||||
// This is the key fix from DiffSynth's forward_sampling
|
|
||||||
var yOut, xOut int32
|
|
||||||
if inPH == 1 {
|
|
||||||
yOut = 0
|
|
||||||
} else {
|
|
||||||
// Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1)
|
|
||||||
yOut = y * (outPH - 1) / (inPH - 1)
|
|
||||||
}
|
|
||||||
if inPW == 1 {
|
|
||||||
xOut = 0
|
|
||||||
} else {
|
|
||||||
xOut = x * (outPW - 1) / (inPW - 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
var row []float32
|
|
||||||
if useNegFrame {
|
|
||||||
// Last image in multi-image uses frame -1
|
|
||||||
row = computeNegFrameFreqs(yOut, xOut)
|
|
||||||
} else {
|
|
||||||
// Single image uses frame 1, multi-image uses frame 1, 2, etc.
|
|
||||||
frameIdx := int32(imgIdx + 1)
|
|
||||||
row = computePosFreqs(frameIdx, yOut, xOut)
|
|
||||||
}
|
|
||||||
copy(imgFreqsData[idx:], row)
|
|
||||||
idx += headDim
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim})
|
|
||||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
|
||||||
|
|
||||||
// Text frequencies - start after max video index
|
|
||||||
maxVidIdx := max(outPH/2, outPW/2)
|
|
||||||
|
|
||||||
txtFreqsData := make([]float32, txtLen*headDim)
|
|
||||||
idx = 0
|
|
||||||
for t := int32(0); t < txtLen; t++ {
|
|
||||||
pos := maxVidIdx + t
|
|
||||||
for i := 0; i < len(freqsT)*2; i++ {
|
|
||||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
|
||||||
}
|
|
||||||
idx += int32(len(freqsT) * 2)
|
|
||||||
for i := 0; i < len(freqsH)*2; i++ {
|
|
||||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
|
||||||
}
|
|
||||||
idx += int32(len(freqsH) * 2)
|
|
||||||
for i := 0; i < len(freqsW)*2; i++ {
|
|
||||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
|
||||||
}
|
|
||||||
idx += int32(len(freqsW) * 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
|
||||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
|
||||||
|
|
||||||
return &qwen_image.RoPECache{
|
|
||||||
ImgFreqs: imgFreqs,
|
|
||||||
TxtFreqs: txtFreqs,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,249 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image_edit
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestMain initializes MLX before running tests.
|
|
||||||
// If MLX libraries are not available, tests are skipped.
|
|
||||||
func TestMain(m *testing.M) {
|
|
||||||
// Change to repo root so ./build/lib/ollama/ path works
|
|
||||||
_, thisFile, _, _ := runtime.Caller(0)
|
|
||||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
|
||||||
if err := os.Chdir(repoRoot); err != nil {
|
|
||||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := mlx.InitMLX(); err != nil {
|
|
||||||
fmt.Printf("Skipping qwen_image_edit tests: %v\n", err)
|
|
||||||
os.Exit(0)
|
|
||||||
}
|
|
||||||
os.Exit(m.Run())
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestComputeAxisFreqs verifies frequency computation matches Python reference
|
|
||||||
func TestComputeAxisFreqs(t *testing.T) {
|
|
||||||
theta := float64(10000)
|
|
||||||
|
|
||||||
// Expected values from Python:
|
|
||||||
// freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim))
|
|
||||||
expectedFreqsT := []float64{
|
|
||||||
1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684,
|
|
||||||
0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017,
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedFreqsH_first4 := []float64{
|
|
||||||
1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494,
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedFreqsH_last4 := []float64{
|
|
||||||
0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test temporal frequencies (dim=16)
|
|
||||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
|
||||||
if len(freqsT) != 8 {
|
|
||||||
t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT))
|
|
||||||
}
|
|
||||||
for i, expected := range expectedFreqsT {
|
|
||||||
if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 {
|
|
||||||
t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test height/width frequencies (dim=56)
|
|
||||||
freqsH := qwen_image.ComputeAxisFreqs(56, theta)
|
|
||||||
if len(freqsH) != 28 {
|
|
||||||
t.Fatalf("expected 28 height frequencies, got %d", len(freqsH))
|
|
||||||
}
|
|
||||||
for i, expected := range expectedFreqsH_first4 {
|
|
||||||
if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 {
|
|
||||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i, expected := range expectedFreqsH_last4 {
|
|
||||||
idx := 24 + i // last 4 of 28
|
|
||||||
if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 {
|
|
||||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions
|
|
||||||
func TestMakeFreqTable(t *testing.T) {
|
|
||||||
theta := float64(10000)
|
|
||||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
|
||||||
maxIdx := int32(4096)
|
|
||||||
|
|
||||||
// Test positive table
|
|
||||||
posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
|
||||||
|
|
||||||
// Position 0 should give cos=1, sin=0 for all frequencies
|
|
||||||
for i := 0; i < len(freqsT)*2; i += 2 {
|
|
||||||
if posTable[0][i] != 1.0 {
|
|
||||||
t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i])
|
|
||||||
}
|
|
||||||
if posTable[0][i+1] != 0.0 {
|
|
||||||
t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Position 1, first frequency (1.0): angle = 1*1 = 1
|
|
||||||
// cos(1) = 0.5403, sin(1) = 0.8415
|
|
||||||
if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 {
|
|
||||||
t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0])
|
|
||||||
}
|
|
||||||
if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 {
|
|
||||||
t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test negative table
|
|
||||||
negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true)
|
|
||||||
|
|
||||||
// negTable[4095] corresponds to position -1
|
|
||||||
// cos(-1) = cos(1), sin(-1) = -sin(1)
|
|
||||||
if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 {
|
|
||||||
t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0])
|
|
||||||
}
|
|
||||||
if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 {
|
|
||||||
t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1])
|
|
||||||
}
|
|
||||||
|
|
||||||
// negTable[4094] corresponds to position -2
|
|
||||||
// cos(-2) = cos(2), sin(-2) = -sin(2)
|
|
||||||
cos2 := math.Cos(2.0)
|
|
||||||
sin2 := math.Sin(2.0)
|
|
||||||
if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 {
|
|
||||||
t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0])
|
|
||||||
}
|
|
||||||
if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 {
|
|
||||||
t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case
|
|
||||||
func TestPrepareRoPE_QwenImage(t *testing.T) {
|
|
||||||
if !mlx.GPUIsAvailable() {
|
|
||||||
t.Skip("GPU not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
mlx.SetDefaultDeviceCPU()
|
|
||||||
|
|
||||||
// 4x4 patch grid, single image
|
|
||||||
imgH, imgW := int32(4), int32(4)
|
|
||||||
txtLen := int32(5)
|
|
||||||
axesDims := []int32{16, 56, 56}
|
|
||||||
|
|
||||||
cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims)
|
|
||||||
mlx.Eval(cache.ImgFreqs, cache.TxtFreqs)
|
|
||||||
|
|
||||||
// Check shapes
|
|
||||||
imgShape := cache.ImgFreqs.Shape()
|
|
||||||
if imgShape[0] != 16 { // 4*4 patches
|
|
||||||
t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// For single image (frame=0), all temporal values should be cos=1, sin=0
|
|
||||||
imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32)
|
|
||||||
mlx.Eval(imgFreqsCPU)
|
|
||||||
imgData := imgFreqsCPU.Data()
|
|
||||||
|
|
||||||
// Check first 16 values of patch 0 (temporal cos/sin pairs)
|
|
||||||
for i := 0; i < 16; i += 2 {
|
|
||||||
cosVal := imgData[i]
|
|
||||||
sinVal := imgData[i+1]
|
|
||||||
if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 {
|
|
||||||
t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal)
|
|
||||||
}
|
|
||||||
if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 {
|
|
||||||
t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cache.ImgFreqs.Free()
|
|
||||||
cache.TxtFreqs.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestScaleRopePositions verifies the centered position calculation for scale_rope=True
|
|
||||||
func TestScaleRopePositions(t *testing.T) {
|
|
||||||
// For a 4x4 grid with scale_rope=True:
|
|
||||||
// hHalf = 2, wHalf = 2
|
|
||||||
// hNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
|
||||||
// wNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
|
||||||
//
|
|
||||||
// Height positions:
|
|
||||||
// y=0: -(4-2) + 0 = -2
|
|
||||||
// y=1: -(4-2) + 1 = -1
|
|
||||||
// y=2: 2 - 2 = 0
|
|
||||||
// y=3: 3 - 2 = 1
|
|
||||||
//
|
|
||||||
// Same for width
|
|
||||||
|
|
||||||
pH, pW := int32(4), int32(4)
|
|
||||||
hHalf := pH / 2
|
|
||||||
wHalf := pW / 2
|
|
||||||
hNegCount := pH - hHalf
|
|
||||||
wNegCount := pW - wHalf
|
|
||||||
|
|
||||||
expectedH := []int32{-2, -1, 0, 1}
|
|
||||||
expectedW := []int32{-2, -1, 0, 1}
|
|
||||||
|
|
||||||
for y := int32(0); y < pH; y++ {
|
|
||||||
var hPos int32
|
|
||||||
if y < hNegCount {
|
|
||||||
hPos = -(pH - hHalf) + y
|
|
||||||
} else {
|
|
||||||
hPos = y - hNegCount
|
|
||||||
}
|
|
||||||
if hPos != expectedH[y] {
|
|
||||||
t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for x := int32(0); x < pW; x++ {
|
|
||||||
var wPos int32
|
|
||||||
if x < wNegCount {
|
|
||||||
wPos = -(pW - wHalf) + x
|
|
||||||
} else {
|
|
||||||
wPos = x - wNegCount
|
|
||||||
}
|
|
||||||
if wPos != expectedW[x] {
|
|
||||||
t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRoPEHeadDimensions verifies the head dimension breakdown
|
|
||||||
func TestRoPEHeadDimensions(t *testing.T) {
|
|
||||||
// axes_dims_rope = [16, 56, 56]
|
|
||||||
// Each dimension uses half the values for frequencies
|
|
||||||
// So we get: 8 + 28 + 28 = 64 frequency values
|
|
||||||
// Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position
|
|
||||||
|
|
||||||
axesDims := []int32{16, 56, 56}
|
|
||||||
expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2)
|
|
||||||
expectedHeadDim := expectedFreqs * 2
|
|
||||||
|
|
||||||
if expectedFreqs != 64 {
|
|
||||||
t.Errorf("expected 64 frequency values, got %d", expectedFreqs)
|
|
||||||
}
|
|
||||||
if expectedHeadDim != 128 {
|
|
||||||
t.Errorf("expected head_dim=128, got %d", expectedHeadDim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// This should match the transformer's attention head dimension
|
|
||||||
// hidden_size = 3072, num_heads = 24
|
|
||||||
// head_dim = 3072 / 24 = 128
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -1,642 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen_image_edit
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// VAEConfig holds Qwen-Image VAE configuration
|
|
||||||
type VAEConfig struct {
|
|
||||||
ZDim int32 `json:"z_dim"` // 16
|
|
||||||
BaseDim int32 `json:"base_dim"` // 96
|
|
||||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
|
||||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
|
||||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
|
||||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
|
||||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
|
||||||
func defaultVAEConfig() *VAEConfig {
|
|
||||||
return &VAEConfig{
|
|
||||||
ZDim: 16,
|
|
||||||
BaseDim: 96,
|
|
||||||
DimMult: []int32{1, 2, 4, 4},
|
|
||||||
NumResBlocks: 2,
|
|
||||||
LatentsMean: []float32{
|
|
||||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
|
||||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
|
||||||
0.4134, -0.0715, 0.5517, -0.3632,
|
|
||||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
|
||||||
},
|
|
||||||
LatentsStd: []float32{
|
|
||||||
2.8184, 1.4541, 2.3275, 2.6558,
|
|
||||||
1.2196, 1.7708, 2.6052, 2.0743,
|
|
||||||
3.2687, 2.1526, 2.8652, 1.5579,
|
|
||||||
1.6382, 1.1253, 2.8251, 1.916,
|
|
||||||
},
|
|
||||||
TemperalDownsample: []bool{false, true, true},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// VAE is the full VAE with encoder and decoder
|
|
||||||
type VAE struct {
|
|
||||||
Config *VAEConfig
|
|
||||||
Encoder *VAEEncoder
|
|
||||||
Decoder *VAEDecoder
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load loads the VAE from a directory
|
|
||||||
func (m *VAE) Load(path string) error {
|
|
||||||
fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
|
|
||||||
|
|
||||||
cfg := defaultVAEConfig()
|
|
||||||
m.Config = cfg
|
|
||||||
|
|
||||||
weights, err := safetensors.LoadModelWeights(path)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("weights: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load weights as f32 for quality (matches Python default behavior)
|
|
||||||
// VAE decoder precision is critical for final image quality
|
|
||||||
fmt.Print(" Loading weights as f32... ")
|
|
||||||
if err := weights.Load(mlx.DtypeFloat32); err != nil {
|
|
||||||
return fmt.Errorf("failed to load weights: %w", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
|
||||||
|
|
||||||
// Load encoder
|
|
||||||
fmt.Print(" Loading encoder... ")
|
|
||||||
m.Encoder = &VAEEncoder{}
|
|
||||||
if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
|
|
||||||
return fmt.Errorf("encoder: %w", err)
|
|
||||||
}
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
// Load decoder
|
|
||||||
fmt.Print(" Loading decoder... ")
|
|
||||||
m.Decoder = &VAEDecoder{}
|
|
||||||
if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
|
|
||||||
return fmt.Errorf("decoder: %w", err)
|
|
||||||
}
|
|
||||||
fmt.Println("✓")
|
|
||||||
|
|
||||||
weights.ReleaseAll()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode encodes an image to latents
|
|
||||||
// x: [B, C, T, H, W] image tensor in [-1, 1] range
|
|
||||||
// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
|
|
||||||
func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
|
|
||||||
return m.Encoder.Encode(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode decodes latents to image
|
|
||||||
// z: [B, C, T, H, W] latents (denormalized)
|
|
||||||
// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
|
|
||||||
func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
|
|
||||||
return m.Decoder.Decode(z)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize applies latent normalization
|
|
||||||
// Input z should be f32 (from VAE encoder), output is f32 for transformer
|
|
||||||
func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
|
|
||||||
shape := z.Shape()
|
|
||||||
C := shape[1]
|
|
||||||
|
|
||||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
|
||||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
|
||||||
|
|
||||||
// Mean/std are f32, will match z dtype through broadcasting
|
|
||||||
return mlx.Div(mlx.Sub(z, mean), std)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Denormalize reverses latent normalization
|
|
||||||
// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
|
|
||||||
func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
|
|
||||||
shape := z.Shape()
|
|
||||||
C := shape[1]
|
|
||||||
|
|
||||||
// Convert latents to f32 for VAE decoder quality
|
|
||||||
z = mlx.AsType(z, mlx.DtypeFloat32)
|
|
||||||
|
|
||||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
|
||||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
|
||||||
|
|
||||||
return mlx.Add(mlx.Mul(z, std), mean)
|
|
||||||
}
|
|
||||||
|
|
||||||
// VAEEncoder is the encoder part of the VAE
|
|
||||||
// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
|
|
||||||
// - Blocks 0,1: ResBlocks (base_dim)
|
|
||||||
// - Block 2: Downsample
|
|
||||||
// - Blocks 3,4: ResBlocks (base_dim*2)
|
|
||||||
// - Block 5: Downsample + temporal
|
|
||||||
// - Blocks 6,7: ResBlocks (base_dim*4)
|
|
||||||
// - Block 8: Downsample + temporal
|
|
||||||
// - Blocks 9,10: ResBlocks (base_dim*4)
|
|
||||||
type VAEEncoder struct {
|
|
||||||
Config *VAEConfig
|
|
||||||
|
|
||||||
ConvIn *CausalConv3d
|
|
||||||
Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
|
|
||||||
MidBlock *MidBlock
|
|
||||||
NormOut *RMSNorm3D
|
|
||||||
ConvOut *CausalConv3d
|
|
||||||
QuantConv *CausalConv3d
|
|
||||||
}
|
|
||||||
|
|
||||||
// EncoderBlock is either a ResBlock or a Downsample
|
|
||||||
type EncoderBlock interface {
|
|
||||||
Forward(x *mlx.Array) *mlx.Array
|
|
||||||
IsDownsample() bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// EncoderResBlock wraps ResBlock
|
|
||||||
type EncoderResBlock struct {
|
|
||||||
*ResBlock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *EncoderResBlock) IsDownsample() bool { return false }
|
|
||||||
|
|
||||||
// EncoderDownsample is a downsample layer
|
|
||||||
type EncoderDownsample struct {
|
|
||||||
Resample *CausalConv3d
|
|
||||||
TimeConv *CausalConv3d // Optional temporal downsample
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *EncoderDownsample) IsDownsample() bool { return true }
|
|
||||||
|
|
||||||
func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
// Spatial downsample with stride 2
|
|
||||||
// WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
|
|
||||||
x = d.forwardSpatialDownsample(x)
|
|
||||||
|
|
||||||
// NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
|
|
||||||
// with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
|
|
||||||
// The Python forward checks: if feat_cache is not None ... then use time_conv
|
|
||||||
// Since we don't support streaming, we skip time_conv entirely.
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
|
|
||||||
func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
|
|
||||||
xShape := x.Shape()
|
|
||||||
B := xShape[0]
|
|
||||||
T := xShape[1]
|
|
||||||
H := xShape[2]
|
|
||||||
W := xShape[3]
|
|
||||||
C := xShape[4]
|
|
||||||
|
|
||||||
wShape := d.Resample.Weight.Shape()
|
|
||||||
outC := wShape[0]
|
|
||||||
|
|
||||||
// Reshape to [B*T, H, W, C] for 2D conv
|
|
||||||
x = mlx.Reshape(x, B*T, H, W, C)
|
|
||||||
|
|
||||||
// Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
|
|
||||||
// ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
|
|
||||||
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
|
|
||||||
|
|
||||||
// Apply 2D conv with stride 2
|
|
||||||
weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
|
||||||
x = conv2DStrided(x, weight, 2)
|
|
||||||
|
|
||||||
if d.Resample.Bias != nil {
|
|
||||||
bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
|
|
||||||
x = mlx.Add(x, bias)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output dims after stride 2: (H+1)/2, (W+1)/2
|
|
||||||
outH := (H + 1) / 2
|
|
||||||
outW := (W + 1) / 2
|
|
||||||
|
|
||||||
// Reshape back to [B, T, H', W', C]
|
|
||||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
|
||||||
mlx.Eval(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadFromWeights loads the encoder from pre-loaded weights
|
|
||||||
func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
|
||||||
e.Config = cfg
|
|
||||||
|
|
||||||
// Conv in
|
|
||||||
convIn, err := newCausalConv3d(weights, "encoder.conv_in")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
e.ConvIn = convIn
|
|
||||||
|
|
||||||
// Encoder uses flat block structure:
|
|
||||||
// dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
|
|
||||||
// Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
|
|
||||||
// That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
|
|
||||||
e.Blocks = make([]EncoderBlock, 0, 11)
|
|
||||||
|
|
||||||
// Track dimensions
|
|
||||||
dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
|
|
||||||
blockIdx := 0
|
|
||||||
|
|
||||||
for stage := 0; stage < len(cfg.DimMult); stage++ {
|
|
||||||
inDim := cfg.BaseDim
|
|
||||||
if stage > 0 {
|
|
||||||
inDim = dims[stage-1]
|
|
||||||
}
|
|
||||||
outDim := dims[stage]
|
|
||||||
|
|
||||||
// ResBlocks for this stage (num_res_blocks per stage)
|
|
||||||
for r := int32(0); r < cfg.NumResBlocks; r++ {
|
|
||||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
|
||||||
currentInDim := inDim
|
|
||||||
if r > 0 {
|
|
||||||
currentInDim = outDim
|
|
||||||
}
|
|
||||||
block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
|
|
||||||
}
|
|
||||||
e.Blocks = append(e.Blocks, block)
|
|
||||||
blockIdx++
|
|
||||||
}
|
|
||||||
|
|
||||||
// Downsample after each stage except the last
|
|
||||||
if stage < len(cfg.DimMult)-1 {
|
|
||||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
|
||||||
down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
|
|
||||||
}
|
|
||||||
e.Blocks = append(e.Blocks, down)
|
|
||||||
blockIdx++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mid block
|
|
||||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
|
||||||
midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
e.MidBlock = midBlock
|
|
||||||
|
|
||||||
// Norm out
|
|
||||||
normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
e.NormOut = normOut
|
|
||||||
|
|
||||||
// Conv out
|
|
||||||
convOut, err := newCausalConv3d(weights, "encoder.conv_out")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
e.ConvOut = convOut
|
|
||||||
|
|
||||||
// Quant conv
|
|
||||||
quantConv, err := newCausalConv3d(weights, "quant_conv")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
e.QuantConv = quantConv
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
|
|
||||||
func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
|
|
||||||
block, err := newResBlock(weights, prefix, inDim, outDim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &EncoderResBlock{block}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// newEncoderDownsample creates a downsample layer for the encoder
|
|
||||||
func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
|
|
||||||
resample, err := newCausalConv3d(weights, prefix+".resample.1")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var timeConv *CausalConv3d
|
|
||||||
if temporal {
|
|
||||||
timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &EncoderDownsample{
|
|
||||||
Resample: resample,
|
|
||||||
TimeConv: timeConv,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode encodes an image to latents
|
|
||||||
// x: [B, C, T, H, W] image tensor (channels-first)
|
|
||||||
// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
|
|
||||||
func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
|
|
||||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
|
||||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
|
|
||||||
mlx.Eval(x)
|
|
||||||
|
|
||||||
// Conv in
|
|
||||||
x = e.ConvIn.Forward(x)
|
|
||||||
|
|
||||||
// Encoder blocks (mix of ResBlocks and Downsamplers)
|
|
||||||
for _, block := range e.Blocks {
|
|
||||||
prev := x
|
|
||||||
x = block.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mid block
|
|
||||||
x = e.MidBlock.Forward(x)
|
|
||||||
|
|
||||||
// Norm + silu
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
x = e.NormOut.Forward(x)
|
|
||||||
x = silu3D(x)
|
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conv out
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
x = e.ConvOut.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Quant conv
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
x = e.QuantConv.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get mode from distribution (first half of channels = mean)
|
|
||||||
// Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
|
|
||||||
shape := x.Shape()
|
|
||||||
latentC := shape[4] / 2
|
|
||||||
x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
|
|
||||||
|
|
||||||
// Convert back to channels-first [N, C, T, H, W]
|
|
||||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
|
||||||
mlx.Eval(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// VAEDecoder is the decoder part of the VAE
|
|
||||||
type VAEDecoder struct {
|
|
||||||
Config *VAEConfig
|
|
||||||
|
|
||||||
PostQuantConv *CausalConv3d
|
|
||||||
ConvIn *CausalConv3d
|
|
||||||
MidBlock *MidBlock
|
|
||||||
UpBlocks []*UpBlock
|
|
||||||
NormOut *RMSNorm3D
|
|
||||||
ConvOut *CausalConv3d
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadFromWeights loads the decoder from pre-loaded weights
|
|
||||||
func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
|
||||||
d.Config = cfg
|
|
||||||
|
|
||||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
d.PostQuantConv = postQuantConv
|
|
||||||
|
|
||||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
d.ConvIn = convIn
|
|
||||||
|
|
||||||
// Mid block
|
|
||||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
|
||||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
d.MidBlock = midBlock
|
|
||||||
|
|
||||||
// Up blocks (reversed dim_mult)
|
|
||||||
numUpBlocks := len(cfg.DimMult)
|
|
||||||
d.UpBlocks = make([]*UpBlock, numUpBlocks)
|
|
||||||
|
|
||||||
dimsMult := make([]int32, numUpBlocks+1)
|
|
||||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
|
||||||
for i := 0; i < numUpBlocks; i++ {
|
|
||||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
|
||||||
}
|
|
||||||
|
|
||||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
|
||||||
for i := range cfg.TemperalDownsample {
|
|
||||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < numUpBlocks; i++ {
|
|
||||||
inDim := cfg.BaseDim * dimsMult[i]
|
|
||||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
|
||||||
|
|
||||||
if i > 0 {
|
|
||||||
inDim = inDim / 2
|
|
||||||
}
|
|
||||||
|
|
||||||
upsampleMode := ""
|
|
||||||
if i < numUpBlocks-1 {
|
|
||||||
if temporalUpsample[i] {
|
|
||||||
upsampleMode = "upsample3d"
|
|
||||||
} else {
|
|
||||||
upsampleMode = "upsample2d"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
|
||||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
d.UpBlocks[i] = upBlock
|
|
||||||
}
|
|
||||||
|
|
||||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
d.NormOut = normOut
|
|
||||||
|
|
||||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
d.ConvOut = convOut
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode converts latents to image
|
|
||||||
// z: [B, C, T, H, W] denormalized latents
|
|
||||||
func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
|
||||||
var x *mlx.Array
|
|
||||||
|
|
||||||
// Convert from channels-first to channels-last
|
|
||||||
{
|
|
||||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
|
||||||
mlx.Eval(z)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostQuantConv
|
|
||||||
x = d.PostQuantConv.Forward(z)
|
|
||||||
z.Free()
|
|
||||||
|
|
||||||
// ConvIn
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
x = d.ConvIn.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mid block
|
|
||||||
x = d.MidBlock.Forward(x)
|
|
||||||
|
|
||||||
// Up blocks
|
|
||||||
for _, upBlock := range d.UpBlocks {
|
|
||||||
x = upBlock.Forward(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NormOut + silu
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
x = d.NormOut.Forward(x)
|
|
||||||
x = silu3D(x)
|
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConvOut
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
x = d.ConvOut.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Post-processing: clamp and convert back to channels-first
|
|
||||||
{
|
|
||||||
prev := x
|
|
||||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
|
||||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// DownBlock handles downsampling in encoder
|
|
||||||
type DownBlock struct {
|
|
||||||
ResBlocks []*ResBlock
|
|
||||||
Downsampler *Downsample
|
|
||||||
}
|
|
||||||
|
|
||||||
// newDownBlock creates a down block
|
|
||||||
func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
|
|
||||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
|
||||||
|
|
||||||
currentDim := inDim
|
|
||||||
for i := int32(0); i <= numBlocks; i++ {
|
|
||||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
|
||||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resBlocks[i] = block
|
|
||||||
currentDim = outDim
|
|
||||||
}
|
|
||||||
|
|
||||||
var downsampler *Downsample
|
|
||||||
if downsampleMode != "" {
|
|
||||||
downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &DownBlock{
|
|
||||||
ResBlocks: resBlocks,
|
|
||||||
Downsampler: downsampler,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies down block
|
|
||||||
func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
for _, block := range d.ResBlocks {
|
|
||||||
prev := x
|
|
||||||
x = block.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
|
|
||||||
if d.Downsampler != nil {
|
|
||||||
prev := x
|
|
||||||
x = d.Downsampler.Forward(x)
|
|
||||||
prev.Free()
|
|
||||||
}
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// Downsample handles spatial downsampling
|
|
||||||
type Downsample struct {
|
|
||||||
Conv *mlx.Array
|
|
||||||
Bias *mlx.Array
|
|
||||||
Mode string
|
|
||||||
}
|
|
||||||
|
|
||||||
// newDownsample creates a downsampler
|
|
||||||
func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
|
|
||||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
|
||||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
|
||||||
return &Downsample{
|
|
||||||
Conv: conv,
|
|
||||||
Bias: bias,
|
|
||||||
Mode: mode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies downsampling to channels-last input [B, T, H, W, C]
|
|
||||||
func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B := shape[0]
|
|
||||||
T := shape[1]
|
|
||||||
H := shape[2]
|
|
||||||
W := shape[3]
|
|
||||||
C := shape[4]
|
|
||||||
outC := d.Conv.Shape()[0]
|
|
||||||
|
|
||||||
// Reshape to [B*T, H, W, C] for 2D conv
|
|
||||||
x = mlx.Reshape(x, B*T, H, W, C)
|
|
||||||
|
|
||||||
// Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
|
|
||||||
// For 3x3 stride 2: pad 1 on all sides
|
|
||||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
|
||||||
|
|
||||||
// Conv with stride 2 using manual strided patching
|
|
||||||
weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
|
|
||||||
x = conv2DStrided(x, weight, 2)
|
|
||||||
if d.Bias != nil {
|
|
||||||
bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
|
|
||||||
x = mlx.Add(x, bias)
|
|
||||||
}
|
|
||||||
|
|
||||||
x = mlx.Reshape(x, B, T, H/2, W/2, outC)
|
|
||||||
mlx.Eval(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
@@ -9,7 +9,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/manifest"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// modelConfig represents the HuggingFace config.json structure
|
// modelConfig represents the HuggingFace config.json structure
|
||||||
@@ -35,22 +36,22 @@ type modelConfig struct {
|
|||||||
|
|
||||||
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
|
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
|
||||||
// It reads the config.json layer and returns a map compatible with GGML's KV format.
|
// It reads the config.json layer and returns a map compatible with GGML's KV format.
|
||||||
func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
|
func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
|
||||||
manifest, err := imagegen.LoadManifest(modelName)
|
mf, err := manifest.ParseNamedManifest(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var config modelConfig
|
var config modelConfig
|
||||||
if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
|
if err := mf.ReadConfigJSON("config.json", &config); err != nil {
|
||||||
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate total tensor bytes from manifest layers
|
// Calculate total tensor bytes from manifest layers
|
||||||
var totalBytes int64
|
var totalBytes int64
|
||||||
var tensorCount int64
|
var tensorCount int64
|
||||||
for _, layer := range manifest.Manifest.Layers {
|
for _, layer := range mf.Layers {
|
||||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||||
totalBytes += layer.Size
|
totalBytes += layer.Size
|
||||||
tensorCount++
|
tensorCount++
|
||||||
}
|
}
|
||||||
@@ -151,27 +152,30 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
|
|||||||
|
|
||||||
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
||||||
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
||||||
func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
|
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
|
||||||
manifest, err := imagegen.LoadManifest(modelName)
|
mf, err := manifest.ParseNamedManifest(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return getTensorInfoFromManifest(manifest)
|
return getTensorInfoFromManifest(mf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getTensorInfoFromManifest extracts tensor info from a manifest.
|
// getTensorInfoFromManifest extracts tensor info from a manifest.
|
||||||
// This is separated for testability.
|
// This is separated for testability.
|
||||||
func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
|
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
|
||||||
var tensors []api.Tensor
|
var tensors []api.Tensor
|
||||||
|
|
||||||
for _, layer := range manifest.Manifest.Layers {
|
for _, layer := range mf.Layers {
|
||||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
if layer.MediaType != manifest.MediaTypeImageTensor {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the safetensors header from the blob
|
// Read the safetensors header from the blob
|
||||||
blobPath := manifest.BlobPath(layer.Digest)
|
blobPath, err := manifest.BlobsPath(layer.Digest)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
info, err := readSafetensorsHeader(blobPath)
|
info, err := readSafetensorsHeader(blobPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Skip tensors we can't read
|
// Skip tensors we can't read
|
||||||
@@ -197,15 +201,15 @@ func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor,
|
|||||||
// GetSafetensorsDtype returns the quantization type for a safetensors model.
|
// GetSafetensorsDtype returns the quantization type for a safetensors model.
|
||||||
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
|
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
|
||||||
// Otherwise returns the torch_dtype from config.json.
|
// Otherwise returns the torch_dtype from config.json.
|
||||||
func GetSafetensorsDtype(modelName string) (string, error) {
|
func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||||
manifest, err := imagegen.LoadManifest(modelName)
|
mf, err := manifest.ParseNamedManifest(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to load manifest: %w", err)
|
return "", fmt.Errorf("failed to load manifest: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if model is quantized by looking for _scale tensors
|
// Check if model is quantized by looking for _scale tensors
|
||||||
for _, layer := range manifest.Manifest.Layers {
|
for _, layer := range mf.Layers {
|
||||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||||
if strings.HasSuffix(layer.Name, "_scale") {
|
if strings.HasSuffix(layer.Name, "_scale") {
|
||||||
// Model is quantized - return FP8 (affine quantization)
|
// Model is quantized - return FP8 (affine quantization)
|
||||||
return "FP8", nil
|
return "FP8", nil
|
||||||
@@ -217,7 +221,7 @@ func GetSafetensorsDtype(modelName string) (string, error) {
|
|||||||
var cfg struct {
|
var cfg struct {
|
||||||
TorchDtype string `json:"torch_dtype"`
|
TorchDtype string `json:"torch_dtype"`
|
||||||
}
|
}
|
||||||
if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
|
if err := mf.ReadConfigJSON("config.json", &cfg); err != nil {
|
||||||
return "", fmt.Errorf("failed to read config.json: %w", err)
|
return "", fmt.Errorf("failed to read config.json: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/manifest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBuildModelInfo(t *testing.T) {
|
func TestBuildModelInfo(t *testing.T) {
|
||||||
@@ -451,8 +451,14 @@ func TestParseSafetensorsHeader_Errors(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetTensorInfoFromManifest(t *testing.T) {
|
func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||||
// Create a temp directory for blobs
|
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||||
|
|
||||||
|
blobDir := filepath.Join(tempDir, "blobs")
|
||||||
|
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create blobs dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Create test tensor blobs
|
// Create test tensor blobs
|
||||||
tensors := []struct {
|
tensors := []struct {
|
||||||
@@ -463,26 +469,26 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "model.embed_tokens.weight",
|
name: "model.embed_tokens.weight",
|
||||||
digest: "sha256:abc123",
|
digest: "sha256:abc123abc123abc123abc123abc123abc123abc123abc123abc123abc123abc0",
|
||||||
dtype: "BF16",
|
dtype: "BF16",
|
||||||
shape: []int64{262144, 2560},
|
shape: []int64{262144, 2560},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "model.layers.0.self_attn.q_proj.weight",
|
name: "model.layers.0.self_attn.q_proj.weight",
|
||||||
digest: "sha256:def456",
|
digest: "sha256:def456def456def456def456def456def456def456def456def456def456def0",
|
||||||
dtype: "BF16",
|
dtype: "BF16",
|
||||||
shape: []int64{2560, 2560},
|
shape: []int64{2560, 2560},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "model.norm.weight",
|
name: "model.norm.weight",
|
||||||
digest: "sha256:ghi789",
|
digest: "sha256:789789789789789789789789789789789789789789789789789789789789abc0",
|
||||||
dtype: "F32",
|
dtype: "F32",
|
||||||
shape: []int64{2560},
|
shape: []int64{2560},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create blob files
|
// Create blob files
|
||||||
var layers []imagegen.ManifestLayer
|
var layers []manifest.Layer
|
||||||
for _, tensor := range tensors {
|
for _, tensor := range tensors {
|
||||||
// Create safetensors blob
|
// Create safetensors blob
|
||||||
header := map[string]any{
|
header := map[string]any{
|
||||||
@@ -498,15 +504,17 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
|||||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||||
buf.Write(headerJSON)
|
buf.Write(headerJSON)
|
||||||
|
|
||||||
// Write blob file
|
// Write blob file using the digest format expected by GetBlobsPath
|
||||||
blobName := "sha256-" + tensor.digest[7:]
|
blobPath, err := manifest.BlobsPath(tensor.digest)
|
||||||
blobPath := filepath.Join(tempDir, blobName)
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get blob path: %v", err)
|
||||||
|
}
|
||||||
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
||||||
t.Fatalf("failed to write blob: %v", err)
|
t.Fatalf("failed to write blob: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
layers = append(layers, imagegen.ManifestLayer{
|
layers = append(layers, manifest.Layer{
|
||||||
MediaType: "application/vnd.ollama.image.tensor",
|
MediaType: manifest.MediaTypeImageTensor,
|
||||||
Digest: tensor.digest,
|
Digest: tensor.digest,
|
||||||
Size: int64(buf.Len() + 1000), // header + fake data
|
Size: int64(buf.Len() + 1000), // header + fake data
|
||||||
Name: tensor.name,
|
Name: tensor.name,
|
||||||
@@ -514,21 +522,20 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add a non-tensor layer (should be skipped)
|
// Add a non-tensor layer (should be skipped)
|
||||||
layers = append(layers, imagegen.ManifestLayer{
|
layers = append(layers, manifest.Layer{
|
||||||
MediaType: "application/vnd.ollama.image.json",
|
MediaType: "application/vnd.ollama.image.json",
|
||||||
Digest: "sha256:config",
|
Digest: "sha256:0000000000000000000000000000000000000000000000000000000000000000",
|
||||||
Size: 100,
|
Size: 100,
|
||||||
Name: "config.json",
|
Name: "config.json",
|
||||||
})
|
})
|
||||||
|
|
||||||
manifest := &imagegen.ModelManifest{
|
mf := &manifest.Manifest{
|
||||||
Manifest: &imagegen.Manifest{
|
SchemaVersion: 2,
|
||||||
Layers: layers,
|
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||||
},
|
Layers: layers,
|
||||||
BlobDir: tempDir,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := getTensorInfoFromManifest(manifest)
|
result, err := getTensorInfoFromManifest(mf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
|
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user