找回密码
 注册
搜索
[新手上路]批处理新手入门导读[视频教程]批处理基础视频教程[视频教程]VBS基础视频教程[批处理精品]批处理版照片整理器
[批处理精品]纯批处理备份&还原驱动[批处理精品]CMD命令50条不能说的秘密[在线下载]第三方命令行工具[在线帮助]VBScript / JScript 在线参考
查看: 24133|回复: 3

[网络工具] 命令行版的ChatGPT(修改版)

[复制链接]
发表于 2023-7-28 23:28:10 | 显示全部楼层 |阅读模式
本帖最后由 CrLf 于 2023-7-29 00:17 编辑

命令行下调用OpenAI接口,从标准输入中读取用户输入并将其发送到GPT模型,再将响应写入标准输出。因原版默认是UTF8,所以我改成默认以GBK编码读取输入,并增加 --utf8 开关兼容utf8编码。

原版GitHub:https://github.com/pdfinn/sgpt

用法:
  1. sgpt -k <API_KEY> -i <INSTRUCTION> [-t TEMPERATURE] [-m MODEL] [-s SEPARATOR] [-u] [-d]
复制代码
参数说明:
短参数长参数 环境变量描述 默认值
-k --api_key SGPT_API_KEY 配置OpenAI的API KEY
-i --instruction SGPT_INSTRUCTION 系统指令,用于补充一些背景信息或要求
-t --temperature SGPT_TEMPERATURE 温度值,范围是0~1,数值越高,给出的答案越有想象力但也更倾向于编造 0.5
-m --model SGPT_MODEL 所采用的模型 gpt-3.5-turbo
-s --separator SGPT_SEPARATOR 不同内容的分隔符 \n
-u --utf8 SGPT_UTF8 以UTF8编码解读输入内容(该参数由CrLf添加,使默认编码是GBK) false
-d --debug SGPT_DEBUG 启用调试模式,将输出很多调试信息 false

