- Make sure the socket is still open before entering oskittcp

- Remove an unused parameter from OskitTCPBind
 - Return a status value from OskitTCPGetAddress
 - Add debug print for unhandled error codes

svn path=/trunk/; revision=43864
This commit is contained in:
Cameron Gutman 2009-10-31 01:05:31 +00:00
parent 82bd7858d1
commit a06fc827da
4 changed files with 67 additions and 25 deletions

View file

@ -90,7 +90,6 @@ NTSTATUS TCPListen( PCONNECTION_ENDPOINT Connection, UINT Backlog ) {
TI_DbgPrint(DEBUG_TCP,("AddressToBind - %x:%x\n", AddressToBind.sin_addr, AddressToBind.sin_port)); TI_DbgPrint(DEBUG_TCP,("AddressToBind - %x:%x\n", AddressToBind.sin_addr, AddressToBind.sin_port));
Status = TCPTranslateError( OskitTCPBind( Connection->SocketContext, Status = TCPTranslateError( OskitTCPBind( Connection->SocketContext,
Connection,
&AddressToBind, &AddressToBind,
sizeof(AddressToBind) ) ); sizeof(AddressToBind) ) );

View file

@ -555,7 +555,7 @@ NTSTATUS TCPShutdown(VOID)
} }
NTSTATUS TCPTranslateError( int OskitError ) { NTSTATUS TCPTranslateError( int OskitError ) {
NTSTATUS Status = STATUS_UNSUCCESSFUL; NTSTATUS Status;
switch( OskitError ) { switch( OskitError ) {
case 0: Status = STATUS_SUCCESS; break; case 0: Status = STATUS_SUCCESS; break;
@ -565,7 +565,10 @@ NTSTATUS TCPTranslateError( int OskitError ) {
case OSK_ECONNRESET: Status = STATUS_REMOTE_NOT_LISTENING; break; case OSK_ECONNRESET: Status = STATUS_REMOTE_NOT_LISTENING; break;
case OSK_EINPROGRESS: case OSK_EINPROGRESS:
case OSK_EAGAIN: Status = STATUS_PENDING; break; case OSK_EAGAIN: Status = STATUS_PENDING; break;
default: Status = STATUS_INVALID_CONNECTION; break; default:
DbgPrint("OskitTCP returned unhandled error code: %d\n", OskitError);
Status = STATUS_INVALID_CONNECTION;
break;
} }
TI_DbgPrint(DEBUG_TCP,("Error %d -> %x\n", OskitError, Status)); TI_DbgPrint(DEBUG_TCP,("Error %d -> %x\n", OskitError, Status));
@ -621,7 +624,6 @@ NTSTATUS TCPConnect
Status = TCPTranslateError Status = TCPTranslateError
( OskitTCPBind( Connection->SocketContext, ( OskitTCPBind( Connection->SocketContext,
Connection,
&AddressToBind, &AddressToBind,
sizeof(AddressToBind) ) ); sizeof(AddressToBind) ) );
@ -703,6 +705,8 @@ NTSTATUS TCPClose
DrainSignals(); DrainSignals();
Status = TCPTranslateError( OskitTCPClose( Connection->SocketContext ) ); Status = TCPTranslateError( OskitTCPClose( Connection->SocketContext ) );
if (Status == STATUS_SUCCESS)
Connection->SocketContext = NULL;
TI_DbgPrint(DEBUG_TCP,("TCPClose finished %x\n", Status)); TI_DbgPrint(DEBUG_TCP,("TCPClose finished %x\n", Status));
@ -867,13 +871,15 @@ NTSTATUS TCPGetSockAddress
OSK_UINT LocalAddress, RemoteAddress; OSK_UINT LocalAddress, RemoteAddress;
OSK_UI16 LocalPort, RemotePort; OSK_UI16 LocalPort, RemotePort;
PTA_IP_ADDRESS AddressIP = (PTA_IP_ADDRESS)Address; PTA_IP_ADDRESS AddressIP = (PTA_IP_ADDRESS)Address;
NTSTATUS Status;
ASSERT_LOCKED(&TCPLock); ASSERT_LOCKED(&TCPLock);
OskitTCPGetAddress Status = TCPTranslateError(OskitTCPGetAddress(Connection->SocketContext,
( Connection->SocketContext,
&LocalAddress, &LocalPort, &LocalAddress, &LocalPort,
&RemoteAddress, &RemotePort ); &RemoteAddress, &RemotePort));
if (!NT_SUCCESS(Status))
return Status;
AddressIP->TAAddressCount = 1; AddressIP->TAAddressCount = 1;
AddressIP->Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP; AddressIP->Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP;
@ -881,7 +887,7 @@ NTSTATUS TCPGetSockAddress
AddressIP->Address[0].Address[0].sin_port = GetRemote ? RemotePort : LocalPort; AddressIP->Address[0].Address[0].sin_port = GetRemote ? RemotePort : LocalPort;
AddressIP->Address[0].Address[0].in_addr = GetRemote ? RemoteAddress : LocalAddress; AddressIP->Address[0].Address[0].in_addr = GetRemote ? RemoteAddress : LocalAddress;
return STATUS_SUCCESS; return Status;
} }
VOID TCPRemoveIRP( PCONNECTION_ENDPOINT Endpoint, PIRP Irp ) { VOID TCPRemoveIRP( PCONNECTION_ENDPOINT Endpoint, PIRP Irp ) {

View file

@ -127,7 +127,7 @@ extern int OskitTCPConnect( void *socket, void *connection,
void *nam, OSK_UINT namelen ); void *nam, OSK_UINT namelen );
extern int OskitTCPClose( void *socket ); extern int OskitTCPClose( void *socket );
extern int OskitTCPBind( void *socket, void *connection, extern int OskitTCPBind( void *socket,
void *nam, OSK_UINT namelen ); void *nam, OSK_UINT namelen );
extern int OskitTCPAccept( void *socket, void **new_socket, extern int OskitTCPAccept( void *socket, void **new_socket,
@ -144,7 +144,7 @@ extern int OskitTCPRecv( void *connection,
OSK_UINT *OutLen, OSK_UINT *OutLen,
OSK_UINT Flags ); OSK_UINT Flags );
void OskitTCPGetAddress( void *socket, int OskitTCPGetAddress( void *socket,
OSK_UINT *LocalAddress, OSK_UINT *LocalAddress,
OSK_UI16 *LocalPort, OSK_UI16 *LocalPort,
OSK_UINT *RemoteAddress, OSK_UINT *RemoteAddress,

View file

@ -76,9 +76,6 @@ void TimerOskitTCP( int FastTimer, int SlowTimer ) {
void RegisterOskitTCPEventHandlers( POSKITTCP_EVENT_HANDLERS EventHandlers ) { void RegisterOskitTCPEventHandlers( POSKITTCP_EVENT_HANDLERS EventHandlers ) {
memcpy( &OtcpEvent, EventHandlers, sizeof(OtcpEvent) ); memcpy( &OtcpEvent, EventHandlers, sizeof(OtcpEvent) );
if( OtcpEvent.PacketSend )
OS_DbgPrint(OSK_MID_TRACE,("SendPacket handler registered: %x\n",
OtcpEvent.PacketSend));
} }
void OskitDumpBuffer( OSK_PCHAR Data, OSK_UINT Len ) void OskitDumpBuffer( OSK_PCHAR Data, OSK_UINT Len )
@ -169,7 +166,7 @@ int OskitTCPRecv( void *connection,
return error; return error;
} }
int OskitTCPBind( void *socket, void *connection, int OskitTCPBind( void *socket,
void *nam, OSK_UINT namelen ) { void *nam, OSK_UINT namelen ) {
int error = EFAULT; int error = EFAULT;
struct socket *so = socket; struct socket *so = socket;
@ -178,6 +175,9 @@ int OskitTCPBind( void *socket, void *connection,
OS_DbgPrint(OSK_MID_TRACE,("Called, socket = %08x\n", socket)); OS_DbgPrint(OSK_MID_TRACE,("Called, socket = %08x\n", socket));
if (!socket)
return OSK_ESHUTDOWN;
if( nam ) if( nam )
addr = *((struct sockaddr *)nam); addr = *((struct sockaddr *)nam);
@ -243,11 +243,18 @@ done:
} }
int OskitTCPShutdown( void *socket, int disconn_type ) { int OskitTCPShutdown( void *socket, int disconn_type ) {
if (!socket)
return OSK_ESHUTDOWN;
return soshutdown( socket, disconn_type ); return soshutdown( socket, disconn_type );
} }
int OskitTCPClose( void *socket ) { int OskitTCPClose( void *socket ) {
struct socket *so = socket; struct socket *so = socket;
if (!socket)
return OSK_ESHUTDOWN;
so->so_connection = 0; so->so_connection = 0;
soclose( so ); soclose( so );
return 0; return 0;
@ -259,6 +266,9 @@ int OskitTCPSend( void *socket, OSK_PCHAR Data, OSK_UINT Len,
struct uio uio; struct uio uio;
struct iovec iov; struct iovec iov;
if (!socket)
return OSK_ESHUTDOWN;
iov.iov_len = Len; iov.iov_len = Len;
iov.iov_base = (char *)Data; iov.iov_base = (char *)Data;
uio.uio_iov = &iov; uio.uio_iov = &iov;
@ -293,6 +303,12 @@ int OskitTCPAccept( void *socket,
struct inpcb *inp; struct inpcb *inp;
int namelen = 0, error = 0, s; int namelen = 0, error = 0, s;
if (!socket)
return OSK_ESHUTDOWN;
if (!new_socket || !AddrOut)
return OSK_EINVAL;
OS_DbgPrint(OSK_MID_TRACE,("OSKITTCP: Doing accept (Finish %d)\n", OS_DbgPrint(OSK_MID_TRACE,("OSKITTCP: Doing accept (Finish %d)\n",
FinishAccepting)); FinishAccepting));
@ -436,6 +452,9 @@ int OskitTCPSetSockOpt(void *socket,
{ {
struct mbuf *m; struct mbuf *m;
if (!socket)
return OSK_ESHUTDOWN;
if (size >= MLEN) if (size >= MLEN)
return OSK_EINVAL; return OSK_EINVAL;
@ -460,6 +479,9 @@ int OskitTCPGetSockOpt(void *socket,
int error, oldsize = *size; int error, oldsize = *size;
struct mbuf *m; struct mbuf *m;
if (!socket)
return OSK_ESHUTDOWN;
error = sogetopt(socket, level, optname, &m); error = sogetopt(socket, level, optname, &m);
if (!error) if (!error)
{ {
@ -482,6 +504,9 @@ int OskitTCPGetSockOpt(void *socket,
int OskitTCPListen( void *socket, int backlog ) { int OskitTCPListen( void *socket, int backlog ) {
int error; int error;
if (!socket)
return OSK_ESHUTDOWN;
OS_DbgPrint(OSK_MID_TRACE,("Called, socket = %08x\n", socket)); OS_DbgPrint(OSK_MID_TRACE,("Called, socket = %08x\n", socket));
error = solisten( socket, backlog ); error = solisten( socket, backlog );
OS_DbgPrint(OSK_MID_TRACE,("Ending: %08x\n", error)); OS_DbgPrint(OSK_MID_TRACE,("Ending: %08x\n", error));
@ -489,32 +514,44 @@ int OskitTCPListen( void *socket, int backlog ) {
return error; return error;
} }
void OskitTCPSetAddress( void *socket, int OskitTCPSetAddress( void *socket,
OSK_UINT LocalAddress, OSK_UINT LocalAddress,
OSK_UI16 LocalPort, OSK_UI16 LocalPort,
OSK_UINT RemoteAddress, OSK_UINT RemoteAddress,
OSK_UI16 RemotePort ) { OSK_UI16 RemotePort ) {
struct socket *so = socket; struct socket *so = socket;
struct inpcb *inp = (struct inpcb *)so->so_pcb; struct inpcb *inp;
if (!socket)
return OSK_ESHUTDOWN;
inp = (struct inpcb *)so->so_pcb;
inp->inp_laddr.s_addr = LocalAddress; inp->inp_laddr.s_addr = LocalAddress;
inp->inp_lport = LocalPort; inp->inp_lport = LocalPort;
inp->inp_faddr.s_addr = RemoteAddress; inp->inp_faddr.s_addr = RemoteAddress;
inp->inp_fport = RemotePort; inp->inp_fport = RemotePort;
return 0;
} }
void OskitTCPGetAddress( void *socket, int OskitTCPGetAddress( void *socket,
OSK_UINT *LocalAddress, OSK_UINT *LocalAddress,
OSK_UI16 *LocalPort, OSK_UI16 *LocalPort,
OSK_UINT *RemoteAddress, OSK_UINT *RemoteAddress,
OSK_UI16 *RemotePort ) { OSK_UI16 *RemotePort ) {
struct socket *so = socket; struct socket *so = socket;
struct inpcb *inp = so ? (struct inpcb *)so->so_pcb : NULL; struct inpcb *inp;
if( inp ) {
if (!socket)
return OSK_ESHUTDOWN;
inp = (struct inpcb *)so->so_pcb;
*LocalAddress = inp->inp_laddr.s_addr; *LocalAddress = inp->inp_laddr.s_addr;
*LocalPort = inp->inp_lport; *LocalPort = inp->inp_lport;
*RemoteAddress = inp->inp_faddr.s_addr; *RemoteAddress = inp->inp_faddr.s_addr;
*RemotePort = inp->inp_fport; *RemotePort = inp->inp_fport;
}
return 0;
} }
struct ifaddr *ifa_iffind(struct sockaddr *addr, int type) struct ifaddr *ifa_iffind(struct sockaddr *addr, int type)