+
diff --git a/traversal.go b/traversal.go
index 9febae3..416bb7d 100644
--- a/traversal.go
+++ b/traversal.go
@@ -182,7 +182,7 @@ func (this *Selection) ParentsFilteredUntilNodes(filterSelector string, nodes ..
// Siblings() gets the siblings of each element in the Selection. It returns
// a new Selection object containing the matched elements.
func (this *Selection) Siblings() *Selection {
- return pushStack(this, getSiblingNodes(this.Nodes))
+ return pushStack(this, getSiblingNodes(this.Nodes, 0))
}
// SiblingsFiltered() gets the siblings of each element in the Selection
@@ -190,7 +190,26 @@ func (this *Selection) Siblings() *Selection {
// matched elements.
func (this *Selection) SiblingsFiltered(selector string) *Selection {
// Get the Siblings() unfiltered
- n := getSiblingNodes(this.Nodes)
+ n := getSiblingNodes(this.Nodes, 0)
+ // Create a temporary Selection to filter using winnow
+ sel := &Selection{n, this.document, nil}
+ // Filter based on selector
+ n = winnow(sel, selector, true)
+ return pushStack(this, n)
+}
+
+// Next() gets the immediately following sibling of each element in the
+// Selection. It returns a new Selection object containing the matched elements.
+func (this *Selection) Next() *Selection {
+ return pushStack(this, getSiblingNodes(this.Nodes, 1))
+}
+
+// NextFiltered() gets the immediately following sibling of each element in the
+// Selection filtered by a selector. It returns a new Selection object
+// containing the matched elements.
+func (this *Selection) NextFiltered(selector string) *Selection {
+ // Get the Next() unfiltered
+ n := getSiblingNodes(this.Nodes, 1)
// Create a temporary Selection to filter using winnow
sel := &Selection{n, this.document, nil}
// Filter based on selector
@@ -222,12 +241,36 @@ func getParentsNodes(nodes []*html.Node, stopSelector string, stopNodes []*html.
}
// Internal implementation of sibling nodes that return a raw slice of matches.
-func getSiblingNodes(nodes []*html.Node) []*html.Node {
+func getSiblingNodes(nodes []*html.Node, siblingType int) []*html.Node {
+ // Sibling type means:
+ // -1 : previous node only
+ // 1 : next node only
+ // 0 : all but itself
return mapNodes(nodes, func(i int, n *html.Node) (result []*html.Node) {
+ var prev *html.Node
+
+ // Get the parent and loop through all children
if p := n.Parent; p != nil {
for _, c := range p.Child {
- if c != n && c.Type == html.ElementNode {
- result = append(result, c)
+ // Care only about elements
+ if c.Type == html.ElementNode {
+ // Is it the existing node?
+ if c == n && siblingType == -1 {
+ // We want the previous node only, so append it and return
+ if prev != nil {
+ result = append(result, prev)
+ }
+ return
+ } else if prev == n && siblingType == 1 {
+ // We want only the next node and this is it, so append it and return
+ result = append(result, c)
+ return
+ }
+ prev = c
+ if c != n && siblingType == 0 {
+ // This is not the original node, so append it
+ result = append(result, c)
+ }
}
}
}
diff --git a/traversal_test.go b/traversal_test.go
index fdccfcf..1d1de47 100644
--- a/traversal_test.go
+++ b/traversal_test.go
@@ -132,3 +132,28 @@ func TestSiblingsFiltered(t *testing.T) {
sel := Doc().Root.Find(".pvk-gutter").SiblingsFiltered(".pvk-content")
AssertLength(t, sel.Nodes, 3)
}
+
+func TestNext(t *testing.T) {
+ sel := Doc().Root.Find("h1").Next()
+ AssertLength(t, sel.Nodes, 1)
+}
+
+func TestNext2(t *testing.T) {
+ sel := Doc().Root.Find(".close").Next()
+ AssertLength(t, sel.Nodes, 1)
+}
+
+func TestNextNone(t *testing.T) {
+ sel := Doc().Root.Find("small").Next()
+ AssertLength(t, sel.Nodes, 0)
+}
+
+func TestNextFiltered(t *testing.T) {
+ sel := Doc().Root.Find(".container-fluid").NextFiltered("div")
+ AssertLength(t, sel.Nodes, 2)
+}
+
+func TestNextFiltered2(t *testing.T) {
+ sel := Doc().Root.Find(".container-fluid").NextFiltered("[ng-view]")
+ AssertLength(t, sel.Nodes, 1)
+}