More ws cleanup

Signed-off-by: Martin Sustrik <sustrik@250bpm.com>
diff --git a/src/transports/ws/sws.c b/src/transports/ws/sws.c
index 31099c9..ef9fb4e 100644
--- a/src/transports/ws/sws.c
+++ b/src/transports/ws/sws.c
@@ -67,7 +67,6 @@
 
 /*  WebSocket protocol header bit masks as per RFC 6455. */
 #define NN_SWS_FRAME_BITMASK_MASKED 0x80
-#define NN_SWS_FRAME_BITMASK_NOT_MASKED 0x00
 #define NN_SWS_FRAME_BITMASK_LENGTH 0x7F
 
 /*  WebSocket Close Status Codes (1004-1006 and 1015 are reserved). */
@@ -123,7 +122,7 @@
 static void nn_sws_fail_conn (struct nn_sws *self, int code, char *reason);
 static int nn_sws_recv_hdr (struct nn_sws *self);
 static void nn_sws_mask_payload (uint8_t *payload, size_t payload_len,
-    const uint8_t *mask, size_t mask_len, int *mask_start_pos);
+    uint8_t mask [4], int *mask_start_pos);
 static void nn_sws_validate_close_handshake (struct nn_sws *self);
 
 void nn_sws_init (struct nn_sws *self, int src,
@@ -348,38 +347,18 @@
 
                     sws->is_final_frame = sws->inhdr [0] &
                         NN_SWS_FRAME_BITMASK_FIN;
-                    sws->masked = sws->inhdr [1] &
-                        NN_SWS_FRAME_BITMASK_MASKED;
-
-                    switch (sws->mode) {
-                    case NN_WS_SERVER:
-                        /*  Require mask bit to be set from client. */
-                        if (sws->masked) {
-                            /*  Continue receiving header for this frame. */
-                            sws->ext_hdr_len = 4;
-                            break;
-                        }
-                        else {
-                            nn_sws_fail_conn (sws, NN_SWS_CLOSE_ERR_PROTO,
-                                "Server expects MASK bit to be set.");
-                            return;
-                        }
-                    case NN_WS_CLIENT:
-                        /*  Require mask bit to be unset from server. */
-                        if (sws->masked) {
-                            nn_sws_fail_conn (sws, NN_SWS_CLOSE_ERR_PROTO,
-                                "Client expects MASK bit to be unset.");
-                            return;
-                        }
-                        else {
-                            /*  Continue receiving header for this frame. */
-                            sws->ext_hdr_len = 0;
-                            break;
-                        }
-                    default:
-                        /*  Only two modes of this endpoint are expected. */
-                        nn_assert (0);
-                        return;
+     
+                    /*  Communication from client to server must be masked.
+                        Communication from server to client must be unmasked. */
+                    if (sws->mode == NN_WS_SERVER) {
+                        nn_assert (sws->inhdr [1] &
+                            NN_SWS_FRAME_BITMASK_MASKED);
+                        sws->ext_hdr_len = 4;
+                    }
+                    else {
+                        nn_assert (!(sws->inhdr [1] &
+                            NN_SWS_FRAME_BITMASK_MASKED));
+                        sws->ext_hdr_len = 0;
                     }
 
                     sws->opcode = sws->inhdr [0] &
@@ -581,36 +560,28 @@
 
                     if (sws->payload_ctl <= 0x7d) {
                         sws->inmsg_current_chunk_len = sws->payload_ctl;
-                        if (sws->masked) {
-                            sws->mask = sws->inhdr + NN_SWS_FRAME_SIZE_INITIAL;
-                        }
-                        else {
-                            sws->mask = NULL;
+                        if (sws->mode == NN_WS_SERVER) {
+                            memcpy (sws->mask,
+                               sws->inhdr + NN_SWS_FRAME_SIZE_INITIAL, 4);
                         }
                     }
                     else if (sws->payload_ctl == 0xffff) {
                         sws->inmsg_current_chunk_len =
                             nn_gets (sws->inhdr + NN_SWS_FRAME_SIZE_INITIAL);
-                        if (sws->masked) {
-                            sws->mask = sws->inhdr +
+                        if (sws->mode == NN_WS_SERVER) {
+                            memcpy (sws->mask, sws->inhdr +
                                 NN_SWS_FRAME_SIZE_INITIAL +
-                                NN_SWS_FRAME_SIZE_PAYLOAD_16;
-                        }
-                        else {
-                            sws->mask = NULL;
+                                NN_SWS_FRAME_SIZE_PAYLOAD_16, 4);
                         }
                     }
                     else {
                         sws->inmsg_current_chunk_len =
                             (size_t) nn_getll (sws->inhdr +
                             NN_SWS_FRAME_SIZE_INITIAL);
-                        if (sws->masked) {
-                            sws->mask = sws->inhdr +
+                        if (sws->mode == NN_WS_SERVER) {
+                            memcpy (sws->mask, sws->inhdr +
                                 NN_SWS_FRAME_SIZE_INITIAL +
-                                NN_SWS_FRAME_SIZE_PAYLOAD_63;
-                        }
-                        else {
-                            sws->mask = NULL;
+                                NN_SWS_FRAME_SIZE_PAYLOAD_63, 4);
                         }
                     }
 
@@ -655,23 +626,14 @@
                 case NN_SWS_INSTATE_RECV_PAYLOAD:
 
                     /*  Unmask if necessary. */
-                    if (sws->masked) {
+                    if (sws->mode == NN_WS_SERVER) {
                         nn_sws_mask_payload (sws->inmsg_current_chunk_buf,
-                            sws->inmsg_current_chunk_len, sws->mask, 4, NULL);
+                            sws->inmsg_current_chunk_len, sws->mask, NULL);
                     }
 
                     switch (sws->opcode) {
 
                     case NN_WS_OPCODE_BINARY:
-                        if (sws->is_final_frame) {
-                            sws->instate = NN_SWS_INSTATE_RECVD_CHUNKED;
-                            nn_pipebase_received (&sws->pipebase);
-                        }
-                        else {
-                            nn_sws_recv_hdr (sws);
-                        }
-                        return;
-
                     case NN_WS_OPCODE_FRAGMENT:
                         if (sws->is_final_frame) {
                             sws->instate = NN_SWS_INSTATE_RECVD_CHUNKED;
@@ -683,6 +645,7 @@
                         return;
 
                     case NN_WS_OPCODE_CLOSE:
+
                         /*  If the payload is not even long enough for the
                             required 2-octet Close Code, the connection
                             should have been failed upstream. */
@@ -837,22 +800,22 @@
 
 /*  Mask or unmask message payload. */
 static void nn_sws_mask_payload (uint8_t *payload, size_t payload_len,
-    const uint8_t *mask, size_t mask_len, int *mask_start_pos)
+    uint8_t mask [4], int *mask_start_pos)
 {
     unsigned i;
 
     if (mask_start_pos) {
         for (i = 0; i < payload_len; i++) {
-            payload [i] ^= mask [(i + *mask_start_pos) % mask_len];
+            payload [i] ^= mask [(i + *mask_start_pos) % 4];
         }
 
-        *mask_start_pos = (i + *mask_start_pos) % mask_len;
+        *mask_start_pos = (i + *mask_start_pos) % 4;
 
         return;
     }
     else {
         for (i = 0; i < payload_len; i++) {
-            payload [i] ^= mask [i % mask_len];
+            payload [i] ^= mask [i % 4];
         }
         return;
     }
@@ -940,11 +903,9 @@
                    operations, masking only as much data at a time. */
         mask_pos = 0;
         nn_sws_mask_payload (nn_chunkref_data (&sws->outmsg.sphdr),
-            nn_chunkref_size (&sws->outmsg.sphdr),
-            mask, 4, &mask_pos);
+            nn_chunkref_size (&sws->outmsg.sphdr), mask, &mask_pos);
         nn_sws_mask_payload (nn_chunkref_data (&sws->outmsg.body),
-            nn_chunkref_size (&sws->outmsg.body),
-            mask, 4, &mask_pos);
+            nn_chunkref_size (&sws->outmsg.body), mask, &mask_pos);
     }
 
     /*  Start async sending. */
@@ -1154,23 +1115,13 @@
 
     self->fail_msg_len = NN_SWS_FRAME_SIZE_INITIAL;
 
-    if (self->mode == NN_WS_SERVER) {
-        self->fail_msg [1] |= NN_SWS_FRAME_BITMASK_NOT_MASKED;
-    }
-    else if (self->mode == NN_WS_CLIENT) {
+    /*  Generate 32-bit mask as per RFC 6455 5.3. */
+    if (self->mode == NN_WS_CLIENT) {
         self->fail_msg [1] |= NN_SWS_FRAME_BITMASK_MASKED;
-
-        /*  Generate 32-bit mask as per RFC 6455 5.3. */
         nn_random_generate (mask, sizeof (mask));
-        
         memcpy (&self->fail_msg [NN_SWS_FRAME_SIZE_INITIAL], mask, 4);
-
         self->fail_msg_len += 4;
     }
-    else {
-        /*  Developer error. */
-        nn_assert (0);
-    }
 
     payload_pos = &self->fail_msg [self->fail_msg_len];
     
@@ -1183,7 +1134,7 @@
 
     /*  If this is a client, apply mask. */
     if (self->mode == NN_WS_CLIENT) {
-        nn_sws_mask_payload (payload_pos, payload_len, mask, 4, NULL);
+        nn_sws_mask_payload (payload_pos, payload_len, mask, NULL);
     }
 
     self->fail_msg_len += payload_len;
diff --git a/src/transports/ws/sws.h b/src/transports/ws/sws.h
index 3c32325..5c07e3f 100644
--- a/src/transports/ws/sws.h
+++ b/src/transports/ws/sws.h
@@ -83,11 +83,13 @@
     /*  Buffer used to store the framing of incoming message. */
     uint8_t inhdr [NN_SWS_FRAME_MAX_HDR_LEN];
 
+    /*  On the server side this field contains mask of the incoming message.
+        On the client side it is unused. */
+    uint8_t mask [4];
+
     /*  Parsed header frames. */
     uint8_t opcode;
     uint8_t payload_ctl;
-    uint8_t masked;
-    uint8_t *mask;
     size_t ext_hdr_len;
     int is_final_frame;
     int is_control_frame;