From bc529ddfe8f0eae71a401e3931831a5f2eee67b5 Mon Sep 17 00:00:00 2001 From: Myzel394 <50424412+Myzel394@users.noreply.github.com> Date: Mon, 19 Aug 2024 22:33:42 +0200 Subject: [PATCH] chore(wireguard): Improvements --- handlers/wireguard/analyzer.go | 94 ++++++---------------- handlers/wireguard/documentation-fields.go | 5 +- handlers/wireguard/fetch-code-actions.go | 2 +- handlers/wireguard/wg-parser.go | 12 +++ handlers/wireguard/wg-property.go | 15 ++++ handlers/wireguard/wg-section.go | 27 ++++++- 6 files changed, 80 insertions(+), 75 deletions(-) diff --git a/handlers/wireguard/analyzer.go b/handlers/wireguard/analyzer.go index 9c587e5..3194e06 100644 --- a/handlers/wireguard/analyzer.go +++ b/handlers/wireguard/analyzer.go @@ -53,28 +53,17 @@ func (p wireguardParser) analyzeOnlyOneInterfaceSectionSpecified() []protocol.Di diagnostics := []protocol.Diagnostic{} alreadyFound := false - for _, section := range p.Sections { - if *section.Name == "Interface" { - if alreadyFound { - severity := protocol.DiagnosticSeverityError - diagnostics = append(diagnostics, protocol.Diagnostic{ - Message: "Only one [Interface] section is allowed", - Severity: &severity, - Range: protocol.Range{ - Start: protocol.Position{ - Line: section.StartLine, - Character: 0, - }, - End: protocol.Position{ - Line: section.StartLine, - Character: 99999999, - }, - }, - }) - } - - alreadyFound = true + for _, section := range p.getSectionsByName("Interface") { + if alreadyFound { + severity := protocol.DiagnosticSeverityError + diagnostics = append(diagnostics, protocol.Diagnostic{ + Message: "Only one [Interface] section is allowed", + Severity: &severity, + Range: section.getHeaderLineRange(), + }) } + + alreadyFound = true } return diagnostics @@ -116,28 +105,15 @@ func (p wireguardParser) analyzeDNSContainsFallback() []protocol.Diagnostic { func (p wireguardParser) analyzeKeepAliveIsSet() []protocol.Diagnostic { diagnostics := make([]protocol.Diagnostic, 0) - for _, section := range p.Sections { - if section.Name != nil && *section.Name == "Peer" { - // If an endpoint is set, then we should only check for the keepalive property - if section.fetchFirstProperty("Endpoint") != nil { - if section.fetchFirstProperty("PersistentKeepalive") == nil { - severity := protocol.DiagnosticSeverityHint - diagnostics = append(diagnostics, protocol.Diagnostic{ - Message: "PersistentKeepalive is not set. It is recommended to set this property, as it helps to maintain the connection when users are behind NAT", - Severity: &severity, - Range: protocol.Range{ - Start: protocol.Position{ - Line: section.StartLine, - Character: 0, - }, - End: protocol.Position{ - Line: section.StartLine, - Character: 99999999, - }, - }, - }) - } - } + for _, section := range p.getSectionsByName("Peer") { + // If an endpoint is set, then we should only check for the keepalive property + if section.existsProperty("Endpoint") && !section.existsProperty("PersistentKeepalive") { + severity := protocol.DiagnosticSeverityHint + diagnostics = append(diagnostics, protocol.Diagnostic{ + Message: "PersistentKeepalive is not set. It is recommended to set this property, as it helps to maintain the connection when users are behind NAT", + Severity: &severity, + Range: section.getRange(), + }) } } @@ -161,36 +137,27 @@ func (p wireguardParser) checkIfValuesAreValid() []protocol.Diagnostic { return diagnostics } -func (p wireguardSection) analyzeSection() []protocol.Diagnostic { +func (s wireguardSection) analyzeSection() []protocol.Diagnostic { diagnostics := []protocol.Diagnostic{} - if p.Name == nil { + if s.Name == nil { // No section name severity := protocol.DiagnosticSeverityError diagnostics = append(diagnostics, protocol.Diagnostic{ Message: "This section is missing a name", Severity: &severity, - Range: p.getRange(), + Range: s.getRange(), }) return diagnostics } - if _, found := optionsHeaderMap[*p.Name]; !found { + if _, found := optionsHeaderMap[*s.Name]; !found { // Unknown section severity := protocol.DiagnosticSeverityError diagnostics = append(diagnostics, protocol.Diagnostic{ - Message: fmt.Sprintf("Unknown section '%s'. It must be one of: [Interface], [Peer]", *p.Name), + Message: fmt.Sprintf("Unknown section '%s'. It must be one of: [Interface], [Peer]", *s.Name), Severity: &severity, - Range: protocol.Range{ - Start: protocol.Position{ - Line: p.StartLine, - Character: 0, - }, - End: protocol.Position{ - Line: p.StartLine, - Character: 99999999, - }, - }, + Range: s.getHeaderLineRange(), }) return diagnostics @@ -238,16 +205,7 @@ func (p wireguardProperty) analyzeProperty( { Message: "Property is missing a value", Severity: &severity, - Range: protocol.Range{ - Start: protocol.Position{ - Line: propertyLine, - Character: 0, - }, - End: protocol.Position{ - Line: propertyLine, - Character: 99999999, - }, - }, + Range: p.getLineRange(propertyLine), }, } } diff --git a/handlers/wireguard/documentation-fields.go b/handlers/wireguard/documentation-fields.go index 23133b5..9ca858d 100644 --- a/handlers/wireguard/documentation-fields.go +++ b/handlers/wireguard/documentation-fields.go @@ -44,8 +44,9 @@ You can also specify multiple subnets or IPv6 subnets like so: Address = 192.0.2.1/24,2001:DB8::/64 `, Value: docvalues.IPAddressValue{ - AllowIPv4: true, - AllowIPv6: true, + AllowIPv4: true, + AllowIPv6: true, + AllowRange: true, }, }, "ListenPort": { diff --git a/handlers/wireguard/fetch-code-actions.go b/handlers/wireguard/fetch-code-actions.go index 6601ba7..0d4b5a0 100644 --- a/handlers/wireguard/fetch-code-actions.go +++ b/handlers/wireguard/fetch-code-actions.go @@ -10,7 +10,7 @@ func getKeepaliveCodeActions( for index, section := range parser.Sections { if section.StartLine >= line && line <= section.EndLine && section.Name != nil && *section.Name == "Peer" { - if section.fetchFirstProperty("Endpoint") != nil && section.fetchFirstProperty("PersistentKeepalive") == nil { + if section.existsProperty("Endpoint") && !section.existsProperty("PersistentKeepalive") { commandID := "wireguard." + codeActionAddKeepalive command := protocol.Command{ Title: "Add PersistentKeepalive", diff --git a/handlers/wireguard/wg-parser.go b/handlers/wireguard/wg-parser.go index a035185..b93e73f 100644 --- a/handlers/wireguard/wg-parser.go +++ b/handlers/wireguard/wg-parser.go @@ -343,3 +343,15 @@ func (p *wireguardParser) getPropertyByLine(line uint32) (*wireguardSection, *wi return section, property } + +func (p *wireguardParser) getSectionsByName(name string) []*wireguardSection { + var sections []*wireguardSection + + for _, section := range p.Sections { + if section.Name != nil && *section.Name == name { + sections = append(sections, section) + } + } + + return sections +} diff --git a/handlers/wireguard/wg-property.go b/handlers/wireguard/wg-property.go index 59ed3eb..5913c92 100644 --- a/handlers/wireguard/wg-property.go +++ b/handlers/wireguard/wg-property.go @@ -5,6 +5,8 @@ import ( "config-lsp/utils" "regexp" "strings" + + protocol "github.com/tliron/glsp/protocol_3_16" ) var linePattern = regexp.MustCompile(`^\s*(?P.+?)\s*(?P=)\s*(?P\S.*?)?\s*(?:(?:;|#).*)?\s*$`) @@ -37,6 +39,19 @@ func (p wireguardProperty) String() string { return p.Key.Name + "=" + p.Value.Value } +func (p wireguardProperty) getLineRange(line uint32) protocol.Range { + return protocol.Range{ + Start: protocol.Position{ + Line: line, + Character: p.Key.Location.Start, + }, + End: protocol.Position{ + Line: line, + Character: p.Key.Location.End, + }, + } +} + func createWireguardProperty(line string) (*wireguardProperty, error) { if !strings.Contains(line, "=") { indexes := utils.GetTrimIndex(line) diff --git a/handlers/wireguard/wg-section.go b/handlers/wireguard/wg-section.go index 62b65b8..1f09deb 100644 --- a/handlers/wireguard/wg-section.go +++ b/handlers/wireguard/wg-section.go @@ -37,6 +37,19 @@ type wireguardSection struct { Properties wireguardProperties } +func (s wireguardSection) getHeaderLineRange() protocol.Range { + return protocol.Range{ + Start: protocol.Position{ + Line: s.StartLine, + Character: 0, + }, + End: protocol.Position{ + Line: s.StartLine, + Character: 99999999, + }, + } +} + func (s wireguardSection) getRange() protocol.Range { return protocol.Range{ Start: protocol.Position{ @@ -62,14 +75,20 @@ func (s wireguardSection) String() string { return fmt.Sprintf("[%s]; %d-%d: %v", name, s.StartLine, s.EndLine, s.Properties) } -func (s *wireguardSection) fetchFirstProperty(name string) *wireguardProperty { - for _, property := range s.Properties { +func (s *wireguardSection) fetchFirstProperty(name string) (*uint32, *wireguardProperty) { + for line, property := range s.Properties { if property.Key.Name == name { - return &property + return &line, &property } } - return nil + return nil, nil +} + +func (s *wireguardSection) existsProperty(name string) bool { + _, property := s.fetchFirstProperty(name) + + return property != nil } func (s *wireguardSection) findProperty(lineNumber uint32) (*wireguardProperty, error) {