diff --git a/resources/psutil/process.go b/resources/psutil/process.go index 13ebccbf..d0e0f0a6 100644 --- a/resources/psutil/process.go +++ b/resources/psutil/process.go @@ -36,6 +36,7 @@ type process struct { cpuLimit uint64 ncpu float64 proc *psprocess.Process + procfs Procfs stopTicker context.CancelFunc @@ -57,6 +58,7 @@ func (u *util) Process(pid int32) (Process, error) { cpuLimit: u.cpuLimit, ncpu: u.ncpu, gpu: u.gpu, + procfs: u.procfs, } proc, err := psprocess.NewProcess(pid) @@ -117,13 +119,10 @@ func (p *process) collectCPU() cpuTimesStat { func (p *process) collectCPUFromChildren(proc *psprocess.Process) *cpuTimesStat { stat := cpuTimesStat{} - children, err := proc.Children() - if err != nil { - return &stat - } + children := p.procfs.AllChildren(proc.Pid) - for _, child := range children { - cstat, err := cpuTimes(child.Pid) + for _, pid := range children { + cstat, err := cpuTimes(pid) if err != nil { continue } @@ -133,14 +132,6 @@ func (p *process) collectCPUFromChildren(proc *psprocess.Process) *cpuTimesStat stat.user += cstat.user stat.idle += cstat.idle stat.other += cstat.other - - cstat = p.collectCPUFromChildren(child) - - stat.total += cstat.total - stat.system += cstat.system - stat.user += cstat.user - stat.idle += cstat.idle - stat.other += cstat.other } return &stat @@ -178,22 +169,22 @@ func (p *process) collectMemory() uint64 { } func (p *process) collectMemoryFromChildren(proc *psprocess.Process) uint64 { - children, err := proc.Children() - if err != nil { - return 0 - } + children := p.procfs.AllChildren(proc.Pid) rss := uint64(0) - for _, child := range children { + for _, pid := range children { + child, err := psprocess.NewProcess(pid) + if err != nil { + continue + } + info, err := child.MemoryInfo() if err != nil { continue } rss += info.RSS - - rss += p.collectMemoryFromChildren(child) } return rss diff --git a/resources/psutil/procfs.go b/resources/psutil/procfs.go new file mode 100644 index 00000000..db3563e8 --- /dev/null +++ b/resources/psutil/procfs.go @@ -0,0 +1,146 @@ +package psutil + +import ( + "bytes" + "context" + "io/fs" + "os" + "regexp" + "slices" + "strconv" + "sync" + "time" +) + +type Procfs interface { + // Children returns all direct children of a process + Children(ppid int32) []int32 + + // AllChildren returns all children of a process + AllChildren(ppid int32) []int32 +} + +type procfs struct { + children map[int32][]int32 + + lock sync.RWMutex +} + +func NewProcfs(ctx context.Context, interval time.Duration) (Procfs, error) { + p := &procfs{ + children: map[int32][]int32{}, + } + + children, err := p.createChildrenMap() + if err != nil { + return p, err + } + + p.children = children + + go p.ticker(ctx, interval) + + return p, nil +} + +func (p *procfs) Children(ppid int32) []int32 { + p.lock.RLock() + defer p.lock.RUnlock() + + pids, ok := p.children[ppid] + if !ok { + return []int32{} + } + + return slices.Clone(pids) +} + +func (p *procfs) AllChildren(ppid int32) []int32 { + children := p.Children(ppid) + + allchildren := slices.Clone(children) + + for _, child := range children { + allchildren = append(allchildren, p.AllChildren(child)...) + } + + return allchildren +} + +func (p *procfs) ticker(ctx context.Context, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + children, err := p.createChildrenMap() + if err == nil { + p.lock.Lock() + p.children = children + p.lock.Unlock() + } + } + } +} + +func (p *procfs) createChildrenMap() (map[int32][]int32, error) { + children := map[int32][]int32{} + re := regexp.MustCompile(`^[0-9]+$`) + + proc := os.Getenv("HOST_PROC") + if proc == "" { + proc = "/proc" + } + + fs := os.DirFS(proc).(fs.ReadDirFS) + dirents, err := fs.ReadDir(".") + if err != nil { + return nil, err + } + + for _, d := range dirents { + if !d.IsDir() { + continue + } + + name := d.Name() + + if !re.MatchString(name) { + continue + } + + data, err := os.ReadFile(proc + "/" + name + "/stat") + if err != nil { + continue + } + + fields := bytes.Split(data, []byte{' '}) + if len(fields) < 4 { + continue + } + + var pid int32 = 0 + var ppid int32 = 0 + + if x, err := strconv.ParseInt(string(fields[3]), 10, 32); err == nil { + ppid = int32(x) + } + + if x, err := strconv.ParseInt(name, 10, 32); err == nil { + pid = int32(x) + } + + if pid == 0 { + continue + } + + c := children[ppid] + c = append(c, pid) + children[ppid] = c + } + + return children, nil +} diff --git a/resources/psutil/procfs_test.go b/resources/psutil/procfs_test.go new file mode 100644 index 00000000..8e271d2d --- /dev/null +++ b/resources/psutil/procfs_test.go @@ -0,0 +1,41 @@ +package psutil + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestChildren(t *testing.T) { + p := &procfs{ + children: map[int32][]int32{ + 0: {1}, + 1: {2, 3}, + 2: {4, 5, 6}, + 3: {7, 8, 9}, + }, + } + + require.Equal(t, []int32{1}, p.Children(0)) + require.Equal(t, []int32{2, 3}, p.Children(1)) + require.Equal(t, []int32{4, 5, 6}, p.Children(2)) + require.Equal(t, []int32{7, 8, 9}, p.Children(3)) + require.Equal(t, []int32{}, p.Children(4)) +} + +func TestAllChildren(t *testing.T) { + p := &procfs{ + children: map[int32][]int32{ + 0: {1}, + 1: {2, 3}, + 2: {4, 5, 6}, + 3: {7, 8, 9}, + }, + } + + require.Equal(t, []int32{1, 2, 3, 4, 5, 6, 7, 8, 9}, p.AllChildren(0)) + require.Equal(t, []int32{2, 3, 4, 5, 6, 7, 8, 9}, p.AllChildren(1)) + require.Equal(t, []int32{4, 5, 6}, p.AllChildren(2)) + require.Equal(t, []int32{7, 8, 9}, p.AllChildren(3)) + require.Equal(t, []int32{}, p.AllChildren(4)) +} diff --git a/resources/psutil/psutil.go b/resources/psutil/psutil.go index 89070762..9a4a6016 100644 --- a/resources/psutil/psutil.go +++ b/resources/psutil/psutil.go @@ -137,6 +137,8 @@ type util struct { mem MemoryInfo gpu psutilgpu.GPU + + procfs Procfs } // New returns a new util, it will be started automatically @@ -184,6 +186,9 @@ func New(root string, gpu psutilgpu.GPU) (Util, error) { go u.tickCPU(ctx, time.Second) go u.tickMemory(ctx, time.Second) + procfs, _ := NewProcfs(ctx, 5*time.Second) + u.procfs = procfs + u.stopOnce = sync.Once{} return u, nil @@ -224,7 +229,8 @@ func (u *util) detectCgroupVersion() int { } func (u *util) cgroupCPULimit(version int) (uint64, float64) { - if version == 1 { + switch version { + case 1: lines, err := u.readFile("cpu/cpu.cfs_quota_us") if err != nil { return 0, 0 @@ -248,7 +254,7 @@ func (u *util) cgroupCPULimit(version int) (uint64, float64) { return uint64(1e6/period*quota) * 1e3, quota / period // nanoseconds } - } else if version == 2 { + case 2: lines, err := u.readFile("cpu.max") if err != nil { return 0, 0 @@ -437,7 +443,8 @@ func (u *util) CPU() (*CPUInfo, error) { func (u *util) cgroupCPUTimes(version int) (*cpuTimesStat, error) { info := &cpuTimesStat{} - if version == 1 { + switch version { + case 1: lines, err := u.readFile("cpuacct/cpuacct.usage") if err != nil { return nil, err @@ -449,7 +456,7 @@ func (u *util) cgroupCPUTimes(version int) (*cpuTimesStat, error) { } info.system = usage - } else if version == 2 { + case 2: lines, err := u.readFile("cpu.stat") if err != nil { return nil, err @@ -523,7 +530,8 @@ func (u *util) Memory() (*MemoryInfo, error) { func (u *util) cgroupVirtualMemory(version int) (*MemoryInfo, error) { info := &MemoryInfo{} - if version == 1 { + switch version { + case 1: lines, err := u.readFile("memory/memory.limit_in_bytes") if err != nil { return nil, err @@ -547,7 +555,7 @@ func (u *util) cgroupVirtualMemory(version int) (*MemoryInfo, error) { info.Total = total info.Available = total - used info.Used = used - } else if version == 2 { + case 2: lines, err := u.readFile("memory.max") if err != nil { return nil, err