mirror of
https://github.com/jefferyjob/go-easy-utils.git
synced 2025-09-27 03:15:55 +08:00
255 lines
6.1 KiB
Go
255 lines
6.1 KiB
Go
package cryptox
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"errors"
|
|
"github.com/stretchr/testify/assert"
|
|
"testing"
|
|
)
|
|
|
|
// 伪造一个读取器,用来模拟产生错误的情况
|
|
type badRandomReader struct{}
|
|
|
|
func (r *badRandomReader) Read([]byte) (int, error) {
|
|
return 0, errors.New("fake error")
|
|
}
|
|
|
|
func TestGenerateRSAKeys(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
before func(t *testing.T)
|
|
after func(t *testing.T)
|
|
wantPrivateKey bool
|
|
wantPublicKey bool
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "生成成功",
|
|
before: func(t *testing.T) {},
|
|
after: func(t *testing.T) {},
|
|
wantPrivateKey: true,
|
|
wantPublicKey: true,
|
|
wantErr: nil,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
tc.before(t)
|
|
|
|
privateKeyPEM, publicKeyPEM, err := GenerateRSAKeys()
|
|
assert.Equal(t, tc.wantErr, err)
|
|
assert.Equal(t, tc.wantPrivateKey, privateKeyPEM != "")
|
|
assert.Equal(t, tc.wantPublicKey, publicKeyPEM != "")
|
|
})
|
|
}
|
|
}
|
|
|
|
// 测试 rsa.GenerateKey 生产失败
|
|
func TestGenerateRSAKeysError(t *testing.T) {
|
|
// 保存原始的 rand.Reader
|
|
originalRandReader := rand.Reader
|
|
|
|
// 替换全局的 rand.Reader 为一个模拟错误的 Reader
|
|
rand.Reader = &badRandomReader{}
|
|
|
|
// 恢复原始的 rand.Reader
|
|
defer func() {
|
|
rand.Reader = originalRandReader
|
|
}()
|
|
|
|
privateKeyPEM, publicKeyPEM, err := GenerateRSAKeys()
|
|
|
|
// 检查返回的错误是否符合预期
|
|
expectedErr := errors.New("fake error")
|
|
if err == nil || err.Error() != expectedErr.Error() {
|
|
t.Errorf("Expected error '%v' but got '%v'", expectedErr, err)
|
|
}
|
|
|
|
// 检查返回的 PEM 字符串是否为空
|
|
if privateKeyPEM != "" {
|
|
t.Error("Expected empty private key PEM string")
|
|
}
|
|
|
|
if publicKeyPEM != "" {
|
|
t.Error("Expected empty public key PEM string")
|
|
}
|
|
}
|
|
|
|
func TestGenerateRSAKeys2(t *testing.T) {
|
|
// 测试 GenerateRSAKeys 是否能够正常工作
|
|
privateKeyPEM, publicKeyPEM, err := GenerateRSAKeys()
|
|
if err != nil {
|
|
t.Errorf("GenerateRSAKeys() error = %v", err)
|
|
return
|
|
}
|
|
|
|
// 检查返回的 PEM 字符串是否有效
|
|
if privateKeyPEM == "" {
|
|
t.Error("Expected non-empty private key PEM string")
|
|
}
|
|
|
|
if publicKeyPEM == "" {
|
|
t.Error("Expected non-empty public key PEM string")
|
|
}
|
|
|
|
// 尝试解析生成的私钥
|
|
block, _ := pem.Decode([]byte(privateKeyPEM))
|
|
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
|
t.Errorf("Invalid private key PEM")
|
|
return
|
|
}
|
|
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
t.Errorf("Failed to parse private key: %v", err)
|
|
return
|
|
}
|
|
|
|
// 尝试解析生成的公钥
|
|
block, _ = pem.Decode([]byte(publicKeyPEM))
|
|
if block == nil || block.Type != "RSA PUBLIC KEY" {
|
|
t.Errorf("Invalid public key PEM")
|
|
return
|
|
}
|
|
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
|
if err != nil {
|
|
t.Errorf("Failed to parse public key: %v", err)
|
|
return
|
|
}
|
|
rsaPublicKey, ok := publicKey.(*rsa.PublicKey)
|
|
if !ok {
|
|
t.Errorf("Failed to convert public key to RSA public key")
|
|
return
|
|
}
|
|
|
|
// 比较生成的密钥与解析的密钥
|
|
if !privateKey.PublicKey.Equal(rsaPublicKey) {
|
|
t.Errorf("Generated private key does not match parsed public key")
|
|
return
|
|
}
|
|
}
|
|
|
|
func TestEncryptRSA(t *testing.T) {
|
|
_, pubKeyStr, _ := GenerateRSAKeys()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
publicKeyStr string
|
|
message []byte
|
|
expectedErr bool
|
|
}{
|
|
{
|
|
name: "正常加密",
|
|
publicKeyStr: pubKeyStr,
|
|
message: []byte("test message"),
|
|
expectedErr: false,
|
|
},
|
|
{
|
|
name: "无效公钥",
|
|
publicKeyStr: "invalid public key",
|
|
message: []byte("test message"),
|
|
expectedErr: true,
|
|
},
|
|
{
|
|
name: "无效公钥格式",
|
|
publicKeyStr: "-----BEGIN PUBLIC KEY-----\nInvalidKey\n-----END PUBLIC KEY-----",
|
|
message: []byte("test message"),
|
|
expectedErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
result, err := EncryptRSA(tc.publicKeyStr, tc.message)
|
|
if tc.expectedErr {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDecryptRSA(t *testing.T) {
|
|
privateKey, publicKey, _ := GenerateRSAKeys()
|
|
|
|
// 使用公钥加密一条消息,以便测试解密
|
|
message := []byte("test message")
|
|
ciphertext, err := EncryptRSA(publicKey, message)
|
|
if err != nil {
|
|
t.Fatalf("failed to encrypt message: %v", err)
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
privateKeyStr string
|
|
ciphertext []byte
|
|
expected []byte
|
|
expectedErr bool
|
|
}{
|
|
{
|
|
name: "正常解密",
|
|
privateKeyStr: privateKey,
|
|
ciphertext: ciphertext,
|
|
expected: message,
|
|
expectedErr: false,
|
|
},
|
|
{
|
|
name: "无效私钥",
|
|
privateKeyStr: "invalid private key",
|
|
ciphertext: ciphertext,
|
|
expected: nil,
|
|
expectedErr: true,
|
|
},
|
|
{
|
|
name: "无效私钥格式",
|
|
privateKeyStr: "-----BEGIN RSA PRIVATE KEY-----\nInvalidKey\n-----END RSA PRIVATE KEY-----",
|
|
ciphertext: ciphertext,
|
|
expected: nil,
|
|
expectedErr: true,
|
|
},
|
|
{
|
|
name: "无效密文",
|
|
privateKeyStr: privateKey,
|
|
ciphertext: []byte("invalid ciphertext"),
|
|
expected: nil,
|
|
expectedErr: true,
|
|
},
|
|
{
|
|
name: "无效私钥数据",
|
|
privateKeyStr: "-----BEGIN RSA PRIVATE KEY-----\n" + "A" + "\n-----END RSA PRIVATE KEY-----",
|
|
ciphertext: ciphertext,
|
|
expected: nil,
|
|
expectedErr: true,
|
|
},
|
|
{
|
|
name: "不完整的私钥",
|
|
privateKeyStr: "-----BEGIN RSA PRIVATE KEY-----\n" +
|
|
"MIICWwIBAAKBgQDEkzKS0u5p6kwl9m0g3g4mMI09S8QOAbW5aBMbDWZ5R0pUtH5h" +
|
|
"J9mQFt8Uu4FJ8Yc9C5ZiM5F9pV5J2V4SeKk3RbKjFG2iD6rzO/OMrMZ3/1H8n02" +
|
|
"eZ/D14SvnPBNhYnb8Ysdd4kS8A==\n-----END RSA PRIVATE KEY-----",
|
|
ciphertext: ciphertext,
|
|
expected: nil,
|
|
expectedErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
result, err := DecryptRSA(tc.privateKeyStr, tc.ciphertext)
|
|
if tc.expectedErr {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tc.expected, result)
|
|
}
|
|
})
|
|
}
|
|
}
|