74 lines
2.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package prompt
import (
"fmt"
"math/rand"
"strings"
"time"
"github.com/tiktoken-go/tokenizer"
"mingor/llm-api-benchmark-tool/internal/config"
)
// PromptGenerator 生成短/长提示词
// 使用 tiktoken-go tokenizer 计算 token 数
type PromptGenerator struct {
cfg *config.Config
tok *tokenizer.Tokenizer
}
// NewPromptGenerator 初始化生成器model 用于 tokenizer
func NewPromptGenerator(model string, cfg *config.Config) (*PromptGenerator, error) {
tok, err := tokenizer.NewTokenizer(model)
if err != nil {
return nil, err
}
rand.Seed(time.Now().UnixNano())
return &PromptGenerator{cfg: cfg, tok: tok}, nil
}
// Generate 根据类型(tp)和占位值(values)生成提示词
// tp 可为 "short" 或 "long" 等配置中的键
func (g *PromptGenerator) Generate(tp string, values map[string]string) (string, error) {
pt, ok := g.cfg.PromptTemplates[tp]
if !ok {
return "", fmt.Errorf("unknown prompt type: %s", tp)
}
if len(pt.Templates) == 0 {
return "", fmt.Errorf("no templates for type: %s", tp)
}
tpl := pt.Templates[rand.Intn(len(pt.Templates))]
prompt := fillTemplate(tpl, values)
// 计算 token 数
toks, err := g.tok.Encode(prompt)
if err != nil {
return "", err
}
target := pt.TargetTokens
tol := int(float64(target) * 0.05)
cnt := len(toks)
if cnt < target-tol || cnt > target+tol {
// 长度不在范围,用户可根据需要调整
}
return prompt, nil
}
// CountTokens 返回文本的 token 数
func (g *PromptGenerator) CountTokens(text string) (int, error) {
toks, err := g.tok.Encode(text)
if err != nil {
return 0, err
}
return len(toks), nil
}
// fillTemplate 用 map[string]string 替换模板中的 {key}
func fillTemplate(tpl string, values map[string]string) string {
out := tpl
for k, v := range values {
out = strings.ReplaceAll(out, "{"+k+"}", v)
}
return out
}