From 08f1b82a655b75de9b07c9cb3a40c8050bc25ffd Mon Sep 17 00:00:00 2001 From: atridadl Date: Mon, 5 Feb 2024 00:17:25 -0700 Subject: [PATCH] Update Himbot's response format and add code command --- command/ask.go | 6 ++--- command/code.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++++ lib/replicate.go | 49 +++++++++++++++++++++++++++++++++++- main.go | 12 +++++++++ 4 files changed, 127 insertions(+), 4 deletions(-) create mode 100644 command/code.go diff --git a/command/ask.go b/command/ask.go index 2166b1a..dd9936b 100644 --- a/command/ask.go +++ b/command/ask.go @@ -47,18 +47,18 @@ func Ask(ctx context.Context, data cmdroute.CommandData) *api.InteractionRespons textFile := bytes.NewBuffer([]byte(respString)) file := sendpart.File{ - Name: "himbot_response.txt", + Name: "himbot_response.md", Reader: textFile, } return &api.InteractionResponseData{ - Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response:\n"), + Content: option.NewNullableString("Prompt: " + options.Prompt + "\n"), AllowedMentions: &api.AllowedMentions{}, Files: []sendpart.File{file}, } } return &api.InteractionResponseData{ - Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response: " + respString), + Content: option.NewNullableString("Prompt: " + options.Prompt + "\n--------------------\n" + respString), AllowedMentions: &api.AllowedMentions{}, } } diff --git a/command/code.go b/command/code.go new file mode 100644 index 0000000..f0d5d78 --- /dev/null +++ b/command/code.go @@ -0,0 +1,64 @@ +package command + +import ( + "bytes" + "context" + "errors" + "fmt" + "himbot/lib" + "time" + + "github.com/diamondburned/arikawa/v3/api" + "github.com/diamondburned/arikawa/v3/api/cmdroute" + "github.com/diamondburned/arikawa/v3/utils/json/option" + "github.com/diamondburned/arikawa/v3/utils/sendpart" +) + +func Code(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData { + // Cooldown Logic + allowed, cooldownString := lib.CooldownHandler(*data.Event, "code", time.Minute) + + if !allowed { + return lib.ErrorResponse(errors.New(cooldownString)) + } + + // Command Logic + var options struct { + Prompt string `discord:"prompt"` + } + + if err := data.Options.Unmarshal(&options); err != nil { + lib.CancelCooldown(data.Event.User.ID.String(), "ask") + return lib.ErrorResponse(err) + } + + respString, err := lib.ReplicateCodeGeneration(options.Prompt) + + if err != nil { + fmt.Printf("ChatCompletion error: %v\n", err) + lib.CancelCooldown(data.Event.User.ID.String(), "ask") + return &api.InteractionResponseData{ + Content: option.NewNullableString("ChatCompletion Error!"), + AllowedMentions: &api.AllowedMentions{}, + } + } + + if len(respString) > 1800 { + textFile := bytes.NewBuffer([]byte(respString)) + + file := sendpart.File{ + Name: "himbot_response.md", + Reader: textFile, + } + + return &api.InteractionResponseData{ + Content: option.NewNullableString("Prompt: " + options.Prompt + "\n"), + AllowedMentions: &api.AllowedMentions{}, + Files: []sendpart.File{file}, + } + } + return &api.InteractionResponseData{ + Content: option.NewNullableString("Prompt: " + options.Prompt + "\n--------------------\n" + respString), + AllowedMentions: &api.AllowedMentions{}, + } +} diff --git a/lib/replicate.go b/lib/replicate.go index 8087be5..f92e67c 100644 --- a/lib/replicate.go +++ b/lib/replicate.go @@ -13,6 +13,8 @@ import ( "github.com/replicate/replicate-go" ) +var PromptPrefix = "Ready for a dose of sarcasm and wit? Himbot, your Discord assistant, is up for the challenge. Hit it with the prompt:" + func ReplicateTextGeneration(prompt string) (string, error) { client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) if clientError != nil { @@ -20,7 +22,7 @@ func ReplicateTextGeneration(prompt string) (string, error) { } input := replicate.PredictionInput{ - "prompt": "Respond to the following prompt as the helpful but sarcastic and witty discord assistant called Himbot: " + prompt, + "prompt": PromptPrefix + prompt, "max_new_tokens": 4096, "prompt_template": "[INST] {prompt} [/INST]", } @@ -59,6 +61,51 @@ func ReplicateTextGeneration(prompt string) (string, error) { return result, nil } +func ReplicateCodeGeneration(prompt string) (string, error) { + client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) + if clientError != nil { + return "", clientError + } + + input := replicate.PredictionInput{ + "prompt": PromptPrefix + prompt, + "max_new_tokens": 4096, + } + webhook := replicate.Webhook{ + URL: "https://example.com/webhook", + Events: []replicate.WebhookEventType{"start", "completed"}, + } + + prediction, predictionError := client.Run(context.Background(), "meta/codellama-70b-instruct:a279116fe47a0f65701a8817188601e2fe8f4b9e04a518789655ea7b995851bf", input, &webhook) + + if predictionError != nil { + return "", predictionError + } + + if prediction == nil { + return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue") + } + + test, ok := prediction.([]interface{}) + + if !ok { + return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue") + } + + strs := make([]string, len(test)) + for i, v := range test { + str, ok := v.(string) + if !ok { + return "", errors.New("element is not a string") + } + strs[i] = str + } + + result := strings.Join(strs, "") + + return result, nil +} + func ReplicateImageGeneration(prompt string, filename string) (*bytes.Buffer, error) { client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) if clientError != nil { diff --git a/main.go b/main.go index c44f77a..0fe451e 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,17 @@ var commands = []api.CreateCommandData{ }, }, }, + { + Name: "code", + Description: "Ask Himbot programming questions! Cooldown: 2 Minutes.", + Options: []discord.CommandOption{ + &discord.StringOption{ + OptionName: "prompt", + Description: "The prompt to send to Himbot.", + Required: true, + }, + }, + }, { Name: "pic", Description: "Generate an image! Cooldown: 1 Minute.", @@ -96,6 +107,7 @@ func newHandler(s *state.State) *handler { h.Use(cmdroute.Deferrable(s, cmdroute.DeferOpts{})) h.AddFunc("ping", command.Ping) h.AddFunc("ask", command.Ask) + h.AddFunc("code", command.Code) h.AddFunc("pic", command.Pic) h.AddFunc("hs", command.HS)