Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 112 additions & 45 deletions compliance/compliance.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,44 @@ import (
"github.com/stackrox/rox/pkg/metrics"
"github.com/stackrox/rox/pkg/mtls"
"github.com/stackrox/rox/pkg/protoutils"
"github.com/stackrox/rox/pkg/sync"
"github.com/stackrox/rox/pkg/utils"
"github.com/stackrox/rox/pkg/version"
"google.golang.org/grpc/metadata"
)

var log = logging.LoggerForModule()

const (
// nodeResourceID is the resource ID used for node scanning UMH.
// Compliance handles exactly one node, so a single constant suffices.
nodeResourceID = "this-node"
)

// Compliance represents the Compliance app
type Compliance struct {
nodeNameProvider node.NodeNameProvider
nodeScanner node.NodeScanner
nodeIndexer node.NodeIndexer
umhNodeInventory node.UnconfirmedMessageHandler
umhNodeIndex node.UnconfirmedMessageHandler
cache *sensor.MsgFromCompliance
nodeNameProvider node.NodeNameProvider
nodeScanner node.NodeScanner
nodeIndexer node.NodeIndexer
umhNodeInventory node.UnconfirmedMessageHandler
umhNodeIndex node.UnconfirmedMessageHandler
nodeInventoryCache *sensor.MsgFromCompliance
nodeInventoryCacheMu sync.Mutex
nodeIndexCache *sensor.MsgFromCompliance
nodeIndexCacheMu sync.Mutex
}

// NewComplianceApp constructs the Compliance app object
func NewComplianceApp(nnp node.NodeNameProvider, scanner node.NodeScanner, nodeIndexer node.NodeIndexer,
umhNodeInv, umhNodeIndex node.UnconfirmedMessageHandler) *Compliance {
return &Compliance{
nodeNameProvider: nnp,
nodeScanner: scanner,
nodeIndexer: nodeIndexer,
umhNodeInventory: umhNodeInv,
umhNodeIndex: umhNodeIndex,
cache: nil,
nodeNameProvider: nnp,
nodeScanner: scanner,
nodeIndexer: nodeIndexer,
umhNodeInventory: umhNodeInv,
umhNodeIndex: umhNodeIndex,
nodeInventoryCache: nil,
nodeIndexCache: nil,
}
}

Expand Down Expand Up @@ -179,13 +190,18 @@ func (c *Compliance) manageNodeInventoryScanLoop(ctx context.Context) <-chan *se
select {
case <-ctx.Done():
return
case _, ok := <-c.umhNodeInventory.RetryCommand():
if c.cache == nil {
log.Debug("Requested to retry but cache is empty. Resetting scan timer.")
case resourceID, ok := <-c.umhNodeInventory.RetryCommand():
if !ok {
log.Info("UMH retry channel for node inventory closed; stopping scan loop")
return
}
cachedMsg := c.getNodeInventoryCache()
if cachedMsg == nil {
log.Debugf("Requested to retry %s but cache is empty. Resetting scan timer.", resourceID)
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionResendingCacheMiss, cmetrics.ScannerVersionV2)
t.Reset(time.Second)
} else if ok {
nodeInventoriesC <- c.cache
} else {
nodeInventoriesC <- cachedMsg
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionResendingCacheHit, cmetrics.ScannerVersionV2)
}
case <-t.C:
Expand Down Expand Up @@ -215,13 +231,18 @@ func (c *Compliance) manageNodeIndexScanLoop(ctx context.Context) <-chan *sensor
select {
case <-ctx.Done():
return
case _, ok := <-c.umhNodeIndex.RetryCommand():
if c.cache == nil {
log.Debug("Requested to retry but cache is empty. Resetting scan timer.")
case resourceID, ok := <-c.umhNodeIndex.RetryCommand():
if !ok {
log.Info("UMH retry channel for node index closed; stopping scan loop")
return
}
cachedMsg := c.getNodeIndexCache()
if cachedMsg == nil {
log.Debugf("Requested to retry %s but cache is empty. Resetting scan timer.", resourceID)
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionResendingCacheMiss, cmetrics.ScannerVersionV4)
t.Reset(time.Second)
} else if ok {
nodeIndexesC <- c.cache
} else {
nodeIndexesC <- cachedMsg
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionResendingCacheHit, cmetrics.ScannerVersionV4)
}
case <-t.C:
Expand Down Expand Up @@ -249,8 +270,8 @@ func (c *Compliance) runNodeInventoryScan(ctx context.Context) *sensor.MsgFromCo
}
cmetrics.ObserveNodeInventoryScan(msg.GetNodeInventory())
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionScan, cmetrics.ScannerVersionV2)
c.umhNodeInventory.ObserveSending()
c.cache = msg.CloneVT()
c.umhNodeInventory.ObserveSending(nodeResourceID)
c.setNodeInventoryCache(msg.CloneVT())
return msg
}

