Replicate
This commit is contained in:
parent
870f024420
commit
f4aab9470f
6 changed files with 155 additions and 109 deletions
|
@ -32,7 +32,7 @@ func Ask(ctx context.Context, data cmdroute.CommandData) *api.InteractionRespons
|
|||
return lib.ErrorResponse(err)
|
||||
}
|
||||
|
||||
respString, err := lib.OpenAITextGeneration(options.Prompt)
|
||||
respString, err := lib.ReplicateTextGeneration(options.Prompt)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("ChatCompletion error: %v\n", err)
|
||||
|
|
|
@ -37,7 +37,7 @@ func Pic(ctx context.Context, data cmdroute.CommandData) *api.InteractionRespons
|
|||
// Concatenate clean username and timestamp to form filename
|
||||
filename := data.Event.Sender().Username + "_" + timestamp + ".jpg"
|
||||
|
||||
imageFile, err := lib.OpenAIImageGeneration(options.Prompt, filename)
|
||||
imageFile, err := lib.ReplicateImageGeneration(options.Prompt, filename)
|
||||
|
||||
if err != nil {
|
||||
lib.CancelCooldown(data.Event.Member.User.ID.String(), "pic")
|
||||
|
|
2
go.mod
2
go.mod
|
@ -8,6 +8,7 @@ require (
|
|||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sync v0.5.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
||||
|
@ -16,6 +17,7 @@ require (
|
|||
github.com/gorilla/schema v1.2.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.1 // indirect
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/replicate/replicate-go v0.16.1
|
||||
github.com/sashabaranov/go-openai v1.20.0
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
)
|
||||
|
|
4
go.sum
4
go.sum
|
@ -19,6 +19,8 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
|||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/replicate/replicate-go v0.16.1 h1:LbImDfB6ef0yEfWbBNQdnC5CoKmHxonoa/UUJ6YrFC8=
|
||||
github.com/replicate/replicate-go v0.16.1/go.mod h1:otIrl1vDmyjNhTzmVmp/mQU3Wt1+3387gFNEsAZq0ig=
|
||||
github.com/sashabaranov/go-openai v1.20.0 h1:r9WiwJY6Q2aPDhVyfOSKm83Gs04ogN1yaaBoQOnusS4=
|
||||
github.com/sashabaranov/go-openai v1.20.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
|
@ -38,6 +40,8 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
|||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
|
||||
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
|
107
lib/openai.go
107
lib/openai.go
|
@ -1,107 +0,0 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
var PromptPrefix = "Your name is Himbot. You are a helpful but sarcastic and witty discord bot. Please respond with a natural response to the following prompt with that personality in mind:"
|
||||
|
||||
func OpenAITextGeneration(prompt string) (string, error) {
|
||||
godotenv.Load(".env")
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
|
||||
client := openai.NewClient(apiKey)
|
||||
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: openai.GPT4Turbo1106,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: PromptPrefix + prompt,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Ask command error: %v\n", err)
|
||||
return "", errors.New("https://fly.storage.tigris.dev/atridad/himbot/no.gif")
|
||||
}
|
||||
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func OpenAIImageGeneration(prompt string, filename string) (*bytes.Buffer, error) {
|
||||
godotenv.Load(".env")
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
|
||||
client := openai.NewClient(apiKey)
|
||||
|
||||
imageResponse, err := client.CreateImage(
|
||||
context.Background(),
|
||||
openai.ImageRequest{
|
||||
Model: openai.CreateImageModelDallE3,
|
||||
Prompt: prompt,
|
||||
Size: openai.CreateImageSize1024x1024,
|
||||
Quality: openai.CreateImageQualityStandard,
|
||||
ResponseFormat: openai.CreateImageResponseFormatURL,
|
||||
N: 1,
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Pic command error: %v\n", err)
|
||||
return nil, errors.New("https://fly.storage.tigris.dev/atridad/himbot/hornypolice.gif")
|
||||
}
|
||||
|
||||
imgUrl := imageResponse.Data[0].URL
|
||||
|
||||
imageRes, imageGetErr := http.Get(imgUrl)
|
||||
if imageGetErr != nil {
|
||||
return nil, imageGetErr
|
||||
}
|
||||
|
||||
defer imageRes.Body.Close()
|
||||
|
||||
imageBytes, imgReadErr := io.ReadAll(imageRes.Body)
|
||||
if imgReadErr != nil {
|
||||
return nil, imgReadErr
|
||||
}
|
||||
|
||||
// Save image to a temporary file
|
||||
tmpfile, err := os.Create(filename)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
if _, err := tmpfile.Write(imageBytes); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Upload the image to S3
|
||||
_, uploadErr := UploadToS3(tmpfile.Name())
|
||||
if uploadErr != nil {
|
||||
log.Printf("Failed to upload image to S3: %v", uploadErr)
|
||||
}
|
||||
|
||||
imageFile := bytes.NewBuffer(imageBytes)
|
||||
return imageFile, nil
|
||||
}
|
147
lib/replicate.go
Normal file
147
lib/replicate.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/replicate/replicate-go"
|
||||
)
|
||||
|
||||
var ReplicatePromptPrefix = "Your name is Himbot. You are a helpful but sarcastic and witty discord bot. Please respond with a natural response to the following prompt with that personality in mind:"
|
||||
|
||||
func ReplicateTextGeneration(prompt string) (string, error) {
|
||||
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
|
||||
if clientError != nil {
|
||||
return "", clientError
|
||||
}
|
||||
|
||||
input := replicate.PredictionInput{
|
||||
"prompt": ReplicatePromptPrefix + prompt,
|
||||
"max_new_tokens": 4096,
|
||||
"prompt_template": "<s>[INST] {prompt} [/INST]",
|
||||
}
|
||||
|
||||
webhook := replicate.Webhook{
|
||||
URL: "https://example.com/webhook",
|
||||
Events: []replicate.WebhookEventType{"start", "completed"},
|
||||
}
|
||||
|
||||
prediction, predictionError := client.Run(context.Background(), "mistralai/mixtral-8x7b-instruct-v0.1:cf18decbf51c27fed6bbdc3492312c1c903222a56e3fe9ca02d6cbe5198afc10", input, &webhook)
|
||||
|
||||
if predictionError != nil {
|
||||
return "", predictionError
|
||||
}
|
||||
|
||||
if prediction == nil {
|
||||
return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue")
|
||||
}
|
||||
|
||||
test, ok := prediction.([]interface{})
|
||||
|
||||
if !ok {
|
||||
return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue")
|
||||
}
|
||||
|
||||
strs := make([]string, len(test))
|
||||
for i, v := range test {
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return "", errors.New("element is not a string")
|
||||
}
|
||||
strs[i] = str
|
||||
}
|
||||
|
||||
result := strings.Join(strs, "")
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func ReplicateImageGeneration(prompt string, filename string) (*bytes.Buffer, error) {
|
||||
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
|
||||
if clientError != nil {
|
||||
return nil, clientError
|
||||
}
|
||||
|
||||
input := replicate.PredictionInput{
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"prompt": prompt,
|
||||
"scheduler": "K_EULER",
|
||||
"num_outputs": 1,
|
||||
"guidance_scale": 7.5,
|
||||
"lora_scale": 0.65,
|
||||
"lora_weights": "https://replicate.delivery/pbxt/hM1H6f93HCVYQq471gZz6EYtRHPMJYAsyxeQXdGnozeDJKOkA/trained_model.tar",
|
||||
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, text, BadDream, lowres, low resolution, mutated body parts, extra limbs, mutated body parts, inaccurate hands, too many hands, deformed fingers, too many fingers, deformed eyes, deformed faces, unrealistic faces",
|
||||
"num_inference_steps": 35,
|
||||
"disable_safety_checker": true,
|
||||
}
|
||||
webhook := replicate.Webhook{
|
||||
URL: "https://example.com/webhook",
|
||||
Events: []replicate.WebhookEventType{"start", "completed"},
|
||||
}
|
||||
|
||||
prediction, predictionError := client.Run(context.Background(), "batouresearch/open-dalle-1.1-lora:2ade2cbfc88298b98366a6e361559e11666c17ed415d341c9ae776b30a61b196", input, &webhook)
|
||||
|
||||
if predictionError != nil {
|
||||
return nil, predictionError
|
||||
}
|
||||
|
||||
if prediction == nil {
|
||||
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
|
||||
}
|
||||
|
||||
test, ok := prediction.([]interface{})
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
|
||||
}
|
||||
|
||||
imgUrl, ok := test[0].(string)
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
|
||||
}
|
||||
|
||||
imageRes, imageGetErr := http.Get(imgUrl)
|
||||
if imageGetErr != nil {
|
||||
return nil, imageGetErr
|
||||
}
|
||||
|
||||
defer imageRes.Body.Close()
|
||||
|
||||
imageBytes, imgReadErr := io.ReadAll(imageRes.Body)
|
||||
if imgReadErr != nil {
|
||||
return nil, imgReadErr
|
||||
}
|
||||
|
||||
// Save image to a temporary file
|
||||
tmpfile, err := os.Create(filename)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
if _, err := tmpfile.Write(imageBytes); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Upload the image to S3
|
||||
_, uploadErr := UploadToS3(tmpfile.Name())
|
||||
if uploadErr != nil {
|
||||
log.Printf("Failed to upload image to S3: %v", uploadErr)
|
||||
}
|
||||
|
||||
imageFile := bytes.NewBuffer(imageBytes)
|
||||
return imageFile, nil
|
||||
}
|
Loading…
Add table
Reference in a new issue