74 lines
2.0 KiB
Go
74 lines
2.0 KiB
Go
package prompt
|
||
|
||
import (
|
||
"fmt"
|
||
"math/rand"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/tiktoken-go/tokenizer"
|
||
"mingor/llm-api-benchmark-tool/internal/config"
|
||
)
|
||
|
||
// PromptGenerator 生成短/长提示词
|
||
// 使用 tiktoken-go tokenizer 计算 token 数
|
||
|
||
type PromptGenerator struct {
|
||
cfg *config.Config
|
||
tok *tokenizer.Tokenizer
|
||
}
|
||
|
||
// NewPromptGenerator 初始化生成器,model 用于 tokenizer
|
||
func NewPromptGenerator(model string, cfg *config.Config) (*PromptGenerator, error) {
|
||
tok, err := tokenizer.NewTokenizer(model)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
rand.Seed(time.Now().UnixNano())
|
||
return &PromptGenerator{cfg: cfg, tok: tok}, nil
|
||
}
|
||
|
||
// Generate 根据类型(tp)和占位值(values)生成提示词
|
||
// tp 可为 "short" 或 "long" 等配置中的键
|
||
func (g *PromptGenerator) Generate(tp string, values map[string]string) (string, error) {
|
||
pt, ok := g.cfg.PromptTemplates[tp]
|
||
if !ok {
|
||
return "", fmt.Errorf("unknown prompt type: %s", tp)
|
||
}
|
||
if len(pt.Templates) == 0 {
|
||
return "", fmt.Errorf("no templates for type: %s", tp)
|
||
}
|
||
tpl := pt.Templates[rand.Intn(len(pt.Templates))]
|
||
prompt := fillTemplate(tpl, values)
|
||
// 计算 token 数
|
||
toks, err := g.tok.Encode(prompt)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
target := pt.TargetTokens
|
||
tol := int(float64(target) * 0.05)
|
||
cnt := len(toks)
|
||
if cnt < target-tol || cnt > target+tol {
|
||
// 长度不在范围,用户可根据需要调整
|
||
}
|
||
return prompt, nil
|
||
}
|
||
|
||
// CountTokens 返回文本的 token 数
|
||
func (g *PromptGenerator) CountTokens(text string) (int, error) {
|
||
toks, err := g.tok.Encode(text)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return len(toks), nil
|
||
}
|
||
|
||
// fillTemplate 用 map[string]string 替换模板中的 {key}
|
||
func fillTemplate(tpl string, values map[string]string) string {
|
||
out := tpl
|
||
for k, v := range values {
|
||
out = strings.ReplaceAll(out, "{"+k+"}", v)
|
||
}
|
||
return out
|
||
}
|