From 3701a9c8cfc9115baa18e38e1310e80e16b0ed12 Mon Sep 17 00:00:00 2001 From: Simon Let Date: Wed, 27 Nov 2019 16:21:27 +0100 Subject: [PATCH] handle signals - write pending records on shutdown --- cmd/daemon/main.go | 4 ++- cmd/daemon/run-server.go | 25 +++++++++---- pkg/histfile/histfile.go | 59 +++++++++++++++++++++---------- pkg/signalhandler/signalhander.go | 57 +++++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 26 deletions(-) create mode 100644 pkg/signalhandler/signalhander.go diff --git a/cmd/daemon/main.go b/cmd/daemon/main.go index a0ff6df..8bc48f9 100644 --- a/cmd/daemon/main.go +++ b/cmd/daemon/main.go @@ -70,10 +70,12 @@ func main() { log.Fatal("Could not create pidfile", err) } runServer(config, historyPath) + log.Println("main: Removing pidfile ...") err = os.Remove(pidfilePath) if err != nil { log.Println("Could not delete pidfile", err) } + log.Println("main: Shutdown - bye") } func statusHandler(w http.ResponseWriter, r *http.Request) { @@ -92,7 +94,7 @@ func killDaemon(pidfile string) error { if err != nil { log.Fatal("Pidfile contents are malformed", err) } - cmd := exec.Command("kill", strconv.Itoa(pid)) + cmd := exec.Command("kill", "-s", "sigint", strconv.Itoa(pid)) err = cmd.Run() if err != nil { log.Printf("Command finished with error: %v", err) diff --git a/cmd/daemon/run-server.go b/cmd/daemon/run-server.go index 6eb44b7..34a530e 100644 --- a/cmd/daemon/run-server.go +++ b/cmd/daemon/run-server.go @@ -2,6 +2,7 @@ package main import ( "net/http" + "os" "strconv" "github.com/curusarn/resh/pkg/cfg" @@ -9,12 +10,16 @@ import ( "github.com/curusarn/resh/pkg/records" "github.com/curusarn/resh/pkg/sesshist" "github.com/curusarn/resh/pkg/sesswatch" + "github.com/curusarn/resh/pkg/signalhandler" ) func runServer(config cfg.Config, historyPath string) { var recordSubscribers []chan records.Record var sessionInitSubscribers []chan records.Record var sessionDropSubscribers []chan string + var signalSubscribers []chan os.Signal + + shutdown := make(chan string) // sessshist sesshistSessionsToInit := make(chan records.Record) @@ -29,7 +34,9 @@ func runServer(config cfg.Config, historyPath string) { recordSubscribers = append(recordSubscribers, histfileRecords) histfileSessionsToDrop := make(chan string) sessionDropSubscribers = append(sessionDropSubscribers, histfileSessionsToDrop) - histfileBox := histfile.New(histfileRecords, historyPath, 10000, histfileSessionsToDrop) + histfileSignals := make(chan os.Signal) + signalSubscribers = append(signalSubscribers, histfileSignals) + histfileBox := histfile.New(histfileRecords, historyPath, 10000, histfileSessionsToDrop, histfileSignals, shutdown) // sesshist New sesshistDispatch := sesshist.NewDispatch(sesshistSessionsToInit, sesshistSessionsToDrop, sesshistRecords, histfileBox, config.SesshistInitHistorySize) @@ -40,9 +47,15 @@ func runServer(config cfg.Config, historyPath string) { sesswatch.Go(sesswatchSessionsToWatch, sessionDropSubscribers, config.SesswatchPeriodSeconds) // handlers - http.HandleFunc("/status", statusHandler) - http.Handle("/record", &recordHandler{subscribers: recordSubscribers}) - http.Handle("/session_init", &sessionInitHandler{subscribers: sessionInitSubscribers}) - http.Handle("/recall", &recallHandler{sesshistDispatch: sesshistDispatch}) - http.ListenAndServe(":"+strconv.Itoa(config.Port), nil) + mux := http.NewServeMux() + mux.HandleFunc("/status", statusHandler) + mux.Handle("/record", &recordHandler{subscribers: recordSubscribers}) + mux.Handle("/session_init", &sessionInitHandler{subscribers: sessionInitSubscribers}) + mux.Handle("/recall", &recallHandler{sesshistDispatch: sesshistDispatch}) + + server := &http.Server{Addr: ":" + strconv.Itoa(config.Port), Handler: mux} + go server.ListenAndServe() + + // signalhandler - takes over the main goroutine so when signal handler exists the whole program exits + signalhandler.Run(signalSubscribers, shutdown, server) } diff --git a/pkg/histfile/histfile.go b/pkg/histfile/histfile.go index 5964603..7e1bcc2 100644 --- a/pkg/histfile/histfile.go +++ b/pkg/histfile/histfile.go @@ -23,14 +23,16 @@ type Histfile struct { } // New creates new histfile and runs two gorutines on it -func New(input chan records.Record, historyPath string, initHistSize int, sessionsToDrop chan string) *Histfile { +func New(input chan records.Record, historyPath string, initHistSize int, sessionsToDrop chan string, + signals chan os.Signal, shutdownDone chan string) *Histfile { + hf := Histfile{ sessions: map[string]records.Record{}, historyPath: historyPath, cmdLinesLastIndex: map[string]int{}, } go hf.loadHistory(initHistSize) - go hf.writer(input) + go hf.writer(input, signals, shutdownDone) go hf.sessionGC(sessionsToDrop) return &hf } @@ -61,33 +63,52 @@ func (h *Histfile) sessionGC(sessionsToDrop chan string) { } // writer reads records from channel, merges them and writes them to file -func (h *Histfile) writer(input chan records.Record) { +func (h *Histfile) writer(input chan records.Record, signals chan os.Signal, shutdownDone chan string) { for { func() { - record := <-input - h.sessionsMutex.Lock() - defer h.sessionsMutex.Unlock() + select { + case record := <-input: + h.sessionsMutex.Lock() + defer h.sessionsMutex.Unlock() - // allows nested sessions to merge records properly - mergeID := record.SessionID + "_" + strconv.Itoa(record.Shlvl) - if record.PartOne { - if _, found := h.sessions[mergeID]; found { - log.Println("histfile WARN: Got another first part of the records before merging the previous one - overwriting! " + - "(this happens in bash because bash-preexec runs when it's not supposed to)") - } - h.sessions[mergeID] = record - } else { - if part1, found := h.sessions[mergeID]; found == false { - log.Println("histfile ERROR: Got second part of records and nothing to merge it with - ignoring! (mergeID:", mergeID, ")") + // allows nested sessions to merge records properly + mergeID := record.SessionID + "_" + strconv.Itoa(record.Shlvl) + if record.PartOne { + if _, found := h.sessions[mergeID]; found { + log.Println("histfile WARN: Got another first part of the records before merging the previous one - overwriting! " + + "(this happens in bash because bash-preexec runs when it's not supposed to)") + } + h.sessions[mergeID] = record } else { - delete(h.sessions, mergeID) - go h.mergeAndWriteRecord(part1, record) + if part1, found := h.sessions[mergeID]; found == false { + log.Println("histfile ERROR: Got second part of records and nothing to merge it with - ignoring! (mergeID:", mergeID, ")") + } else { + delete(h.sessions, mergeID) + go h.mergeAndWriteRecord(part1, record) + } } + case sig := <-signals: + log.Println("histfile: Got signal " + sig.String()) + h.sessionsMutex.Lock() + defer h.sessionsMutex.Unlock() + log.Println("histfile DEBUG: Unlocked mutex") + + for sessID, record := range h.sessions { + log.Panicln("histfile WARN: Writing incomplete record for session " + sessID) + h.writeRecord(record) + } + log.Println("histfile DEBUG: Shutdown success") + shutdownDone <- "histfile" + return } }() } } +func (h *Histfile) writeRecord(part1 records.Record) { + writeRecord(part1, h.historyPath) +} + func (h *Histfile) mergeAndWriteRecord(part1, part2 records.Record) { err := part1.Merge(part2) if err != nil { diff --git a/pkg/signalhandler/signalhander.go b/pkg/signalhandler/signalhander.go new file mode 100644 index 0000000..c3c201b --- /dev/null +++ b/pkg/signalhandler/signalhander.go @@ -0,0 +1,57 @@ +package signalhandler + +import ( + "context" + "log" + "net/http" + "os" + "os/signal" + "strconv" + "syscall" + "time" +) + +func sendSignals(sig os.Signal, subscribers []chan os.Signal, done chan string) { + for _, sub := range subscribers { + sub <- sig + } + chanCount := len(subscribers) + start := time.Now() + delay := time.Millisecond * 100 + timeout := time.Millisecond * 2000 + + for { + select { + case _ = <-done: + chanCount-- + if chanCount == 0 { + log.Println("signalhandler: All boxes shut down successfully") + return + } + default: + time.Sleep(delay) + } + if time.Since(start) > timeout { + log.Println("signalhandler: Timouted while waiting for proper shutdown - " + strconv.Itoa(chanCount) + " boxes are up after " + timeout.String()) + return + } + } +} + +// Run catches and handles signals +func Run(subscribers []chan os.Signal, done chan string, server *http.Server) { + signals := make(chan os.Signal, 1) + + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + + sig := <-signals + log.Println("signalhandler: Got signal " + sig.String()) + + log.Println("signalhandler: Sending signals to Subscribers") + sendSignals(sig, subscribers, done) + + log.Println("signalhandler: Shutting down the server") + if err := server.Shutdown(context.Background()); err != nil { + log.Printf("HTTP server Shutdown: %v", err) + } +}