Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
asyncRecorder.WithClient(string(client))
interceptor.Setup(logger, asyncRecorder, mcpProxy)

cred := interceptor.Credential()
if err := rec.RecordInterception(ctx, &recorder.InterceptionRecord{
ID: interceptor.ID().String(),
InitiatorID: actor.ID,
Expand All @@ -228,6 +229,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
Client: string(client),
ClientSessionID: sessionID,
CorrelatingToolCallID: interceptor.CorrelatingToolCallID(),
CredentialKind: string(cred.Kind),
CredentialHint: cred.Hint,
}); err != nil {
span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err))
logger.Warn(ctx, "failed to record interception", slog.Error(err))
Expand All @@ -242,6 +245,9 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
slog.F("interception_id", interceptor.ID()),
slog.F("user_agent", r.UserAgent()),
slog.F("streaming", interceptor.Streaming()),
slog.F("credential_kind", string(cred.Kind)),
slog.F("credential_hint", cred.Hint),
slog.F("credential_length", cred.Length),
)

log.Debug(ctx, "interception started")
Expand Down
9 changes: 7 additions & 2 deletions intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ type interceptionBase struct {
logger slog.Logger
tracer trace.Tracer

recorder recorder.Recorder
mcpProxy mcp.ServerProxier
recorder recorder.Recorder
mcpProxy mcp.ServerProxier
credential intercept.CredentialInfo
}

func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
Expand Down Expand Up @@ -74,6 +75,10 @@ func (i *interceptionBase) ID() uuid.UUID {
return i.id
}

func (i *interceptionBase) Credential() intercept.CredentialInfo {
return i.credential
}

func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger
i.recorder = recorder
Expand Down
2 changes: 2 additions & 0 deletions intercept/chatcompletions/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func NewBlockingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
Expand All @@ -45,6 +46,7 @@ func NewBlockingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
}}
}

Expand Down
2 changes: 2 additions & 0 deletions intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func NewStreamingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
Expand All @@ -50,6 +51,7 @@ func NewStreamingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
}}
}

Expand Down
3 changes: 2 additions & 1 deletion intercept/chatcompletions/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/internal/testutil"
"github.com/google/uuid"
"github.com/openai/openai-go/v3"
Expand Down Expand Up @@ -86,7 +87,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)

tracer := otel.Tracer("test")
interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer)
interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer, intercept.CredentialInfo{})

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
Expand Down
31 changes: 31 additions & 0 deletions intercept/credential.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package intercept

import "github.com/coder/aibridge/utils"

// CredentialKind identifies how a request was authenticated.
// Keep in sync with the credential_kind enum in coderd's database.
type CredentialKind string

// Credential kind constants for interception recording.
const (
CredentialKindCentralized CredentialKind = "centralized"
CredentialKindBYOK CredentialKind = "byok"
)

// CredentialInfo holds credential metadata for an interception.
type CredentialInfo struct {
Kind CredentialKind
Hint string
Length int
}

// NewCredentialInfo creates a CredentialInfo from a raw credential.
// The credential is automatically masked before storage so that the
// original secret is never retained.
func NewCredentialInfo(kind CredentialKind, credential string) CredentialInfo {
return CredentialInfo{
Kind: kind,
Hint: utils.MaskSecret(credential),
Comment thread
evgeniy-scherbina marked this conversation as resolved.
Length: len(credential),
}
}
2 changes: 2 additions & 0 deletions intercept/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Interceptor interface {
Streaming() bool
// TraceAttributes returns tracing attributes for this [Interceptor]
TraceAttributes(*http.Request) []attribute.KeyValue
// Credential returns the credential metadata for this interception.
Credential() CredentialInfo
// CorrelatingToolCallID returns the ID of a tool call result submitted
// in the request, if present. This is used to correlate the current
// interception back to the previous interception that issued those tool
Expand Down
9 changes: 7 additions & 2 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,19 @@ type interceptionBase struct {
tracer trace.Tracer
logger slog.Logger

recorder recorder.Recorder
mcpProxy mcp.ServerProxier
recorder recorder.Recorder
mcpProxy mcp.ServerProxier
credential intercept.CredentialInfo
}

func (i *interceptionBase) ID() uuid.UUID {
return i.id
}

func (i *interceptionBase) Credential() intercept.CredentialInfo {
return i.credential
}

func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger
i.recorder = recorder
Expand Down
2 changes: 2 additions & 0 deletions intercept/messages/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func NewBlockingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
Expand All @@ -47,6 +48,7 @@ func NewBlockingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
}}
}

Expand Down
2 changes: 2 additions & 0 deletions intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func NewStreamingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
Expand All @@ -53,6 +54,7 @@ func NewStreamingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
}}
}

Expand Down
9 changes: 7 additions & 2 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ type responsesInterceptionBase struct {
recorder recorder.Recorder
mcpProxy mcp.ServerProxier

logger slog.Logger
tracer trace.Tracer
logger slog.Logger
tracer trace.Tracer
credential intercept.CredentialInfo
}