CrLf修改后的源码:
  1. package main

  2. import (
  3.         "bufio"
  4.         "encoding/json"
  5.         "fmt"
  6.         "github.com/spf13/pflag"
  7.         "github.com/spf13/viper"
  8.         "io"
  9.         "io/ioutil"
  10.         "log"
  11.         "net/http"
  12.         "os"
  13.         "strconv"
  14.         "strings"

  15.         // mod by CrLf 添加必要的模块
  16.         "bytes"
  17.         "golang.org/x/text/encoding/simplifiedchinese"
  18.         "golang.org/x/text/transform"

  19. )

  20. // mod by CrLf 用于将UTF8转码为GBK
  21. // UTF-8 转 GBK
  22. func Utf8ToGbk(s []byte) ([]byte, error) {
  23.         reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewEncoder())
  24.         d, e := ioutil.ReadAll(reader)
  25.         if e != nil {
  26.                 return nil, e
  27.         }
  28.         return d, nil
  29. }

  30. func GbkToUtf8(s []byte) ([]byte, error) {
  31.     reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewDecoder())
  32.     d, e := ioutil.ReadAll(reader)
  33.     if e != nil {
  34.         return nil, e
  35.     }
  36.     return d, nil
  37. }


  38. type OpenAIResponse struct {
  39.         Choices []struct {
  40.                 Text    string `json:"text,omitempty"`
  41.                 Message struct {
  42.                         Role    string `json:"role,omitempty"`
  43.                         Content string `json:"content,omitempty"`
  44.                 } `json:"message,omitempty"`
  45.         } `json:"choices"`
  46. }

  47. // mod by CrLf 声明utf8变量
  48. var utf8 *bool
  49. var debug *bool

  50. func init() {
  51.         // mod by CrLf 去除重复的提醒
  52.        
  53.         // envUTF8 := os.Getenv("SGPT_UTF8")
  54.         // envDebug := os.Getenv("SGPT_DEBUG")
  55.         // utf8 = pflag.Bool("u", parseBoolWithDefault(envUTF8, false), "Enable UTF8 input")
  56.         // debug = pflag.Bool("d", parseBoolWithDefault(envDebug, false), "Enable debug output")
  57. }

  58. func main() {
  59.         // Default values
  60.         defaultTemperature := 0.5
  61.         defaultModel := "gpt-3.5-turbo"

  62.         // Check environment variables
  63.         envApiKey := os.Getenv("SGPT_API_KEY")
  64.         envInstruction := os.Getenv("SGPT_INSTRUCTION")
  65.         envTemperature, err := strconv.ParseFloat(os.Getenv("SGPT_TEMPERATURE"), 64)
  66.         if err != nil {
  67.                 envTemperature = defaultTemperature
  68.         }
  69.         envModel := os.Getenv("SGPT_MODEL")
  70.         envSeparator := os.Getenv("SGPT_SEPARATOR")
  71.        
  72.         // mod by CrLf 增加对环境变量 SGPT_UTF8 的支持
  73.         envUTF8 := parseBoolWithDefault(os.Getenv("SGPT_UTF8"), false)
  74.         envDebug := parseBoolWithDefault(os.Getenv("SGPT_DEBUG"), false)

  75.         // Command line arguments
  76.         apiKey := pflag.StringP("api_key", "k", envApiKey, "OpenAI API key")
  77.         instruction := pflag.StringP("instruction", "i", envInstruction, "Instruction for the GPT model")
  78.         temperature := pflag.Float64P("temperature", "t", envTemperature, "Temperature for the GPT model")
  79.         model := pflag.StringP("model", "m", envModel, "GPT model to use")
  80.         defaulSeparator := "\n"
  81.         separator := pflag.StringP("separator", "s", envSeparator, "Separator character for input")
  82.         if *separator == "" {
  83.                 *separator = defaulSeparator
  84.         }
  85.        
  86.         // mod by CrLf 增加对参数 --utf8 或 -u 的支持
  87.         utf8 = pflag.BoolP("utf8", "u", envUTF8, "Enable UTF8 input")
  88.         debug = pflag.BoolP("debug", "d", envDebug, "Enable debug output")
  89.         pflag.Parse()

  90.         // Read the configuration file
  91.         viper.SetConfigName("sgpt")
  92.         viper.AddConfigPath(".")
  93.         viper.AddConfigPath("$HOME/.sgpt")
  94.         viper.SetConfigType("yaml")

  95.         err = viper.ReadInConfig()
  96.        
  97.         // mod by CrLf 默认屏蔽无用警告,仅在debug模式下展示
  98.         if _, ok := err.(viper.ConfigFileNotFoundError); ok {
  99.                 debugOutput(*debug, "Warning: Config file not found: %v", err)
  100.         } else if err != nil {
  101.                 debugOutput(*debug, "Warning: Error reading config file: %v", err)
  102.         }

  103.         // Set default values and bind configuration values to flags
  104.         viper.SetDefault("model", defaultModel)
  105.         viper.SetDefault("temperature", defaultTemperature)
  106.         viper.BindPFlag("api_key", pflag.Lookup("k"))
  107.         viper.BindPFlag("instruction", pflag.Lookup("i"))
  108.         viper.BindPFlag("model", pflag.Lookup("m"))
  109.         viper.BindPFlag("temperature", pflag.Lookup("t"))
  110.         viper.BindPFlag("separator", pflag.Lookup("s"))
  111.         viper.BindPFlag("debug", pflag.Lookup("d"))

  112.         // Use default values if neither flags nor environment variables are set
  113.         if *model == "" {
  114.                 *model = defaultModel
  115.         }

  116.         if *apiKey == "" {
  117.                 log.Fatal("API key is required")
  118.         }


  119.         // Read input from stdin continuously
  120.         // mod by CrLf 根据utf8开关的启禁用状态判断以utf8还是gbk读取stdin
  121.         var reader io.RuneReader
  122.         if *utf8 {
  123.                 reader = bufio.NewReader(os.Stdin)
  124.         } else {
  125.                 byteInput, _ := io.ReadAll(os.Stdin)
  126.                 gbkBytes, _ := GbkToUtf8(byteInput)
  127.                 reader = bytes.NewReader(gbkBytes)
  128.         }

  129.         var inputBuffer strings.Builder

  130.         for {
  131.                 inputChar, _, err := reader.ReadRune()
  132.                 if err == io.EOF {
  133.                         input := inputBuffer.String()
  134.                         if input != "" {
  135.                                 response, err := callOpenAI(*apiKey, *instruction, input, *temperature, *model)
  136.                                 if err != nil {
  137.                                         log.Fatal(err)
  138.                                 }
  139.                                 fmt.Println(response)
  140.                         }
  141.                         break
  142.                 }
  143.                 if err != nil {
  144.                         log.Fatal(err)
  145.                 }

  146.                 if string(inputChar) == *separator {
  147.                         input := inputBuffer.String()
  148.                         inputBuffer.Reset()

  149.                         response, err := callOpenAI(*apiKey, *instruction, input, *temperature, *model)
  150.                         if err != nil {
  151.                                 log.Fatal(err)
  152.                         }

  153.                         fmt.Println(response)
  154.                 } else {
  155.                         inputBuffer.WriteRune(inputChar)
  156.                 }
  157.         }
  158. }

  159. func debugOutput(debug bool, format string, a ...interface{}) {
  160.         if debug {
  161.                 log.Printf(format, a...)
  162.         }
  163. }

  164. func parseFloatWithDefault(value string, defaultValue float64) float64 {
  165.         if value == "" {
  166.                 return defaultValue
  167.         }
  168.         parsedValue, err := strconv.ParseFloat(value, 64)
  169.         if err != nil {
  170.                 log.Printf("Warning: Failed to parse float value: %v", err)
  171.                 return defaultValue
  172.         }
  173.         return parsedValue
  174. }

  175. func parseBoolWithDefault(value string, defaultValue bool) bool {
  176.         if value == "" {
  177.                 return defaultValue
  178.         }
  179.         parsedValue, err := strconv.ParseBool(value)
  180.         if err != nil {
  181.                 log.Printf("Warning: Failed to parse bool value: %v", err)
  182.                 return defaultValue
  183.         }
  184.         return parsedValue
  185. }

  186. func callOpenAI(apiKey, instruction, input string, temperature float64, model string) (string, error) {
  187.         var url string
  188.         var jsonData []byte
  189.         var err error

  190.         switch model {
  191.         case "gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo":
  192.                 url = "https://api.openai.com/v1/chat/completions"

  193.                 // Prepare JSON data for GPT-4 models
  194.                 messages := []map[string]string{
  195.                         {"role": "system", "content": instruction},
  196.                         {"role": "user", "content": input},
  197.                 }

  198.                 jsonData, err = json.Marshal(map[string]interface{}{
  199.                         "model":       model,
  200.                         "messages":    messages,
  201.                         "temperature": temperature,
  202.                         "max_tokens":  100,
  203.                         "stop":        []string{"\n"},
  204.                 })

  205.         case "text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001":
  206.                 url = "https://api.openai.com/v1/completions"

  207.                 // Prepare JSON data for GPT-3 models
  208.                 prompt := instruction + " " + input
  209.                 jsonData, err = json.Marshal(map[string]interface{}{
  210.                         "model":       model,
  211.                         "prompt":      prompt,
  212.                         "temperature": temperature,
  213.                         "max_tokens":  100,
  214.                         "stop":        []string{"\n"},
  215.                 })

  216.         case "whisper-1":
  217.                 url = "https://api.openai.com/v1/audio/transcriptions"
  218.         default:
  219.                 return "", fmt.Errorf("unsupported model: %s", model)
  220.         }

  221.         if err != nil {
  222.                 return "", err
  223.         }

  224.         data := strings.NewReader(string(jsonData))

  225.         req, err := http.NewRequest("POST", url, data)
  226.         if err != nil {
  227.                 return "", err
  228.         }

  229.         req.Header.Set("Content-Type", "application/json")
  230.         req.Header.Set("Authorization", "Bearer "+apiKey)

  231.         client := &http.Client{}
  232.         resp, err := client.Do(req)
  233.         if err != nil {
  234.                 return "", err
  235.         }
  236.         defer resp.Body.Close()

  237.         body, err := ioutil.ReadAll(resp.Body)
  238.         if err != nil {
  239.                 return "", err
  240.         }

  241.         debugOutput(*debug, "API response: %s\n", string(body))

  242.         var openAIResponse OpenAIResponse
  243.         err = json.Unmarshal(body, &openAIResponse)
  244.         if err != nil {
  245.                 return "", err
  246.         }

  247.         if len(openAIResponse.Choices) == 0 {
  248.                 debugOutput(*debug, "API response: %s\n", string(body))
  249.                 debugOutput(*debug, "HTTP status code: %s\n", strconv.Itoa(resp.StatusCode))
  250.                 return "", fmt.Errorf("no choices returned from the API")
  251.         }

  252.         assistantMessage := ""
  253.         for _, choice := range openAIResponse.Choices {
  254.                 if choice.Message.Role == "assistant" {
  255.                         assistantMessage = strings.TrimSpace(choice.Message.Content)
  256.                         break
  257.                 }
  258.                 if choice.Text != "" {
  259.                         assistantMessage = strings.TrimSpace(choice.Text)
  260.                         break
  261.                 }
  262.         }

  263.         if assistantMessage == "" {
  264.                 return "", fmt.Errorf("no assistant message found in the API response")
  265.         }

  266.         return assistantMessage, nil
  267. }
