Stun Server  Compliant with the latest RFCs including 5389, 5769, and 5780
discover the local host's own external IP address
stunsocketthread.cpp
Go to the documentation of this file.
1 /*
2  Copyright 2011 John Selbie
3 
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7 
8  http://www.apache.org/licenses/LICENSE-2.0
9 
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 */
16 
17 
18 
19 #include "commonincludes.hpp"
20 #include "stuncore.h"
21 #include "stunsocket.h"
22 #include "stunsocketthread.h"
23 #include "recvfromex.h"
24 #include "ratelimiter.h"
25 
26 
28 _arrSendSockets(), // zero-init
29 _fNeedToExit(false),
30 _pthread((pthread_t)-1),
31 _fThreadIsValid(false),
32 _rotation(0),
33 _tsa() // zero-init
34 {
36 }
37 
39 {
40  SignalForStop(true);
42 }
43 
45 {
46  _arrSendSockets = NULL;
47  _socks.clear();
48 }
49 
50 HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets, TransportAddressSet* pTSA, IStunAuth* pAuth, SocketRole rolePrimaryRecv, boost::shared_ptr<RateLimiter>& spLimiter)
51 {
52  HRESULT hr = S_OK;
53 
54  bool fSingleSocketRecv = ::IsValidSocketRole(rolePrimaryRecv);
55 
57 
58  ChkIfA(arrayOfFourSockets == NULL, E_INVALIDARG);
59  ChkIfA(pTSA == NULL, E_INVALIDARG);
60 
61  // if this thread was configured to listen on a single socket (aka "multi-threaded mode"), then
62  // validate that it exists
63  if (fSingleSocketRecv)
64  {
65  ChkIfA(arrayOfFourSockets[rolePrimaryRecv].IsValid()==false, E_UNEXPECTED);
66  }
67 
68  _arrSendSockets = arrayOfFourSockets;
69 
70  // initialize the TSA thing
71  _tsa = *pTSA;
72 
73  if (fSingleSocketRecv)
74  {
75  // only one socket to listen on
76  _socks.push_back(&_arrSendSockets[rolePrimaryRecv]);
77  }
78  else
79  {
80  for (size_t i = 0; i < 4; i++)
81  {
82  if (_arrSendSockets[i].IsValid())
83  {
84  _socks.push_back(&_arrSendSockets[i]);
85  }
86  }
87  }
88 
89 
90 
92 
93  _fNeedToExit = false;
94 
95  _rotation = 0;
96 
97  _spAuth.Attach(pAuth);
98 
99  _spLimiter = spLimiter;
100 
101 Cleanup:
102  return hr;
103 }
104 
106 {
107  HRESULT hr = S_OK;
108 
109  _reader.Reset();
110 
114 
116 
117  _msgIn.fConnectionOriented = false;
120 
121  return hr;
122 }
123 
125 {
126  _reader.Reset();
127  _spBufferReader.reset();
128  _spBufferIn.reset();
129  _spBufferOut.reset();
130 
131  _msgIn.pReader = NULL;
132  _msgOut.spBufferOut.reset();
133 }
134 
135 
137 {
138  HRESULT hr = S_OK;
139  int err = 0;
140 
142 
143  ChkIfA(_socks.size() <= 0, E_FAIL);
144 
145  err = ::pthread_create(&_pthread, NULL, CStunSocketThread::ThreadFunction, this);
146 
147  ChkIfA(err != 0, ERRNO_TO_HRESULT(err));
148  _fThreadIsValid = true;
149 
150 Cleanup:
151  return hr;
152 }
153 
154 
155 
156 
158 {
159 
160  HRESULT hr = S_OK;
161 
162  _fNeedToExit = true;
163 
164  // have the socket send a message to itself
165  // if another thread is sharing the same socket, this may wake that thread up to
166  // but all the threads should be started and shutdown together
167  if (fPostMessages)
168  {
169  for (size_t index = 0; index < _socks.size(); index++)
170  {
171  char data = 'x';
172 
173  ASSERT(_socks[index] != NULL);
174 
175  CSocketAddress addr(_socks[index]->GetLocalAddress());
176 
177 
178  // If no specific adapter was binded to, IP will be 0.0.0.0
179  // Linux evidently treats 0.0.0.0 IP as loopback (and works)
180  // On Windows you can't send to 0.0.0.0. sendto will fail - switch to sending to localhost
181  if (addr.IsIPAddressZero())
182  {
183  CSocketAddress addrLocal;
184  CSocketAddress::GetLocalHost(addr.GetFamily(), &addrLocal);
185  addrLocal.SetPort(addr.GetPort());
186  addr = addrLocal;
187  }
188 
189  ::sendto(_socks[index]->GetSocketHandle(), &data, 1, 0, addr.GetSockAddr(), addr.GetSockAddrLength());
190  }
191  }
192 
193  return hr;
194 }
195 
197 {
198  void* pRetValFromThread = NULL;
199 
200  if (_fThreadIsValid)
201  {
202  // now wait for the thread to exit
203  pthread_join(_pthread, &pRetValFromThread);
204  }
205 
206  _fThreadIsValid = false;
207  _pthread = (pthread_t)-1;
208 
210 
212 
213  return S_OK;
214 }
215 
216 // static
218 {
219  ((CStunSocketThread*)pThis)->Run();
220  return NULL;
221 }
222 
223 // returns an index into _socks, not _arrSockets
225 {
226  fd_set set = {};
227  int nHighestSockValue = 0;
228  int ret;
229  CStunSocket* pReadySocket = NULL;
230  UNREFERENCED_VARIABLE(ret); // only referenced in ASSERT
231  size_t nSocketCount = _socks.size();
232 
233 
234  // rotation gives another socket priority in the next loop
235  _rotation = (_rotation + 1) % nSocketCount;
236  ASSERT(_rotation >= 0);
237 
238  FD_ZERO(&set);
239 
240  for (size_t index = 0; index < nSocketCount; index++)
241  {
242  ASSERT(_socks[index] != NULL);
243  int sock = _socks[index]->GetSocketHandle();
244  ASSERT(sock != -1);
245  FD_SET(sock, &set);
246  nHighestSockValue = (sock > nHighestSockValue) ? sock : nHighestSockValue;
247  }
248 
249  // wait indefinitely for a socket
250  ret = ::select(nHighestSockValue+1, &set, NULL, NULL, NULL);
251 
252  ASSERT(ret > 0); // This will be a benign assert, and should never happen. But I will want to know if it does
253 
254  // now figure out which socket just got data on it
255  for (size_t index = 0; index < nSocketCount; index++)
256  {
257  int indexconverted = (index + _rotation) % nSocketCount;
258  int sock = _socks[indexconverted]->GetSocketHandle();
259 
260  ASSERT(sock != -1);
261 
262  if (FD_ISSET(sock, &set))
263  {
264  pReadySocket = _socks[indexconverted];
265  break;
266  }
267  }
268 
269  ASSERT(pReadySocket != NULL);
270 
271  return pReadySocket;
272 }
273 
274 
276 {
277  size_t nSocketCount = _socks.size();
278  bool fMultiSocketMode = (nSocketCount > 1);
279  int recvflags = fMultiSocketMode ? MSG_DONTWAIT : 0;
280  CStunSocket* pSocket = _socks[0];
281  int ret;
282  char szIPRemote[100] = {};
283  char szIPLocal[100] = {};
284  bool allowed_to_pass = true;
285 
286 
287  int sendsocketcount = 0;
288 
289  sendsocketcount += (int)(_tsa.set[RolePP].fValid);
290  sendsocketcount += (int)(_tsa.set[RolePA].fValid);
291  sendsocketcount += (int)(_tsa.set[RoleAP].fValid);
292  sendsocketcount += (int)(_tsa.set[RoleAA].fValid);
293 
294  Logging::LogMsg(LL_DEBUG, "Starting listener thread (%d recv sockets, %d send sockets)", _socks.size(), sendsocketcount);
295 
296  while (_fNeedToExit == false)
297  {
298 
299  if (fMultiSocketMode)
300  {
301  pSocket = WaitForSocketData();
302 
303  if (_fNeedToExit)
304  {
305  break;
306  }
307 
308  ASSERT(pSocket != NULL);
309 
310  if (pSocket == NULL)
311  {
312  // just go back to waiting;
313  continue;
314  }
315  }
316 
317  ASSERT(pSocket != NULL);
318 
319  // now receive the data
320  _spBufferIn->SetSize(0);
321 
322  ret = ::recvfromex(pSocket->GetSocketHandle(), _spBufferIn->GetData(), _spBufferIn->GetAllocatedSize(), recvflags, &_msgIn.addrRemote, &_msgIn.addrLocal);
323 
324  // recvfromex no longer sets the port value on the local address
325  if (ret >= 0)
326  {
328  }
329 
330 
332  {
333  _msgIn.addrRemote.ToStringBuffer(szIPRemote, 100);
334  _msgIn.addrLocal.ToStringBuffer(szIPLocal, 100);
335  }
336  else
337  {
338  szIPRemote[0] = '\0';
339  szIPLocal[0] = '\0';
340  }
341 
342  Logging::LogMsg(LL_VERBOSE, "recvfrom returns %d from %s on local interface %s", ret, szIPRemote, szIPLocal);
343 
344  allowed_to_pass = (_spLimiter.get() != NULL) ? _spLimiter->RateCheck(_msgIn.addrRemote) : true;
345 
346  if (allowed_to_pass == false)
347  {
348  Logging::LogMsg(LL_VERBOSE, "RateLimiter signals false for packet from %s", szIPRemote);
349  }
350 
351  if ((ret < 0) || (allowed_to_pass == false))
352  {
353  // error
354  continue;
355  }
356 
357  if (_fNeedToExit)
358  {
359  break;
360  }
361 
362  _spBufferIn->SetSize(ret);
363 
364  _msgIn.socketrole = pSocket->GetRole();
365 
366 
367  // --------------------------------------------------------------------
368  // now let's handle this message and get the response back out
369 
371  }
372 
373  Logging::LogMsg(LL_DEBUG, "Thread exiting");
374 }
375 
376 
378 {
379  HRESULT hr = S_OK;
380  int sendret = -1;
381  int sockout = -1;
382 
383  // Reset the reader object and re-attach the buffer
384  _reader.Reset();
385  _spBufferReader->SetSize(0);
387 
388  // Consume the message and just validate that it is a stun message
389  _reader.AddBytes(_spBufferIn->GetData(), _spBufferIn->GetSize());
391 
392  // msgIn and msgOut are already initialized
393 
395 
399  ASSERT(sockout != -1);
400 
401  // find the socket that matches the role specified by msgOut
402  sendret = ::sendto(sockout, _spBufferOut->GetData(), _spBufferOut->GetSize(), 0, _msgOut.addrDest.GetSockAddr(), _msgOut.addrDest.GetSockAddrLength());
403 
404 
406  {
407  Logging::LogMsg(LL_VERBOSE, "sendto returns %d (err == %d)\n", sendret, errno);
408  }
409 
410 
411 Cleanup:
412  return hr;
413 }
414 
415 
416 
417 
418 
const uint32_t MAX_STUN_MESSAGE_SIZE
Definition: stuntypes.h:178
#define S_OK
Definition: hresult.h:46
#define ASSERT(expr)
bool IsValidSocketRole(SocketRole sr)
Definition: socketrole.h:31
std::vector< CStunSocket * > _socks
const uint32_t LL_DEBUG
Definition: logger.h:24
StunMessageOut _msgOut
CDataStream & GetStream()
Definition: stunreader.cpp:855
CStunMessageReader _reader
CStunSocket * _arrSendSockets
socklen_t GetSockAddrLength() const
CRefCountedBuffer _spBufferOut
SocketRole socketrole
#define Chk(expr)
Definition: chkmacros.h:53
CRefCountedBuffer _spBufferReader
ssize_t recvfromex(int sockfd, void *buf, size_t len, int flags, CSocketAddress *pSrcAddr, CSocketAddress *pDstAddr)
Definition: recvfromex.cpp:44
void LogMsg(uint32_t level, const char *pszFormat,...)
Definition: logger.cpp:44
static HRESULT GetLocalHost(uint16_t family, CSocketAddress *pAddr)
CRefCountedBuffer _spBufferIn
SocketRole socketrole
TransportAddressSet _tsa
HRESULT Init(CStunSocket *arrayOfFourSockets, TransportAddressSet *pTSA, IStunAuth *pAuth, SocketRole rolePrimaryRecv, boost::shared_ptr< RateLimiter > &_spRateLimiter)
uint16_t GetPort() const
#define UNREFERENCED_VARIABLE(unrefparam)
HRESULT ProcessRequestAndSendResponse()
#define ERRNO_TO_HRESULT(err)
Definition: hresult.h:41
#define ChkIf(expr, hrerror)
Definition: chkmacros.h:63
ReaderParseState AddBytes(const uint8_t *pData, uint32_t size)
Definition: stunreader.cpp:750
#define E_UNEXPECTED
Definition: hresult.h:48
Definition: buffer.h:27
CSocketAddress addrRemote
What local IP address the message was received on (useful if the socket binded to INADDR_ANY) ...
CSocketAddress addrLocal
which socket id did the message arrive on
boost::shared_ptr< RateLimiter > _spLimiter
void SetPort(uint16_t)
int GetSocketHandle() const
Definition: stunsocket.cpp:82
ReaderParseState GetState()
Definition: stunreader.cpp:820
SocketRole GetRole() const
Definition: stunsocket.cpp:98
static void * ThreadFunction(void *pThis)
SocketRole
Definition: socketrole.h:22
int32_t HRESULT
Definition: hresult.h:22
bool IsValid()
Definition: stunsocket.cpp:51
const sockaddr * GetSockAddr() const
CSocketAddress addrDest
StunMessageIn _msgIn
HRESULT ToStringBuffer(char *pszAddrBytes, size_t length) const
#define E_INVALIDARG
Definition: hresult.h:51
CStunMessageReader * pReader
the address of the node that sent us the message
#define E_FAIL
Definition: hresult.h:56
static HRESULT ProcessRequest(const StunMessageIn &msgIn, StunMessageOut &msgOut, TransportAddressSet *pAddressSet, IStunAuth *pAuth)
void Attach(T *ptr)
const uint32_t LL_VERBOSE
Definition: logger.h:25
uint32_t GetLogLevel()
Definition: logger.cpp:33
CRefCountedPtr< IStunAuth > _spAuth
CStunSocket * WaitForSocketData()
uint16_t GetFamily() const
TransportAddress set[4]
CRefCountedBuffer spBufferOut
boost::shared_ptr< CBuffer > CRefCountedBuffer
Definition: buffer.h:65
void Attach(CRefCountedBuffer &buffer, bool fForWriting)
Definition: datastream.cpp:55
bool IsIPAddressZero() const
HRESULT SignalForStop(bool fPostMessages)
#define ChkIfA(expr, hrerror)
Definition: chkmacros.h:84
bool fConnectionOriented
reader containing a valid stun message
const CSocketAddress & GetLocalAddress() const
Definition: stunsocket.cpp:87