Trying some weird training nonsense. Maybe this will be fun.
All checks were successful
Docker Deploy / build-and-push (push) Successful in 3m23s
All checks were successful
Docker Deploy / build-and-push (push) Successful in 3m23s
This commit is contained in:
@@ -2,73 +2,69 @@ package command
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"himbot/lib"
|
||||
"math/rand"
|
||||
"regexp"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
)
|
||||
|
||||
type MarkovData struct {
|
||||
Chain map[string][]string // "word1 word2" -> ["word3", ...]
|
||||
Starts []string
|
||||
}
|
||||
|
||||
type MarkovCache struct {
|
||||
data map[string]*MarkovData
|
||||
data map[string]*lib.MarkovData
|
||||
hashes map[string]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
markovCache = &MarkovCache{
|
||||
data: make(map[string]*MarkovData),
|
||||
data: make(map[string]*lib.MarkovData),
|
||||
hashes: make(map[string]string),
|
||||
}
|
||||
urlRegex = regexp.MustCompile(`https?://[^\s]+`)
|
||||
mentionRegex = regexp.MustCompile(`<[@#&!][^>]+>`)
|
||||
bardChain *lib.MarkovData
|
||||
)
|
||||
|
||||
func MarkovCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
|
||||
channelID := i.ChannelID
|
||||
|
||||
numMessages := lib.AppConfig.MarkovDefaultMessages
|
||||
if len(i.ApplicationCommandData().Options) > 0 {
|
||||
if i.ApplicationCommandData().Options[0].Name == "messages" {
|
||||
numMessages = int(i.ApplicationCommandData().Options[0].IntValue())
|
||||
if numMessages <= 0 {
|
||||
numMessages = lib.AppConfig.MarkovDefaultMessages
|
||||
} else if numMessages > lib.AppConfig.MarkovMaxMessages {
|
||||
numMessages = lib.AppConfig.MarkovMaxMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages)
|
||||
if data := getCachedChain(cacheKey); data != nil {
|
||||
if msg := generateMessage(data, ""); msg != "" {
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
|
||||
allMessages, err := fetchMessages(s, channelID, numMessages)
|
||||
func InitBard(modelPath string) error {
|
||||
f, err := os.Open(modelPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var data lib.MarkovData
|
||||
decoder := gob.NewDecoder(f)
|
||||
if err := decoder.Decode(&data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := buildMarkovChain(allMessages)
|
||||
setCachedChain(cacheKey, data, allMessages)
|
||||
bardChain = &data
|
||||
return nil
|
||||
}
|
||||
|
||||
newMessage := generateMessage(data, "")
|
||||
if newMessage == "" {
|
||||
newMessage = "Not enough text data to generate a message."
|
||||
func BardCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
|
||||
if bardChain == nil {
|
||||
return "The bard is sleeping (dataset not loaded).", nil
|
||||
}
|
||||
|
||||
return newMessage, nil
|
||||
var question string
|
||||
for _, option := range i.ApplicationCommandData().Options {
|
||||
if option.Name == "question" {
|
||||
question = option.StringValue()
|
||||
}
|
||||
}
|
||||
|
||||
answer := lib.GenerateMessage(bardChain, question)
|
||||
if answer == "" {
|
||||
answer = "Words fail me."
|
||||
}
|
||||
|
||||
if question != "" {
|
||||
return fmt.Sprintf("**Q:** %s\n**A:** %s", question, answer), nil
|
||||
}
|
||||
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
func MarkovQuestionCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
|
||||
@@ -95,7 +91,7 @@ func MarkovQuestionCommand(s *discordgo.Session, i *discordgo.InteractionCreate)
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages)
|
||||
var data *MarkovData
|
||||
var data *lib.MarkovData
|
||||
|
||||
if cachedData := getCachedChain(cacheKey); cachedData != nil {
|
||||
data = cachedData
|
||||
@@ -104,11 +100,18 @@ func MarkovQuestionCommand(s *discordgo.Session, i *discordgo.InteractionCreate)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data = buildMarkovChain(allMessages)
|
||||
setCachedChain(cacheKey, data, allMessages)
|
||||
|
||||
var texts []string
|
||||
for _, msg := range allMessages {
|
||||
texts = append(texts, msg.Content)
|
||||
}
|
||||
|
||||
// Use order 2 for chat history (sparse data)
|
||||
data = lib.BuildMarkovChain(texts, 2)
|
||||
setCachedChain(cacheKey, data, hashMessages(allMessages))
|
||||
}
|
||||
|
||||
answer := generateMessage(data, question)
|
||||
answer := lib.GenerateMessage(data, question)
|
||||
if answer == "" {
|
||||
answer = "I don't have enough context to answer that."
|
||||
}
|
||||
@@ -116,15 +119,13 @@ func MarkovQuestionCommand(s *discordgo.Session, i *discordgo.InteractionCreate)
|
||||
return fmt.Sprintf("**Q:** %s\n**A:** %s", question, answer), nil
|
||||
}
|
||||
|
||||
func getCachedChain(cacheKey string) *MarkovData {
|
||||
func getCachedChain(cacheKey string) *lib.MarkovData {
|
||||
markovCache.mu.RLock()
|
||||
defer markovCache.mu.RUnlock()
|
||||
return markovCache.data[cacheKey]
|
||||
}
|
||||
|
||||
func setCachedChain(cacheKey string, data *MarkovData, messages []*discordgo.Message) {
|
||||
hash := hashMessages(messages)
|
||||
|
||||
func setCachedChain(cacheKey string, data *lib.MarkovData, hash string) {
|
||||
markovCache.mu.Lock()
|
||||
defer markovCache.mu.Unlock()
|
||||
|
||||
@@ -186,106 +187,3 @@ func fetchMessages(s *discordgo.Session, channelID string, numMessages int) ([]*
|
||||
|
||||
return allMessages, nil
|
||||
}
|
||||
|
||||
func cleanText(text string) string {
|
||||
text = urlRegex.ReplaceAllString(text, "")
|
||||
text = mentionRegex.ReplaceAllString(text, "")
|
||||
return strings.Join(strings.Fields(text), " ")
|
||||
}
|
||||
|
||||
func buildMarkovChain(messages []*discordgo.Message) *MarkovData {
|
||||
data := &MarkovData{
|
||||
Chain: make(map[string][]string),
|
||||
Starts: make([]string, 0),
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
cleaned := cleanText(msg.Content)
|
||||
if cleaned == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
words := strings.Fields(cleaned)
|
||||
if len(words) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
startKey := key(words[0], words[1])
|
||||
data.Starts = append(data.Starts, startKey)
|
||||
|
||||
for i := 0; i < len(words)-2; i++ {
|
||||
k := key(words[i], words[i+1])
|
||||
val := words[i+2]
|
||||
data.Chain[k] = append(data.Chain[k], val)
|
||||
}
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func generateMessage(data *MarkovData, seed string) string {
|
||||
if len(data.Starts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var w1, w2 string
|
||||
var currentKey string
|
||||
|
||||
// Try to seed based on input question
|
||||
if seed != "" {
|
||||
seedWords := strings.Fields(cleanText(seed))
|
||||
var candidates []string
|
||||
|
||||
for k := range data.Chain {
|
||||
for _, sw := range seedWords {
|
||||
if len(sw) > 3 && strings.Contains(strings.ToLower(k), strings.ToLower(sw)) {
|
||||
candidates = append(candidates, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidates) > 0 {
|
||||
currentKey = candidates[rand.Intn(len(candidates))]
|
||||
}
|
||||
}
|
||||
|
||||
if currentKey == "" {
|
||||
currentKey = data.Starts[rand.Intn(len(data.Starts))]
|
||||
}
|
||||
|
||||
parts := strings.Split(currentKey, " ")
|
||||
w1, w2 = parts[0], parts[1]
|
||||
|
||||
output := []string{w1, w2}
|
||||
|
||||
for i := 0; i < 40; i++ {
|
||||
nextOptions, exists := data.Chain[currentKey]
|
||||
if !exists || len(nextOptions) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
nextWord := nextOptions[rand.Intn(len(nextOptions))]
|
||||
output = append(output, nextWord)
|
||||
|
||||
w1 = w2
|
||||
w2 = nextWord
|
||||
currentKey = key(w1, w2)
|
||||
|
||||
// Soft stop on punctuation
|
||||
if i > 5 && strings.ContainsAny(nextWord, ".!?") {
|
||||
if rand.Float32() > 0.3 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(output, " ")
|
||||
}
|
||||
|
||||
func key(w1, w2 string) string {
|
||||
return w1 + " " + w2
|
||||
}
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user