package command import ( "crypto/md5" "encoding/gob" "fmt" "himbot/lib" "os" "strings" "sync" "github.com/bwmarrin/discordgo" ) type MarkovCache struct { data map[string]*lib.MarkovData hashes map[string]string mu sync.RWMutex } var ( markovCache = &MarkovCache{ data: make(map[string]*lib.MarkovData), hashes: make(map[string]string), } bardChain *lib.MarkovData ) func InitBard(modelPath string) error { f, err := os.Open(modelPath) if err != nil { return err } defer f.Close() var data lib.MarkovData decoder := gob.NewDecoder(f) if err := decoder.Decode(&data); err != nil { return err } bardChain = &data return nil } func BardCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) { if bardChain == nil { return "The bard is sleeping (dataset not loaded).", nil } var question string for _, option := range i.ApplicationCommandData().Options { if option.Name == "question" { question = option.StringValue() } } answer := lib.GenerateMessage(bardChain, question) if answer == "" { answer = "Words fail me." } if question != "" { return fmt.Sprintf("**Q:** %s\n**A:** %s", question, answer), nil } return answer, nil } func MarkovQuestionCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) { channelID := i.ChannelID var question string numMessages := lib.AppConfig.MarkovDefaultMessages for _, option := range i.ApplicationCommandData().Options { switch option.Name { case "question": question = option.StringValue() case "messages": numMessages = int(option.IntValue()) if numMessages <= 0 { numMessages = lib.AppConfig.MarkovDefaultMessages } else if numMessages > lib.AppConfig.MarkovMaxMessages { numMessages = lib.AppConfig.MarkovMaxMessages } } } if question == "" { return "Please provide a question!", nil } cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages) var data *lib.MarkovData if cachedData := getCachedChain(cacheKey); cachedData != nil { data = cachedData } else { allMessages, err := fetchMessages(s, channelID, numMessages) if err != nil { return "", err } var texts []string for _, msg := range allMessages { texts = append(texts, msg.Content) } // Use order 2 for chat history (sparse data) data = lib.BuildMarkovChain(texts, 2) setCachedChain(cacheKey, data, hashMessages(allMessages)) } answer := lib.GenerateMessage(data, question) if answer == "" { answer = "I don't have enough context to answer that." } return fmt.Sprintf("**Q:** %s\n**A:** %s", question, answer), nil } func getCachedChain(cacheKey string) *lib.MarkovData { markovCache.mu.RLock() defer markovCache.mu.RUnlock() return markovCache.data[cacheKey] } func setCachedChain(cacheKey string, data *lib.MarkovData, hash string) { markovCache.mu.Lock() defer markovCache.mu.Unlock() if len(data.Starts) > 0 { markovCache.data[cacheKey] = data markovCache.hashes[cacheKey] = hash if len(markovCache.data) > lib.AppConfig.MarkovCacheSize { for k := range markovCache.data { delete(markovCache.data, k) delete(markovCache.hashes, k) break } } } } func hashMessages(messages []*discordgo.Message) string { var content strings.Builder for _, msg := range messages { content.WriteString(msg.ID) } return fmt.Sprintf("%x", md5.Sum([]byte(content.String()))) } func fetchMessages(s *discordgo.Session, channelID string, numMessages int) ([]*discordgo.Message, error) { var allMessages []*discordgo.Message var lastMessageID string allMessages = make([]*discordgo.Message, 0, numMessages) for len(allMessages) < numMessages { batchSize := 100 if numMessages-len(allMessages) < 100 { batchSize = numMessages - len(allMessages) } batch, err := s.ChannelMessages(channelID, batchSize, lastMessageID, "", "") if err != nil { return nil, err } if len(batch) == 0 { break } for _, msg := range batch { if !msg.Author.Bot && len(strings.TrimSpace(msg.Content)) > 0 { allMessages = append(allMessages, msg) } } lastMessageID = batch[len(batch)-1].ID if len(batch) < 100 { break } } return allMessages, nil }