package handlers import ( "context" "encoding/json" "fmt" "net/http" "time" "github.com/gorilla/websocket" "supervisor/internal/domain" "supervisor/internal/ws" ) type WSTerminalHandler struct { manager SessionService upgrader websocket.Upgrader } func NewWSTerminalHandler(manager SessionService) *WSTerminalHandler { return &WSTerminalHandler{ manager: manager, upgrader: websocket.Upgrader{ CheckOrigin: func(_ *http.Request) bool { return true }, }, } } func (h *WSTerminalHandler) Handle(w http.ResponseWriter, r *http.Request) { sessionID := r.PathValue("id") current, err := h.manager.GetSession(r.Context(), sessionID) if err != nil { writeError(w, http.StatusNotFound, err) return } conn, err := h.upgrader.Upgrade(w, r, nil) if err != nil { return } defer conn.Close() events, cancel, err := h.manager.Subscribe(sessionID) if err != nil { _ = conn.WriteJSON(errorEnvelope(sessionID, err.Error())) return } defer cancel() if err := conn.WriteJSON(map[string]any{ "type": string(domain.EventSessionStatus), "session": sessionID, "payload": ws.SessionStatusPayload{Status: string(current.Status), ExitCode: current.ExitCode}, }); err != nil { return } scrollback, _ := h.manager.Scrollback(sessionID) if len(scrollback) > 0 { if err := conn.WriteJSON(map[string]any{ "type": string(domain.EventTerminalOutput), "session": sessionID, "payload": ws.TerminalOutputPayload{Data: string(scrollback)}, }); err != nil { return } } ctx, stop := context.WithCancel(r.Context()) defer stop() readErr := make(chan error, 1) go func() { readErr <- h.readLoop(ctx, conn, sessionID) }() for { select { case err := <-readErr: if err == nil || websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { return } _ = conn.WriteJSON(errorEnvelope(sessionID, err.Error())) return case event, ok := <-events: if !ok { return } if err := conn.WriteJSON(toEnvelope(event)); err != nil { return } } } } func (h *WSTerminalHandler) readLoop(ctx context.Context, conn *websocket.Conn, sessionID string) error { conn.SetReadLimit(1 << 20) _ = conn.SetReadDeadline(time.Time{}) for { select { case <-ctx.Done(): return nil default: } _, msg, err := conn.ReadMessage() if err != nil { return err } var envelope ws.Envelope if err := json.Unmarshal(msg, &envelope); err != nil { return err } switch envelope.Type { case "terminal.input": var payload ws.TerminalInputPayload if err := json.Unmarshal(envelope.Payload, &payload); err != nil { return err } if err := h.manager.WriteInput(ctx, sessionID, payload.Data); err != nil { return err } case "terminal.resize": var payload ws.TerminalResizePayload if err := json.Unmarshal(envelope.Payload, &payload); err != nil { return err } if err := h.manager.Resize(ctx, sessionID, payload.Cols, payload.Rows); err != nil { return err } default: return fmt.Errorf("unsupported message type: %s", envelope.Type) } } } func toEnvelope(event domain.Event) map[string]any { return map[string]any{ "type": string(event.Type), "session": event.SessionID, "payload": event.Payload, } } func errorEnvelope(sessionID string, message string) map[string]any { return map[string]any{ "type": "error", "session": sessionID, "payload": ws.ErrorPayload{Message: message}, } }