00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00026 #ifndef OW32_AsyncSecureSocket_h
00027 #define OW32_AsyncSecureSocket_h
00028
00029 #include <OW32/windows.h>
00030 #include <OW32/AsyncSocket.h>
00031 #include <OW32/SSLInit.h>
00032 #include <OW32/auto_array_ptr.h>
00033 #include <tchar.h>
00034 #include <cassert>
00035
00036
00037 namespace OW32
00038 {
00039
00041 template <class T>
00042 class CAsyncSecureSocket : public T
00043 {
00044 private:
00045 CAsyncSecureSocket(const CAsyncSecureSocket& );
00046 CAsyncSecureSocket& operator= (const CAsyncSecureSocket& );
00047 void initialise();
00048
00049 public:
00051 void SetCallback(CAsyncSocketCallback* pCallback) { m_pCallback = pCallback; }
00053 CAsyncSocketCallback* GetCallback() const { return m_pCallback; }
00054
00055 virtual void onConnectCompletion(BOOL ) {}
00056 virtual void onCloseCompletion(BOOL ) {}
00057 virtual void onSendCompletion(BOOL bRet, DWORD cbSent);
00058 virtual void onReadCompletion(BOOL bRet, DWORD cbReceived);
00059 virtual void onTimeout() = 0;
00060
00065 CAsyncSecureSocket(CAsyncSocketCallback* pCallback, SOCKET s) :
00066 T(&m_callback,s),
00067 m_pCallback(pCallback)
00068 {
00069 initialise();
00070 }
00071
00075 CAsyncSecureSocket(CAsyncSocketCallback* pCallback) :
00076 T(&m_callback,s),
00077 m_pCallback(pCallback)
00078 {
00079 initialise();
00080 }
00081
00084 void negotiate() { m_State = State_Negotiate; negotiateLoop();}
00085
00088 ~CAsyncSecureSocket();
00089
00090 virtual int send(const char* buf, int len);
00091 virtual int recv(char* buf, int len);
00092
00096 void setCredentials(CredHandle hCreds) { m_hCreds = hCreds; }
00097
00101 void setRequireClientAuth(bool bRequireClientAuth) { m_bRequireClientAuth = bRequireClientAuth; }
00102
00109 static SECURITY_STATUS createCredentialsFromCertificate(CredHandle* phCreds, PCCERT_CONTEXT pCertContext,
00110 DWORD dwDirection, DWORD dwEnabledProtocols = 0);
00111
00115 SECURITY_STATUS getRemoteCert(PCCERT_CONTEXT* pRemoteCertContext);
00116
00119 void freeCredentials();
00120
00121 protected:
00123 virtual void negotiateLoop()=0;
00124
00125
00126 class CAsyncSecureSocketCallback : public CAsyncSocketCallback
00127 {
00128 virtual void onReadCompletion(BOOL bRet, DWORD cbReceived) {
00129 CAsyncSecureSocket* pParent = (CAsyncSecureSocket*)((char*)this - offsetof(CAsyncSecureSocket,m_callback));
00130 pParent->onReadCompletion(bRet,cbReceived);
00131 }
00132 virtual void onSendCompletion(BOOL bRet, DWORD cbSent) {
00133 CAsyncSecureSocket* pParent = (CAsyncSecureSocket*)((char*)this - offsetof(CAsyncSecureSocket,m_callback));
00134 pParent->onSendCompletion(bRet,cbSent);
00135 }
00136 virtual void onConnectCompletion(BOOL bRet) {
00137 CAsyncSecureSocket* pParent = (CAsyncSecureSocket*)((char*)this - offsetof(CAsyncSecureSocket,m_callback));
00138 pParent->onConnectCompletion(bRet);
00139 }
00140 virtual void onCloseCompletion(BOOL bRet) {
00141 CAsyncSecureSocket* pParent = (CAsyncSecureSocket*)((char*)this - offsetof(CAsyncSecureSocket,m_callback));
00142 pParent->onCloseCompletion(bRet);
00143 }
00144 virtual void onTimeout() {
00145 CAsyncSecureSocket* pParent = (CAsyncSecureSocket*)((char *)this - offsetof(CAsyncSecureSocket,m_callback));
00146 pParent->GetCallback()->onTimeout();
00147 }
00148 virtual void onAllNotificationsProcessed() {
00149 CAsyncSecureSocket* pParent = (CAsyncSecureSocket*)((char *)this - offsetof(CAsyncSecureSocket,m_callback));
00150 pParent->GetCallback()->onAllNotificationsProcessed();
00151 }
00152 };
00153 friend CAsyncSecureSocketCallback;
00154 CAsyncSecureSocketCallback m_callback;
00155
00156 void ReadCompletion(BOOL bRet, DWORD cbReceived);
00157 void SendCompletion(BOOL bRet, DWORD cbSent);
00158 SECURITY_STATUS querySizes();
00159
00160 SecPkgContext_StreamSizes m_Sizes;
00161 CtxtHandle m_hContext;
00162 CredHandle m_hCreds;
00163
00164 auto_array_ptr<char> m_Extra;
00165 int m_ExtraCount;
00166 int m_ExtraDecrypted,m_ExtraDecryptedPos;
00167 int m_recvLen;
00168 char* m_recvBuf;
00169
00170 auto_array_ptr<char> m_SendBuf;
00171 const char* m_sendData;
00172 int m_bufferSent,m_totalBuffered;
00173 int m_sendProcessed,m_sendLength;
00174
00175 bool m_bRequireClientAuth;
00176 bool m_ownCredentials;
00177 bool m_fUseIoCompletion;
00178
00179 enum ConnectionState
00180 {
00181 State_Initial,
00182 State_Negotiate,
00183 State_Connected,
00184 State_Renegotiate,
00185 State_Shutdown
00186 };
00187 ConnectionState m_State;
00188
00189 CAsyncSocketCallback* m_pCallback;
00190 };
00191
00192 template <class T>
00193 void CAsyncSecureSocket<T>::initialise()
00194 {
00195 m_bRequireClientAuth = false;
00196 m_ownCredentials = false;
00197 m_State = State_Initial;
00198
00199 m_ExtraCount = 0;
00200 m_ExtraDecrypted = 0;
00201 m_ExtraDecryptedPos = 0;
00202
00203 m_sendData = 0;
00204 m_bufferSent = m_totalBuffered = 0;
00205 m_sendProcessed = m_sendLength = 0;
00206
00207 SecInvalidateHandle(&m_hContext);
00208 SecInvalidateHandle(&m_hCreds);
00209
00210
00211 m_Sizes.cbMaximumMessage = (DWORD) -1;
00212 m_Sizes.cbHeader = (DWORD)-1;
00213 m_Sizes.cbTrailer = (DWORD)-1;
00214 m_Sizes.cbBlockSize = (DWORD)-1;
00215 m_Sizes.cBuffers = (DWORD)-1;
00216 }
00217
00218 template <class T>
00219 CAsyncSecureSocket<T>::~CAsyncSecureSocket()
00220 {
00221 if (SecIsValidHandle(&m_hContext))
00222 g_SecurityFunc.DeleteSecurityContext(&m_hContext);
00223 freeCredentials();
00224 }
00225
00226 template <class T>
00227 void CAsyncSecureSocket<T>::freeCredentials()
00228 {
00229
00230 if (m_ownCredentials && SecIsValidHandle(&m_hCreds)) {
00231 g_SecurityFunc.DeleteSecurityContext(&m_hContext);
00232 SecInvalidateHandle(&m_hCreds);
00233 m_ownCredentials = false;
00234 }
00235 }
00236
00237 template <class T>
00238 SECURITY_STATUS CAsyncSecureSocket<T>::getRemoteCert(PCCERT_CONTEXT* pRemoteCertContext)
00239 {
00240 return g_SecurityFunc.QueryContextAttributes(&m_hContext,
00241 SECPKG_ATTR_REMOTE_CERT_CONTEXT,
00242 (PVOID)pRemoteCertContext);
00243 }
00244
00245 template <class T>
00246 int CAsyncSecureSocket<T>::send(const char* buf, int len)
00247 {
00248 if (m_State == State_Initial)
00249 {
00250 return T::send(buf,len);
00251 }
00252 m_bufferSent = m_totalBuffered = m_sendProcessed = 0;
00253 m_sendLength = len;
00254 m_sendData = buf;
00255 SendCompletion(TRUE, 0);
00256 return 0;
00257 }
00258
00259 template <class T>
00260 void CAsyncSecureSocket<T>::onSendCompletion(BOOL bRet, DWORD cbSent)
00261 {
00262 if (!bRet || cbSent == 0) {
00263 m_pCallback->onSendCompletion(bRet, cbSent);
00264 return;
00265 }
00266 SendCompletion(bRet, cbSent);
00267 }
00268
00269 template <class T>
00270 void CAsyncSecureSocket<T>::SendCompletion(BOOL , DWORD cbSent)
00271 {
00272
00273
00274 m_bufferSent += cbSent;
00275
00276 SecBuffer Buffers[4];
00277 SecBufferDesc Message;
00278
00279 Message.ulVersion = SECBUFFER_VERSION;
00280 Message.cBuffers = 4;
00281 Message.pBuffers = Buffers;
00282
00283
00284
00285
00286 assert(m_Sizes.cbMaximumMessage != -1 && m_Sizes.cbHeader != -1 &&
00287 m_Sizes.cbTrailer != -1 && m_Sizes.cbBlockSize != -1 && m_Sizes.cBuffers != -1);
00288 if (!m_SendBuf.get())
00289 m_SendBuf.reset(new char[m_Sizes.cbMaximumMessage+m_Sizes.cbHeader+m_Sizes.cbTrailer]);
00290
00291 for (;;)
00292 {
00293
00294 if (m_bufferSent < m_totalBuffered)
00295 {
00296 int ret = T::send(m_SendBuf.get() + m_bufferSent, m_totalBuffered - m_bufferSent);
00297 if (ret != 0)
00298 m_pCallback->onSendCompletion(FALSE, 0);
00299 return;
00300 }
00301
00302
00303 m_totalBuffered = m_bufferSent = 0;
00304
00305
00306 if (m_sendProcessed >= m_sendLength)
00307 {
00308 m_pCallback->onSendCompletion(TRUE, m_sendProcessed);
00309 return;
00310 }
00311
00312
00313 Buffers[0].pvBuffer = m_SendBuf.get();
00314 Buffers[0].cbBuffer = m_Sizes.cbHeader;
00315 Buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
00316
00317
00318 DWORD chunksize = (m_sendLength-m_sendProcessed);
00319 if (chunksize > m_Sizes.cbMaximumMessage)
00320 chunksize = m_Sizes.cbMaximumMessage;
00321 CopyMemory(m_SendBuf.get()+m_Sizes.cbHeader,
00322 m_sendData+m_sendProcessed, chunksize);
00323
00324 Buffers[1].pvBuffer = m_SendBuf.get()+m_Sizes.cbHeader;
00325 Buffers[1].cbBuffer = chunksize;
00326 Buffers[1].BufferType = SECBUFFER_DATA;
00327
00328 Buffers[2].pvBuffer = m_SendBuf.get()+m_Sizes.cbHeader+chunksize;
00329 Buffers[2].cbBuffer = m_Sizes.cbTrailer;
00330 Buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
00331
00332 Buffers[3].BufferType = SECBUFFER_EMPTY;
00333
00334 SECURITY_STATUS scRet = g_SecurityFunc.EncryptMessage(&m_hContext, 0, &Message, 0);
00335 if (FAILED(scRet))
00336 {
00337 SetLastError(scRet);
00338 m_pCallback->onSendCompletion(FALSE, 0);
00339 return;
00340 }
00341
00342 m_sendProcessed += chunksize;
00343 m_totalBuffered = m_Sizes.cbHeader+chunksize+m_Sizes.cbTrailer;
00344
00345
00346 }
00347 }
00348
00349 template <class T>
00350 int CAsyncSecureSocket<T>::recv(char* buf, int len)
00351 {
00352 if (m_State == State_Initial)
00353 {
00354 return T::recv(buf,len);
00355 }
00356
00357
00358
00359
00360 if (len == 0) {
00361 assert(0);
00362 return 0;
00363 }
00364 m_recvBuf = buf;
00365 m_recvLen = len;
00366 ReadCompletion(TRUE, 0);
00367 return 0;
00368 }
00369
00370 template <class T>
00371 void CAsyncSecureSocket<T>::onReadCompletion(BOOL bRet, DWORD cbIoBuffer)
00372 {
00373
00374
00375 if (!bRet || cbIoBuffer == 0) {
00376 m_pCallback->onReadCompletion(bRet, cbIoBuffer);
00377 return;
00378 }
00379 ReadCompletion(bRet, cbIoBuffer);
00380 }
00381
00382 template <class T>
00383 void CAsyncSecureSocket<T>::ReadCompletion(BOOL , DWORD cbIoBuffer)
00384 {
00385
00386
00387 SecBuffer Buffers[4];
00388 SecBufferDesc Message;
00389
00390 Message.ulVersion = SECBUFFER_VERSION;
00391 Message.cBuffers = 4;
00392 Message.pBuffers = Buffers;
00393
00394
00395 SECURITY_STATUS scRet = SEC_E_OK;
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405
00406
00407 int DecryptedCount = m_ExtraDecrypted-m_ExtraDecryptedPos;
00408 if (DecryptedCount > 0) {
00409 m_recvLen = min(DecryptedCount, m_recvLen);
00410
00411 CopyMemory(m_recvBuf, m_Extra.get()+m_ExtraDecryptedPos, m_recvLen);
00412 m_ExtraDecryptedPos += m_recvLen;
00413 m_pCallback->onReadCompletion(TRUE, m_recvLen);
00414 return;
00415 }
00416
00417 if (!m_Extra.get())
00418 m_Extra.reset(new char[m_Sizes.cbMaximumMessage+m_Sizes.cbHeader+m_Sizes.cbTrailer]);
00419
00420
00421 if (m_ExtraCount)
00422 {
00423 if (m_ExtraDecrypted)
00424 MoveMemory(m_Extra.get(), m_Extra.get()+m_ExtraDecrypted, m_ExtraCount);
00425 cbIoBuffer += m_ExtraCount;
00426 }
00427
00428
00429
00430 m_ExtraCount = 0;
00431 m_ExtraDecryptedPos = 0;
00432 m_ExtraDecrypted = 0;
00433
00434 do
00435 {
00436
00437
00438
00439 if (cbIoBuffer == 0 || scRet == SEC_E_INCOMPLETE_MESSAGE)
00440 {
00441
00442 int get = m_Sizes.cbHeader + m_Sizes.cbMaximumMessage + m_Sizes.cbTrailer - cbIoBuffer;
00443 if (get <= 0) {
00444 SetLastError((DWORD)E_UNEXPECTED);
00445 m_pCallback->onReadCompletion(FALSE, 0);
00446 return;
00447 }
00448
00449 int ret = T::recv(m_Extra.get() + cbIoBuffer, get);
00450 m_ExtraCount = cbIoBuffer;
00451 if (ret != 0)
00452 m_pCallback->onReadCompletion(FALSE, 0);
00453 return;
00454 }
00455
00456 Buffers[0].pvBuffer = m_Extra.get();
00457 Buffers[0].cbBuffer = cbIoBuffer;
00458 Buffers[0].BufferType = SECBUFFER_DATA;
00459
00460 Buffers[1].BufferType = SECBUFFER_EMPTY;
00461 Buffers[2].BufferType = SECBUFFER_EMPTY;
00462 Buffers[3].BufferType = SECBUFFER_EMPTY;
00463
00464 scRet = g_SecurityFunc.DecryptMessage(&m_hContext, &Message, 0, NULL);
00465 }
00466 while (scRet == SEC_E_INCOMPLETE_MESSAGE);
00467
00468
00469 if (scRet == SEC_I_CONTEXT_EXPIRED || scRet == SEC_E_CONTEXT_EXPIRED) {
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482
00483
00484
00485
00486
00487 m_pCallback->onReadCompletion(TRUE, 0);
00488 return ;
00489 }
00490
00491
00492 if (scRet != SEC_E_OK && scRet != SEC_I_RENEGOTIATE)
00493 {
00494 SetLastError(scRet);
00495 m_pCallback->onReadCompletion(FALSE, m_recvLen);
00496 return;
00497 }
00498
00499
00500
00501 DWORD cbGot = 0;
00502 if (Buffers[1].BufferType == SECBUFFER_DATA) {
00503 cbGot = min(m_recvLen, (int)Buffers[1].cbBuffer);
00504 CopyMemory(m_recvBuf, Buffers[1].pvBuffer, cbGot);
00505
00506
00507 if ((int)Buffers[1].cbBuffer > cbGot) {
00508 MoveMemory(m_Extra.get(), (BYTE*)Buffers[1].pvBuffer + cbGot, Buffers[1].cbBuffer - cbGot);
00509 m_ExtraDecrypted = Buffers[1].cbBuffer - cbGot;
00510 }
00511 }
00512
00513
00514
00515 if (Buffers[3].BufferType == SECBUFFER_EXTRA && Buffers[3].cbBuffer > 0) {
00516 m_ExtraCount = Buffers[3].cbBuffer;
00517 MoveMemory(m_Extra.get() + m_ExtraDecrypted,
00518 m_Extra.get() + cbIoBuffer - m_ExtraCount, m_ExtraCount);
00519 }
00520
00521
00522 if (scRet == SEC_I_RENEGOTIATE) {
00523
00524
00525
00526 m_State = State_Renegotiate;
00527 negotiateLoop();
00528 return;
00529 }
00530 m_pCallback->onReadCompletion(TRUE, cbGot);
00531 }
00532
00533 template <class T>
00534 SECURITY_STATUS CAsyncSecureSocket<T>::querySizes()
00535 {
00536
00537 #if 0
00538
00539
00540
00541 return g_SecurityFunc.QueryContextAttributes(&m_hContext, SECPKG_ATTR_STREAM_SIZES, &m_Sizes);
00542 #else
00543 m_Sizes.cbMaximumMessage = 16384;
00544 m_Sizes.cbHeader = 5;
00545 m_Sizes.cbTrailer = 16;
00546 m_Sizes.cBuffers = 4;
00547 m_Sizes.cbBlockSize = 1;
00548 return S_OK;
00549 #endif
00550 }
00551
00552 template <class T>
00553 SECURITY_STATUS CAsyncSecureSocket<T>::createCredentialsFromCertificate(
00554 CredHandle* phCredHandle,
00555 PCCERT_CONTEXT pCertContext,
00556 DWORD dwDirection,
00557 DWORD dwEnabledProtocols)
00558 {
00559
00560 SCHANNEL_CRED SchannelCred;
00561 ZeroMemory(&SchannelCred, sizeof(SchannelCred));
00562
00563 SchannelCred.dwVersion = SCHANNEL_CRED_VERSION;
00564 SchannelCred.grbitEnabledProtocols = dwEnabledProtocols;
00565
00566 OSVERSIONINFO osVer;
00567 ZeroMemory(&osVer, sizeof(osVer));
00568 osVer.dwOSVersionInfoSize = sizeof(osVer);
00569 if (!GetVersionEx(&osVer))
00570 return HRESULT_FROM_WIN32(::GetLastError());
00571
00572 if (osVer.dwMajorVersion >= 5) {
00573 if (dwDirection == SECPKG_CRED_INBOUND) {
00574 SchannelCred.dwFlags |= SCH_CRED_NO_SYSTEM_MAPPER;
00575 } else {
00576 SchannelCred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS;
00577 SchannelCred.dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION;
00578 }
00579 }
00580
00581 if (pCertContext) {
00582 SchannelCred.cCreds = 1;
00583 SchannelCred.paCred = &pCertContext;
00584 }
00585
00586
00587 TimeStamp tsExpiry;
00588
00589 SECURITY_STATUS Status = g_SecurityFunc.AcquireCredentialsHandle(
00590 NULL,
00591 UNISP_NAME,
00592 dwDirection,
00593 NULL,
00594 &SchannelCred,
00595 NULL,
00596 NULL,
00597 phCredHandle,
00598 &tsExpiry);
00599 return Status;
00600 }
00601
00602 }
00603
00604 #endif // OW32_AsyncSecureSocket_h