Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions app/jobs/heartbeatjob/heartbeatjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@ import (
"sync"

"hostlink/app/services/heartbeat"
"hostlink/domain/task"
"hostlink/internal/telemetry"

"github.com/labstack/gommon/log"
)

type TriggerFunc func(context.Context, func() error)

type TaskEnqueuer interface {
Enqueue(ctx context.Context, t task.Task) error
}

type HeartbeatJobConfig struct {
Trigger TriggerFunc
}
Expand All @@ -36,21 +44,43 @@ func NewWithConfig(cfg HeartbeatJobConfig) HeartbeatJob {
}
}

func (hj *HeartbeatJob) Register(ctx context.Context, svc heartbeat.Service) context.CancelFunc {
func (hj *HeartbeatJob) Register(ctx context.Context, svc heartbeat.Service, enqueuers ...TaskEnqueuer) context.CancelFunc {
ctx, cancel := context.WithCancel(ctx)
hj.cancel = cancel

hj.wg.Add(1)
go func() {
defer hj.wg.Done()
hj.config.Trigger(ctx, func() error {
return svc.Send()
pendingTasks, err := svc.Send()
if err != nil {
return err
}
hj.enqueueTasks(ctx, pendingTasks, enqueuers...)
return nil
})
}()

return cancel
}

func (hj *HeartbeatJob) enqueueTasks(ctx context.Context, tasks []task.Task, enqueuers ...TaskEnqueuer) {
if len(tasks) == 0 || len(enqueuers) == 0 {
return
}
for _, t := range tasks {
if t.Status == "completed" || t.Status == "failed" || t.Status == "cancelled" {
continue
}
for _, enq := range enqueuers {
if err := enq.Enqueue(ctx, t); err != nil {
log.Errorf("failed to enqueue task %s from heartbeat: %v", t.ID, err)
}
}
}
telemetry.Metric("hostlink.heartbeat.tasks_delivered", len(tasks), map[string]any{})
}

func (hj *HeartbeatJob) Shutdown() {
if hj.cancel != nil {
hj.cancel()
Expand Down
14 changes: 9 additions & 5 deletions app/jobs/heartbeatjob/heartbeatjob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"hostlink/domain/task"
)

type MockHeartbeatService struct {
mock.Mock
}

func (m *MockHeartbeatService) Send() error {
func (m *MockHeartbeatService) Send() ([]task.Task, error) {
args := m.Called()
return args.Error(0)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]task.Task), args.Error(1)
}

func immediateTrigger(callCount int, done chan struct{}) TriggerFunc {
Expand Down Expand Up @@ -71,7 +75,7 @@ func TestNewWithConfig_DefaultsNilTrigger(t *testing.T) {
// TestRegister_CallsServiceSend - trigger calls heartbeat.Service.Send()
func TestRegister_CallsServiceSend(t *testing.T) {
svc := new(MockHeartbeatService)
svc.On("Send").Return(nil).Times(3)
svc.On("Send").Return(nil, nil).Times(3)

done := make(chan struct{})
job := NewWithConfig(HeartbeatJobConfig{
Expand All @@ -90,7 +94,7 @@ func TestRegister_CallsServiceSend(t *testing.T) {
// TestRegister_ContinuesOnError - job continues running after Send() error
func TestRegister_ContinuesOnError(t *testing.T) {
svc := new(MockHeartbeatService)
svc.On("Send").Return(errors.New("connection refused")).Times(3)
svc.On("Send").Return(nil, errors.New("connection refused")).Times(3)

done := make(chan struct{})
job := NewWithConfig(HeartbeatJobConfig{
Expand Down Expand Up @@ -128,7 +132,7 @@ func TestRegister_ReturnsCancel(t *testing.T) {
func TestShutdown_StopsJob(t *testing.T) {
var callCount atomic.Int32
svc := new(MockHeartbeatService)
svc.On("Send").Return(nil).Run(func(args mock.Arguments) {
svc.On("Send").Return(nil, nil).Run(func(args mock.Arguments) {
callCount.Add(1)
})

Expand Down
13 changes: 12 additions & 1 deletion app/jobs/heartbeatjob/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package heartbeatjob

import (
"context"
"fmt"
"time"

"github.com/labstack/gommon/log"
Expand All @@ -23,13 +24,23 @@ func TriggerWithConfig(ctx context.Context, fn func() error, config TriggerConfi
case <-ctx.Done():
return
case <-time.After(config.Interval):
if err := fn(); err != nil {
if err := safeCall(fn); err != nil {
log.Errorf("heartbeat failed: %s", err)
}
}
}
}

func safeCall(fn func() error) (err error) {
defer func() {
if r := recover(); r != nil {
log.Errorf("Panic recovered in heartbeat: %v", r)
err = fmt.Errorf("panic: %v", r)
}
}()
return fn()
}

func Trigger(ctx context.Context, fn func() error) {
TriggerWithConfig(ctx, fn, DefaultTriggerConfig())
}
13 changes: 12 additions & 1 deletion app/jobs/metricsjob/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package metricsjob

import (
"context"
"fmt"
"time"

"github.com/labstack/gommon/log"
Expand All @@ -27,13 +28,23 @@ func triggerWithConfig(ctx context.Context, fn func() error, config TriggerConfi
case <-ctx.Done():
return
case <-time.After(config.InitialDelay):
if err := fn(); err != nil {
if err := safeCall(fn); err != nil {
log.Errorf("Failed while running metrics job: %s", err)
}
}
}
}

func safeCall(fn func() error) (err error) {
defer func() {
if r := recover(); r != nil {
log.Errorf("Panic recovered in metrics: %v", r)
err = fmt.Errorf("panic: %v", r)
}
}()
return fn()
}

func TriggerWithConfig(ctx context.Context, fn func() error, config TriggerConfig) {
triggerWithConfig(ctx, fn, config)
}
Expand Down
13 changes: 12 additions & 1 deletion app/jobs/selfupdatejob/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package selfupdatejob

import (
"context"
"fmt"
"time"

log "github.com/sirupsen/logrus"
Expand All @@ -26,13 +27,23 @@ func TriggerWithConfig(ctx context.Context, fn func() error, config TriggerConfi
case <-ctx.Done():
return
case <-time.After(config.Interval):
if err := fn(); err != nil {
if err := safeCall(fn); err != nil {
log.Errorf("self-update check failed: %s", err)
}
}
}
}

func safeCall(fn func() error) (err error) {
defer func() {
if r := recover(); r != nil {
log.Errorf("Panic recovered in self-update: %v", r)
err = fmt.Errorf("panic: %v", r)
}
}()
return fn()
}

// Trigger runs fn with the default configuration.
func Trigger(ctx context.Context, fn func() error) {
TriggerWithConfig(ctx, fn, DefaultTriggerConfig())
Expand Down
16 changes: 14 additions & 2 deletions app/jobs/taskjob/taskjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr
for {
select {
case queued := <-tj.enqueueCh:
tj.processTask(ctx, queued, tr, channel)
tj.processTaskSafe(ctx, queued, tr, channel)
case <-ctx.Done():
return
}
Expand All @@ -109,7 +109,7 @@ func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr
}
}
for _, t := range incompleteTasks {
tj.processTask(ctx, t, tr, channel)
tj.processTaskSafe(ctx, t, tr, channel)
}
return nil
})
Expand All @@ -130,6 +130,18 @@ func (tj *TaskJob) Enqueue(ctx context.Context, t task.Task) error {
}
}

func (tj *TaskJob) processTaskSafe(ctx context.Context, t task.Task, tr taskreporter.TaskReporter, channel ResultChannel) {
defer func() {
if r := recover(); r != nil {
log.Errorf("Panic recovered in processTask: %v", r)
telemetry.Metric("hostlink.task_runner.panic", 1, map[string]any{
"task_id": t.ID,
})
}
}()
tj.processTask(ctx, t, tr, channel)
}

func (tj *TaskJob) processTask(ctx context.Context, t task.Task, tr taskreporter.TaskReporter, channel ResultChannel) {
tempFile, err := os.CreateTemp("", "*_script.sh")
if err != nil {
Expand Down
15 changes: 13 additions & 2 deletions app/jobs/taskjob/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package taskjob

import (
"context"
"fmt"
"time"

"github.com/labstack/gommon/log"
Expand All @@ -27,13 +28,23 @@ func triggerWithConfig(ctx context.Context, fn func() error, config TriggerConfi
case <-ctx.Done():
return
case <-time.After(config.InitialDelay):
if err := fn(); err != nil {
log.Errorf("Failed while running metrics job: %s", err)
if err := safeCall(fn); err != nil {
log.Errorf("Failed while running task poller: %s", err)
}
}
}
}

func safeCall(fn func() error) (err error) {
defer func() {
if r := recover(); r != nil {
log.Errorf("Panic recovered: %v", r)
err = fmt.Errorf("panic: %v", r)
}
}()
return fn()
}

func TriggerWithConfig(ctx context.Context, fn func() error, config TriggerConfig) {
triggerWithConfig(ctx, fn, config)
}
Expand Down
13 changes: 9 additions & 4 deletions app/services/heartbeat/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import (

"hostlink/app/services/agentstate"
"hostlink/config/appconf"
"hostlink/domain/task"
"hostlink/internal/apiserver"
)

type Service interface {
Send() error
Send() ([]task.Task, error)
}

type heartbeatService struct {
Expand Down Expand Up @@ -49,11 +50,15 @@ func NewWithDependencies(
}
}

func (s *heartbeatService) Send() error {
func (s *heartbeatService) Send() ([]task.Task, error) {
agentID := s.agentstate.GetAgentID()
if agentID == "" {
return fmt.Errorf("agent not registered: missing agent ID")
return nil, fmt.Errorf("agent not registered: missing agent ID")
}

return s.apiserver.Heartbeat(context.Background(), agentID)
resp, err := s.apiserver.Heartbeat(context.Background(), agentID)
if err != nil {
return nil, err
}
return resp.PendingTasks, nil
}
Loading
Loading