diff --git a/pkg/compose/image_loader.go b/pkg/compose/image_loader.go index 0a262d1..da1062b 100644 --- a/pkg/compose/image_loader.go +++ b/pkg/compose/image_loader.go @@ -5,12 +5,13 @@ import ( "context" "encoding/json" "fmt" + "io" + "path" + "github.com/docker/docker/client" "github.com/docker/docker/pkg/jsonmessage" "github.com/foundriesio/composeapp/internal/progress" v1 "github.com/opencontainers/image-spec/specs-go/v1" - "io" - "path" ) type ( @@ -67,6 +68,12 @@ type ( r *io.PipeReader w *io.PipeWriter } + + imageURI2RefCounter struct { + URI string + // reference counter to this image URI, this is how many times "dockerd" reports this image is loaded + refCounter int + } ) const ( @@ -200,7 +207,7 @@ func LoadImages(ctx context.Context, curLayerID := "" p := &LoadImageProgress{ State: ImageLoadStateImageWaiting, - ImageID: imageURIs[curImageIndex], + ImageID: imageURIs[curImageIndex].URI, } for { @@ -222,18 +229,18 @@ func LoadImages(ctx context.Context, case ImageLoadStateImageWaiting: { if jm.Progress == nil { - p.State = ImageLoadStateImageExist - reportProgressIfEnabled(options, p) - - curImageIndex++ - curLayerID = "" - p.State = ImageLoadStateImageWaiting - if curImageIndex < len(imageURIs) { - p.ImageID = imageURIs[curImageIndex] + if imageURIs[curImageIndex].refCounter--; imageURIs[curImageIndex].refCounter == 0 { + p.State = ImageLoadStateImageExist + reportProgressIfEnabled(options, p) + + curImageIndex++ + curLayerID = "" + p.State = ImageLoadStateImageWaiting + p.ImageID = getImageID(imageURIs, curImageIndex) } } else { curLayerID = jm.ID - p.ImageID = imageURIs[curImageIndex] + p.ImageID = getImageID(imageURIs, curImageIndex) if _, ok := layersMap[curLayerID]; ok { p.ID = layersMap[curLayerID][:7] } else { @@ -260,15 +267,16 @@ func LoadImages(ctx context.Context, case ImageLoadStateLayerSyncing: { if jm.Progress == nil { - p.State = ImageLoadStateImageLoaded - reportProgressIfEnabled(options, p) - - curImageIndex++ - curLayerID = "" - p.State = ImageLoadStateImageWaiting - if curImageIndex < len(imageURIs) { - p.ImageID = imageURIs[curImageIndex] + if imageURIs[curImageIndex].refCounter--; imageURIs[curImageIndex].refCounter == 0 { + p.State = ImageLoadStateImageLoaded + reportProgressIfEnabled(options, p) + + curImageIndex++ + curLayerID = "" + p.ImageID = getImageID(imageURIs, curImageIndex) + p.State = ImageLoadStateImageWaiting } + } else if curLayerID != jm.ID { p.State = ImageLoadStateLayerLoaded reportProgressIfEnabled(options, p) @@ -302,7 +310,7 @@ func reportProgressIfEnabled(opts *LoadImageOptions, p *LoadImageProgress) { opts.ProgressReporter.Update(*p) } -func generateImageLoadingManifestForApp(ctx context.Context, app App, blobsRoot string, options *LoadImageOptions) (imageLoadManifests []*imageLoadManifest, imageURIs []string, layersMap map[string]string, err error) { +func generateImageLoadingManifestForApp(ctx context.Context, app App, blobsRoot string, options *LoadImageOptions) (imageLoadManifests []*imageLoadManifest, imageURIs []imageURI2RefCounter, layersMap map[string]string, err error) { layersMap = make(map[string]string) // Generate the image load manifests for _, imageRoot := range app.GetComposeRoot().Children { @@ -313,7 +321,16 @@ func generateImageLoadingManifestForApp(ctx context.Context, app App, blobsRoot break } imageLoadManifests = append(imageLoadManifests, manifest) - imageURIs = append(imageURIs, imageRoot.Ref()) + var refCounter int + if options.RefWithDigest { + refCounter = 1 + } else { + refCounter = len(manifest.RepoTags) + } + imageURIs = append(imageURIs, imageURI2RefCounter{ + URI: imageRoot.Ref(), + refCounter: refCounter, + }) for index, diffID := range imageConfig.RootFS.DiffIDs { // The first 12 characters of the diffID are used as a key to the layers map, // the map between the diffID and the layer distribution hash. @@ -404,3 +421,12 @@ func generateImageLoadManifest( return loadManifest, &imageConfig, nil } + +func getImageID(imageURIs []imageURI2RefCounter, imageIndex int) string { + if imageIndex < len(imageURIs) { + return imageURIs[imageIndex].URI + } else { + fmt.Printf("Warning: image index %d is out of range (max %d)\n", imageIndex, len(imageURIs)) + return "unknown" + } +}