wswrapper: wrap existing server using LD_PRELOAD.

wswrapper.so is LD_PRELOAD shared library that interposes and turns
a generic TCP socket into a WebSockets service.

This current version works but will only allow work for a single
connection, subsequent connections will not be wrapped. In addition
the wrapper interposes on the first incoming network connection. It
should read an environment variable to determine the port to interpose
on. Also, should limit origin based on another environment variable.
Then there should be a wswrap setup script that allows easier
invocation.
This commit is contained in:
Joel Martin 2010-12-02 22:11:02 -06:00
parent 16d1493fce
commit 7ac02c951d
2 changed files with 619 additions and 3 deletions

View File

@ -1,11 +1,21 @@
TARGETS=wsproxy wswrapper.so
CFLAGS += -fPIC
all: $(TARGETS)
wsproxy: wsproxy.o websocket.o md5.o wsproxy: wsproxy.o websocket.o md5.o
$(CC) $(LDFLAGS) $^ -l ssl -lcrypto -l resolv -o $@ $(CC) $(LDFLAGS) $^ -lssl -lcrypto -lresolv -o $@
wswrapper.so: wswrapper.o md5.o
$(CC) $(LDFLAGS) $^ -shared -fPIC -ldl -lresolv -o $@
websocket.o: websocket.c websocket.h md5.h websocket.o: websocket.c websocket.h md5.h
wsproxy.o: wsproxy.c websocket.h wsproxy.o: wsproxy.c websocket.h
wswrapper.o: wswrapper.c
$(CC) -c $(CFLAGS) -o $@ $*.c
md5.o: md5.c md5.h md5.o: md5.c md5.h
$(CC) $(CFLAGS) -c -o $@ $*.c -DHAVE_MEMCPY -DSTDC_HEADERS $(CC) -c $(CFLAGS) -o $@ $*.c -DHAVE_MEMCPY -DSTDC_HEADERS
clean: clean:
rm -f wsproxy wsproxy.o websocket.o md5.o rm -f wsproxy wswrapper.so *.o

606
utils/wswrapper.c Normal file
View File

