package sessionbundle import ( "archive/tar" "bytes" "compress/gzip" "encoding/base64" "fmt" "io" "os" "path/filepath" "strings" ) const maxBundleBytes = 4 * 1024 * 1024 var bundleFiles = map[string]struct{}{ "cache/id": {}, "cache/user": {}, "cache/quota": {}, "cache/config.json": {}, } type Result struct { Configured bool `json:"configured"` Restored bool `json:"restored"` Source string `json:"source,omitempty"` WorkDir string `json:"work_dir,omitempty"` Files []string `json:"files,omitempty"` Message string `json:"message,omitempty"` } func Restore(workDir string, inline string, filePath string) (Result, error) { result := Result{ Configured: strings.TrimSpace(inline) != "" || strings.TrimSpace(filePath) != "", WorkDir: strings.TrimSpace(workDir), } if !result.Configured { result.Message = "session bundle not configured" return result, nil } if result.WorkDir == "" { return result, fmt.Errorf("session bundle restore requires a work dir") } bundle, source, err := resolveBundle(inline, filePath) if err != nil { return result, err } raw, err := DecodeBundle(bundle) if err != nil { return result, err } files, err := ApplyBundleToWorkDir(result.WorkDir, raw) if err != nil { return result, err } result.Restored = true result.Source = source result.Files = files result.Message = fmt.Sprintf("restored %d files", len(files)) return result, nil } func DecodeBundle(b64 string) ([]byte, error) { b64 = strings.TrimSpace(b64) if b64 == "" { return nil, fmt.Errorf("empty bundle") } raw, err := base64.StdEncoding.DecodeString(b64) if err != nil { return nil, fmt.Errorf("invalid base64: %w", err) } if len(raw) > maxBundleBytes { return nil, fmt.Errorf("bundle too large: %d bytes", len(raw)) } return raw, nil } func ApplyBundleToWorkDir(workDir string, raw []byte) ([]string, error) { if len(raw) > maxBundleBytes { return nil, fmt.Errorf("bundle too large: %d bytes", len(raw)) } if err := os.MkdirAll(workDir, 0o755); err != nil { return nil, err } gz, err := gzip.NewReader(bytes.NewReader(raw)) if err != nil { return nil, fmt.Errorf("open bundle gzip: %w", err) } defer gz.Close() reader := tar.NewReader(gz) restored := make([]string, 0, len(bundleFiles)) var total int64 for { header, err := reader.Next() if err == io.EOF { break } if err != nil { return nil, fmt.Errorf("read bundle tar: %w", err) } if !isSafeMember(header) { continue } total += header.Size if total > maxBundleBytes { return nil, fmt.Errorf("bundle expanded too large: %d bytes", total) } dest := filepath.Join(workDir, filepath.FromSlash(header.Name)) if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { return nil, err } file, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, fileModeFor(header.Name)) if err != nil { return nil, err } if _, err := io.Copy(file, reader); err != nil { file.Close() return nil, err } if err := file.Close(); err != nil { return nil, err } restored = append(restored, header.Name) } return restored, nil } func resolveBundle(inline string, filePath string) (string, string, error) { if value := strings.TrimSpace(inline); value != "" { return value, "inline", nil } path := strings.TrimSpace(filePath) if path == "" { return "", "", fmt.Errorf("session bundle file path is empty") } expanded := expandHome(path) body, err := os.ReadFile(expanded) if err != nil { return "", "", fmt.Errorf("read bundle file: %w", err) } return strings.TrimSpace(string(body)), filepath.Base(expanded), nil } func isSafeMember(header *tar.Header) bool { if header == nil { return false } if _, ok := bundleFiles[header.Name]; !ok { return false } if header.Typeflag != tar.TypeReg && header.Typeflag != tar.TypeRegA { return false } if header.Name == "" || strings.HasPrefix(header.Name, "/") { return false } for _, part := range strings.Split(header.Name, "/") { if part == ".." { return false } } return true } func fileModeFor(name string) os.FileMode { if strings.HasSuffix(name, "/user") { return 0o600 } return 0o644 } func expandHome(path string) string { path = strings.TrimSpace(path) if path == "" || path[0] != '~' { return path } home, err := os.UserHomeDir() if err != nil || home == "" { return path } if path == "~" { return home } if len(path) > 1 && (path[1] == '/' || path[1] == '\\') { return filepath.Join(home, path[2:]) } return path }