2025-04-21 18:19:09 +08:00

104 lines
2.6 KiB
Go

package config
import (
"fmt"
"github.com/spf13/viper"
)
// Config 表示应用程序的配置
type Config struct {
API APIConfig `mapstructure:"api"`
PromptTemplates PromptTemplatesConfig `mapstructure:"prompt_templates"`
Requests []RequestConfig `mapstructure:"requests"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
Timeout int `mapstructure:"timeout"`
PoissonLambda float64 `mapstructure:"poisson_lambda"`
Tokenizer TokenizerConfig `mapstructure:"tokenizer"`
}
// APIConfig 表示API相关配置
type APIConfig struct {
Endpoint string `mapstructure:"endpoint"`
APIKey string `mapstructure:"api_key"`
Model string `mapstructure:"model"`
}
// PromptTemplatesConfig 表示提示词模板配置
type PromptTemplatesConfig struct {
Short PromptTypeConfig `mapstructure:"short"`
Long PromptTypeConfig `mapstructure:"long"`
}
// PromptTypeConfig 表示特定类型提示词的配置
type PromptTypeConfig struct {
TargetTokens int `mapstructure:"target_tokens"`
Templates []string `mapstructure:"templates"`
}
// RequestConfig 表示请求配置
type RequestConfig struct {
Type string `mapstructure:"type"`
Weight float64 `mapstructure:"weight"`
}
// ConcurrencyConfig 表示并发配置
type ConcurrencyConfig struct {
Steps []int `mapstructure:"steps"`
DurationPerStep int `mapstructure:"duration_per_step"`
}
// TokenizerConfig 表示分词器配置
type TokenizerConfig struct {
Model string `mapstructure:"model"`
}
// LoadConfig 从指定路径加载配置
func LoadConfig(path string) (*Config, error) {
viper.SetConfigFile(path)
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("解析配置失败: %w", err)
}
// 验证配置
if err := validateConfig(&config); err != nil {
return nil, err
}
return &config, nil
}
// validateConfig 验证配置的有效性
func validateConfig(config *Config) error {
if config.API.Endpoint == "" {
return fmt.Errorf("API端点不能为空")
}
if config.API.APIKey == "" {
return fmt.Errorf("API密钥不能为空")
}
if config.API.Model == "" {
return fmt.Errorf("模型名称不能为空")
}
if len(config.Requests) == 0 {
return fmt.Errorf("请求配置不能为空")
}
if len(config.Concurrency.Steps) == 0 {
return fmt.Errorf("并发步骤不能为空")
}
if config.Timeout <= 0 {
return fmt.Errorf("超时时间必须大于0")
}
return nil
}