diff --git a/.env.example b/.env.example index acd882c..bddc1d6 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,28 @@ -# Tokens +# Discord Configuration DISCORD_TOKEN="" + +# Container configuration ROOT_DIR="" + +# Himbucks System Configuration +HIMBUCKS_PER_REWARD=10 +MESSAGE_COUNT_THRESHOLD=5 +HIMBUCKS_COOLDOWN_MINUTES=1 + +# Markov Chain Configuration +MARKOV_DEFAULT_MESSAGES=100 +MARKOV_MAX_MESSAGES=1000 +MARKOV_CACHE_SIZE=10 + +# Database Configuration +DB_MAX_OPEN_CONNS=25 +DB_MAX_IDLE_CONNS=5 +DB_CONN_MAX_LIFETIME_MINUTES=5 + +# Command Cooldowns (in seconds) +PING_COOLDOWN_SECONDS=5 +HS_COOLDOWN_SECONDS=10 +MARKOV_COOLDOWN_SECONDS=30 +HIMBUCKS_COOLDOWN_SECONDS=5 +HIMBOARD_COOLDOWN_SECONDS=5 +SENDBUCKS_COOLDOWN_SECONDS=1800 \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 55fd29a..90a55d1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Build stage -FROM golang:1.23.2 AS build +FROM golang:1.24.3 AS build WORKDIR /app diff --git a/command/markov.go b/command/markov.go index abb00ab..53d599b 100644 --- a/command/markov.go +++ b/command/markov.go @@ -1,27 +1,60 @@ package command import ( + "crypto/md5" + "fmt" + "himbot/lib" "math/rand" + "regexp" "strings" + "sync" + "time" "github.com/bwmarrin/discordgo" ) +// Cache for Markov chains to avoid rebuilding for the same channel/message count +type MarkovCache struct { + chains map[string]map[string][]string + hashes map[string]string + mu sync.RWMutex +} + +var ( + markovCache = &MarkovCache{ + chains: make(map[string]map[string][]string), + hashes: make(map[string]string), + } + // Regex for cleaning text + urlRegex = regexp.MustCompile(`https?://[^\s]+`) + mentionRegex = regexp.MustCompile(`<[@#&!][^>]+>`) + emojiRegex = regexp.MustCompile(``) +) + func MarkovCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) { channelID := i.ChannelID - numMessages := 100 // Default value + numMessages := lib.AppConfig.MarkovDefaultMessages // Default value from config if len(i.ApplicationCommandData().Options) > 0 { if i.ApplicationCommandData().Options[0].Name == "messages" { numMessages = int(i.ApplicationCommandData().Options[0].IntValue()) if numMessages <= 0 { - numMessages = 100 - } else if numMessages > 1000 { - numMessages = 1000 // Limit to 1000 messages max + numMessages = lib.AppConfig.MarkovDefaultMessages + } else if numMessages > lib.AppConfig.MarkovMaxMessages { + numMessages = lib.AppConfig.MarkovMaxMessages // Limit from config } } } + // Check cache first + cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages) + if chain := getCachedChain(cacheKey); chain != nil { + newMessage := generateMessage(chain) + if newMessage != "" { + return newMessage, nil + } + } + // Fetch messages allMessages, err := fetchMessages(s, channelID, numMessages) if err != nil { @@ -31,6 +64,9 @@ func MarkovCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string // Build the Markov chain from the fetched messages chain := buildMarkovChain(allMessages) + // Cache the chain + setCachedChain(cacheKey, chain, allMessages) + // Generate a new message using the Markov chain newMessage := generateMessage(chain) @@ -42,6 +78,49 @@ func MarkovCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string return newMessage, nil } +func getCachedChain(cacheKey string) map[string][]string { + markovCache.mu.RLock() + defer markovCache.mu.RUnlock() + + if chain, exists := markovCache.chains[cacheKey]; exists { + return chain + } + return nil +} + +func setCachedChain(cacheKey string, chain map[string][]string, messages []*discordgo.Message) { + // Create a hash of the messages to detect changes + hash := hashMessages(messages) + + markovCache.mu.Lock() + defer markovCache.mu.Unlock() + + // Only cache if we have a meaningful chain + if len(chain) > 10 { + markovCache.chains[cacheKey] = chain + markovCache.hashes[cacheKey] = hash + + // Simple cache cleanup - keep only last N entries from config + if len(markovCache.chains) > lib.AppConfig.MarkovCacheSize { + // Remove oldest entry (simple FIFO) + for k := range markovCache.chains { + delete(markovCache.chains, k) + delete(markovCache.hashes, k) + break + } + } + } +} + +func hashMessages(messages []*discordgo.Message) string { + var content strings.Builder + for _, msg := range messages { + content.WriteString(msg.ID) + content.WriteString(msg.Content) + } + return fmt.Sprintf("%x", md5.Sum([]byte(content.String()))) +} + func fetchMessages(s *discordgo.Session, channelID string, numMessages int) ([]*discordgo.Message, error) { var allMessages []*discordgo.Message var lastMessageID string @@ -61,7 +140,13 @@ func fetchMessages(s *discordgo.Session, channelID string, numMessages int) ([]* break // No more messages to fetch } - allMessages = append(allMessages, batch...) + // Filter out bot messages and empty messages during fetch + for _, msg := range batch { + if !msg.Author.Bot && len(strings.TrimSpace(msg.Content)) > 0 { + allMessages = append(allMessages, msg) + } + } + lastMessageID = batch[len(batch)-1].ID if len(batch) < 100 { @@ -72,20 +157,52 @@ func fetchMessages(s *discordgo.Session, channelID string, numMessages int) ([]* return allMessages, nil } -// buildMarkovChain creates a Markov chain from a list of messages +// cleanText removes URLs, mentions, emojis, and normalizes text +func cleanText(text string) string { + // Remove URLs + text = urlRegex.ReplaceAllString(text, "") + // Remove mentions + text = mentionRegex.ReplaceAllString(text, "") + // Remove custom emojis + text = emojiRegex.ReplaceAllString(text, "") + // Normalize whitespace + text = strings.Join(strings.Fields(text), " ") + return strings.TrimSpace(text) +} + +// buildMarkovChain creates an improved Markov chain from a list of messages func buildMarkovChain(messages []*discordgo.Message) map[string][]string { chain := make(map[string][]string) + for _, msg := range messages { - words := strings.Fields(msg.Content) + cleanedContent := cleanText(msg.Content) + if len(cleanedContent) < 3 { // Skip very short messages + continue + } + + words := strings.Fields(cleanedContent) + if len(words) < 2 { // Need at least 2 words for a chain + continue + } + // Build the chain by associating each word with the word that follows it for i := 0; i < len(words)-1; i++ { - chain[words[i]] = append(chain[words[i]], words[i+1]) + currentWord := strings.ToLower(words[i]) + nextWord := words[i+1] // Keep original case for next word + + // Skip very short words or words with special characters + if len(currentWord) < 2 || strings.ContainsAny(currentWord, "!@#$%^&*()[]{}") { + continue + } + + chain[currentWord] = append(chain[currentWord], nextWord) } } + return chain } -// generateMessage creates a new message using the Markov chain +// generateMessage creates a new message using the Markov chain with improved logic func generateMessage(chain map[string][]string) string { if len(chain) == 0 { return "" @@ -94,22 +211,66 @@ func generateMessage(chain map[string][]string) string { words := []string{} var currentWord string - // Start with a random word from the chain - for word := range chain { - currentWord = word - break - } - - // Generate up to 20 words - for i := 0; i < 20; i++ { - words = append(words, currentWord) - if nextWords, ok := chain[currentWord]; ok && len(nextWords) > 0 { - // Randomly select the next word from the possible follow-ups - currentWord = nextWords[rand.Intn(len(nextWords))] - } else { + // Start with a random word that has good follow-ups + attempts := 0 + for word, nextWords := range chain { + if len(nextWords) >= 2 && len(word) > 2 { // Prefer words with multiple options + currentWord = word + break + } + attempts++ + if attempts > 50 { // Fallback to any word + currentWord = word break } } - return strings.Join(words, " ") + if currentWord == "" { + return "" + } + + // Generate between 5 and 25 words + maxWords := 5 + rand.Intn(20) + for i := 0; i < maxWords; i++ { + // Add current word (capitalize first word) + if i == 0 { + words = append(words, strings.Title(currentWord)) + } else { + words = append(words, currentWord) + } + + if nextWords, ok := chain[strings.ToLower(currentWord)]; ok && len(nextWords) > 0 { + // Randomly select the next word from the possible follow-ups + currentWord = nextWords[rand.Intn(len(nextWords))] + } else { + // Try to find a new starting point + found := false + for word, nextWords := range chain { + if len(nextWords) > 0 && len(word) > 2 { + currentWord = word + found = true + break + } + } + if !found { + break + } + } + } + + result := strings.Join(words, " ") + + // Add punctuation if missing + if len(result) > 0 && !strings.ContainsAny(result[len(result)-1:], ".!?") { + // Randomly add punctuation + punctuation := []string{".", "!", "?"} + result += punctuation[rand.Intn(len(punctuation))] + } + + return result +} + +func init() { + // Seed random number generator + rand.Seed(time.Now().UnixNano()) } diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml deleted file mode 100644 index b8c756c..0000000 --- a/docker-compose.dev.yml +++ /dev/null @@ -1,11 +0,0 @@ -services: - app: - build: - context: . - dockerfile: Dockerfile - image: your-app-image:latest - command: ["/app"] - pull_policy: build - environment: - - DISCORD_TOKEN=$DISCORD_TOKEN - - COOLDOWN_ALLOW_LIST=$COOLDOWN_ALLOW_LIST diff --git a/docker-compose.yml b/docker-compose.yml index 9ce7178..04bbeb4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,7 +4,30 @@ services: ports: - "3117:3000" environment: + # Discord Configuration - DISCORD_TOKEN=${DISCORD_TOKEN} - - ROOT_DIR=${ROOT_DIR} + + # Himbucks System Configuration + - HIMBUCKS_PER_REWARD=${HIMBUCKS_PER_REWARD:-10} + - MESSAGE_COUNT_THRESHOLD=${MESSAGE_COUNT_THRESHOLD:-5} + - HIMBUCKS_COOLDOWN_MINUTES=${HIMBUCKS_COOLDOWN_MINUTES:-1} + + # Markov Chain Configuration + - MARKOV_DEFAULT_MESSAGES=${MARKOV_DEFAULT_MESSAGES:-100} + - MARKOV_MAX_MESSAGES=${MARKOV_MAX_MESSAGES:-1000} + - MARKOV_CACHE_SIZE=${MARKOV_CACHE_SIZE:-10} + + # Database Configuration + - DB_MAX_OPEN_CONNS=${DB_MAX_OPEN_CONNS:-25} + - DB_MAX_IDLE_CONNS=${DB_MAX_IDLE_CONNS:-5} + - DB_CONN_MAX_LIFETIME_MINUTES=${DB_CONN_MAX_LIFETIME_MINUTES:-5} + + # Command Cooldowns (in seconds) + - PING_COOLDOWN_SECONDS=${PING_COOLDOWN_SECONDS:-5} + - HS_COOLDOWN_SECONDS=${HS_COOLDOWN_SECONDS:-10} + - MARKOV_COOLDOWN_SECONDS=${MARKOV_COOLDOWN_SECONDS:-30} + - HIMBUCKS_COOLDOWN_SECONDS=${HIMBUCKS_COOLDOWN_SECONDS:-5} + - HIMBOARD_COOLDOWN_SECONDS=${HIMBOARD_COOLDOWN_SECONDS:-5} + - SENDBUCKS_COOLDOWN_SECONDS=${SENDBUCKS_COOLDOWN_SECONDS:-1800} volumes: - ${ROOT_DIR}/himbot_data:/data diff --git a/go.mod b/go.mod index dc3e479..90e62aa 100644 --- a/go.mod +++ b/go.mod @@ -1,18 +1,18 @@ module himbot -go 1.23 +go 1.24 require ( github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/libsql/sqlite-antlr4-parser v0.0.0-20240721121621-c0bdc870f11c // indirect - golang.org/x/crypto v0.28.0 // indirect - golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c // indirect - golang.org/x/sys v0.26.0 // indirect + golang.org/x/crypto v0.38.0 // indirect + golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 // indirect + golang.org/x/sys v0.33.0 // indirect ) require ( - github.com/bwmarrin/discordgo v0.28.1 + github.com/bwmarrin/discordgo v0.29.0 github.com/gorilla/websocket v1.5.3 // indirect github.com/joho/godotenv v1.5.1 - github.com/tursodatabase/go-libsql v0.0.0-20241011135853-3effbb6dea5c + github.com/tursodatabase/go-libsql v0.0.0-20250416102726-983f7e9acb0e ) diff --git a/go.sum b/go.sum index 7dcb545..db62ba4 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= -github.com/bwmarrin/discordgo v0.28.1 h1:gXsuo2GBO7NbR6uqmrrBDplPUx2T3nzu775q/Rd1aG4= -github.com/bwmarrin/discordgo v0.28.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= +github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= +github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= @@ -13,19 +13,19 @@ github.com/libsql/sqlite-antlr4-parser v0.0.0-20240721121621-c0bdc870f11c h1:WsJ github.com/libsql/sqlite-antlr4-parser v0.0.0-20240721121621-c0bdc870f11c/go.mod h1:gIcFddvsvPcRCO6QDmWH9/zcFd5U26QWWRMgZh4ddyo= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/tursodatabase/go-libsql v0.0.0-20241011135853-3effbb6dea5c h1:a8TrFzP+zK+uYcMWuLQoNOR78SG/yISSnHwMIcyWa2Q= -github.com/tursodatabase/go-libsql v0.0.0-20241011135853-3effbb6dea5c/go.mod h1:TjsB2miB8RW2Sse8sdxzVTdeGlx74GloD5zJYUC38d8= +github.com/tursodatabase/go-libsql v0.0.0-20250416102726-983f7e9acb0e h1:DUEcD8ukLWxIlcRWWJSuAX6IbEQln2bc7t9HOT45FFk= +github.com/tursodatabase/go-libsql v0.0.0-20250416102726-983f7e9acb0e/go.mod h1:TjsB2miB8RW2Sse8sdxzVTdeGlx74GloD5zJYUC38d8= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= -golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY= -golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI= +golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/lib/command.go b/lib/command.go index 845e91f..f0044ee 100644 --- a/lib/command.go +++ b/lib/command.go @@ -14,8 +14,15 @@ func HandleCommand(commandName string, cooldownDuration time.Duration, handler C return } + // Get user information (handle both guild and DM contexts) + user, userErr := GetUser(i) + if userErr != nil { + RespondWithError(s, i, "Error getting user information: "+userErr.Error()) + return + } + // Get or create user and guild profile - _, createUserError := GetOrCreateUserWithGuild(i.Member.User.ID, i.Member.User.Username, i.GuildID) + _, createUserError := GetOrCreateUserWithGuild(user.ID, user.Username, i.GuildID) if createUserError != nil { RespondWithError(s, i, "Error creating user profile: "+createUserError.Error()) diff --git a/lib/config.go b/lib/config.go new file mode 100644 index 0000000..9d13360 --- /dev/null +++ b/lib/config.go @@ -0,0 +1,90 @@ +package lib + +import ( + "os" + "strconv" + "time" +) + +// Config holds all configuration values +type Config struct { + // Discord settings + DiscordToken string + + // Himbucks settings + HimbucksPerReward int + MessageCountThreshold int + CooldownPeriod time.Duration + + // Markov settings + MarkovDefaultMessages int + MarkovMaxMessages int + MarkovCacheSize int + + // Database settings + MaxOpenConns int + MaxIdleConns int + ConnMaxLifetime time.Duration + + // Command cooldowns (in seconds) + PingCooldown int + HsCooldown int + MarkovCooldown int + HimbucksCooldown int + HimboardCooldown int + SendbucksCooldown int +} + +var AppConfig *Config + +// LoadConfig loads configuration from environment variables +func LoadConfig() *Config { + config := &Config{ + // Discord settings + DiscordToken: getEnv("DISCORD_TOKEN", ""), + + // Himbucks settings + HimbucksPerReward: getEnvInt("HIMBUCKS_PER_REWARD", 10), + MessageCountThreshold: getEnvInt("MESSAGE_COUNT_THRESHOLD", 5), + CooldownPeriod: time.Duration(getEnvInt("HIMBUCKS_COOLDOWN_MINUTES", 1)) * time.Minute, + + // Markov settings + MarkovDefaultMessages: getEnvInt("MARKOV_DEFAULT_MESSAGES", 100), + MarkovMaxMessages: getEnvInt("MARKOV_MAX_MESSAGES", 1000), + MarkovCacheSize: getEnvInt("MARKOV_CACHE_SIZE", 10), + + // Database settings + MaxOpenConns: getEnvInt("DB_MAX_OPEN_CONNS", 25), + MaxIdleConns: getEnvInt("DB_MAX_IDLE_CONNS", 5), + ConnMaxLifetime: time.Duration(getEnvInt("DB_CONN_MAX_LIFETIME_MINUTES", 5)) * time.Minute, + + // Command cooldowns (in seconds) + PingCooldown: getEnvInt("PING_COOLDOWN_SECONDS", 5), + HsCooldown: getEnvInt("HS_COOLDOWN_SECONDS", 10), + MarkovCooldown: getEnvInt("MARKOV_COOLDOWN_SECONDS", 30), + HimbucksCooldown: getEnvInt("HIMBUCKS_COOLDOWN_SECONDS", 5), + HimboardCooldown: getEnvInt("HIMBOARD_COOLDOWN_SECONDS", 5), + SendbucksCooldown: getEnvInt("SENDBUCKS_COOLDOWN_SECONDS", 1800), + } + + AppConfig = config + return config +} + +// getEnv gets an environment variable with a default value +func getEnv(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +// getEnvInt gets an environment variable as an integer with a default value +func getEnvInt(key string, defaultValue int) int { + if value := os.Getenv(key); value != "" { + if intValue, err := strconv.Atoi(value); err == nil { + return intValue + } + } + return defaultValue +} \ No newline at end of file diff --git a/lib/db.go b/lib/db.go index e2b3b59..10da77c 100644 --- a/lib/db.go +++ b/lib/db.go @@ -16,6 +16,15 @@ import ( var DBClient *sql.DB var DBConnector *libsql.Connector +// Prepared statements +var ( + stmtGetBalance *sql.Stmt + stmtUpdateBalance *sql.Stmt + stmtGetLeaderboard *sql.Stmt + stmtGetUserProfile *sql.Stmt + stmtUpdateProfile *sql.Stmt +) + func InitDB() error { // Determine DB path based on /data directory existence var dbPath string @@ -31,9 +40,24 @@ func InitDB() error { os.Exit(1) } + // Configure connection pool using config values + db.SetMaxOpenConns(AppConfig.MaxOpenConns) + db.SetMaxIdleConns(AppConfig.MaxIdleConns) + db.SetConnMaxLifetime(AppConfig.ConnMaxLifetime) + + // Test the connection + if err := db.Ping(); err != nil { + return fmt.Errorf("failed to ping database: %w", err) + } + DBClient = db - return runMigrations() + if err := runMigrations(); err != nil { + return err + } + + // Prepare frequently used statements + return prepareStatements() } type Migration struct { @@ -140,3 +164,61 @@ func runMigrations() error { log.Println("Database migrations completed successfully") return nil } + +func prepareStatements() error { + var err error + + // Prepare balance query + stmtGetBalance, err = DBClient.Prepare(` + SELECT gp.currency_balance + FROM guild_profiles gp + JOIN users u ON gp.user_id = u.id + WHERE u.discord_id = ? AND gp.guild_id = ?`) + if err != nil { + return fmt.Errorf("failed to prepare balance query: %w", err) + } + + // Prepare leaderboard query + stmtGetLeaderboard, err = DBClient.Prepare(` + SELECT u.username, gp.currency_balance, gp.message_count + FROM guild_profiles gp + JOIN users u ON gp.user_id = u.id + WHERE gp.guild_id = ? + ORDER BY gp.currency_balance DESC + LIMIT ?`) + if err != nil { + return fmt.Errorf("failed to prepare leaderboard query: %w", err) + } + + // Prepare user profile query + stmtGetUserProfile, err = DBClient.Prepare(` + SELECT message_count, last_reward_at + FROM guild_profiles + WHERE user_id = ? AND guild_id = ?`) + if err != nil { + return fmt.Errorf("failed to prepare user profile query: %w", err) + } + + log.Println("Prepared statements initialized successfully") + return nil +} + +// CleanupDB closes all prepared statements +func CleanupDB() { + if stmtGetBalance != nil { + stmtGetBalance.Close() + } + if stmtUpdateBalance != nil { + stmtUpdateBalance.Close() + } + if stmtGetLeaderboard != nil { + stmtGetLeaderboard.Close() + } + if stmtGetUserProfile != nil { + stmtGetUserProfile.Close() + } + if stmtUpdateProfile != nil { + stmtUpdateProfile.Close() + } + log.Println("Database cleanup completed") +} diff --git a/lib/himbucks.go b/lib/himbucks.go index 0500376..9995a38 100644 --- a/lib/himbucks.go +++ b/lib/himbucks.go @@ -8,12 +8,6 @@ import ( "github.com/bwmarrin/discordgo" ) -const ( - HimbucksPerReward = 10 - MessageCountThreshold = 5 - CooldownPeriod = time.Minute -) - type HimbucksEntry struct { Username string Balance int @@ -41,8 +35,8 @@ func ProcessHimbucks(s *discordgo.Session, m *discordgo.MessageCreate, ctx *Proc } messageCount++ - shouldReward := messageCount >= MessageCountThreshold && - (!lastRewardAt.Valid || time.Since(lastRewardAt.Time) >= CooldownPeriod) + shouldReward := messageCount >= AppConfig.MessageCountThreshold && + (!lastRewardAt.Valid || time.Since(lastRewardAt.Time) >= AppConfig.CooldownPeriod) if shouldReward { _, err = tx.Exec(` @@ -51,7 +45,7 @@ func ProcessHimbucks(s *discordgo.Session, m *discordgo.MessageCreate, ctx *Proc message_count = 0, last_reward_at = CURRENT_TIMESTAMP WHERE user_id = ? AND guild_id = ?`, - HimbucksPerReward, ctx.UserID, ctx.GuildID) + AppConfig.HimbucksPerReward, ctx.UserID, ctx.GuildID) } else { _, err = tx.Exec(` UPDATE guild_profiles @@ -69,12 +63,7 @@ func ProcessHimbucks(s *discordgo.Session, m *discordgo.MessageCreate, ctx *Proc func GetBalance(discordID, guildID string) (int, error) { var balance int - err := DBClient.QueryRow(` - SELECT gp.currency_balance - FROM guild_profiles gp - JOIN users u ON gp.user_id = u.id - WHERE u.discord_id = ? AND gp.guild_id = ?`, - discordID, guildID).Scan(&balance) + err := stmtGetBalance.QueryRow(discordID, guildID).Scan(&balance) if err == sql.ErrNoRows { return 0, nil } @@ -173,14 +162,7 @@ func SendBalance(fromDiscordID, toDiscordID, guildID string, amount int) error { } func GetLeaderboard(guildID string, limit int) ([]HimbucksEntry, error) { - rows, err := DBClient.Query(` - SELECT u.username, gp.currency_balance, gp.message_count - FROM guild_profiles gp - JOIN users u ON gp.user_id = u.id - WHERE gp.guild_id = ? - ORDER BY gp.currency_balance DESC - LIMIT ?`, - guildID, limit) + rows, err := stmtGetLeaderboard.Query(guildID, limit) if err != nil { return nil, fmt.Errorf("failed to get leaderboard: %w", err) } diff --git a/main.go b/main.go index 3a8a822..1f6d1ad 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "himbot/command" "himbot/lib" "log" @@ -14,6 +15,164 @@ import ( ) var ( + commands []*discordgo.ApplicationCommand + commandHandlers map[string]func(s *discordgo.Session, i *discordgo.InteractionCreate) +) + +func main() { + godotenv.Load(".env") + + // Load configuration + config := lib.LoadConfig() + + // Initialize commands and handlers with config + initCommands(config) + initCommandHandlers(config) + + err := lib.InitDB() + if err != nil { + log.Fatalf("Failed to initialize database: %v", err) + } + + if config.DiscordToken == "" { + log.Fatalln("No $DISCORD_TOKEN given.") + } + + dg, err := discordgo.New("Bot " + config.DiscordToken) + if err != nil { + log.Fatalf("Error creating Discord session: %v", err) + } + + dg.AddHandler(ready) + dg.AddHandler(interactionCreate) + + processorManager := lib.NewMessageProcessorManager() + + // Register processors + processorManager.RegisterProcessor(lib.ProcessHimbucks) + + dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) { + processorManager.ProcessMessage(s, m) + }) + + dg.Identify.Intents = discordgo.IntentsGuilds | discordgo.IntentsGuildMessages + + err = dg.Open() + if err != nil { + log.Fatalf("Error opening connection: %v", err) + } + + log.Println("Bot is now running. Press CTRL-C to exit.") + registerCommands(dg) + + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + <-sc + + log.Println("Shutting down gracefully...") + + if lib.DBClient != nil { + // Close prepared statements + lib.CleanupDB() + lib.DBClient.Close() + } + + dg.Close() +} + +func ready(s *discordgo.Session, event *discordgo.Ready) { + log.Printf("Logged in as: %v#%v", s.State.User.Username, s.State.User.Discriminator) +} + +func interactionCreate(s *discordgo.Session, i *discordgo.InteractionCreate) { + if h, ok := commandHandlers[i.ApplicationCommandData().Name]; ok { + h(s, i) + } +} + +func registerCommands(s *discordgo.Session) { + log.Println("Checking command registration...") + + existingCommands, err := s.ApplicationCommands(s.State.User.ID, "") + if err != nil { + log.Printf("Error fetching existing commands: %v", err) + return + } + + // Create maps for easier comparison + existingMap := make(map[string]*discordgo.ApplicationCommand) + for _, cmd := range existingCommands { + existingMap[cmd.Name] = cmd + } + + desiredMap := make(map[string]*discordgo.ApplicationCommand) + for _, cmd := range commands { + desiredMap[cmd.Name] = cmd + } + + // Delete commands that no longer exist + for name, existingCmd := range existingMap { + if _, exists := desiredMap[name]; !exists { + log.Printf("Deleting removed command: %s", name) + err := s.ApplicationCommandDelete(s.State.User.ID, "", existingCmd.ID) + if err != nil { + log.Printf("Error deleting command %s: %v", name, err) + } + } + } + + // Update or create commands + for _, desiredCmd := range commands { + if existingCmd, exists := existingMap[desiredCmd.Name]; exists { + // Check if command needs updating (simple comparison) + if !commandsEqual(existingCmd, desiredCmd) { + log.Printf("Updating command: %s", desiredCmd.Name) + _, err := s.ApplicationCommandEdit(s.State.User.ID, "", existingCmd.ID, desiredCmd) + if err != nil { + log.Printf("Error updating command %s: %v", desiredCmd.Name, err) + } + } else { + log.Printf("Command %s is up to date", desiredCmd.Name) + } + } else { + log.Printf("Creating new command: %s", desiredCmd.Name) + _, err := s.ApplicationCommandCreate(s.State.User.ID, "", desiredCmd) + if err != nil { + log.Printf("Error creating command %s: %v", desiredCmd.Name, err) + } + } + } + + log.Println("Command registration completed") +} + +// commandsEqual performs a basic comparison between two commands +func commandsEqual(existing, desired *discordgo.ApplicationCommand) bool { + if existing.Name != desired.Name || + existing.Description != desired.Description || + len(existing.Options) != len(desired.Options) { + return false + } + + // Compare options (basic comparison) + for i, existingOpt := range existing.Options { + if i >= len(desired.Options) { + return false + } + desiredOpt := desired.Options[i] + if existingOpt.Name != desiredOpt.Name || + existingOpt.Description != desiredOpt.Description || + existingOpt.Type != desiredOpt.Type || + existingOpt.Required != desiredOpt.Required { + return false + } + } + + return true +} + +// initCommands initializes command definitions with configuration +func initCommands(config *lib.Config) { commands = []*discordgo.ApplicationCommand{ { Name: "ping", @@ -38,7 +197,7 @@ var ( { Type: discordgo.ApplicationCommandOptionInteger, Name: "messages", - Description: "Number of messages to use (default: 100, max: 1000)", + Description: fmt.Sprintf("Number of messages to use (default: %d, max: %d)", config.MarkovDefaultMessages, config.MarkovMaxMessages), Required: false, }, }, @@ -71,103 +230,16 @@ var ( }, }, } +} +// initCommandHandlers initializes command handlers with configuration +func initCommandHandlers(config *lib.Config) { commandHandlers = map[string]func(s *discordgo.Session, i *discordgo.InteractionCreate){ - "ping": lib.HandleCommand("ping", 5*time.Second, command.PingCommand), - "hs": lib.HandleCommand("hs", 10*time.Second, command.HsCommand), - "markov": lib.HandleCommand("markov", 30*time.Second, command.MarkovCommand), - "himbucks": lib.HandleCommand("himbucks", 5*time.Second, command.BalanceGetCommand), - "himboard": lib.HandleCommand("himboard", 5*time.Second, command.LeaderboardCommand), - "sendbucks": lib.HandleCommand("sendbucks", 1800*time.Second, command.BalanceSendCommand), - } -) - -func main() { - godotenv.Load(".env") - - err := lib.InitDB() - if err != nil { - log.Fatalf("Failed to initialize database: %v", err) - } - - token := os.Getenv("DISCORD_TOKEN") - - if token == "" { - log.Fatalln("No $DISCORD_TOKEN given.") - } - - dg, err := discordgo.New("Bot " + token) - if err != nil { - log.Fatalf("Error creating Discord session: %v", err) - } - - dg.AddHandler(ready) - dg.AddHandler(interactionCreate) - - processorManager := lib.NewMessageProcessorManager() - - // Register processors - processorManager.RegisterProcessor(lib.ProcessHimbucks) - - dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) { - processorManager.ProcessMessage(s, m) - }) - - dg.Identify.Intents = discordgo.IntentsGuilds | discordgo.IntentsGuildMessages - - err = dg.Open() - if err != nil { - log.Fatalf("Error opening connection: %v", err) - } - - log.Println("Bot is now running. Press CTRL-C to exit.") - registerCommands(dg) - - sc := make(chan os.Signal, 1) - signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) - <-sc - - if lib.DBClient != nil { - lib.DBClient.Close() - } - - dg.Close() -} - -func ready(s *discordgo.Session, event *discordgo.Ready) { - log.Printf("Logged in as: %v#%v", s.State.User.Username, s.State.User.Discriminator) -} - -func interactionCreate(s *discordgo.Session, i *discordgo.InteractionCreate) { - if h, ok := commandHandlers[i.ApplicationCommandData().Name]; ok { - h(s, i) - } -} - -func registerCommands(s *discordgo.Session) { - // First, delete all existing commands - log.Println("Deleting existing commands...") - - existingCommands, err := s.ApplicationCommands(s.State.User.ID, "") - if err != nil { - log.Printf("Error fetching existing commands: %v", err) - } - - for _, cmd := range existingCommands { - err := s.ApplicationCommandDelete(s.State.User.ID, "", cmd.ID) - if err != nil { - log.Printf("Error deleting command %s: %v", cmd.Name, err) - } - } - - // Then register the new commands - log.Println("Registering new commands...") - registeredCommands := make([]*discordgo.ApplicationCommand, len(commands)) - for i, v := range commands { - cmd, err := s.ApplicationCommandCreate(s.State.User.ID, "", v) - if err != nil { - log.Panicf("Cannot create '%v' command: %v", v.Name, err) - } - registeredCommands[i] = cmd + "ping": lib.HandleCommand("ping", time.Duration(config.PingCooldown)*time.Second, command.PingCommand), + "hs": lib.HandleCommand("hs", time.Duration(config.HsCooldown)*time.Second, command.HsCommand), + "markov": lib.HandleCommand("markov", time.Duration(config.MarkovCooldown)*time.Second, command.MarkovCommand), + "himbucks": lib.HandleCommand("himbucks", time.Duration(config.HimbucksCooldown)*time.Second, command.BalanceGetCommand), + "himboard": lib.HandleCommand("himboard", time.Duration(config.HimboardCooldown)*time.Second, command.LeaderboardCommand), + "sendbucks": lib.HandleCommand("sendbucks", time.Duration(config.SendbucksCooldown)*time.Second, command.BalanceSendCommand), } }