Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions central/apirequestlog/interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package apirequestlog

import (
"context"
"net/http"

"github.com/stackrox/rox/central/metrics/custom/api_requests"
"github.com/stackrox/rox/pkg/httputil"
"github.com/stackrox/rox/pkg/telemetry/phonehome"
"google.golang.org/grpc"
)

// UnaryServerInterceptor creates a gRPC unary interceptor that tracks API
// request metadata.
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
resp, err := handler(ctx, req)
rp := phonehome.GetGRPCRequestDetails(ctx, err, info.FullMethod, req)
api_requests.RecordRequest(rp)
return resp, err
}
}

// StreamServerInterceptor creates a gRPC stream interceptor that tracks API
// request metadata.
func StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
err := handler(srv, ss)
rp := phonehome.GetGRPCRequestDetails(ss.Context(), err,
info.FullMethod, nil)
api_requests.RecordRequest(rp)
return err
}
}

// HTTPInterceptor creates an HTTP interceptor that tracks API request metadata.
func HTTPInterceptor() httputil.HTTPInterceptor {
return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wrappedWriter := httputil.NewStatusTrackingWriter(w)
handler.ServeHTTP(wrappedWriter, r)
statusCode := 0
if statusCodePtr := wrappedWriter.GetStatusCode(); statusCodePtr != nil {
statusCode = *statusCodePtr
}
rp := phonehome.GetHTTPRequestDetails(r.Context(), r, statusCode)
api_requests.RecordRequest(rp)
})
}
}
89 changes: 89 additions & 0 deletions central/apirequestlog/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package apirequestlog

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func TestUnaryServerInterceptor(t *testing.T) {
interceptor := UnaryServerInterceptor()

// Create a test handler that returns success
successHandler := func(ctx context.Context, req interface{}) (interface{}, error) {
return "success", nil
}

// Create a test handler that returns an error
errorHandler := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, status.Error(codes.NotFound, "not found")
}

tests := []struct {
name string
handler grpc.UnaryHandler
fullMethod string
expectError bool
}{
{
name: "successful request",
handler: successHandler,
fullMethod: "/v1.ClusterService/GetClusters",
expectError: false,
},
{
name: "failed request",
handler: errorHandler,
fullMethod: "/v1.PolicyService/GetPolicy",
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()

info := &grpc.UnaryServerInfo{
FullMethod: tt.fullMethod,
}

// The interceptor will use requestinfo.FromContext which returns zero value
resp, err := interceptor(ctx, nil, info, tt.handler)

if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, "success", resp)
}
})
}
}

func TestHTTPInterceptor(t *testing.T) {
interceptor := HTTPInterceptor()

// Create a test handler
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
})

wrappedHandler := interceptor(testHandler)

req := httptest.NewRequest(http.MethodGet, "/api/v1/clusters", nil)
req.Header.Set("User-Agent", "test-agent/1.0")

recorder := httptest.NewRecorder()

wrappedHandler.ServeHTTP(recorder, req)

assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "OK", recorder.Body.String())
}
6 changes: 6 additions & 0 deletions central/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
administrationUsageService "github.com/stackrox/rox/central/administration/usage/service"
alertDatastore "github.com/stackrox/rox/central/alert/datastore"
alertService "github.com/stackrox/rox/central/alert/service"
"github.com/stackrox/rox/central/apirequestlog"
apitokenDS "github.com/stackrox/rox/central/apitoken/datastore"
apiTokenExpiration "github.com/stackrox/rox/central/apitoken/expiration"
apiTokenService "github.com/stackrox/rox/central/apitoken/service"
Expand Down Expand Up @@ -629,6 +630,11 @@ func startGRPCServer() {
)
config.HTTPInterceptors = append(config.HTTPInterceptors, observe.AuthzTraceHTTPInterceptor(authzTraceSink))

// API request tracking for Prometheus metrics
config.UnaryInterceptors = append(config.UnaryInterceptors, apirequestlog.UnaryServerInterceptor())
config.StreamInterceptors = append(config.StreamInterceptors, apirequestlog.StreamServerInterceptor())
config.HTTPInterceptors = append(config.HTTPInterceptors, apirequestlog.HTTPInterceptor())

