Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion api/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"io"
"net/http"
"os"
"regexp"
"strings"
"time"

Expand Down Expand Up @@ -48,8 +50,9 @@ func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) {
clientOpts.LogVerboseHTTP = opts.LogVerboseHTTP
}

userAgentValue := fmt.Sprintf("GitHub CLI %s", opts.AppVersion)
headers := map[string]string{
userAgent: fmt.Sprintf("GitHub CLI %s", opts.AppVersion),
userAgent: getUserAgentForActions(userAgentValue),
}
clientOpts.Headers = headers

Expand Down Expand Up @@ -140,3 +143,20 @@ func getHost(r *http.Request) string {
}
return r.URL.Host
}

// getUserAgentForActions appends the ACTIONS_ORCHESTRATION_ID to the user agent
// if the environment variable is set. The orchestration ID is sanitized to only allow
// alphanumeric characters, underscores, hyphens, and dots.
func getUserAgentForActions(baseUserAgent string) string {
orchID := os.Getenv("ACTIONS_ORCHESTRATION_ID")
if orchID == "" {
return baseUserAgent
}

// Sanitize the orchestration ID to ensure it contains only valid characters
// Valid characters: 0-9, a-z, A-Z, _, -, .
re := regexp.MustCompile(`[^a-zA-Z0-9_.-]`)
sanitizedID := re.ReplaceAllString(orchID, "_")

return fmt.Sprintf("%s (actions_orchestration_id/%s)", baseUserAgent, sanitizedID)
}
106 changes: 106 additions & 0 deletions api/http_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"regexp"
"strings"
"testing"
Expand Down Expand Up @@ -313,3 +314,108 @@ func normalizeVerboseLog(t string) string {
t = timezoneRE.ReplaceAllString(t, "> Time-Zone: <timezone>")
return t
}

func TestGetUserAgentForActions(t *testing.T) {
tests := []struct {
name string
baseUserAgent string
orchID string
expectedResult string
}{
{
name: "with orchestration ID",
baseUserAgent: "GitHub CLI v1.2.3",
orchID: "workflow-12345-job-67890",
expectedResult: "GitHub CLI v1.2.3 (actions_orchestration_id/workflow-12345-job-67890)",
},
{
name: "without orchestration ID",
baseUserAgent: "GitHub CLI v1.2.3",
orchID: "",
expectedResult: "GitHub CLI v1.2.3",
},
{
name: "with special characters",
baseUserAgent: "GitHub CLI v1.2.3",
orchID: "test (with) special/chars",
expectedResult: "GitHub CLI v1.2.3 (actions_orchestration_id/test__with__special_chars)",
},
{
name: "with various special characters",
baseUserAgent: "GitHub CLI v1.2.3",
orchID: "test!@#$%^&*()+=[]{}|\\:;\"'<>?,/",
expectedResult: "GitHub CLI v1.2.3 (actions_orchestration_id/test___________________________)",
},
{
name: "with allowed characters only",
baseUserAgent: "GitHub CLI v1.2.3",
orchID: "test_with-allowed.chars123",
expectedResult: "GitHub CLI v1.2.3 (actions_orchestration_id/test_with-allowed.chars123)",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set environment variable
if tt.orchID != "" {
t.Setenv("ACTIONS_ORCHESTRATION_ID", tt.orchID)
} else {
os.Unsetenv("ACTIONS_ORCHESTRATION_ID")
}

result := getUserAgentForActions(tt.baseUserAgent)
assert.Equal(t, tt.expectedResult, result)
})
}
}

func TestNewHTTPClientWithActionsOrchestrationID(t *testing.T) {
// Set the orchestration ID
orchID := "test-orch-id-12345"
t.Setenv("ACTIONS_ORCHESTRATION_ID", orchID)

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userAgent := r.Header.Get("User-Agent")
assert.Contains(t, userAgent, "GitHub CLI v1.2.3")
assert.Contains(t, userAgent, "actions_orchestration_id/test-orch-id-12345")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()

client, err := NewHTTPClient(HTTPClientOptions{
AppVersion: "v1.2.3",
})
require.NoError(t, err)

req, err := http.NewRequest("GET", ts.URL, nil)
require.NoError(t, err)
res, err := client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}

func TestNewHTTPClientWithoutActionsOrchestrationID(t *testing.T) {
// Ensure the environment variable is not set
os.Unsetenv("ACTIONS_ORCHESTRATION_ID")

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userAgent := r.Header.Get("User-Agent")
assert.Equal(t, "GitHub CLI v1.2.3", userAgent)
assert.NotContains(t, userAgent, "actions_orchestration_id")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()

client, err := NewHTTPClient(HTTPClientOptions{
AppVersion: "v1.2.3",
})
require.NoError(t, err)

req, err := http.NewRequest("GET", ts.URL, nil)
require.NoError(t, err)
res, err := client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}