add config plugin ability

This commit is contained in:
notch
2021-01-15 14:12:02 +08:00
parent f462fc4f3b
commit 3f0ae42e2b
2 changed files with 89 additions and 11 deletions

View File

@@ -1,3 +1,6 @@
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for details.
//
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
@@ -6,7 +9,13 @@ package config
import (
"errors"
"io"
"net/http"
"os"
"path/filepath"
"plugin"
"strings"
"time"
)
// Provider 提供者接口
@@ -33,8 +42,38 @@ func (c *ProviderConfig) Load(builtins ...Provider) (Provider, error) {
}
}
// TODO: load a plugin provider
return nil, errors.New("The provider '" + c.Provider + "' could not be loaded. ")
// Attempt to load a plugin provider
p, err := plugin.Open(resolvePath(c.Provider))
if err != nil {
return nil, errors.New("The provider plugin '" + c.Provider + "' could not be opened. " + err.Error())
}
// Get the symbol
sym, err := p.Lookup("New")
if err != nil {
return nil, errors.New("The provider '" + c.Provider + "' does not contain 'func New() interface{}' symbol")
}
// Resolve the
pFactory, validFunc := sym.(*func() interface{})
if !validFunc {
return nil, errors.New("The provider '" + c.Provider + "' does not contain 'func New() interface{}' symbol")
}
// Construct the provider
provider, validProv := ((*pFactory)()).(Provider)
if !validProv {
return nil, errors.New("The provider '" + c.Provider + "' does not implement 'Provider'")
}
// Configure the provider
err = provider.Configure(c.Config)
if err != nil {
return nil, errors.New("The provider '" + c.Provider + "' could not be configured")
}
// Succesfully opened and configured a provider
return provider, nil
}
// LoadOrPanic 加载 Provider 如果失败直接 panics.
@@ -58,3 +97,49 @@ func LoadProvider(config *ProviderConfig, providers ...Provider) Provider {
// Load the provider according to the configuration
return config.LoadOrPanic(providers...)
}
func resolvePath(path string) string {
// If it's an url, download the file
if strings.HasPrefix(path, "http") {
f, err := httpFile(path)
if err != nil {
panic(err)
}
// Get the downloaded file path
path = f.Name()
}
// Make sure the path is absolute
path, _ = filepath.Abs(path)
return path
}
// DefaultClient used for http with a shorter timeout.
var defaultClient = &http.Client{
Timeout: 5 * time.Second,
}
// httpFile downloads a file from HTTP
var httpFile = func(url string) (*os.File, error) {
tokens := strings.Split(url, "/")
fileName := tokens[len(tokens)-1]
output, err := os.Create(fileName)
if err != nil {
return nil, err
}
defer output.Close()
response, err := http.Get(url)
if err != nil {
return nil, err
}
defer response.Body.Close()
if _, err := io.Copy(output, response.Body); err != nil {
return nil, err
}
return output, nil
}

View File

@@ -9,7 +9,6 @@ import (
"errors"
"io/ioutil"
"os"
"path/filepath"
"strings"
)
@@ -50,9 +49,3 @@ func (c *TLSConfig) Load() (*tls.Config, error) {
Certificates: []tls.Certificate{cer},
}, err
}
func resolvePath(path string) string {
// Make sure the path is absolute
path, _ = filepath.Abs(path)
return path
}