[amber_ctl] Discourage enabling multiple sources
When adding an enabled source or enabling a specific source, disable all
other enabled sources to prevent amber from picking one at random during
package installs or OTA updates. This change also adds a "-x" flag to
disable this feature, for those that like to live dangerously.
PKG-382 #done
Test: manual, add enabled/disabled sources, enable sources, with and
without -x flag
Change-Id: I1554230a3e19c9ecbe960c5f2a7e2a88de71f36d
diff --git a/go/src/amber/cmd/ctl/amber-control.go b/go/src/amber/cmd/ctl/amber-control.go
index 179f61e..32db611 100644
--- a/go/src/amber/cmd/ctl/amber-control.go
+++ b/go/src/amber/cmd/ctl/amber-control.go
@@ -47,6 +47,7 @@
-n: name of the update source (optional, with URL)
-f: file path or url to a source config file
-h: SHA256 hash of source config file (optional, with URL)
+ -x: do not disable other active sources (if the provided source is enabled)
rm_src - remove a source, if it exists
-n: name of the update source
@@ -55,6 +56,7 @@
enable_src
-n: name of the update source
+ -x: do not disable other active sources
disable_src
-n: name of the update source
@@ -65,14 +67,15 @@
`
var (
- fs = flag.NewFlagSet("default", flag.ExitOnError)
- pkgFile = fs.String("f", "", "Path to a source config file")
- hash = fs.String("h", "", "SHA256 hash of source config file (required if -f is a URL, ignored otherwise)")
- name = fs.String("n", "", "Name of a source or package")
- version = fs.String("v", "", "Version of a package")
- blobID = fs.String("i", "", "Content ID of the blob")
- noWait = fs.Bool("nowait", false, "Return once installation has started, package will not yet be available.")
- merkle = fs.String("m", "", "Merkle root of the desired update.")
+ fs = flag.NewFlagSet("default", flag.ExitOnError)
+ pkgFile = fs.String("f", "", "Path to a source config file")
+ hash = fs.String("h", "", "SHA256 hash of source config file (required if -f is a URL, ignored otherwise)")
+ name = fs.String("n", "", "Name of a source or package")
+ version = fs.String("v", "", "Version of a package")
+ blobID = fs.String("i", "", "Content ID of the blob")
+ noWait = fs.Bool("nowait", false, "Return once installation has started, package will not yet be available.")
+ merkle = fs.String("m", "", "Merkle root of the desired update.")
+ nonExclusive = fs.Bool("x", false, "When adding or enabling a source, do not disable other sources.")
)
type ErrGetFile string
@@ -201,6 +204,12 @@
return fmt.Errorf("IPC encountered an error: %s", err)
}
+ if isSourceConfigEnabled(&cfg) && !*nonExclusive {
+ if err := disableAllSources(a, cfg.Id); err != nil {
+ return err
+ }
+ }
+
return nil
}
@@ -273,9 +282,6 @@
}
func setSourceEnablement(a *amber.ControlInterface, id string, enabled bool) error {
- if id == "" {
- return fmt.Errorf("no source id provided")
- }
status, err := a.SetSrcEnabled(id, enabled)
if err != nil {
return fmt.Errorf("call failure attempting to change source status: %s", err)
@@ -283,6 +289,36 @@
if status != amber.StatusOk {
return fmt.Errorf("failure changing source status")
}
+
+ return nil
+}
+
+func isSourceConfigEnabled(cfg *amber.SourceConfig) bool {
+ if cfg.StatusConfig == nil {
+ return true
+ }
+ return cfg.StatusConfig.Enabled
+}
+
+func disableAllSources(a *amber.ControlInterface, except string) error {
+ errorIds := []string{}
+ cfgs, err := a.ListSrcs()
+ if err != nil {
+ return err
+ }
+ for _, cfg := range cfgs {
+ if cfg.Id != except && isSourceConfigEnabled(&cfg) {
+ if err := setSourceEnablement(a, cfg.Id, false); err != nil {
+ log.Printf("error disabling %q: %s", cfg.Id, err)
+ errorIds = append(errorIds, fmt.Sprintf("%q", cfg.Id))
+ } else {
+ fmt.Printf("Source %q disabled\n", cfg.Id)
+ }
+ }
+ }
+ if len(errorIds) > 0 {
+ return fmt.Errorf("could not disable %s", strings.Join(errorIds, ", "))
+ }
return nil
}
@@ -349,13 +385,27 @@
}
fmt.Printf("On your computer go to:\n\n\t%v\n\nand enter\n\n\t%v\n\n", device.VerificationUrl, device.UserCode)
case "enable_src":
+ if *name == "" {
+ log.Printf("Error enabling source: no source id provided")
+ return 1
+ }
err := setSourceEnablement(proxy, *name, true)
if err != nil {
log.Printf("Error enabling source: %s", err)
return 1
}
fmt.Printf("Source %q enabled\n", *name)
+ if !*nonExclusive {
+ if err := disableAllSources(proxy, *name); err != nil {
+ log.Printf("Error disabling sources: %s", err)
+ return 1
+ }
+ }
case "disable_src":
+ if *name == "" {
+ log.Printf("Error disabling source: no source id provided")
+ return 1
+ }
err := setSourceEnablement(proxy, *name, false)
if err != nil {
log.Printf("Error disabling source: %s", err)