Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions coderd/aibridge/aibridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ import (
// that use per-user LLM credentials but cannot set custom headers.
const HeaderCoderToken = "X-Coder-AI-Governance-Token" //nolint:gosec // This is a header name, not a credential.

// HeaderCoderRequestID is a header set by aibridgeproxyd on each
// request forwarded to aibridged for cross-service log correlation.
const HeaderCoderRequestID = "X-Coder-AI-Governance-Request-Id"

// IsBYOK reports whether the request is using BYOK mode, determined
// by the presence of the X-Coder-AI-Governance-Token header.
func IsBYOK(header http.Header) bool {
Expand Down
62 changes: 62 additions & 0 deletions enterprise/aibridged/aibridged_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,65 @@ func TestRouting(t *testing.T) {
})
}
}

// TestServeHTTP_StripInternalHeaders verifies that internal X-Coder-*
// headers are never forwarded to upstream LLM providers.
func TestServeHTTP_StripInternalHeaders(t *testing.T) {
t.Parallel()

cases := []struct {
name string
header string
value string
}{
{
name: "X-Coder-AI-Governance-Token",
header: agplaibridge.HeaderCoderToken,
value: "coder-token",
},
{
name: "X-Coder-AI-Governance-Request-Id",
header: agplaibridge.HeaderCoderRequestID,
value: uuid.NewString(),
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

mockH := &mockHandler{}

srv, client, pool := newTestServer(t)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.IsAuthorizedResponse{OwnerId: uuid.NewString()}, nil)
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(mockH, nil)

httpSrv := httptest.NewServer(srv)
t.Cleanup(httpSrv.Close)

ctx := testutil.Context(t, testutil.WaitShort)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, httpSrv.URL+"/anthropic/v1/messages", nil)
require.NoError(t, err)

// Always set a valid auth token so the request reaches
// the upstream handler.
req.Header.Set("Authorization", "Bearer coder-token")
req.Header.Set(tc.header, tc.value)

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

require.Equal(t, http.StatusOK, resp.StatusCode)
require.NotNil(t, mockH.headersReceived)

// Assert no X-Coder-* headers were forwarded upstream.
for name := range mockH.headersReceived {
require.NotContains(t, name, "X-Coder-",
"internal header %q must not be forwarded to upstream providers", name)
}
})
}
}
10 changes: 10 additions & 0 deletions enterprise/aibridged/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ func (s *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {

logger := s.logger.With(slog.F("path", r.URL.Path))

// Extract and strip proxy request ID for cross-service log
// correlation. Absent for direct requests not routed through
// aibridgeproxyd.
if proxyReqID := r.Header.Get(agplaibridge.HeaderCoderRequestID); proxyReqID != "" {
// Inject into context so downstream loggers include it.
ctx = slog.With(ctx, slog.F("aibridgeproxy_id", proxyReqID))
logger = logger.With(slog.F("aibridgeproxy_id", proxyReqID))
}
r.Header.Del(agplaibridge.HeaderCoderRequestID)

byok := agplaibridge.IsBYOK(r.Header)
authMode := "centralized"
if byok {
Expand Down
8 changes: 2 additions & 6 deletions enterprise/aibridgeproxyd/aibridgeproxyd.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ const (
)

const (
// HeaderAIBridgeRequestID is the header used to correlate requests
// between aibridgeproxyd and aibridged.
HeaderAIBridgeRequestID = "X-AI-Bridge-Request-Id"
// ProxyAuthRealm is the realm used in Proxy-Authenticate challenges.
// The realm helps clients identify which credentials to use.
ProxyAuthRealm = `"Coder AI Bridge Proxy"`
Expand Down Expand Up @@ -960,9 +957,8 @@ func (s *Server) handleRequest(req *http.Request, ctx *goproxy.ProxyCtx) (*http.

injectBYOKHeaderIfNeeded(req.Header, reqCtx.CoderToken)

// Set custom header for cross-service log correlation.
// This allows correlating aibridgeproxyd logs with aibridged logs.
req.Header.Set(HeaderAIBridgeRequestID, reqCtx.RequestID.String())
// Set request ID header to correlate requests between aibridgeproxyd and aibridged.
req.Header.Set(agplaibridge.HeaderCoderRequestID, reqCtx.RequestID.String())

logger.Info(s.ctx, "routing MITM request to aibridged",
slog.F("aibridged_url", aiBridgeParsedURL.String()),
Expand Down
2 changes: 1 addition & 1 deletion enterprise/aibridgeproxyd/aibridgeproxyd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,7 @@ func TestProxy_MITM(t *testing.T) {
receivedPath = r.URL.Path
receivedAuthz = r.Header.Get("Authorization")
receivedBYOK = r.Header.Get(agplaibridge.HeaderCoderToken)
receivedRequestID = r.Header.Get(aibridgeproxyd.HeaderAIBridgeRequestID)
receivedRequestID = r.Header.Get(agplaibridge.HeaderCoderRequestID)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello from aibridged"))
}))
Expand Down
Loading