Thumbnail

rani/matterbridge.git

Clone URL: https://git.buni.party/rani/matterbridge.git

commit b192d92b06bfbcce7fe1174b766798ad3a69eb4b Author: Duco van Amstel <duco.vanamstel@gmail.com> Date: Tue Nov 13 22:30:56 2018 +0000 Make config.Config more unit-test friendly (#586) diff --git a/bridge/bridge.go b/bridge/bridge.go index 0436eeb..debe2d6 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -227 +227 @@ type Bridge struct {   Channels map[string]config.ChannelInfo   Joined map[string]bool   Log *log.Entry - Config *config.Config + Config config.Config   General *config.Protocol  }   @@ -6936 +6941 @@ func (b *Bridge) joinChannels(channels map[string]config.ChannelInfo, exists map  }    func (b *Bridge) GetBool(key string) bool { - if b.Config.GetBool(b.Account + "." + key) { - return b.Config.GetBool(b.Account + "." + key) + val, ok := b.Config.GetBool(b.Account + "." + key) + if !ok { + val, _ = b.Config.GetBool("general." + key)   } - return b.Config.GetBool("general." + key) + return val  }    func (b *Bridge) GetInt(key string) int { - if b.Config.GetInt(b.Account+"."+key) != 0 { - return b.Config.GetInt(b.Account + "." + key) + val, ok := b.Config.GetInt(b.Account + "." + key) + if !ok { + val, _ = b.Config.GetInt("general." + key)   } - return b.Config.GetInt("general." + key) + return val  }    func (b *Bridge) GetString(key string) string { - if b.Config.GetString(b.Account+"."+key) != "" { - return b.Config.GetString(b.Account + "." + key) + val, ok := b.Config.GetString(b.Account + "." + key) + if !ok { + val, _ = b.Config.GetString("general." + key)   } - return b.Config.GetString("general." + key) + return val  }    func (b *Bridge) GetStringSlice(key string) []string { - if len(b.Config.GetStringSlice(b.Account+"."+key)) != 0 { - return b.Config.GetStringSlice(b.Account + "." + key) + val, ok := b.Config.GetStringSlice(b.Account + "." + key) + if !ok { + val, _ = b.Config.GetStringSlice("general." + key)   } - return b.Config.GetStringSlice("general." + key) + return val  }    func (b *Bridge) GetStringSlice2D(key string) [][]string { - if len(b.Config.GetStringSlice2D(b.Account+"."+key)) != 0 { - return b.Config.GetStringSlice2D(b.Account + "." + key) + val, ok := b.Config.GetStringSlice2D(b.Account + "." + key) + if !ok { + val, _ = b.Config.GetStringSlice2D("general." + key)   } - return b.Config.GetStringSlice2D("general." + key) + return val  } diff --git a/bridge/config/config.go b/bridge/config/config.go index 503a8de..258401d 100644 --- a/bridge/config/config.go +++ b/bridge/config/config.go @@ -27 +29 @@ package config    import (   "bytes" + "fmt"   "io/ioutil" + "os"   "strings"   "sync"   "time" @@ -17713 +17923 @@ type ConfigValues struct {   SameChannelGateway []SameChannelGateway  }   -type Config struct { +type Config interface { + ConfigValues() *ConfigValues + GetBool(key string) (bool, bool) + GetInt(key string) (int, bool) + GetString(key string) (string, bool) + GetStringSlice(key string) ([]string, bool) + GetStringSlice2D(key string) ([][]string, bool) +} + +type config struct {   v *viper.Viper - *ConfigValues   sync.RWMutex + + cv *ConfigValues  }   -func NewConfig(cfgfile string) *Config { +func NewConfig(cfgfile string) Config {   log.SetFormatter(&prefixed.TextFormatter{PrefixPadding: 13, DisableColors: true, FullTimestamp: false})   flog := log.WithFields(log.Fields{"prefix": "config"})   viper.SetConfigFile(cfgfile) @@ -1919 +2039 @@ func NewConfig(cfgfile string) *Config {   if err != nil {   log.Fatal(err)   } - mycfg := NewConfigFromString(input) - if mycfg.ConfigValues.General.MediaDownloadSize == 0 { - mycfg.ConfigValues.General.MediaDownloadSize = 1000000 + mycfg := newConfigFromString(input) + if mycfg.cv.General.MediaDownloadSize == 0 { + mycfg.cv.General.MediaDownloadSize = 1000000   }   viper.WatchConfig()   viper.OnConfigChange(func(e fsnotify.Event) { @@ -2118 +22311 @@ func getFileContents(filename string) ([]byte, error) {   return input, nil  }   -func NewConfigFromString(input []byte) *Config { - var cfg ConfigValues +func NewConfigFromString(input []byte) Config { + return newConfigFromString(input) +} + +func newConfigFromString(input []byte) *config {   viper.SetConfigType("toml")   viper.SetEnvPrefix("matterbridge")   viper.AddConfigPath(".") @@ -22245 +23751 @@ func NewConfigFromString(input []byte) *Config {   if err != nil {   log.Fatal(err)   } - err = viper.Unmarshal(&cfg) + + cfg := &ConfigValues{} + err = viper.Unmarshal(cfg)   if err != nil {   log.Fatal(err)   } - mycfg := new(Config) - mycfg.v = viper.GetViper() - mycfg.ConfigValues = &cfg - return mycfg + return &config{ + v: viper.GetViper(), + cv: cfg, + } +} + +func (c *config) ConfigValues() *ConfigValues { + return c.cv  }   -func (c *Config) GetBool(key string) bool { +func (c *config) GetBool(key string) (bool, bool) {   c.RLock()   defer c.RUnlock()   // log.Debugf("getting bool %s = %#v", key, c.v.GetBool(key)) - return c.v.GetBool(key) + return c.v.GetBool(key), c.v.IsSet(key)  }   -func (c *Config) GetInt(key string) int { +func (c *config) GetInt(key string) (int, bool) {   c.RLock()   defer c.RUnlock()   // log.Debugf("getting int %s = %d", key, c.v.GetInt(key)) - return c.v.GetInt(key) + return c.v.GetInt(key), c.v.IsSet(key)  }   -func (c *Config) GetString(key string) string { +func (c *config) GetString(key string) (string, bool) {   c.RLock()   defer c.RUnlock()   // log.Debugf("getting String %s = %s", key, c.v.GetString(key)) - return c.v.GetString(key) + return c.v.GetString(key), c.v.IsSet(key)  }   -func (c *Config) GetStringSlice(key string) []string { +func (c *config) GetStringSlice(key string) ([]string, bool) {   c.RLock()   defer c.RUnlock()   // log.Debugf("getting StringSlice %s = %#v", key, c.v.GetStringSlice(key)) - return c.v.GetStringSlice(key) + return c.v.GetStringSlice(key), c.v.IsSet(key)  }   -func (c *Config) GetStringSlice2D(key string) [][]string { +func (c *config) GetStringSlice2D(key string) ([][]string, bool) {   c.RLock()   defer c.RUnlock()   result := [][]string{} @@ -2729 +2939 @@ func (c *Config) GetStringSlice2D(key string) [][]string {   }   result = append(result, result2)   } - return result + return result, true   } - return result + return result, false  }    func GetIconURL(msg *Message, iconURL string) string { @@ -2863 +30746 @@ func GetIconURL(msg *Message, iconURL string) string {   iconURL = strings.Replace(iconURL, "{PROTOCOL}", protocol, -1)   return iconURL  } + +type TestConfig struct { + Config + + Overrides map[string]interface{} +} + +func (c *TestConfig) GetBool(key string) (bool, bool) { + val, ok := c.Overrides[key] + fmt.Fprintln(os.Stderr, "DEBUG:", c.Overrides, key, ok, val) + if ok { + return val.(bool), true + } + return c.Config.GetBool(key) +} + +func (c *TestConfig) GetInt(key string) (int, bool) { + if val, ok := c.Overrides[key]; ok { + return val.(int), true + } + return c.Config.GetInt(key) +} + +func (c *TestConfig) GetString(key string) (string, bool) { + if val, ok := c.Overrides[key]; ok { + return val.(string), true + } + return c.Config.GetString(key) +} + +func (c *TestConfig) GetStringSlice(key string) ([]string, bool) { + if val, ok := c.Overrides[key]; ok { + return val.([]string), true + } + return c.Config.GetStringSlice(key) +} + +func (c *TestConfig) GetStringSlice2D(key string) ([][]string, bool) { + if val, ok := c.Overrides[key]; ok { + return val.([][]string), true + } + return c.Config.GetStringSlice2D(key) +} diff --git a/gateway/gateway.go b/gateway/gateway.go index 9baeef9..bbaef04 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -337 +338 @@ import (  )    type Gateway struct { - *config.Config + config.Config +   Router *Router   MyConfig *config.Gateway   Bridges map[string]*bridge.Bridge @@ -1077 +1087 @@ func (gw *Gateway) AddBridge(cfg *config.Bridge) error {   if br == nil {   br = bridge.New(cfg)   br.Config = gw.Router.Config - br.General = &gw.General + br.General = &gw.ConfigValues().General   // set logging   br.Log = log.WithFields(log.Fields{"prefix": "bridge"})   brconfig := &bridge.Config{Remote: gw.Message, Log: log.WithFields(log.Fields{"prefix": br.Protocol}), Bridge: br} @@ -2787 +2797 @@ func (gw *Gateway) handleMessage(msg config.Message, dest *bridge.Bridge) []*BrM     // Get the ID of the parent message in thread   var canonicalParentMsgID string - if msg.ParentID != "" && (gw.Config.General.PreserveThreading || dest.GetBool("PreserveThreading")) { + if msg.ParentID != "" && (gw.ConfigValues().General.PreserveThreading || dest.GetBool("PreserveThreading")) {   thisParentMsgID := dest.Protocol + " " + msg.ParentID   canonicalParentMsgID = gw.FindCanonicalMsgID(thisParentMsgID)   } @@ -39113 +39213 @@ func (gw *Gateway) ignoreMessage(msg *config.Message) bool {  func (gw *Gateway) modifyUsername(msg config.Message, dest *bridge.Bridge) string {   br := gw.Bridges[msg.Account]   msg.Protocol = br.Protocol - if gw.Config.General.StripNick || dest.GetBool("StripNick") { + if gw.ConfigValues().General.StripNick || dest.GetBool("StripNick") {   re := regexp.MustCompile("[^a-zA-Z0-9]+")   msg.Username = re.ReplaceAllString(msg.Username, "")   }   nick := dest.GetString("RemoteNickFormat")   if nick == "" { - nick = gw.Config.General.RemoteNickFormat + nick = gw.ConfigValues().General.RemoteNickFormat   }     // loop to replace nicks @@ -4367 +4377 @@ func (gw *Gateway) modifyUsername(msg config.Message, dest *bridge.Bridge) strin  }    func (gw *Gateway) modifyAvatar(msg config.Message, dest *bridge.Bridge) string { - iconurl := gw.Config.General.IconURL + iconurl := gw.ConfigValues().General.IconURL   if iconurl == "" {   iconurl = dest.GetString("IconURL")   } @@ -4777 +4789 @@ func (gw *Gateway) handleFiles(msg *config.Message) {   reg := regexp.MustCompile("[^a-zA-Z0-9]+")     // If we don't have a attachfield or we don't have a mediaserver configured return - if msg.Extra == nil || (gw.Config.General.MediaServerUpload == "" && gw.Config.General.MediaDownloadPath == "") { + if msg.Extra == nil || + (gw.ConfigValues().General.MediaServerUpload == "" && + gw.ConfigValues().General.MediaDownloadPath == "") {   return   }   @@ -49910 +50210 @@ func (gw *Gateway) handleFiles(msg *config.Message) {     sha1sum := fmt.Sprintf("%x", sha1.Sum(*fi.Data))[:8]   - if gw.Config.General.MediaServerUpload != "" { + if gw.ConfigValues().General.MediaServerUpload != "" {   // Use MediaServerUpload. Upload using a PUT HTTP request and basicauth.   - url := gw.Config.General.MediaServerUpload + "/" + sha1sum + "/" + fi.Name + url := gw.ConfigValues().General.MediaServerUpload + "/" + sha1sum + "/" + fi.Name     req, err := http.NewRequest("PUT", url, bytes.NewReader(*fi.Data))   if err != nil { @@ -5217 +5247 @@ func (gw *Gateway) handleFiles(msg *config.Message) {   } else {   // Use MediaServerPath. Place the file on the current filesystem.   - dir := gw.Config.General.MediaDownloadPath + "/" + sha1sum + dir := gw.ConfigValues().General.MediaDownloadPath + "/" + sha1sum   err := os.Mkdir(dir, os.ModePerm)   if err != nil && !os.IsExist(err) {   flog.Errorf("mediaserver path failed, could not mkdir: %s %#v", err, err) @@ -5397 +5427 @@ func (gw *Gateway) handleFiles(msg *config.Message) {   }     // Download URL. - durl := gw.Config.General.MediaServerDownload + "/" + sha1sum + "/" + fi.Name + durl := gw.ConfigValues().General.MediaServerDownload + "/" + sha1sum + "/" + fi.Name     flog.Debugf("mediaserver download URL = %s", durl)   diff --git a/gateway/router.go b/gateway/router.go index 030b7b1..a6c6daf 100644 --- a/gateway/router.go +++ b/gateway/router.go @@ -227 +232 @@ package gateway    import (   "fmt" + "time"     "github.com/42wim/matterbridge/bridge"   "github.com/42wim/matterbridge/bridge/config"   samechannelgateway "github.com/42wim/matterbridge/gateway/samechannel" - // "github.com/davecgh/go-spew/spew" - "time"  )    type Router struct { + config.Config +   Gateways map[string]*Gateway   Message chan config.Message   MattermostPlugin chan config.Message - *config.Config  }   -func NewRouter(cfg *config.Config) (*Router, error) { - r := &Router{Message: make(chan config.Message), MattermostPlugin: make(chan config.Message), Gateways: make(map[string]*Gateway), Config: cfg} +func NewRouter(cfg config.Config) (*Router, error) { + r := &Router{ + Config: cfg, + Message: make(chan config.Message), + MattermostPlugin: make(chan config.Message), + Gateways: make(map[string]*Gateway), + }   sgw := samechannelgateway.New(cfg)   gwconfigs := sgw.GetConfig()   - for _, entry := range append(gwconfigs, cfg.Gateway...) { + for _, entry := range append(gwconfigs, cfg.ConfigValues().Gateway...) {   if !entry.Enable {   continue   } diff --git a/gateway/samechannel/samechannel.go b/gateway/samechannel/samechannel.go index 937d769..ea846e9 100644 --- a/gateway/samechannel/samechannel.go +++ b/gateway/samechannel/samechannel.go @@ -517 +517 @@ import (  )    type SameChannelGateway struct { - *config.Config + config.Config  }   -func New(cfg *config.Config) *SameChannelGateway { +func New(cfg config.Config) *SameChannelGateway {   return &SameChannelGateway{Config: cfg}  }    func (sgw *SameChannelGateway) GetConfig() []config.Gateway {   var gwconfigs []config.Gateway   cfg := sgw.Config - for _, gw := range cfg.SameChannelGateway { + for _, gw := range cfg.ConfigValues().SameChannelGateway {   gwconfig := config.Gateway{Name: gw.Name, Enable: gw.Enable}   for _, account := range gw.Accounts {   for _, channel := range gw.Channels { diff --git a/gateway/samechannel/samechannel_test.go b/gateway/samechannel/samechannel_test.go index 7c75444..c0e579a 100644 --- a/gateway/samechannel/samechannel_test.go +++ b/gateway/samechannel/samechannel_test.go @@ -116 +113 @@  package samechannelgateway    import ( - "fmt" -   "github.com/42wim/matterbridge/bridge/config" - "github.com/BurntSushi/toml"   "github.com/stretchr/testify/assert"     "testing"  )   -var testconfig = ` +const testConfig = `  [mattermost.test]  [slack.test]   @@ -2112 +1856 @@ var testconfig = ` channels = [ "testing","testing2","testing10"]  `   -func TestGetConfig(t *testing.T) { - var cfg *config.Config - if _, err := toml.Decode(testconfig, &cfg); err != nil { - fmt.Println(err) +var ( + expectedConfig = config.Gateway{ + Name: "blah", + Enable: true, + In: []config.Bridge(nil), + Out: []config.Bridge(nil), + InOut: []config.Bridge{ + { + Account: "mattermost.test", + Channel: "testing", + Options: config.ChannelOptions{Key: ""}, + SameChannel: true, + }, + { + Account: "mattermost.test", + Channel: "testing2", + Options: config.ChannelOptions{Key: ""}, + SameChannel: true, + }, + { + Account: "mattermost.test", + Channel: "testing10", + Options: config.ChannelOptions{Key: ""}, + SameChannel: true, + }, + { + Account: "slack.test", + Channel: "testing", + Options: config.ChannelOptions{Key: ""}, + SameChannel: true, + }, + { + Account: "slack.test", + Channel: "testing2", + Options: config.ChannelOptions{Key: ""}, + SameChannel: true, + }, + { + Account: "slack.test", + Channel: "testing10", + Options: config.ChannelOptions{Key: ""}, + SameChannel: true, + }, + },   } +) + +func TestGetConfig(t *testing.T) { + cfg := config.NewConfigFromString([]byte(testConfig))   sgw := New(cfg)   configs := sgw.GetConfig() - assert.Equal(t, []config.Gateway{{Name: "blah", Enable: true, In: []config.Bridge(nil), Out: []config.Bridge(nil), InOut: []config.Bridge{{Account: "mattermost.test", Channel: "testing", Options: config.ChannelOptions{Key: ""}, SameChannel: true}, {Account: "mattermost.test", Channel: "testing2", Options: config.ChannelOptions{Key: ""}, SameChannel: true}, {Account: "mattermost.test", Channel: "testing10", Options: config.ChannelOptions{Key: ""}, SameChannel: true}, {Account: "slack.test", Channel: "testing", Options: config.ChannelOptions{Key: ""}, SameChannel: true}, {Account: "slack.test", Channel: "testing2", Options: config.ChannelOptions{Key: ""}, SameChannel: true}, {Account: "slack.test", Channel: "testing10", Options: config.ChannelOptions{Key: ""}, SameChannel: true}}}}, configs) + assert.Equal(t, []config.Gateway{expectedConfig}, configs)  } diff --git a/matterbridge.go b/matterbridge.go index 8e852b0..90c0437 100644 --- a/matterbridge.go +++ b/matterbridge.go @@ -447 +447 @@ func main() {   flog.Println("WARNING: THIS IS A DEVELOPMENT VERSION. Things may break.")   }   cfg := config.NewConfig(*flagConfig) - cfg.General.Debug = *flagDebug + cfg.ConfigValues().General.Debug = *flagDebug   r, err := gateway.NewRouter(cfg)   if err != nil {   flog.Fatalf("Starting gateway failed: %s", err)