Kaynağa Gözat

simplify csum, added tcp option parser

wangyu- 7 yıl önce
ebeveyn
işleme
2e8294ab88
5 değiştirilmiş dosya ile 181 ekleme ve 18 silme
  1. 73 0
      common.cpp
  2. 14 0
      common.h
  3. 4 0
      main.cpp
  4. 1 0
      misc.cpp
  5. 89 18
      network.cpp

+ 73 - 0
common.cpp

@@ -314,6 +314,46 @@ u64_t hton64(u64_t a)
 
 }
 
+void write_u16(char * p,u16_t w)
+{
+	*(unsigned char*)(p + 1) = (w & 0xff);
+	*(unsigned char*)(p + 0) = (w >> 8);
+}
+u16_t read_u16(char * p)
+{
+	u16_t res;
+	res = *(const unsigned char*)(p + 0);
+	res = *(const unsigned char*)(p + 1) + (res << 8);
+	return res;
+}
+
+void write_u32(char * p,u32_t l)
+{
+	*(unsigned char*)(p + 3) = (unsigned char)((l >>  0) & 0xff);
+	*(unsigned char*)(p + 2) = (unsigned char)((l >>  8) & 0xff);
+	*(unsigned char*)(p + 1) = (unsigned char)((l >> 16) & 0xff);
+	*(unsigned char*)(p + 0) = (unsigned char)((l >> 24) & 0xff);
+}
+u32_t read_u32(char * p)
+{
+	u32_t res;
+	res = *(const unsigned char*)(p + 0);
+	res = *(const unsigned char*)(p + 1) + (res << 8);
+	res = *(const unsigned char*)(p + 2) + (res << 8);
+	res = *(const unsigned char*)(p + 3) + (res << 8);
+	return res;
+}
+
+void write_u64(char * s,u64_t a)
+{
+	assert(0==1);
+}
+u64_t read_u64(char * s)
+{
+	assert(0==1);
+	return 0;
+}
+
 void setnonblocking(int sock) {
 	int opts;
 	opts = fcntl(sock, F_GETFL);
@@ -358,6 +398,39 @@ unsigned short csum(const unsigned short *ptr,int nbytes) {//works both for big
     return(answer);
 }
 
+unsigned short csum_with_header(char* header,int hlen,const unsigned short *ptr,int nbytes) {//works both for big and little endian
+
+    long sum;
+    unsigned short oddbyte;
+    short answer;
+
+    assert(hlen%2==0);
+
+    sum=0;
+	unsigned short * tmp= (unsigned short *)header;
+	for(int i=0;i<hlen/2;i++)
+	{
+		sum+=*tmp++;
+	}
+
+
+    while(nbytes>1) {
+        sum+=*ptr++;
+        nbytes-=2;
+    }
+    if(nbytes==1) {
+        oddbyte=0;
+        *((u_char*)&oddbyte)=*(u_char*)ptr;
+        sum+=oddbyte;
+    }
+
+    sum = (sum>>16)+(sum & 0xffff);
+    sum = sum + (sum>>16);
+    answer=(short)~sum;
+
+    return(answer);
+}
+
 int set_buf_size(int fd,int socket_buf_size)
 {
 	if(force_socket_buf)

+ 14 - 0
common.h

@@ -64,6 +64,9 @@ typedef long long i64_t;
 typedef unsigned int u32_t;
 typedef int i32_t;
 
+typedef unsigned short u16_t;
+typedef short i16_t;
+
 typedef u32_t id_t;
 
 typedef u64_t iv_t;
@@ -250,6 +253,16 @@ u32_t get_true_random_number();
 u32_t get_true_random_number_nz();
 u64_t ntoh64(u64_t a);
 u64_t hton64(u64_t a);
+
+void write_u16(char *,u16_t a);// network order
+u16_t read_u16(char *);
+
+void write_u32(char *,u32_t a);// network order
+u32_t read_u32(char *);
+
+void write_u64(char *,u64_t a);
+u64_t read_uu64(char *);
+
 bool larger_than_u16(uint16_t a,uint16_t b);
 bool larger_than_u32(u32_t a,u32_t b);
 void setnonblocking(int sock);
@@ -258,6 +271,7 @@ int set_buf_size(int fd,int socket_buf_size);
 void myexit(int a);
 
 unsigned short csum(const unsigned short *ptr,int nbytes);
+unsigned short csum_with_header(char* header,int hlen,const unsigned short *ptr,int nbytes);
 
 int numbers_to_char(id_t id1,id_t id2,id_t id3,char * &data,int &len);
 int char_to_numbers(const char * data,int len,id_t &id1,id_t &id2,id_t &id3);

+ 4 - 0
main.cpp

@@ -12,6 +12,8 @@ char hb_buf[buf_len];
 
 int on_epoll_recv_event=0;  //TODO, just a flag to help detect epoll infinite shoot
 
+int server_easytcp=0;//currently only for test
+
 int client_on_timer(conn_info_t &conn_info) //for client. called when a timer is ready in epoll
 {
 	//keep_iptables_rule();
@@ -1011,6 +1013,8 @@ int server_on_raw_recv_multi() //called when server received an raw packet
 			{
 				return 0;
 			}
+			if(server_easytcp!=0)
+				return 0;
 			raw_info_t &raw_info=tmp_raw_info;
 			packet_info_t &send_info=raw_info.send_info;
 			packet_info_t &recv_info=raw_info.recv_info;

+ 1 - 0
misc.cpp

@@ -157,6 +157,7 @@ void print_help()
 	printf("    --disable-bpf                         disable the kernel space filter,most time its not necessary\n");
 	printf("                                          unless you suspect there is a bug\n");
 //	printf("\n");
+	printf("    --dev                 <string>        bind raw socket to a device, not necessary but improves performance\n");
 	printf("    --sock-buf            <number>        buf size for socket,>=10 and <=10240,unit:kbyte,default:1024\n");
 	printf("    --force-sock-buf                      bypass system limitation while setting sock-buf\n");
 	printf("    --seq-mode            <number>        seq increase mode for faketcp:\n");

+ 89 - 18
network.cpp

@@ -910,13 +910,12 @@ int send_raw_tcp(raw_info_t &raw_info,const char * payload, int payloadlen) {
 	char send_raw_tcp_buf[buf_len];
 	//char *send_raw_tcp_buf=send_raw_tcp_buf0;
 
-	struct tcphdr *tcph = (struct tcphdr *) (send_raw_tcp_buf
-			+ sizeof(struct pseudo_header));
-
+	struct tcphdr *tcph = (struct tcphdr *) (send_raw_tcp_buf);
 
 	memset(tcph,0,sizeof(tcphdr));
 
-	struct pseudo_header *psh = (struct pseudo_header *) (send_raw_tcp_buf);
+	pseudo_header tmp_header={0};
+	struct pseudo_header *psh = &tmp_header;
 
 	//TCP Header
 	tcph->source = htons(send_info.src_port);
@@ -933,7 +932,7 @@ int send_raw_tcp(raw_info_t &raw_info,const char * payload, int payloadlen) {
 
 	if (tcph->syn == 1) {
 		tcph->doff = 10;  //tcp header size
-		int i = sizeof(pseudo_header)+sizeof(tcphdr);
+		int i = sizeof(tcphdr);
 		send_raw_tcp_buf[i++] = 0x02;  //mss
 		send_raw_tcp_buf[i++] = 0x04;
 		send_raw_tcp_buf[i++] = 0x05;
@@ -969,7 +968,7 @@ int send_raw_tcp(raw_info_t &raw_info,const char * payload, int payloadlen) {
 		send_raw_tcp_buf[i++] = wscale;
 	} else {
 		tcph->doff = 8;
-		int i = sizeof(pseudo_header)+sizeof(tcphdr);
+		int i = sizeof(tcphdr);
 
 		send_raw_tcp_buf[i++] = 0x01;
 		send_raw_tcp_buf[i++] = 0x01;
@@ -1000,7 +999,7 @@ int send_raw_tcp(raw_info_t &raw_info,const char * payload, int payloadlen) {
 	tcph->check = 0; //leave checksum 0 now, filled later by pseudo header
 	tcph->urg_ptr = 0;
 
-	char *tcp_data = send_raw_tcp_buf+sizeof(struct pseudo_header) + tcph->doff * 4;
+	char *tcp_data = send_raw_tcp_buf+ + tcph->doff * 4;
 
 	memcpy(tcp_data, payload, payloadlen);
 
@@ -1010,13 +1009,13 @@ int send_raw_tcp(raw_info_t &raw_info,const char * payload, int payloadlen) {
 	psh->protocol = IPPROTO_TCP;
 	psh->tcp_length = htons(tcph->doff * 4 + payloadlen);
 
-	int csum_size = sizeof(struct pseudo_header) + tcph->doff*4 + payloadlen;
+	int csum_size = tcph->doff*4 + payloadlen;
 
-	tcph->check = csum( (unsigned short*) send_raw_tcp_buf, csum_size);
+	tcph->check = csum_with_header((char *)psh,sizeof(pseudo_header), (unsigned short*) send_raw_tcp_buf, csum_size);
 
 	int tcp_totlen=tcph->doff*4 + payloadlen;
 
-	if(send_raw_ip(raw_info,send_raw_tcp_buf+ sizeof(struct pseudo_header),tcp_totlen)!=0)
+	if(send_raw_ip(raw_info,send_raw_tcp_buf,tcp_totlen)!=0)
 	{
 		return -1;
 	}
@@ -1335,13 +1334,74 @@ int recv_raw_udp(raw_info_t &raw_info, char *&payload, int &payloadlen)
 
     return 0;
 }
+int parse_tcp_option(char * option_begin,char * option_end,packet_info_t &recv_info)
+{
+    recv_info.has_ts=0;
+    recv_info.ts=0;
+
+    char *ptr=option_begin;
+    //char *option_end=tcp_begin+tcp_hdr_len;
+    while(ptr<option_end)
+    {
+    	if(*ptr==0)
+    	{
+    		return  0;
+    	}
+    	else if(*ptr==1)
+    	{
+    		ptr++;
+    	}
+    	else if(*ptr==8)
+    	{
+    		if(ptr+1>=option_end)
+    		{
+    			mylog(log_debug,"invaild option ptr+1==option_end,for ts\n");
+    			return -1;
+    		}
+    		if(*(ptr+1)!=10)
+    		{
+    			mylog(log_debug,"invaild ts len\n");
+    			return -1;
+    		}
+    		if(ptr+10>option_end)
+    		{
+    			mylog(log_debug,"ptr+10>option_end for ts\n");
+    			return -1;
+    		}
+
+    		recv_info.has_ts=1;
+
+    		recv_info.ts= read_u32(ptr+2);
+    		recv_info.ts_ack=read_u32(ptr+6);
+
+    		//printf("<%d %d>!\n",recv_info.ts,recv_info.ts_ack);
+
+    		//return 0;//we currently only parse ts, so just return after its found
+    		ptr+=8;
+    	}
+    	else
+    	{
+    		if(ptr+1>=option_end)
+    		{
+    			mylog(log_debug,"invaild option ptr+1==option_end\n");
+    			return -1;
+    		}
+    		else
+    		{
+    			//omit check
+    			ptr+=*(ptr+1);
+    		}
+    	}
+    }
 
+	return 0;
+}
 int recv_raw_tcp(raw_info_t &raw_info,char * &payload,int &payloadlen)
 {
 	const packet_info_t &send_info=raw_info.send_info;
 	packet_info_t &recv_info=raw_info.recv_info;
 
-	static char recv_raw_tcp_buf[buf_len];
+	//static char recv_raw_tcp_buf[buf_len];
 
 	char * ip_payload;
 	int ip_payloadlen;
@@ -1376,9 +1436,10 @@ int recv_raw_tcp(raw_info_t &raw_info,char * &payload,int &payloadlen)
     	return -1;
     }
 
-    memcpy(recv_raw_tcp_buf+ sizeof(struct pseudo_header) , ip_payload , ip_payloadlen);
+   // memcpy(recv_raw_tcp_buf+ sizeof(struct pseudo_header) , ip_payload , ip_payloadlen);
 
-    struct pseudo_header *psh=(pseudo_header *)recv_raw_tcp_buf ;
+    pseudo_header tmp_header;
+    struct pseudo_header *psh=&tmp_header ;
 
     psh->source_address = recv_info.src_ip;
     psh->dest_address = recv_info.dst_ip;
@@ -1386,8 +1447,14 @@ int recv_raw_tcp(raw_info_t &raw_info,char * &payload,int &payloadlen)
     psh->protocol = IPPROTO_TCP;
     psh->tcp_length = htons(ip_payloadlen);
 
-    int csum_len=sizeof(struct pseudo_header)+ip_payloadlen;
-    uint16_t tcp_chk = csum( (unsigned short*) recv_raw_tcp_buf, csum_len);
+    int csum_len=ip_payloadlen;
+    uint16_t tcp_chk = csum_with_header((char *)psh,sizeof(pseudo_header), (unsigned short*) ip_payload, csum_len);
+
+    /*for(int i=0;i<csum_len;i++)
+    {
+    	printf("<%d>",int(ip_payload[i]));
+    }
+    printf("\n");*/
 
     if(tcp_chk!=0)
     {
@@ -1397,10 +1464,12 @@ int recv_raw_tcp(raw_info_t &raw_info,char * &payload,int &payloadlen)
 
     }
 
-    char *tcp_begin=recv_raw_tcp_buf+sizeof(struct pseudo_header);  //ip packet's data part
+    char *tcp_begin=ip_payload;  //ip packet's data part
 
-    char *tcp_option=recv_raw_tcp_buf+sizeof(struct pseudo_header)+sizeof(tcphdr);
+    char *tcp_option=ip_payload+sizeof(tcphdr);
 
+    /*
+    //old ts parse code
     recv_info.has_ts=0;
     recv_info.ts=0;
     if(tcph->doff==10)
@@ -1445,7 +1514,9 @@ int recv_raw_tcp(raw_info_t &raw_info,char * &payload,int &payloadlen)
     {
     	//mylog(log_info,"tcph->doff= %u\n",tcph->doff);
     }
-
+    printf("<%d %d>\n",recv_info.ts,recv_info.ts_ack);
+    */
+    parse_tcp_option(tcp_option,tcp_begin+tcph->doff*4,recv_info);
 
     recv_info.ack=tcph->ack;
     recv_info.syn=tcph->syn;