104 lines
2.6 KiB
Go
104 lines
2.6 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
// Config 表示应用程序的配置
|
|
type Config struct {
|
|
API APIConfig `mapstructure:"api"`
|
|
PromptTemplates PromptTemplatesConfig `mapstructure:"prompt_templates"`
|
|
Requests []RequestConfig `mapstructure:"requests"`
|
|
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
|
Timeout int `mapstructure:"timeout"`
|
|
PoissonLambda float64 `mapstructure:"poisson_lambda"`
|
|
Tokenizer TokenizerConfig `mapstructure:"tokenizer"`
|
|
}
|
|
|
|
// APIConfig 表示API相关配置
|
|
type APIConfig struct {
|
|
Endpoint string `mapstructure:"endpoint"`
|
|
APIKey string `mapstructure:"api_key"`
|
|
Model string `mapstructure:"model"`
|
|
}
|
|
|
|
// PromptTemplatesConfig 表示提示词模板配置
|
|
type PromptTemplatesConfig struct {
|
|
Short PromptTypeConfig `mapstructure:"short"`
|
|
Long PromptTypeConfig `mapstructure:"long"`
|
|
}
|
|
|
|
// PromptTypeConfig 表示特定类型提示词的配置
|
|
type PromptTypeConfig struct {
|
|
TargetTokens int `mapstructure:"target_tokens"`
|
|
Templates []string `mapstructure:"templates"`
|
|
}
|
|
|
|
// RequestConfig 表示请求配置
|
|
type RequestConfig struct {
|
|
Type string `mapstructure:"type"`
|
|
Weight float64 `mapstructure:"weight"`
|
|
}
|
|
|
|
// ConcurrencyConfig 表示并发配置
|
|
type ConcurrencyConfig struct {
|
|
Steps []int `mapstructure:"steps"`
|
|
DurationPerStep int `mapstructure:"duration_per_step"`
|
|
}
|
|
|
|
// TokenizerConfig 表示分词器配置
|
|
type TokenizerConfig struct {
|
|
Model string `mapstructure:"model"`
|
|
}
|
|
|
|
// LoadConfig 从指定路径加载配置
|
|
func LoadConfig(path string) (*Config, error) {
|
|
viper.SetConfigFile(path)
|
|
|
|
if err := viper.ReadInConfig(); err != nil {
|
|
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
|
}
|
|
|
|
var config Config
|
|
if err := viper.Unmarshal(&config); err != nil {
|
|
return nil, fmt.Errorf("解析配置失败: %w", err)
|
|
}
|
|
|
|
// 验证配置
|
|
if err := validateConfig(&config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &config, nil
|
|
}
|
|
|
|
// validateConfig 验证配置的有效性
|
|
func validateConfig(config *Config) error {
|
|
if config.API.Endpoint == "" {
|
|
return fmt.Errorf("API端点不能为空")
|
|
}
|
|
|
|
if config.API.APIKey == "" {
|
|
return fmt.Errorf("API密钥不能为空")
|
|
}
|
|
|
|
if config.API.Model == "" {
|
|
return fmt.Errorf("模型名称不能为空")
|
|
}
|
|
|
|
if len(config.Requests) == 0 {
|
|
return fmt.Errorf("请求配置不能为空")
|
|
}
|
|
|
|
if len(config.Concurrency.Steps) == 0 {
|
|
return fmt.Errorf("并发步骤不能为空")
|
|
}
|
|
|
|
if config.Timeout <= 0 {
|
|
return fmt.Errorf("超时时间必须大于0")
|
|
}
|
|
|
|
return nil
|
|
}
|