diff --git a/src/internal.c b/src/internal.c index 249462e31..b0e264e37 100644 --- a/src/internal.c +++ b/src/internal.c @@ -8345,7 +8345,8 @@ static int DoUserAuthRequest(WOLFSSH* ssh, { word32 begin; int ret = WS_SUCCESS; - byte authNameId; + byte authNameId = ID_UNKNOWN; + int serviceValid = 1; WS_UserAuthData authData; WLOG(WS_LOG_DEBUG, "Entering DoUserAuthRequest()"); @@ -8356,37 +8357,32 @@ static int DoUserAuthRequest(WOLFSSH* ssh, if (ret == WS_SUCCESS) { begin = *idx; WMEMSET(&authData, 0, sizeof(authData)); - ret = GetSize(&authData.usernameSz, buf, len, &begin); - } - - if (ret == WS_SUCCESS) { - authData.username = buf + begin; - begin += authData.usernameSz; - - ret = GetUint32(&authData.serviceNameSz, buf, len, &begin); + ret = GetStringRef(&authData.usernameSz, &authData.username, + buf, len, &begin); } if (ret == WS_SUCCESS) { - ret = wolfSSH_SetUsernameRaw(ssh, authData.username, authData.usernameSz); + ret = GetStringRef(&authData.serviceNameSz, &authData.serviceName, + buf, len, &begin); } if (ret == WS_SUCCESS) { - if (authData.serviceNameSz > len - begin) { - ret = WS_BUFFER_E; + if (NameToId((const char*)authData.serviceName, authData.serviceNameSz) + != ID_SERVICE_CONNECTION) { + WLOG(WS_LOG_DEBUG, "DUAR: Invalid service name"); + serviceValid = 0; + ret = SendUserAuthFailure(ssh, 0); + /* Consume all remaining data */ + *idx = len; + } + else { + ret = GetStringRef(&authData.authNameSz, &authData.authName, + buf, len, &begin); } } - if (ret == WS_SUCCESS) { - authData.serviceName = buf + begin; - begin += authData.serviceNameSz; - - ret = GetSize(&authData.authNameSz, buf, len, &begin); - } - - if (ret == WS_SUCCESS) { - authData.authName = buf + begin; - begin += authData.authNameSz; - authNameId = NameToId((char*)authData.authName, authData.authNameSz); + if (ret == WS_SUCCESS && serviceValid) { + authNameId = NameToId((const char*)authData.authName, authData.authNameSz); ssh->authId = authNameId; if (authNameId == ID_USERAUTH_PASSWORD) @@ -8409,11 +8405,14 @@ static int DoUserAuthRequest(WOLFSSH* ssh, #endif else { WLOG(WS_LOG_DEBUG, - "invalid userauth type: %s", IdToName(authNameId)); + "DUAR: invalid userauth type: %s", IdToName(authNameId)); ret = SendUserAuthFailure(ssh, 0); + /* Consume all remaining data */ + begin = len; } if (ret == WS_SUCCESS) { + /* Set the username for valid service only */ ret = wolfSSH_SetUsernameRaw(ssh, authData.username, authData.usernameSz); } @@ -17976,6 +17975,12 @@ int wolfSSH_TestChannelPutData(WOLFSSH_CHANNEL* channel, byte* data, return ChannelPutData(channel, data, dataSz); } +int wolfSSH_TestDoUserAuthRequest(WOLFSSH* ssh, byte* buf, word32 len, + word32* idx) +{ + return DoUserAuthRequest(ssh, buf, len, idx); +} + #ifndef WOLFSSH_NO_DH_GEX_SHA256 int wolfSSH_TestDoKexDhGexRequest(WOLFSSH* ssh, byte* buf, word32 len, diff --git a/tests/unit.c b/tests/unit.c index 9bbd2cc29..4afc26f5f 100644 --- a/tests/unit.c +++ b/tests/unit.c @@ -1181,13 +1181,25 @@ static int test_ChannelPutData(void) return result; } +/* Plaintext SSH packet from IoSend (before encryption/MAC): LENGTH_SZ, + * PAD_LENGTH_SZ, then payload starting with the message ID (RFC 4253; + * wolfSSH PreparePacket/BundlePacket). Not for encrypted payloads or + * arbitrary truncated chunks. */ +static int CaptureMsgId(const byte* buf, word32 len) +{ + word32 off = LENGTH_SZ + PAD_LENGTH_SZ; + + if (len <= off) + return -1; + return (int)buf[off]; +} + /* Verify DoChannelRequest sends CHANNEL_SUCCESS for known types and * CHANNEL_FAILURE for unrecognized ones (RFC 4254 Section 5.4). * * A custom IoSend callback captures the outgoing packet in plaintext - * (no cipher negotiated on a fresh session). The SSH packet layout is: - * [4-byte packet_length][1-byte padding_length][1-byte msg_id]... - * so the message ID lives at byte offset 5. */ + * (no cipher negotiated on a fresh session). Message ID is read via + * CaptureMsgId() using LENGTH_SZ + PAD_LENGTH_SZ. */ static byte s_chanReqCapture[256]; static word32 s_chanReqCaptureSz = 0; @@ -1289,20 +1301,204 @@ static int test_DoChannelRequest(void) goto done; } - if (s_chanReqCaptureSz <= 5) { - printf("DoChannelRequest[%s]: captured packet too short (%u)\n", - cases[i].label, s_chanReqCaptureSz); - result = -410 - i; + { + int capMsgId = CaptureMsgId(s_chanReqCapture, s_chanReqCaptureSz); + + if (capMsgId < 0) { + printf("DoChannelRequest[%s]: captured packet too short (%u)\n", + cases[i].label, s_chanReqCaptureSz); + result = -410 - i; + goto done; + } + + if (capMsgId != (int)cases[i].expectMsgId) { + printf("DoChannelRequest[%s]: msg_id=0x%02x, expected=0x%02x\n", + cases[i].label, + capMsgId, cases[i].expectMsgId); + result = -420 - i; + goto done; + } + } + } + +done: + wolfSSH_free(ssh); + wolfSSH_CTX_free(ctx); + return result; +} + +/* Capture buffer for the service-name unit test. Separate from the channel- + * request capture so the two tests can run independently in any order. */ +static byte s_authSvcCapture[256]; +static word32 s_authSvcCaptureSz = 0; +static word32 s_authSvcSendCount = 0; + +static int CaptureIoSendAuthSvc(WOLFSSH* ssh, void* buf, word32 sz, void* ctx) +{ + (void)ssh; (void)ctx; + s_authSvcCaptureSz = (sz < (word32)sizeof(s_authSvcCapture)) + ? sz : (word32)sizeof(s_authSvcCapture); + WMEMCPY(s_authSvcCapture, buf, s_authSvcCaptureSz); + s_authSvcSendCount++; + return (int)sz; +} + +/* Verify DoUserAuthRequest rejects non-"ssh-connection" service names per + * RFC 4252 Section 5. For each case we assert: + * 1. ret == WS_SUCCESS (connection stays open for retry) + * 2. SSH_MSG_USERAUTH_FAILURE is actually sent (see CaptureMsgId(): + * LENGTH_SZ + PAD_LENGTH_SZ then msg id) + * 3. *idx == len (entire payload consumed; buffer stays aligned) + * + * For invalid-service cases the auth-method field is intentionally omitted + * from the payload. DoUserAuthRequest must short-circuit at the service-name + * check and still satisfy all three assertions — proving it never tries to + * parse the missing auth-method field. If the short-circuit were absent, + * GetSize() for authNameSz would hit end-of-buffer and return WS_BUFFER_E, + * failing assertion 1. + * + * For the valid-service case, auth method "xyz-unknown" (always unsupported + * regardless of compile-time options) is included. The function reaches + * auth-method dispatch, falls to the unknown-method else-branch, and sends + * USERAUTH_FAILURE via that normal path. + * + * A second valid-service row appends fake password-style bytes after the + * method name. That proves DoUserAuthRequest() consumes trailing + * method-specific payload (begin = len in the unknown-method branch); without + * it, DoReceive() could advance inputBuffer.idx short of the packet end and + * misalign decoding. */ +static const byte s_unknownAuthTrailingFakePassword[] = { + 0x00, /* "change password" FALSE */ + 0x00, 0x00, 0x00, 0x08, + 'p', 'a', 's', 's', 'w', 'o', 'r', 'd', +}; + +static int test_DoUserAuthRequest_serviceName(void) +{ + WOLFSSH_CTX* ctx = NULL; + WOLFSSH* ssh = NULL; + int result = 0; + struct { + const char* svcName; + word32 svcNameSz; + const char* authMethod; /* NULL = omit field (proves short-circuit) */ + word32 authMethodSz; + int expectRet; + const char* label; + const byte* authTrailing; /* bytes after auth method; NULL if none */ + word32 authTrailingSz; + } cases[] = { + /* valid service: auth dispatch fires, fails on unknown method */ + { "ssh-connection", 14, "xyz-unknown", 11, WS_SUCCESS, + "valid svc unknown auth", NULL, 0 }, + /* same but trailing junk must be skipped so *idx reaches len */ + { "ssh-connection", 14, "xyz-unknown", 11, WS_SUCCESS, + "valid svc unknown auth trailing junk", + s_unknownAuthTrailingFakePassword, + (word32)sizeof(s_unknownAuthTrailingFakePassword) }, + /* invalid service: short-circuit, auth-method field absent */ + { "ssh-agent", 9, NULL, 0, WS_SUCCESS, + "invalid ssh-agent svc", NULL, 0 }, + { "bad", 3, NULL, 0, WS_SUCCESS, + "invalid bad svc", NULL, 0 }, + /* zero-length service name: NameToId("",0)==ID_UNKNOWN, must reject */ + { "", 0, NULL, 0, WS_SUCCESS, + "zero-length svc", NULL, 0 }, + /* ssh-userauth: NameToId returns ID_SERVICE_USERAUTH, not + * ID_SERVICE_CONNECTION, so must also be rejected */ + { "ssh-userauth", 12, NULL, 0, WS_SUCCESS, + "invalid ssh-userauth svc", NULL, 0 }, + }; + int i; + + ctx = wolfSSH_CTX_new(WOLFSSH_ENDPOINT_SERVER, NULL); + if (ctx == NULL) return -500; + wolfSSH_SetIOSend(ctx, CaptureIoSendAuthSvc); + + for (i = 0; i < (int)(sizeof(cases)/sizeof(cases[0])); i++) { + byte buf[128]; + word32 len = 0, idx = 0; + word32 snsz = cases[i].svcNameSz; + int ret; + + ssh = wolfSSH_new(ctx); + if (ssh == NULL) { result = -501; goto done; } + + s_authSvcCaptureSz = 0; + s_authSvcSendCount = 0; + WMEMSET(s_authSvcCapture, 0, sizeof(s_authSvcCapture)); + + /* username: "user" */ + buf[len++] = 0; buf[len++] = 0; buf[len++] = 0; buf[len++] = 4; + WMEMCPY(buf + len, "user", 4); len += 4; + + /* service name */ + buf[len++] = (byte)(snsz >> 24); buf[len++] = (byte)(snsz >> 16); + buf[len++] = (byte)(snsz >> 8); buf[len++] = (byte)snsz; + if (snsz > 0) { WMEMCPY(buf + len, cases[i].svcName, snsz); } + len += snsz; + + /* auth method: omit for invalid-service cases to prove short-circuit */ + if (cases[i].authMethod != NULL) { + word32 amsz = cases[i].authMethodSz; + buf[len++] = (byte)(amsz >> 24); buf[len++] = (byte)(amsz >> 16); + buf[len++] = (byte)(amsz >> 8); buf[len++] = (byte)amsz; + WMEMCPY(buf + len, cases[i].authMethod, amsz); len += amsz; + if (cases[i].authTrailingSz > 0U) { + WMEMCPY(buf + len, cases[i].authTrailing, + cases[i].authTrailingSz); + len += cases[i].authTrailingSz; + } + } + + ret = wolfSSH_TestDoUserAuthRequest(ssh, buf, len, &idx); + + if (s_authSvcSendCount != 1) { + printf("DoUserAuthRequest_svcName[%s]: expected 1 send, got %u\n", + cases[i].label, s_authSvcSendCount); + result = -540 - i; + goto done; + } + + if (ret != cases[i].expectRet) { + printf("DoUserAuthRequest_svcName[%s]: ret=%d expected=%d\n", + cases[i].label, ret, cases[i].expectRet); + result = -502 - i; + goto done; + } + + /* MSGID_USERAUTH_FAILURE must be in the captured packet. */ + { + int capMsgId = CaptureMsgId(s_authSvcCapture, s_authSvcCaptureSz); + + if (capMsgId < 0 || capMsgId != MSGID_USERAUTH_FAILURE) { + printf("DoUserAuthRequest_svcName[%s]: USERAUTH_FAILURE not " + "sent (capSz=%u msg_id=0x%02x)\n", cases[i].label, + s_authSvcCaptureSz, + capMsgId >= 0 ? capMsgId : 0); + result = -520 - i; + goto done; + } + } + + /* All cases must consume the entire payload. */ + if (idx != len) { + printf("DoUserAuthRequest_svcName[%s]: idx=%u expected len=%u\n", + cases[i].label, idx, len); + result = -510 - i; goto done; } - if (s_chanReqCapture[5] != cases[i].expectMsgId) { - printf("DoChannelRequest[%s]: msg_id=0x%02x, expected=0x%02x\n", - cases[i].label, - s_chanReqCapture[5], cases[i].expectMsgId); - result = -420 - i; + /* Invalid-service cases must NOT record the username. */ + if (cases[i].authMethod == NULL && ssh->userName != NULL) { + printf("DoUserAuthRequest_svcName[%s]: userName set on invalid " + "service (expected NULL)\n", cases[i].label); + result = -530 - i; goto done; } + + wolfSSH_free(ssh); + ssh = NULL; } done: @@ -1609,6 +1805,11 @@ int wolfSSH_UnitTest(int argc, char** argv) unitResult = test_ChannelPutData(); printf("ChannelPutData: %s\n", (unitResult == 0 ? "SUCCESS" : "FAILED")); testResult = testResult || unitResult; + + unitResult = test_DoUserAuthRequest_serviceName(); + printf("DoUserAuthRequest_serviceName: %s\n", + (unitResult == 0 ? "SUCCESS" : "FAILED")); + testResult = testResult || unitResult; #endif #ifdef WOLFSSH_KEYGEN diff --git a/wolfssh/internal.h b/wolfssh/internal.h index 01ac5bd67..2260120f7 100644 --- a/wolfssh/internal.h +++ b/wolfssh/internal.h @@ -1350,6 +1350,8 @@ enum WS_MessageIdLimits { WOLFSSH_API int wolfSSH_TestDoKexDhReply(WOLFSSH* ssh, byte* buf, word32 len, word32* idx); WOLFSSH_API int wolfSSH_TestChannelPutData(WOLFSSH_CHANNEL*, byte*, word32); + WOLFSSH_API int wolfSSH_TestDoUserAuthRequest(WOLFSSH* ssh, byte* buf, + word32 len, word32* idx); #ifndef WOLFSSH_NO_DH_GEX_SHA256 WOLFSSH_API int wolfSSH_TestDoKexDhGexRequest(WOLFSSH* ssh, byte* buf, word32 len, word32* idx);