rwkv

package module
v0.0.0-...-51a35e7 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jan 12, 2024 License: MIT Imports: 19 Imported by: 0

README

rwkv

pure go for rwkv and support cross-platform.

Go Reference

rwkv.go is a wrapper around rwkv-cpp, which is an adaption of ggml.cpp.

Installation

go get github.com/seasonjs/rwkv

AutoModel Compatibility

See deps folder for dylib compatibility, you can build the library by yourself, and push request is welcome.

So far NewRwkvAutoModel only support Windows ROCM GFX1100.

If you want to use GPU, please make sure your GPU support Windows ROCM GFX1100.

Click here to see your Windows ROCM architecture.

platform x32 x64 arm AMD/ROCM NVIDIA/CUDA
windows not support support avx2/avx512/avx not support support GFX1100 not support
linux not support support not support not support not support
darwin not support support support not support not support

Usage

You don't need to download rwkv dynamic library.

package main

import (
	"fmt"
	"github.com/seasonjs/rwkv"
)

func main() {
	model, err := rwkv.NewRwkvAutoModel(rwkv.RwkvOptions{
		MaxTokens:   100,
		StopString:  "\n",
		Temperature: 0.8,
		TopP:        0.5,
		TokenizerType: rwkv.Normal, //or World 
		CpuThreads:    10,
	})
	if err != nil {
		fmt.Print(err.Error())
	}
	defer func(rwkv *rwkv.RwkvModel) {
		err := model.Close()
		if err != nil {
			panic(err)
		}
	}(model)

	err = model.LoadFromFile("./data/rwkv-110M-Q5.bin")
	if err != nil {
		fmt.Print(err.Error())
	}

	// This context hold the logits and status, as well can int a new one.
	ctx, err := model.InitState()
	if err != nil {
		fmt.Print(err.Error())
	}
	out, err := ctx.Predict("hello ")
	if err != nil {
		fmt.Print(err.Error())
	}
	fmt.Print(out)

	// We can use `PredictStream` to generate response like `ChatGPT`

	ctx1, err := model.InitState()
	if err != nil {
		fmt.Print(err.Error())
	}
	responseText := ""
	msg := make(chan string)
	ctx1.PredictStream("hello ", msg)
	if err != nil {
		fmt.Print(err.Error())
	}
	for value := range msg {
		responseText += value
	}
	fmt.Print(responseText)
}

Now GPU is supported!! you can use NewRwkvAutoModel to set GpuEnable. see AutoModel Compatibility about gpu support.

package main

import (
	"fmt"
	"github.com/seasonjs/rwkv"
)

func main() {
	model, err := NewRwkvAutoModel(rwkv.RwkvOptions{
		MaxTokens:     100,
		StopString:    "/n",
		Temperature:   0.8,
		TopP:          0.5,
		TokenizerType: World, //or World
		PrintError:    true,
		CpuThreads:    10,
		GpuEnable:     true,
		//GpuOffLoadLayers:      0, //default 0 means all layers will offload to gpu
	})

	if err != nil {
		panic(err)
	}

	defer model.Close()

	err = model.LoadFromFile("./models/RWKV-novel-4-World-7B-20230810-ctx128k-ggml-Q5_1.bin")
	if err != nil {
		t.Error(err)
		return
	}

	// NOTICE: This context hold the logits and status, as well can int a new one.
	ctx, err := rwkv.InitState()
	if err != nil {
		panic(err)
	}
	out, err := ctx.Predict("hello ")
	if err != nil {
		fmt.Printf("error: %v", err)
	}
	fmt.Print(out)
}

If NewRwkvAutoModel can't automatic loading of dynamic library, please use NewRwkvModel method load manually.

package main

import (
	"fmt"
	"github.com/seasonjs/rwkv"
	"runtime"
)

func getLibrary() string {
	switch runtime.GOOS {
	case "darwin":
		return "./deps/darwin/librwkv_x86.dylib"
	case "linux":
		return "./deps/linux/librwkv.so"
	case "windows":
		return "./deps/windows/rwkv_avx2_x64.dll"
	default:
		panic(fmt.Errorf("GOOS=%s is not supported", runtime.GOOS))
	}
}

