一些常规优化

This commit is contained in:
snltty
2024-11-29 00:49:22 +08:00
parent 75e66acebb
commit fb7e2d4f4e
34 changed files with 322 additions and 235 deletions

View File

@@ -1,6 +1,7 @@
using System;
using System.Buffers;
using System.Buffers.Binary;
using System.Collections.Generic;
using System.Net;
using System.Security.Cryptography;
using System.Text;
@@ -13,7 +14,7 @@ namespace linker.libs.websocket
public static class WebSocketParser
{
private readonly static SHA1 sha1 = SHA1.Create();
private readonly static Memory<byte> magicCode = Encoding.ASCII.GetBytes("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
private readonly static Memory<byte> magicCode = Encoding.UTF8.GetBytes("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
/// <summary>
/// 构建连接数据
/// </summary>
@@ -21,21 +22,23 @@ namespace linker.libs.websocket
/// <returns></returns>
public static byte[] BuildConnectData(WebsocketHeaderInfo header)
{
string path = header.Path.Length == 0 ? "/" : Encoding.UTF8.GetString(header.Path.Span);
string path = header.Path.Length == 0 ? "/" : header.Path;
header.TryGetHeaderValue(WebsocketHeaderKey.SecWebSocketKey,out string key);
StringBuilder sb = new StringBuilder(10);
sb.Append($"GET {path} HTTP/1.1\r\n");
sb.Append($"Upgrade: websocket\r\n");
sb.Append($"Connection: Upgrade\r\n");
sb.Append($"Sec-WebSocket-Version: 13\r\n");
sb.Append($"Sec-WebSocket-Key: {Encoding.UTF8.GetString(header.SecWebSocketKey.Span)}\r\n");
if (header.SecWebSocketProtocol.Length > 0)
sb.Append($"Sec-WebSocket-Key: {key}\r\n");
if (header.TryGetHeaderValue(WebsocketHeaderKey.SecWebSocketProtocol, out string protocol))
{
sb.Append($"Sec-WebSocket-Protocol: {Encoding.UTF8.GetString(header.SecWebSocketProtocol.Span)}\r\n");
sb.Append($"Sec-WebSocket-Protocol: {protocol}\r\n");
}
if (header.SecWebSocketExtensions.Length > 0)
if (header.TryGetHeaderValue(WebsocketHeaderKey.SecWebSocketExtensions, out string extensions))
{
sb.Append($"Sec-WebSocket-Extensions: {Encoding.UTF8.GetString(header.SecWebSocketExtensions.Span)}\r\n");
sb.Append($"Sec-WebSocket-Extensions: {extensions}\r\n");
}
sb.Append("\r\n");
@@ -48,30 +51,31 @@ namespace linker.libs.websocket
/// <returns></returns>
public static byte[] BuildConnectResponseData(WebsocketHeaderInfo header)
{
string acceptStr = BuildSecWebSocketAccept(header.SecWebSocketKey);
header.TryGetHeaderValue(WebsocketHeaderKey.SecWebSocketKey,out string key);
string acceptStr = BuildSecWebSocketAccept(key);
StringBuilder sb = new StringBuilder(10);
sb.Append($"HTTP/1.1 {(int)header.StatusCode} {AddSpace(header.StatusCode)}\r\n");
sb.Append($"Sec-WebSocket-Accept: {acceptStr}\r\n");
if (header.Connection.Length > 0)
if (header.TryGetHeaderValue(WebsocketHeaderKey.Connection, out string str1))
{
sb.Append($"Connection: {Encoding.UTF8.GetString(header.Connection.Span)}\r\n");
sb.Append($"Connection: {str1}\r\n");
}
if (header.Upgrade.Length > 0)
if (header.TryGetHeaderValue(WebsocketHeaderKey.Upgrade, out str1))
{
sb.Append($"Upgrade: {Encoding.UTF8.GetString(header.Upgrade.Span)}\r\n");
sb.Append($"Upgrade: {str1}\r\n");
}
if (header.SecWebSocketVersion.Length > 0)
if (header.TryGetHeaderValue(WebsocketHeaderKey.SecWebSocketVersion, out str1))
{
sb.Append($"Sec-WebSocket-Version: {Encoding.UTF8.GetString(header.SecWebSocketVersion.Span)}\r\n");
sb.Append($"Sec-Websocket-Version: {str1}\r\n");
}
if (header.SecWebSocketProtocol.Length > 0)
if (header.TryGetHeaderValue(WebsocketHeaderKey.SecWebSocketProtocol, out str1))
{
sb.Append($"Sec-WebSocket-Protocol: {Encoding.UTF8.GetString(header.SecWebSocketProtocol.Span)}\r\n");
sb.Append($"Sec-WebSocket-Protocol: {str1}\r\n");
}
if (header.SecWebSocketExtensions.Length > 0)
if (header.TryGetHeaderValue(WebsocketHeaderKey.SecWebSocketExtensions, out str1))
{
sb.Append($"Sec-WebSocket-Extensions: {Encoding.UTF8.GetString(header.SecWebSocketExtensions.Span)}\r\n");
sb.Append($"Sec-WebSocket-Extensions: {str1}\r\n");
}
sb.Append("\r\n");
@@ -81,7 +85,7 @@ namespace linker.libs.websocket
/// 生成随机key
/// </summary>
/// <returns></returns>
public static byte[] BuildSecWebSocketKey()
public static string BuildSecWebSocketKey()
{
byte[] bytes = new byte[16];
Random random = new Random(DateTime.Now.Ticks.GetHashCode());
@@ -89,8 +93,7 @@ namespace linker.libs.websocket
{
bytes[i] = (byte)random.Next(0, 255);
}
byte[] res = Encoding.UTF8.GetBytes(Convert.ToBase64String(bytes));
return res;
return Convert.ToBase64String(bytes);
}
/// <summary>
/// 构建mask数据
@@ -113,12 +116,12 @@ namespace linker.libs.websocket
/// </summary>
/// <param name="key"></param>
/// <returns></returns>
private static string BuildSecWebSocketAccept(Memory<byte> key)
private static string BuildSecWebSocketAccept(string key)
{
int keyLength = key.Length + magicCode.Length;
byte[] acceptBytes = new byte[keyLength];
key.CopyTo(acceptBytes);
Encoding.UTF8.GetBytes(key).AsMemory().CopyTo(acceptBytes);
magicCode.CopyTo(acceptBytes.AsMemory(key.Length));
string acceptStr = Convert.ToBase64String(sha1.ComputeHash(acceptBytes, 0, keyLength));
@@ -131,10 +134,10 @@ namespace linker.libs.websocket
/// <param name="key"></param>
/// <param name="accept"></param>
/// <returns></returns>
public static bool VerifySecWebSocketAccept(Memory<byte> key, Memory<byte> accept)
public static bool VerifySecWebSocketAccept(string key, string accept)
{
string acceptStr = BuildSecWebSocketAccept(key);
return acceptStr == Encoding.UTF8.GetString(accept.Span);
return acceptStr == accept;
}
/// <summary>
@@ -536,113 +539,108 @@ namespace linker.libs.websocket
/// </summary>
public sealed class WebsocketHeaderInfo
{
static byte[][] bytes = new byte[][] {
Encoding.ASCII.GetBytes("Connection: "),
Encoding.ASCII.GetBytes("Upgrade: "),
Encoding.ASCII.GetBytes("Origin: "),
Encoding.ASCII.GetBytes("Sec-WebSocket-Version: "),
Encoding.ASCII.GetBytes("Sec-WebSocket-Key: "),
Encoding.ASCII.GetBytes("Sec-WebSocket-Extensions: "),
Encoding.ASCII.GetBytes("Sec-WebSocket-Protocol: "),
Encoding.ASCII.GetBytes("Sec-WebSocket-Accept: ")
};
static byte[] httpBytes = Encoding.UTF8.GetBytes("HTTP/");
static byte[] endBytes = Encoding.UTF8.GetBytes("\r\n");
static byte[] splitBytes = Encoding.UTF8.GetBytes(": ");
public HttpStatusCode StatusCode { get; set; } = HttpStatusCode.SwitchingProtocols;
public Memory<byte> Method { get; private set; }
private string _pathSet { get; set; }
/// <summary>
/// 用这个设置path值
/// 状态码
/// </summary>
public string PathSet
public HttpStatusCode StatusCode { get; set; } = HttpStatusCode.SwitchingProtocols;
/// <summary>
/// 方法
/// </summary>
public string Method { get; private set; }
/// <summary>
/// 路径
/// </summary>
public string Path { get; set; }
/// <summary>
/// 请求头
/// </summary>
public Dictionary<string, string> Headers { get; private set; } = new Dictionary<string, string>();
/// <summary>
/// 获取请求头
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
/// <returns></returns>
public bool TryGetHeaderValue(string key, out string value)
{
get
{
return _pathSet;
}
set
{
_pathSet = value;
Path = Encoding.UTF8.GetBytes(_pathSet);
}
return Headers.TryGetValue(key, out value) && string.IsNullOrWhiteSpace(value) == false;
}
/// <summary>
/// 如果 仅1个字符那就是 /
/// 设置请求头
/// </summary>
public Memory<byte> Path { get; private set; }
public Memory<byte> Connection { get; private set; }
public Memory<byte> Upgrade { get; private set; }
public Memory<byte> Origin { get; private set; }
public Memory<byte> SecWebSocketVersion { get; private set; }
public Memory<byte> SecWebSocketKey { get; set; }
public Memory<byte> SecWebSocketExtensions { get; set; }
public Memory<byte> SecWebSocketProtocol { get; set; }
public Memory<byte> SecWebSocketAccept { get; set; }
/// <param name="key"></param>
/// <param name="value"></param>
public void SetHeaderValue(string key, string value)
{
Headers[key] = value;
}
/// <summary>
/// 解析websocket请求头
/// </summary>
/// <param name="header"></param>
/// <returns></returns>
public static WebsocketHeaderInfo Parse(Memory<byte> header)
{
Span<byte> span = header.Span;
int flag = 0xff;
int bit = 0x01;
ulong[] res = new ulong[bytes.Length];
Span<byte> temp = span;
WebsocketHeaderInfo headerInfo = new WebsocketHeaderInfo();
for (int i = 0, len = span.Length; i < len; i++)
//跳过头
temp = temp.Slice(temp.IndexOf(endBytes) + 2);
int splitIndex = 0;
//还有分割线
while ((splitIndex = temp.IndexOf(splitBytes)) >= 0)
{
if (span[i] == 13 && span[i + 1] == 10 && span[i + 2] == 13 && span[i + 3] == 10)
{
break;
}
if (span[i] == 13 && span[i + 1] == 10)
{
int startIndex = i + 2;
for (int k = 0; k < bytes.Length; k++)
{
if ((flag >> k & 1) == 1 && span[startIndex] == bytes[k][0])
{
if (span.Slice(startIndex, bytes[k].Length).SequenceEqual(bytes[k]))
{
int index = span.Slice(startIndex).IndexOf((byte)13);
flag &= ~(bit << k);
//取到key
string key = Encoding.UTF8.GetString(temp.Slice(0, splitIndex)).ToLowerInvariant();
//跳过key
temp = temp.Slice(splitIndex + 2);
#pragma warning disable CS0675 // 对进行了带符号扩展的操作数使用了按位或运算符
res[k] = (ulong)(startIndex + bytes[k].Length) << 32 | (ulong)(index - bytes[k].Length);
#pragma warning restore CS0675 // 对进行了带符号扩展的操作数使用了按位或运算符
//取到value
int endIndex = temp.IndexOf(endBytes);
string value = Encoding.UTF8.GetString(temp.Slice(0, endIndex));
//跳过value
temp = temp.Slice(endIndex + 2);
i += index + 1;
break;
}
}
}
}
headerInfo.Headers[key] = value;
}
WebsocketHeaderInfo headerInfo = new WebsocketHeaderInfo
{
Connection = header.Slice((int)(res[0] >> 32), (int)(res[0] & 0xffffffff)),
Upgrade = header.Slice((int)(res[1] >> 32), (int)(res[1] & 0xffffffff)),
Origin = header.Slice((int)(res[2] >> 32), (int)(res[2] & 0xffffffff)),
SecWebSocketVersion = header.Slice((int)(res[3] >> 32), (int)(res[3] & 0xffffffff)),
SecWebSocketKey = header.Slice((int)(res[4] >> 32), (int)(res[4] & 0xffffffff)),
SecWebSocketExtensions = header.Slice((int)(res[5] >> 32), (int)(res[5] & 0xffffffff)),
SecWebSocketProtocol = header.Slice((int)(res[6] >> 32), (int)(res[6] & 0xffffffff)),
SecWebSocketAccept = header.Slice((int)(res[7] >> 32), (int)(res[7] & 0xffffffff)),
};
int pathIndex = span.IndexOf((byte)32);
int pathIndex1 = span.Slice(pathIndex + 1).IndexOf((byte)32);
//响应的,获取状态码
if (header.Slice(0, httpBytes.Length).Span.SequenceEqual(httpBytes))
{
int code = int.Parse(Encoding.UTF8.GetString(header.Slice(pathIndex + 1, pathIndex1).Span));
headerInfo.StatusCode = (HttpStatusCode)code;
}
//请求的,获取路径和方法
else
{
headerInfo.Path = header.Slice(pathIndex + 1, pathIndex1);
headerInfo.Method = header.Slice(0, pathIndex);
headerInfo.Path = Encoding.UTF8.GetString(span.Slice(pathIndex + 1, pathIndex1));
headerInfo.Method = Encoding.UTF8.GetString(span.Slice(0, pathIndex));
}
return headerInfo;
}
}
public sealed class WebsocketHeaderKey
{
public static string Connection = "connection";
public static string Upgrade = "upgrade";
public static string Origin = "origin";
public static string SecWebSocketVersion = "sec-websocket-version";
public static string SecWebSocketKey = "sec-websocket-key";
public static string SecWebSocketExtensions = "sec-websocket-extensions";
public static string SecWebSocketProtocol = "sec-websocket-protocol";
public static string SecWebSocketAccept = "sec-websocket-accept";
}
}