Expand All @@ -266,14 +287,39 @@ func (c *Compliance) runNodeIndex(ctx context.Context) *sensor.MsgFromCompliance
log.Errorf("Error creating node index: %v", err)
return nil
}
c.umhNodeIndex.ObserveSending()
c.umhNodeIndex.ObserveSending(nodeResourceID)
cmetrics.ObserveNodeIndexReport(report, nodeName)
msg := c.createIndexMsg(report, nodeName)
cmetrics.ObserveReportProtobufMessage(msg, cmetrics.ScannerVersionV4)
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionScan, cmetrics.ScannerVersionV4)
c.setNodeIndexCache(msg.CloneVT())
return msg
}

func (c *Compliance) getNodeInventoryCache() *sensor.MsgFromCompliance {
c.nodeInventoryCacheMu.Lock()
defer c.nodeInventoryCacheMu.Unlock()
return c.nodeInventoryCache
}

func (c *Compliance) setNodeInventoryCache(msg *sensor.MsgFromCompliance) {
c.nodeInventoryCacheMu.Lock()
defer c.nodeInventoryCacheMu.Unlock()
c.nodeInventoryCache = msg
}

func (c *Compliance) getNodeIndexCache() *sensor.MsgFromCompliance {
c.nodeIndexCacheMu.Lock()
defer c.nodeIndexCacheMu.Unlock()
return c.nodeIndexCache
}

func (c *Compliance) setNodeIndexCache(msg *sensor.MsgFromCompliance) {
c.nodeIndexCacheMu.Lock()
defer c.nodeIndexCacheMu.Unlock()
c.nodeIndexCache = msg
}

func (c *Compliance) manageStream(ctx context.Context, cli sensor.ComplianceServiceClient, sig *concurrency.Signal, toSensorC <-chan *sensor.MsgFromCompliance) {
for {
select {
Expand Down Expand Up @@ -343,43 +389,64 @@ func (c *Compliance) runRecv(ctx context.Context, client sensor.ComplianceServic
}
}
case *sensor.MsgToCompliance_ComplianceAck:
complianceAck := t.ComplianceAck
log.Debugf("Received ComplianceACK: type=%s, action=%s, resource_id=%s, reason=%s",
complianceAck.GetMessageType(),
complianceAck.GetAction(),
complianceAck.GetResourceId(),
complianceAck.GetReason(),
)
c.handleNodeScanningComplianceAck(complianceAck)
// New ComplianceACK from Sensor 4.10+
c.handleComplianceACK(t.ComplianceAck)
default:
utils.Should(errors.Errorf("Unhandled msg type: %T", t))
}
}
}

