diff --git a/premium/monitor.go b/premium/monitor.go index 02573b4672..cf3deb99f6 100644 --- a/premium/monitor.go +++ b/premium/monitor.go @@ -60,20 +60,23 @@ func WithCancelOnQuotaExceeded(ctx context.Context, qm QuotaMonitor, ops ...Quot return newCtx, nil } -func (qc quotaChecker) checkInitialQuota(ctx context.Context) error { - hasQuota, err := qc.qm.HasQuota(ctx) +func (qc *quotaChecker) checkInitialQuota(ctx context.Context) error { + result, err := qc.qm.CheckQuota(ctx) if err != nil { return err } + if result.SuggestedQueryInterval > 0 { + qc.duration = result.SuggestedQueryInterval + } - if !hasQuota { + if !result.HasQuota { return ErrNoQuota{team: qc.qm.TeamName()} } return nil } -func (qc quotaChecker) startQuotaMonitor(ctx context.Context) context.Context { +func (qc *quotaChecker) startQuotaMonitor(ctx context.Context) context.Context { newCtx, cancelWithCause := context.WithCancelCause(ctx) go func() { ticker := time.NewTicker(qc.duration) @@ -84,7 +87,7 @@ func (qc quotaChecker) startQuotaMonitor(ctx context.Context) context.Context { case <-newCtx.Done(): return case <-ticker.C: - hasQuota, err := qc.qm.HasQuota(newCtx) + result, err := qc.qm.CheckQuota(newCtx) if err != nil { consecutiveFailures++ hasQuotaErrors = errors.Join(hasQuotaErrors, err) @@ -94,9 +97,13 @@ func (qc quotaChecker) startQuotaMonitor(ctx context.Context) context.Context { } continue } + if result.SuggestedQueryInterval > 0 && qc.duration != result.SuggestedQueryInterval { + qc.duration = result.SuggestedQueryInterval + ticker.Reset(qc.duration) + } consecutiveFailures = 0 hasQuotaErrors = nil - if !hasQuota { + if !result.HasQuota { cancelWithCause(ErrNoQuota{team: qc.qm.TeamName()}) return } diff --git a/premium/monitor_test.go b/premium/monitor_test.go index 0ce3876bda..a319e84550 100644 --- a/premium/monitor_test.go +++ b/premium/monitor_test.go @@ -10,8 +10,8 @@ import ( ) type quotaResponse struct { - hasQuota bool - err error + result CheckQuotaResult + err error } func newFakeQuotaMonitor(hasQuota ...quotaResponse) *fakeQuotaMonitor { @@ -23,12 +23,12 @@ type fakeQuotaMonitor struct { calls int } -func (f *fakeQuotaMonitor) HasQuota(_ context.Context) (bool, error) { +func (f *fakeQuotaMonitor) CheckQuota(_ context.Context) (CheckQuotaResult, error) { resp := f.responses[f.calls] if f.calls < len(f.responses)-1 { f.calls++ } - return resp.hasQuota, resp.err + return resp.result, resp.err } func (*fakeQuotaMonitor) TeamName() string { @@ -39,7 +39,7 @@ func TestWithCancelOnQuotaExceeded_NoInitialQuota(t *testing.T) { ctx := context.Background() responses := []quotaResponse{ - {false, nil}, + {CheckQuotaResult{HasQuota: false}, nil}, } _, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(responses...)) @@ -50,8 +50,8 @@ func TestWithCancelOnQuotaExceeded_NoQuota(t *testing.T) { ctx := context.Background() responses := []quotaResponse{ - {true, nil}, - {false, nil}, + {CheckQuotaResult{HasQuota: true}, nil}, + {CheckQuotaResult{HasQuota: false}, nil}, } ctx, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(responses...), WithQuotaCheckPeriod(1*time.Millisecond)) require.NoError(t, err) @@ -65,9 +65,9 @@ func TestWithCancelOnQuotaCheckConsecutiveFailures(t *testing.T) { ctx := context.Background() responses := []quotaResponse{ - {true, nil}, - {false, errors.New("test2")}, - {false, errors.New("test3")}, + {CheckQuotaResult{HasQuota: true}, nil}, + {CheckQuotaResult{HasQuota: false}, errors.New("test2")}, + {CheckQuotaResult{HasQuota: false}, errors.New("test3")}, } ctx, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(responses...), diff --git a/premium/usage.go b/premium/usage.go index a9c5d5a844..b180bb1e8e 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -43,6 +43,7 @@ const ( BatchLimitHeader = "x-cq-batch-limit" MinimumUpdateIntervalHeader = "x-cq-minimum-update-interval" MaximumUpdateIntervalHeader = "x-cq-maximum-update-interval" + QueryIntervalHeader = "x-cq-query-interval" ) //go:generate mockgen -package=mocks -destination=../premium/mocks/marketplacemetering.go -source=usage.go AWSMarketplaceClientInterface @@ -55,11 +56,19 @@ type TokenClient interface { GetTokenType() auth.TokenType } +type CheckQuotaResult struct { + // HasQuota is true if the quota has not been exceeded + HasQuota bool + + // SuggestedQueryInterval is the suggested interval to wait before querying the API again + SuggestedQueryInterval time.Duration +} + type QuotaMonitor interface { // TeamName returns the team name TeamName() string - // HasQuota returns true if the quota has not been exceeded - HasQuota(context.Context) (bool, error) + // CheckQuota checks if the quota has been exceeded + CheckQuota(context.Context) (CheckQuotaResult, error) } type UsageClient interface { @@ -359,21 +368,34 @@ func (u *BatchUpdater) TeamName() string { return u.teamName } -func (u *BatchUpdater) HasQuota(ctx context.Context) (bool, error) { +func (u *BatchUpdater) CheckQuota(ctx context.Context) (CheckQuotaResult, error) { if u.awsMarketplaceClient != nil { - return true, nil + return CheckQuotaResult{HasQuota: true}, nil } u.logger.Debug().Str("url", u.url).Str("team", u.teamName).Str("pluginTeam", u.pluginMeta.Team).Str("pluginKind", string(u.pluginMeta.Kind)).Str("pluginName", u.pluginMeta.Name).Msg("checking quota") usage, err := u.apiClient.GetTeamPluginUsageWithResponse(ctx, u.teamName, u.pluginMeta.Team, u.pluginMeta.Kind, u.pluginMeta.Name) if err != nil { - return false, fmt.Errorf("failed to get usage: %w", err) + return CheckQuotaResult{HasQuota: false}, fmt.Errorf("failed to get usage: %w", err) } if usage.StatusCode() != http.StatusOK { - return false, fmt.Errorf("failed to get usage: %s", usage.Status()) + return CheckQuotaResult{HasQuota: false}, fmt.Errorf("failed to get usage: %s", usage.Status()) } - hasQuota := usage.JSON200.RemainingRows == nil || *usage.JSON200.RemainingRows > 0 - return hasQuota, nil + res := CheckQuotaResult{ + HasQuota: usage.JSON200.RemainingRows == nil || *usage.JSON200.RemainingRows > 0, + } + if usage.HTTPResponse == nil { + return res, nil + } + if headerValue := usage.HTTPResponse.Header.Get(QueryIntervalHeader); headerValue != "" { + interval, err := strconv.ParseUint(headerValue, 10, 32) + if interval > 0 { + res.SuggestedQueryInterval = time.Duration(interval) * time.Second + } else { + u.logger.Warn().Err(err).Str(QueryIntervalHeader, headerValue).Msg("failed to parse query interval") + } + } + return res, nil } func (u *BatchUpdater) Close() error { @@ -700,8 +722,8 @@ func (n *NoOpUsageClient) TeamName() string { return n.TeamNameValue } -func (NoOpUsageClient) HasQuota(_ context.Context) (bool, error) { - return true, nil +func (NoOpUsageClient) CheckQuota(_ context.Context) (CheckQuotaResult, error) { + return CheckQuotaResult{HasQuota: true}, nil } func (NoOpUsageClient) Increase(_ uint32) error { diff --git a/premium/usage_test.go b/premium/usage_test.go index 9a3a17b3a6..aa17e2f58b 100644 --- a/premium/usage_test.go +++ b/premium/usage_test.go @@ -113,10 +113,10 @@ func TestUsageService_HasQuota_NoRowsRemaining(t *testing.T) { usageClient := newClient(t, apiClient, WithBatchLimit(0)) - hasQuota, err := usageClient.HasQuota(ctx) + result, err := usageClient.CheckQuota(ctx) require.NoError(t, err) - assert.False(t, hasQuota, "should not have quota") + assert.False(t, result.HasQuota, "should not have quota") } func TestUsageService_HasQuota_WithRowsRemaining(t *testing.T) { @@ -130,10 +130,10 @@ func TestUsageService_HasQuota_WithRowsRemaining(t *testing.T) { usageClient := newClient(t, apiClient, WithBatchLimit(0)) - hasQuota, err := usageClient.HasQuota(ctx) + result, err := usageClient.CheckQuota(ctx) require.NoError(t, err) - assert.True(t, hasQuota, "should have quota") + assert.True(t, result.HasQuota, "should have quota") } func TestUsageService_Increase_ZeroBatchSize(t *testing.T) {