feat: 增加了集合方法 差集、对称差集、交集 (#77)

This commit is contained in:
libin
2024-07-13 16:15:57 +08:00
committed by GitHub
parent 37d3ed1753
commit 9f72a1e7ad
4 changed files with 280 additions and 0 deletions

View File

@@ -43,4 +43,13 @@ func SliceToMap[T any, K comparable](items []T, keyFunc func(T) K) map[K]T
// ExtractKeys 字段提取
func ExtractKeys[T any, K comparable](items []T, keyFunc func(T) K) []K
// Diff 计算差集
func Diff[T comparable](s []T, slices ...[]T) []T
// SymmetricDiff 计算对称差集
func SymmetricDiff[T comparable](slices ...[]T) []T
// Intersect 计算交集
func Intersect[T comparable](slices ...[]T) []T
```

70
sliceUtil/setof.go Normal file
View File

@@ -0,0 +1,70 @@
package sliceUtil
// Diff 计算差集
func Diff[T comparable](s []T, slices ...[]T) []T {
seen := make(map[T]bool)
for _, slice := range slices {
for _, elem := range slice {
seen[elem] = true
}
}
var result []T
for _, elem := range s {
if !seen[elem] {
result = append(result, elem)
}
}
return result
}
// SymmetricDiff 计算对称差集
func SymmetricDiff[T comparable](slices ...[]T) []T {
// 判断当前元素是否存在于其他切片中
containsAllNotMe := func(item T, idx int, slices ...[]T) bool {
for i, slice := range slices {
if i == idx {
continue
}
for _, val := range slice {
if val == item {
return true
}
}
}
return false
}
var result []T
for idx, slice := range slices {
for _, item := range slice {
if !containsAllNotMe(item, idx, slices...) {
result = append(result, item)
}
}
}
return UniqueSlice(result)
}
// Intersect 计算交集
func Intersect[T comparable](slices ...[]T) []T {
// 判断元素是否存在于其他切片中
containsAll := func(item T, slices [][]T) bool {
for _, slice := range slices {
if !InSlice(item, slice) {
return false
}
}
return true
}
var result []T
for _, item := range slices[0] {
if containsAll(item, slices[1:]) {
result = append(result, item)
}
}
return UniqueSlice(result)
}

View File

@@ -0,0 +1,39 @@
package sliceUtil
import "fmt"
func ExampleDiff() {
s1 := []int64{1, 2, 5, 7}
s2 := []int64{5, 6, 7}
diff := Diff(s1, s2)
fmt.Println(diff)
// Output:
// [1 2]
}
// 演示 SymmetricDiff 函数的用法
func ExampleSymmetricDiff() {
s1 := []int64{1, 5, 7}
s2 := []int64{5, 6, 7, 8, 9}
s3 := []int64{9, 10, 11}
diff := SymmetricDiff(s1, s2, s3)
fmt.Println(diff)
// Output:
// [1 6 8 10 11]
}
// 演示 ExampleIntersect 函数的用法
func ExampleIntersect() {
s1 := []int64{1, 5, 7}
s2 := []int64{5, 6, 7, 8, 9}
diff := Intersect(s1, s2)
fmt.Println(diff)
// Output:
// [5 7]
}

162
sliceUtil/setof_test.go Normal file
View File

@@ -0,0 +1,162 @@
package sliceUtil
import (
"reflect"
"testing"
)
func TestDiff(t *testing.T) {
testCases := []struct {
name string
s []int
slices [][]int
want []int
}{
{
name: "第一个切片与空切片列表的比较",
s: []int{1, 2, 3},
slices: [][]int{},
want: []int{1, 2, 3},
},
{
name: "第一个切片与包含部分相同元素的切片列表的比较",
s: []int{1, 2, 3, 4, 5},
slices: [][]int{{4, 5}, {6, 7}},
want: []int{1, 2, 3},
},
{
name: "第一个切片与包含全部相同元素的切片列表的比较",
s: []int{1, 2, 3},
slices: [][]int{{1, 2, 3}},
want: []int{},
},
{
name: "第一个切片与不包含任何相同元素的切片列表的比较",
s: []int{1, 2, 3},
slices: [][]int{{4, 5, 6}},
want: []int{1, 2, 3},
},
{
name: "空切片与任何切片列表的比较",
s: []int{},
slices: [][]int{{1, 2, 3}},
want: []int{},
},
}
for _, tc := range testCases {
t.Run("", func(t *testing.T) {
got := Diff(tc.s, tc.slices...)
if len(got) == 0 && len(tc.want) == 0 {
return
}
if !reflect.DeepEqual(got, tc.want) {
t.Errorf("Diff(%v, %v) = %v, want %v", tc.s, tc.slices, got, tc.want)
}
})
}
}
func TestSymmetricDiff(t *testing.T) {
tests := []struct {
name string
slices [][]int
expected []int
}{
{
name: "两个切片计算对称差集",
slices: [][]int{{1, 2, 3}, {2, 3}},
expected: []int{1},
},
{
name: "多个切片计算对称差集",
slices: [][]int{{2, 4}, {3}, {1, 2, 3, 4, 5}},
expected: []int{1, 5},
},
{
name: "No Symmetric Difference",
slices: [][]int{{1, 2, 3}, {1, 2, 3}},
expected: []int{},
},
{
name: "Empty Base Slice",
slices: [][]int{{}, {1, 2, 3}},
expected: []int{1, 2, 3},
},
{
name: "Empty Comparison Slices",
slices: [][]int{{1, 2, 3}, {}, {}},
expected: []int{1, 2, 3},
},
{
name: "All Empty Slices",
slices: [][]int{{}, {}, {}},
expected: []int{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SymmetricDiff(tt.slices...)
if len(got) == 0 && len(tt.expected) == 0 {
return
}
if !reflect.DeepEqual(got, tt.expected) {
t.Errorf("Diff() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIntersect(t *testing.T) {
tests := []struct {
name string
slices [][]int
expected []int
}{
{
name: "Single Intersection",
slices: [][]int{{1, 2, 3}, {2, 3, 4}},
expected: []int{2, 3},
},
{
name: "Multiple Intersections",
slices: [][]int{{2, 3, 4}, {2, 3, 5}, {1, 2, 2, 3}},
expected: []int{2, 3},
},
{
name: "No Intersection",
slices: [][]int{{1, 2, 3}, {4, 5, 6}},
expected: []int{},
},
{
name: "Empty Base Slice",
slices: [][]int{{}, {1, 2, 3}},
expected: []int{},
},
{
name: "Empty Comparison Slices",
slices: [][]int{{1, 2, 3}, {}, {}},
expected: []int{},
},
{
name: "All Empty Slices",
slices: [][]int{{}, {}, {}},
expected: []int{},
},
{
name: "All Same Elements",
slices: [][]int{{1, 2, 3}, {1, 2, 3}, {1, 2, 3}},
expected: []int{1, 2, 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := Intersect(tt.slices...)
if len(got) == 0 && len(tt.expected) == 0 {
return
}
if !reflect.DeepEqual(got, tt.expected) {
t.Errorf("Intersect() = %v, want %v", got, tt.expected)
}
})
}
}