diff --git a/command/ask.go b/command/ask.go index 40a5ad9..07e0e17 100644 --- a/command/ask.go +++ b/command/ask.go @@ -32,7 +32,7 @@ func Ask(ctx context.Context, data cmdroute.CommandData) *api.InteractionRespons return lib.ErrorResponse(err) } - respString, err := lib.OpenAITextGeneration(options.Prompt) + respString, err := lib.ReplicateTextGeneration(options.Prompt) if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) diff --git a/command/pic.go b/command/pic.go index 7c5adeb..5e99451 100644 --- a/command/pic.go +++ b/command/pic.go @@ -37,7 +37,7 @@ func Pic(ctx context.Context, data cmdroute.CommandData) *api.InteractionRespons // Concatenate clean username and timestamp to form filename filename := data.Event.Sender().Username + "_" + timestamp + ".jpg" - imageFile, err := lib.OpenAIImageGeneration(options.Prompt, filename) + imageFile, err := lib.ReplicateImageGeneration(options.Prompt, filename) if err != nil { lib.CancelCooldown(data.Event.Member.User.ID.String(), "pic") diff --git a/go.mod b/go.mod index b8f6152..4e9b178 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect golang.org/x/net v0.21.0 // indirect + golang.org/x/sync v0.5.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) @@ -16,6 +17,7 @@ require ( github.com/gorilla/schema v1.2.1 // indirect github.com/gorilla/websocket v1.5.1 // indirect github.com/joho/godotenv v1.5.1 + github.com/replicate/replicate-go v0.16.1 github.com/sashabaranov/go-openai v1.20.0 golang.org/x/time v0.5.0 // indirect ) diff --git a/go.sum b/go.sum index 8f71464..3556e46 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/replicate/replicate-go v0.16.1 h1:LbImDfB6ef0yEfWbBNQdnC5CoKmHxonoa/UUJ6YrFC8= +github.com/replicate/replicate-go v0.16.1/go.mod h1:otIrl1vDmyjNhTzmVmp/mQU3Wt1+3387gFNEsAZq0ig= github.com/sashabaranov/go-openai v1.20.0 h1:r9WiwJY6Q2aPDhVyfOSKm83Gs04ogN1yaaBoQOnusS4= github.com/sashabaranov/go-openai v1.20.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -38,6 +40,8 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/lib/openai.go b/lib/openai.go deleted file mode 100644 index 5a16196..0000000 --- a/lib/openai.go +++ /dev/null @@ -1,107 +0,0 @@ -package lib - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "log" - "net/http" - "os" - - "github.com/joho/godotenv" - "github.com/sashabaranov/go-openai" -) - -var PromptPrefix = "Your name is Himbot. You are a helpful but sarcastic and witty discord bot. Please respond with a natural response to the following prompt with that personality in mind:" - -func OpenAITextGeneration(prompt string) (string, error) { - godotenv.Load(".env") - apiKey := os.Getenv("OPENAI_API_KEY") - - client := openai.NewClient(apiKey) - - resp, err := client.CreateChatCompletion( - context.Background(), - openai.ChatCompletionRequest{ - Model: openai.GPT4Turbo1106, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: PromptPrefix + prompt, - }, - }, - }, - ) - - if err != nil { - fmt.Printf("Ask command error: %v\n", err) - return "", errors.New("https://fly.storage.tigris.dev/atridad/himbot/no.gif") - } - - return resp.Choices[0].Message.Content, nil -} - -func OpenAIImageGeneration(prompt string, filename string) (*bytes.Buffer, error) { - godotenv.Load(".env") - apiKey := os.Getenv("OPENAI_API_KEY") - - client := openai.NewClient(apiKey) - - imageResponse, err := client.CreateImage( - context.Background(), - openai.ImageRequest{ - Model: openai.CreateImageModelDallE3, - Prompt: prompt, - Size: openai.CreateImageSize1024x1024, - Quality: openai.CreateImageQualityStandard, - ResponseFormat: openai.CreateImageResponseFormatURL, - N: 1, - }, - ) - - if err != nil { - fmt.Printf("Pic command error: %v\n", err) - return nil, errors.New("https://fly.storage.tigris.dev/atridad/himbot/hornypolice.gif") - } - - imgUrl := imageResponse.Data[0].URL - - imageRes, imageGetErr := http.Get(imgUrl) - if imageGetErr != nil { - return nil, imageGetErr - } - - defer imageRes.Body.Close() - - imageBytes, imgReadErr := io.ReadAll(imageRes.Body) - if imgReadErr != nil { - return nil, imgReadErr - } - - // Save image to a temporary file - tmpfile, err := os.Create(filename) - - if err != nil { - log.Fatal(err) - } - - defer os.Remove(tmpfile.Name()) - - if _, err := tmpfile.Write(imageBytes); err != nil { - log.Fatal(err) - } - if err := tmpfile.Close(); err != nil { - log.Fatal(err) - } - - // Upload the image to S3 - _, uploadErr := UploadToS3(tmpfile.Name()) - if uploadErr != nil { - log.Printf("Failed to upload image to S3: %v", uploadErr) - } - - imageFile := bytes.NewBuffer(imageBytes) - return imageFile, nil -} diff --git a/lib/replicate.go b/lib/replicate.go new file mode 100644 index 0000000..bebf6de --- /dev/null +++ b/lib/replicate.go @@ -0,0 +1,147 @@ +package lib + +import ( + "bytes" + "context" + "errors" + "io" + "log" + "net/http" + "os" + "strings" + + "github.com/replicate/replicate-go" +) + +var ReplicatePromptPrefix = "Your name is Himbot. You are a helpful but sarcastic and witty discord bot. Please respond with a natural response to the following prompt with that personality in mind:" + +func ReplicateTextGeneration(prompt string) (string, error) { + client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) + if clientError != nil { + return "", clientError + } + + input := replicate.PredictionInput{ + "prompt": ReplicatePromptPrefix + prompt, + "max_new_tokens": 4096, + "prompt_template": "[INST] {prompt} [/INST]", + } + + webhook := replicate.Webhook{ + URL: "https://example.com/webhook", + Events: []replicate.WebhookEventType{"start", "completed"}, + } + + prediction, predictionError := client.Run(context.Background(), "mistralai/mixtral-8x7b-instruct-v0.1:cf18decbf51c27fed6bbdc3492312c1c903222a56e3fe9ca02d6cbe5198afc10", input, &webhook) + + if predictionError != nil { + return "", predictionError + } + + if prediction == nil { + return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue") + } + + test, ok := prediction.([]interface{}) + + if !ok { + return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue") + } + + strs := make([]string, len(test)) + for i, v := range test { + str, ok := v.(string) + if !ok { + return "", errors.New("element is not a string") + } + strs[i] = str + } + + result := strings.Join(strs, "") + + return result, nil +} + +func ReplicateImageGeneration(prompt string, filename string) (*bytes.Buffer, error) { + client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) + if clientError != nil { + return nil, clientError + } + + input := replicate.PredictionInput{ + "width": 1024, + "height": 1024, + "prompt": prompt, + "scheduler": "K_EULER", + "num_outputs": 1, + "guidance_scale": 7.5, + "lora_scale": 0.65, + "lora_weights": "https://replicate.delivery/pbxt/hM1H6f93HCVYQq471gZz6EYtRHPMJYAsyxeQXdGnozeDJKOkA/trained_model.tar", + "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, text, BadDream, lowres, low resolution, mutated body parts, extra limbs, mutated body parts, inaccurate hands, too many hands, deformed fingers, too many fingers, deformed eyes, deformed faces, unrealistic faces", + "num_inference_steps": 35, + "disable_safety_checker": true, + } + webhook := replicate.Webhook{ + URL: "https://example.com/webhook", + Events: []replicate.WebhookEventType{"start", "completed"}, + } + + prediction, predictionError := client.Run(context.Background(), "batouresearch/open-dalle-1.1-lora:2ade2cbfc88298b98366a6e361559e11666c17ed415d341c9ae776b30a61b196", input, &webhook) + + if predictionError != nil { + return nil, predictionError + } + + if prediction == nil { + return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue") + } + + test, ok := prediction.([]interface{}) + + if !ok { + return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue") + } + + imgUrl, ok := test[0].(string) + + if !ok { + return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue") + } + + imageRes, imageGetErr := http.Get(imgUrl) + if imageGetErr != nil { + return nil, imageGetErr + } + + defer imageRes.Body.Close() + + imageBytes, imgReadErr := io.ReadAll(imageRes.Body) + if imgReadErr != nil { + return nil, imgReadErr + } + + // Save image to a temporary file + tmpfile, err := os.Create(filename) + + if err != nil { + log.Fatal(err) + } + + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write(imageBytes); err != nil { + log.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + log.Fatal(err) + } + + // Upload the image to S3 + _, uploadErr := UploadToS3(tmpfile.Name()) + if uploadErr != nil { + log.Printf("Failed to upload image to S3: %v", uploadErr) + } + + imageFile := bytes.NewBuffer(imageBytes) + return imageFile, nil +}