2024-01-09 21:34:47 -07:00
package lib
import (
"bytes"
"context"
"errors"
"io"
"net/http"
2024-01-09 23:05:32 -07:00
"strings"
2024-01-09 21:34:47 -07:00
"github.com/replicate/replicate-go"
)
2024-01-09 23:05:32 -07:00
func ReplicateTextGeneration ( prompt string ) ( string , error ) {
client , clientError := replicate . NewClient ( replicate . WithTokenFromEnv ( ) )
if clientError != nil {
return "" , clientError
}
input := replicate . PredictionInput {
2024-01-10 01:03:25 -07:00
"prompt" : prompt ,
"max_new_tokens" : 4096 ,
2024-01-09 23:05:32 -07:00
}
webhook := replicate . Webhook {
URL : "https://example.com/webhook" ,
Events : [ ] replicate . WebhookEventType { "start" , "completed" } ,
}
prediction , predictionError := client . Run ( context . Background ( ) , "mistralai/mixtral-8x7b-instruct-v0.1:7b3212fbaf88310cfef07a061ce94224e82efc8403c26fc67e8f6c065de51f21" , input , & webhook )
if predictionError != nil {
return "" , predictionError
}
test , ok := prediction . ( [ ] interface { } )
if ! ok {
return "" , errors . New ( "there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements" )
}
strs := make ( [ ] string , len ( test ) )
for i , v := range test {
str , ok := v . ( string )
if ! ok {
return "" , errors . New ( "element is not a string" )
}
strs [ i ] = str
}
result := strings . Join ( strs , "" )
return result , nil
}
2024-01-09 21:34:47 -07:00
func ReplicateImageGeneration ( prompt string ) ( * bytes . Buffer , error ) {
client , clientError := replicate . NewClient ( replicate . WithTokenFromEnv ( ) )
if clientError != nil {
return nil , clientError
}
input := replicate . PredictionInput {
2024-01-10 00:46:10 -07:00
"prompt" : prompt ,
"refiner" : "expert_ensemble_refiner" ,
"num_inference_steps" : 69 ,
"disable_safety_checker" : true ,
2024-01-09 21:34:47 -07:00
}
webhook := replicate . Webhook {
URL : "https://example.com/webhook" ,
Events : [ ] replicate . WebhookEventType { "start" , "completed" } ,
}
2024-01-10 00:36:58 -07:00
prediction , predictionError := client . Run ( context . Background ( ) , "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" , input , & webhook )
2024-01-09 21:34:47 -07:00
if predictionError != nil {
return nil , predictionError
}
test , ok := prediction . ( [ ] interface { } )
if ! ok {
return nil , errors . New ( "there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements" )
}
imgUrl , ok := test [ 0 ] . ( string )
if ! ok {
return nil , errors . New ( "there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements" )
}
imageRes , imageGetErr := http . Get ( imgUrl )
if imageGetErr != nil {
return nil , imageGetErr
}
defer imageRes . Body . Close ( )
imageBytes , imgReadErr := io . ReadAll ( imageRes . Body )
if imgReadErr != nil {
return nil , imgReadErr
}
imageFile := bytes . NewBuffer ( imageBytes )
return imageFile , nil
}