diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c6a599ad20..11da447395 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,7 +36,7 @@ jobs: - name: Run golangci-lint uses: golangci/golangci-lint-action@v9 with: - version: v2.1 + version: v2.4 args: >- --verbose --max-issues-per-linter=0 diff --git a/internal/hcs/callback.go b/internal/hcs/callback.go index 7b27173c3a..80c23fd2a4 100644 --- a/internal/hcs/callback.go +++ b/internal/hcs/callback.go @@ -3,20 +3,27 @@ package hcs import ( + "context" + "encoding/json" "fmt" "sync" + "sync/atomic" "syscall" + "github.com/sirupsen/logrus" + "github.com/Microsoft/hcsshim/internal/interop" + "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" "github.com/Microsoft/hcsshim/internal/vmcompute" - "github.com/sirupsen/logrus" ) var ( - nextCallback uintptr - callbackMap = map[uintptr]*notificationWatcherContext{} + // TODO: don't delete notification contexts on close, so callback can handle delayed notifications + + // used to lock [callbackMap]. callbackMapLock = sync.RWMutex{} + callbackMap = map[callbackNumber]*notificationWatcherContext{} notificationWatcherCallback = syscall.NewCallback(notificationWatcher) @@ -87,6 +94,31 @@ func (hn hcsNotification) String() string { } } +// HCS callbacks take the form: +// +// typedef void (CALLBACK *HCS_NOTIFICATION_CALLBACK)( +// _In_ DWORD notificationType, +// _In_opt_ void* context, +// _In_ HRESULT notificationStatus, +// _In_opt_ PCWSTR notificationData +// ); +// +// where the context is a pointer to the data that is associated with a particular notification. +// +// However, since Golang can freely move structs, pointer values are not stable. +// Therefore, interpret the pointer as the unique ID of a [notificationWatcherContext] +// stored in [callbackMap]. +// +// Note: Pointer stability via converting to [unsafe.Pointer] for syscalls is only guaranteed +// until the syscall returns, and the same pointer value is therefore invalid across different +// syscall invocations. +// See point (4) of the [unsafe.Pointer] documentation. +type callbackNumber uintptr + +var callbackCounter atomic.Uintptr + +func nextCallback() callbackNumber { return callbackNumber(callbackCounter.Add(1)) } + type notificationChannel chan error type notificationWatcherContext struct { @@ -132,32 +164,81 @@ func closeChannels(channels notificationChannels) { } } -func notificationWatcher(notificationType hcsNotification, callbackNumber uintptr, notificationStatus uintptr, notificationData *uint16) uintptr { - var result error - if int32(notificationStatus) < 0 { - result = interop.Win32FromHresult(notificationStatus) +func notificationWatcher( + notificationType hcsNotification, + callbackNum callbackNumber, + notificationStatus uintptr, + notificationData *uint16, +) uintptr { + ctx, entry := log.SetEntry(context.Background(), logrus.Fields{ + logfields.CallbackNumber: callbackNum, + "notification-type": notificationType.String(), + }) + + result := processNotification(ctx, notificationStatus, notificationData) + if result != nil { + entry.Data[logrus.ErrorKey] = result } callbackMapLock.RLock() - context := callbackMap[callbackNumber] + callbackCtx := callbackMap[callbackNum] callbackMapLock.RUnlock() - if context == nil { + if callbackCtx == nil { + entry.Warn("received HCS notification for unknown callback number") return 0 } - log := logrus.WithFields(logrus.Fields{ - "notification-type": notificationType.String(), - "system-id": context.systemID, - }) - if context.processID != 0 { - log.Data[logfields.ProcessID] = context.processID + entry.Data[logfields.SystemID] = callbackCtx.systemID + if callbackCtx.processID != 0 { + entry.Data[logfields.ProcessID] = callbackCtx.processID } - log.Debug("HCS notification") + entry.Debug("received HCS notification") - if channel, ok := context.channels[notificationType]; ok { + if channel, ok := callbackCtx.channels[notificationType]; ok { channel <- result } return 0 } + +// processNotification parses and validates HCS notifications and returns the result as an error. +func processNotification(ctx context.Context, notificationStatus uintptr, notificationData *uint16) (err error) { + // TODO: merge/unify with [processHcsResult] + status := int32(notificationStatus) + if status < 0 { + err = interop.Win32FromHresult(notificationStatus) + } + + if notificationData == nil { + return err + } + + // don't call CoTaskMemFree since HCS_NOTIFICATION_CALLBACK's notificationData is PCWSTR. + resultJSON := interop.ConvertString(notificationData) + result := &hcsResult{} + if jsonErr := json.Unmarshal([]byte(resultJSON), result); jsonErr != nil { + log.G(ctx).WithFields(logrus.Fields{ + logfields.JSON: resultJSON, + logrus.ErrorKey: jsonErr, + }).Warn("could not unmarshal HCS result") + return err + } + log.G(ctx).WithField("result", result).Trace("parsed notification data") + + // the HResult and data payload should have the same error value + if result.Error < 0 && status < 0 && status != result.Error { + log.G(ctx).WithFields(logrus.Fields{ + "status": status, + "data": result.Error, + }).Warn("mismatched notification status and data HResult values") + } + + if len(result.ErrorEvents) > 0 { + return &resultError{ + Err: err, + Events: result.ErrorEvents, + } + } + return err +} diff --git a/internal/hcs/errors.go b/internal/hcs/errors.go index 3e10f5c7e0..6546853d57 100644 --- a/internal/hcs/errors.go +++ b/internal/hcs/errors.go @@ -8,9 +8,13 @@ import ( "errors" "fmt" "net" + "strings" "syscall" + "github.com/sirupsen/logrus" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/logfields" ) var ( @@ -102,43 +106,95 @@ type ErrorEvent struct { //Data []EventData `json:"Data,omitempty"` // Omit this as HCS doesn't encode this well. It's more confusing to include. It is however logged in debug mode (see processHcsResult function) } -type hcsResult struct { - Error int32 - ErrorMessage string - ErrorEvents []ErrorEvent `json:"ErrorEvents,omitempty"` +func (ev *ErrorEvent) String() string { + sb := new(strings.Builder) + ev.writeTo(sb) + return sb.String() } -func (ev *ErrorEvent) String() string { - evs := "[Event Detail: " + ev.Message +func (ev *ErrorEvent) writeTo(b *strings.Builder) { + // rough wag at needed length + b.Grow(64 + len(ev.Message) + len(ev.StackTrace) + len(ev.Provider) + len(ev.Source)) + + // [strings.Builder] Write* functions always return nil errors + _, _ = b.WriteString("[Event Detail: " + ev.Message) if ev.StackTrace != "" { - evs += " Stack Trace: " + ev.StackTrace + _, _ = b.WriteString(" Stack Trace: " + ev.StackTrace) } if ev.Provider != "" { - evs += " Provider: " + ev.Provider + _, _ = b.WriteString(" Provider: " + ev.Provider) } if ev.EventID != 0 { - evs = fmt.Sprintf("%s EventID: %d", evs, ev.EventID) + fmt.Fprintf(b, " EventID: %d", ev.EventID) } if ev.Flags != 0 { - evs = fmt.Sprintf("%s flags: %d", evs, ev.Flags) + fmt.Fprintf(b, " flags: %d", ev.Flags) } if ev.Source != "" { - evs += " Source: " + ev.Source + _, _ = b.WriteString(" Source: " + ev.Source) } - evs += "]" - return evs + _ = b.WriteByte(']') +} + +type hcsResult struct { + Error int32 + ErrorMessage string // ErrorMessage should be the same as `windows.Errno(Err).Error()`. + ErrorEvents []ErrorEvent `json:"ErrorEvents,omitempty"` + // TODO: AttributionRecords } func processHcsResult(ctx context.Context, resultJSON string) []ErrorEvent { - if resultJSON != "" { - result := &hcsResult{} - if err := json.Unmarshal([]byte(resultJSON), result); err != nil { - log.G(ctx).WithError(err).Warning("Could not unmarshal HCS result") - return nil - } - return result.ErrorEvents + if resultJSON == "" { + return nil } - return nil + + result := &hcsResult{} + if err := json.Unmarshal([]byte(resultJSON), result); err != nil { + log.G(ctx).WithFields(logrus.Fields{ + logfields.JSON: resultJSON, + logrus.ErrorKey: err, + }).Warn("Could not unmarshal HCS result") + return nil + } + return result.ErrorEvents +} + +// TODO: move [resultError] and [ErrorEvent] to schema2 and handle parsing results/data directly in vmcompute +// See: https://learn.microsoft.com/en-us/virtualization/api/hcs/schemareference#ResultError + +// resultError describes an HCS operation result and allows retaining the ErrorEvents +// when an [error] is expected but an [HcsError] is not appropriate (i.e., the operation is not known). +// +// It is used in [notificationWatcher] to send the [ErrorEvent]s in a callback's notification data +// to the corresponding [waitForNotification] via a [notificationChannel]. +type resultError struct { + Err error + Events []ErrorEvent +} + +func (e *resultError) Error() string { + return appendErrorEvents(e.Err.Error(), e.Events) +} + +func (e *resultError) Is(target error) bool { + return errors.Is(e.Err, target) +} + +func (e *resultError) Unwrap() error { + return e.Err +} + +// getEvents checks to see if err is a [resultError], and if so, returns the unwrapped error +// and [ErrorEvent]s. +// +// Currently, only [notificationWatcher] creates [resultError]s, and they are subsequently +// handled in [waitForNotification]. +func getEvents(err error) ([]ErrorEvent, error) { + var rErr *resultError + if errors.As(err, &rErr) { + return rErr.Events, rErr.Err + } + return nil, err } type HcsError struct { @@ -149,12 +205,22 @@ type HcsError struct { var _ net.Error = &HcsError{} -func (e *HcsError) Error() string { - s := e.Op + ": " + e.Err.Error() - for _, ev := range e.Events { - s += "\n" + ev.String() +func makeHCSError(op string, err error, events []ErrorEvent) error { + // Don't double wrap errors + var e *HcsError + if errors.As(err, &e) { + return err + } + + return &HcsError{ + Op: op, + Err: err, + Events: events, } - return s +} + +func (e *HcsError) Error() string { + return appendErrorEvents(e.Op+": "+e.Err.Error(), e.Events) } func (e *HcsError) Is(target error) bool { @@ -194,11 +260,7 @@ type SystemError struct { var _ net.Error = &SystemError{} func (e *SystemError) Error() string { - s := e.Op + " " + e.ID + ": " + e.Err.Error() - for _, ev := range e.Events { - s += "\n" + ev.String() - } - return s + return appendErrorEvents(fmt.Sprintf("%s %s: %s", e.Op, e.ID, e.Err.Error()), e.Events) } func makeSystemError(system *System, op string, err error, events []ErrorEvent) error { @@ -228,11 +290,7 @@ type ProcessError struct { var _ net.Error = &ProcessError{} func (e *ProcessError) Error() string { - s := fmt.Sprintf("%s %s:%d: %s", e.Op, e.SystemID, e.Pid, e.Err.Error()) - for _, ev := range e.Events { - s += "\n" + ev.String() - } - return s + return appendErrorEvents(fmt.Sprintf("%s %s:%d: %s", e.Op, e.SystemID, e.Pid, e.Err.Error()), e.Events) } func makeProcessError(process *Process, op string, err error, events []ErrorEvent) error { @@ -241,6 +299,7 @@ func makeProcessError(process *Process, op string, err error, events []ErrorEven if errors.As(err, &e) { return err } + return &ProcessError{ Pid: process.Pid(), SystemID: process.SystemID(), @@ -252,6 +311,22 @@ func makeProcessError(process *Process, op string, err error, events []ErrorEven } } +// common formatting for error strings followed by event data, +func appendErrorEvents(s string, events []ErrorEvent) string { + if len(events) == 0 { + return s + } + + sb := new(strings.Builder) + _, _ = sb.WriteString(s) + for _, ev := range events { + // don't join with newlines since those are ... awkward within error strings + _, _ = sb.WriteString(": ") + ev.writeTo(sb) + } + return sb.String() +} + // IsNotExist checks if an error is caused by the Container or Process not existing. // Note: Currently, ErrElementNotFound can mean that a Process has either // already exited, or does not exist. Both IsAlreadyStopped and IsNotExist diff --git a/internal/hcs/process.go b/internal/hcs/process.go index fef2bf546c..d8d9ea63a0 100644 --- a/internal/hcs/process.go +++ b/internal/hcs/process.go @@ -12,14 +12,17 @@ import ( "syscall" "time" + "github.com/sirupsen/logrus" "go.opencensus.io/trace" "github.com/Microsoft/hcsshim/internal/cow" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/logfields" "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" "github.com/Microsoft/hcsshim/internal/vmcompute" + "github.com/Microsoft/hcsshim/internal/winapi" ) type Process struct { @@ -32,7 +35,7 @@ type Process struct { stdin io.WriteCloser stdout io.ReadCloser stderr io.ReadCloser - callbackNumber uintptr + callbackNumber callbackNumber killSignalDelivered bool closedWaitOnce sync.Once @@ -97,7 +100,7 @@ func (process *Process) Signal(ctx context.Context, options interface{}) (bool, operation := "hcs::Process::Signal" - if process.handle == 0 { + if winapi.IsInvalidHandle(process.handle) { return false, makeProcessError(process, operation, ErrAlreadyClosed, nil) } @@ -116,13 +119,21 @@ func (process *Process) Signal(ctx context.Context, options interface{}) (bool, } // Kill signals the process to terminate but does not wait for it to finish terminating. -func (process *Process) Kill(ctx context.Context) (bool, error) { +func (process *Process) Kill(ctx context.Context) (_ bool, err error) { + operation := "hcs::Process::Kill" + ctx, span := oc.StartSpan(ctx, operation) + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + span.AddAttributes( + trace.StringAttribute(logfields.SystemID, process.SystemID()), + trace.Int64Attribute(logfields.ProcessID, int64(process.processID))) + + ctxNoCancel := context.WithoutCancel(ctx) + process.handleLock.RLock() defer process.handleLock.RUnlock() - operation := "hcs::Process::Kill" - - if process.handle == 0 { + if winapi.IsInvalidHandle(process.handle) { return false, makeProcessError(process, operation, ErrAlreadyClosed, nil) } @@ -139,6 +150,9 @@ func (process *Process) Kill(ctx context.Context) (bool, error) { return true, nil } + // NOTE: this re-registers callbacks for the same underlying compute system and process, + // but with a different handle, which is ... excessive. + // HCS serializes the signals sent to a target pid per compute system handle. // To avoid SIGKILL being serialized behind other signals, we open a new compute // system handle to deliver the kill signal. @@ -154,10 +168,10 @@ func (process *Process) Kill(ctx context.Context) (bool, error) { log.G(ctx).WithField("err", err).Error("Terminate() call failed") return false, err } - process.system.Close() + process.system.CloseCtx(ctxNoCancel) //nolint:errcheck return true, nil } - defer hcsSystem.Close() + defer hcsSystem.CloseCtx(ctxNoCancel) //nolint:errcheck newProcessHandle, err := hcsSystem.OpenProcess(ctx, process.Pid()) if err != nil { @@ -169,7 +183,7 @@ func (process *Process) Kill(ctx context.Context) (bool, error) { return false, err } } - defer newProcessHandle.Close() + defer newProcessHandle.CloseCtx(ctxNoCancel) //nolint:errcheck resultJSON, err := vmcompute.HcsTerminateProcess(ctx, newProcessHandle.handle) if err != nil { @@ -214,20 +228,19 @@ func (process *Process) waitBackground() { ctx, span := oc.StartSpan(context.Background(), operation) defer span.End() span.AddAttributes( - trace.StringAttribute("cid", process.SystemID()), - trace.Int64Attribute("pid", int64(process.processID))) + trace.StringAttribute(logfields.SystemID, process.SystemID()), + trace.Int64Attribute(logfields.ProcessID, int64(process.processID))) var ( - err error exitCode = -1 propertiesJSON string resultJSON string ) - err = waitForNotification(ctx, process.callbackNumber, hcsNotificationProcessExited, nil) + events, err := waitForNotification(ctx, process.callbackNumber, hcsNotificationProcessExited, nil) if err != nil { - err = makeProcessError(process, operation, err, nil) - log.G(ctx).WithError(err).Error("failed wait") + err = makeProcessError(process, operation, err, events) + log.G(ctx).WithError(err).Error("failed wait on process exit") } else { process.handleLock.RLock() defer process.handleLock.RUnlock() @@ -287,7 +300,7 @@ func (process *Process) ResizeConsole(ctx context.Context, width, height uint16) operation := "hcs::Process::ResizeConsole" - if process.handle == 0 { + if winapi.IsInvalidHandle(process.handle) { return makeProcessError(process, operation, ErrAlreadyClosed, nil) } modifyRequest := hcsschema.ProcessModifyRequest{ @@ -333,13 +346,13 @@ func (process *Process) StdioLegacy() (_ io.WriteCloser, _ io.ReadCloser, _ io.R defer span.End() defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes( - trace.StringAttribute("cid", process.SystemID()), - trace.Int64Attribute("pid", int64(process.processID))) + trace.StringAttribute(logfields.SystemID, process.SystemID()), + trace.Int64Attribute(logfields.ProcessID, int64(process.processID))) process.handleLock.RLock() defer process.handleLock.RUnlock() - if process.handle == 0 { + if winapi.IsInvalidHandle(process.handle) { return nil, nil, nil, makeProcessError(process, operation, ErrAlreadyClosed, nil) } @@ -382,13 +395,13 @@ func (process *Process) CloseStdin(ctx context.Context) (err error) { defer span.End() defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes( - trace.StringAttribute("cid", process.SystemID()), - trace.Int64Attribute("pid", int64(process.processID))) + trace.StringAttribute(logfields.SystemID, process.SystemID()), + trace.Int64Attribute(logfields.ProcessID, int64(process.processID))) process.handleLock.RLock() defer process.handleLock.RUnlock() - if process.handle == 0 { + if winapi.IsInvalidHandle(process.handle) { return makeProcessError(process, operation, ErrAlreadyClosed, nil) } @@ -428,13 +441,13 @@ func (process *Process) CloseStdout(ctx context.Context) (err error) { defer span.End() defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes( - trace.StringAttribute("cid", process.SystemID()), - trace.Int64Attribute("pid", int64(process.processID))) + trace.StringAttribute(logfields.SystemID, process.SystemID()), + trace.Int64Attribute(logfields.ProcessID, int64(process.processID))) process.handleLock.Lock() defer process.handleLock.Unlock() - if process.handle == 0 { + if winapi.IsInvalidHandle(process.handle) { return nil } @@ -452,13 +465,13 @@ func (process *Process) CloseStderr(ctx context.Context) (err error) { defer span.End() defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes( - trace.StringAttribute("cid", process.SystemID()), - trace.Int64Attribute("pid", int64(process.processID))) + trace.StringAttribute(logfields.SystemID, process.SystemID()), + trace.Int64Attribute(logfields.ProcessID, int64(process.processID))) process.handleLock.Lock() defer process.handleLock.Unlock() - if process.handle == 0 { + if winapi.IsInvalidHandle(process.handle) { return nil } @@ -473,20 +486,28 @@ func (process *Process) CloseStderr(ctx context.Context) (err error) { // Close cleans up any state associated with the process but does not kill // or wait on it. -func (process *Process) Close() (err error) { +func (process *Process) Close() error { + return process.CloseCtx(context.Background()) +} + +// CloseCtx is similar to [System.Close], but accepts a context. +// +// The context is used for all operations, including waits, so timeouts/cancellations may prevent +// proper system cleanup. +func (process *Process) CloseCtx(ctx context.Context) (err error) { operation := "hcs::Process::Close" - ctx, span := oc.StartSpan(context.Background(), operation) + ctx, span := oc.StartSpan(ctx, operation) defer span.End() defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes( - trace.StringAttribute("cid", process.SystemID()), - trace.Int64Attribute("pid", int64(process.processID))) + trace.StringAttribute(logfields.SystemID, process.SystemID()), + trace.Int64Attribute(logfields.ProcessID, int64(process.processID))) process.handleLock.Lock() defer process.handleLock.Unlock() // Don't double free this - if process.handle == 0 { + if winapi.IsInvalidHandle(process.handle) { return nil } @@ -523,7 +544,16 @@ func (process *Process) Close() (err error) { return nil } +// Requires holding [Process.handleLock]. func (process *Process) registerCallback(ctx context.Context) error { + callbackNum := nextCallback() + + log.G(ctx).WithFields(logrus.Fields{ + logfields.SystemID: process.SystemID(), + logfields.ProcessID: process.processID, + logfields.CallbackNumber: callbackNum, + }).Trace("register process callback") + callbackContext := ¬ificationWatcherContext{ channels: newProcessChannels(), systemID: process.SystemID(), @@ -531,26 +561,31 @@ func (process *Process) registerCallback(ctx context.Context) error { } callbackMapLock.Lock() - callbackNumber := nextCallback - nextCallback++ - callbackMap[callbackNumber] = callbackContext + callbackMap[callbackNum] = callbackContext callbackMapLock.Unlock() - callbackHandle, err := vmcompute.HcsRegisterProcessCallback(ctx, process.handle, notificationWatcherCallback, callbackNumber) + callbackHandle, err := vmcompute.HcsRegisterProcessCallback(ctx, process.handle, notificationWatcherCallback, uintptr(callbackNum)) if err != nil { return err } callbackContext.handle = callbackHandle - process.callbackNumber = callbackNumber + process.callbackNumber = callbackNum return nil } +// Requires holding [Process.handleLock]. func (process *Process) unregisterCallback(ctx context.Context) error { - callbackNumber := process.callbackNumber + callbackNum := process.callbackNumber + + log.G(ctx).WithFields(logrus.Fields{ + logfields.SystemID: process.SystemID(), + logfields.ProcessID: process.processID, + logfields.CallbackNumber: callbackNum, + }).Trace("unregister process callback") callbackMapLock.RLock() - callbackContext := callbackMap[callbackNumber] + callbackContext := callbackMap[callbackNum] callbackMapLock.RUnlock() if callbackContext == nil { @@ -559,7 +594,7 @@ func (process *Process) unregisterCallback(ctx context.Context) error { handle := callbackContext.handle - if handle == 0 { + if winapi.IsInvalidHandle(handle) { return nil } @@ -573,7 +608,7 @@ func (process *Process) unregisterCallback(ctx context.Context) error { closeChannels(callbackContext.channels) callbackMapLock.Lock() - delete(callbackMap, callbackNumber) + delete(callbackMap, callbackNum) callbackMapLock.Unlock() handle = 0 //nolint:ineffassign diff --git a/internal/hcs/service.go b/internal/hcs/service.go index a46b0051df..dc0adcbf6a 100644 --- a/internal/hcs/service.go +++ b/internal/hcs/service.go @@ -21,7 +21,7 @@ func GetServiceProperties(ctx context.Context, q hcsschema.PropertyQuery) (*hcss propertiesJSON, resultJSON, err := vmcompute.HcsGetServiceProperties(ctx, string(queryb)) events := processHcsResult(ctx, resultJSON) if err != nil { - return nil, &HcsError{Op: operation, Err: err, Events: events} + return nil, makeHCSError(operation, err, events) } if propertiesJSON == "" { @@ -45,7 +45,7 @@ func ModifyServiceSettings(ctx context.Context, settings hcsschema.ModificationR resultJSON, err := vmcompute.HcsModifyServiceSettings(ctx, string(settingsJSON)) events := processHcsResult(ctx, resultJSON) if err != nil { - return &HcsError{Op: operation, Err: err, Events: events} + return makeHCSError(operation, err, events) } return nil } diff --git a/internal/hcs/system.go b/internal/hcs/system.go index b1597466f6..08b655fbc3 100644 --- a/internal/hcs/system.go +++ b/internal/hcs/system.go @@ -12,6 +12,9 @@ import ( "syscall" "time" + "github.com/sirupsen/logrus" + "go.opencensus.io/trace" + "github.com/Microsoft/hcsshim/internal/cow" "github.com/Microsoft/hcsshim/internal/hcs/schema1" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" @@ -21,15 +24,14 @@ import ( "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/timeout" "github.com/Microsoft/hcsshim/internal/vmcompute" - "github.com/sirupsen/logrus" - "go.opencensus.io/trace" + "github.com/Microsoft/hcsshim/internal/winapi" ) type System struct { handleLock sync.RWMutex handle vmcompute.HcsSystem id string - callbackNumber uintptr + callbackNumber callbackNumber closedWaitOnce sync.Once waitBlock chan struct{} @@ -63,7 +65,7 @@ func CreateComputeSystem(ctx context.Context, id string, hcsDocumentInterface in ctx, span := oc.StartSpan(ctx, operation) defer span.End() defer func() { oc.SetSpanStatus(span, err) }() - span.AddAttributes(trace.StringAttribute("cid", id)) + span.AddAttributes(trace.StringAttribute(logfields.SystemID, id)) computeSystem := newSystem(id) @@ -83,7 +85,7 @@ func CreateComputeSystem(ctx context.Context, id string, hcsDocumentInterface in if createError == nil || IsPending(createError) { defer func() { if err != nil { - computeSystem.Close() + computeSystem.CloseCtx(context.WithoutCancel(ctx)) //nolint:errcheck } }() if err = computeSystem.registerCallback(ctx); err != nil { @@ -115,6 +117,8 @@ func CreateComputeSystem(ctx context.Context, id string, hcsDocumentInterface in func OpenComputeSystem(ctx context.Context, id string) (*System, error) { operation := "hcs::OpenComputeSystem" + log.G(ctx).WithField(logfields.SystemID, id).Trace(operation) + computeSystem := newSystem(id) handle, resultJSON, err := vmcompute.HcsOpenComputeSystem(ctx, id) events := processHcsResult(ctx, resultJSON) @@ -124,7 +128,7 @@ func OpenComputeSystem(ctx context.Context, id string) (*System, error) { computeSystem.handle = handle defer func() { if err != nil { - computeSystem.Close() + computeSystem.CloseCtx(context.WithoutCancel(ctx)) //nolint:errcheck } }() if err = computeSystem.registerCallback(ctx); err != nil { @@ -176,7 +180,7 @@ func GetComputeSystems(ctx context.Context, q schema1.ComputeSystemQuery) ([]sch computeSystemsJSON, resultJSON, err := vmcompute.HcsEnumerateComputeSystems(ctx, string(queryb)) events := processHcsResult(ctx, resultJSON) if err != nil { - return nil, &HcsError{Op: operation, Err: err, Events: events} + return nil, makeHCSError(operation, err, events) } if computeSystemsJSON == "" { @@ -199,14 +203,14 @@ func (computeSystem *System) Start(ctx context.Context) (err error) { ctx, span := oc.StartSpan(ctx, operation) defer span.End() defer func() { oc.SetSpanStatus(span, err) }() - span.AddAttributes(trace.StringAttribute("cid", computeSystem.id)) + span.AddAttributes(trace.StringAttribute(logfields.SystemID, computeSystem.id)) computeSystem.handleLock.RLock() defer computeSystem.handleLock.RUnlock() // prevent starting an exited system because waitblock we do not recreate waitBlock // or rerun waitBackground, so we have no way to be notified of it closing again - if computeSystem.handle == 0 { + if winapi.IsInvalidHandle(computeSystem.handle) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } @@ -232,7 +236,7 @@ func (computeSystem *System) Shutdown(ctx context.Context) error { operation := "hcs::System::Shutdown" - if computeSystem.handle == 0 || computeSystem.stopped() { + if winapi.IsInvalidHandle(computeSystem.handle) || computeSystem.stopped() { return nil } @@ -254,7 +258,7 @@ func (computeSystem *System) Terminate(ctx context.Context) error { operation := "hcs::System::Terminate" - if computeSystem.handle == 0 || computeSystem.stopped() { + if winapi.IsInvalidHandle(computeSystem.handle) || computeSystem.stopped() { return nil } @@ -278,9 +282,9 @@ func (computeSystem *System) waitBackground() { operation := "hcs::System::waitBackground" ctx, span := oc.StartSpan(context.Background(), operation) defer span.End() - span.AddAttributes(trace.StringAttribute("cid", computeSystem.id)) + span.AddAttributes(trace.StringAttribute(logfields.SystemID, computeSystem.id)) - err := waitForNotification(ctx, computeSystem.callbackNumber, hcsNotificationSystemExited, nil) + events, err := waitForNotification(ctx, computeSystem.callbackNumber, hcsNotificationSystemExited, nil) if err == nil { log.G(ctx).Debug("system exited") } else if errors.Is(err, ErrVmcomputeUnexpectedExit) { @@ -288,7 +292,8 @@ func (computeSystem *System) waitBackground() { computeSystem.exitError = makeSystemError(computeSystem, operation, err, nil) err = nil } else { - err = makeSystemError(computeSystem, operation, err, nil) + log.G(ctx).WithError(err).Error("failed wait on system exit") + err = makeSystemError(computeSystem, operation, err, events) } computeSystem.closedWaitOnce.Do(func() { computeSystem.waitError = err @@ -351,7 +356,7 @@ func (computeSystem *System) Properties(ctx context.Context, types ...schema1.Pr operation := "hcs::System::Properties" - if computeSystem.handle == 0 { + if winapi.IsInvalidHandle(computeSystem.handle) { return nil, makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } @@ -489,10 +494,12 @@ func (computeSystem *System) statisticsInProc(job *jobobject.JobObject) (*hcssch } // hcsPropertiesV2Query is a helper to make a HcsGetComputeSystemProperties call using the V2 schema property types. +// +// Requires holding [System.handleLock]. func (computeSystem *System) hcsPropertiesV2Query(ctx context.Context, types []hcsschema.PropertyType) (*hcsschema.Properties, error) { operation := "hcs::System::PropertiesV2" - if computeSystem.handle == 0 { + if winapi.IsInvalidHandle(computeSystem.handle) { return nil, makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } @@ -581,12 +588,12 @@ func (computeSystem *System) Pause(ctx context.Context) (err error) { ctx, span := oc.StartSpan(ctx, operation) defer span.End() defer func() { oc.SetSpanStatus(span, err) }() - span.AddAttributes(trace.StringAttribute("cid", computeSystem.id)) + span.AddAttributes(trace.StringAttribute(logfields.SystemID, computeSystem.id)) computeSystem.handleLock.RLock() defer computeSystem.handleLock.RUnlock() - if computeSystem.handle == 0 { + if winapi.IsInvalidHandle(computeSystem.handle) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } @@ -609,12 +616,12 @@ func (computeSystem *System) Resume(ctx context.Context) (err error) { ctx, span := oc.StartSpan(ctx, operation) defer span.End() defer func() { oc.SetSpanStatus(span, err) }() - span.AddAttributes(trace.StringAttribute("cid", computeSystem.id)) + span.AddAttributes(trace.StringAttribute(logfields.SystemID, computeSystem.id)) computeSystem.handleLock.RLock() defer computeSystem.handleLock.RUnlock() - if computeSystem.handle == 0 { + if winapi.IsInvalidHandle(computeSystem.handle) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } @@ -637,7 +644,7 @@ func (computeSystem *System) Save(ctx context.Context, options interface{}) (err ctx, span := oc.StartSpan(ctx, operation) defer span.End() defer func() { oc.SetSpanStatus(span, err) }() - span.AddAttributes(trace.StringAttribute("cid", computeSystem.id)) + span.AddAttributes(trace.StringAttribute(logfields.SystemID, computeSystem.id)) saveOptions, err := json.Marshal(options) if err != nil { @@ -647,7 +654,7 @@ func (computeSystem *System) Save(ctx context.Context, options interface{}) (err computeSystem.handleLock.RLock() defer computeSystem.handleLock.RUnlock() - if computeSystem.handle == 0 { + if winapi.IsInvalidHandle(computeSystem.handle) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } @@ -665,7 +672,7 @@ func (computeSystem *System) createProcess(ctx context.Context, operation string computeSystem.handleLock.RLock() defer computeSystem.handleLock.RUnlock() - if computeSystem.handle == 0 { + if winapi.IsInvalidHandle(computeSystem.handle) { return nil, nil, makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } @@ -686,7 +693,7 @@ func (computeSystem *System) createProcess(ctx context.Context, operation string return nil, nil, makeSystemError(computeSystem, operation, err, events) } - log.G(ctx).WithField("pid", processInfo.ProcessId).Debug("created process pid") + log.G(ctx).WithField(logfields.ProcessID, processInfo.ProcessId).Debug("created process pid") return newProcess(processHandle, int(processInfo.ProcessId), computeSystem), &processInfo, nil } @@ -727,7 +734,7 @@ func (computeSystem *System) OpenProcess(ctx context.Context, pid int) (*Process operation := "hcs::System::OpenProcess" - if computeSystem.handle == 0 { + if winapi.IsInvalidHandle(computeSystem.handle) { return nil, makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } @@ -760,13 +767,13 @@ func (computeSystem *System) CloseCtx(ctx context.Context) (err error) { ctx, span := oc.StartSpan(ctx, operation) defer span.End() defer func() { oc.SetSpanStatus(span, err) }() - span.AddAttributes(trace.StringAttribute("cid", computeSystem.id)) + span.AddAttributes(trace.StringAttribute(logfields.SystemID, computeSystem.id)) computeSystem.handleLock.Lock() defer computeSystem.handleLock.Unlock() // Don't double free this - if computeSystem.handle == 0 { + if winapi.IsInvalidHandle(computeSystem.handle) { return nil } @@ -788,34 +795,46 @@ func (computeSystem *System) CloseCtx(ctx context.Context) (err error) { return nil } +// Requires holding [System.handleLock]. func (computeSystem *System) registerCallback(ctx context.Context) error { + callbackNum := nextCallback() + + log.G(ctx).WithFields(logrus.Fields{ + logfields.SystemID: computeSystem.id, + logfields.CallbackNumber: callbackNum, + }).Trace("register computer system callback") + callbackContext := ¬ificationWatcherContext{ channels: newSystemChannels(), systemID: computeSystem.id, } callbackMapLock.Lock() - callbackNumber := nextCallback - nextCallback++ - callbackMap[callbackNumber] = callbackContext + callbackMap[callbackNum] = callbackContext callbackMapLock.Unlock() callbackHandle, err := vmcompute.HcsRegisterComputeSystemCallback(ctx, computeSystem.handle, - notificationWatcherCallback, callbackNumber) + notificationWatcherCallback, uintptr(callbackNum)) if err != nil { return err } callbackContext.handle = callbackHandle - computeSystem.callbackNumber = callbackNumber + computeSystem.callbackNumber = callbackNum return nil } +// Requires holding [System.handleLock]. func (computeSystem *System) unregisterCallback(ctx context.Context) error { - callbackNumber := computeSystem.callbackNumber + callbackNum := computeSystem.callbackNumber + + log.G(ctx).WithFields(logrus.Fields{ + logfields.SystemID: computeSystem.id, + logfields.CallbackNumber: callbackNum, + }).Trace("unregister computer system callback") callbackMapLock.RLock() - callbackContext := callbackMap[callbackNumber] + callbackContext := callbackMap[callbackNum] callbackMapLock.RUnlock() if callbackContext == nil { @@ -824,7 +843,7 @@ func (computeSystem *System) unregisterCallback(ctx context.Context) error { handle := callbackContext.handle - if handle == 0 { + if winapi.IsInvalidHandle(handle) { return nil } @@ -838,7 +857,7 @@ func (computeSystem *System) unregisterCallback(ctx context.Context) error { closeChannels(callbackContext.channels) callbackMapLock.Lock() - delete(callbackMap, callbackNumber) + delete(callbackMap, callbackNum) callbackMapLock.Unlock() handle = 0 //nolint:ineffassign @@ -853,7 +872,7 @@ func (computeSystem *System) Modify(ctx context.Context, config interface{}) err operation := "hcs::System::Modify" - if computeSystem.handle == 0 { + if winapi.IsInvalidHandle(computeSystem.handle) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } diff --git a/internal/hcs/utils.go b/internal/hcs/utils.go index 76eb2be7cf..9261274cb7 100644 --- a/internal/hcs/utils.go +++ b/internal/hcs/utils.go @@ -7,11 +7,14 @@ import ( "io" "syscall" + "github.com/pkg/errors" + "golang.org/x/sys/windows" + "github.com/Microsoft/go-winio" diskutil "github.com/Microsoft/go-winio/vhd" + "github.com/Microsoft/hcsshim/computestorage" - "github.com/pkg/errors" - "golang.org/x/sys/windows" + "github.com/Microsoft/hcsshim/internal/winapi" ) // makeOpenFiles calls winio.NewOpenFile for each handle in a slice but closes all the handles @@ -19,7 +22,7 @@ import ( func makeOpenFiles(hs []syscall.Handle) (_ []io.ReadWriteCloser, err error) { fs := make([]io.ReadWriteCloser, len(hs)) for i, h := range hs { - if h != syscall.Handle(0) { + if !winapi.IsInvalidHandle(h) { if err == nil { fs[i], err = winio.NewOpenFile(windows.Handle(h)) } diff --git a/internal/hcs/waithelper.go b/internal/hcs/waithelper.go index 3a51ed1955..bd9ae5b028 100644 --- a/internal/hcs/waithelper.go +++ b/internal/hcs/waithelper.go @@ -6,44 +6,52 @@ import ( "context" "time" + "github.com/sirupsen/logrus" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/logfields" ) func processAsyncHcsResult( ctx context.Context, err error, resultJSON string, - callbackNumber uintptr, + callbackNum callbackNumber, expectedNotification hcsNotification, timeout *time.Duration, ) ([]ErrorEvent, error) { - events := processHcsResult(ctx, resultJSON) if IsPending(err) { - return nil, waitForNotification(ctx, callbackNumber, expectedNotification, timeout) + return waitForNotification(ctx, callbackNum, expectedNotification, timeout) } - return events, err + return processHcsResult(ctx, resultJSON), err } func waitForNotification( ctx context.Context, - callbackNumber uintptr, + callbackNum callbackNumber, expectedNotification hcsNotification, timeout *time.Duration, -) error { +) ([]ErrorEvent, error) { + entry := log.G(ctx).WithFields(logrus.Fields{ + logfields.CallbackNumber: callbackNum, + "notification-type": expectedNotification.String(), + }) + callbackMapLock.RLock() - if _, ok := callbackMap[callbackNumber]; !ok { - callbackMapLock.RUnlock() - log.G(ctx).WithField("callbackNumber", callbackNumber).Error("failed to waitForNotification: callbackNumber does not exist in callbackMap") - return ErrHandleClose - } - channels := callbackMap[callbackNumber].channels + callbackCtx := callbackMap[callbackNum] callbackMapLock.RUnlock() + if callbackCtx == nil { + entry.Error("failed to waitForNotification: callbackNumber does not exist in callbackMap") + return nil, ErrHandleClose + } + channels := callbackCtx.channels + expectedChannel := channels[expectedNotification] if expectedChannel == nil { - log.G(ctx).WithField("type", expectedNotification).Error("unknown notification type in waitForNotification") - return ErrInvalidNotificationType + entry.Error("unknown notification type in waitForNotification") + return nil, ErrInvalidNotificationType } var c <-chan time.Time @@ -56,27 +64,27 @@ func waitForNotification( select { case err, ok := <-expectedChannel: if !ok { - return ErrHandleClose + return nil, ErrHandleClose } - return err + return getEvents(err) case err, ok := <-channels[hcsNotificationSystemExited]: if !ok { - return ErrHandleClose + return nil, ErrHandleClose } // If the expected notification is hcsNotificationSystemExited which of the two selects // chosen is random. Return the raw error if hcsNotificationSystemExited is expected if channels[hcsNotificationSystemExited] == expectedChannel { - return err + return getEvents(err) } - return ErrUnexpectedContainerExit + return nil, ErrUnexpectedContainerExit case _, ok := <-channels[hcsNotificationServiceDisconnect]: if !ok { - return ErrHandleClose + return nil, ErrHandleClose } // hcsNotificationServiceDisconnect should never be an expected notification // it does not need the same handling as hcsNotificationSystemExited - return ErrUnexpectedProcessAbort + return nil, ErrUnexpectedProcessAbort case <-c: - return ErrTimeout + return nil, ErrTimeout } } diff --git a/internal/hvsocket/hvsocket.go b/internal/hvsocket/hvsocket.go index 595834f7c1..8bab750935 100644 --- a/internal/hvsocket/hvsocket.go +++ b/internal/hvsocket/hvsocket.go @@ -6,12 +6,12 @@ package hvsocket import ( "context" "fmt" - "github.com/Microsoft/hcsshim/internal/log" "unsafe" "github.com/Microsoft/go-winio/pkg/guid" "golang.org/x/sys/windows" + "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/resources" ) diff --git a/internal/interop/interop.go b/internal/interop/interop.go index 0f8f676c58..350a9d45fe 100644 --- a/internal/interop/interop.go +++ b/internal/interop/interop.go @@ -12,11 +12,16 @@ import ( //sys coTaskMemFree(buffer unsafe.Pointer) = api_ms_win_core_com_l1_1_0.CoTaskMemFree func ConvertAndFreeCoTaskMemString(buffer *uint16) string { - str := syscall.UTF16ToString((*[1 << 29]uint16)(unsafe.Pointer(buffer))[:]) + str := ConvertString(buffer) coTaskMemFree(unsafe.Pointer(buffer)) return str } +// Converts a PWSTR to a string, duplicating the underlying data and leaving the original unmodified. +func ConvertString(buffer *uint16) string { + return syscall.UTF16ToString((*[1 << 29]uint16)(unsafe.Pointer(buffer))[:]) +} + func Win32FromHresult(hr uintptr) syscall.Errno { if hr&0x1fff0000 == 0x00070000 { return syscall.Errno(hr & 0xffff) diff --git a/internal/log/hook.go b/internal/log/hook.go index da81023b65..6ee833b7cd 100644 --- a/internal/log/hook.go +++ b/internal/log/hook.go @@ -41,6 +41,8 @@ type Hook struct { // AddSpanContext adds [logfields.TraceID] and [logfields.SpanID] fields to // the entry from the span context stored in [logrus.Entry.Context], if it exists. + // + // Default is true. AddSpanContext bool } diff --git a/internal/logfields/fields.go b/internal/logfields/fields.go index cceb3e2d18..6cf9174045 100644 --- a/internal/logfields/fields.go +++ b/internal/logfields/fields.go @@ -14,6 +14,9 @@ const ( ProcessID = "pid" TaskID = "tid" UVMID = "uvm-id" + SystemID = "system-id" + + CallbackNumber = "callback-number" // networking and IO diff --git a/internal/uvm/create_wcow.go b/internal/uvm/create_wcow.go index f922017b44..e1551cac5e 100644 --- a/internal/uvm/create_wcow.go +++ b/internal/uvm/create_wcow.go @@ -69,7 +69,7 @@ type OptionsWCOW struct { // AdditionalRegistryKeys are Registry keys and their values to additionally add to the uVM. AdditionalRegistryKeys []hcsschema.RegistryValue - OutputHandlerCreator OutputHandlerCreator // Creates an [OutputHandler] that controls how output received over HVSocket from the UVM is handled. Defaults to parsing output as ETW Log events + OutputHandlerCreator OutputHandlerCreator `json:"-"` // Creates an [OutputHandler] that controls how output received over HVSocket from the UVM is handled. Defaults to parsing output as ETW Log events LogSources string // ETW providers to be set for the logging service ForwardLogs bool // Whether to forward logs to the host or not } diff --git a/internal/uvm/start.go b/internal/uvm/start.go index 781bc3c417..d51185e294 100644 --- a/internal/uvm/start.go +++ b/internal/uvm/start.go @@ -172,7 +172,7 @@ func (uvm *UtilityVM) Start(ctx context.Context) (err error) { // save parent context, without timeout to use in terminate pCtx := ctx ctx, cancel := context.WithTimeout(pCtx, timeout.GCSConnectionTimeout) - log.G(ctx).Debugf("using gcs connection timeout: %s\n", timeout.GCSConnectionTimeout) + log.G(ctx).Debugf("using gcs connection timeout: %s", timeout.GCSConnectionTimeout) g, gctx := errgroup.WithContext(ctx) defer func() { @@ -215,7 +215,7 @@ func (uvm *UtilityVM) Start(ctx context.Context) (err error) { switch uvm.operatingSystem { case "windows": // Windows specific handling - // For windows, the Listener can recieve a connection later, so we + // For windows, the Listener can receive a connection later, so we // start the output handler in a goroutine with a non-timeout context. // This allows the output handler to run independently of the UVM Create's // lifecycle. The approach potentially allows to wait for reconnections too, diff --git a/internal/winapi/doc.go b/internal/winapi/doc.go index 9acc0bfc17..0bfaf47a33 100644 --- a/internal/winapi/doc.go +++ b/internal/winapi/doc.go @@ -1,3 +1,5 @@ // Package winapi contains various low-level bindings to Windows APIs. It can // be thought of as an extension to golang.org/x/sys/windows. package winapi + +//go:generate go tool github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go diff --git a/internal/winapi/utils.go b/internal/winapi/utils.go index de16750d76..8372560900 100644 --- a/internal/winapi/utils.go +++ b/internal/winapi/utils.go @@ -10,6 +10,11 @@ import ( "golang.org/x/sys/windows" ) +// IsInvalidHandle returns true if the Handle is zero or [windows.InvalidHandle]. +func IsInvalidHandle[H ~uintptr](h H) bool { + return h == 0 || uintptr(h) == uintptr(windows.InvalidHandle) +} + // Uint16BufferToSlice wraps a uint16 pointer-and-length into a slice // for easier interop with Go APIs func Uint16BufferToSlice(buffer *uint16, bufferLength int) (result []uint16) { diff --git a/internal/winapi/winapi.go b/internal/winapi/winapi.go deleted file mode 100644 index 009e70ab19..0000000000 --- a/internal/winapi/winapi.go +++ /dev/null @@ -1,3 +0,0 @@ -package winapi - -//go:generate go tool github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go diff --git a/test/functional/main_test.go b/test/functional/main_test.go index 18bd6a72f4..4360487c06 100644 --- a/test/functional/main_test.go +++ b/test/functional/main_test.go @@ -164,6 +164,7 @@ func runTests(m *testing.M) error { return fmt.Errorf("tests must be run in an elevated context") } + logrus.AddHook(log.NewHook()) trace.ApplyConfig(trace.Config{DefaultSampler: oc.DefaultSampler}) trace.RegisterExporter(&oc.LogrusExporter{}) diff --git a/test/gcs/main_test.go b/test/gcs/main_test.go index ce399e0767..4832494781 100644 --- a/test/gcs/main_test.go +++ b/test/gcs/main_test.go @@ -21,6 +21,7 @@ import ( "github.com/Microsoft/hcsshim/internal/guest/runtime/runc" "github.com/Microsoft/hcsshim/internal/guest/transport" "github.com/Microsoft/hcsshim/internal/guestpath" + "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/pkg/securitypolicy" @@ -109,6 +110,7 @@ func TestMain(m *testing.M) { func setup() (err error) { _ = os.MkdirAll(guestpath.LCOWRootPrefixInUVM, 0755) + logrus.AddHook(log.NewHook()) trace.ApplyConfig(trace.Config{DefaultSampler: oc.DefaultSampler}) trace.RegisterExporter(&oc.LogrusExporter{}) diff --git a/test/internal/cmd/cmd.go b/test/internal/cmd/cmd.go index d7bcc78a38..cdf7aea195 100644 --- a/test/internal/cmd/cmd.go +++ b/test/internal/cmd/cmd.go @@ -23,16 +23,17 @@ const CopyAfterExitTimeout = time.Second const ForcedKilledExitCode = 137 func desc(c *cmd.Cmd) string { - desc := "init command" - if c.Spec != nil { - if c.Spec.CommandLine != "" { - desc = c.Spec.CommandLine - } else { - desc = strings.Join(c.Spec.Args, " ") - } + switch { + case c == nil: + return "" + case c.Spec == nil: + return "init command" + case c.Spec.CommandLine != "": + return c.Spec.CommandLine + default: } - return desc + return strings.Join(c.Spec.Args, " ") } func Create(ctx context.Context, _ testing.TB, c cow.ProcessHost, p *specs.Process, io *BufferedIO) *cmd.Cmd { @@ -51,8 +52,12 @@ func Create(ctx context.Context, _ testing.TB, c cow.ProcessHost, p *specs.Proce func Start(_ context.Context, tb testing.TB, c *cmd.Cmd) { tb.Helper() + + d := desc(c) + tb.Logf("starting command: %q", d) + if err := c.Start(); err != nil { - tb.Fatalf("failed to start %q: %v", desc(c), err) + tb.Fatalf("failed to start %q: %v", d, err) } } @@ -64,14 +69,21 @@ func Run(ctx context.Context, tb testing.TB, c *cmd.Cmd) int { func Wait(_ context.Context, tb testing.TB, c *cmd.Cmd) int { tb.Helper() + + d := desc(c) + tb.Logf("waiting on process: %q", d) + // todo, wait on context.Done if err := c.Wait(); err != nil { ee := &cmd.ExitError{} if errors.As(err, &ee) { - return ee.ExitCode() + ec := ee.ExitCode() + tb.Logf("process exit code: %d", ec) + return ec } - tb.Fatalf("failed to wait on %q: %v", desc(c), err) + tb.Fatalf("failed to wait on %q: %v", d, err) } + return 0 } @@ -84,10 +96,14 @@ func WaitExitCode(ctx context.Context, tb testing.TB, c *cmd.Cmd, e int) { func Kill(ctx context.Context, tb testing.TB, c *cmd.Cmd) { tb.Helper() + + d := desc(c) + tb.Logf("kill process: %q", d) + ok, err := c.Process.Kill(ctx) if !ok { - tb.Fatalf("could not deliver kill to %q", desc(c)) + tb.Fatalf("could not deliver kill to %q", d) } else if err != nil { - tb.Fatalf("could not kill %q: %v", desc(c), err) + tb.Fatalf("could not kill %q: %v", d, err) } } diff --git a/test/internal/container/container.go b/test/internal/container/container.go index 0d3a9a0a5b..e1236feb68 100644 --- a/test/internal/container/container.go +++ b/test/internal/container/container.go @@ -31,6 +31,7 @@ func Create( name, owner string, ) (c cow.Container, r *resources.Resources, _ func()) { tb.Helper() + tb.Logf("creating container: %q", name) if spec.Windows == nil || spec.Windows.Network == nil || spec.Windows.LayerFolders == nil { tb.Fatalf("improperly configured windows spec for container %q: %#+v", name, spec.Windows) @@ -72,6 +73,8 @@ func Create( } f := func() { + tb.Logf("cleaning up container: %q", name) + if err := resources.ReleaseResources(ctx, r, vm, true); err != nil { tb.Errorf("failed to release container resources: %v", err) } @@ -99,6 +102,8 @@ func Start(ctx context.Context, tb testing.TB, c cow.Container, io *testcmd.Buff func StartWithSpec(ctx context.Context, tb testing.TB, c cow.Container, p *specs.Process, io *testcmd.BufferedIO) *cmd.Cmd { tb.Helper() + tb.Logf("starting container: %q", c.ID()) + if err := c.Start(ctx); err != nil { tb.Fatalf("could not start %q: %v", c.ID(), err) } @@ -111,6 +116,8 @@ func StartWithSpec(ctx context.Context, tb testing.TB, c cow.Container, p *specs func Wait(_ context.Context, tb testing.TB, c cow.Container) { tb.Helper() + tb.Logf("waiting on container: %q", c.ID()) + // todo: add wait on ctx.Done if err := c.Wait(); err != nil { tb.Fatalf("could not wait on container %q: %v", c.ID(), err) @@ -119,6 +126,8 @@ func Wait(_ context.Context, tb testing.TB, c cow.Container) { func Kill(ctx context.Context, tb testing.TB, c cow.Container) { tb.Helper() + tb.Logf("kill container: %q", c.ID()) + if err := c.Shutdown(ctx); err != nil { tb.Fatalf("could not terminate container %q: %v", c.ID(), err) } diff --git a/test/pkg/uvm/lcow.go b/test/pkg/uvm/lcow.go index 4e17122349..19d1537315 100644 --- a/test/pkg/uvm/lcow.go +++ b/test/pkg/uvm/lcow.go @@ -75,8 +75,16 @@ func CreateAndStartLCOWFromOpts(ctx context.Context, tb testing.TB, opts *uvm.Op return vm } +//nolint:staticcheck // SA5011: staticcheck thinks `opts` may be nil, even though we fail if it is func CreateLCOW(ctx context.Context, tb testing.TB, opts *uvm.OptionsLCOW) (*uvm.UtilityVM, CleanupFn) { tb.Helper() + + if opts == nil { + tb.Fatalf("opts cannot be nil bet set with BootFiles") + } + + tb.Logf("create LCOW uVM: %q", opts.ID) + vm, err := uvm.CreateLCOW(ctx, opts) if err != nil { tb.Fatalf("could not create LCOW UVM: %v", err) diff --git a/test/pkg/uvm/uvm.go b/test/pkg/uvm/uvm.go index 84c6fbf030..45e35aa2a0 100644 --- a/test/pkg/uvm/uvm.go +++ b/test/pkg/uvm/uvm.go @@ -60,15 +60,17 @@ func CreateAndStart(ctx context.Context, tb testing.TB, opts any) *uvm.UtilityVM func Start(ctx context.Context, tb testing.TB, vm *uvm.UtilityVM) { tb.Helper() - err := vm.Start(ctx) + tb.Logf("start uVM: %q", vm.ID()) - if err != nil { + if err := vm.Start(ctx); err != nil { tb.Fatalf("could not start UVM: %v", err) } } func Wait(ctx context.Context, tb testing.TB, vm *uvm.UtilityVM) { tb.Helper() + tb.Logf("waiting on container: %q", vm.ID()) + if err := vm.WaitCtx(ctx); err != nil { tb.Fatalf("could not wait for uvm %q: %v", vm.ID(), err) } @@ -76,6 +78,8 @@ func Wait(ctx context.Context, tb testing.TB, vm *uvm.UtilityVM) { func Kill(ctx context.Context, tb testing.TB, vm *uvm.UtilityVM) { tb.Helper() + tb.Logf("kill uVM: %q", vm.ID()) + if err := vm.Terminate(ctx); err != nil { tb.Fatalf("could not kill uvm %q: %v", vm.ID(), err) } @@ -83,6 +87,8 @@ func Kill(ctx context.Context, tb testing.TB, vm *uvm.UtilityVM) { func Close(ctx context.Context, tb testing.TB, vm *uvm.UtilityVM) { tb.Helper() + tb.Logf("close uVM: %q", vm.ID()) + if err := vm.CloseCtx(ctx); err != nil { tb.Fatalf("could not close uvm %q: %v", vm.ID(), err) } diff --git a/test/pkg/uvm/wcow.go b/test/pkg/uvm/wcow.go index 2f4644cae8..5caa809587 100644 --- a/test/pkg/uvm/wcow.go +++ b/test/pkg/uvm/wcow.go @@ -34,10 +34,18 @@ func CreateWCOWUVM(ctx context.Context, tb testing.TB, id, image string) (*uvm.U } // CreateWCOW creates a WCOW utility VM with the passed opts. +// +//nolint:staticcheck // SA5011: staticcheck thinks `opts` may be nil, even though we fail if it is func CreateWCOW(ctx context.Context, tb testing.TB, opts *uvm.OptionsWCOW) (*uvm.UtilityVM, CleanupFn) { tb.Helper() - if opts == nil || opts.BootFiles == nil { + if opts == nil { + tb.Fatalf("opts cannot be nil bet set with BootFiles") + } + + tb.Logf("create WCOW uVM: %q", opts.ID) + + if opts.BootFiles == nil { tb.Fatalf("opts must bet set with BootFiles") }