From dc1d5e4848f238be643c0023a0fe4ae4d887612a Mon Sep 17 00:00:00 2001 From: Simon Let Date: Fri, 4 Oct 2019 14:35:47 +0200 Subject: [PATCH] Add session watcher, improve history file handler Session watcher recieves all incoming records and and periodically checks if the session is still running. If session exits it sends message to other parts of RESH to drop the session. History handler recieves sessions to drop from session watcher. --- cmd/daemon/histfile.go | 57 ---------------------- cmd/daemon/main.go | 58 ++++++++++++++--------- go.mod | 1 + go.sum | 2 + pkg/histfile/histfile.go | 97 ++++++++++++++++++++++++++++++++++++++ pkg/sesswatch/sesswatch.go | 69 +++++++++++++++++++++++++++ 6 files changed, 205 insertions(+), 79 deletions(-) delete mode 100644 cmd/daemon/histfile.go create mode 100644 pkg/histfile/histfile.go create mode 100644 pkg/sesswatch/sesswatch.go diff --git a/cmd/daemon/histfile.go b/cmd/daemon/histfile.go deleted file mode 100644 index d608ca0..0000000 --- a/cmd/daemon/histfile.go +++ /dev/null @@ -1,57 +0,0 @@ -package main - -import ( - "encoding/json" - "log" - "os" - - "github.com/curusarn/resh/pkg/records" -) - -// HistfileWriter - reads records from channel, merges them and wrotes them to file -func HistfileWriter(input chan records.Record, outputPath string) { - sessions := map[string]records.Record{} - - for { - record := <-input - if record.PartOne { - if _, found := sessions[record.SessionID]; found { - log.Println("ERROR: Got another first part of the records before merging the previous one - overwriting!") - } - sessions[record.SessionID] = record - } else { - part1, found := sessions[record.SessionID] - if found == false { - log.Println("ERROR: Got second part of records and nothing to merge it with - ignoring!") - continue - } - delete(sessions, record.SessionID) - go mergeAndWriteRecord(part1, record, outputPath) - } - } -} - -func mergeAndWriteRecord(part1, part2 records.Record, outputPath string) { - err := part1.Merge(part2) - if err != nil { - log.Println("Error while merging", err) - return - } - recJSON, err := json.Marshal(part1) - if err != nil { - log.Println("Marshalling error", err) - return - } - f, err := os.OpenFile(outputPath, - os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Println("Could not open file", err) - return - } - defer f.Close() - _, err = f.Write(append(recJSON, []byte("\n")...)) - if err != nil { - log.Printf("Error while writing: %v, %s\n", part1, err) - return - } -} diff --git a/cmd/daemon/main.go b/cmd/daemon/main.go index 69c031f..5b9c4ea 100644 --- a/cmd/daemon/main.go +++ b/cmd/daemon/main.go @@ -15,7 +15,9 @@ import ( "github.com/BurntSushi/toml" "github.com/curusarn/resh/pkg/cfg" + "github.com/curusarn/resh/pkg/histfile" "github.com/curusarn/resh/pkg/records" + "github.com/curusarn/resh/pkg/sesswatch" ) // Version from git set during build @@ -84,31 +86,35 @@ func statusHandler(w http.ResponseWriter, r *http.Request) { } type recordHandler struct { - histfile chan records.Record + subscribers []chan records.Record } func (h *recordHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK\n")) - record := records.Record{} - jsn, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println("Error reading the body", err) - return - } + // run rest of the handler as goroutine to prevent any hangups + go func() { + if err != nil { + log.Println("Error reading the body", err) + return + } - err = json.Unmarshal(jsn, &record) - if err != nil { - log.Println("Decoding error: ", err) - log.Println("Payload: ", jsn) - return - } - h.histfile <- record - part := "2" - if record.PartOne { - part = "1" - } - log.Println("Received:", record.CmdLine, " - part", part) + record := records.Record{} + err = json.Unmarshal(jsn, &record) + if err != nil { + log.Println("Decoding error: ", err) + log.Println("Payload: ", jsn) + return + } + for _, sub := range h.subscribers { + sub <- record + } + part := "2" + if record.PartOne { + part = "1" + } + log.Println("Received:", record.CmdLine, " - part", part) + }() // fmt.Println("cmd:", r.CmdLine) // fmt.Println("pwd:", r.Pwd) @@ -117,11 +123,19 @@ func (h *recordHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func runServer(port int, outputPath string) { - histfile := make(chan records.Record) - go HistfileWriter(histfile, outputPath) + var recordSubscribers []chan records.Record + + histfileChan := make(chan records.Record) + recordSubscribers = append(recordSubscribers, histfileChan) + sessionsToDrop := make(chan string) + histfile.Go(histfileChan, outputPath, sessionsToDrop) + + sesswatchChan := make(chan records.Record) + recordSubscribers = append(recordSubscribers, sesswatchChan) + sesswatch.Go(sesswatchChan, []chan string{sessionsToDrop}, 10) http.HandleFunc("/status", statusHandler) - http.Handle("/record", &recordHandler{histfile: histfile}) + http.Handle("/record", &recordHandler{subscribers: recordSubscribers}) //http.Handle("/session_init", &sessionInitHandler{OutputPath: outputPath}) //http.Handle("/recall", &recallHandler{OutputPath: outputPath}) http.ListenAndServe(":"+strconv.Itoa(port), nil) diff --git a/go.mod b/go.mod index 64dd27e..910abe6 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/jpillora/longestcommon v0.0.0-20161227235612-adb9d91ee629 github.com/mattn/go-shellwords v1.0.6 github.com/mb-14/gomarkov v0.0.0-20190125094512-044dd0dcb5e7 + github.com/mitchellh/go-ps v0.0.0-20190716172923-621e5597135b github.com/schollz/progressbar v1.0.0 github.com/spf13/cobra v0.0.5 github.com/wcharczuk/go-chart v2.0.1+incompatible diff --git a/go.sum b/go.sum index e6a3477..beb087d 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/mattn/go-shellwords v1.0.6/go.mod h1:3xCvwCdWdlDJUrvuMn7Wuy9eWs4pE8vq github.com/mb-14/gomarkov v0.0.0-20190125094512-044dd0dcb5e7 h1:VsJjhYhufMGXICLwLYr8mFVMp8/A+YqmagMHnG/BA/4= github.com/mb-14/gomarkov v0.0.0-20190125094512-044dd0dcb5e7/go.mod h1:zQmHoMvvVJb7cxyt1wGT77lqUaeOFXlogOppOr4uHVo= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-ps v0.0.0-20190716172923-621e5597135b h1:9+ke9YJ9KGWw5ANXK6ozjoK47uI3uNbXv4YVINBnGm8= +github.com/mitchellh/go-ps v0.0.0-20190716172923-621e5597135b/go.mod h1:r1VsdOzOPt1ZSrGZWFoNhsAedKnEd6r9Np1+5blZCWk= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/pkg/histfile/histfile.go b/pkg/histfile/histfile.go new file mode 100644 index 0000000..5418c45 --- /dev/null +++ b/pkg/histfile/histfile.go @@ -0,0 +1,97 @@ +package histfile + +import ( + "encoding/json" + "log" + "os" + "sync" + + "github.com/curusarn/resh/pkg/records" +) + +type histfile struct { + mutex sync.Mutex + sessions map[string]records.Record + outputPath string +} + +// Go creates histfile and runs two gorutines on it +func Go(input chan records.Record, outputPath string, sessionsToDrop chan string) { + hf := histfile{sessions: map[string]records.Record{}, outputPath: outputPath} + go hf.writer(input) + go hf.sessionGC(sessionsToDrop) +} + +// sessionGC reads sessionIDs from channel and deletes them from histfile struct +func (h *histfile) sessionGC(sessionsToDrop chan string) { + for { + func() { + session := <-sessionsToDrop + log.Println("histfile: got session to drop", session) + h.mutex.Lock() + defer h.mutex.Unlock() + if part1, found := h.sessions[session]; found == true { + log.Println("histfile: Dropping session:", session) + delete(h.sessions, session) + go writeRecord(part1, h.outputPath) + } else { + log.Println("histfile: No hanging parts for session:", session) + } + }() + } +} + +// writer reads records from channel, merges them and writes them to file +func (h *histfile) writer(input chan records.Record) { + for { + func() { + record := <-input + h.mutex.Lock() + defer h.mutex.Unlock() + + if record.PartOne { + if _, found := h.sessions[record.SessionID]; found { + log.Println("histfile ERROR: Got another first part of the records before merging the previous one - overwriting!") + } + h.sessions[record.SessionID] = record + } else { + part1, found := h.sessions[record.SessionID] + if found == false { + log.Println("histfile ERROR: Got second part of records and nothing to merge it with - ignoring!") + } else { + delete(h.sessions, record.SessionID) + go mergeAndWriteRecord(part1, record, h.outputPath) + } + } + }() + } +} + +func mergeAndWriteRecord(part1, part2 records.Record, outputPath string) { + err := part1.Merge(part2) + if err != nil { + log.Println("Error while merging", err) + return + } + writeRecord(part1, outputPath) +} + +func writeRecord(rec records.Record, outputPath string) { + recJSON, err := json.Marshal(rec) + if err != nil { + log.Println("Marshalling error", err) + return + } + f, err := os.OpenFile(outputPath, + os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Println("Could not open file", err) + return + } + defer f.Close() + _, err = f.Write(append(recJSON, []byte("\n")...)) + if err != nil { + log.Printf("Error while writing: %v, %s\n", rec, err) + return + } +} diff --git a/pkg/sesswatch/sesswatch.go b/pkg/sesswatch/sesswatch.go new file mode 100644 index 0000000..54b2c09 --- /dev/null +++ b/pkg/sesswatch/sesswatch.go @@ -0,0 +1,69 @@ +package sesswatch + +import ( + "log" + "sync" + "time" + + "github.com/curusarn/resh/pkg/records" + "github.com/mitchellh/go-ps" +) + +type sesswatch struct { + sessionsToDrop []chan string + sleepSeconds uint + + watchedSessions map[string]bool + mutex sync.Mutex +} + +// Go runs the session watcher - watches sessions and sends +func Go(input chan records.Record, sessionsToDrop []chan string, sleepSeconds uint) { + sw := sesswatch{sessionsToDrop: sessionsToDrop, sleepSeconds: sleepSeconds, watchedSessions: map[string]bool{}} + go sw.waiter(input) +} + +func (s *sesswatch) waiter(sessionsToWatch chan records.Record) { + for { + func() { + record := <-sessionsToWatch + session := record.SessionID + pid := record.SessionPid + if record.PartOne == false { + log.Println("sesswatch: part2 - ignoring:", session, "~", pid) + return // continue + } + log.Println("sesswatch: got session ~ pid:", session, "~", pid) + s.mutex.Lock() + defer s.mutex.Unlock() + if s.watchedSessions[session] == false { + log.Println("sesswatch: start watching NEW session ~ pid:", session, "~", pid) + s.watchedSessions[session] = true + go s.watcher(session, record.SessionPid) + } + }() + } +} + +func (s *sesswatch) watcher(sessionID string, sessionPID int) { + for { + time.Sleep(time.Duration(s.sleepSeconds) * time.Second) + proc, err := ps.FindProcess(sessionPID) + if err != nil { + log.Println("sesswatch ERROR: error while finding process:", sessionPID) + } else if proc == nil { + log.Println("sesswatch: Dropping session ~ pid:", sessionID, "~", sessionPID) + func() { + s.mutex.Lock() + defer s.mutex.Unlock() + s.watchedSessions[sessionID] = false + }() + for _, ch := range s.sessionsToDrop { + log.Println("sesswatch: sending 'drop session' message ...") + ch <- sessionID + log.Println("sesswatch: sending 'drop session' message DONE") + } + break + } + } +}