Skip to content

Commit 6ed390c

Browse files
author
Manish Ranjan Mahanta
committed
Erroring out instead of silent override, and bubbling up the error to the caller
Signed-off-by: Manish Ranjan Mahanta <mmahanta@microsoft.com>
1 parent e3349bb commit 6ed390c

4 files changed

Lines changed: 190 additions & 38 deletions

File tree

internal/gcs-sidecar/handlers.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,10 @@ func (b *Bridge) modifyServiceSettings(req *request) (err error) {
526526
log.G(req.ctx).Tracef("Allowed log sources after policy enforcement: %v", allowedLogSources)
527527

528528
// Update the allowed log sources in the settings. This will be forwarded to inbox GCS which expects the log sources in a JSON string format with GUIDs for providers included.
529-
allowedLogSources = etw.UpdateLogSources(req.ctx, allowedLogSources, false, true)
529+
allowedLogSources, err := etw.UpdateLogSources(allowedLogSources, false, true)
530+
if err != nil {
531+
return fmt.Errorf("failed to update log sources: %w", err)
532+
}
530533
settings.Settings = allowedLogSources
531534
}
532535
default:

internal/uvm/log_wcow.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package uvm
44

55
import (
66
"context"
7+
"fmt"
78

89
"github.com/Microsoft/hcsshim/internal/gcs"
910
"github.com/Microsoft/hcsshim/internal/gcs/prot"
@@ -69,15 +70,17 @@ func (uvm *UtilityVM) SetLogSources(ctx context.Context) error {
6970
// For confidential WCOw, we skip the adding guids to the log sources as the sidecar-GCS will verify the
7071
// allowed log sources against policy and append the necessary GUIDs to the ones allowed. Rest are dropped.
7172
// For non-confidential WCOW, we include the GUIDs in the log sources as the hcsshim communicates directly with the inboxGCS.
72-
settings := etw.UpdateLogSources(ctx, uvm.logSources, uvm.defaultLogSourcesEnabled, !uvm.HasConfidentialPolicy())
73-
73+
settings, err := etw.UpdateLogSources(uvm.logSources, uvm.defaultLogSourcesEnabled, !uvm.HasConfidentialPolicy())
74+
if err != nil {
75+
return fmt.Errorf("failed to parse log sources: %w", err)
76+
}
7477
req := guestrequest.LogForwardServiceRPCRequest{
7578
RPCType: guestrequest.RPCModifyServiceSettings,
7679
Settings: settings,
7780
}
78-
err := uvm.gc.ModifyServiceSettings(ctx, prot.LogForwardService, req)
81+
err = uvm.gc.ModifyServiceSettings(ctx, prot.LogForwardService, req)
7982
if err != nil {
80-
return err
83+
return fmt.Errorf("failed to modify service settings: %w", err)
8184
}
8285
}
8386
return nil

internal/vm/vmutils/etw/provider_map.go

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
package etw
22

33
import (
4-
"context"
54
"encoding/base64"
65
"encoding/json"
6+
"fmt"
77
"strings"
88

99
"github.com/Microsoft/go-winio/pkg/guid"
10-
"github.com/Microsoft/hcsshim/internal/log"
1110
)
1211

1312
// Log Sources JSON structure
@@ -94,17 +93,15 @@ func mergeLogSources(resultSources []Source, userSources []Source) []Source {
9493
}
9594

9695
// decodeAndUnmarshalLogSources decodes a base64-encoded JSON string and unmarshals it into a LogSourcesInfo.
97-
func decodeAndUnmarshalLogSources(ctx context.Context, base64EncodedJSONLogConfig string) (LogSourcesInfo, error) {
96+
func decodeAndUnmarshalLogSources(base64EncodedJSONLogConfig string) (LogSourcesInfo, error) {
9897
jsonBytes, err := base64.StdEncoding.DecodeString(base64EncodedJSONLogConfig)
9998
if err != nil {
100-
log.G(ctx).Errorf("Error decoding base64 log config: %v", err)
101-
return LogSourcesInfo{}, err
99+
return LogSourcesInfo{}, fmt.Errorf("error decoding base64 log config: %w", err)
102100
}
103101

104102
var userLogSources LogSourcesInfo
105103
if err := json.Unmarshal(jsonBytes, &userLogSources); err != nil {
106-
log.G(ctx).Errorf("Error unmarshalling user log config: %v", err)
107-
return LogSourcesInfo{}, err
104+
return LogSourcesInfo{}, fmt.Errorf("error unmarshalling user log config: %w", err)
108105
}
109106
return userLogSources, nil
110107
}
@@ -119,14 +116,13 @@ func trimGUID(in string) string {
119116

120117
// resolveGUIDsWithLookup normalizes and fills in provider GUIDs from the well-known ETW map
121118
// for all providers across all sources. Providers with an invalid GUID are warned and skipped.
122-
func resolveGUIDsWithLookup(ctx context.Context, sources []Source) []Source {
119+
func resolveGUIDsWithLookup(sources []Source) ([]Source, error) {
123120
for i, src := range sources {
124121
for j, provider := range src.Providers {
125122
if provider.ProviderGUID != "" {
126123
guid, err := guid.FromString(trimGUID(provider.ProviderGUID))
127124
if err != nil {
128-
log.G(ctx).Warningf("Skipping invalid GUID %q for provider %q: %v", provider.ProviderGUID, provider.ProviderName, err)
129-
continue
125+
return nil, fmt.Errorf("invalid GUID %q for provider %q: %w", provider.ProviderGUID, provider.ProviderName, err)
130126
}
131127
sources[i].Providers[j].ProviderGUID = strings.ToLower(guid.String())
132128
}
@@ -135,77 +131,86 @@ func resolveGUIDsWithLookup(ctx context.Context, sources []Source) []Source {
135131
}
136132
}
137133
}
138-
return sources
134+
return sources, nil
139135
}
140136

141137
// stripRedundantGUIDs removes the GUID from providers where both Name and GUID are present and
142138
// the GUID matches the well-known lookup by name. This ensures sidecar-GCS prefers name-based
143-
// policy verification. Invalid GUIDs are warned and left as-is after normalization.
144-
func stripRedundantGUIDs(ctx context.Context, sources []Source) []Source {
139+
// policy verification. Invalid GUIDs are errored out.
140+
func stripRedundantGUIDs(sources []Source) ([]Source, error) {
145141
for i, src := range sources {
146142
for j, provider := range src.Providers {
147143
if provider.ProviderName == "" || provider.ProviderGUID == "" {
148144
continue
149145
}
150146
guid, err := guid.FromString(trimGUID(provider.ProviderGUID))
151147
if err != nil {
152-
log.G(ctx).Warningf("Skipping invalid GUID %q for provider %q: %v", provider.ProviderGUID, provider.ProviderName, err)
153-
continue
148+
return nil, fmt.Errorf("invalid GUID %q for provider %q: %w", provider.ProviderGUID, provider.ProviderName, err)
154149
}
155150
if strings.EqualFold(guid.String(), getProviderGUIDFromName(provider.ProviderName)) {
156151
sources[i].Providers[j].ProviderGUID = ""
157152
} else {
153+
// If the GUID doesn't match the well-known GUID for the provider name,
154+
// we keep it but ensure it's normalized to lowercase without braces.
155+
// However, we remove the provider name to avoid incorrect policy matches
156+
// in sidecar-GCS, since the GUID is the source of truth in this case.
157+
sources[i].Providers[j].ProviderName = ""
158158
sources[i].Providers[j].ProviderGUID = strings.ToLower(guid.String())
159159
}
160160
}
161161
}
162-
return sources
162+
return sources, nil
163163
}
164164

165165
// applyGUIDPolicy applies GUID resolution or stripping to all sources depending on the includeGUIDs flag.
166166
// See resolveGUIDsWithLookup and stripRedundantGUIDs for the respective behaviors.
167-
func applyGUIDPolicy(ctx context.Context, sources []Source, includeGUIDs bool) []Source {
167+
func applyGUIDPolicy(sources []Source, includeGUIDs bool) ([]Source, error) {
168168
if len(sources) == 0 {
169-
return sources
169+
return sources, nil
170170
}
171171
if includeGUIDs {
172-
return resolveGUIDsWithLookup(ctx, sources)
172+
return resolveGUIDsWithLookup(sources)
173173
}
174-
return stripRedundantGUIDs(ctx, sources)
174+
return stripRedundantGUIDs(sources)
175175
}
176176

177177
// marshalAndEncodeLogSources marshals the given LogSourcesInfo to JSON and encodes it as a base64 string.
178178
// On error, it logs and returns the original fallback string.
179-
func marshalAndEncodeLogSources(ctx context.Context, logCfg LogSourcesInfo, fallback string) (string, error) {
179+
func marshalAndEncodeLogSources(logCfg LogSourcesInfo) (string, error) {
180180
jsonBytes, err := json.Marshal(logCfg)
181181
if err != nil {
182-
log.G(ctx).Errorf("Error marshalling log config: %v", err)
183-
return fallback, err
182+
return "", fmt.Errorf("error marshalling log config: %w", err)
184183
}
185184
return base64.StdEncoding.EncodeToString(jsonBytes), nil
186185
}
187186

188187
// UpdateLogSources updates the user provided log sources with the default log sources based on the
189188
// configuration and returns the updated log sources as a base64 encoded JSON string.
190189
// If there is an error in the process, it returns the original user provided log sources string.
191-
func UpdateLogSources(ctx context.Context, base64EncodedJSONLogConfig string, useDefaultLogSources bool, includeGUIDs bool) string {
190+
func UpdateLogSources(base64EncodedJSONLogConfig string, useDefaultLogSources bool, includeGUIDs bool) (string, error) {
192191
var resultLogCfg LogSourcesInfo
193192
if useDefaultLogSources {
194193
resultLogCfg = defaultLogSourcesInfo
195194
}
196195

197196
if base64EncodedJSONLogConfig != "" {
198-
userLogSources, err := decodeAndUnmarshalLogSources(ctx, base64EncodedJSONLogConfig)
199-
if err == nil {
200-
resultLogCfg.LogConfig.Sources = mergeLogSources(resultLogCfg.LogConfig.Sources, userLogSources.LogConfig.Sources)
197+
userLogSources, err := decodeAndUnmarshalLogSources(base64EncodedJSONLogConfig)
198+
if err != nil {
199+
return "", fmt.Errorf("failed to decode and unmarshal user log sources: %w", err)
201200
}
201+
resultLogCfg.LogConfig.Sources = mergeLogSources(resultLogCfg.LogConfig.Sources, userLogSources.LogConfig.Sources)
202+
202203
}
203204

204-
resultLogCfg.LogConfig.Sources = applyGUIDPolicy(ctx, resultLogCfg.LogConfig.Sources, includeGUIDs)
205+
var err error
206+
resultLogCfg.LogConfig.Sources, err = applyGUIDPolicy(resultLogCfg.LogConfig.Sources, includeGUIDs)
207+
if err != nil {
208+
return "", fmt.Errorf("failed to apply GUID policy: %w", err)
209+
}
205210

206-
result, err := marshalAndEncodeLogSources(ctx, resultLogCfg, base64EncodedJSONLogConfig)
211+
result, err := marshalAndEncodeLogSources(resultLogCfg)
207212
if err != nil {
208-
return base64EncodedJSONLogConfig
213+
return "", fmt.Errorf("failed to marshal and encode log sources: %w", err)
209214
}
210-
return result
215+
return result, nil
211216
}

internal/vm/vmutils/etw/provider_map_test.go

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package etw
22

33
import (
4-
"context"
54
"encoding/base64"
65
"encoding/json"
76
"reflect"
@@ -108,7 +107,10 @@ func TestUpdateLogSources_Combinations(t *testing.T) {
108107
t.Run(tt.name, func(t *testing.T) {
109108
defaultLogSourcesInfo = cloneLogSourcesInfo(originalDefaults)
110109

111-
gotEncoded := UpdateLogSources(context.Background(), tt.base64Input, tt.useDefault, tt.includeGUIDs)
110+
gotEncoded, err := UpdateLogSources(tt.base64Input, tt.useDefault, tt.includeGUIDs)
111+
if err != nil {
112+
t.Fatalf("UpdateLogSources returned error: %v", err)
113+
}
112114
got := mustDecodeLogSources(t, gotEncoded)
113115

114116
if !reflect.DeepEqual(got, tt.expectedLogCfg) {
@@ -243,3 +245,142 @@ func mustDecodeLogSources(t *testing.T, encoded string) LogSourcesInfo {
243245
}
244246
return cfg
245247
}
248+
249+
func TestUpdateLogSources_ErrorCases(t *testing.T) {
250+
originalDefaults := cloneLogSourcesInfo(defaultLogSourcesInfo)
251+
t.Cleanup(func() {
252+
defaultLogSourcesInfo = cloneLogSourcesInfo(originalDefaults)
253+
})
254+
255+
// Build a config with an invalid GUID to trigger applyGUIDPolicy errors.
256+
invalidGUIDConfig := LogSourcesInfo{
257+
LogConfig: LogConfig{
258+
Sources: []Source{
259+
{
260+
Type: "ETW",
261+
Providers: []EtwProvider{
262+
{
263+
ProviderName: "SomeProvider",
264+
ProviderGUID: "not-a-valid-guid",
265+
},
266+
},
267+
},
268+
},
269+
},
270+
}
271+
invalidGUIDBase64 := mustEncodeLogSources(t, invalidGUIDConfig)
272+
273+
// Build a config with an invalid GUID but no provider name (only GUID set),
274+
// to trigger the resolveGUIDsWithLookup path specifically.
275+
invalidGUIDOnlyConfig := LogSourcesInfo{
276+
LogConfig: LogConfig{
277+
Sources: []Source{
278+
{
279+
Type: "ETW",
280+
Providers: []EtwProvider{
281+
{
282+
ProviderGUID: "zzz-invalid",
283+
},
284+
},
285+
},
286+
},
287+
},
288+
}
289+
invalidGUIDOnlyBase64 := mustEncodeLogSources(t, invalidGUIDOnlyConfig)
290+
291+
tests := []struct {
292+
name string
293+
base64Input string
294+
useDefault bool
295+
includeGUIDs bool
296+
errContains string
297+
}{
298+
{
299+
name: "invalid_base64_input",
300+
base64Input: "not-valid-base64!@#$",
301+
useDefault: false,
302+
includeGUIDs: false,
303+
errContains: "failed to decode and unmarshal user log sources",
304+
},
305+
{
306+
name: "valid_base64_invalid_json",
307+
base64Input: base64.StdEncoding.EncodeToString([]byte("{{not json}}")),
308+
useDefault: false,
309+
includeGUIDs: false,
310+
errContains: "failed to decode and unmarshal user log sources",
311+
},
312+
{
313+
name: "invalid_base64_with_defaults",
314+
base64Input: "!!!bad-base64!!!",
315+
useDefault: true,
316+
includeGUIDs: false,
317+
errContains: "failed to decode and unmarshal user log sources",
318+
},
319+
{
320+
name: "invalid_base64_with_defaults_and_guids",
321+
base64Input: "???",
322+
useDefault: true,
323+
includeGUIDs: true,
324+
errContains: "failed to decode and unmarshal user log sources",
325+
},
326+
{
327+
name: "valid_base64_malformed_json_structure",
328+
base64Input: base64.StdEncoding.EncodeToString([]byte(`{"LogConfig": {"sources": "not_an_array"}}`)),
329+
useDefault: false,
330+
includeGUIDs: false,
331+
errContains: "failed to decode and unmarshal user log sources",
332+
},
333+
{
334+
name: "invalid_guid_with_includeGUIDs_resolveGUIDsWithLookup",
335+
base64Input: invalidGUIDBase64,
336+
useDefault: false,
337+
includeGUIDs: true,
338+
errContains: "failed to apply GUID policy",
339+
},
340+
{
341+
name: "invalid_guid_without_includeGUIDs_stripRedundantGUIDs",
342+
base64Input: invalidGUIDBase64,
343+
useDefault: false,
344+
includeGUIDs: false,
345+
errContains: "failed to apply GUID policy",
346+
},
347+
{
348+
name: "invalid_guid_only_no_name_with_includeGUIDs",
349+
base64Input: invalidGUIDOnlyBase64,
350+
useDefault: false,
351+
includeGUIDs: true,
352+
errContains: "failed to apply GUID policy",
353+
},
354+
{
355+
name: "invalid_guid_with_defaults_and_includeGUIDs",
356+
base64Input: invalidGUIDBase64,
357+
useDefault: true,
358+
includeGUIDs: true,
359+
errContains: "failed to apply GUID policy",
360+
},
361+
{
362+
name: "invalid_guid_with_defaults_without_includeGUIDs",
363+
base64Input: invalidGUIDBase64,
364+
useDefault: true,
365+
includeGUIDs: false,
366+
errContains: "failed to apply GUID policy",
367+
},
368+
}
369+
370+
for _, tt := range tests {
371+
t.Run(tt.name, func(t *testing.T) {
372+
defaultLogSourcesInfo = cloneLogSourcesInfo(originalDefaults)
373+
374+
got, err := UpdateLogSources(tt.base64Input, tt.useDefault, tt.includeGUIDs)
375+
if err == nil {
376+
t.Fatalf("expected error containing %q, got nil (result: %q)", tt.errContains, got)
377+
}
378+
if !strings.Contains(err.Error(), tt.errContains) {
379+
t.Fatalf("expected error containing %q, got: %v", tt.errContains, err)
380+
}
381+
if got != "" {
382+
t.Fatalf("expected empty result on error, got %q", got)
383+
}
384+
})
385+
}
386+
}

0 commit comments

Comments
 (0)