func main() {
	model, err := rwkv.NewRwkvModel(getLibrary(), rwkv.RwkvOptions{
		MaxTokens:   100,
		StopString:  "\n",
		Temperature: 0.8,
		TopP:        0.5,
		TokenizerType: rwkv.Normal, //or World 
		CpuThreads:    10,
	})
	if err != nil {
		fmt.Print(err.Error())
	}
	defer func(rwkv *rwkv.RwkvModel) {
		err := model.Close()
		if err != nil {
			panic(err)
		}
	}(model)

	err = model.LoadFromFile("./data/rwkv-110M-Q5.bin")
	if err != nil {
		fmt.Print(err.Error())
	}

	// This context hold the logits and status, as well can int a new one.
	ctx, err := model.InitState()
	if err != nil {
		fmt.Print(err.Error())
	}
	out, err := ctx.Predict("hello ")
	if err != nil {
		fmt.Print(err.Error())
	}
	fmt.Print(out)

	// We can use `PredictStream` to generate response like `ChatGPT`

	ctx1, err := model.InitState()
	if err != nil {
		fmt.Print(err.Error())
	}
	responseText := ""
	msg := make(chan string)
	ctx1.PredictStream("hello ", msg)
	if err != nil {
		fmt.Print(err.Error())
	}
	for value := range msg {
		responseText += value
	}
	fmt.Print(responseText)
}

Packaging

To ship a working program that includes this AI, you will need to include the following files:

  • librwkv.dylib / librwkv.so / rwkv.dll (buildin)
  • the model file
  • the tokenizer file (buildin)

Low level API

This package also provide low level Api which is same as rwkv-cpp. See detail at rwkv-doc.

Thanks

License

Copyright (c) seasonjs. All rights reserved. Licensed under the MIT License. See License.txt in the project root for license information.

Documentation

Index

Constants

View Source
const (
	GEN_alpha_presence  = 0.4 // Presence Penalty
	GEN_alpha_frequency = 0.4 // Frequency Penalty
	GEN_penalty_decay   = 0.996

	END_OF_TEXT = 0
	END_OF_LINE = 11

	AVOID_REPEAT = ",:?!"
)
View Source
const (
	RwkvErrorArgs        RwkvErrors = 1 << 8
	RwkvErrorFile                   = 2 << 8
	RwkvErrorModel                  = 3 << 8
	RwkvErrorModelParams            = 4 << 8
	RwkvErrorGraph                  = 5 << 8
	RwkvErrorCtx                    = 6 << 8
)

Variables

This section is empty.

Functions

func GetGPUInfo

func GetGPUInfo() (string, error)

func SampleLogits

func SampleLogits(logits v32.V32, temperature float32, topP float32, logitBias map[int]float32) (int, error)

Types

type CRwkv

