package scp import ( "bytes" "errors" "fmt" "io" "os" "path/filepath" log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) const ( fileMode = "0644" buffSize = 1024 * 256 ) //CopyTo copy from local to remote func CopyTo(sshClient *ssh.Client, local string, remote string) (int64, error) { session, err := sshClient.NewSession() if err != nil { return 0, err } defer session.Close() stderr := &bytes.Buffer{} session.Stderr = stderr stdout := &bytes.Buffer{} session.Stdout = stdout writer, err := session.StdinPipe() if err != nil { return 0, err } defer writer.Close() err = session.Start("scp -t " + filepath.Dir(remote)) if err != nil { return 0, err } localFile, err := os.Open(local) if err != nil { return 0, err } fileInfo, err := localFile.Stat() if err != nil { return 0, err } _, err = fmt.Fprintf(writer, "C%s %d %s\n", fileMode, fileInfo.Size(), filepath.Base(remote)) if err != nil { return 0, err } n, err := copyN(writer, localFile, fileInfo.Size()) if err != nil { return 0, err } err = ack(writer) if err != nil { return 0, err } err = session.Wait() log.Debugf("Copied %v bytes out of %v. err: %v stdout:%v. stderr:%v", n, fileInfo.Size(), err, stdout, stderr) //NOTE: Process exited with status 1 is not an error, it just how scp work. (waiting for the next control message and we send EOF) return n, nil } //CopyFrom copy from remote to local func CopyFrom(sshClient *ssh.Client, remote string, local string) (int64, error) { session, err := sshClient.NewSession() if err != nil { return 0, err } defer session.Close() stderr := &bytes.Buffer{} session.Stderr = stderr writer, err := session.StdinPipe() if err != nil { return 0, err } defer writer.Close() reader, err := session.StdoutPipe() if err != nil { return 0, err } err = session.Start("scp -f " + remote) if err != nil { return 0, err } err = ack(writer) if err != nil { return 0, err } msg, err := NewMessageFromReader(reader) if err != nil { return 0, err } if msg.Type == ErrorMessage || msg.Type == WarnMessage { return 0, msg.Error } log.Debugf("Receiving %v", msg) err = ack(writer) if err != nil { return 0, err } outFile, err := os.Create(local) if err != nil { return 0, err } defer outFile.Close() n, err := copyN(outFile, reader, msg.Size) if err != nil { return 0, err } err = outFile.Sync() if err != nil { return 0, err } err = outFile.Close() if err != nil { return 0, err } err = session.Wait() log.Debugf("Copied %v bytes out of %v. err: %v stderr:%v", n, msg.Size, err, stderr) return n, nil } func ack(writer io.Writer) error { var msg = []byte{0, 0, 10, 13} n, err := writer.Write(msg) if err != nil { return err } if n < len(msg) { return errors.New("Failed to write ack buffer") } return nil } func copyN(writer io.Writer, src io.Reader, size int64) (int64, error) { reader := io.LimitReader(src, size) var total int64 for total < size { n, err := io.CopyBuffer(writer, reader, make([]byte, buffSize)) log.Debugf("Copied chunk %v total: %v out of %v err: %v ", n, total, size, err) if err != nil { return 0, err } total += n } return total, nil }