diff --git a/log/hook/splitstream.go b/log/hook/splitstream.go index 3ba12d8..2ec293a 100644 --- a/log/hook/splitstream.go +++ b/log/hook/splitstream.go @@ -2,41 +2,46 @@ package hook import ( "github.com/sirupsen/logrus" - "os" + "io" + "io/ioutil" ) -type SplitStdoutStreamHook struct{} - -func (soh *SplitStdoutStreamHook) Levels() []logrus.Level { - return []logrus.Level{ - logrus.DebugLevel, - logrus.InfoLevel, - } +type writerHook struct { + writer io.Writer + levels []logrus.Level } -func (soh *SplitStdoutStreamHook) Fire(entry *logrus.Entry) error { + +func (h *writerHook) Levels() []logrus.Level { + return h.levels +} + +func (h *writerHook) Fire(entry *logrus.Entry) error { serialized, err := entry.Logger.Formatter.Format(entry) if err != nil { return err } - _, err = os.Stdout.Write(serialized) + _, err = h.writer.Write(serialized) return err } -type SplitStderrStreamHook struct{} +func RegisterSplitLogger(logger *logrus.Logger, outWriter io.Writer, errWriter io.Writer) { + logger.SetOutput(ioutil.Discard) -func (seh *SplitStderrStreamHook) Levels() []logrus.Level { - return []logrus.Level{ - logrus.WarnLevel, - logrus.ErrorLevel, - logrus.FatalLevel, - logrus.PanicLevel, - } -} -func (seh *SplitStderrStreamHook) Fire(entry *logrus.Entry) error { - serialized, err := entry.Logger.Formatter.Format(entry) - if err != nil { - return err - } - _, err = os.Stderr.Write(serialized) - return err + logger.AddHook(&writerHook{ + writer: outWriter, + levels: []logrus.Level{ + logrus.DebugLevel, + logrus.InfoLevel, + }, + }) + + logger.AddHook(&writerHook{ + writer: errWriter, + levels: []logrus.Level{ + logrus.WarnLevel, + logrus.ErrorLevel, + logrus.FatalLevel, + logrus.PanicLevel, + }, + }) } diff --git a/log/hook/splitstream_test.go b/log/hook/splitstream_test.go index 2c53301..6ec6537 100644 --- a/log/hook/splitstream_test.go +++ b/log/hook/splitstream_test.go @@ -1,79 +1,69 @@ package hook import ( - "bytes" "github.com/sirupsen/logrus" - "io" - "io/ioutil" - "os" - "strings" + "github.com/stretchr/testify/assert" "testing" + "time" ) +type testWriter struct { + c chan []byte +} + +func (w testWriter) Write(p []byte) (int, error) { + w.c <- p + return len(p), nil +} + func TestSplitStdoutStreamHook_Fire(t *testing.T) { + outWriter := testWriter{c: make(chan []byte, 2)} + errWriter := testWriter{c: make(chan []byte, 2)} + defaultWriter := testWriter{c: make(chan []byte, 2)} - logrus.SetLevel(logrus.DebugLevel) - logrus.SetOutput(ioutil.Discard) - logrus.AddHook(&SplitStdoutStreamHook{}) + log := logrus.New() + log.SetOutput(defaultWriter) + log.SetLevel(logrus.DebugLevel) - oldStdout := os.Stdout - sr, sw, _ := os.Pipe() - os.Stdout = sw + RegisterSplitLogger(log, outWriter, errWriter) - outChan := make(chan string, 2) - go func() { - var stdoutBuff bytes.Buffer - io.Copy(&stdoutBuff, sr) - outChan <- stdoutBuff.String() - outChan <- stdoutBuff.String() - }() + log.Debug("out1") + log.Info("out2") + log.Warn("err1") + log.Error("err2") - logrus.Debug("out1") - logrus.Info("out2") - - sw.Close() - os.Stdout = oldStdout - - stdoutMsg := <-outChan - if !strings.Contains(stdoutMsg, "msg=out1") { - t.Fatalf("failed to split info level log into stdout") + select { + case log := <-outWriter.c: + assert.Contains(t, string(log), "out1") + case <-time.After(time.Second): + t.Fatalf("timed out waiting for out log") } - stdoutMsg = <-outChan - if !strings.Contains(stdoutMsg, "msg=out2") { - t.Fatalf("failed to split info level log into stdout") - } -} - -func TestSplitStderrStreamHook_Fire(t *testing.T) { - logrus.SetOutput(ioutil.Discard) - logrus.AddHook(&SplitStderrStreamHook{}) - - oldStderr := os.Stderr - er, ew, _ := os.Pipe() - os.Stderr = ew - - errChan := make(chan string, 2) - go func() { - var stderrOut bytes.Buffer - io.Copy(&stderrOut, er) - errChan <- stderrOut.String() - errChan <- stderrOut.String() - }() - - logrus.Warn("err1") - logrus.Error("err2") - - ew.Close() - os.Stderr = oldStderr - - stderrMsg := <-errChan - if !strings.Contains(stderrMsg, "msg=err1") { - t.Fatalf("failed to split error level log into stderr") - } - - stderrMsg = <-errChan - if !strings.Contains(stderrMsg, "msg=err2") { - t.Fatalf("failed to split error level log into stderr") + select { + case log := <-outWriter.c: + assert.Contains(t, string(log), "out2") + case <-time.After(time.Second): + t.Fatalf("timed out waiting for out log") + } + + select { + case log := <-errWriter.c: + assert.Contains(t, string(log), "err1") + case <-time.After(time.Second): + t.Fatalf("timed out waiting for err log") + } + + select { + case log := <-errWriter.c: + assert.Contains(t, string(log), "err2") + case <-time.After(time.Second): + t.Fatalf("timed out waiting for err log") + } + + select { + case <-defaultWriter.c: + t.Fatalf("got default log") + case <-time.After(time.Second): + // Noop } } diff --git a/main.go b/main.go index 4586b9a..407a731 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,6 @@ import ( "github.com/aptible/supercronic/crontab" "github.com/aptible/supercronic/log/hook" "github.com/sirupsen/logrus" - "io/ioutil" "os" "os/signal" "sync" @@ -38,9 +37,11 @@ func main() { } if *splitLogs { - logrus.SetOutput(ioutil.Discard) - logrus.AddHook(&hook.SplitStderrStreamHook{}) - logrus.AddHook(&hook.SplitStdoutStreamHook{}) + hook.RegisterSplitLogger( + logrus.StandardLogger(), + os.Stdout, + os.Stderr, + ) } if flag.NArg() != 1 {