Skip to content

Commit 505ec19

Browse files
committed
update websocket
1 parent f46f9f2 commit 505ec19

3 files changed

Lines changed: 147 additions & 68 deletions

File tree

include/simple_socket/ws/WebSocket.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ namespace simple_socket {
1717

1818
virtual void send(const std::string& msg) = 0;
1919

20+
virtual void send(const uint8_t* msg, size_t len) = 0;
21+
2022
virtual ~WebSocketConnection() = default;
2123

2224
private:

src/simple_socket/ws/WebSocketClient.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,11 @@ struct WebSocketClient::Impl {
214214
conn->send(message);
215215
}
216216

217+
void send(const uint8_t* message, size_t len) {
218+
219+
conn->send(message, len);
220+
}
221+
217222
void close() {
218223

219224
conn->close(true);

src/simple_socket/ws/WebSocketConnection.hpp

Lines changed: 140 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414

1515
namespace simple_socket {
1616

17+
enum : uint8_t {
18+
WS_CONT = 0x0,
19+
WS_TEXT = 0x1,
20+
WS_BIN = 0x2,
21+
WS_CLOSE = 0x8,
22+
WS_PING = 0x9,
23+
WS_PONG = 0xA
24+
};
25+
1726
struct WebSocketCallbacks {
1827
std::function<void(WebSocketConnection*)>& onOpen;
1928
std::function<void(WebSocketConnection*)>& onClose;
@@ -63,6 +72,12 @@ namespace simple_socket {
6372
conn_->write(frame);
6473
}
6574

75+
void send(const uint8_t* message, size_t len) override {
76+
const auto frame = buildBin(message, len, role_);
77+
std::lock_guard lg(tx_mtx_);
78+
conn_->write(frame);
79+
}
80+
6681
void close(bool self) {
6782
if (closed_.exchange(true)) return;
6883

@@ -148,100 +163,157 @@ namespace simple_socket {
148163
}
149164

150165
static std::vector<uint8_t> buildText(const std::string& s, Role role) {
151-
return buildFrame(0x1, reinterpret_cast<const uint8_t*>(s.data()), s.size(), role);
166+
return buildFrame(WS_TEXT, reinterpret_cast<const uint8_t*>(s.data()), s.size(), role);
167+
}
168+
169+
static std::vector<uint8_t> buildBin(const uint8_t* msg, size_t len, Role role) {
170+
return buildFrame(WS_BIN, msg, len, role);
152171
}
153172

154173
static std::vector<uint8_t> buildClose(uint16_t code, Role role) {
155174
uint8_t p[2] = {static_cast<uint8_t>(code >> 8), static_cast<uint8_t>(code & 0xFF)};
156-
return buildFrame(0x8, p, 2, role);
175+
return buildFrame(WS_CLOSE, p, 2, role);
157176
}
158177

159178
static std::vector<uint8_t> buildPong(const std::vector<uint8_t>& payload, Role role) {
160-
return buildFrame(0xA, payload.data(), payload.size(), role);
179+
return buildFrame(WS_PONG, payload.data(), payload.size(), role);
161180
}
162181

163182
void listen() {
183+
std::vector<uint8_t> rx; // accumulated bytes from socket
184+
std::vector<uint8_t> message;// assembling fragmented messages
185+
bool continued = false;
186+
uint8_t startOpcode = 0;
187+
164188
while (!closed_) {
165-
const auto recv = conn_->read(buffer);
189+
const int recv = conn_->read(buffer);
166190
if (recv <= 0) break;
167191

168-
std::vector<uint8_t> frame{buffer.begin(), buffer.begin() + recv};
169-
if (frame.size() < 2) continue;
170-
171-
const uint8_t b0 = frame[0];
172-
const uint8_t b1 = frame[1];
173-
const uint8_t opcode = b0 & 0x0F;
174-
const bool isMasked = (b1 & 0x80) != 0;
175-
uint64_t payloadLen = (b1 & 0x7F);
176-
177-
// Role‑based masking validation
178-
if (role_ == Role::Server && !isMasked) {
179-
// Protocol error: client must mask; tear down
180-
std::lock_guard lg(tx_mtx_);
181-
conn_->write(buildClose(1002, role_));// 1002 protocol error
182-
break;
183-
}
184-
if (role_ == Role::Client && isMasked) {
185-
// Protocol error: server must not mask
186-
std::lock_guard lg(tx_mtx_);
187-
conn_->write(buildClose(1002, role_));
188-
break;
189-
}
192+
rx.insert(rx.end(), buffer.begin(), buffer.begin() + recv);
190193

191-
size_t pos = 2;
192-
if (payloadLen == 126) {
193-
if (frame.size() < 4) continue;
194-
payloadLen = (static_cast<uint64_t>(frame[2]) << 8) | frame[3];
195-
pos += 2;
196-
} else if (payloadLen == 127) {
197-
if (frame.size() < 10) continue;
198-
payloadLen = 0;
199-
for (int i = 0; i < 8; ++i) payloadLen = (payloadLen << 8) | frame[2 + i];
200-
pos += 8;
201-
}
194+
size_t pos = 0;
195+
while (true) {
196+
if (rx.size() - pos < 2) break;
202197

203-
if (frame.size() < pos + (isMasked ? 4 : 0) + payloadLen) continue;
198+
const uint8_t b0 = rx[pos];
199+
const uint8_t b1 = rx[pos + 1];
200+
const bool fin = (b0 & 0x80) != 0;
201+
const uint8_t opcode = static_cast<uint8_t>(b0 & 0x0F);
202+
const bool isMasked = (b1 & 0x80) != 0;
203+
uint64_t payloadLen = (b1 & 0x7F);
204204

205-
uint8_t mask[4] = {0, 0, 0, 0};
206-
if (isMasked)
207-
for (int i = 0; i < 4; ++i) mask[i] = frame[pos++];
205+
// Masking rules
206+
if (role_ == Role::Server && !isMasked) {
207+
std::lock_guard lg(tx_mtx_);
208+
conn_->write(buildClose(1002, role_));
209+
close(false);
210+
return;
211+
}
212+
if (role_ == Role::Client && isMasked) {
213+
std::lock_guard lg(tx_mtx_);
214+
conn_->write(buildClose(1002, role_));
215+
close(false);
216+
return;
217+
}
208218

209-
std::vector<uint8_t> payload(frame.begin() + pos, frame.begin() + pos + payloadLen);
210-
if (isMasked) {
211-
for (size_t i = 0; i < payload.size(); ++i) {
212-
payload[i] ^= mask[i & 0x03];
219+
size_t hdr = 2;
220+
if (payloadLen == 126) {
221+
if (rx.size() - pos < hdr + 2) break;
222+
payloadLen = (static_cast<uint64_t>(rx[pos + 2]) << 8) | rx[pos + 3];
223+
hdr += 2;
224+
} else if (payloadLen == 127) {
225+
if (rx.size() - pos < hdr + 8) break;
226+
payloadLen = 0;
227+
for (int i = 0; i < 8; ++i) payloadLen = (payloadLen << 8) | rx[pos + 2 + i];
228+
hdr += 8;
213229
}
214-
}
215230

216-
switch (opcode) {
217-
case 0x1: {// Text
218-
std::string message(payload.begin(), payload.end());
219-
if (callbacks_.onMessage) callbacks_.onMessage(this, message);
220-
} break;
231+
const size_t need = hdr + (isMasked ? 4 : 0) + payloadLen;
232+
if (rx.size() - pos < need) break;
233+
234+
size_t off = pos + hdr;
235+
uint8_t mask[4] = {0, 0, 0, 0};
236+
if (isMasked) {
237+
for (int i = 0; i < 4; ++i) mask[i] = rx[off + i];
238+
off += 4;
239+
}
240+
241+
std::vector<uint8_t> chunk;
242+
chunk.resize(payloadLen);
243+
for (size_t i = 0; i < chunk.size(); ++i) {
244+
uint8_t b = rx[off + i];
245+
if (isMasked) b ^= mask[i & 0x03];
246+
chunk[i] = b;
247+
}
221248

222-
case 0x8: {// Close
223-
// Echo a close if we didn't initiate one
224-
std::vector<uint8_t> closeResp = buildClose(1000, role_);
249+
pos = off + payloadLen;
250+
251+
if (opcode == WS_CLOSE) {
225252
{
226253
std::lock_guard lg(tx_mtx_);
227-
conn_->write(closeResp);
254+
conn_->write(buildClose(1000, role_));
228255
}
229256
close(false);
230-
} break;
231-
232-
case 0x9: {// Ping
233-
std::vector<uint8_t> pongFrame = buildPong(payload, role_);
257+
return;
258+
} else if (opcode == WS_PING) {
234259
std::lock_guard lg(tx_mtx_);
235-
conn_->write(pongFrame);
236-
} break;
237-
238-
case 0xA:// Pong
239-
break;
240-
241-
default:
242-
std::cerr << "Unsupported opcode: " << static_cast<int>(opcode) << std::endl;
243-
break;
260+
conn_->write(buildPong(chunk, role_));
261+
continue;
262+
} else if (opcode == WS_PONG) {
263+
continue;
264+
} else if (opcode == WS_CONT) {
265+
if (!continued) {
266+
std::lock_guard lg(tx_mtx_);
267+
conn_->write(buildClose(1002, role_));
268+
close(false);
269+
return;
270+
}
271+
message.insert(message.end(), chunk.begin(), chunk.end());
272+
if (fin) {
273+
// deliver completed fragmented message
274+
if (startOpcode == WS_TEXT) {
275+
std::string s(message.begin(), message.end());
276+
if (callbacks_.onMessage) callbacks_.onMessage(this, s);
277+
} else if (startOpcode == WS_BIN) {
278+
std::string s(reinterpret_cast<const char*>(message.data()), message.size());
279+
if (callbacks_.onMessage) callbacks_.onMessage(this, s);
280+
}
281+
message.clear();
282+
continued = false;
283+
startOpcode = 0;
284+
}
285+
continue;
286+
} else if (opcode == WS_TEXT || opcode == WS_BIN) {
287+
if (continued) {
288+
std::lock_guard lg(tx_mtx_);
289+
conn_->write(buildClose(1002, role_));
290+
close(false);
291+
return;
292+
}
293+
if (fin) {
294+
// single-frame message
295+
if (opcode == WS_TEXT) {
296+
std::string s(chunk.begin(), chunk.end());
297+
if (callbacks_.onMessage) callbacks_.onMessage(this, s);
298+
} else {
299+
std::string s(reinterpret_cast<const char*>(chunk.data()), chunk.size());
300+
if (callbacks_.onMessage) callbacks_.onMessage(this, s);
301+
}
302+
} else {
303+
// start fragmented message
304+
message = std::move(chunk);
305+
continued = true;
306+
startOpcode = opcode;
307+
}
308+
continue;
309+
} else {
310+
// Unknown opcode: ignore this frame
311+
continue;
312+
}
244313
}
314+
315+
// discard parsed bytes, keep remainder for next read
316+
if (pos > 0) rx.erase(rx.begin(), rx.begin() + pos);
245317
}
246318
}
247319
};

0 commit comments

Comments
 (0)