Skip to content

Commit a976da7

Browse files
committed
Fix backup logic, ensure configurations retain formatting, and sort ListConfigs output
1 parent 5e0fc4d commit a976da7

3 files changed

Lines changed: 117 additions & 16 deletions

File tree

.goreleaser.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ builds:
1515
- arm64
1616
main: ./cmd/sshc
1717
archives:
18-
- format: tar.gz
18+
- formats: [ tar.gz ]
1919
# this name template makes the OS and Arch compatible with what nm and otool expect
2020
name_template: >-
2121
{{ .ProjectName }}_
@@ -27,11 +27,11 @@ archives:
2727
# use zip for windows archives
2828
format_overrides:
2929
- goos: windows
30-
format: zip
30+
formats: [ zip ]
3131
checksum:
3232
name_template: 'checksums.txt'
3333
snapshot:
34-
name_template: "{{ incpatch .Version }}-next"
34+
version_template: "{{ incpatch .Version }}-next"
3535
changelog:
3636
sort: asc
3737
filters:

internal/config/manager.go

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"os"
66
"path/filepath"
7+
"slices"
78
"strings"
89
)
910

@@ -36,16 +37,22 @@ func (m *Manager) Init() error {
3637
backupFile := configFile + ".backup"
3738

3839
if _, err := os.Stat(configFile); err == nil {
39-
// 1. backup current ssh config
40-
content, err := os.ReadFile(configFile)
41-
if err != nil {
42-
return fmt.Errorf("failed to read ssh config for backup: %w", err)
43-
}
44-
if err := os.WriteFile(backupFile, content, 0600); err != nil {
45-
return fmt.Errorf("failed to create backup: %w", err)
40+
// 1. backup current ssh config if backup doesn't exist
41+
if _, err := os.Stat(backupFile); os.IsNotExist(err) {
42+
content, err := os.ReadFile(configFile)
43+
if err != nil {
44+
return fmt.Errorf("failed to read ssh config for backup: %w", err)
45+
}
46+
if err := os.WriteFile(backupFile, content, 0600); err != nil {
47+
return fmt.Errorf("failed to create backup: %w", err)
48+
}
4649
}
4750

4851
// 2. Create a new ssh config with the include
52+
content, err := os.ReadFile(configFile)
53+
if err != nil {
54+
return fmt.Errorf("failed to read ssh config: %w", err)
55+
}
4956
// Check if Include line already exists
5057
found := false
5158
for line := range strings.SplitSeq(string(content), "\n") {
@@ -128,23 +135,23 @@ func (m *Manager) UpdateConfig(name string, opts ConfigOptions) error {
128135
} else if strings.HasPrefix(lowerTrimmed, "hostname ") {
129136
foundFields["hostname"] = true
130137
if opts.Hostname != "" {
131-
indent := line[:strings.Index(lowerTrimmed, "hostname")]
138+
indent := line[:strings.Index(strings.ToLower(line), "hostname")]
132139
newLines = append(newLines, indent+"Hostname "+opts.Hostname)
133140
} else {
134141
newLines = append(newLines, line)
135142
}
136143
} else if strings.HasPrefix(lowerTrimmed, "user ") {
137144
foundFields["user"] = true
138145
if opts.User != "" {
139-
indent := line[:strings.Index(lowerTrimmed, "user")]
146+
indent := line[:strings.Index(strings.ToLower(line), "user")]
140147
newLines = append(newLines, indent+"User "+opts.User)
141148
} else {
142149
newLines = append(newLines, line)
143150
}
144151
} else if strings.HasPrefix(lowerTrimmed, "port ") {
145152
foundFields["port"] = true
146153
if opts.Port != 0 {
147-
indent := line[:strings.Index(lowerTrimmed, "port")]
154+
indent := line[:strings.Index(strings.ToLower(line), "port")]
148155
newLines = append(newLines, indent+fmt.Sprintf("Port %d", opts.Port))
149156
} else {
150157
newLines = append(newLines, line)
@@ -157,23 +164,23 @@ func (m *Manager) UpdateConfig(name string, opts ConfigOptions) error {
157164
if absPath, err := filepath.Abs(identity); err == nil {
158165
identity = absPath
159166
}
160-
indent := line[:strings.Index(lowerTrimmed, "identityfile")]
167+
indent := line[:strings.Index(strings.ToLower(line), "identityfile")]
161168
newLines = append(newLines, indent+"IdentityFile "+identity)
162169
} else {
163170
newLines = append(newLines, line)
164171
}
165172
} else if strings.HasPrefix(lowerTrimmed, "forwardagent ") {
166173
foundFields["forwardagent"] = true
167174
if opts.ForwardAgent != "" {
168-
indent := line[:strings.Index(lowerTrimmed, "forwardagent")]
175+
indent := line[:strings.Index(strings.ToLower(line), "forwardagent")]
169176
newLines = append(newLines, indent+"ForwardAgent "+opts.ForwardAgent)
170177
} else {
171178
newLines = append(newLines, line)
172179
}
173180
} else if strings.HasPrefix(lowerTrimmed, "proxyjump ") {
174181
foundFields["proxyjump"] = true
175182
if opts.ProxyJump != "" {
176-
indent := line[:strings.Index(lowerTrimmed, "proxyjump")]
183+
indent := line[:strings.Index(strings.ToLower(line), "proxyjump")]
177184
newLines = append(newLines, indent+"ProxyJump "+opts.ProxyJump)
178185
} else {
179186
newLines = append(newLines, line)
@@ -258,5 +265,6 @@ func (m *Manager) ListConfigs() ([]string, error) {
258265
configs = append(configs, entry.Name())
259266
}
260267
}
268+
slices.Sort(configs)
261269
return configs, nil
262270
}

internal/config/manager_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,27 @@ func TestManager_Init_Backup(t *testing.T) {
9090
if count != 1 {
9191
t.Errorf("Include line found %d times, expected 1", count)
9292
}
93+
94+
// Verify that backup was NOT overwritten
95+
// 1. Change the config content
96+
newConfigContent := "Host new\n Hostname updated.com"
97+
if err := os.WriteFile(filepath.Join(tmpDir, "config"), []byte(newConfigContent), 0600); err != nil {
98+
t.Fatal(err)
99+
}
100+
101+
// 2. Run Init again
102+
if err := m.Init(); err != nil {
103+
t.Fatal(err)
104+
}
105+
106+
// 3. Backup should still have the ORIGINAL content, not the newConfigContent
107+
backupContentAfter, err := os.ReadFile(filepath.Join(tmpDir, "config.backup"))
108+
if err != nil {
109+
t.Fatal(err)
110+
}
111+
if string(backupContentAfter) != initialContent {
112+
t.Errorf("Backup was overwritten! Got: %s, Want: %s", string(backupContentAfter), initialContent)
113+
}
93114
}
94115

95116
func TestManager_AddRemoveConfig(t *testing.T) {
@@ -129,6 +150,37 @@ func TestManager_AddRemoveConfig(t *testing.T) {
129150
}
130151
}
131152

153+
func TestManager_ListConfigs_Sorted(t *testing.T) {
154+
tmpDir, err := os.MkdirTemp("", "sshc-test-list")
155+
if err != nil {
156+
t.Fatal(err)
157+
}
158+
defer os.RemoveAll(tmpDir)
159+
160+
m := &Manager{
161+
SshDir: tmpDir,
162+
}
163+
_ = m.Init()
164+
165+
// Add configs in non-alphabetical order
166+
names := []string{"zebra", "apple", "banana"}
167+
for _, name := range names {
168+
if err := m.AddConfig(name, "Host "+name); err != nil {
169+
t.Fatal(err)
170+
}
171+
}
172+
173+
configs, err := m.ListConfigs()
174+
if err != nil {
175+
t.Fatal(err)
176+
}
177+
178+
expected := []string{"apple", "banana", "zebra"}
179+
if !slices.Equal(configs, expected) {
180+
t.Errorf("ListConfigs() not sorted. Got: %v, Want: %v", configs, expected)
181+
}
182+
}
183+
132184
func TestManager_UpdateConfig(t *testing.T) {
133185
tmpDir, err := os.MkdirTemp("", "sshc-test-update")
134186
if err != nil {
@@ -218,3 +270,44 @@ func TestManager_UpdateConfig(t *testing.T) {
218270
t.Errorf("ProxyJump not added")
219271
}
220272
}
273+
274+
func TestManager_UpdateConfig_Indentation(t *testing.T) {
275+
tmpDir, err := os.MkdirTemp("", "sshc-test-indent")
276+
if err != nil {
277+
t.Fatal(err)
278+
}
279+
defer os.RemoveAll(tmpDir)
280+
281+
m := &Manager{
282+
SshDir: tmpDir,
283+
}
284+
_ = m.Init()
285+
286+
name := "indent-test"
287+
// Use 8 spaces for indentation
288+
initialContent := "Host my-host\n Hostname old-hostname\n"
289+
if err := m.AddConfig(name, initialContent); err != nil {
290+
t.Fatal(err)
291+
}
292+
293+
opts := ConfigOptions{
294+
Hostname: "new-hostname",
295+
}
296+
if err := m.UpdateConfig(name, opts); err != nil {
297+
t.Fatal(err)
298+
}
299+
300+
content, err := os.ReadFile(m.GetConfigPath(name))
301+
if err != nil {
302+
t.Fatal(err)
303+
}
304+
305+
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
306+
for _, line := range lines {
307+
if strings.Contains(line, "Hostname") {
308+
if !strings.HasPrefix(line, " ") {
309+
t.Errorf("Indentation lost. Expected 8 spaces at start of line: %q", line)
310+
}
311+
}
312+
}
313+
}

0 commit comments

Comments
 (0)