From 2a809a7d4cef092e9c15192d5a166fe35c4f4343 Mon Sep 17 00:00:00 2001 From: Myzel394 <50424412+Myzel394@users.noreply.github.com> Date: Sun, 18 Aug 2024 17:05:11 +0200 Subject: [PATCH] feat(wireguard): Improve analyzer --- handlers/fstab/fstab_test.go | 2 +- handlers/wireguard/analyzer.go | 70 ++++++++++++++++++---- handlers/wireguard/analyzer_test.go | 53 ++++++++++++++++ handlers/wireguard/documentation-fields.go | 6 +- 4 files changed, 116 insertions(+), 15 deletions(-) create mode 100644 handlers/wireguard/analyzer_test.go diff --git a/handlers/fstab/fstab_test.go b/handlers/fstab/fstab_test.go index 14d4261..52585d3 100644 --- a/handlers/fstab/fstab_test.go +++ b/handlers/fstab/fstab_test.go @@ -71,7 +71,7 @@ func TestValidBasicExample(t *testing.T) { t.Fatal("getCompletion failed to return correct number of completions. Got:", len(completions), "but expected:", 4) } - if completions[0].Label != "UUID" { + if completions[0].Label != "UUID" && completions[0].Label != "PARTUID" { t.Fatal("getCompletion failed to return correct label. Got:", completions[0].Label, "but expected:", "UUID") } } diff --git a/handlers/wireguard/analyzer.go b/handlers/wireguard/analyzer.go index 38f8afd..d03f6ba 100644 --- a/handlers/wireguard/analyzer.go +++ b/handlers/wireguard/analyzer.go @@ -11,6 +11,12 @@ import ( ) func (p wireguardParser) analyze() []protocol.Diagnostic { + sectionsErrors := p.analyzeSections() + + if len(sectionsErrors) > 0 { + return sectionsErrors + } + validCheckErrors := p.checkIfValuesAreValid() if len(validCheckErrors) > 0 { @@ -24,6 +30,55 @@ func (p wireguardParser) analyze() []protocol.Diagnostic { return diagnostics } +func (p wireguardParser) analyzeSections() []protocol.Diagnostic { + diagnostics := []protocol.Diagnostic{} + + for _, section := range p.Sections { + sectionDiagnostics := section.analyzeSection() + + if len(sectionDiagnostics) > 0 { + diagnostics = append(diagnostics, sectionDiagnostics...) + } + } + + if len(diagnostics) > 0 { + return diagnostics + } + + return p.analyzeOnlyOneInterfaceSectionSpecified() +} + +func (p wireguardParser) analyzeOnlyOneInterfaceSectionSpecified() []protocol.Diagnostic { + diagnostics := []protocol.Diagnostic{} + alreadyFound := false + + for _, section := range p.Sections { + if *section.Name == "Interface" { + if alreadyFound { + severity := protocol.DiagnosticSeverityError + diagnostics = append(diagnostics, protocol.Diagnostic{ + Message: "Only one [Interface] section is allowed", + Severity: &severity, + Range: protocol.Range{ + Start: protocol.Position{ + Line: section.StartLine, + Character: 0, + }, + End: protocol.Position{ + Line: section.StartLine, + Character: 99999999, + }, + }, + }) + } + + alreadyFound = true + } + } + + return diagnostics +} + func (p wireguardParser) analyzeDNSContainsFallback() []protocol.Diagnostic { lineNumber, property := p.fetchPropertyByName("DNS") @@ -56,17 +111,12 @@ func (p wireguardParser) analyzeDNSContainsFallback() []protocol.Diagnostic { return []protocol.Diagnostic{} } +// Check if the values are valid. +// Assumes that sections have been analyzed already. func (p wireguardParser) checkIfValuesAreValid() []protocol.Diagnostic { diagnostics := []protocol.Diagnostic{} for _, section := range p.Sections { - sectionDiagnostics := section.analyzeSection() - - if len(sectionDiagnostics) > 0 { - diagnostics = append(diagnostics, sectionDiagnostics...) - continue - } - for lineNumber, property := range section.Properties { diagnostics = append( diagnostics, @@ -253,12 +303,6 @@ func (p wireguardSection) analyzeDuplicateProperties() []protocol.Diagnostic { return diagnostics } -func (p wireguardSection) analyzeInterfaceSection() []protocol.Diagnostic { - diagnostics := []protocol.Diagnostic{} - - return diagnostics -} - func (p wireguardParser) analyzeAllowedIPIsInRange() []protocol.Diagnostic { diagnostics := []protocol.Diagnostic{} diff --git a/handlers/wireguard/analyzer_test.go b/handlers/wireguard/analyzer_test.go new file mode 100644 index 0000000..cd4c892 --- /dev/null +++ b/handlers/wireguard/analyzer_test.go @@ -0,0 +1,53 @@ +package wireguard + +import "testing" + +func TestMultipleIntefaces(t *testing.T) { + content := dedent(` +[Interface] +PrivateKey = abc + +[Interface] +PrivateKey = def +`) + parser := createWireguardParser() + parser.parseFromString(content) + + diagnostics := parser.analyze() + + if len(diagnostics) == 0 { + t.Errorf("Expected diagnostic errors, got %d", len(diagnostics)) + } +} + +func TestInvalidValue(t *testing.T) { + content := dedent(` +[Interface] +DNS = nope +`) + parser := createWireguardParser() + parser.parseFromString(content) + + diagnostics := parser.analyze() + + if len(diagnostics) == 0 { + t.Errorf("Expected diagnostic errors, got %d", len(diagnostics)) + } +} + +func TestDuplicateProperties(t *testing.T) { + content := dedent(` +[Interface] +PrivateKey = abc +DNS = 1.1.1.1 +PrivateKey = def +`) + parser := createWireguardParser() + parser.parseFromString(content) + + diagnostics := parser.analyze() + + if len(diagnostics) == 0 { + t.Errorf("Expected diagnostic errors, got %d", len(diagnostics)) + } +} diff --git a/handlers/wireguard/documentation-fields.go b/handlers/wireguard/documentation-fields.go index 403fdc9..23133b5 100644 --- a/handlers/wireguard/documentation-fields.go +++ b/handlers/wireguard/documentation-fields.go @@ -49,7 +49,7 @@ You can also specify multiple subnets or IPv6 subnets like so: }, }, "ListenPort": { - Documentation: `When the node is acting as a public bounce server, it should hardcode a port to listen for incoming VPN connections from the public internet. Clients not acting as relays should not set this value. + Documentation: `When the node is acting as a public bounce server, it should hardcode a port to listen for incoming VPN connections from the public internet. Clients not acting as relays should not set this value. If not specified, chosen randomly. ## Examples Using default WireGuard port @@ -214,6 +214,10 @@ Remove the iptables rule that forwards packets on the WireGuard interface `, Value: docvalues.StringValue{}, }, + "FwMark": { + Documentation: "a 32-bit fwmark for outgoing packets. If set to 0 or \"off\", this option is disabled. May be specified in hexadecimal by prepending \"0x\". Optional", + Value: docvalues.StringValue{}, + }, } var interfaceAllowedDuplicateFields = map[string]struct{}{