Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion pkg/clientconn/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ func TLSConfig(server mtls.Subject, opts TLSConfigOptions) (*tls.Config, error)

type connectionOptions struct {
useServiceCertToken bool
useInsecureNoTLS bool
dialTLSFunc DialTLSFunc
rootCAs *x509.CertPool
}
Expand Down Expand Up @@ -198,14 +199,22 @@ func AddRootCAs(certs ...*x509.Certificate) ConnectionOption {
})
}

// UseServiceCertToken specifies whether or not a `ServiceCert` token should be used.
// UseServiceCertToken specifies whether a `ServiceCert` token should be used.
func UseServiceCertToken(use bool) ConnectionOption {
return connectOptFunc(func(opts *connectionOptions) error {
opts.useServiceCertToken = use
return nil
})
}

// UseInsecureNoTLS specifies whether to use insecure, non-TLS connections.
func UseInsecureNoTLS(use bool) ConnectionOption {
return connectOptFunc(func(opts *connectionOptions) error {
opts.useInsecureNoTLS = use
return nil
})
}

// UseDialTLSFunc uses the given connection function for dialing.
func UseDialTLSFunc(fn DialTLSFunc) ConnectionOption {
return connectOptFunc(func(opts *connectionOptions) error {
Expand All @@ -229,6 +238,7 @@ func OptionsForEndpoint(endpoint string, extraConnOpts ...ConnectionOption) (Opt
}

clientConnOpts := Options{
InsecureNoTLS: connOpts.useInsecureNoTLS,
TLS: TLSConfigOptions{
UseClientCert: MustUseClientCert,
ServerName: host,
Expand Down Expand Up @@ -383,6 +393,10 @@ func GRPCConnection(dialCtx context.Context, server mtls.Subject, endpoint strin

// NewHTTPClient creates an HTTP client for the given service using the client
// certificate of the calling service.
// When specifying the url.URL for the *http.Request for the returned *http.Client to complete,
// there is no need to specify the Host nor Scheme; however,
// if provided, they both must match the expected values.
// See AuthenticatedHTTPTransport for more information.
func NewHTTPClient(serviceIdentity mtls.Subject, serviceEndpoint string, timeout time.Duration) (*http.Client, error) {
transport, err := AuthenticatedHTTPTransport(
serviceEndpoint, serviceIdentity, nil, UseServiceCertToken(true))
Expand Down
90 changes: 90 additions & 0 deletions pkg/clientconn/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@ package clientconn

import (
"crypto/x509"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path"
"testing"

"github.com/stackrox/rox/pkg/httputil"
"github.com/stackrox/rox/pkg/mtls"
"github.com/stackrox/rox/pkg/mtls/verifier"
"github.com/stretchr/testify/suite"
)
Expand Down Expand Up @@ -48,3 +53,88 @@ func (t *ClientTestSuite) TestRootCA_WithNilCA_ShouldPanic() {
_, _ = OptionsForEndpoint(centralEndpoint, AddRootCAs(nil))
})
}

func (t *ClientTestSuite) TestAuthenticatedHTTPTransport_WebSocket() {
noopServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
baseTransport := httputil.DefaultTransport()

testcases := []struct {
name string
scheme string
valid bool
}{
{
name: "valid wss",
scheme: "wss",
valid: true,
},
{
name: "invalid wss",
scheme: "wss",
valid: false,
},
{
name: "valid ws",
scheme: "ws",
valid: true,
},
{
name: "invalid ws",
scheme: "ws",
valid: false,
},
}
for _, testcase := range testcases {
t.Run(testcase.name, func() {
// Ensure the request's URL drops the WebSocket.
baseTransport.Proxy = func(r *http.Request) (*url.URL, error) {
if !testcase.valid {
t.FailNow("Should not make it this far")
}

// http because TLS is disabled.
t.Equal("http://central.stackrox.svc:443/hello/howdy?file=rhelv2%2Frepository-to-cpe.json&uuid=f81dbc6b-5899-433b-bc86-9127219a9d89", r.URL.String())

// Forward traffic to the NO-OP Server
return url.Parse(noopServer.URL)
}

host := testcase.scheme + "://central.stackrox.svc:443"
// This is sorted by key.
rawQuery := url.Values{
"uuid": []string{"f81dbc6b-5899-433b-bc86-9127219a9d89"},
"file": []string{"rhelv2/repository-to-cpe.json"},
}.Encode()
endpoint := (&url.URL{Path: "/hello/howdy", RawQuery: rawQuery}).String()
if !testcase.valid {
endpoint = (&url.URL{
Scheme: "https",
Host: host,
Path: "hello/howdy",
RawQuery: rawQuery,
}).String()
}

req, err := http.NewRequest(http.MethodGet, endpoint, nil)
if testcase.valid {
t.NoError(err)
} else {
errEndpoint := `"https://` + testcase.scheme + `:%2F%2Fcentral.stackrox.svc:443/hello/howdy?file=rhelv2%2Frepository-to-cpe.json&uuid=f81dbc6b-5899-433b-bc86-9127219a9d89"`
errString := `parse ` + errEndpoint + `: invalid URL escape "%2F"`
t.EqualError(err, errString)
return
}

transport, err := AuthenticatedHTTPTransport(host, mtls.CentralSubject, baseTransport, UseInsecureNoTLS(true))
t.Require().NoError(err)
client := &http.Client{
Transport: transport,
Timeout: 0,
}

resp, err := client.Do(req)
t.NoError(err)
t.Equal(http.StatusOK, resp.StatusCode)
})
}
}
15 changes: 7 additions & 8 deletions sensor/common/scannerdefinitions/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"google.golang.org/grpc/codes"
)

const scannerDefsPath = "/api/extensions/scannerdefinitions"

var (
headersToProxy = set.NewFrozenStringSet("If-Modified-Since", "Accept-Encoding")
)
Expand All @@ -22,18 +24,16 @@ var (
// from Central.
type scannerDefinitionsHandler struct {
centralClient *http.Client
centralHost string
}

// NewDefinitionsHandler creates a new scanner definitions handler.
func NewDefinitionsHandler(centralHost string) (http.Handler, error) {
client, err := clientconn.NewHTTPClient(mtls.CentralSubject, centralHost, 0)
func NewDefinitionsHandler(centralEndpoint string) (http.Handler, error) {
client, err := clientconn.NewHTTPClient(mtls.CentralSubject, centralEndpoint, 0)
if err != nil {
return nil, errors.Wrap(err, "instantiating central HTTP transport")
}
return &scannerDefinitionsHandler{
centralClient: client,
centralHost: centralHost,
}, nil
}

Expand All @@ -44,11 +44,10 @@ func (h *scannerDefinitionsHandler) ServeHTTP(writer http.ResponseWriter, reques
return
}
// Prepare the Central's request, proxy relevant headers and all parameters.
// No need to set Scheme nor Host, as the client will already do that for us.
centralURL := url.URL{
Scheme: "https",
Host: h.centralHost,
Path: "api/extensions/scannerdefinitions",
RawQuery: request.URL.Query().Encode(),
Path: scannerDefsPath,
RawQuery: request.URL.RawQuery,
}
centralRequest, err := http.NewRequestWithContext(
request.Context(), http.MethodGet, centralURL.String(), nil)
Expand Down
14 changes: 5 additions & 9 deletions sensor/common/scannerdefinitions/http_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import (
"net/url"
"testing"

"github.com/stackrox/rox/pkg/httputil"
"github.com/stretchr/testify/assert"
)

var _ http.ResponseWriter = (*responseWriterMock)(nil)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you have this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ensures the code will not compile if *responseWriterMock no longer implements the http.ResponseWriter interface. See https://github.com/uber-go/guide/blob/master/style.md#verify-interface-compliance


type responseWriterMock struct {
bytes.Buffer
statusCode int
Expand All @@ -30,14 +33,7 @@ func (m *responseWriterMock) WriteHeader(statusCode int) {
m.statusCode = statusCode
}

// transportMockFunc is a transport mock that call itself to implement http.Transport's RoundTrip.
type transportMockFunc func(req *http.Request) (*http.Response, error)

func (f transportMockFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

func Test_scannerDefinitionsHandler_ServeHTTP(t *testing.T) {
func TestServeHTTP_Responses(t *testing.T) {
type args struct {
writer *responseWriterMock
request *http.Request
Expand Down Expand Up @@ -109,7 +105,7 @@ func Test_scannerDefinitionsHandler_ServeHTTP(t *testing.T) {
}
h := &scannerDefinitionsHandler{
centralClient: &http.Client{
Transport: transportMockFunc(func(req *http.Request) (*http.Response, error) {
Transport: httputil.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, tt.args.request.URL.RawQuery, req.URL.RawQuery)
for _, header := range headersToProxy.AsSlice() {
assert.Equal(t, tt.args.request.Header.Values(header), req.Header.Values(header))
Expand Down