// Before authorization is checked, we want to inject the sac client into the context.
config.PreAuthContextEnrichers = append(config.PreAuthContextEnrichers,
centralSAC.GetEnricher().GetPreAuthContextEnricher(authzTraceSink),
Expand Down
36 changes: 36 additions & 0 deletions central/metrics/custom/api_requests/labels.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package api_requests

import (
"strconv"

"github.com/stackrox/rox/central/metrics/custom/tracker"
"github.com/stackrox/rox/pkg/telemetry/phonehome"
)

var LazyLabels = tracker.LazyLabelGetters[*finding]{
"UserID": func(f *finding) string {
if f.UserID != nil {
return f.UserID.UID()
}
return ""
},
"UserAgent": func(f *finding) string { return getUserAgentFromHeaders(f.Headers) },
"Path": func(f *finding) string { return f.Path },
"Method": func(f *finding) string { return f.Method },
"Status": func(f *finding) string { return strconv.Itoa(f.Code) },
}

type finding = phonehome.RequestParams

func getUserAgentFromHeaders(headers func(string) []string) string {
if headers == nil {
return ""
}
if userAgents := headers("User-Agent"); len(userAgents) > 0 {
return userAgents[0]
}
if userAgents := headers("user-agent"); len(userAgents) > 0 {
return userAgents[0]
}
return ""
}
38 changes: 38 additions & 0 deletions central/metrics/custom/api_requests/tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package api_requests

import (
"github.com/stackrox/rox/central/metrics/custom/tracker"
"github.com/stackrox/rox/pkg/sync"
"github.com/stackrox/rox/pkg/telemetry/phonehome"
)

var (
singleton *tracker.TrackerBase[*finding]
singletonOnce sync.Once
)

// new creates a new API request tracker using TrackerBase with counter support.
// The tracker is created with a nil generator since it uses real-time counter
// increments.
func new() *tracker.TrackerBase[*finding] {
return tracker.MakeGlobalTrackerBase(
"api_request",
"API requests",
LazyLabels,
nil, // nil generator = counter tracker
)
}

// Singleton returns the global API request tracker instance.
func Singleton() *tracker.TrackerBase[*finding] {
singletonOnce.Do(func() {
singleton = new()
})
return singleton
}

// RecordRequest records an API request by incrementing the counter.
// This is a convenience wrapper around TrackerBase.IncrementCounter.
func RecordRequest(rp *phonehome.RequestParams) {
Singleton().IncrementCounter(rp)
}
2 changes: 1 addition & 1 deletion central/metrics/custom/expiry/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func New(s service.Service) *tracker.TrackerBase[*finding] {
return tracker.MakeTrackerBase(
return tracker.MakeGlobalTrackerBase(
"cert_exp",
"certificate expiry",
LazyLabels,
Expand Down
2 changes: 1 addition & 1 deletion central/metrics/custom/policies/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func New(ds policyDS.DataStore) *tracker.TrackerBase[*finding] {
return tracker.MakeTrackerBase(
return tracker.MakeGlobalTrackerBase(
"cfg",
"policies",
LazyLabels,
Expand Down
28 changes: 24 additions & 4 deletions central/metrics/custom/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import (
"context"
"net/http"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
alertDS "github.com/stackrox/rox/central/alert/datastore"
clusterDS "github.com/stackrox/rox/central/cluster/datastore"
configDS "github.com/stackrox/rox/central/config/datastore"
expiryS "github.com/stackrox/rox/central/credentialexpiry/service"
deploymentDS "github.com/stackrox/rox/central/deployment/datastore"
"github.com/stackrox/rox/central/metrics"
"github.com/stackrox/rox/central/metrics/custom/api_requests"
"github.com/stackrox/rox/central/metrics/custom/clusters"
"github.com/stackrox/rox/central/metrics/custom/expiry"
"github.com/stackrox/rox/central/metrics/custom/image_vulnerabilities"
Expand Down Expand Up @@ -95,6 +98,9 @@ func makeRunner(ds *runnerDatastores) trackerRunner {
// rox_central_cert_exp_hours
"hours": expiry.LazyLabels.GetLabels(),
}),
}, {
api_requests.Singleton(),
(*storage.PrometheusMetrics).GetApiRequests,
},
}
}
Expand Down Expand Up @@ -162,14 +168,28 @@ func (tr trackerRunner) ServeHTTP(w http.ResponseWriter, req *http.Request) {
go tracker.Gather(newCtx)
}
}
registry, err := metrics.GetCustomRegistry(userID)

userRegistry, err := metrics.GetCustomRegistry(userID)
if err != nil {
httputil.WriteError(w, err)
return
}

globalRegistry, err := metrics.GetGlobalRegistry()
if err != nil {
httputil.WriteError(w, err)
return
}
registry.Lock()
defer registry.Unlock()
registry.ServeHTTP(w, req)

userRegistry.Lock()
defer userRegistry.Unlock()
globalRegistry.Lock()
defer globalRegistry.Unlock()

promhttp.HandlerFor(
prometheus.Gatherers{userRegistry, globalRegistry},
promhttp.HandlerOpts{}).ServeHTTP(w, req)

go phonehome()
}

Expand Down
1 change: 1 addition & 0 deletions central/metrics/custom/tracker/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@
toAdd []MetricName
toDelete []MetricName
period time.Duration
enabled bool

Check failure on line 43 in central/metrics/custom/tracker/configuration.go

View workflow job for this annotation

GitHub Actions / golangci-lint

File is not properly formatted (gofmt)
}
Loading
Loading