From b7f13578307e3ec6fc42115715b1db598757288e Mon Sep 17 00:00:00 2001 From: Myzel394 <50424412+Myzel394@users.noreply.github.com> Date: Wed, 21 Aug 2024 22:41:51 +0200 Subject: [PATCH] feat: Add IP host set --- utils/ip-host.go | 144 ++++++++++++++++++++++++++++++++++++++++++ utils/ip-host_test.go | 68 ++++++++++++++++++++ 2 files changed, 212 insertions(+) create mode 100644 utils/ip-host.go create mode 100644 utils/ip-host_test.go diff --git a/utils/ip-host.go b/utils/ip-host.go new file mode 100644 index 0000000..4fab166 --- /dev/null +++ b/utils/ip-host.go @@ -0,0 +1,144 @@ +package utils + +import ( + "context" + "errors" + net "net/netip" +) + +type iPv4Tree struct { + TrueNode *iPv4Tree + FalseNode *iPv4Tree + Context context.Context +} + +func (t *iPv4Tree) addHostBits( + hostBits []bool, + ctx context.Context, +) { + if len(hostBits) == 0 { + t.Context = ctx + return + } + + if hostBits[0] { + if t.TrueNode == nil { + t.TrueNode = &iPv4Tree{} + } + t.TrueNode.addHostBits(hostBits[1:], ctx) + } else { + if t.FalseNode == nil { + t.FalseNode = &iPv4Tree{} + } + t.FalseNode.addHostBits(hostBits[1:], ctx) + } +} + +func (t *iPv4Tree) getFromHostBits(hostBits []bool) *context.Context { + if t.Context != nil || len(hostBits) == 0 { + return &t.Context + } + + if hostBits[0] { + if t.TrueNode == nil { + return nil + } + + return t.TrueNode.getFromHostBits(hostBits[1:]) + } else { + if t.FalseNode == nil { + return nil + } + + return t.FalseNode.getFromHostBits(hostBits[1:]) + } +} + +func createIPv4Tree( + hostBits []bool, + ctx context.Context, +) iPv4Tree { + tree := iPv4Tree{} + tree.addHostBits(hostBits, ctx) + + return tree +} + +type IPv4HostSet struct { + tree iPv4Tree +} + +func CreateIPv4HostSet() IPv4HostSet { + return IPv4HostSet{ + tree: iPv4Tree{}, + } +} + +// Add a new ip to the host set +// `hostAmount`: Amount of host bits +// Return: (, ) +func (h *IPv4HostSet) AddIP( + ip net.Prefix, + ctx context.Context, +) (bool, error) { + hostBits, err := ipToHostBits(ip) + + if err != nil { + return false, err + } + + if h.tree.getFromHostBits(hostBits) != nil { + return false, nil + } + + h.tree.addHostBits(hostBits, ctx) + + return true, nil +} + +func (h IPv4HostSet) ContainsIP( + ip net.Prefix, +) (*context.Context, error) { + hostBits, err := ipToHostBits(ip) + + if err != nil { + ctx := context.Background() + return &ctx, err + } + + ctx := h.tree.getFromHostBits(hostBits) + + return ctx, nil +} + +func ipToHostBits(ip net.Prefix) ([]bool, error) { + if !ip.Addr().Is4() { + return nil, errors.New("Only IPv4 is supported currently") + } + + ipv4 := ip.Addr().As4() + allHostBits := [32]bool{} + for i, b := range ipv4 { + bits := byteToBits(b) + for j, bit := range bits { + allHostBits[i*8+j] = bit + } + } + + hostBits := allHostBits[:ip.Bits()] + + return hostBits, nil +} + +func byteToBits(b byte) [8]bool { + return [8]bool{ + (b>>0)&1 != 0, + (b>>1)&1 != 0, + (b>>2)&1 != 0, + (b>>3)&1 != 0, + (b>>4)&1 != 0, + (b>>5)&1 != 0, + (b>>6)&1 != 0, + (b>>7)&1 != 0, + } +} diff --git a/utils/ip-host_test.go b/utils/ip-host_test.go new file mode 100644 index 0000000..509978f --- /dev/null +++ b/utils/ip-host_test.go @@ -0,0 +1,68 @@ +package utils + +import ( + "context" + "net/netip" + "testing" +) + +func TestFullHostIpAddresses(t *testing.T) { + // Test the full host IP address + hostSet := CreateIPv4HostSet() + + hostSet.AddIP(netip.MustParsePrefix("10.0.0.1/32"), context.Background()) + hostSet.AddIP(netip.MustParsePrefix("10.0.0.2/32"), context.Background()) + hostSet.AddIP(netip.MustParsePrefix("10.0.0.3/32"), context.Background()) + + if ctx, _ := hostSet.ContainsIP(netip.MustParsePrefix("10.0.0.1/32")); ctx == nil { + t.Fatalf("Expected to find 10.0.0.1/32 in the host set") + } + + if ctx, _ := hostSet.ContainsIP(netip.MustParsePrefix("10.0.0.5/32")); ctx != nil { + t.Fatalf("Expected NOT to find 10.0.0.5/32 in the host set") + } +} + +func TestPartialHostIpAddresses(t *testing.T) { + // Test the partial host IP address + hostSet := CreateIPv4HostSet() + + hostSet.AddIP(netip.MustParsePrefix("10.0.0.1/32"), context.Background()) + hostSet.AddIP(netip.MustParsePrefix("10.0.0.2/32"), context.Background()) + hostSet.AddIP(netip.MustParsePrefix("10.0.0.3/32"), context.Background()) + + if ctx, _ := hostSet.ContainsIP(netip.MustParsePrefix("10.0.0.1/16")); ctx == nil { + t.Fatalf("Expected to find 10.0.0.1/16 in the host set") + } + + if ctx, _ := hostSet.ContainsIP(netip.MustParsePrefix("192.168.0.1/16")); ctx != nil { + t.Fatalf("Expected NOT to find 192.168.0.1/16 in the host set") + } +} + +func TestMixedHostIpAddresses(t *testing.T) { + // Test the mixed host IP address + hostSet := CreateIPv4HostSet() + + hostSet.AddIP(netip.MustParsePrefix("10.0.0.1/16"), context.Background()) + hostSet.AddIP(netip.MustParsePrefix("192.168.0.1/32"), context.Background()) + + if ctx, _ := hostSet.ContainsIP(netip.MustParsePrefix("10.0.0.2/32")); ctx == nil { + t.Fatalf("Expected to find 10.0.0.3/32 in the host set") + } + + if ctx, _ := hostSet.ContainsIP(netip.MustParsePrefix("192.168.0.2/32")); ctx != nil { + t.Fatalf("Expected NOT to find 192.168.0.2/32 in the host set") + } +} + +func TestSimpleExactCheck(t *testing.T) { + // Test the real example + hostSet := CreateIPv4HostSet() + + hostSet.AddIP(netip.MustParsePrefix("10.0.0.1/16"), context.Background()) + + if ctx, _ := hostSet.ContainsIP(netip.MustParsePrefix("10.0.0.1/16")); ctx == nil { + t.Fatalf("Expected to find 10.0.0.1/16 in the host set") + } +}