复制代码
编译后的下载地址:http://bcn.bathome.net/s/tool/index.html?key=sgpt

评分

参与人数 1技术 +1 收起 理由
老刘1号 + 1 感谢分享

查看全部评分

 楼主| 发表于 2023-7-28 23:31:32 | 显示全部楼层
本帖最后由 CrLf 于 2023-7-29 00:13 编辑

举个例子:
  1. echo 柬埔寨在哪里|sgpt.exe --api_key "***这里是你的openai_api_key***" --instruction "请用中文回答:" --model "gpt-3.5-turbo"
  2. :: 回答为:柬埔寨位于东南亚,东临越南,南接泰国,西邻泰国和洞朗,北界老挝。
复制代码
如果要传入非GBK字符,请 chcp 65001 后使用 --utf8 开关
发表于 2023-12-15 08:34:48 | 显示全部楼层
回复 2# CrLf


    感谢大佬分享, 请教一下,
1. APIkey能不挂梯直接在国内用吗? 会被封吗号?
2. 免费的帐号创建的api能直接用吗? 听说好像有几$的额度, 我试了一下能创建key, 但是没有看到额度,
先感谢
发表于 2024-5-21 20:49:53 | 显示全部楼层
这个 代码能改一下 用过的 大模型吗
您需要登录后才可以回帖 登录 | 注册

本版积分规则

Archiver|手机版|小黑屋|批处理之家 ( 渝ICP备10000708号 )

GMT+8, 2026-3-16 22:02 , Processed in 0.018524 second(s), 8 queries , File On.

Powered by Discuz! X3.5

© 2001-2026 Discuz! Team.

快速回复 返回顶部 返回列表