|
14 | 14 |
|
15 | 15 | namespace simple_socket { |
16 | 16 |
|
| 17 | + enum : uint8_t { |
| 18 | + WS_CONT = 0x0, |
| 19 | + WS_TEXT = 0x1, |
| 20 | + WS_BIN = 0x2, |
| 21 | + WS_CLOSE = 0x8, |
| 22 | + WS_PING = 0x9, |
| 23 | + WS_PONG = 0xA |
| 24 | + }; |
| 25 | + |
17 | 26 | struct WebSocketCallbacks { |
18 | 27 | std::function<void(WebSocketConnection*)>& onOpen; |
19 | 28 | std::function<void(WebSocketConnection*)>& onClose; |
@@ -63,6 +72,12 @@ namespace simple_socket { |
63 | 72 | conn_->write(frame); |
64 | 73 | } |
65 | 74 |
|
| 75 | + void send(const uint8_t* message, size_t len) override { |
| 76 | + const auto frame = buildBin(message, len, role_); |
| 77 | + std::lock_guard lg(tx_mtx_); |
| 78 | + conn_->write(frame); |
| 79 | + } |
| 80 | + |
66 | 81 | void close(bool self) { |
67 | 82 | if (closed_.exchange(true)) return; |
68 | 83 |
|
@@ -148,100 +163,157 @@ namespace simple_socket { |
148 | 163 | } |
149 | 164 |
|
150 | 165 | static std::vector<uint8_t> buildText(const std::string& s, Role role) { |
151 | | - return buildFrame(0x1, reinterpret_cast<const uint8_t*>(s.data()), s.size(), role); |
| 166 | + return buildFrame(WS_TEXT, reinterpret_cast<const uint8_t*>(s.data()), s.size(), role); |
| 167 | + } |
| 168 | + |
| 169 | + static std::vector<uint8_t> buildBin(const uint8_t* msg, size_t len, Role role) { |
| 170 | + return buildFrame(WS_BIN, msg, len, role); |
152 | 171 | } |
153 | 172 |
|
154 | 173 | static std::vector<uint8_t> buildClose(uint16_t code, Role role) { |
155 | 174 | uint8_t p[2] = {static_cast<uint8_t>(code >> 8), static_cast<uint8_t>(code & 0xFF)}; |
156 | | - return buildFrame(0x8, p, 2, role); |
| 175 | + return buildFrame(WS_CLOSE, p, 2, role); |
157 | 176 | } |
158 | 177 |
|
159 | 178 | static std::vector<uint8_t> buildPong(const std::vector<uint8_t>& payload, Role role) { |
160 | | - return buildFrame(0xA, payload.data(), payload.size(), role); |
| 179 | + return buildFrame(WS_PONG, payload.data(), payload.size(), role); |
161 | 180 | } |
162 | 181 |
|
163 | 182 | void listen() { |
| 183 | + std::vector<uint8_t> rx; // accumulated bytes from socket |
| 184 | + std::vector<uint8_t> message;// assembling fragmented messages |
| 185 | + bool continued = false; |
| 186 | + uint8_t startOpcode = 0; |
| 187 | + |
164 | 188 | while (!closed_) { |
165 | | - const auto recv = conn_->read(buffer); |
| 189 | + const int recv = conn_->read(buffer); |
166 | 190 | if (recv <= 0) break; |
167 | 191 |
|
168 | | - std::vector<uint8_t> frame{buffer.begin(), buffer.begin() + recv}; |
169 | | - if (frame.size() < 2) continue; |
170 | | - |
171 | | - const uint8_t b0 = frame[0]; |
172 | | - const uint8_t b1 = frame[1]; |
173 | | - const uint8_t opcode = b0 & 0x0F; |
174 | | - const bool isMasked = (b1 & 0x80) != 0; |
175 | | - uint64_t payloadLen = (b1 & 0x7F); |
176 | | - |
177 | | - // Role‑based masking validation |
178 | | - if (role_ == Role::Server && !isMasked) { |
179 | | - // Protocol error: client must mask; tear down |
180 | | - std::lock_guard lg(tx_mtx_); |
181 | | - conn_->write(buildClose(1002, role_));// 1002 protocol error |
182 | | - break; |
183 | | - } |
184 | | - if (role_ == Role::Client && isMasked) { |
185 | | - // Protocol error: server must not mask |
186 | | - std::lock_guard lg(tx_mtx_); |
187 | | - conn_->write(buildClose(1002, role_)); |
188 | | - break; |
189 | | - } |
| 192 | + rx.insert(rx.end(), buffer.begin(), buffer.begin() + recv); |
190 | 193 |
|
191 | | - size_t pos = 2; |
192 | | - if (payloadLen == 126) { |
193 | | - if (frame.size() < 4) continue; |
194 | | - payloadLen = (static_cast<uint64_t>(frame[2]) << 8) | frame[3]; |
195 | | - pos += 2; |
196 | | - } else if (payloadLen == 127) { |
197 | | - if (frame.size() < 10) continue; |
198 | | - payloadLen = 0; |
199 | | - for (int i = 0; i < 8; ++i) payloadLen = (payloadLen << 8) | frame[2 + i]; |
200 | | - pos += 8; |
201 | | - } |
| 194 | + size_t pos = 0; |
| 195 | + while (true) { |
| 196 | + if (rx.size() - pos < 2) break; |
202 | 197 |
|
203 | | - if (frame.size() < pos + (isMasked ? 4 : 0) + payloadLen) continue; |
| 198 | + const uint8_t b0 = rx[pos]; |
| 199 | + const uint8_t b1 = rx[pos + 1]; |
| 200 | + const bool fin = (b0 & 0x80) != 0; |
| 201 | + const uint8_t opcode = static_cast<uint8_t>(b0 & 0x0F); |
| 202 | + const bool isMasked = (b1 & 0x80) != 0; |
| 203 | + uint64_t payloadLen = (b1 & 0x7F); |
204 | 204 |
|
205 | | - uint8_t mask[4] = {0, 0, 0, 0}; |
206 | | - if (isMasked) |
207 | | - for (int i = 0; i < 4; ++i) mask[i] = frame[pos++]; |
| 205 | + // Masking rules |
| 206 | + if (role_ == Role::Server && !isMasked) { |
| 207 | + std::lock_guard lg(tx_mtx_); |
| 208 | + conn_->write(buildClose(1002, role_)); |
| 209 | + close(false); |
| 210 | + return; |
| 211 | + } |
| 212 | + if (role_ == Role::Client && isMasked) { |
| 213 | + std::lock_guard lg(tx_mtx_); |
| 214 | + conn_->write(buildClose(1002, role_)); |
| 215 | + close(false); |
| 216 | + return; |
| 217 | + } |
208 | 218 |
|
209 | | - std::vector<uint8_t> payload(frame.begin() + pos, frame.begin() + pos + payloadLen); |
210 | | - if (isMasked) { |
211 | | - for (size_t i = 0; i < payload.size(); ++i) { |
212 | | - payload[i] ^= mask[i & 0x03]; |
| 219 | + size_t hdr = 2; |
| 220 | + if (payloadLen == 126) { |
| 221 | + if (rx.size() - pos < hdr + 2) break; |
| 222 | + payloadLen = (static_cast<uint64_t>(rx[pos + 2]) << 8) | rx[pos + 3]; |
| 223 | + hdr += 2; |
| 224 | + } else if (payloadLen == 127) { |
| 225 | + if (rx.size() - pos < hdr + 8) break; |
| 226 | + payloadLen = 0; |
| 227 | + for (int i = 0; i < 8; ++i) payloadLen = (payloadLen << 8) | rx[pos + 2 + i]; |
| 228 | + hdr += 8; |
213 | 229 | } |
214 | | - } |
215 | 230 |
|
216 | | - switch (opcode) { |
217 | | - case 0x1: {// Text |
218 | | - std::string message(payload.begin(), payload.end()); |
219 | | - if (callbacks_.onMessage) callbacks_.onMessage(this, message); |
220 | | - } break; |
| 231 | + const size_t need = hdr + (isMasked ? 4 : 0) + payloadLen; |
| 232 | + if (rx.size() - pos < need) break; |
| 233 | + |
| 234 | + size_t off = pos + hdr; |
| 235 | + uint8_t mask[4] = {0, 0, 0, 0}; |
| 236 | + if (isMasked) { |
| 237 | + for (int i = 0; i < 4; ++i) mask[i] = rx[off + i]; |
| 238 | + off += 4; |
| 239 | + } |
| 240 | + |
| 241 | + std::vector<uint8_t> chunk; |
| 242 | + chunk.resize(payloadLen); |
| 243 | + for (size_t i = 0; i < chunk.size(); ++i) { |
| 244 | + uint8_t b = rx[off + i]; |
| 245 | + if (isMasked) b ^= mask[i & 0x03]; |
| 246 | + chunk[i] = b; |
| 247 | + } |
221 | 248 |
|
222 | | - case 0x8: {// Close |
223 | | - // Echo a close if we didn't initiate one |
224 | | - std::vector<uint8_t> closeResp = buildClose(1000, role_); |
| 249 | + pos = off + payloadLen; |
| 250 | + |
| 251 | + if (opcode == WS_CLOSE) { |
225 | 252 | { |
226 | 253 | std::lock_guard lg(tx_mtx_); |
227 | | - conn_->write(closeResp); |
| 254 | + conn_->write(buildClose(1000, role_)); |
228 | 255 | } |
229 | 256 | close(false); |
230 | | - } break; |
231 | | - |
232 | | - case 0x9: {// Ping |
233 | | - std::vector<uint8_t> pongFrame = buildPong(payload, role_); |
| 257 | + return; |
| 258 | + } else if (opcode == WS_PING) { |
234 | 259 | std::lock_guard lg(tx_mtx_); |
235 | | - conn_->write(pongFrame); |
236 | | - } break; |
237 | | - |
238 | | - case 0xA:// Pong |
239 | | - break; |
240 | | - |
241 | | - default: |
242 | | - std::cerr << "Unsupported opcode: " << static_cast<int>(opcode) << std::endl; |
243 | | - break; |
| 260 | + conn_->write(buildPong(chunk, role_)); |
| 261 | + continue; |
| 262 | + } else if (opcode == WS_PONG) { |
| 263 | + continue; |
| 264 | + } else if (opcode == WS_CONT) { |
| 265 | + if (!continued) { |
| 266 | + std::lock_guard lg(tx_mtx_); |
| 267 | + conn_->write(buildClose(1002, role_)); |
| 268 | + close(false); |
| 269 | + return; |
| 270 | + } |
| 271 | + message.insert(message.end(), chunk.begin(), chunk.end()); |
| 272 | + if (fin) { |
| 273 | + // deliver completed fragmented message |
| 274 | + if (startOpcode == WS_TEXT) { |
| 275 | + std::string s(message.begin(), message.end()); |
| 276 | + if (callbacks_.onMessage) callbacks_.onMessage(this, s); |
| 277 | + } else if (startOpcode == WS_BIN) { |
| 278 | + std::string s(reinterpret_cast<const char*>(message.data()), message.size()); |
| 279 | + if (callbacks_.onMessage) callbacks_.onMessage(this, s); |
| 280 | + } |
| 281 | + message.clear(); |
| 282 | + continued = false; |
| 283 | + startOpcode = 0; |
| 284 | + } |
| 285 | + continue; |
| 286 | + } else if (opcode == WS_TEXT || opcode == WS_BIN) { |
| 287 | + if (continued) { |
| 288 | + std::lock_guard lg(tx_mtx_); |
| 289 | + conn_->write(buildClose(1002, role_)); |
| 290 | + close(false); |
| 291 | + return; |
| 292 | + } |
| 293 | + if (fin) { |
| 294 | + // single-frame message |
| 295 | + if (opcode == WS_TEXT) { |
| 296 | + std::string s(chunk.begin(), chunk.end()); |
| 297 | + if (callbacks_.onMessage) callbacks_.onMessage(this, s); |
| 298 | + } else { |
| 299 | + std::string s(reinterpret_cast<const char*>(chunk.data()), chunk.size()); |
| 300 | + if (callbacks_.onMessage) callbacks_.onMessage(this, s); |
| 301 | + } |
| 302 | + } else { |
| 303 | + // start fragmented message |
| 304 | + message = std::move(chunk); |
| 305 | + continued = true; |
| 306 | + startOpcode = opcode; |
| 307 | + } |
| 308 | + continue; |
| 309 | + } else { |
| 310 | + // Unknown opcode: ignore this frame |
| 311 | + continue; |
| 312 | + } |
244 | 313 | } |
| 314 | + |
| 315 | + // discard parsed bytes, keep remainder for next read |
| 316 | + if (pos > 0) rx.erase(rx.begin(), rx.begin() + pos); |
245 | 317 | } |
246 | 318 | } |
247 | 319 | }; |
|
0 commit comments