diff --git a/ipinfo/cmd_tool_aggregate.go b/ipinfo/cmd_tool_aggregate.go index 5445a0e2..b3da67ed 100644 --- a/ipinfo/cmd_tool_aggregate.go +++ b/ipinfo/cmd_tool_aggregate.go @@ -23,9 +23,7 @@ func printHelpToolAggregate() { `Usage: %s tool aggregate [] Description: - Accepts IPs, IP ranges, and CIDRs, aggregating them efficiently. - Input can be IPs, IP ranges, CIDRs, and/or filepath to a file - containing any of these. Works for both IPv4 and IPv6. + Accepts IPv4 IPs and CIDRs, aggregating them efficiently. If input contains single IPs, it tries to merge them into the input CIDRs, otherwise they are printed to the output as they are. @@ -37,9 +35,6 @@ Examples: # Aggregate two CIDRs. $ %[1]s tool aggregate 1.1.1.0/30 1.1.1.0/28 - # Aggregate IP range and CIDR. - $ %[1]s tool aggregate 1.1.1.0-1.1.1.244 1.1.1.0/28 - # Aggregate enteries from 2 files. $ %[1]s tool aggregate /path/to/file1.txt /path/to/file2.txt diff --git a/lib/cidr.go b/lib/cidr.go new file mode 100644 index 00000000..06b96d81 --- /dev/null +++ b/lib/cidr.go @@ -0,0 +1,68 @@ +package lib + +import ( + "bytes" + "encoding/binary" + "math" + "net" + "sort" +) + +// CIDR represens a Classless Inter-Domain Routing structure. +type CIDR struct { + IP net.IP + Network *net.IPNet +} + +// newCidr creates a newCidr CIDR structure. +func newCidr(s string) *CIDR { + ip, ipnet, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return &CIDR{ + IP: ip, + Network: ipnet, + } +} + +func (c *CIDR) String() string { + return c.Network.String() +} + +// MaskLen returns a network mask length. +func (c *CIDR) MaskLen() uint32 { + i, _ := c.Network.Mask.Size() + return uint32(i) +} + +// PrefixUint32 returns a prefix. +func (c *CIDR) PrefixUint32() uint32 { + return binary.BigEndian.Uint32(c.IP.To4()) +} + +// Size returns a size of a CIDR range. +func (c *CIDR) Size() int { + ones, bits := c.Network.Mask.Size() + return int(math.Pow(2, float64(bits-ones))) +} + +// list returns a slice of sorted CIDR structures. +func list(s []string) []*CIDR { + out := make([]*CIDR, 0) + for _, c := range s { + out = append(out, newCidr(c)) + } + sort.Sort(cidrSort(out)) + return out +} + +type cidrSort []*CIDR + +func (s cidrSort) Len() int { return len(s) } +func (s cidrSort) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func (s cidrSort) Less(i, j int) bool { + cmp := bytes.Compare(s[i].IP, s[j].IP) + return cmp < 0 || (cmp == 0 && s[i].MaskLen() < s[j].MaskLen()) +} diff --git a/lib/cmd_tool_aggregate.go b/lib/cmd_tool_aggregate.go index 4dbf8bc4..4be9a686 100644 --- a/lib/cmd_tool_aggregate.go +++ b/lib/cmd_tool_aggregate.go @@ -2,12 +2,10 @@ package lib import ( "bufio" - "bytes" "fmt" "io" "net" "os" - "sort" "strings" "github.com/spf13/pflag" @@ -55,72 +53,21 @@ func CmdToolAggregate( return nil } - // Parses a list of CIDRs. - parseCIDRs := func(cidrs []string) []net.IPNet { - parsedCIDRs := make([]net.IPNet, 0) - for _, cidrStr := range cidrs { - _, ipNet, err := net.ParseCIDR(cidrStr) - if err != nil { - if !f.Quiet { - fmt.Printf("Invalid CIDR: %s\n", cidrStr) - } - continue - } - parsedCIDRs = append(parsedCIDRs, *ipNet) - } - - return parsedCIDRs - } - // Input parser. - parseInput := func(rows []string) ([]net.IPNet, []net.IP) { - parsedCIDRs := make([]net.IPNet, 0) + parseInput := func(rows []string) ([]string, []net.IP) { + parsedCIDRs := make([]string, 0) parsedIPs := make([]net.IP, 0) - var separator string for _, rowStr := range rows { if strings.ContainsAny(rowStr, ",-") { - if delim := strings.ContainsRune(rowStr, ','); delim { - separator = "," - } else { - separator = "-" - } - - ipRange := strings.Split(rowStr, separator) - if len(ipRange) != 2 { - if !f.Quiet { - fmt.Printf("Invalid IP range: %s\n", rowStr) - } - continue - } - - if strings.ContainsRune(rowStr, ':') { - cidrs, err := CIDRsFromIP6RangeStrRaw(rowStr) - if err == nil { - parsedCIDRs = append(parsedCIDRs, parseCIDRs(cidrs)...) - continue - } else { - if !f.Quiet { - fmt.Printf("Invalid IP range %s. Err: %v\n", rowStr, err) - } - continue - } - } else { - cidrs, err := CIDRsFromIPRangeStrRaw(rowStr) - if err == nil { - parsedCIDRs = append(parsedCIDRs, parseCIDRs(cidrs)...) - continue - } else { - if !f.Quiet { - fmt.Printf("Invalid IP range %s. Err: %v\n", rowStr, err) - } - continue - } - } + continue } else if strings.ContainsRune(rowStr, '/') { - parsedCIDRs = append(parsedCIDRs, parseCIDRs([]string{rowStr})...) + _, ipnet, err := net.ParseCIDR(rowStr) + if err == nil && IsCIDRIPv4(ipnet) { + parsedCIDRs = append(parsedCIDRs, []string{rowStr}...) + } continue } else { - if ip := net.ParseIP(rowStr); ip != nil { + if ip := net.ParseIP(rowStr); IsIPv4(ip) { parsedIPs = append(parsedIPs, ip) } else { if !f.Quiet { @@ -165,7 +112,7 @@ func CmdToolAggregate( } // Vars to contain CIDRs/IPs from all input sources. - parsedCIDRs := make([]net.IPNet, 0) + parsedCIDRs := make([]string, 0) parsedIPs := make([]net.IP, 0) // Collect CIDRs/IPs from stdin. @@ -187,30 +134,35 @@ func CmdToolAggregate( rows := scanrdr(file) file.Close() cidrs, ips := parseInput(rows) + parsedCIDRs = append(parsedCIDRs, cidrs...) parsedIPs = append(parsedIPs, ips...) } - // Sort and merge collected CIDRs and IPs. - aggregatedCIDRs := aggregateCIDRs(parsedCIDRs) + adjacentCombined := combineAdjacent(stripOverlapping(list(parsedCIDRs))) + outlierIPs := make([]net.IP, 0) - length := len(aggregatedCIDRs) - for _, ip := range parsedIPs { - for i, cidr := range aggregatedCIDRs { - if cidr.Contains(ip) { - break - } else if i == length-1 { - outlierIPs = append(outlierIPs, ip) + length := len(adjacentCombined) + if length != 0 { + for _, ip := range parsedIPs { + for i, cidr := range adjacentCombined { + if cidr.Network.Contains(ip) { + break + } else if i == length-1 { + outlierIPs = append(outlierIPs, ip) + } } } + } else { + outlierIPs = append(outlierIPs, parsedIPs...) } // Print the aggregated CIDRs. - for _, r := range aggregatedCIDRs { + for _, r := range adjacentCombined { fmt.Println(r.String()) } - // Print outliers. + // Print the outlierIPs. for _, r := range outlierIPs { fmt.Println(r.String()) } @@ -218,62 +170,70 @@ func CmdToolAggregate( return nil } -// Helper function to aggregate IP ranges. -func aggregateCIDRs(cidrs []net.IPNet) []net.IPNet { - aggregatedCIDRs := make([]net.IPNet, 0) - - // Sort CIDRs by starting IP. - sortCIDRs(cidrs) - - for _, r := range cidrs { - if len(aggregatedCIDRs) == 0 { - aggregatedCIDRs = append(aggregatedCIDRs, r) +// stripOverlapping returns a slice of CIDR structures with overlapping ranges +// stripped. +func stripOverlapping(s []*CIDR) []*CIDR { + l := len(s) + for i := 0; i < l-1; i++ { + if s[i] == nil { continue } - - last := len(aggregatedCIDRs) - 1 - prev := aggregatedCIDRs[last] - - if canAggregate(prev, r) { - // Merge overlapping CIDRs. - aggregatedCIDRs[last] = aggregateCIDR(prev, r) - } else { - aggregatedCIDRs = append(aggregatedCIDRs, r) + for j := i + 1; j < l; j++ { + if overlaps(s[j], s[i]) { + s[j] = nil + } } } - - return aggregatedCIDRs -} - -// Helper function to sort IP ranges by starting IP. -func sortCIDRs(ipRanges []net.IPNet) { - sort.SliceStable(ipRanges, func(i, j int) bool { - return bytes.Compare(ipRanges[i].IP, ipRanges[j].IP) < 0 - }) + return filter(s) } -// Helper function to check if two CIDRs can be aggregated. -func canAggregate(r1, r2 net.IPNet) bool { - return r1.Contains(r2.IP) || r2.Contains(r1.IP) +func overlaps(a, b *CIDR) bool { + return (a.PrefixUint32() / (1 << (32 - b.MaskLen()))) == + (b.PrefixUint32() / (1 << (32 - b.MaskLen()))) } -// Helper function to aggregate two CIDRs. -func aggregateCIDR(r1, r2 net.IPNet) net.IPNet { - mask1, _ := r1.Mask.Size() - mask2, _ := r2.Mask.Size() - - ipLen := net.IPv6len * 8 - if r1.IP.To4() != nil { - ipLen = net.IPv4len * 8 - } +// combineAdjacent returns a slice of CIDR structures with adjacent ranges +// combined. +func combineAdjacent(s []*CIDR) []*CIDR { + for { + found := false + l := len(s) + for i := 0; i < l-1; i++ { + if s[i] == nil { + continue + } + for j := i + 1; j < l; j++ { + if s[j] == nil { + continue + } + if adjacent(s[i], s[j]) { + c := fmt.Sprintf("%s/%d", s[i].IP.String(), s[i].MaskLen()-1) + s[i] = newCidr(c) + s[j] = nil + found = true + } + } + } - // Find the common prefix length - commonPrefixLen := mask1 - if mask2 < commonPrefixLen { - commonPrefixLen = mask2 + if !found { + break + } } + return filter(s) +} - commonPrefix := r1.IP.Mask(net.CIDRMask(commonPrefixLen, ipLen)) +func adjacent(a, b *CIDR) bool { + return (a.MaskLen() == b.MaskLen()) && + (a.PrefixUint32()%(2<<(32-b.MaskLen())) == 0) && + (b.PrefixUint32()-a.PrefixUint32() == (1 << (32 - a.MaskLen()))) +} - return net.IPNet{IP: commonPrefix, Mask: net.CIDRMask(commonPrefixLen, ipLen)} +func filter(s []*CIDR) []*CIDR { + out := s[:0] + for _, x := range s { + if x != nil { + out = append(out, x) + } + } + return out }