# -*- coding: utf-8 -*-

#-------------------------------------------------------------------------
# drawElements Quality Program utilities
# --------------------------------------
#
# Copyright 2017 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#-------------------------------------------------------------------------

import os
import re
import sys
import argparse
import threading
import subprocess

from build_apk import findSDK
from build_apk import getDefaultBuildRoot
from build_apk import getPackageAndLibrariesForTarget
from build_apk import getBuildRootRelativeAPKPath
from build_apk import parsePackageName

# Import from <root>/scripts
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

from build.common import *

class Device:
	def __init__(self, serial, product, model, device):
		self.serial		= serial
		self.product	= product
		self.model		= model
		self.device		= device

	def __str__ (self):
		return "%s: {product: %s, model: %s, device: %s}" % (self.serial, self.product, self.model, self.device)

def getDevices (adbPath):
	proc = subprocess.Popen([adbPath, 'devices', '-l'], stdout=subprocess.PIPE)
	(stdout, stderr) = proc.communicate()

	if proc.returncode != 0:
		raise Exception("adb devices -l failed, got %d" % proc.returncode)

	ptrn = re.compile(r'^([a-zA-Z0-9\.\-:]+)\s+.*product:([^\s]+)\s+model:([^\s]+)\s+device:([^\s]+)')
	devices = []
	for line in stdout.splitlines()[1:]:
		if len(line.strip()) == 0:
			continue

		m = ptrn.match(line.decode('utf-8'))
		if m == None:
			print("WARNING: Failed to parse device info '%s'" % line)
			continue

		devices.append(Device(m.group(1), m.group(2), m.group(3), m.group(4)))

	return devices

def execWithPrintPrefix (args, linePrefix="", failOnNonZeroExit=True):

	def readApplyPrefixAndPrint (source, prefix, sink):
		while True:
			line = source.readline()
			if len(line) == 0: # EOF
				break;
			sink.write(prefix + line.decode('utf-8'))

	process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
	stdoutJob = threading.Thread(target=readApplyPrefixAndPrint, args=(process.stdout, linePrefix, sys.stdout))
	stderrJob = threading.Thread(target=readApplyPrefixAndPrint, args=(process.stderr, linePrefix, sys.stderr))
	stdoutJob.start()
	stderrJob.start()
	retcode = process.wait()
	if failOnNonZeroExit and retcode != 0:
		raise Exception("Failed to execute '%s', got %d" % (str(args), retcode))

def serialApply (f, argsList):
	for args in argsList:
		f(*args)

def parallelApply (f, argsList):
	class ErrorCode:
		def __init__ (self):
			self.error = None;

	def applyAndCaptureError (func, args, errorCode):
		try:
			func(*args)
		except:
			errorCode.error = sys.exc_info()

	errorCode = ErrorCode()
	jobs = []
	for args in argsList:
		job = threading.Thread(target=applyAndCaptureError, args=(f, args, errorCode))
		job.start()
		jobs.append(job)

	for job in jobs:
		job.join()

	if errorCode.error:
		raise errorCode.error[0](errorCode.error[1]).with_traceback(errorCode.error[2])

def uninstall (adbPath, packageName, extraArgs = [], printPrefix=""):
	print(printPrefix + "Removing existing %s...\n" % packageName,)
	execWithPrintPrefix([adbPath] + extraArgs + [
			'uninstall',
			packageName
		], printPrefix, failOnNonZeroExit=False)
	print(printPrefix + "Remove complete\n",)

def install (adbPath, apkPath, extraArgs = [], printPrefix=""):
	print(printPrefix + "Installing %s...\n" % apkPath,)
	execWithPrintPrefix([adbPath] + extraArgs + [
			'install',
			apkPath
		], printPrefix)
	print(printPrefix + "Install complete\n",)

def installToDevice (device, adbPath, packageName, apkPath, printPrefix=""):
	if len(printPrefix) == 0:
		print("Installing to %s (%s)...\n" % (device.serial, device.model), end='')
	else:
		print(printPrefix + "Installing to %s\n" % device.serial, end='')

	uninstall(adbPath, packageName, ['-s', device.serial], printPrefix)
	install(adbPath, apkPath, ['-s', device.serial], printPrefix)

