diff --git a/traversal.go b/traversal.go index 758a63b..56bcb3c 100644 --- a/traversal.go +++ b/traversal.go @@ -178,6 +178,12 @@ func (this *Selection) ParentsFilteredUntilNodes(filterSelector string, nodes .. return pushStack(this, n) } +// Siblings() gets the siblings of each element in the Selection. It returns +// a new Selection object containing the matched elements. +func (this *Selection) Siblings() *Selection { + return pushStack(this, getSiblingNodes(this.Nodes)) +} + // Internal implementation to get all parent nodes, stopping at the specified // node (or nil if no stop). func getParentsNodes(nodes []*html.Node, stopSelector string, stopNodes []*html.Node) []*html.Node { @@ -201,6 +207,20 @@ func getParentsNodes(nodes []*html.Node, stopSelector string, stopNodes []*html. }) } +// Internal implementation of sibling nodes that return a raw slice of matches. +func getSiblingNodes(nodes []*html.Node) []*html.Node { + return mapNodes(nodes, func(i int, n *html.Node) (result []*html.Node) { + if p := n.Parent; p != nil { + for _, c := range p.Child { + if c != n && c.Type == html.ElementNode { + result = append(result, c) + } + } + } + return + }) +} + // Internal implementation of parent nodes that return a raw slice of Nodes. func getParentNodes(nodes []*html.Node) []*html.Node { return mapNodes(nodes, func(i int, n *html.Node) []*html.Node { @@ -247,6 +267,7 @@ func findWithContext(selector string, nodes ...*html.Node) []*html.Node { // Return the child nodes of each node in the Selection object, without // duplicates. func getSelectionChildren(s *Selection, elemOnly bool) (result []*html.Node) { + // TODO : Refactor to use mapNodes? for _, n := range s.Nodes { result = appendWithoutDuplicates(result, getChildren(n, elemOnly)) } diff --git a/traversal_test.go b/traversal_test.go index f76906b..102714f 100644 --- a/traversal_test.go +++ b/traversal_test.go @@ -112,3 +112,13 @@ func TestParentsFilteredUntilNodes(t *testing.T) { sel = sel.ParentsFilteredUntilNodes("body", sel2.Nodes...) AssertLength(t, sel.Nodes, 1) } + +func TestSiblings(t *testing.T) { + sel := Doc().Root.Find("h1").Siblings() + AssertLength(t, sel.Nodes, 1) +} + +func TestSiblings2(t *testing.T) { + sel := Doc().Root.Find(".pvk-gutter").Siblings() + AssertLength(t, sel.Nodes, 9) +}