Files
himbot/command/markov.go
Atridad Lahiji 9694a42f3f
All checks were successful
Docker Deploy / build-and-push (push) Successful in 3m23s
Trying some weird training nonsense. Maybe this will be fun.
2026-01-20 14:48:53 -07:00

190 lines
4.1 KiB
Go

package command
import (
"crypto/md5"
"encoding/gob"
"fmt"
"himbot/lib"
"os"
"strings"
"sync"
"github.com/bwmarrin/discordgo"
)
type MarkovCache struct {
data map[string]*lib.MarkovData
hashes map[string]string
mu sync.RWMutex
}
var (
markovCache = &MarkovCache{
data: make(map[string]*lib.MarkovData),
hashes: make(map[string]string),
}
bardChain *lib.MarkovData
)
func InitBard(modelPath string) error {
f, err := os.Open(modelPath)
if err != nil {
return err
}
defer f.Close()
var data lib.MarkovData
decoder := gob.NewDecoder(f)
if err := decoder.Decode(&data); err != nil {
return err
}
bardChain = &data
return nil
}
func BardCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
if bardChain == nil {
return "The bard is sleeping (dataset not loaded).", nil
}
var question string
for _, option := range i.ApplicationCommandData().Options {
if option.Name == "question" {
question = option.StringValue()
}
}
answer := lib.GenerateMessage(bardChain, question)
if answer == "" {
answer = "Words fail me."
}
if question != "" {
return fmt.Sprintf("**Q:** %s\n**A:** %s", question, answer), nil
}
return answer, nil
}
func MarkovQuestionCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
channelID := i.ChannelID
var question string
numMessages := lib.AppConfig.MarkovDefaultMessages
for _, option := range i.ApplicationCommandData().Options {
switch option.Name {
case "question":
question = option.StringValue()
case "messages":
numMessages = int(option.IntValue())
if numMessages <= 0 {
numMessages = lib.AppConfig.MarkovDefaultMessages
} else if numMessages > lib.AppConfig.MarkovMaxMessages {
numMessages = lib.AppConfig.MarkovMaxMessages
}
}
}
if question == "" {
return "Please provide a question!", nil
}
cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages)
var data *lib.MarkovData
if cachedData := getCachedChain(cacheKey); cachedData != nil {
data = cachedData
} else {
allMessages, err := fetchMessages(s, channelID, numMessages)
if err != nil {
return "", err
}
var texts []string
for _, msg := range allMessages {
texts = append(texts, msg.Content)
}
// Use order 2 for chat history (sparse data)
data = lib.BuildMarkovChain(texts, 2)
setCachedChain(cacheKey, data, hashMessages(allMessages))
}
answer := lib.GenerateMessage(data, question)
if answer == "" {
answer = "I don't have enough context to answer that."
}
return fmt.Sprintf("**Q:** %s\n**A:** %s", question, answer), nil
}
func getCachedChain(cacheKey string) *lib.MarkovData {
markovCache.mu.RLock()
defer markovCache.mu.RUnlock()
return markovCache.data[cacheKey]
}
func setCachedChain(cacheKey string, data *lib.MarkovData, hash string) {
markovCache.mu.Lock()
defer markovCache.mu.Unlock()
if len(data.Starts) > 0 {
markovCache.data[cacheKey] = data
markovCache.hashes[cacheKey] = hash
if len(markovCache.data) > lib.AppConfig.MarkovCacheSize {
for k := range markovCache.data {
delete(markovCache.data, k)
delete(markovCache.hashes, k)
break
}
}
}
}
func hashMessages(messages []*discordgo.Message) string {
var content strings.Builder
for _, msg := range messages {
content.WriteString(msg.ID)
}
return fmt.Sprintf("%x", md5.Sum([]byte(content.String())))
}
func fetchMessages(s *discordgo.Session, channelID string, numMessages int) ([]*discordgo.Message, error) {
var allMessages []*discordgo.Message
var lastMessageID string
allMessages = make([]*discordgo.Message, 0, numMessages)
for len(allMessages) < numMessages {
batchSize := 100
if numMessages-len(allMessages) < 100 {
batchSize = numMessages - len(allMessages)
}
batch, err := s.ChannelMessages(channelID, batchSize, lastMessageID, "", "")
if err != nil {
return nil, err
}
if len(batch) == 0 {
break
}
for _, msg := range batch {
if !msg.Author.Bot && len(strings.TrimSpace(msg.Content)) > 0 {
allMessages = append(allMessages, msg)
}
}
lastMessageID = batch[len(batch)-1].ID
if len(batch) < 100 {
break
}
}
return allMessages, nil
}