type CRwkv interface {
	// RwkvSetPrintErrors Sets whether errors are automatically printed to stderr.
	// If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails.
	// - ctx: the context to suppress error messages for.
	//   If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors,
	//   as well as the default for new context.
	// - print_errors: whether error messages should be automatically printed.
	RwkvSetPrintErrors(ctx *RwkvCtx, enable bool)

	// RwkvGetPrintErrors Gets whether errors are automatically printed to stderr.
	// - ctx: the context to retrieve the setting for, or NULL for the global setting.
	RwkvGetPrintErrors(ctx *RwkvCtx) bool

	// RwkvGetLastError Retrieves and clears the error flags.
	// - ctx: the context the retrieve the error for, or NULL for the global error.
	RwkvGetLastError(ctx *RwkvCtx) error

	// RwkvInitFromFile Loads the model from a file and prepares it for inference.
	// Returns NULL on any error.
	// - model_file_path: path to model file in ggml format.
	// - n_threads: count of threads to use, must be positive.
	RwkvInitFromFile(filePath string, threads uint32) *RwkvCtx

	// RwkvCloneContext Creates a new context from an existing one.
	// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
	// Each rwkv_context can have one eval running at a time.
	// Every rwkv_context must be freed using rwkv_free.
	// - ctx: context to be cloned.
	// - n_threads: count of threads to use, must be positive.
	RwkvCloneContext(ctx *RwkvCtx, threads uint32) *RwkvCtx

	// RwkvGpuOffloadLayers Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
	// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
	RwkvGpuOffloadLayers(ctx *RwkvCtx, nGpuLayers uint32) error

	// RwkvEval Evaluates the model for a single token.
	// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
	// Returns false on any error.
	// - token: next token index, in range 0 <= token < n_vocab.
	// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
	// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
	// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
	RwkvEval(ctx *RwkvCtx, token uint32, stateIn []float32, stateOut []float32, logitsOut []float32) error

	// RwkvEvalSequence Evaluates the model for a sequence of tokens.
	// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
	// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
	// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
	// Returns false on any error.
	// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
	// - sequence_len: number of tokens to read from the array.
	// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
	// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
	// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
	RwkvEvalSequence(ctx *RwkvCtx, token uint32, sequenceLen uint64, stateIn []float32, stateOut []float32, logitsOut []float32) error

	// RwkvGetNVocab Returns the number of tokens in the given model's vocabulary.
	// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
	RwkvGetNVocab(ctx *RwkvCtx) uint64

	// RwkvGetNEmbedding Returns the number of elements in the given model's embedding.
	// Useful for reading individual fields of a model's hidden state.
	RwkvGetNEmbedding(ctx *RwkvCtx) uint64

	// RwkvGetNLayer Returns the number of layers in the given model.
	// Useful for always offloading the entire model to GPU.
	RwkvGetNLayer(ctx *RwkvCtx) uint64

	// RwkvGetStateLength Returns the number of float elements in a complete state for the given model.
	// This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state.
	RwkvGetStateLength(ctx *RwkvCtx) uint64

	// RwkvGetLogitsLength Returns the number of float elements in the logits output of a given model.
	// This is currently always identical to n_vocab.
	RwkvGetLogitsLength(ctx *RwkvCtx) uint64

	// RwkvInitState Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL.
	// Useful in cases where tracking the first call to these functions may be annoying or expensive.
	// State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs.
	// - state: FP32 buffer of size rwkv_get_state_len() to initialize
	RwkvInitState(ctx *RwkvCtx, state []float32)

	// RwkvFree Frees all allocated memory and the context.
	// Does not need to be called on the same thread that created the rwkv_context.
	RwkvFree(ctx *RwkvCtx) error

	// RwkvQuantizeModelFile Quantizes FP32 or FP16 model to one of quantized formats.
	// Returns false on any error. Error messages would be printed to stderr.
	// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
	// - model_file_path_out: quantized model will be written here.
	// - format_name: must be one of available format names below.
	// Available format names:
	// - Q4_0
	// - Q4_1
	// - Q5_0
	// - Q5_1
	// - Q8_0
	RwkvQuantizeModelFile(ctx *RwkvCtx, in, out string, format QuantizedFormat) error

	// RwkvGetSystemInfoString Returns system information string.
	RwkvGetSystemInfoString() string
}

type CRwkvImpl

type CRwkvImpl struct {
	// contains filtered or unexported fields
}

func NewCRwkv

func NewCRwkv(libraryPath string) (*CRwkvImpl, error)

func (*CRwkvImpl) RwkvCloneContext

func (c *CRwkvImpl) RwkvCloneContext(ctx *RwkvCtx, threads uint32) *RwkvCtx

func (*CRwkvImpl) RwkvEval

func (c *CRwkvImpl) RwkvEval(ctx *RwkvCtx, token uint32, stateIn []float32, stateOut []float32, logitsOut []float32) error

func (*CRwkvImpl) RwkvEvalSequence

