本帖最后由 CrLf 于 2023-7-29 00:17 编辑
命令行下调用OpenAI接口,从标准输入中读取用户输入并将其发送到GPT模型,再将响应写入标准输出。因原版默认是UTF8,所以我改成默认以GBK编码读取输入,并增加 --utf8 开关兼容utf8编码。
原版GitHub:https://github.com/pdfinn/sgpt
用法:- sgpt -k <API_KEY> -i <INSTRUCTION> [-t TEMPERATURE] [-m MODEL] [-s SEPARATOR] [-u] [-d]
复制代码 参数说明:
短参数 | 长参数 | 环境变量 | 描述 | 默认值 | -k | --api_key | SGPT_API_KEY | 配置OpenAI的API KEY | 无 | -i | --instruction | SGPT_INSTRUCTION | 系统指令,用于补充一些背景信息或要求 | 无 | -t | --temperature | SGPT_TEMPERATURE | 温度值,范围是0~1,数值越高,给出的答案越有想象力但也更倾向于编造 | 0.5 | -m | --model | SGPT_MODEL | 所采用的模型 | gpt-3.5-turbo | -s | --separator | SGPT_SEPARATOR | 不同内容的分隔符 | \n | -u | --utf8 | SGPT_UTF8 | 以UTF8编码解读输入内容(该参数由CrLf添加,使默认编码是GBK) | false | -d | --debug | SGPT_DEBUG | 启用调试模式,将输出很多调试信息 | false |
CrLf修改后的源码:- package main
-
- import (
- "bufio"
- "encoding/json"
- "fmt"
- "github.com/spf13/pflag"
- "github.com/spf13/viper"
- "io"
- "io/ioutil"
- "log"
- "net/http"
- "os"
- "strconv"
- "strings"
-
- // mod by CrLf 添加必要的模块
- "bytes"
- "golang.org/x/text/encoding/simplifiedchinese"
- "golang.org/x/text/transform"
-
- )
-
- // mod by CrLf 用于将UTF8转码为GBK
- // UTF-8 转 GBK
- func Utf8ToGbk(s []byte) ([]byte, error) {
- reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewEncoder())
- d, e := ioutil.ReadAll(reader)
- if e != nil {
- return nil, e
- }
- return d, nil
- }
-
- func GbkToUtf8(s []byte) ([]byte, error) {
- reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewDecoder())
- d, e := ioutil.ReadAll(reader)
- if e != nil {
- return nil, e
- }
- return d, nil
- }
-
-
- type OpenAIResponse struct {
- Choices []struct {
- Text string `json:"text,omitempty"`
- Message struct {
- Role string `json:"role,omitempty"`
- Content string `json:"content,omitempty"`
- } `json:"message,omitempty"`
- } `json:"choices"`
- }
-
- // mod by CrLf 声明utf8变量
- var utf8 *bool
- var debug *bool
-
- func init() {
- // mod by CrLf 去除重复的提醒
-
- // envUTF8 := os.Getenv("SGPT_UTF8")
- // envDebug := os.Getenv("SGPT_DEBUG")
- // utf8 = pflag.Bool("u", parseBoolWithDefault(envUTF8, false), "Enable UTF8 input")
- // debug = pflag.Bool("d", parseBoolWithDefault(envDebug, false), "Enable debug output")
- }
-
- func main() {
- // Default values
- defaultTemperature := 0.5
- defaultModel := "gpt-3.5-turbo"
-
- // Check environment variables
- envApiKey := os.Getenv("SGPT_API_KEY")
- envInstruction := os.Getenv("SGPT_INSTRUCTION")
- envTemperature, err := strconv.ParseFloat(os.Getenv("SGPT_TEMPERATURE"), 64)
- if err != nil {
- envTemperature = defaultTemperature
- }
- envModel := os.Getenv("SGPT_MODEL")
- envSeparator := os.Getenv("SGPT_SEPARATOR")
-
- // mod by CrLf 增加对环境变量 SGPT_UTF8 的支持
- envUTF8 := parseBoolWithDefault(os.Getenv("SGPT_UTF8"), false)
- envDebug := parseBoolWithDefault(os.Getenv("SGPT_DEBUG"), false)
-
- // Command line arguments
- apiKey := pflag.StringP("api_key", "k", envApiKey, "OpenAI API key")
- instruction := pflag.StringP("instruction", "i", envInstruction, "Instruction for the GPT model")
- temperature := pflag.Float64P("temperature", "t", envTemperature, "Temperature for the GPT model")
- model := pflag.StringP("model", "m", envModel, "GPT model to use")
- defaulSeparator := "\n"
- separator := pflag.StringP("separator", "s", envSeparator, "Separator character for input")
- if *separator == "" {
- *separator = defaulSeparator
- }
-
- // mod by CrLf 增加对参数 --utf8 或 -u 的支持
- utf8 = pflag.BoolP("utf8", "u", envUTF8, "Enable UTF8 input")
- debug = pflag.BoolP("debug", "d", envDebug, "Enable debug output")
- pflag.Parse()
-
- // Read the configuration file
- viper.SetConfigName("sgpt")
- viper.AddConfigPath(".")
- viper.AddConfigPath("$HOME/.sgpt")
- viper.SetConfigType("yaml")
-
- err = viper.ReadInConfig()
-
- // mod by CrLf 默认屏蔽无用警告,仅在debug模式下展示
- if _, ok := err.(viper.ConfigFileNotFoundError); ok {
- debugOutput(*debug, "Warning: Config file not found: %v", err)
- } else if err != nil {
- debugOutput(*debug, "Warning: Error reading config file: %v", err)
- }
-
- // Set default values and bind configuration values to flags
- viper.SetDefault("model", defaultModel)
- viper.SetDefault("temperature", defaultTemperature)
- viper.BindPFlag("api_key", pflag.Lookup("k"))
- viper.BindPFlag("instruction", pflag.Lookup("i"))
- viper.BindPFlag("model", pflag.Lookup("m"))
- viper.BindPFlag("temperature", pflag.Lookup("t"))
- viper.BindPFlag("separator", pflag.Lookup("s"))
- viper.BindPFlag("debug", pflag.Lookup("d"))
-
- // Use default values if neither flags nor environment variables are set
- if *model == "" {
- *model = defaultModel
- }
-
- if *apiKey == "" {
- log.Fatal("API key is required")
- }
-
-
- // Read input from stdin continuously
- // mod by CrLf 根据utf8开关的启禁用状态判断以utf8还是gbk读取stdin
- var reader io.RuneReader
- if *utf8 {
- reader = bufio.NewReader(os.Stdin)
- } else {
- byteInput, _ := io.ReadAll(os.Stdin)
- gbkBytes, _ := GbkToUtf8(byteInput)
- reader = bytes.NewReader(gbkBytes)
- }
-
- var inputBuffer strings.Builder
-
- for {
- inputChar, _, err := reader.ReadRune()
- if err == io.EOF {
- input := inputBuffer.String()
- if input != "" {
- response, err := callOpenAI(*apiKey, *instruction, input, *temperature, *model)
- if err != nil {
- log.Fatal(err)
- }
- fmt.Println(response)
- }
- break
- }
- if err != nil {
- log.Fatal(err)
- }
-
- if string(inputChar) == *separator {
- input := inputBuffer.String()
- inputBuffer.Reset()
-
- response, err := callOpenAI(*apiKey, *instruction, input, *temperature, *model)
- if err != nil {
- log.Fatal(err)
- }
-
- fmt.Println(response)
- } else {
- inputBuffer.WriteRune(inputChar)
- }
- }
- }
-
- func debugOutput(debug bool, format string, a ...interface{}) {
- if debug {
- log.Printf(format, a...)
- }
- }
-
- func parseFloatWithDefault(value string, defaultValue float64) float64 {
- if value == "" {
- return defaultValue
- }
- parsedValue, err := strconv.ParseFloat(value, 64)
- if err != nil {
- log.Printf("Warning: Failed to parse float value: %v", err)
- return defaultValue
- }
- return parsedValue
- }
-
- func parseBoolWithDefault(value string, defaultValue bool) bool {
- if value == "" {
- return defaultValue
- }
- parsedValue, err := strconv.ParseBool(value)
- if err != nil {
- log.Printf("Warning: Failed to parse bool value: %v", err)
- return defaultValue
- }
- return parsedValue
- }
-
- func callOpenAI(apiKey, instruction, input string, temperature float64, model string) (string, error) {
- var url string
- var jsonData []byte
- var err error
-
- switch model {
- case "gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo":
- url = "https://api.openai.com/v1/chat/completions"
-
- // Prepare JSON data for GPT-4 models
- messages := []map[string]string{
- {"role": "system", "content": instruction},
- {"role": "user", "content": input},
- }
-
- jsonData, err = json.Marshal(map[string]interface{}{
- "model": model,
- "messages": messages,
- "temperature": temperature,
- "max_tokens": 100,
- "stop": []string{"\n"},
- })
-
- case "text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001":
- url = "https://api.openai.com/v1/completions"
-
- // Prepare JSON data for GPT-3 models
- prompt := instruction + " " + input
- jsonData, err = json.Marshal(map[string]interface{}{
- "model": model,
- "prompt": prompt,
- "temperature": temperature,
- "max_tokens": 100,
- "stop": []string{"\n"},
- })
-
- case "whisper-1":
- url = "https://api.openai.com/v1/audio/transcriptions"
- default:
- return "", fmt.Errorf("unsupported model: %s", model)
- }
-
- if err != nil {
- return "", err
- }
-
- data := strings.NewReader(string(jsonData))
-
- req, err := http.NewRequest("POST", url, data)
- if err != nil {
- return "", err
- }
-
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+apiKey)
-
- client := &http.Client{}
- resp, err := client.Do(req)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
-
- body, err := ioutil.ReadAll(resp.Body)
- if err != nil {
- return "", err
- }
-
- debugOutput(*debug, "API response: %s\n", string(body))
-
- var openAIResponse OpenAIResponse
- err = json.Unmarshal(body, &openAIResponse)
- if err != nil {
- return "", err
- }
-
- if len(openAIResponse.Choices) == 0 {
- debugOutput(*debug, "API response: %s\n", string(body))
- debugOutput(*debug, "HTTP status code: %s\n", strconv.Itoa(resp.StatusCode))
- return "", fmt.Errorf("no choices returned from the API")
- }
-
- assistantMessage := ""
- for _, choice := range openAIResponse.Choices {
- if choice.Message.Role == "assistant" {
- assistantMessage = strings.TrimSpace(choice.Message.Content)
- break
- }
- if choice.Text != "" {
- assistantMessage = strings.TrimSpace(choice.Text)
- break
- }
- }
-
- if assistantMessage == "" {
- return "", fmt.Errorf("no assistant message found in the API response")
- }
-
- return assistantMessage, nil
- }
复制代码 编译后的下载地址:http://bcn.bathome.net/s/tool/index.html?key=sgpt |