Files
go-easy-utils/cryptox/rsa_test.go
2025-07-08 15:26:54 +08:00

255 lines
6.1 KiB
Go

package cryptox
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"github.com/stretchr/testify/assert"
"testing"
)
// 伪造一个读取器,用来模拟产生错误的情况
type badRandomReader struct{}
func (r *badRandomReader) Read([]byte) (int, error) {
return 0, errors.New("fake error")
}
func TestGenerateRSAKeys(t *testing.T) {
testCases := []struct {
name string
before func(t *testing.T)
after func(t *testing.T)
wantPrivateKey bool
wantPublicKey bool
wantErr error
}{
{
name: "生成成功",
before: func(t *testing.T) {},
after: func(t *testing.T) {},
wantPrivateKey: true,
wantPublicKey: true,
wantErr: nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.before(t)
privateKeyPEM, publicKeyPEM, err := GenerateRSAKeys()
assert.Equal(t, tc.wantErr, err)
assert.Equal(t, tc.wantPrivateKey, privateKeyPEM != "")
assert.Equal(t, tc.wantPublicKey, publicKeyPEM != "")
})
}
}
// 测试 rsa.GenerateKey 生产失败
func TestGenerateRSAKeysError(t *testing.T) {
// 保存原始的 rand.Reader
originalRandReader := rand.Reader
// 替换全局的 rand.Reader 为一个模拟错误的 Reader
rand.Reader = &badRandomReader{}
// 恢复原始的 rand.Reader
defer func() {
rand.Reader = originalRandReader
}()
privateKeyPEM, publicKeyPEM, err := GenerateRSAKeys()
// 检查返回的错误是否符合预期
expectedErr := errors.New("fake error")
if err == nil || err.Error() != expectedErr.Error() {
t.Errorf("Expected error '%v' but got '%v'", expectedErr, err)
}
// 检查返回的 PEM 字符串是否为空
if privateKeyPEM != "" {
t.Error("Expected empty private key PEM string")
}
if publicKeyPEM != "" {
t.Error("Expected empty public key PEM string")
}
}
func TestGenerateRSAKeys2(t *testing.T) {
// 测试 GenerateRSAKeys 是否能够正常工作
privateKeyPEM, publicKeyPEM, err := GenerateRSAKeys()
if err != nil {
t.Errorf("GenerateRSAKeys() error = %v", err)
return
}
// 检查返回的 PEM 字符串是否有效
if privateKeyPEM == "" {
t.Error("Expected non-empty private key PEM string")
}
if publicKeyPEM == "" {
t.Error("Expected non-empty public key PEM string")
}
// 尝试解析生成的私钥
block, _ := pem.Decode([]byte(privateKeyPEM))
if block == nil || block.Type != "RSA PRIVATE KEY" {
t.Errorf("Invalid private key PEM")
return
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
t.Errorf("Failed to parse private key: %v", err)
return
}
// 尝试解析生成的公钥
block, _ = pem.Decode([]byte(publicKeyPEM))
if block == nil || block.Type != "RSA PUBLIC KEY" {
t.Errorf("Invalid public key PEM")
return
}
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
t.Errorf("Failed to parse public key: %v", err)
return
}
rsaPublicKey, ok := publicKey.(*rsa.PublicKey)
if !ok {
t.Errorf("Failed to convert public key to RSA public key")
return
}
// 比较生成的密钥与解析的密钥
if !privateKey.PublicKey.Equal(rsaPublicKey) {
t.Errorf("Generated private key does not match parsed public key")
return
}
}
func TestEncryptRSA(t *testing.T) {
_, pubKeyStr, _ := GenerateRSAKeys()
testCases := []struct {
name string
publicKeyStr string
message []byte
expectedErr bool
}{
{
name: "正常加密",
publicKeyStr: pubKeyStr,
message: []byte("test message"),
expectedErr: false,
},
{
name: "无效公钥",
publicKeyStr: "invalid public key",
message: []byte("test message"),
expectedErr: true,
},
{
name: "无效公钥格式",
publicKeyStr: "-----BEGIN PUBLIC KEY-----\nInvalidKey\n-----END PUBLIC KEY-----",
message: []byte("test message"),
expectedErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := EncryptRSA(tc.publicKeyStr, tc.message)
if tc.expectedErr {
assert.Error(t, err)
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
}
})
}
}
func TestDecryptRSA(t *testing.T) {
privateKey, publicKey, _ := GenerateRSAKeys()
// 使用公钥加密一条消息,以便测试解密
message := []byte("test message")
ciphertext, err := EncryptRSA(publicKey, message)
if err != nil {
t.Fatalf("failed to encrypt message: %v", err)
}
testCases := []struct {
name string
privateKeyStr string
ciphertext []byte
expected []byte
expectedErr bool
}{
{
name: "正常解密",
privateKeyStr: privateKey,
ciphertext: ciphertext,
expected: message,
expectedErr: false,
},
{
name: "无效私钥",
privateKeyStr: "invalid private key",
ciphertext: ciphertext,
expected: nil,
expectedErr: true,
},
{
name: "无效私钥格式",
privateKeyStr: "-----BEGIN RSA PRIVATE KEY-----\nInvalidKey\n-----END RSA PRIVATE KEY-----",
ciphertext: ciphertext,
expected: nil,
expectedErr: true,
},
{
name: "无效密文",
privateKeyStr: privateKey,
ciphertext: []byte("invalid ciphertext"),
expected: nil,
expectedErr: true,
},
{
name: "无效私钥数据",
privateKeyStr: "-----BEGIN RSA PRIVATE KEY-----\n" + "A" + "\n-----END RSA PRIVATE KEY-----",
ciphertext: ciphertext,
expected: nil,
expectedErr: true,
},
{
name: "不完整的私钥",
privateKeyStr: "-----BEGIN RSA PRIVATE KEY-----\n" +
"MIICWwIBAAKBgQDEkzKS0u5p6kwl9m0g3g4mMI09S8QOAbW5aBMbDWZ5R0pUtH5h" +
"J9mQFt8Uu4FJ8Yc9C5ZiM5F9pV5J2V4SeKk3RbKjFG2iD6rzO/OMrMZ3/1H8n02" +
"eZ/D14SvnPBNhYnb8Ysdd4kS8A==\n-----END RSA PRIVATE KEY-----",
ciphertext: ciphertext,
expected: nil,
expectedErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := DecryptRSA(tc.privateKeyStr, tc.ciphertext)
if tc.expectedErr {
assert.Error(t, err)
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}