00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00026 #ifndef OW32_AsyncSecureSocketServer_h
00027 #define OW32_AsyncSecureSocketServer_h
00028
00029 #include <OW32/AsyncSecureSocket.h>
00030 #include <OW32/SSLInit.h>
00031
00032
00033 namespace OW32
00034 {
00035
00037 template <class T>
00038 class CAsyncSecureSocketServer : public CAsyncSecureSocket<T>
00039 {
00040 private:
00041
00042 CAsyncSecureSocketServer& operator= (const CAsyncSecureSocketServer& );
00043 CAsyncSecureSocketServer(const CAsyncSecureSocketServer& );
00044
00045 public:
00049 CAsyncSecureSocketServer(CAsyncSocketCallback* pCallback) :
00050 CAsyncSecureSocket<T>(pCallback)
00051 {
00052 }
00053
00058 CAsyncSecureSocketServer(CAsyncSocketCallback* pCallback, SOCKET s) :
00059 CAsyncSecureSocket<T>(pCallback, s)
00060 {
00061 }
00062
00067 virtual int shutdown(int how);
00068
00074 SECURITY_STATUS setServerCertificate(PCCERT_CONTEXT pCertContext,
00075 DWORD dwEnabledProtocols=0);
00076
00087 static SECURITY_STATUS createCredentialsFromCertificate(
00088 CredHandle* phCreds, PCCERT_CONTEXT pCertContext, DWORD dwEnabledProtocols=0)
00089 {
00090 return CSecureSocket::createCredentialsFromCertificate(phCreds, pCertContext,
00091 SECPKG_CRED_INBOUND, dwEnabledProtocols);
00092 }
00093
00094 protected:
00097 virtual void negotiateLoop();
00098
00099 private:
00100 void NegotiateError();
00101
00102 void ReadCompletion(BOOL bRet, DWORD cbReceived);
00103 void SendCompletion(BOOL bRet, DWORD cbSent);
00104
00105 virtual void onSendCompletion(BOOL bRet, DWORD cbBytesSent);
00106 virtual void onReadCompletion(BOOL bRet, DWORD cbBytesReceived);
00107 virtual void onTimeout();
00108
00109 bool ProcessDecryptedData();
00110 int handshakeSend(const char* data, int length);
00111
00112 auto_array_ptr<char> m_handshakeIo;
00113 DWORD m_cbHandshakeIo;
00114
00115 SECURITY_STATUS m_scRet;
00116 SecBuffer m_InBuffers[2];
00117 SecBufferDesc m_InBuffer;
00118 SecBuffer m_OutBuffers[1];
00119 SecBufferDesc m_OutBuffer;
00120 int m_shutdownHow;
00121
00122 enum { IO_BUFFER_SIZE = 16*1024 };
00123 };
00124
00125 template <class T>
00126 void CAsyncSecureSocketServer<T>::NegotiateError()
00127 {
00128 switch (m_State)
00129 {
00130 case State_Negotiate:
00131 m_pCallback->onConnectCompletion(FALSE);
00132 break;
00133 case State_Renegotiate:
00134 case State_Connected:
00135 m_pCallback->onReadCompletion(FALSE,0);
00136 break;
00137 case State_Shutdown:
00138 m_pCallback->onCloseCompletion(FALSE);
00139 break;
00140 }
00141 }
00142
00143 template <class T>
00144 void CAsyncSecureSocketServer<T>::onTimeout()
00145 {
00146 if (m_State == State_Connected) {
00147 m_pCallback->onTimeout();
00148 return;
00149 }
00150 SetLastError(ERROR_TIMEOUT);
00151 NegotiateError();
00152 }
00153
00154 template <class T>
00155 void CAsyncSecureSocketServer<T>::negotiateLoop()
00156 {
00157
00158 m_scRet = SEC_I_CONTINUE_NEEDED;
00159
00160
00161
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171 m_handshakeIo.reset(new char[IO_BUFFER_SIZE]);
00172 m_cbHandshakeIo = 0;
00173
00174
00175
00176 if (m_ExtraCount > 0)
00177 {
00178
00179 if (m_ExtraCount > IO_BUFFER_SIZE)
00180 {
00181 SetLastError((DWORD)E_UNEXPECTED);
00182 NegotiateError();
00183 return;
00184 }
00185 CopyMemory(m_handshakeIo.get(), m_Extra.get(), m_ExtraCount);
00186 m_cbHandshakeIo = m_ExtraCount;
00187 }
00188 ReadCompletion(TRUE,0);
00189 }
00190
00191 template <class T>
00192 int CAsyncSecureSocketServer<T>::handshakeSend(const char* data, int length)
00193 {
00194 m_sendData = data;
00195 m_sendLength = length;
00196 m_sendProcessed = 0;
00197 return T::send(data, length);
00198 }
00199
00200 template <class T>
00201 void CAsyncSecureSocketServer<T>::onSendCompletion(BOOL bRet, DWORD cbSent)
00202 {
00203 if (m_State == State_Initial)
00204 {
00205 m_pCallback->onSendCompletion(bRet, cbSent);
00206 return;
00207 }
00208
00209 if (m_State == State_Connected)
00210 {
00211 CAsyncSecureSocket<T>::onSendCompletion(bRet, cbSent);
00212 return;
00213 }
00214
00215
00216
00217 if (!bRet || cbSent == 0) {
00218 if (cbSent == 0 && bRet)
00219 SetLastError(ERROR_INVALID_FUNCTION);
00220 NegotiateError();
00221 return;
00222 }
00223 SendCompletion(bRet, cbSent);
00224 }
00225
00226 template <class T>
00227 void CAsyncSecureSocketServer<T>::SendCompletion(BOOL bRet, DWORD cbSent)
00228 {
00229 bRet;
00230 m_sendProcessed += cbSent;
00231
00232 if (m_sendProcessed < m_sendLength)
00233 {
00234 int ret = T::send(m_sendData + m_sendProcessed, m_sendLength - m_sendProcessed);
00235 if (ret != 0)
00236 {
00237 NegotiateError();
00238 }
00239 return;
00240 }
00241
00242 if (m_sendData)
00243 g_SecurityFunc.FreeContextBuffer( (PVOID)m_sendData );
00244 m_sendData = NULL;
00245
00246 switch (m_State)
00247 {
00248 case State_Negotiate:
00249 case State_Renegotiate:
00250 if (ProcessDecryptedData())
00251 ReadCompletion(TRUE, 0);
00252 break;
00253 case State_Shutdown:
00254
00255 if (CAsyncSecureSocket<T>::shutdown(m_shutdownHow) == SOCKET_ERROR)
00256 NegotiateError();
00257 break;
00258 }
00259 }
00260
00261 template <class T>
00262 void CAsyncSecureSocketServer<T>::onReadCompletion(BOOL bRet, DWORD cbReceived)
00263 {
00264 if (m_State == State_Initial)
00265 {
00266 m_pCallback->onReadCompletion(bRet, cbReceived);
00267 return;
00268 }
00269
00270 if (m_State == State_Connected)
00271 {
00272 CAsyncSecureSocket<T>::onReadCompletion(bRet, cbReceived);
00273 return;
00274 }
00275
00276
00277
00278 if (!bRet || cbReceived == 0) {
00279 if (cbReceived == 0 && bRet)
00280 SetLastError(ERROR_INVALID_FUNCTION);
00281 NegotiateError();
00282 return;
00283 }
00284 ReadCompletion(bRet, cbReceived);
00285 }
00286
00287 template <class T>
00288 void CAsyncSecureSocketServer<T>::ReadCompletion(BOOL bRet, DWORD cbReceived)
00289 {
00290 bRet;
00291 m_cbHandshakeIo += cbReceived;
00292
00293 while ( m_scRet == SEC_I_CONTINUE_NEEDED ||
00294 m_scRet == SEC_E_INCOMPLETE_MESSAGE ||
00295 m_scRet == SEC_I_INCOMPLETE_CREDENTIALS)
00296 {
00297
00298 if (0 == m_cbHandshakeIo || m_scRet == SEC_E_INCOMPLETE_MESSAGE)
00299 {
00300
00301 if (m_cbHandshakeIo >= IO_BUFFER_SIZE) {
00302 SetLastError((DWORD)E_UNEXPECTED);
00303 NegotiateError();
00304 return;
00305 }
00306
00307
00308
00309 m_scRet = SEC_I_CONTINUE_NEEDED;
00310 int ret = T::recv(&m_handshakeIo[m_cbHandshakeIo], IO_BUFFER_SIZE-m_cbHandshakeIo);
00311 if (ret != 0)
00312 NegotiateError();
00313 return;
00314 }
00315
00316
00317
00318
00319
00320
00321 m_InBuffers[0].pvBuffer = m_handshakeIo.get();
00322 m_InBuffers[0].cbBuffer = m_cbHandshakeIo ;
00323 m_InBuffers[0].BufferType = SECBUFFER_TOKEN;
00324
00325 m_InBuffers[1].pvBuffer = NULL;
00326 m_InBuffers[1].cbBuffer = 0;
00327 m_InBuffers[1].BufferType = SECBUFFER_EMPTY;
00328
00329 m_InBuffer.cBuffers = 2;
00330 m_InBuffer.pBuffers = m_InBuffers;
00331 m_InBuffer.ulVersion = SECBUFFER_VERSION;
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341 m_OutBuffer.cBuffers = 1;
00342 m_OutBuffer.pBuffers = m_OutBuffers;
00343 m_OutBuffer.ulVersion = SECBUFFER_VERSION;
00344
00345 m_OutBuffers[0].pvBuffer = NULL;
00346 m_OutBuffers[0].BufferType = SECBUFFER_TOKEN;
00347 m_OutBuffers[0].cbBuffer = 0;
00348
00349 DWORD dwSSPIOutFlags = 0;
00350 DWORD dwSSPIFlags =
00351 ASC_REQ_SEQUENCE_DETECT |
00352 ASC_REQ_REPLAY_DETECT |
00353 ASC_REQ_CONFIDENTIALITY |
00354 ASC_REQ_EXTENDED_ERROR |
00355 ASC_REQ_ALLOCATE_MEMORY |
00356 ASC_REQ_STREAM |
00357 (m_bRequireClientAuth ? ASC_REQ_MUTUAL_AUTH : 0);
00358 TimeStamp tsExpiry = {0};
00359
00360
00361 bool fInitContext = !SecIsValidHandle(&m_hContext);
00362
00363
00364
00365 __try
00366 {
00367 m_scRet = g_SecurityFunc.AcceptSecurityContext(
00368 &m_hCreds,
00369 (fInitContext?NULL:&m_hContext),
00370 &m_InBuffer,
00371 dwSSPIFlags,
00372 SECURITY_NATIVE_DREP,
00373 (fInitContext?&m_hContext:NULL),
00374 &m_OutBuffer,
00375 &dwSSPIOutFlags,
00376 &tsExpiry);
00377 }
00378 __except(EXCEPTION_EXECUTE_HANDLER)
00379 {
00380 m_scRet = GetExceptionCode();
00381 if (!FAILED(m_scRet)) m_scRet = E_UNEXPECTED;
00382 }
00383
00384
00385 if ( m_scRet == SEC_E_OK ||
00386 m_scRet == SEC_I_CONTINUE_NEEDED ||
00387 (FAILED(m_scRet) && (0 != (dwSSPIOutFlags & ASC_RET_EXTENDED_ERROR))))
00388 {
00389 if (m_OutBuffers[0].cbBuffer != 0 &&
00390 m_OutBuffers[0].pvBuffer != NULL )
00391 {
00392
00393
00394
00395
00396 int ret = handshakeSend((const char *)m_OutBuffers[0].pvBuffer, m_OutBuffers[0].cbBuffer);
00397 if (ret != 0)
00398 {
00399 NegotiateError();
00400 }
00401 return;
00402 }
00403 } else if (m_OutBuffers[0].pvBuffer != NULL) {
00404
00405
00406
00407 g_SecurityFunc.FreeContextBuffer( m_OutBuffers[0].pvBuffer );
00408 m_OutBuffers[0].pvBuffer = NULL;
00409 }
00410
00411 ProcessDecryptedData();
00412 }
00413 }
00414
00415 template <class T>
00416 bool CAsyncSecureSocketServer<T>::ProcessDecryptedData()
00417 {
00418
00419 if ( m_scRet == SEC_E_OK )
00420 {
00421
00422 SECURITY_STATUS m_scRet = querySizes();
00423 if (m_scRet != SEC_E_OK)
00424 {
00425 SetLastError(m_scRet);
00426 NegotiateError();
00427 return false;
00428 }
00429
00430
00431
00432 m_Extra.reset(0);
00433
00434
00435
00436 DWORD dwExtraSize = m_Sizes.cbMaximumMessage+m_Sizes.cbHeader+m_Sizes.cbTrailer;
00437 if ( m_InBuffers[1].BufferType == SECBUFFER_EXTRA )
00438 {
00439
00440 if (dwExtraSize < m_InBuffers[1].cbBuffer)
00441 dwExtraSize = m_InBuffers[1].cbBuffer;
00442
00443 m_ExtraCount = m_InBuffers[1].cbBuffer;
00444 m_Extra.reset(new char[dwExtraSize]);
00445 CopyMemory(m_Extra.get(), (LPBYTE) (m_handshakeIo.get() +
00446 (m_cbHandshakeIo - m_InBuffers[1].cbBuffer)), m_ExtraCount);
00447 }
00448
00449
00450 ConnectionState PrevState = m_State;
00451 m_State = State_Connected;
00452 switch (PrevState)
00453 {
00454 case State_Negotiate:
00455 m_pCallback->onConnectCompletion(TRUE);
00456 break;
00457 case State_Renegotiate:
00458 ReadCompletion(TRUE, 0);
00459 break;
00460 }
00461 return false;
00462 }
00463
00464 if (FAILED(m_scRet) && (m_scRet != SEC_E_INCOMPLETE_MESSAGE))
00465 {
00466
00467 SetLastError(m_scRet);
00468 NegotiateError();
00469 return false;
00470 }
00471
00472
00473
00474 if ( m_scRet != SEC_E_INCOMPLETE_MESSAGE &&
00475 m_scRet != SEC_I_INCOMPLETE_CREDENTIALS)
00476 {
00477
00478 if ( m_InBuffers[1].BufferType == SECBUFFER_EXTRA )
00479 {
00480
00481
00482 MoveMemory(m_handshakeIo.get(),
00483 (LPBYTE) (m_handshakeIo.get() + (m_cbHandshakeIo - m_InBuffers[1].cbBuffer)),
00484 m_InBuffers[1].cbBuffer);
00485 m_cbHandshakeIo = m_InBuffers[1].cbBuffer;
00486 }
00487 else
00488 {
00489
00490
00491
00492 m_cbHandshakeIo = 0;
00493 }
00494 }
00495 return true;
00496 }
00497
00498 template <class T>
00499 int CAsyncSecureSocketServer<T>::shutdown(int how)
00500 {
00501 if (m_State != State_Connected) {
00502 return CAsyncSecureSocket<T>::shutdown(how);
00503 }
00504
00505
00506
00507
00508
00509 DWORD dwType = SCHANNEL_SHUTDOWN;
00510
00511 SecBuffer OutBuffers[1];
00512 OutBuffers[0].pvBuffer = &dwType;
00513 OutBuffers[0].BufferType = SECBUFFER_TOKEN;
00514 OutBuffers[0].cbBuffer = sizeof(dwType);
00515
00516 SecBufferDesc OutBuffer;
00517 OutBuffer.cBuffers = 1;
00518 OutBuffer.pBuffers = OutBuffers;
00519 OutBuffer.ulVersion = SECBUFFER_VERSION;
00520
00521 SECURITY_STATUS Status = g_SecurityFunc.ApplyControlToken(&m_hContext, &OutBuffer);
00522
00523 if (FAILED(Status))
00524 {
00525 SetLastError(Status);
00526 return SOCKET_ERROR;
00527 }
00528
00529
00530
00531
00532
00533 DWORD dwSSPIFlags =
00534 ASC_REQ_SEQUENCE_DETECT |
00535 ASC_REQ_REPLAY_DETECT |
00536 ASC_REQ_CONFIDENTIALITY |
00537 ASC_REQ_EXTENDED_ERROR |
00538 ASC_REQ_ALLOCATE_MEMORY |
00539 ASC_REQ_STREAM;
00540
00541 OutBuffers[0].pvBuffer = NULL;
00542 OutBuffers[0].BufferType = SECBUFFER_TOKEN;
00543 OutBuffers[0].cbBuffer = 0;
00544
00545 OutBuffer.cBuffers = 1;
00546 OutBuffer.pBuffers = OutBuffers;
00547 OutBuffer.ulVersion = SECBUFFER_VERSION;
00548
00549 DWORD dwSSPIOutFlags = 0;
00550 TimeStamp tsExpiry = {0};
00551 Status = g_SecurityFunc.AcceptSecurityContext(
00552 &m_hCreds,
00553 &m_hContext,
00554 NULL,
00555 dwSSPIFlags,
00556 SECURITY_NATIVE_DREP,
00557 NULL,
00558 &OutBuffer,
00559 &dwSSPIOutFlags,
00560 &tsExpiry);
00561
00562 if (FAILED(Status))
00563 {
00564 SetLastError(Status);
00565 return SOCKET_ERROR;
00566 }
00567
00568 PBYTE pbMessage = (unsigned char *)OutBuffers[0].pvBuffer;
00569 DWORD cbMessage = OutBuffers[0].cbBuffer;
00570
00571
00572
00573
00574
00575
00576 if (pbMessage == NULL || cbMessage == 0)
00577 {
00578 return CAsyncSecureSocket<T>::shutdown(how);
00579 }
00580
00581
00582 m_State = State_Shutdown;
00583 m_shutdownHow = how;
00584 if (handshakeSend((const char*)pbMessage, (int)cbMessage) != 0)
00585 {
00586 return SOCKET_ERROR;
00587 }
00588 return 0;
00589 }
00590
00591 template <class T>
00592 SECURITY_STATUS CAsyncSecureSocketServer<T>::setServerCertificate(PCCERT_CONTEXT pCertContext,
00593 DWORD dwEnabledProtocols)
00594 {
00595 SECURITY_STATUS scRet = createCredentialsFromCertificate(&m_hCreds, pCertContext,
00596 dwEnabledProtocols);
00597 if (SUCCEEDED(scRet))
00598 m_ownCredentials = true;
00599 return scRet;
00600 }
00601
00602 }
00603
00604 #endif // OW32_AsyncSecureSocketServer_h