This commit is contained in:
Atridad Lahiji 2024-04-20 12:06:49 -06:00
parent b317ebed86
commit 9ad6aa0c59
6 changed files with 0 additions and 292 deletions

BIN
.DS_Store vendored

Binary file not shown.

View file

@ -1,6 +1,5 @@
# Tokens
DISCORD_TOKEN=""
REPLICATE_API_TOKEN=""
# Comma separated
COOLDOWN_ALLOW_LIST=""
# S3

View file

@ -1,64 +0,0 @@
package command
import (
"bytes"
"context"
"errors"
"fmt"
"himbot/lib"
"time"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/api/cmdroute"
"github.com/diamondburned/arikawa/v3/utils/json/option"
"github.com/diamondburned/arikawa/v3/utils/sendpart"
)
func Ask(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData {
// Cooldown Logic
allowed, cooldownString := lib.CooldownHandler(*data.Event, "ask", time.Minute)
if !allowed {
return lib.ErrorResponse(errors.New(cooldownString))
}
// Command Logic
var options struct {
Prompt string `discord:"prompt"`
}
if err := data.Options.Unmarshal(&options); err != nil {
lib.CancelTimer(data.Event.Member.User.ID.String(), "ask")
return lib.ErrorResponse(err)
}
respString, err := lib.ReplicateTextGeneration(options.Prompt)
if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)
lib.CancelTimer(data.Event.Member.User.ID.String(), "ask")
return &api.InteractionResponseData{
Content: option.NewNullableString("ChatCompletion Error!"),
AllowedMentions: &api.AllowedMentions{},
}
}
if len(respString) > 1800 {
textFile := bytes.NewBuffer([]byte(respString))
file := sendpart.File{
Name: "himbot_response.md",
Reader: textFile,
}
return &api.InteractionResponseData{
Content: option.NewNullableString("Prompt: " + options.Prompt + "\n"),
AllowedMentions: &api.AllowedMentions{},
Files: []sendpart.File{file},
}
}
return &api.InteractionResponseData{
Content: option.NewNullableString("Prompt: " + options.Prompt + "\n--------------------\n" + respString),
AllowedMentions: &api.AllowedMentions{},
}
}

View file

@ -1,56 +0,0 @@
package command
import (
"context"
"errors"
"himbot/lib"
"strconv"
"time"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/api/cmdroute"
"github.com/diamondburned/arikawa/v3/utils/json/option"
"github.com/diamondburned/arikawa/v3/utils/sendpart"
)
func Pic(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData {
// Cooldown Logic
allowed, cooldownString := lib.CooldownHandler(*data.Event, "pic", time.Minute*5)
if !allowed {
return lib.ErrorResponse(errors.New(cooldownString))
}
// Command Logic
var options struct {
Prompt string `discord:"prompt"`
}
if err := data.Options.Unmarshal(&options); err != nil {
lib.CancelTimer(data.Event.Member.User.ID.String(), "pic")
return lib.ErrorResponse(err)
}
// Get current epoch timestamp
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
// Concatenate clean username and timestamp to form filename
filename := data.Event.Sender().Username + "_" + timestamp + ".jpg"
imageFile, err := lib.ReplicateImageGeneration(options.Prompt, filename)
if err != nil {
lib.CancelTimer(data.Event.Member.User.ID.String(), "pic")
return lib.ErrorResponse(err)
}
file := sendpart.File{
Name: filename,
Reader: imageFile,
}
return &api.InteractionResponseData{
Content: option.NewNullableString("Prompt: " + options.Prompt),
Files: []sendpart.File{file},
}
}

View file

@ -1,149 +0,0 @@
package lib
import (
"bytes"
"context"
"errors"
"io"
"log"
"net/http"
"os"
"strings"
"github.com/replicate/replicate-go"
)
var SystemPrompt = "Your name is Himbot. You are an assistant bot designed to provide helpful responses. Your responses should be natural and engaging. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
var PromptTemplate = `<s>[INST] Using this information:` + SystemPrompt + `answer the following Prompt: {prompt} [/INST]`
func ReplicateTextGeneration(prompt string) (string, error) {
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
if clientError != nil {
return "", clientError
}
input := replicate.PredictionInput{
"prompt": prompt,
"max_new_tokens": 1024,
"prompt_template": PromptTemplate,
}
webhook := replicate.Webhook{
URL: "https://example.com/webhook",
Events: []replicate.WebhookEventType{"start", "completed"},
}
prediction, predictionError := client.Run(context.Background(), "mistralai/mixtral-8x7b-instruct-v0.1:5d78bcd7a992c4b793465bcdcf551dc2ab9668d12bb7aa714557a21c1e77041c", input, &webhook)
if predictionError != nil {
return "", predictionError
}
if prediction == nil {
return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
test, ok := prediction.([]interface{})
if !ok {
return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
strs := make([]string, len(test))
for i, v := range test {
str, ok := v.(string)
if !ok {
return "", errors.New("element is not a string")
}
strs[i] = str
}
result := strings.Join(strs, "")
return result, nil
}
func ReplicateImageGeneration(prompt string, filename string) (*bytes.Buffer, error) {
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
if clientError != nil {
return nil, clientError
}
input := replicate.PredictionInput{
"width": 1024,
"height": 1024,
"prompt": prompt,
"num_inference_steps": 25,
"negative_prompt": "deformed, noisy, blurry, distorted",
"scheduler": "DPMSolver++",
"guidance_scale": 3,
"prompt_strength": 0.8,
"apply_watermark": false,
"num_outputs": 1,
"disable_safety_checker": true,
}
webhook := replicate.Webhook{
URL: "https://example.com/webhook",
Events: []replicate.WebhookEventType{"start", "completed"},
}
prediction, predictionError := client.Run(context.Background(), "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24", input, &webhook)
if predictionError != nil {
return nil, predictionError
}
if prediction == nil {
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
test, ok := prediction.([]interface{})
if !ok {
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
imgUrl, ok := test[0].(string)
if !ok {
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
imageRes, imageGetErr := http.Get(imgUrl)
if imageGetErr != nil {
return nil, imageGetErr
}
defer imageRes.Body.Close()
imageBytes, imgReadErr := io.ReadAll(imageRes.Body)
if imgReadErr != nil {
return nil, imgReadErr
}
// Save image to a temporary file
tmpfile, err := os.Create(filename)
if err != nil {
log.Fatal(err)
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write(imageBytes); err != nil {
log.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
log.Fatal(err)
}
// Upload the image to S3
_, uploadErr := UploadToS3(tmpfile.Name())
if uploadErr != nil {
log.Printf("Failed to upload image to S3: %v", uploadErr)
}
imageFile := bytes.NewBuffer(imageBytes)
return imageFile, nil
}

22
main.go
View file

@ -20,28 +20,6 @@ var commands = []api.CreateCommandData{
Name: "ping",
Description: "ping pong!",
},
{
Name: "ask",
Description: "Ask Himbot! Cooldown: 1 Minute.",
Options: []discord.CommandOption{
&discord.StringOption{
OptionName: "prompt",
Description: "The prompt to send to Himbot.",
Required: true,
},
},
},
{
Name: "pic",
Description: "Generate an image! Cooldown: 5 Minutes.",
Options: []discord.CommandOption{
&discord.StringOption{
OptionName: "prompt",
Description: "The prompt for the image generation.",
Required: true,
},
},
},
{
Name: "hs",
Description: "This command was your nickname in highschool!",