Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions pkg/monitor/sqsevent/rebalance-recommendation-event.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ func (m SQSMonitor) rebalanceRecommendationToInterruptionEvent(event *EventBridg
Description: fmt.Sprintf("Rebalance recommendation event received. Instance %s will be cordoned at %s \n", rebalanceRecDetail.InstanceID, event.getTime()),
}
interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, n node.Node) error {
// Use provider ID to resolve the actual Kubernetes node name if UseProviderId is configured
nthConfig := n.GetNthConfig()
nodeName := interruptionEvent.NodeName
if nthConfig.UseProviderId && interruptionEvent.ProviderID != "" {
resolvedNodeName, err := n.GetNodeNameFromProviderID(interruptionEvent.ProviderID)
if err != nil {
log.Warn().Err(err).Str("provider_id", interruptionEvent.ProviderID).Msg("Failed to resolve node name from provider ID, falling back to NodeName from event")
} else {
nodeName = resolvedNodeName
}
}

// Remove the draining condition from the node
if err := n.RemoveDrainingCondition(nodeName); err != nil {
log.Err(err).Str("node_name", nodeName).Msg("Unable to remove draining condition from node")
}

errs := m.deleteMessages([]*sqs.Message{message})
if errs != nil {
return errs[0]
Expand All @@ -90,6 +107,11 @@ func (m SQSMonitor) rebalanceRecommendationToInterruptionEvent(event *EventBridg
}
}

// Set the draining condition on the node
if err := n.SetDrainingCondition(nodeName, "RebalanceRecommendation", interruptionEvent.Description); err != nil {
log.Err(err).Str("node_name", nodeName).Msg("Unable to set draining condition on node")
}

err := n.TaintRebalanceRecommendation(nodeName, interruptionEvent.EventID)
if err != nil {
log.Err(err).Msgf("Unable to taint node with taint %s:%s", node.RebalanceRecommendationTaint, interruptionEvent.EventID)
Expand Down
22 changes: 22 additions & 0 deletions pkg/monitor/sqsevent/spot-itn-event.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ func (m SQSMonitor) spotITNTerminationToInterruptionEvent(event *EventBridgeEven
Description: fmt.Sprintf("Spot Interruption notice for instance %s was sent at %s \n", spotInterruptionDetail.InstanceID, event.getTime()),
}
interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, n node.Node) error {
// Use provider ID to resolve the actual Kubernetes node name if UseProviderId is configured
nthConfig := n.GetNthConfig()
nodeName := interruptionEvent.NodeName
if nthConfig.UseProviderId && interruptionEvent.ProviderID != "" {
resolvedNodeName, err := n.GetNodeNameFromProviderID(interruptionEvent.ProviderID)
if err != nil {
log.Warn().Err(err).Str("provider_id", interruptionEvent.ProviderID).Msg("Failed to resolve node name from provider ID, falling back to NodeName from event")
} else {
nodeName = resolvedNodeName
}
}

// Remove the draining condition from the node
if err := n.RemoveDrainingCondition(nodeName); err != nil {
log.Err(err).Str("node_name", nodeName).Msg("Unable to remove draining condition from node")
}

errs := m.deleteMessages([]*sqs.Message{message})
if errs != nil {
return errs[0]
Expand All @@ -92,6 +109,11 @@ func (m SQSMonitor) spotITNTerminationToInterruptionEvent(event *EventBridgeEven
}
}

// Set the draining condition on the node
if err := n.SetDrainingCondition(nodeName, "SpotInterruption", interruptionEvent.Description); err != nil {
log.Err(err).Str("node_name", nodeName).Msg("Unable to set draining condition on node")
}

err := n.TaintSpotItn(nodeName, interruptionEvent.EventID)
if err != nil {
log.Err(err).Msgf("Unable to taint node with taint %s:%s", node.SpotInterruptionTaint, interruptionEvent.EventID)
Expand Down
153 changes: 153 additions & 0 deletions pkg/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ const (
ExcludeFromLoadBalancersLabelValue = "aws-node-termination-handler"
)

const (
// TerminationHandlerDrainingConditionType is a node condition type indicating the node is being drained
TerminationHandlerDrainingConditionType = "TerminationHandlerDraining"
)

const (
// SpotInterruptionTaint is a taint used to make spot instance unschedulable
SpotInterruptionTaint = "aws-node-termination-handler/spot-itn"
Expand Down Expand Up @@ -583,6 +588,154 @@ func (n Node) RemoveNTHTaints(nodeName string) error {
return nil
}

// SetDrainingCondition adds a condition to the node indicating it is being drained by NTH
func (n Node) SetDrainingCondition(nodeName string, reason string, message string) error {
if n.nthConfig.DryRun {
log.Info().Str("node_name", nodeName).Str("reason", reason).Msg("Would have set draining condition on node, but dry-run flag was set")
return nil
}

k8sNode, err := n.fetchKubernetesNode(nodeName)
if err != nil {
return fmt.Errorf("Unable to fetch kubernetes node from API: %w", err)
}

return n.setNodeCondition(k8sNode, TerminationHandlerDrainingConditionType, corev1.ConditionTrue, reason, message)
}

// RemoveDrainingCondition removes the draining condition from the node
func (n Node) RemoveDrainingCondition(nodeName string) error {
if n.nthConfig.DryRun {
log.Info().Str("node_name", nodeName).Msg("Would have removed draining condition from node, but dry-run flag was set")
return nil
}

k8sNode, err := n.fetchKubernetesNode(nodeName)
if err != nil {
return fmt.Errorf("Unable to fetch kubernetes node from API: %w", err)
}

return n.removeNodeCondition(k8sNode, TerminationHandlerDrainingConditionType)
}

// setNodeCondition adds or updates a condition on the node
func (n Node) setNodeCondition(node *corev1.Node, conditionType corev1.NodeConditionType, status corev1.ConditionStatus, reason string, message string) error {
retryDeadline := time.Now().Add(maxRetryDeadline)
freshNode := node.DeepCopy()
client := n.drainHelper.Client
var err error
refresh := false

for {
if refresh {
freshNode, err = client.CoreV1().Nodes().Get(context.TODO(), node.Name, metav1.GetOptions{})
if err != nil || freshNode == nil {
return fmt.Errorf("failed to get node %v: %w", node.Name, err)
}
}

now := metav1.Now()
newCondition := corev1.NodeCondition{
Type: conditionType,
Status: status,
Reason: reason,
Message: message,
LastTransitionTime: now,
LastHeartbeatTime: now,
}

conditionExists := false
for i, condition := range freshNode.Status.Conditions {
if condition.Type == conditionType {
freshNode.Status.Conditions[i] = newCondition
conditionExists = true
break
}
}
if !conditionExists {
freshNode.Status.Conditions = append(freshNode.Status.Conditions, newCondition)
}

_, err = client.CoreV1().Nodes().UpdateStatus(context.TODO(), freshNode, metav1.UpdateOptions{})
if err != nil && errors.IsConflict(err) && time.Now().Before(retryDeadline) {
refresh = true
time.Sleep(conflictRetryInterval)
continue
}

if err != nil {
log.Err(err).
Str("condition_type", string(conditionType)).
Str("node_name", node.Name).
Msg("Error while setting condition on node")
return err
}
log.Info().
Str("condition_type", string(conditionType)).
Str("reason", reason).
Str("node_name", node.Name).
Msg("Successfully set condition on node")
return nil
}
}

// removeNodeCondition removes a condition from the node
func (n Node) removeNodeCondition(node *corev1.Node, conditionType corev1.NodeConditionType) error {
retryDeadline := time.Now().Add(maxRetryDeadline)
freshNode := node.DeepCopy()
client := n.drainHelper.Client
var err error
refresh := false

for {
if refresh {
freshNode, err = client.CoreV1().Nodes().Get(context.TODO(), node.Name, metav1.GetOptions{})
if err != nil || freshNode == nil {
return fmt.Errorf("failed to get node %v: %w", node.Name, err)
}
}

newConditions := make([]corev1.NodeCondition, 0)
found := false
for _, condition := range freshNode.Status.Conditions {
if condition.Type == conditionType {
found = true
continue
}
newConditions = append(newConditions, condition)
}

if !found {
if !refresh {
refresh = true
continue
}
return nil
}

freshNode.Status.Conditions = newConditions
_, err = client.CoreV1().Nodes().UpdateStatus(context.TODO(), freshNode, metav1.UpdateOptions{})
if err != nil && errors.IsConflict(err) && time.Now().Before(retryDeadline) {
refresh = true
time.Sleep(conflictRetryInterval)
continue
}

if err != nil {
log.Err(err).
Str("condition_type", string(conditionType)).
Str("node_name", node.Name).
Msg("Error while removing condition from node")
return err
}
log.Info().
Str("condition_type", string(conditionType)).
Str("node_name", node.Name).
Msg("Successfully removed condition from node")
return nil
}
}

// IsLabeledWithAction will return true if the current node is labeled with NTH action labels
func (n Node) IsLabeledWithAction(nodeName string) (bool, error) {
k8sNode, err := n.fetchKubernetesNode(nodeName)
Expand Down
Loading