func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService {
Expand Down Expand Up @@ -83,6 +84,10 @@ func (i *responsesInterceptionBase) ID() uuid.UUID {
return i.id
}

func (i *responsesInterceptionBase) Credential() intercept.CredentialInfo {
return i.credential
}

func (i *responsesInterceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger.With(slog.F("model", i.Model()))
i.recorder = recorder
Expand Down
2 changes: 2 additions & 0 deletions intercept/responses/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func NewBlockingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *BlockingResponsesInterceptor {
return &BlockingResponsesInterceptor{
responsesInterceptionBase: responsesInterceptionBase{
Expand All @@ -43,6 +44,7 @@ func NewBlockingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
},
}
}
Expand Down
2 changes: 2 additions & 0 deletions intercept/responses/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func NewStreamingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *StreamingResponsesInterceptor {
return &StreamingResponsesInterceptor{
responsesInterceptionBase: responsesInterceptionBase{
Expand All @@ -50,6 +51,7 @@ func NewStreamingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
},
}
}
Expand Down
12 changes: 10 additions & 2 deletions provider/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,29 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr
// set BYOKBearerToken and clear the centralized key.
// When both are present, X-Api-Key takes priority to match
// claude-code behavior.
credKind := intercept.CredentialKindCentralized
credSecret := cfg.Key
authHeaderName := p.AuthHeader()
if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" {
cfg.Key = apiKey
authHeaderName = "X-Api-Key"
credKind = intercept.CredentialKindBYOK
credSecret = apiKey
} else if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" {
cfg.BYOKBearerToken = token
cfg.Key = ""
authHeaderName = "Authorization"
credKind = intercept.CredentialKindBYOK
credSecret = token
}

cred := intercept.NewCredentialInfo(credKind, credSecret)

var interceptor intercept.Interceptor
if reqPayload.Stream() {
interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer)
interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
} else {
interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer)
interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
}
span.SetAttributes(interceptor.TraceAttributes(r)...)
return interceptor, nil
Expand Down
43 changes: 29 additions & 14 deletions provider/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/coder/aibridge/config"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/internal/testutil"
)

Expand Down Expand Up @@ -163,33 +164,43 @@ func TestAnthropic_CreateInterceptor_BYOK(t *testing.T) {
t.Parallel()

tests := []struct {
name string
setHeaders map[string]string
wantXApiKey string
wantAuthorization string
name string
setHeaders map[string]string
wantXApiKey string
wantAuthorization string
wantCredentialKind intercept.CredentialKind
wantCredentialHint string
}{
{
name: "Messages_BYOK_BearerToken",
setHeaders: map[string]string{"Authorization": "Bearer user-access-token"},
wantAuthorization: "Bearer user-access-token",
name: "Messages_BYOK_BearerToken",
setHeaders: map[string]string{"Authorization": "Bearer user-access-token"},
wantAuthorization: "Bearer user-access-token",
wantCredentialKind: intercept.CredentialKindBYOK,
wantCredentialHint: "us...en",
},
{
name: "Messages_BYOK_APIKey",
setHeaders: map[string]string{"X-Api-Key": "user-api-key"},
wantXApiKey: "user-api-key",
name: "Messages_BYOK_APIKey",
setHeaders: map[string]string{"X-Api-Key": "user-api-key"},
wantXApiKey: "user-api-key",
wantCredentialKind: intercept.CredentialKindBYOK,
wantCredentialHint: "us...ey",
},
{
name: "Messages_Centralized_UsesCentralizedKey",
setHeaders: map[string]string{},
wantXApiKey: "test-key",
name: "Messages_Centralized",
setHeaders: map[string]string{},
wantXApiKey: "test-key",
wantCredentialKind: intercept.CredentialKindCentralized,
wantCredentialHint: "***",
},
{
name: "Messages_BYOK_BearerToken_And_APIKey",
setHeaders: map[string]string{
"Authorization": "Bearer user-access-token",
"X-Api-Key": "user-api-key",
},
wantXApiKey: "user-api-key",
wantXApiKey: "user-api-key",
wantCredentialKind: intercept.CredentialKindBYOK,
wantCredentialHint: "us...ey",
},
}

Expand Down Expand Up @@ -223,6 +234,10 @@ func TestAnthropic_CreateInterceptor_BYOK(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, interceptor)

cred := interceptor.Credential()
assert.Equal(t, tc.wantCredentialKind, cred.Kind, "credential kind mismatch")
assert.Equal(t, tc.wantCredentialHint, cred.Hint, "credential hint mismatch")

logger := slog.Make()
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)

Expand Down
10 changes: 6 additions & 4 deletions provider/copilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
ExtraHeaders: extractCopilotHeaders(r),
}

cred := intercept.NewCredentialInfo(intercept.CredentialKindBYOK, key)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it makes sense to add the github token here 🤔 I don't think the copilot case fits the reasoning for logging and debugging on the UI.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ssncferreira
I had very similar proposal, see: #216 (comment)


var interceptor intercept.Interceptor

path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix())
Expand All @@ -156,9 +158,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
}

if req.Stream {
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer, cred)
} else {
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer, cred)
}

case routeCopilotResponses:
Expand All @@ -172,9 +174,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
}

if reqPayload.Stream() {
interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer, cred)
} else {
interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer, cred)
}

default:
Expand Down
Loading
Loading