func (c *Compliance) handleNodeScanningComplianceAck(complianceAck *sensor.MsgToCompliance_ComplianceACK) {
if complianceAck == nil {
// handleComplianceACK handles the new ComplianceACK message from Sensor 4.10+.
// This is the generic ACK/NACK message that replaces the legacy NodeInventoryACK.
func (c *Compliance) handleComplianceACK(ack *sensor.MsgToCompliance_ComplianceACK) {
if ack == nil {
log.Error("Received nil ComplianceACK")
return
}

var handler node.UnconfirmedMessageHandler
switch complianceAck.GetMessageType() {
log.Debugf("Received ComplianceACK: type=%s, action=%s, resource_id=%s, reason=%s",
ack.GetMessageType(), ack.GetAction(), ack.GetResourceId(), ack.GetReason())

switch ack.GetMessageType() {
case sensor.MsgToCompliance_ComplianceACK_NODE_INVENTORY:
handler = c.umhNodeInventory
c.handleNodeInventoryACK(ack.GetAction(), ack.GetReason())
case sensor.MsgToCompliance_ComplianceACK_NODE_INDEX_REPORT:
handler = c.umhNodeIndex
c.handleNodeIndexACK(ack.GetAction(), ack.GetReason())
case sensor.MsgToCompliance_ComplianceACK_VM_INDEX_REPORT:
// TODO: Implement basic handling of VM_INDEX_REPORT ACK/NACK messages in ROX-33555.
default:
log.Debugf("Ignoring ComplianceACK with unsupported message type: %s", complianceAck.GetMessageType())
return
log.Errorf("Unknown ComplianceACK message type: %s", ack.GetMessageType())
}
}

// handleNodeInventoryACK handles ACK/NACK for node inventory messages.
func (c *Compliance) handleNodeInventoryACK(action sensor.MsgToCompliance_ComplianceACK_Action, reason string) {
switch action {
case sensor.MsgToCompliance_ComplianceACK_ACK:
c.umhNodeInventory.HandleACK(nodeResourceID)
case sensor.MsgToCompliance_ComplianceACK_NACK:
if reason != "" {
log.Infof("Node inventory NACK received: %s", reason)
}
c.umhNodeInventory.HandleNACK(nodeResourceID)
default:
log.Errorf("Unknown ComplianceACK action for node inventory: %s", action)
}
}

switch complianceAck.GetAction() {
// handleNodeIndexACK handles ACK/NACK for node index report messages.
func (c *Compliance) handleNodeIndexACK(action sensor.MsgToCompliance_ComplianceACK_Action, reason string) {
switch action {
case sensor.MsgToCompliance_ComplianceACK_ACK:
handler.HandleACK()
c.umhNodeIndex.HandleACK(nodeResourceID)
case sensor.MsgToCompliance_ComplianceACK_NACK:
handler.HandleNACK()
if reason != "" {
log.Infof("Node index NACK received: %s", reason)
}
c.umhNodeIndex.HandleNACK(nodeResourceID)
default:
log.Errorf("Unknown ComplianceACK action: %s", complianceAck.GetAction())
log.Errorf("Unknown ComplianceACK action for node index: %s", action)
}
}

Expand Down
13 changes: 8 additions & 5 deletions compliance/compliance_ack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/stackrox/rox/generated/internalapi/sensor"
"github.com/stackrox/rox/pkg/concurrency"
"github.com/stretchr/testify/assert"
)

Expand All @@ -12,10 +13,12 @@ type fakeUMH struct {
nackCount int
}

func (f *fakeUMH) HandleACK() { f.ackCount++ }
func (f *fakeUMH) HandleNACK() { f.nackCount++ }
func (f *fakeUMH) ObserveSending() {}
func (f *fakeUMH) RetryCommand() <-chan struct{} { return nil }
func (f *fakeUMH) HandleACK(string) { f.ackCount++ }
func (f *fakeUMH) HandleNACK(string) { f.nackCount++ }
func (f *fakeUMH) ObserveSending(string) {}
func (f *fakeUMH) RetryCommand() <-chan string { return nil }
func (f *fakeUMH) OnACK(func(resourceID string)) {}
func (f *fakeUMH) Stopped() concurrency.ReadOnlyErrorSignal { return nil }

func TestHandleNodeScanningComplianceAck(t *testing.T) {
inv := &fakeUMH{}
Expand Down Expand Up @@ -89,7 +92,7 @@ func TestHandleNodeScanningComplianceAck(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
inv.ackCount, inv.nackCount = 0, 0
idx.ackCount, idx.nackCount = 0, 0
c.handleNodeScanningComplianceAck(tt.ack)
c.handleComplianceACK(tt.ack)
assert.Equal(t, tt.wantInvACK, inv.ackCount)
assert.Equal(t, tt.wantInvNACK, inv.nackCount)
assert.Equal(t, tt.wantIdxACK, idx.ackCount)
Expand Down
Loading
Loading