func (c *CRwkvImpl) RwkvEvalSequence(ctx *RwkvCtx, token uint32, sequenceLen uint64, stateIn []float32, stateOut []float32, logitsOut []float32) error

func (*CRwkvImpl) RwkvFree

func (c *CRwkvImpl) RwkvFree(ctx *RwkvCtx) error

func (*CRwkvImpl) RwkvGetLastError

func (c *CRwkvImpl) RwkvGetLastError(ctx *RwkvCtx) error

func (*CRwkvImpl) RwkvGetLogitsLength

func (c *CRwkvImpl) RwkvGetLogitsLength(ctx *RwkvCtx) uint64

func (*CRwkvImpl) RwkvGetNEmbedding

func (c *CRwkvImpl) RwkvGetNEmbedding(ctx *RwkvCtx) uint64

func (*CRwkvImpl) RwkvGetNLayer

func (c *CRwkvImpl) RwkvGetNLayer(ctx *RwkvCtx) uint64

func (*CRwkvImpl) RwkvGetNVocab

func (c *CRwkvImpl) RwkvGetNVocab(ctx *RwkvCtx) uint64

func (*CRwkvImpl) RwkvGetPrintErrors

func (c *CRwkvImpl) RwkvGetPrintErrors(ctx *RwkvCtx) bool

func (*CRwkvImpl) RwkvGetStateLength

func (c *CRwkvImpl) RwkvGetStateLength(ctx *RwkvCtx) uint64

func (*CRwkvImpl) RwkvGetSystemInfoString

func (c *CRwkvImpl) RwkvGetSystemInfoString() string

func (*CRwkvImpl) RwkvGpuOffloadLayers

func (c *CRwkvImpl) RwkvGpuOffloadLayers(ctx *RwkvCtx, nGpuLayers uint32) error

func (*CRwkvImpl) RwkvInitFromFile

func (c *CRwkvImpl) RwkvInitFromFile(filePath string, threads uint32) *RwkvCtx

func (*CRwkvImpl) RwkvInitState

func (c *CRwkvImpl) RwkvInitState(ctx *RwkvCtx, state []float32)

func (*CRwkvImpl) RwkvQuantizeModelFile

func (c *CRwkvImpl) RwkvQuantizeModelFile(ctx *RwkvCtx, in, out string, format QuantizedFormat) error

func (*CRwkvImpl) RwkvSetPrintErrors

func (c *CRwkvImpl) RwkvSetPrintErrors(ctx *RwkvCtx, enable bool)

type ChatModel

type ChatModel struct {
	// contains filtered or unexported fields
}

func NewChatModel

func NewChatModel(modelPath string, options RwkvOptions) (*ChatModel, error)

func (*ChatModel) Decode

func (my *ChatModel) Decode(input []int) string

func (*ChatModel) Encode

func (my *ChatModel) Encode(input string) []int

func (*ChatModel) Eval

func (my *ChatModel) Eval(tokens []int) (string, error)

func (*ChatModel) EvalSequence

func (my *ChatModel) EvalSequence(tokens []int, state []float32) ([]float32, []float32)

type Chatbot

type Chatbot struct {
	// contains filtered or unexported fields
}

func NewChatbot

func NewChatbot(model *ChatModel, userName string, botName string, prompt string) *Chatbot

func (*Chatbot) Process

func (my *Chatbot) Process(message string) string

type GpuType

type GpuType string

type NormalTokenizer

type NormalTokenizer struct {
	// contains filtered or unexported fields
}

func NewNormalTokenizer

func NewNormalTokenizer() (*NormalTokenizer, error)

func (*NormalTokenizer) Decode

func (t *NormalTokenizer) Decode(ids []int) string

func (*NormalTokenizer) Encode

func (t *NormalTokenizer) Encode(input string) ([]int, error)

type QuantizedFormat

type QuantizedFormat string
const (
	Q4_0 QuantizedFormat = "Q4_0"
	Q4_1 QuantizedFormat = "Q4_1"
	Q5_0 QuantizedFormat = "Q5_0"
	Q5_1 QuantizedFormat = "Q5_0"
	Q8_0 QuantizedFormat = "Q8_0"
)