@ -0,0 +1,606 @@
#include <stdio.h>
#include <stdlib.h>
#define __USE_GNU 1 // Pull in RTLD_NEXT
#include <dlfcn.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <errno.h>
#include <string.h>
#include <resolv.h> /* base64 encode/decode */
#include "md5.h"
//#define DO_DEBUG 1
#ifdef DO_DEBUG
#define DEBUG(...) \
if (DO_DEBUG) { \
fprintf(stderr, "wswrapper: "); \
fprintf(stderr, __VA_ARGS__); \
}
#else
#define DEBUG(...)
#endif
#define MSG(...) \
fprintf(stderr, "wswrapper: "); \
fprintf(stderr, __VA_ARGS__);
#define RET_ERROR(eno, ...) \
fprintf(stderr, "wswrapper error: "); \
fprintf(stderr, __VA_ARGS__); \
errno = eno; \
return -1;
const char _WS_response[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
Upgrade: WebSocket\r\n\
Connection: Upgrade\r\n\
%sWebSocket-Origin: %s\r\n\
%sWebSocket-Location: %s://%s%s\r\n\
%sWebSocket-Protocol: sample\r\n\
\r\n%s";
/* WARNING: threading not supported */
int _WS_bufsize = 65536;
char *_WS_rbuf = NULL;
char *_WS_sbuf = NULL;
int _WS_rcarry_cnt = 0;
char _WS_rcarry[3] = "";
int _WS_newframe = 1;
int _WS_sockfd = 0;
int _WS_init() {
if (! (_WS_rbuf = malloc(_WS_bufsize)) ) {
return 0;
}
if (! (_WS_sbuf = malloc(_WS_bufsize)) ) {
return 0;
}
}
int _WS_gen_md5(char *key1, char *key2, char *key3, char *target) {
unsigned int i, spaces1 = 0, spaces2 = 0;
unsigned long num1 = 0, num2 = 0;
unsigned char buf[17];
for (i=0; i < strlen(key1); i++) {
if (key1[i] == ' ') {
spaces1 += 1;
}
if ((key1[i] >= 48) && (key1[i] <= 57)) {
num1 = num1 * 10 + (key1[i] - 48);
}
}
num1 = num1 / spaces1;
for (i=0; i < strlen(key2); i++) {
if (key2[i] == ' ') {
spaces2 += 1;
}
if ((key2[i] >= 48) && (key2[i] <= 57)) {
num2 = num2 * 10 + (key2[i] - 48);
}
}
num2 = num2 / spaces2;
/* Pack it big-endian */
buf[0] = (num1 & 0xff000000) >> 24;
buf[1] = (num1 & 0xff0000) >> 16;
buf[2] = (num1 & 0xff00) >> 8;
buf[3] = num1 & 0xff;
buf[4] = (num2 & 0xff000000) >> 24;
buf[5] = (num2 & 0xff0000) >> 16;
buf[6] = (num2 & 0xff00) >> 8;
buf[7] = num2 & 0xff;
strncpy(buf+8, key3, 8);
buf[16] = '\0';
md5_buffer(buf, 16, target);
target[16] = '\0';
return 1;
}
int _WS_handshake(int sockfd)
{
int sz = 0, len, idx;
int ret = -1, save_errno = EPROTO;
char *last, *start, *end;
long flags;
char handshake[4096], response[4096],
path[1024], prefix[5] = "", scheme[10] = "ws", host[1024],
origin[1024], key1[100], key2[100], key3[9], chksum[17];
static void * (*rfunc)(), * (*wfunc)();
if (!rfunc) rfunc = (void *(*)()) dlsym(RTLD_NEXT, "recv");
if (!wfunc) wfunc = (void *(*)()) dlsym(RTLD_NEXT, "send");
DEBUG("_WS_handshake starting\n");
/* Disable NONBLOCK if set */
flags = fcntl(sockfd, F_GETFL, 0);
if (flags & O_NONBLOCK) {
fcntl(sockfd, F_SETFL, flags^O_NONBLOCK);
}
while (1) {
len = (int) rfunc(sockfd, handshake+sz, 4095, 0);
if (len < 1) {
ret = len;
save_errno = errno;
break;
}
sz += len;
handshake[sz] = '\x00';
if (sz < 4) {
// Not enough yet
continue;
}
if (strstr(handshake, "GET ") != handshake) {
// We got something but it wasn't a WebSockets client
break;
}
last = strstr(handshake, "\r\n\r\n");
if (! last) {
continue;
}
if (! strstr(handshake, "Upgrade: WebSocket\r\n")) {
MSG("Invalid WebSockets handshake\n");
break;
}
// Now parse out the data
start = handshake+4;
end = strstr(start, " HTTP/1.1");
if (!end) { break; }
snprintf(path, end-start+1, "%s", start);
start = strstr(handshake, "\r\nHost: ");
if (!start) { break; }
start += 8;
end = strstr(start, "\r\n");
snprintf(host, end-start+1, "%s", start);
start = strstr(handshake, "\r\nOrigin: ");
if (!start) { break; }
start += 10;
end = strstr(start, "\r\n");
snprintf(origin, end-start+1, "%s", start);
start = strstr(handshake, "\r\n\r\n") + 4;
if (strlen(start) == 8) {
sprintf(prefix, "Sec-");
snprintf(key3, 8+1, "%s", start);
start = strstr(handshake, "\r\nSec-WebSocket-Key1: ");
if (!start) { break; }
start += 22;
end = strstr(start, "\r\n");
snprintf(key1, end-start+1, "%s", start);
start = strstr(handshake, "\r\nSec-WebSocket-Key2: ");
if (!start) { break; }
start += 22;
end = strstr(start, "\r\n");
snprintf(key2, end-start+1, "%s", start);
_WS_gen_md5(key1, key2, key3, chksum);
//DEBUG("Got handshake (v76): %s\n", handshake);
MSG("Got handshake (v76)\n");
} else {
sprintf(prefix, "");
sprintf(key1, "");
sprintf(key2, "");
sprintf(key3, "");
sprintf(chksum, "");
//DEBUG("Got handshake (v75): %s\n", handshake);
MSG("Got handshake (v75)\n");
}
sprintf(response, _WS_response, prefix, origin, prefix, scheme,
host, path, prefix, chksum);
//DEBUG("Handshake response: %s\n", response);
wfunc(sockfd, response, strlen(response), 0);
save_errno = 0;
ret = 0;
break;
}
/* Re-enable NONBLOCK if it was set */
if (flags & O_NONBLOCK) {
fcntl(sockfd, F_SETFL, flags);
}
errno = save_errno;
return ret;
}
ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
size_t len, int flags)
{
int rawcount, deccount, left, rawlen, retlen, decodelen;
int sockflags;
int i;
char * fstart, * fend, * cstart;
static void * (*rfunc)(), * (*rfunc2)();
if (!rfunc) rfunc = (void *(*)()) dlsym(RTLD_NEXT, "recv");
if (!rfunc2) rfunc2 = (void *(*)()) dlsym(RTLD_NEXT, "read");
if (len == 0) {
return 0;
}
if ((_WS_sockfd == 0) || (_WS_sockfd != sockfd)) {
// Not our file descriptor, just pass through
if (recvf) {
return (ssize_t) rfunc(sockfd, buf, len, flags);
} else {
return (ssize_t) rfunc2(sockfd, buf, len);
}
}
DEBUG("_WS_recv(%d, _, %d) called\n", sockfd, len);
sockflags = fcntl(sockfd, F_GETFL, 0);
left = len;
retlen = 0;
// first copy in any carry-over bytes
if (_WS_rcarry_cnt) {
if (_WS_rcarry_cnt == 1) {
DEBUG("Using carry byte: %u (", _WS_rcarry[0]);
} else if (_WS_rcarry_cnt == 2) {
DEBUG("Using carry bytes: %u,%u (", _WS_rcarry[0],
_WS_rcarry[1]);
} else {
RET_ERROR(EIO, "Too many carry-over bytes\n");
}
if (len <= _WS_rcarry_cnt) {
DEBUG("final)\n");
memcpy((char *) buf, _WS_rcarry, len);
_WS_rcarry_cnt -= len;
return len;
} else {
DEBUG("prepending)\n");
memcpy((char *) buf, _WS_rcarry, _WS_rcarry_cnt);
retlen += _WS_rcarry_cnt;
left -= _WS_rcarry_cnt;
_WS_rcarry_cnt = 0;
}
}
// Determine the number of base64 encoded bytes needed
rawcount = (left * 4) / 3 + 3;
rawcount -= rawcount%4;
if (rawcount > _WS_bufsize - 1) {
RET_ERROR(ENOMEM, "recv of %d bytes is larger than buffer\n", rawcount);
}
i = 0;
while (1) {
// Peek at everything available
rawlen = (int) rfunc(sockfd, _WS_rbuf, _WS_bufsize-1,
flags | MSG_PEEK);
if (rawlen <= 0) {
DEBUG("_WS_recv: returning because rawlen %d\n", rawlen);
return (ssize_t) rawlen;
}
fstart = _WS_rbuf;
/*
while (rawlen >= 2 && fstart[0] == '\x00' && fstart[1] == '\xff') {
fstart += 2;
rawlen -= 2;
}
*/
if (rawlen >= 2 && fstart[0] == '\x00' && fstart[1] == '\xff') {
rawlen = (int) rfunc(sockfd, _WS_rbuf, 2, flags);
if (rawlen != 2) {
RET_ERROR(EIO, "Could not strip empty frame headers\n");
}
continue;
}
fstart[rawlen] = '\x00';
if (rawlen - _WS_newframe >= 4) {
// We have enough to base64 decode at least 1 byte
break;
}
// Not enough to base64 decode
if (sockflags & O_NONBLOCK) {
// Just tell the caller to call again
DEBUG("_WS_recv: returning because O_NONBLOCK, rawlen %d\n", rawlen);
errno = EAGAIN;
return -1;
}
// Repeat until at least 1 byte (4 raw bytes) to decode
i++;
if (i > 1000000) {
MSG("Could not send final part of frame\n");
}
}
/*
DEBUG("_WS_recv, left: %d, len: %d, rawlen: %d, newframe: %d, raw: ",
left, len, rawlen, _WS_newframe);
for (i = 0; i < rawlen; i++) {
DEBUG("%u,", (unsigned char) ((char *) fstart)[i]);
}
DEBUG("\n");
*/
if (_WS_newframe) {
if (fstart[0] != '\x00') {
RET_ERROR(EPROTO, "Missing frame start\n");
}
fstart++;
rawlen--;
_WS_newframe = 0;
}
fend = memchr(fstart, '\xff', rawlen);
if (fend) {
_WS_newframe = 1;
if ((fend - fstart) % 4) {
RET_ERROR(EPROTO, "Frame length is not multiple of 4\n");
}
} else {
fend = fstart + rawlen - (rawlen % 4);
if (fend - fstart < 4) {
RET_ERROR(EPROTO, "Frame too short\n");
}
}
// How much should we consume
if (rawcount < fend - fstart) {
_WS_newframe = 0;
deccount = rawcount;
} else {
deccount = fend - fstart;
}
// Now consume what we processed
if (flags & MSG_PEEK) {
MSG("*** Got MSG_PEEK ***\n");
} else {
rfunc(sockfd, _WS_rbuf, fstart - _WS_rbuf + deccount + _WS_newframe, flags);
}
fstart[deccount] = '\x00'; // base64 terminator
// Do direct base64 decode, instead of decode()
decodelen = b64_pton(fstart, (char *) buf + retlen, deccount);
if (decodelen <= 0) {
RET_ERROR(EPROTO, "Base64 decode error\n");
}
if (decodelen <= left) {
retlen += decodelen;
} else {
retlen += left;
if (! (flags & MSG_PEEK)) {
// Add anything left over to the carry-over
_WS_rcarry_cnt = decodelen - left;
if (_WS_rcarry_cnt > 2) {
RET_ERROR(EPROTO, "Got too much base64 data\n");
}
memcpy(_WS_rcarry, buf + retlen, _WS_rcarry_cnt);
if (_WS_rcarry_cnt == 1) {
DEBUG("Saving carry byte: %u\n", _WS_rcarry[0]);
} else if (_WS_rcarry_cnt == 2) {
DEBUG("Saving carry bytes: %u,%u\n", _WS_rcarry[0],
_WS_rcarry[1]);
} else {
MSG("Waah2!\n");
}
}
}
((char *) buf)[retlen] = '\x00';
/*
DEBUG("*** recv %s as ", fstart);
for (i = 0; i < retlen; i++) {
DEBUG("%u,", (unsigned char) ((char *) buf)[i]);
}
DEBUG(" (%d -> %d): %d\n", deccount, decodelen, retlen);
*/
return retlen;
}
ssize_t _WS_send(int sendf, int sockfd, const void *buf,
size_t len, int flags)
{
int rawlen, enclen, rlen, over, left, clen, retlen, dbufsize;
int sockflags;
char * target;
int i;
static void * (*sfunc)(), * (*sfunc2)();
if (!sfunc) sfunc = (void *(*)()) dlsym(RTLD_NEXT, "send");
if (!sfunc2) sfunc2 = (void *(*)()) dlsym(RTLD_NEXT, "write");
if ((_WS_sockfd == 0) || (_WS_sockfd != sockfd)) {
// Not our file descriptor, just pass through
if (sendf) {
return (ssize_t) sfunc(sockfd, buf, len, flags);
} else {
return (ssize_t) sfunc2(sockfd, buf, len);
}
}
DEBUG("_WS_send(%d, _, %d) called\n", sockfd, len);
sockflags = fcntl(sockfd, F_GETFL, 0);
dbufsize = (_WS_bufsize * 3)/4 - 2;
if (len > dbufsize) {
RET_ERROR(ENOMEM, "send of %d bytes is larger than send buffer\n", len);
}
// base64 encode and add frame markers
rawlen = 0;
_WS_sbuf[rawlen++] = '\x00';
enclen = b64_ntop(buf, len, _WS_sbuf+rawlen, _WS_bufsize-rawlen);
if (enclen < 0) {
RET_ERROR(EPROTO, "Base64 encoding error\n");
}
rawlen += enclen;
_WS_sbuf[rawlen++] = '\xff';
rlen = (int) sfunc(sockfd, _WS_sbuf, rawlen, flags);
if (rlen <= 0) {
return rlen;
} else if (rlen < rawlen) {
// Spin until we can send a whole base64 chunck and frame end
over = (rlen - 1) % 4;
left = (4 - over) % 4 + 1; // left to send
DEBUG("_WS_send: rlen: %d (over: %d, left: %d), rawlen: %d\n", rlen, over, left, rawlen);
rlen += left;
_WS_sbuf[rlen-1] = '\xff';
i = 0;
do {
i++;
clen = (int) sfunc(sockfd, _WS_sbuf + rlen - left, left, flags);
if (clen > 0) {
left -= clen;
} else if (clen == 0) {
MSG("_WS_send: got clen %d\n", clen);
} else if (!(sockflags & O_NONBLOCK)) {
MSG("_WS_send: clen %d\n", clen);
return clen;
}
if (i > 1000000) {
MSG("Could not send final part of frame\n");
}
} while (left > 0);
DEBUG("_WS_send: spins until finished %d\n", i);
}
/*
* Report back the number of original characters sent,
* not the raw number sent
*/
// Adjust for framing
retlen = rlen - 2;
// Adjust for base64 padding
if (_WS_sbuf[rlen-1] == '=') { retlen --; }
if (_WS_sbuf[rlen-2] == '=') { retlen --; }
// Adjust for base64 encoding
retlen = (retlen*3)/4;
/*
DEBUG("*** send ");
for (i = 0; i < retlen; i++) {
DEBUG("%u,", (unsigned char) ((char *)buf)[i]);
}
DEBUG(" as '%s' (%d)\n", _WS_sbuf+1, rlen);
*/
return (ssize_t) retlen;
}
/* Override network routines */
/*
int socket(int domain, int type, int protocol)
{
static void * (*func)();
if (!func) func = (void *(*)()) dlsym(RTLD_NEXT, "socket");
DEBUG("socket(_, %d, _) called\n", type);
return (int) func(domain, type, protocol);
}
int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen)
{
static void * (*func)();
if (!func) func = (void *(*)()) dlsym(RTLD_NEXT, "bind");
DEBUG("bind(%d, _, %d) called\n", sockfd, addrlen);
return (int) func(sockfd, addr, addrlen);
}
*/
int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
{
int fd, ret, envfd;
static void * (*func)();
if (!func) func = (void *(*)()) dlsym(RTLD_NEXT, "accept");
DEBUG("accept(%d, _, _) called\n", sockfd);
fd = (int) func(sockfd, addr, addrlen);
if (_WS_sockfd == 0) {
_WS_sockfd = fd;
if (!_WS_rbuf) {
if (! _WS_init()) {
RET_ERROR(ENOMEM, "Could not allocate interposer buffer\n");
}
}
ret = _WS_handshake(_WS_sockfd);
if (ret < 0) {
errno = EPROTO;
return ret;
}
MSG("interposing on fd %d\n", _WS_sockfd);
} else {
DEBUG("already interposing on fd %d\n", _WS_sockfd);
}
return fd;
}
int close(int fd)
{
static void * (*func)();
if (!func) func = (void *(*)()) dlsym(RTLD_NEXT, "close");
if ((_WS_sockfd != 0) && (_WS_sockfd == fd)) {
MSG("finished interposing on fd %d\n", _WS_sockfd);
_WS_sockfd = 0;
}
return (int) func(fd);
}
ssize_t read(int fd, void *buf, size_t count)
{
//DEBUG("read(%d, _, %d) called\n", fd, count);
return (ssize_t) _WS_recv(0, fd, buf, count, 0);
}
ssize_t write(int fd, const void *buf, size_t count)
{
//DEBUG("write(%d, _, %d) called\n", fd, count);
return (ssize_t) _WS_send(0, fd, buf, count, 0);
}
ssize_t recv(int sockfd, void *buf, size_t len, int flags)
{
//DEBUG("recv(%d, _, %d, %d) called\n", sockfd, len, flags);
return (ssize_t) _WS_recv(1, sockfd, buf, len, flags);
}
ssize_t send(int sockfd, const void *buf, size_t len, int flags)
{
//DEBUG("send(%d, _, %d, %d) called\n", sockfd, len, flags);
return (ssize_t) _WS_send(1, sockfd, buf, len, flags);
}