def installToDevices (devices, doParallel, adbPath, packageName, apkPath):
	padLen = max([len(device.model) for device in devices])+1
	if doParallel:
		parallelApply(installToDevice, [(device, adbPath, packageName, apkPath, ("(%s):%s" % (device.model, ' ' * (padLen - len(device.model))))) for device in devices]);
	else:
		serialApply(installToDevice, [(device, adbPath, packageName, apkPath) for device in devices]);

def installToAllDevices (doParallel, adbPath, packageName, apkPath):
	devices = getDevices(adbPath)
	installToDevices(devices, doParallel, adbPath, packageName, apkPath)

def getAPKPath (buildRootPath, target):
	package = getPackageAndLibrariesForTarget(target)[0]
	return os.path.join(buildRootPath, getBuildRootRelativeAPKPath(package))

def getPackageName (target):
	package			= getPackageAndLibrariesForTarget(target)[0]
	manifestPath	= os.path.join(DEQP_DIR, "android", package.appDirName, "AndroidManifest.xml")

	return parsePackageName(manifestPath)

def findADB ():
	adbInPath = which("adb")
	if adbInPath != None:
		return adbInPath

	sdkPath = findSDK()
	if sdkPath != None:
		adbInSDK = os.path.join(sdkPath, "platform-tools", "adb")
		if os.path.isfile(adbInSDK):
			return adbInSDK

	return None

def parseArgs ():
	defaultADBPath		= findADB()
	defaultBuildRoot	= getDefaultBuildRoot()

	parser = argparse.ArgumentParser(os.path.basename(__file__),
		formatter_class=argparse.ArgumentDefaultsHelpFormatter)
	parser.add_argument('--build-root',
		dest='buildRoot',
		default=defaultBuildRoot,
		help="Root build directory")
	parser.add_argument('--adb',
		dest='adbPath',
		default=defaultADBPath,
		help="ADB binary path",
		required=(True if defaultADBPath == None else False))
	parser.add_argument('--target',
		dest='target',
		help='Build target',
		choices=['deqp', 'openglcts'],
		default='deqp')
	parser.add_argument('-p', '--parallel',
		dest='doParallel',
		action="store_true",
		help="Install package in parallel")
	parser.add_argument('-s', '--serial',
		dest='serial',
		type=str,
		nargs='+',
		help="Install package to device with serial number")
	parser.add_argument('-a', '--all',
		dest='all',
		action="store_true",
		help="Install to all devices")

	return parser.parse_args()

if __name__ == "__main__":
	args		= parseArgs()
	packageName	= getPackageName(args.target)
	apkPath		= getAPKPath(args.buildRoot, args.target)

	if not os.path.isfile(apkPath):
		die("%s does not exist" % apkPath)

	if args.all:
		installToAllDevices(args.doParallel, args.adbPath, packageName, apkPath)
	else:
		if args.serial == None:
			devices = getDevices(args.adbPath)
			if len(devices) == 0:
				die('No devices connected')
			elif len(devices) == 1:
				installToDevice(devices[0], args.adbPath, packageName, apkPath)
			else:
				print("More than one device connected:")
				for i in range(0, len(devices)):
					print("%3d: %16s %s" % ((i+1), devices[i].serial, devices[i].model))

				deviceNdx = int(input("Choose device (1-%d): " % len(devices)))
				installToDevice(devices[deviceNdx-1], args.adbPath, packageName, apkPath)
		else:
			devices = getDevices(args.adbPath)

			devices = [dev for dev in devices if dev.serial in args.serial]
			devSerials = [dev.serial for dev in devices]
			notFounds = [serial for serial in args.serial if not serial in devSerials]

			for notFound in notFounds:
				print("Couldn't find device matching serial '%s'" % notFound)

			installToDevices(devices, args.doParallel, args.adbPath, packageName, apkPath)
