package config import ( "fmt" "github.com/spf13/viper" ) // Config 表示应用程序的配置 type Config struct { API APIConfig `mapstructure:"api"` PromptTemplates PromptTemplatesConfig `mapstructure:"prompt_templates"` Requests []RequestConfig `mapstructure:"requests"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"` Timeout int `mapstructure:"timeout"` PoissonLambda float64 `mapstructure:"poisson_lambda"` Tokenizer TokenizerConfig `mapstructure:"tokenizer"` } // APIConfig 表示API相关配置 type APIConfig struct { Endpoint string `mapstructure:"endpoint"` APIKey string `mapstructure:"api_key"` Model string `mapstructure:"model"` MaxTokens int `mapstructure:"max_tokens"` Temperature float64 `mapstructure:"temperature"` } // PromptTemplatesConfig 表示提示词模板配置 type PromptTemplatesConfig struct { Short PromptTypeConfig `mapstructure:"short"` Long PromptTypeConfig `mapstructure:"long"` } // PromptTypeConfig 表示特定类型提示词的配置 type PromptTypeConfig struct { TargetTokens int `mapstructure:"target_tokens"` Templates []string `mapstructure:"templates"` } // RequestConfig 表示请求配置 type RequestConfig struct { Type string `mapstructure:"type"` Weight float64 `mapstructure:"weight"` } // ConcurrencyConfig 表示并发配置 type ConcurrencyConfig struct { Steps []int `mapstructure:"steps"` DurationPerStep int `mapstructure:"duration_per_step"` } // TokenizerConfig 表示分词器配置 type TokenizerConfig struct { Model string `mapstructure:"model"` } // LoadConfig 从指定路径加载配置 func LoadConfig(path string) (*Config, error) { viper.SetConfigFile(path) if err := viper.ReadInConfig(); err != nil { return nil, fmt.Errorf("读取配置文件失败: %w", err) } var config Config if err := viper.Unmarshal(&config); err != nil { return nil, fmt.Errorf("解析配置失败: %w", err) } // 验证配置 if err := validateConfig(&config); err != nil { return nil, err } // 设置默认值 if config.API.MaxTokens == 0 { config.API.MaxTokens = 2048 } if config.API.Temperature == 0 { config.API.Temperature = 0.7 } return &config, nil } // validateConfig 验证配置的有效性 func validateConfig(config *Config) error { if config.API.Endpoint == "" { return fmt.Errorf("API端点不能为空") } if config.API.APIKey == "" { return fmt.Errorf("API密钥不能为空") } if config.API.Model == "" { return fmt.Errorf("模型名称不能为空") } if len(config.Requests) == 0 { return fmt.Errorf("请求配置不能为空") } if len(config.Concurrency.Steps) == 0 { return fmt.Errorf("并发步骤不能为空") } if config.Timeout <= 0 { return fmt.Errorf("超时时间必须大于0") } return nil }