diff options
Diffstat (limited to 'pkg/copy')
-rw-r--r-- | pkg/copy/copy.go | 67 |
1 files changed, 50 insertions, 17 deletions
diff --git a/pkg/copy/copy.go b/pkg/copy/copy.go index 3993b532e..13893deb2 100644 --- a/pkg/copy/copy.go +++ b/pkg/copy/copy.go @@ -25,31 +25,61 @@ import ( // // **************************************************************************** -// Copy the source item to destination. Use extract to untar the source if -// it's a tar archive. -func Copy(source *CopyItem, destination *CopyItem, extract bool) error { +// Copier copies data from a source to a destination CopyItem. +type Copier struct { + copyFunc func() error + cleanUpFuncs []deferFunc +} + +// cleanUp releases resources the Copier may hold open. +func (c *Copier) cleanUp() { + for _, f := range c.cleanUpFuncs { + f() + } +} + +// Copy data from a source to a destination CopyItem. +func (c *Copier) Copy() error { + defer c.cleanUp() + return c.copyFunc() +} + +// GetCopiers returns a Copier to copy the source item to destination. Use +// extract to untar the source if it's a tar archive. +func GetCopier(source *CopyItem, destination *CopyItem, extract bool) (*Copier, error) { + copier := &Copier{} + // First, do the man-page dance. See podman-cp(1) for details. if err := enforceCopyRules(source, destination); err != nil { - return err + return nil, err } // Destination is a stream (e.g., stdout or an http body). if destination.info.IsStream { // Source is a stream (e.g., stdin or an http body). if source.info.IsStream { - _, err := io.Copy(destination.writer, source.reader) - return err + copier.copyFunc = func() error { + _, err := io.Copy(destination.writer, source.reader) + return err + } + return copier, nil } root, glob, err := source.buildahGlobs() if err != nil { - return err + return nil, err } - return buildahCopiah.Get(root, "", source.getOptions(), []string{glob}, destination.writer) + copier.copyFunc = func() error { + return buildahCopiah.Get(root, "", source.getOptions(), []string{glob}, destination.writer) + } + return copier, nil } // Destination is either a file or a directory. if source.info.IsStream { - return buildahCopiah.Put(destination.root, destination.resolved, source.putOptions(), source.reader) + copier.copyFunc = func() error { + return buildahCopiah.Put(destination.root, destination.resolved, source.putOptions(), source.reader) + } + return copier, nil } tarOptions := &archive.TarOptions{ @@ -71,33 +101,36 @@ func Copy(source *CopyItem, destination *CopyItem, extract bool) error { var tarReader io.ReadCloser if extract && archive.IsArchivePath(source.resolved) { if !destination.info.IsDir { - return errors.Errorf("cannot extract archive %q to file %q", source.original, destination.original) + return nil, errors.Errorf("cannot extract archive %q to file %q", source.original, destination.original) } reader, err := os.Open(source.resolved) if err != nil { - return err + return nil, err } - defer reader.Close() + copier.cleanUpFuncs = append(copier.cleanUpFuncs, func() { reader.Close() }) // The stream from stdin may be compressed (e.g., via gzip). decompressedStream, err := archive.DecompressStream(reader) if err != nil { - return err + return nil, err } - defer decompressedStream.Close() + copier.cleanUpFuncs = append(copier.cleanUpFuncs, func() { decompressedStream.Close() }) tarReader = decompressedStream } else { reader, err := archive.TarWithOptions(source.resolved, tarOptions) if err != nil { - return err + return nil, err } - defer reader.Close() + copier.cleanUpFuncs = append(copier.cleanUpFuncs, func() { reader.Close() }) tarReader = reader } - return buildahCopiah.Put(root, dir, source.putOptions(), tarReader) + copier.copyFunc = func() error { + return buildahCopiah.Put(root, dir, source.putOptions(), tarReader) + } + return copier, nil } // enforceCopyRules enforces the rules for copying from a source to a |