00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00026 #ifndef OW32_AsyncSecureSocketClient_h
00027 #define OW32_AsyncSecureSocketClient_h
00028
00029 #include <OW32/AsyncSecureSocket.h>
00030 #include <OW32/SSLInit.h>
00031
00032
00033 namespace OW32
00034 {
00035
00037 template <class T>
00038 class CAsyncSecureSocketClient : public CAsyncSecureSocket<T>
00039 {
00040 private:
00041 CAsyncSecureSocketClient(const CAsyncSecureSocketClient& );
00042 CAsyncSecureSocketClient& operator=(const CAsyncSecureSocketClient& );
00043
00044 protected:
00045
00046 public:
00051 CAsyncSecureSocketClient(CAsyncSocketCallback* pCallback) :
00052 CAsyncSecureSocket<T>(pCallback),
00053 m_fDoRead(false)
00054 {
00055 m_szServerName[0] = _T('\0');
00056 }
00057
00062 CAsyncSecureSocketClient(CAsyncSocketCallback* pCallback, SOCKET s) :
00063 CAsyncSecureSocket<T>(pCallback,s),
00064 m_fDoRead(false)
00065 {
00066 m_szServerName[0] = _T('\0');
00067 }
00068
00072 virtual int shutdown(int how);
00073
00077 void setServerName(LPCTSTR lpszServerName)
00078 {
00079 const int nServerBuf = sizeof(m_szServerName)/sizeof(m_szServerName[0]);
00080 _tcsncpy(m_szServerName, lpszServerName, nServerBuf);
00081 m_szServerName[nServerBuf-1]=_T('\0');
00082 }
00083
00088 SECURITY_STATUS setClientCertificate(PCCERT_CONTEXT pCertContext,
00089 DWORD dwEnabledProtocols=0);
00090
00094 SECURITY_STATUS createNullCredentials(DWORD dwEnabledProtocols=0);
00095
00101 static SECURITY_STATUS createCredentialsFromCertificate(
00102 CredHandle* phCreds, PCCERT_CONTEXT pCertContext, DWORD dwEnabledProtocols=0)
00103 {
00104 return CAsyncSecureSocket<T>::createCredentialsFromCertificate(phCreds, pCertContext,
00105 SECPKG_CRED_OUTBOUND, dwEnabledProtocols);
00106 }
00107
00108 protected:
00114 virtual void negotiateLoop();
00115
00116 private:
00117 TCHAR m_szServerName[260];
00118 int disconnectFromServer();
00119 virtual void onSendCompletion(BOOL bRet, DWORD cbBytesSent);
00120 virtual void onReadCompletion(BOOL bRet, DWORD cbBytesReceived);
00121 virtual void onConnectCompletion(BOOL bRet);
00122 virtual void onTimeout();
00123 int handshakeSend(const char* data, int length);
00124
00125 void NegotiateError();
00126 void ReadCompletion(BOOL bRet, DWORD cbReceived);
00127 void SendCompletion(BOOL bRet, DWORD cbSent);
00128 bool ProcessDecryptedData();
00129
00130 bool m_fDoRead;
00131 SECURITY_STATUS m_scRet;
00132 SecBuffer m_InBuffers[2];
00133 SecBufferDesc m_InBuffer;
00134 SecBuffer m_OutBuffers[1];
00135 SecBufferDesc m_OutBuffer;
00136
00137 enum { IO_BUFFER_SIZE = 16*1024 + 5 + 16 };
00138 auto_array_ptr<char> m_handshakeIo;
00139 DWORD m_cbHandshakeIo;
00140 };
00141
00142 template <class T>
00143 void CAsyncSecureSocketClient<T>::NegotiateError()
00144 {
00145 switch (m_State)
00146 {
00147 case State_Negotiate:
00148 m_pCallback->onConnectCompletion(FALSE);
00149 break;
00150 case State_Renegotiate:
00151 case State_Connected:
00152 m_pCallback->onReadCompletion(FALSE,0);
00153 break;
00154 case State_Shutdown:
00155 m_pCallback->onCloseCompletion(FALSE);
00156 break;
00157 }
00158 }
00159
00160 template <class T>
00161 void CAsyncSecureSocketClient<T>::onConnectCompletion(BOOL bRet)
00162 {
00163 if (!bRet) {
00164 m_pCallback->onConnectCompletion(bRet);
00165 return;
00166 }
00167 (void)negotiate();
00168 }
00169
00170 template <class T>
00171 void CAsyncSecureSocketClient<T>::onTimeout()
00172 {
00173 if (m_State == State_Connected) {
00174 m_pCallback->onTimeout();
00175 return;
00176 }
00177 SetLastError(ERROR_TIMEOUT);
00178 NegotiateError();
00179 }
00180
00181 template <class T>
00182 int CAsyncSecureSocketClient<T>::shutdown(int how)
00183 {
00184 if (m_State != State_Connected) {
00185 return CAsyncSecureSocket<T>::shutdown(how);
00186 }
00187
00188
00189
00190
00191
00192 DWORD dwType = SCHANNEL_SHUTDOWN;
00193
00194 SecBuffer OutBuffers[1];
00195 OutBuffers[0].pvBuffer = &dwType;
00196 OutBuffers[0].BufferType = SECBUFFER_TOKEN;
00197 OutBuffers[0].cbBuffer = sizeof(dwType);
00198
00199 SecBufferDesc OutBuffer;
00200 OutBuffer.cBuffers = 1;
00201 OutBuffer.pBuffers = OutBuffers;
00202 OutBuffer.ulVersion = SECBUFFER_VERSION;
00203
00204 SECURITY_STATUS Status = g_SecurityFunc.ApplyControlToken(&m_hContext, &OutBuffer);
00205 if (FAILED(Status))
00206 {
00207 SetLastError(Status);
00208 return SOCKET_ERROR;
00209 }
00210
00211
00212
00213
00214 DWORD dwSSPIFlags =
00215 ISC_REQ_SEQUENCE_DETECT |
00216 ISC_REQ_REPLAY_DETECT |
00217 ISC_REQ_CONFIDENTIALITY |
00218 ISC_RET_EXTENDED_ERROR |
00219 ISC_REQ_ALLOCATE_MEMORY |
00220 ISC_REQ_STREAM;
00221
00222 OutBuffers[0].pvBuffer = NULL;
00223 OutBuffers[0].BufferType = SECBUFFER_TOKEN;
00224 OutBuffers[0].cbBuffer = 0;
00225
00226 OutBuffer.cBuffers = 1;
00227 OutBuffer.pBuffers = OutBuffers;
00228 OutBuffer.ulVersion = SECBUFFER_VERSION;
00229
00230 DWORD dwSSPIOutFlags;
00231 TimeStamp tsExpiry;
00232 Status = g_SecurityFunc.InitializeSecurityContext(
00233 &m_hCreds,
00234 &m_hContext,
00235 NULL,
00236 dwSSPIFlags,
00237 0,
00238 SECURITY_NATIVE_DREP,
00239 NULL,
00240 0,
00241 &m_hContext,
00242 &OutBuffer,
00243 &dwSSPIOutFlags,
00244 &tsExpiry);
00245
00246 if (FAILED(Status))
00247 {
00248 SetLastError(Status);
00249 return SOCKET_ERROR;
00250 }
00251
00252 PBYTE pbMessage = (unsigned char *)OutBuffers[0].pvBuffer;
00253 DWORD cbMessage = OutBuffers[0].cbBuffer;
00254
00255
00256
00257
00258 m_State = State_Shutdown;
00259 int ret = handshakeSend((const char *)pbMessage, (int)cbMessage);
00260 if (ret == SOCKET_ERROR)
00261 return SOCKET_ERROR;
00262 return 0;
00263 }
00264
00265 template <class T>
00266 void CAsyncSecureSocketClient<T>::negotiateLoop()
00267 {
00268 if (m_State == State_Negotiate)
00269 {
00270 if (!SecIsValidHandle(&m_hCreds)) {
00271 if (createNullCredentials() != SEC_E_OK)
00272 {
00273 NegotiateError();
00274 return ;
00275 }
00276 }
00277
00278 DWORD dwSSPIFlags =
00279 ISC_REQ_SEQUENCE_DETECT |
00280 ISC_REQ_REPLAY_DETECT |
00281 ISC_REQ_CONFIDENTIALITY |
00282 ISC_RET_EXTENDED_ERROR |
00283 ISC_REQ_ALLOCATE_MEMORY |
00284 ISC_REQ_STREAM;
00285
00286
00287
00288
00289
00290 m_OutBuffer.cBuffers = 1;
00291 m_OutBuffer.pBuffers = m_OutBuffers;
00292 m_OutBuffer.ulVersion = SECBUFFER_VERSION;
00293
00294 m_OutBuffers[0].pvBuffer = NULL;
00295 m_OutBuffers[0].BufferType = SECBUFFER_TOKEN;
00296 m_OutBuffers[0].cbBuffer = 0;
00297
00298
00299 DWORD dwSSPIOutFlags=0;
00300 TimeStamp tsExpiry={0};
00301 SECURITY_STATUS scRet = g_SecurityFunc.InitializeSecurityContext(
00302 &m_hCreds,
00303 NULL,
00304 m_szServerName,
00305 dwSSPIFlags,
00306 0,
00307 SECURITY_NATIVE_DREP,
00308 NULL,
00309 0,
00310 &m_hContext,
00311 &m_OutBuffer,
00312 &dwSSPIOutFlags,
00313 &tsExpiry);
00314
00315 if (scRet != SEC_I_CONTINUE_NEEDED)
00316 {
00317 SetLastError(scRet);
00318 NegotiateError();
00319 return;
00320 }
00321 }
00322
00323
00324 m_scRet = SEC_I_CONTINUE_NEEDED;
00325 m_fDoRead = m_State==State_Negotiate;
00326
00327
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337
00338 m_handshakeIo.reset(new char[IO_BUFFER_SIZE]);
00339 m_cbHandshakeIo = 0;
00340
00341
00342
00343 if (m_ExtraCount > 0)
00344 {
00345
00346 if (m_ExtraCount > IO_BUFFER_SIZE)
00347 {
00348 SetLastError((DWORD)E_UNEXPECTED);
00349 NegotiateError();
00350 return;
00351 }
00352 CopyMemory(m_handshakeIo.get(), m_Extra.get(), m_ExtraCount);
00353 m_cbHandshakeIo = m_ExtraCount;
00354 }
00355
00356 if (m_State==State_Negotiate) {
00357
00358 if (m_OutBuffers[0].cbBuffer != 0 && m_OutBuffers[0].pvBuffer != NULL)
00359 {
00360 if (handshakeSend((const char *)m_OutBuffers[0].pvBuffer, m_OutBuffers[0].cbBuffer) != 0)
00361 {
00362 NegotiateError();
00363 }
00364 return;
00365 }
00366 }
00367 ReadCompletion(TRUE, 0);
00368 }
00369
00370 template <class T>
00371 int CAsyncSecureSocketClient<T>::handshakeSend(const char* data, int length)
00372 {
00373 m_sendData = data;
00374 m_sendLength = length;
00375 m_sendProcessed = 0;
00376 return T::send(data, length);
00377 }
00378
00379 template <class T>
00380 void CAsyncSecureSocketClient<T>::onSendCompletion(BOOL bRet, DWORD cbSent)
00381 {
00382 if (m_State == State_Connected)
00383 {
00384 CAsyncSecureSocket<T>::onSendCompletion(bRet, cbSent);
00385 return;
00386 }
00387
00388
00389 if (!bRet || cbSent == 0) {
00390 if (cbSent == 0 && bRet)
00391 SetLastError(ERROR_INVALID_FUNCTION);
00392 NegotiateError();
00393 return;
00394 }
00395 SendCompletion(bRet, cbSent);
00396 }
00397
00398 template <class T>
00399 void CAsyncSecureSocketClient<T>::SendCompletion(BOOL , DWORD cbSent)
00400 {
00401 m_sendProcessed += cbSent;
00402
00403 if (m_sendProcessed < m_sendLength)
00404 {
00405 if (T::send(m_sendData + m_sendProcessed, m_sendLength - m_sendProcessed) != 0)
00406 NegotiateError();
00407 return;
00408 }
00409
00410 if (m_sendData)
00411 g_SecurityFunc.FreeContextBuffer( (PVOID)m_sendData );
00412 m_sendData = NULL;
00413
00414 switch (m_State)
00415 {
00416 case State_Negotiate:
00417 case State_Renegotiate:
00418 if (ProcessDecryptedData())
00419 ReadCompletion(TRUE, 0);
00420 break;
00421 case State_Shutdown:
00422
00423 if (CAsyncSecureSocket<T>::shutdown(SD_SEND) == SOCKET_ERROR)
00424 NegotiateError();
00425 break;
00426 }
00427 }
00428
00429 template <class T>
00430 void CAsyncSecureSocketClient<T>::onReadCompletion(BOOL bRet, DWORD cbReceived)
00431 {
00432 if (m_State == State_Connected)
00433 {
00434 CAsyncSecureSocket<T>::onReadCompletion(bRet, cbReceived);
00435 return;
00436 }
00437
00438
00439
00440 if (!bRet || cbReceived == 0) {
00441 if (cbReceived == 0 && bRet)
00442 SetLastError(ERROR_INVALID_FUNCTION);
00443 NegotiateError();
00444 return;
00445 }
00446 ReadCompletion(bRet, cbReceived);
00447 }
00448
00449 template <class T>
00450 void CAsyncSecureSocketClient<T>::ReadCompletion(BOOL , DWORD cbReceived)
00451 {
00452
00453 DWORD dwSSPIFlags =
00454 ISC_REQ_SEQUENCE_DETECT |
00455 ISC_REQ_REPLAY_DETECT |
00456 ISC_REQ_CONFIDENTIALITY |
00457 ISC_RET_EXTENDED_ERROR |
00458 ISC_REQ_ALLOCATE_MEMORY |
00459 ISC_REQ_STREAM;
00460
00461 m_cbHandshakeIo += cbReceived;
00462
00463
00464
00465
00466
00467 while ( m_scRet == SEC_I_CONTINUE_NEEDED ||
00468 m_scRet == SEC_E_INCOMPLETE_MESSAGE ||
00469 m_scRet == SEC_I_INCOMPLETE_CREDENTIALS)
00470 {
00471
00472
00473
00474
00475 if (0 == m_cbHandshakeIo || m_scRet == SEC_E_INCOMPLETE_MESSAGE)
00476 {
00477 if (m_fDoRead) {
00478
00479 if (m_cbHandshakeIo >= IO_BUFFER_SIZE) {
00480 SetLastError((DWORD)E_UNEXPECTED);
00481 NegotiateError();
00482 return;
00483 }
00484
00485
00486
00487 m_scRet = SEC_I_CONTINUE_NEEDED;
00488 if (T::recv(&m_handshakeIo[m_cbHandshakeIo], IO_BUFFER_SIZE-m_cbHandshakeIo) != 0)
00489 NegotiateError();
00490 return;
00491 } else {
00492 m_fDoRead = true;
00493 }
00494 }
00495
00496
00497
00498
00499
00500
00501
00502 m_InBuffers[0].pvBuffer = m_handshakeIo.get();
00503 m_InBuffers[0].cbBuffer = m_cbHandshakeIo;
00504 m_InBuffers[0].BufferType = SECBUFFER_TOKEN;
00505
00506 m_InBuffers[1].pvBuffer = NULL;
00507 m_InBuffers[1].cbBuffer = 0;
00508 m_InBuffers[1].BufferType = SECBUFFER_EMPTY;
00509
00510 m_InBuffer.cBuffers = 2;
00511 m_InBuffer.pBuffers = m_InBuffers;
00512 m_InBuffer.ulVersion = SECBUFFER_VERSION;
00513
00514
00515
00516
00517
00518
00519
00520 m_OutBuffer.cBuffers = 1;
00521 m_OutBuffer.pBuffers = m_OutBuffers;
00522 m_OutBuffer.ulVersion = SECBUFFER_VERSION;
00523
00524 m_OutBuffers[0].pvBuffer = NULL;
00525 m_OutBuffers[0].BufferType = SECBUFFER_TOKEN;
00526 m_OutBuffers[0].cbBuffer = 0;
00527
00528
00529
00530
00531 DWORD dwSSPIOutFlags = 0;
00532 TimeStamp tsExpiry = {0};
00533
00534
00535
00536 __try {
00537 m_scRet = g_SecurityFunc.InitializeSecurityContext(
00538 &m_hCreds,
00539 &m_hContext,
00540 NULL,
00541 dwSSPIFlags,
00542 0,
00543 SECURITY_NATIVE_DREP,
00544 &m_InBuffer,
00545 0,
00546 &m_hContext,
00547 &m_OutBuffer,
00548 &dwSSPIOutFlags,
00549 &tsExpiry);
00550 }
00551 __except(EXCEPTION_EXECUTE_HANDLER)
00552 {
00553 m_scRet = GetExceptionCode();
00554 if (!FAILED(m_scRet)) m_scRet = E_UNEXPECTED;
00555 }
00556
00557
00558
00559
00560
00561
00562
00563 if(m_scRet == SEC_E_OK ||
00564 m_scRet == SEC_I_CONTINUE_NEEDED ||
00565 (FAILED(m_scRet) && (dwSSPIOutFlags & ISC_RET_EXTENDED_ERROR)))
00566 {
00567 if (m_OutBuffers[0].cbBuffer != 0 &&
00568 m_OutBuffers[0].pvBuffer != NULL )
00569 {
00570
00571
00572
00573
00574 if (handshakeSend((const char *)m_OutBuffers[0].pvBuffer, m_OutBuffers[0].cbBuffer) != 0)
00575 {
00576 NegotiateError();
00577 }
00578 return;
00579 }
00580 } else if (m_OutBuffers[0].pvBuffer != NULL) {
00581
00582
00583
00584 g_SecurityFunc.FreeContextBuffer( m_OutBuffers[0].pvBuffer );
00585 m_OutBuffers[0].pvBuffer = NULL;
00586 }
00587
00588 ProcessDecryptedData();
00589 }
00590 }
00591
00592 template <class T>
00593 bool CAsyncSecureSocketClient<T>::ProcessDecryptedData()
00594 {
00595
00596
00597
00598
00599
00600 if ( m_scRet == SEC_E_INCOMPLETE_MESSAGE )
00601 {
00602 return true;
00603 }
00604
00605
00606
00607
00608 if ( m_scRet == SEC_E_OK )
00609 {
00610
00611 SECURITY_STATUS scRet = querySizes();
00612 if (scRet != SEC_E_OK)
00613 {
00614 SetLastError(scRet);
00615 NegotiateError();
00616 return false;
00617 }
00618 assert(m_Sizes.cbMaximumMessage != -1 && m_Sizes.cbHeader != -1 &&
00619 m_Sizes.cbTrailer != -1 && m_Sizes.cbBlockSize != -1 && m_Sizes.cBuffers != -1);
00620
00621
00622
00623 m_Extra.reset(0);
00624
00625
00626
00627 DWORD dwExtraSize = m_Sizes.cbMaximumMessage+m_Sizes.cbHeader+m_Sizes.cbTrailer;
00628 if ( m_InBuffers[1].BufferType == SECBUFFER_EXTRA )
00629 {
00630
00631 if (dwExtraSize < m_InBuffers[1].cbBuffer)
00632 dwExtraSize = m_InBuffers[1].cbBuffer;
00633
00634 m_ExtraCount = m_InBuffers[1].cbBuffer;
00635 m_Extra.reset(new char[dwExtraSize]);
00636 CopyMemory(m_Extra.get(), (LPBYTE) (m_handshakeIo.get() +
00637 (m_cbHandshakeIo - m_InBuffers[1].cbBuffer)), m_ExtraCount);
00638 }
00639
00640
00641 ConnectionState PrevState = m_State;
00642 m_State = State_Connected;
00643 if (PrevState == State_Negotiate) m_pCallback->onConnectCompletion(TRUE);
00644 else CAsyncSecureSocket<T>::ReadCompletion(TRUE, 0);
00645 return false;
00646 }
00647
00648 if (FAILED(m_scRet))
00649 {
00650 SetLastError(m_scRet);
00651 NegotiateError();
00652 return false;
00653 }
00654
00655
00656
00657
00658
00659
00660 if ( m_scRet == SEC_I_INCOMPLETE_CREDENTIALS )
00661 {
00662
00663
00664
00665
00666
00667
00668
00669
00670
00671
00672
00673
00674 m_fDoRead = false;
00675
00676 m_scRet = SEC_I_CONTINUE_NEEDED;
00677 return true;
00678 }
00679
00680
00681
00682
00683
00684
00685 if ( m_InBuffers[1].BufferType == SECBUFFER_EXTRA )
00686 {
00687
00688
00689 MoveMemory(m_handshakeIo.get(),
00690 (LPBYTE) (m_handshakeIo.get() + (m_cbHandshakeIo - m_InBuffers[1].cbBuffer)),
00691 m_InBuffers[1].cbBuffer);
00692 m_cbHandshakeIo = m_InBuffers[1].cbBuffer;
00693 }
00694 else
00695 {
00696
00697
00698
00699 m_cbHandshakeIo = 0;
00700 }
00701 return true;
00702 }
00703
00704 template <class T>
00705 SECURITY_STATUS CAsyncSecureSocketClient<T>::createNullCredentials(
00706 DWORD dwEnabledProtocols)
00707 {
00708 freeCredentials();
00709 SECURITY_STATUS scRet = createCredentialsFromCertificate(&m_hCreds, NULL, dwEnabledProtocols);
00710 if (SUCCEEDED(scRet))
00711 m_ownCredentials = true;
00712 return scRet;
00713 }
00714
00715 template <class T>
00716 SECURITY_STATUS CAsyncSecureSocketClient<T>::setClientCertificate(
00717 PCCERT_CONTEXT pCertContext, DWORD dwEnabledProtocols)
00718 {
00719 freeCredentials();
00720 SECURITY_STATUS scRet = createCredentialsFromCertificate(
00721 &m_hCreds, pCertContext, dwEnabledProtocols);
00722 if (SUCCEEDED(scRet))
00723 m_ownCredentials = true;
00724 return scRet;
00725 }
00726
00727 }
00728
00729 #endif // OW32_AsyncSecureSocketClient_h