Files
Archive/shadowsocks-windows/shadowsocks-csharp/Util/Sockets/WrappedSocket.cs
2024-12-31 19:31:59 +01:00

269 lines
8.1 KiB
C#

using System;
using System.Net;
using System.Net.Sockets;
using System.Threading;
namespace Shadowsocks.Util.Sockets
{
/*
* A wrapped socket class which support both ipv4 and ipv6 based on the
* connected remote endpoint.
*
* If the server address is host name, then it may have both ipv4 and ipv6 address
* after resolving. The main idea is we don't want to resolve and choose the address
* by ourself. Instead, Socket.ConnectAsync() do handle this thing internally by trying
* each address and returning an established socket connection.
*/
public class WrappedSocket
{
public EndPoint LocalEndPoint => _activeSocket?.LocalEndPoint;
// Only used during connection and close, so it won't cost too much.
private SpinLock _socketSyncLock = new SpinLock();
private bool _disposed;
private bool Connected => _activeSocket != null;
private Socket _activeSocket;
public void BeginConnect(EndPoint remoteEP, AsyncCallback callback, object state)
{
if (_disposed)
{
throw new ObjectDisposedException(GetType().FullName);
}
if (Connected)
{
throw new SocketException((int) SocketError.IsConnected);
}
var arg = new SocketAsyncEventArgs();
arg.RemoteEndPoint = remoteEP;
arg.Completed += OnTcpConnectCompleted;
arg.UserToken = new TcpUserToken(callback, state);
if(!Socket.ConnectAsync(SocketType.Stream, ProtocolType.Tcp, arg))
{
OnTcpConnectCompleted(this, arg);
}
}
private class FakeAsyncResult : IAsyncResult
{
public bool IsCompleted { get; } = true;
public WaitHandle AsyncWaitHandle { get; } = null;
public object AsyncState { get; set; }
public bool CompletedSynchronously { get; } = true;
public Exception InternalException { get; set; } = null;
}
private class TcpUserToken
{
public AsyncCallback Callback { get; }
public object AsyncState { get; }
public TcpUserToken(AsyncCallback callback, object state)
{
Callback = callback;
AsyncState = state;
}
}
private void OnTcpConnectCompleted(object sender, SocketAsyncEventArgs args)
{
using (args)
{
args.Completed -= OnTcpConnectCompleted;
var token = (TcpUserToken) args.UserToken;
if (args.SocketError != SocketError.Success)
{
var ex = args.ConnectByNameError ?? new SocketException((int) args.SocketError);
var r = new FakeAsyncResult()
{
AsyncState = token.AsyncState,
InternalException = ex
};
token.Callback(r);
}
else
{
var lockTaken = false;
if (!_socketSyncLock.IsHeldByCurrentThread)
{
_socketSyncLock.TryEnter(ref lockTaken);
}
try
{
if (Connected)
{
args.ConnectSocket.FullClose();
}
else
{
_activeSocket = args.ConnectSocket;
if (_disposed)
{
_activeSocket.FullClose();
}
var r = new FakeAsyncResult()
{
AsyncState = token.AsyncState
};
token.Callback(r);
}
}
finally
{
if (lockTaken)
{
_socketSyncLock.Exit();
}
}
}
}
}
public void EndConnect(IAsyncResult asyncResult)
{
if (_disposed)
{
throw new ObjectDisposedException(GetType().FullName);
}
var r = asyncResult as FakeAsyncResult;
if (r == null)
{
throw new ArgumentException("Invalid asyncResult.", nameof(asyncResult));
}
if (r.InternalException != null)
{
throw r.InternalException;
}
}
public void Dispose()
{
if (_disposed)
{
return;
}
var lockTaken = false;
if (!_socketSyncLock.IsHeldByCurrentThread)
{
_socketSyncLock.TryEnter(ref lockTaken);
}
try
{
_disposed = true;
_activeSocket?.FullClose();
}
finally
{
if (lockTaken)
{
_socketSyncLock.Exit();
}
}
}
public IAsyncResult BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags,
AsyncCallback callback,
object state)
{
if (_disposed)
{
throw new ObjectDisposedException(GetType().FullName);
}
if (!Connected)
{
throw new SocketException((int) SocketError.NotConnected);
}
return _activeSocket.BeginSend(buffer, offset, size, socketFlags, callback, state);
}
public int EndSend(IAsyncResult asyncResult)
{
if (_disposed)
{
throw new ObjectDisposedException(GetType().FullName);
}
if (!Connected)
{
throw new SocketException((int) SocketError.NotConnected);
}
return _activeSocket.EndSend(asyncResult);
}
public IAsyncResult BeginReceive(byte[] buffer, int offset, int size, SocketFlags socketFlags,
AsyncCallback callback,
object state)
{
if (_disposed)
{
throw new ObjectDisposedException(GetType().FullName);
}
if (!Connected)
{
throw new SocketException((int) SocketError.NotConnected);
}
return _activeSocket.BeginReceive(buffer, offset, size, socketFlags, callback, state);
}
public int EndReceive(IAsyncResult asyncResult)
{
if (_disposed)
{
throw new ObjectDisposedException(GetType().FullName);
}
if (!Connected)
{
throw new SocketException((int) SocketError.NotConnected);
}
return _activeSocket.EndReceive(asyncResult);
}
public void Shutdown(SocketShutdown how)
{
if (_disposed)
{
throw new ObjectDisposedException(GetType().FullName);
}
if (!Connected)
{
return;
}
_activeSocket.Shutdown(how);
}
public void SetSocketOption(SocketOptionLevel optionLevel, SocketOptionName optionName, bool optionValue)
{
SetSocketOption(optionLevel, optionName, optionValue ? 1 : 0);
}
public void SetSocketOption(SocketOptionLevel optionLevel, SocketOptionName optionName, int optionValue)
{
if (_disposed)
{
throw new ObjectDisposedException(GetType().FullName);
}
if (!Connected)
{
throw new SocketException((int)SocketError.NotConnected);
}
_activeSocket.SetSocketOption(optionLevel, optionName, optionValue);
}
}
}