diff --git a/cmd/daemon/histfile.go b/cmd/daemon/histfile.go deleted file mode 100644 index d608ca0..0000000 --- a/cmd/daemon/histfile.go +++ /dev/null @@ -1,57 +0,0 @@ -package main - -import ( - "encoding/json" - "log" - "os" - - "github.com/curusarn/resh/pkg/records" -) - -// HistfileWriter - reads records from channel, merges them and wrotes them to file -func HistfileWriter(input chan records.Record, outputPath string) { - sessions := map[string]records.Record{} - - for { - record := <-input - if record.PartOne { - if _, found := sessions[record.SessionID]; found { - log.Println("ERROR: Got another first part of the records before merging the previous one - overwriting!") - } - sessions[record.SessionID] = record - } else { - part1, found := sessions[record.SessionID] - if found == false { - log.Println("ERROR: Got second part of records and nothing to merge it with - ignoring!") - continue - } - delete(sessions, record.SessionID) - go mergeAndWriteRecord(part1, record, outputPath) - } - } -} - -func mergeAndWriteRecord(part1, part2 records.Record, outputPath string) { - err := part1.Merge(part2) - if err != nil { - log.Println("Error while merging", err) - return - } - recJSON, err := json.Marshal(part1) - if err != nil { - log.Println("Marshalling error", err) - return - } - f, err := os.OpenFile(outputPath, - os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Println("Could not open file", err) - return - } - defer f.Close() - _, err = f.Write(append(recJSON, []byte("\n")...)) - if err != nil { - log.Printf("Error while writing: %v, %s\n", part1, err) - return - } -} diff --git a/cmd/daemon/main.go b/cmd/daemon/main.go index 69c031f..5b9c4ea 100644 --- a/cmd/daemon/main.go +++ b/cmd/daemon/main.go @@ -15,7 +15,9 @@ import ( "github.com/BurntSushi/toml" "github.com/curusarn/resh/pkg/cfg" + "github.com/curusarn/resh/pkg/histfile" "github.com/curusarn/resh/pkg/records" + "github.com/curusarn/resh/pkg/sesswatch" ) // Version from git set during build @@ -84,31 +86,35 @@ func statusHandler(w http.ResponseWriter, r *http.Request) { } type recordHandler struct { - histfile chan records.Record + subscribers []chan records.Record } func (h *recordHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK\n")) - record := records.Record{} - jsn, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println("Error reading the body", err) - return - } + // run rest of the handler as goroutine to prevent any hangups + go func() { + if err != nil { + log.Println("Error reading the body", err) + return + } - err = json.Unmarshal(jsn, &record) - if err != nil { - log.Println("Decoding error: ", err) - log.Println("Payload: ", jsn) - return - } - h.histfile <- record - part := "2" - if record.PartOne { - part = "1" - } - log.Println("Received:", record.CmdLine, " - part", part) + record := records.Record{} + err = json.Unmarshal(jsn, &record) + if err != nil { + log.Println("Decoding error: ", err) + log.Println("Payload: ", jsn) + return + } + for _, sub := range h.subscribers { + sub <- record + } + part := "2" + if record.PartOne { + part = "1" + } + log.Println("Received:", record.CmdLine, " - part", part) + }() // fmt.Println("cmd:", r.CmdLine) // fmt.Println("pwd:", r.Pwd) @@ -117,11 +123,19 @@ func (h *recordHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func runServer(port int, outputPath string) { - histfile := make(chan records.Record) - go HistfileWriter(histfile, outputPath) + var recordSubscribers []chan records.Record + + histfileChan := make(chan records.Record) + recordSubscribers = append(recordSubscribers, histfileChan) + sessionsToDrop := make(chan string) + histfile.Go(histfileChan, outputPath, sessionsToDrop) + + sesswatchChan := make(chan records.Record) + recordSubscribers = append(recordSubscribers, sesswatchChan) + sesswatch.Go(sesswatchChan, []chan string{sessionsToDrop}, 10) http.HandleFunc("/status", statusHandler) - http.Handle("/record", &recordHandler{histfile: histfile}) + http.Handle("/record", &recordHandler{subscribers: recordSubscribers}) //http.Handle("/session_init", &sessionInitHandler{OutputPath: outputPath}) //http.Handle("/recall", &recallHandler{OutputPath: outputPath}) http.ListenAndServe(":"+strconv.Itoa(port), nil) diff --git a/go.mod b/go.mod index 64dd27e..910abe6 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/jpillora/longestcommon v0.0.0-20161227235612-adb9d91ee629 github.com/mattn/go-shellwords v1.0.6 github.com/mb-14/gomarkov v0.0.0-20190125094512-044dd0dcb5e7 + github.com/mitchellh/go-ps v0.0.0-20190716172923-621e5597135b github.com/schollz/progressbar v1.0.0 github.com/spf13/cobra v0.0.5 github.com/wcharczuk/go-chart v2.0.1+incompatible diff --git a/go.sum b/go.sum index e6a3477..beb087d 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/mattn/go-shellwords v1.0.6/go.mod h1:3xCvwCdWdlDJUrvuMn7Wuy9eWs4pE8vq github.com/mb-14/gomarkov v0.0.0-20190125094512-044dd0dcb5e7 h1:VsJjhYhufMGXICLwLYr8mFVMp8/A+YqmagMHnG/BA/4= github.com/mb-14/gomarkov v0.0.0-20190125094512-044dd0dcb5e7/go.mod h1:zQmHoMvvVJb7cxyt1wGT77lqUaeOFXlogOppOr4uHVo= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-ps v0.0.0-20190716172923-621e5597135b h1:9+ke9YJ9KGWw5ANXK6ozjoK47uI3uNbXv4YVINBnGm8= +github.com/mitchellh/go-ps v0.0.0-20190716172923-621e5597135b/go.mod h1:r1VsdOzOPt1ZSrGZWFoNhsAedKnEd6r9Np1+5blZCWk= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/pkg/histfile/histfile.go b/pkg/histfile/histfile.go new file mode 100644 index 0000000..5418c45 --- /dev/null +++ b/pkg/histfile/histfile.go @@ -0,0 +1,97 @@ +package histfile + +import ( + "encoding/json" + "log" + "os" + "sync" + + "github.com/curusarn/resh/pkg/records" +) + +type histfile struct { + mutex sync.Mutex + sessions map[string]records.Record + outputPath string +} + +// Go creates histfile and runs two gorutines on it +func Go(input chan records.Record, outputPath string, sessionsToDrop chan string) { + hf := histfile{sessions: map[string]records.Record{}, outputPath: outputPath} + go hf.writer(input) + go hf.sessionGC(sessionsToDrop) +} + +// sessionGC reads sessionIDs from channel and deletes them from histfile struct +func (h *histfile) sessionGC(sessionsToDrop chan string) { + for { + func() { + session := <-sessionsToDrop + log.Println("histfile: got session to drop", session) + h.mutex.Lock() + defer h.mutex.Unlock() + if part1, found := h.sessions[session]; found == true { + log.Println("histfile: Dropping session:", session) + delete(h.sessions, session) + go writeRecord(part1, h.outputPath) + } else { + log.Println("histfile: No hanging parts for session:", session) + } + }() + } +} + +// writer reads records from channel, merges them and writes them to file +func (h *histfile) writer(input chan records.Record) { + for { + func() { + record := <-input + h.mutex.Lock() + defer h.mutex.Unlock() + + if record.PartOne { + if _, found := h.sessions[record.SessionID]; found { + log.Println("histfile ERROR: Got another first part of the records before merging the previous one - overwriting!") + } + h.sessions[record.SessionID] = record + } else { + part1, found := h.sessions[record.SessionID] + if found == false { + log.Println("histfile ERROR: Got second part of records and nothing to merge it with - ignoring!") + } else { + delete(h.sessions, record.SessionID) + go mergeAndWriteRecord(part1, record, h.outputPath) + } + } + }() + } +} + +func mergeAndWriteRecord(part1, part2 records.Record, outputPath string) { + err := part1.Merge(part2) + if err != nil { + log.Println("Error while merging", err) + return + } + writeRecord(part1, outputPath) +} + +func writeRecord(rec records.Record, outputPath string) { + recJSON, err := json.Marshal(rec) + if err != nil { + log.Println("Marshalling error", err) + return + } + f, err := os.OpenFile(outputPath, + os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Println("Could not open file", err) + return + } + defer f.Close() + _, err = f.Write(append(recJSON, []byte("\n")...)) + if err != nil { + log.Printf("Error while writing: %v, %s\n", rec, err) + return + } +} diff --git a/pkg/sesswatch/sesswatch.go b/pkg/sesswatch/sesswatch.go new file mode 100644 index 0000000..54b2c09 --- /dev/null +++ b/pkg/sesswatch/sesswatch.go @@ -0,0 +1,69 @@ +package sesswatch + +import ( + "log" + "sync" + "time" + + "github.com/curusarn/resh/pkg/records" + "github.com/mitchellh/go-ps" +) + +type sesswatch struct { + sessionsToDrop []chan string + sleepSeconds uint + + watchedSessions map[string]bool + mutex sync.Mutex +} + +// Go runs the session watcher - watches sessions and sends +func Go(input chan records.Record, sessionsToDrop []chan string, sleepSeconds uint) { + sw := sesswatch{sessionsToDrop: sessionsToDrop, sleepSeconds: sleepSeconds, watchedSessions: map[string]bool{}} + go sw.waiter(input) +} + +func (s *sesswatch) waiter(sessionsToWatch chan records.Record) { + for { + func() { + record := <-sessionsToWatch + session := record.SessionID + pid := record.SessionPid + if record.PartOne == false { + log.Println("sesswatch: part2 - ignoring:", session, "~", pid) + return // continue + } + log.Println("sesswatch: got session ~ pid:", session, "~", pid) + s.mutex.Lock() + defer s.mutex.Unlock() + if s.watchedSessions[session] == false { + log.Println("sesswatch: start watching NEW session ~ pid:", session, "~", pid) + s.watchedSessions[session] = true + go s.watcher(session, record.SessionPid) + } + }() + } +} + +func (s *sesswatch) watcher(sessionID string, sessionPID int) { + for { + time.Sleep(time.Duration(s.sleepSeconds) * time.Second) + proc, err := ps.FindProcess(sessionPID) + if err != nil { + log.Println("sesswatch ERROR: error while finding process:", sessionPID) + } else if proc == nil { + log.Println("sesswatch: Dropping session ~ pid:", sessionID, "~", sessionPID) + func() { + s.mutex.Lock() + defer s.mutex.Unlock() + s.watchedSessions[sessionID] = false + }() + for _, ch := range s.sessionsToDrop { + log.Println("sesswatch: sending 'drop session' message ...") + ch <- sessionID + log.Println("sesswatch: sending 'drop session' message DONE") + } + break + } + } +}