type RwkvCtx

type RwkvCtx struct {
	// contains filtered or unexported fields
}

type RwkvErrors

type RwkvErrors uint32
const (
	RwkvErrorNone RwkvErrors = iota
	RwkvErrorAlloc
	RwkvErrorFileOpen
	RwkvErrorFileStat
	RwkvErrorFileRead
	RwkvErrorFileWrite
	RwkvErrorFileMagic
	RwkvErrorFileVersion
	RwkvErrorDataType
	RwkvErrorUnsupported
	RwkvErrorShape
	RwkvErrorDimension
	RwkvErrorKey
	RwkvErrorData
	RwkvErrorParamMissing
)

Represents an error encountered during a function call. These are flags, so an actual value might contain multiple errors.

func (RwkvErrors) Error

func (err RwkvErrors) Error() string

type RwkvModel

type RwkvModel struct {
	// contains filtered or unexported fields
}

func NewRwkvAutoModel

func NewRwkvAutoModel(options RwkvOptions) (*RwkvModel, error)

func NewRwkvModel

func NewRwkvModel(dylibPath string, options RwkvOptions) (*RwkvModel, error)

func (*RwkvModel) Close

func (m *RwkvModel) Close() error

func (*RwkvModel) Gpu

func (m *RwkvModel) Gpu()

func (*RwkvModel) InitState

func (m *RwkvModel) InitState() (*RwkvState, error)

InitState give a new state for new chat context state

func (*RwkvModel) LoadFromFile

func (m *RwkvModel) LoadFromFile(path string) error

func (*RwkvModel) QuantizeModelFile

func (m *RwkvModel) QuantizeModelFile(in, out string, format QuantizedFormat) error

type RwkvOptions

type RwkvOptions struct {
	PrintError       bool
	MaxTokens        int
	StopString       string
	Temperature      float32 // It could be a good idea to increase temperature when top_p is low
	TopP             float32 // Reduce top_p (to 0.5, 0.2, 0.1 etc.) for better Q&A accuracy (and less diversity)
	TokenizerType    TokenizerType
	CpuThreads       uint32
	GpuEnable        bool
	GpuOffLoadLayers uint32
}

type RwkvState

type RwkvState struct {
	// contains filtered or unexported fields
}

func (*RwkvState) Predict

func (s *RwkvState) Predict(input string) (string, error)

Predict give current chat a response

func (*RwkvState) PredictStream

func (s *RwkvState) PredictStream(input string, output chan string)

type Tokenizer

type Tokenizer interface {
	Encode(in string) ([]int, error)
	Decode(in []int) string
}

type TokenizerType

type TokenizerType uint8
const (
	Normal TokenizerType = iota
	World
)

type Trie

type Trie struct {
	// contains filtered or unexported fields
}

Trie represents the trie data structure

func NewTrie

func NewTrie() *Trie

func (*Trie) Add

func (my *Trie) Add(key string, index int, value int) *Trie

func (*Trie) FindLongest

func (my *Trie) FindLongest(key string, index int) (retIndex int, retToken int)

type WorldTokenizer

type WorldTokenizer struct {
	IndexToToken map[int]string
	Trie         *Trie
}

WorldTokenizer represents a tokenizer for encoding and decoding bytes to tokens

func NewWorldTokenizer

func NewWorldTokenizer() (*WorldTokenizer, error)

NewWorldTokenizer initializes a new world tokenizer

func (*WorldTokenizer) Decode

func (wt *WorldTokenizer) Decode(tokens []int) string

Decode decodes tokens to a string

func (*WorldTokenizer) Encode

func (wt *WorldTokenizer) Encode(text string) ([]int, error)

Encode encodes a string to tokens

func (*WorldTokenizer) EncodeBytes

func (wt *WorldTokenizer) EncodeBytes(src string) []int

EncodeBytes encodes bytes to tokens

Directories

Path Synopsis
deps

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL