2024-01-09 21:34:47 -07:00
package lib
import (
"bytes"
"context"
"errors"
"io"
2024-01-19 14:08:13 -07:00
"log"
2024-01-09 21:34:47 -07:00
"net/http"
2024-01-19 14:08:13 -07:00
"os"
2024-01-09 23:05:32 -07:00
"strings"
2024-01-09 21:34:47 -07:00
"github.com/replicate/replicate-go"
)
2024-01-09 23:05:32 -07:00
func ReplicateTextGeneration ( prompt string ) ( string , error ) {
client , clientError := replicate . NewClient ( replicate . WithTokenFromEnv ( ) )
if clientError != nil {
return "" , clientError
}
input := replicate . PredictionInput {
2024-01-10 01:03:25 -07:00
"prompt" : prompt ,
"max_new_tokens" : 4096 ,
2024-01-09 23:05:32 -07:00
}
webhook := replicate . Webhook {
URL : "https://example.com/webhook" ,
Events : [ ] replicate . WebhookEventType { "start" , "completed" } ,
}
2024-01-12 17:49:30 -07:00
prediction , predictionError := client . Run ( context . Background ( ) , "mistralai/mistral-7b-instruct-v0.2:79052a3adbba8116ebc6697dcba67ad0d58feff23e7aeb2f103fc9aa545f9269" , input , & webhook )
2024-01-09 23:05:32 -07:00
if predictionError != nil {
return "" , predictionError
}
test , ok := prediction . ( [ ] interface { } )
if ! ok {
return "" , errors . New ( "there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements" )
}
strs := make ( [ ] string , len ( test ) )
for i , v := range test {
str , ok := v . ( string )
if ! ok {
return "" , errors . New ( "element is not a string" )
}
strs [ i ] = str
}
result := strings . Join ( strs , "" )
return result , nil
}
2024-01-19 14:33:37 -07:00
func ReplicateImageGeneration ( prompt string , filename string ) ( * bytes . Buffer , error ) {
2024-01-09 21:34:47 -07:00
client , clientError := replicate . NewClient ( replicate . WithTokenFromEnv ( ) )
if clientError != nil {
return nil , clientError
}
input := replicate . PredictionInput {
2024-01-10 00:46:10 -07:00
"prompt" : prompt ,
"refiner" : "expert_ensemble_refiner" ,
"num_inference_steps" : 69 ,
"disable_safety_checker" : true ,
2024-01-09 21:34:47 -07:00
}
webhook := replicate . Webhook {
URL : "https://example.com/webhook" ,
Events : [ ] replicate . WebhookEventType { "start" , "completed" } ,
}
2024-01-10 00:36:58 -07:00
prediction , predictionError := client . Run ( context . Background ( ) , "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" , input , & webhook )
2024-01-09 21:34:47 -07:00
if predictionError != nil {
return nil , predictionError
}
test , ok := prediction . ( [ ] interface { } )
if ! ok {
return nil , errors . New ( "there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements" )
}
imgUrl , ok := test [ 0 ] . ( string )
if ! ok {
return nil , errors . New ( "there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements" )
}
imageRes , imageGetErr := http . Get ( imgUrl )
if imageGetErr != nil {
return nil , imageGetErr
}
defer imageRes . Body . Close ( )
imageBytes , imgReadErr := io . ReadAll ( imageRes . Body )
if imgReadErr != nil {
return nil , imgReadErr
}
2024-01-19 14:08:13 -07:00
// Save image to a temporary file
2024-01-19 14:33:37 -07:00
var tmpfile * os . File
var err error
if filename != "" {
tmpfile , err = os . CreateTemp ( "" , filename )
} else {
tmpfile , err = os . CreateTemp ( "" , "image.*.jpg" )
}
2024-01-19 14:08:13 -07:00
if err != nil {
log . Fatal ( err )
}
2024-01-19 14:33:37 -07:00
2024-01-19 14:08:13 -07:00
defer os . Remove ( tmpfile . Name ( ) )
if _ , err := tmpfile . Write ( imageBytes ) ; err != nil {
log . Fatal ( err )
}
if err := tmpfile . Close ( ) ; err != nil {
log . Fatal ( err )
}
// Upload the image to S3
_ , uploadErr := UploadToS3 ( tmpfile . Name ( ) )
if uploadErr != nil {
log . Printf ( "Failed to upload image to S3: %v" , uploadErr )
}
2024-01-09 21:34:47 -07:00
imageFile := bytes . NewBuffer ( imageBytes )
return imageFile , nil
}