Unify input data format

The input data is described in YAML or JSON format
Also implemented support of priorities for repositories

Change-Id: I02f11714ba8880dd06c3ceeadf230c1d812ff0be
Implements:  blueprint unify-input-data
This commit is contained in:
Bulat Gaifullin 2016-01-22 17:22:44 +03:00 committed by Bulat Gaifullin
parent bf821ada23
commit 1ce69b4fef
32 changed files with 1223 additions and 1231 deletions

View File

@ -29,5 +29,11 @@ __all__ = [
"RepositoryApi",
]
__version__ = pbr.version.VersionInfo(
'packetary').version_string()
try:
__version__ = pbr.version.VersionInfo(
'packetary').version_string()
except Exception as e:
# when run tests without installing package
# pbr may raise exception.
print("ERROR:", e)
__version__ = "0.0.0"

View File

@ -16,6 +16,7 @@
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
from collections import defaultdict
import logging
import six
@ -23,8 +24,8 @@ import six
from packetary.controllers import RepositoryController
from packetary.library.connections import ConnectionsManager
from packetary.library.executor import AsynchronousSection
from packetary.objects import Index
from packetary.objects import PackageRelation
from packetary.objects import PackagesForest
from packetary.objects import PackagesTree
from packetary.objects.statistics import CopyStatistics
@ -111,120 +112,111 @@ class RepositoryApi(object):
context = config if isinstance(config, Context) else Context(config)
return cls(RepositoryController.load(context, repotype, repoarch))
def get_packages(self, origin, debs=None, requirements=None):
def get_packages(self, repos_data, requirements_data=None,
include_mandatory=False):
"""Gets the list of packages from repository(es).
:param origin: The list of repository`s URLs
:param debs: the list of repository`s URL to calculate list of
dependencies, that will be used to filter packages.
:param requirements: the list of package relations,
to resolve the list of mandatory packages.
:param repos_data: The list of repository descriptions
:param requirements_data: The list of package`s requirements
that should be included
:param include_mandatory: if True, all mandatory packages will be
:return: the set of packages
"""
repositories = self._get_repositories(origin)
return self._get_packages(repositories, debs, requirements)
repos = self._load_repositories(repos_data)
requirements = self._load_requirements(requirements_data)
return self._get_packages(repos, requirements, include_mandatory)
def clone_repositories(self, origin, destination, debs=None,
requirements=None, keep_existing=True,
include_source=False, include_locale=False):
def clone_repositories(self, repos_data, requirements_data, destination,
include_source=False, include_locale=False,
include_mandatory=False):
"""Creates the clones of specified repositories in local folder.
:param origin: The list of repository`s URLs
:param repos_data: The list of repository descriptions
:param requirements_data: The list of package`s requirements
that should be included
:param destination: the destination folder path
:param debs: the list of repository`s URL to calculate list of
dependencies, that will be used to filter packages.
:param requirements: the list of package relations,
to resolve the list of mandatory packages.
:param keep_existing: If False - local packages that does not exist
in original repo will be removed.
:param include_source: if True, the source packages
will be copied as well.
:param include_locale: if True, the locales
will be copied as well.
:param include_locale: if True, the locales will be copied as well.
:param include_mandatory: if True, all mandatory packages will be
included
:return: count of copied and total packages.
"""
repositories = self._get_repositories(origin)
packages = self._get_packages(repositories, debs, requirements)
mirrors = self.controller.clone_repositories(
repositories, destination, include_source, include_locale
)
package_groups = dict((x, set()) for x in repositories)
for pkg in packages:
repos = self._load_repositories(repos_data)
reqs = self._load_requirements(requirements_data)
all_packages = self._get_packages(repos, reqs, include_mandatory)
package_groups = defaultdict(set)
for pkg in all_packages:
package_groups[pkg.repository].add(pkg)
stat = CopyStatistics()
mirrors = defaultdict(set)
# group packages by mirror
for repo, packages in six.iteritems(package_groups):
mirror = mirrors[repo]
logger.info("copy packages from - %s", repo)
self.controller.copy_packages(
mirror, packages, keep_existing, stat.on_package_copied
mirror = self.controller.fork_repository(
repo, destination, include_source, include_locale
)
mirrors[mirror].update(packages)
# add new packages to mirrors
for mirror, packages in six.iteritems(mirrors):
self.controller.assign_packages(
mirror, packages, stat.on_package_copied
)
return stat
def get_unresolved_dependencies(self, origin, main=None):
def get_unresolved_dependencies(self, repos_data):
"""Gets list of unresolved dependencies for repository(es).
:param origin: The list of repository`s URLs
:param main: The main repository(es) URL
:param repos_data: The list of repository descriptions
:return: list of unresolved dependencies
"""
packages = PackagesTree()
self.controller.load_packages(
self._get_repositories(origin),
packages.add
)
self._load_packages(self._load_repositories(repos_data), packages.add)
return packages.get_unresolved_dependencies()
if main is not None:
base = Index()
self.controller.load_packages(
self._get_repositories(main),
base.add
)
else:
base = None
return packages.get_unresolved_dependencies(base)
def _get_repositories(self, urls):
"""Gets the set of repositories by url."""
repositories = set()
self.controller.load_repositories(urls, repositories.add)
return repositories
def _get_packages(self, repositories, master, requirements):
"""Gets the list of packages according to master and requirements."""
if master is None and requirements is None:
packages = set()
self.controller.load_packages(repositories, packages.add)
return packages
packages = PackagesTree()
self.controller.load_packages(repositories, packages.add)
if master is not None:
main_index = Index()
self.controller.load_packages(
self._get_repositories(master),
main_index.add
)
else:
main_index = None
return packages.get_minimal_subset(
main_index,
self._parse_requirements(requirements)
)
@staticmethod
def _parse_requirements(requirements):
"""Gets the list of relations from requirements.
:param requirements: the list of requirement in next format:
'name [cmp version]|[alt [cmp version]]'
"""
def _get_packages(self, repos, requirements, include_mandatory):
if requirements is not None:
return set(
PackageRelation.from_args(
*(x.split() for x in r.split("|"))) for r in requirements
)
return set()
forest = PackagesForest()
for repo in repos:
self.controller.load_packages(repo, forest.add_tree().add)
return forest.get_packages(requirements, include_mandatory)
packages = set()
self._load_packages(repos, packages.add)
return packages
def _load_packages(self, repos, consumer):
for repo in repos:
self.controller.load_packages(repo, consumer)
def _load_repositories(self, repos_data):
self._validate_repos_data(repos_data)
return self.controller.load_repositories(repos_data)
def _load_requirements(self, requirements_data):
if requirements_data is None:
return
self._validate_requirements_data(requirements_data)
result = []
for r in requirements_data:
self._validate_requirements_data(r)
versions = r.get('versions', None)
if versions is None:
result.append(PackageRelation.from_args((r['name'],)))
else:
for version in versions:
result.append(PackageRelation.from_args(
([r['name']] + version.split(None, 1))
))
return result
def _validate_repos_data(self, repos_data):
# TODO(bgaifullin) implement me
pass
def _validate_requirements_data(self, requirements_data):
# TODO(bgaifullin) implement me
pass

View File

@ -22,7 +22,7 @@ from cliff import command
import six
from packetary.cli.commands.utils import make_display_attr_getter
from packetary.cli.commands.utils import read_lines_from_file
from packetary.cli.commands.utils import read_from_file
from packetary import RepositoryApi
@ -56,21 +56,15 @@ class BaseRepoCommand(command.Command):
default="x86_64",
help='The target architecture.')
origin_gr = parser.add_mutually_exclusive_group(required=True)
origin_gr.add_argument(
'-o', '--origin-url',
nargs="+",
dest='origins',
type=six.text_type,
metavar='URL',
help='Space separated list of URLs of origin repositories.')
origin_gr.add_argument(
'-O', '--origin-file',
type=read_lines_from_file,
dest='origins',
parser.add_argument(
'-r', '--repositories',
dest='repositories',
type=read_from_file,
metavar='FILENAME',
help='The path to file with URLs of origin repositories.')
required=True,
help="The path to file with list of repositories."
"See documentation about format."
)
return parser
@ -98,6 +92,30 @@ class BaseRepoCommand(command.Command):
"""
class PackagesMixin(object):
"""Added arguments to declare list of packages."""
def get_parser(self, prog_name):
parser = super(PackagesMixin, self).get_parser(prog_name)
parser.add_argument(
"--skip-mandatory",
dest='include_mandatory',
action='store_false',
default=True,
help="Do not copy mandatory packages."
)
parser.add_argument(
"-p", "--packages",
dest='requirements',
type=read_from_file,
metavar='FILENAME',
help="The path to file with list of packages."
"See documentation about format."
)
return parser
class BaseProduceOutputCommand(BaseRepoCommand):
columns = None

View File

@ -17,10 +17,10 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
from packetary.cli.commands.base import BaseRepoCommand
from packetary.cli.commands.utils import read_lines_from_file
from packetary.cli.commands.base import PackagesMixin
class CloneCommand(BaseRepoCommand):
class CloneCommand(PackagesMixin, BaseRepoCommand):
"""Clones the specified repository to local folder."""
def get_parser(self, prog_name):
@ -53,52 +53,17 @@ class CloneCommand(BaseRepoCommand):
help="Also copy localisation files."
)
bootstrap_group = parser.add_mutually_exclusive_group(required=False)
bootstrap_group.add_argument(
"-b", "--bootstrap",
nargs='+',
dest='bootstrap',
metavar='PACKAGE [OP VERSION]',
help="Space separated list of package relations, "
"to resolve the list of mandatory packages."
)
bootstrap_group.add_argument(
"-B", "--bootstrap-file",
type=read_lines_from_file,
dest='bootstrap',
metavar='FILENAME',
help="Path to the file with list of package relations, "
"to resolve the list of mandatory packages."
)
requires_group = parser.add_mutually_exclusive_group(required=False)
requires_group.add_argument(
'-r', '--requires-url',
nargs="+",
dest='requires',
metavar='URL',
help="Space separated list of repository`s URL to calculate list "
"of dependencies, that will be used to filter packages")
requires_group.add_argument(
'-R', '--requires-file',
type=read_lines_from_file,
dest='requires',
metavar='FILENAME',
help="The path to the file with list of repository`s URL "
"to calculate list of dependencies, "
"that will be used to filter packages")
return parser
def take_repo_action(self, api, parsed_args):
stat = api.clone_repositories(
parsed_args.origins,
parsed_args.repositories,
parsed_args.requirements,
parsed_args.destination,
parsed_args.requires,
parsed_args.bootstrap,
parsed_args.keep_existing,
parsed_args.sources,
parsed_args.locales
parsed_args.locales,
parsed_args.include_mandatory
)
self.stdout.write(
"Packages copied: {0.copied}/{0.total}.\n".format(stat)

View File

@ -17,10 +17,10 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
from packetary.cli.commands.base import BaseProduceOutputCommand
from packetary.cli.commands.utils import read_lines_from_file
from packetary.cli.commands.base import PackagesMixin
class ListOfPackages(BaseProduceOutputCommand):
class ListOfPackages(PackagesMixin, BaseProduceOutputCommand):
"""Gets the list of packages from repository(es)."""
columns = (
@ -35,51 +35,11 @@ class ListOfPackages(BaseProduceOutputCommand):
"requires",
)
def get_parser(self, prog_name):
parser = super(ListOfPackages, self).get_parser(prog_name)
bootstrap_group = parser.add_mutually_exclusive_group(required=False)
bootstrap_group.add_argument(
"-b", "--bootstrap",
nargs='+',
dest='bootstrap',
metavar='PACKAGE [OP VERSION]',
help="Space separated list of package relations, "
"to resolve the list of mandatory packages."
)
bootstrap_group.add_argument(
"-B", "--bootstrap-file",
type=read_lines_from_file,
dest='bootstrap',
metavar='FILENAME',
help="Path to the file with list of package relations, "
"to resolve the list of mandatory packages."
)
requires_group = parser.add_mutually_exclusive_group(required=False)
requires_group.add_argument(
'-r', '--requires-url',
nargs="+",
dest='requires',
metavar='URL',
help="Space separated list of repository`s URL to calculate list "
"of dependencies, that will be used to filter packages")
requires_group.add_argument(
'-R', '--requires-file',
type=read_lines_from_file,
dest='requires',
metavar='FILENAME',
help="The path to the file with list of repository`s URL "
"to calculate list of dependencies, "
"that will be used to filter packages")
return parser
def take_repo_action(self, api, parsed_args):
return api.get_packages(
parsed_args.origins,
parsed_args.requires,
parsed_args.bootstrap,
parsed_args.repositories,
parsed_args.requirements,
parsed_args.include_mandatory
)

View File

@ -17,7 +17,6 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
from packetary.cli.commands.base import BaseProduceOutputCommand
from packetary.cli.commands.utils import read_lines_from_file
class ListOfUnresolved(BaseProduceOutputCommand):
@ -29,31 +28,9 @@ class ListOfUnresolved(BaseProduceOutputCommand):
"alternative",
)
def get_parser(self, prog_name):
parser = super(ListOfUnresolved, self).get_parser(prog_name)
main_group = parser.add_mutually_exclusive_group(required=False)
main_group.add_argument(
'-m', '--main-url',
nargs="+",
dest='main',
metavar='URL',
help='Space separated list of URLs of repository(es) '
' that are used to resolve dependencies.')
main_group.add_argument(
'-M', '--main-file',
type=read_lines_from_file,
dest='main',
metavar='FILENAME',
help='The path to the file, that contains '
'list of URLs of repository(es) '
' that are used to resolve dependencies.')
return parser
def take_repo_action(self, api, parsed_args):
return api.get_unresolved_dependencies(
parsed_args.origins,
parsed_args.main,
parsed_args.repositories
)

View File

@ -16,25 +16,44 @@
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
import operator
import json
import os
import yaml
import six
def read_lines_from_file(filename):
_PARSERS = {
"": yaml.safe_load,
".json": json.load,
".yaml": yaml.safe_load,
".yml": yaml.safe_load,
}
def read_from_file(filename):
"""Reads lines from file.
Note: the line starts with '#' will be skipped.
:param filename: the path of target file
:return: the list of lines from file
:raise ValuerError: when file-ext is unknown.
"""
if filename is None:
return
file_ext = os.path.splitext(filename)[-1].lower()
try:
parser = _PARSERS[file_ext]
except KeyError:
raise ValueError("Unsupported file format: {0}.\n"
"Please use '.json' or '.yaml' file extension"
.format(file_ext))
with open(filename, 'r') as f:
return [
x
for x in six.moves.map(operator.methodcaller("strip"), f)
if x and not x.startswith("#")
]
return parser(f)
def get_object_attrs(obj, attrs):

View File

@ -22,6 +22,7 @@ import os
import six
import stevedore
from packetary.library import utils
logger = logging.getLogger(__package__)
@ -58,114 +59,86 @@ class RepositoryController(object):
)
return cls(context, driver, repoarch)
def load_repositories(self, urls, consumer):
def load_repositories(self, repositories_data):
"""Loads the repository objects from url.
:param urls: the list of repository urls.
:param consumer: the callback to consume objects
:param repositories_data: the list of repository`s descriptions
:return: the list of repositories sorted according to priority
"""
if isinstance(urls, six.string_types):
urls = [urls]
connection = self.context.connection
for parsed_url in self.driver.parse_urls(urls):
repositories_data.sort(key=self.driver.priority_sort)
repos = []
for repo_data in repositories_data:
self.driver.get_repository(
connection, parsed_url, self.arch, consumer
connection, repo_data, self.arch, repos.append
)
return repos
def load_packages(self, repositories, consumer):
def load_packages(self, repository, consumer):
"""Loads packages from repository.
:param repositories: the repository object
:param repository: the repository object
:param consumer: the callback to consume objects
"""
connection = self.context.connection
for r in repositories:
self.driver.get_packages(connection, r, consumer)
self.driver.get_packages(connection, repository, consumer)
def assign_packages(self, repository, packages, keep_existing=True):
def fork_repository(self, repository, destination, source, locale):
"""Creates copy of repositories.
:param repository: the origin repository
:param destination: the target folder
:param source: If True, the source packages will be copied too.
:param locale: If True, the localisation will be copied too.
:return: the mapping origin to cloned repository.
"""
new_path = os.path.join(
destination,
repository.path or utils.get_path_from_url(repository.url, False)
)
return self.driver.fork_repository(
self.context.connection, repository, new_path, source, locale
)
def assign_packages(self, repository, packages, observer=None):
"""Assigns new packages to the repository.
It replaces the current repository`s packages.
:param repository: the target repository
:param packages: the set of new packages
:param keep_existing:
if True, all existing packages will be kept as is.
if False, all existing packages, that are not included
to new packages will be removed.
:param observer: the package copying process observer
"""
if not isinstance(packages, set):
packages = set(packages)
else:
packages = packages.copy()
if keep_existing:
consume_exist = packages.add
else:
def consume_exist(package):
if package not in packages:
filepath = os.path.join(
package.repository.url, package.filename
)
logger.info("remove package - %s.", filepath)
os.remove(filepath)
self.driver.get_packages(
self.context.connection, repository, consume_exist
self._copy_packages(repository, packages, observer)
self.driver.add_packages(
self.context.connection, repository, packages
)
self.driver.rebuild_repository(repository, packages)
def copy_packages(self, repository, packages, keep_existing, observer):
"""Copies packages to repository.
:param repository: the target repository
:param packages: the set of packages
:param keep_existing: see assign_packages for more details
:param observer: the package copying process observer
"""
def _copy_packages(self, target, packages, observer):
with self.context.async_section() as section:
for package in packages:
section.execute(
self._copy_package, repository, package, observer
self._copy_package, target, package, observer
)
self.assign_packages(repository, packages, keep_existing)
def clone_repositories(self, repositories, destination,
source=False, locale=False):
"""Creates copy of repositories.
:param repositories: the origin repositories
:param destination: the target folder
:param source: If True, the source packages will be copied too.
:param locale: If True, the localisation will be copied too.
:return: the mapping origin to cloned repository.
"""
mirros = dict()
destination = os.path.abspath(destination)
with self.context.async_section(0) as section:
for r in repositories:
section.execute(
self._fork_repository,
r, destination, source, locale, mirros
)
return mirros
def _fork_repository(self, r, destination, source, locale, mirrors):
"""Creates clone of repository and stores it in mirrors."""
new_repository = self.driver.fork_repository(
self.context.connection, r, destination, source, locale
)
mirrors[r] = new_repository
def _copy_package(self, target, package, observer):
"""Synchronises remote file to local fs."""
dst_path = os.path.join(target.url, package.filename)
src_path = urljoin(package.repository.url, package.filename)
bytes_copied = self.context.connection.retrieve(
src_path, dst_path, size=package.filesize
)
if package.filesize < 0:
package.filesize = bytes_copied
observer(bytes_copied)
bytes_copied = 0
if target.url != package.repository.url:
dst_path = os.path.join(
utils.get_path_from_url(target.url), package.filename
)
src_path = urljoin(package.repository.url, package.filename)
bytes_copied = self.context.connection.retrieve(
src_path, dst_path, size=package.filesize
)
if package.filesize < 0:
package.filesize = bytes_copied
if observer:
observer(bytes_copied)

View File

@ -35,18 +35,11 @@ class RepositoryDriverBase(object):
self.logger = logging.getLogger(__package__)
@abc.abstractmethod
def parse_urls(self, urls):
"""Parses the repository url.
:return: the sequence of parsed urls
"""
@abc.abstractmethod
def get_repository(self, connection, url, arch, consumer):
def get_repository(self, connection, repository_data, arch, consumer):
"""Loads the repository meta information from URL.
:param connection: the connection manager instance
:param url: the repository`s url
:param repository_data: the repository`s url
:param arch: the repository`s architecture
:param consumer: the callback to consume result
"""
@ -74,9 +67,19 @@ class RepositoryDriverBase(object):
"""
@abc.abstractmethod
def rebuild_repository(self, repository, packages):
"""Re-builds the repository.
def add_packages(self, connection, repository, packages):
"""Adds new packages to the repository.
:param connection: the connection manager instance
:param repository: the target repository
:param packages: the set of packages
"""
@abc.abstractmethod
def priority_sort(self, repo_data):
"""Key method to sort repositories data by priority.
:param repo_data: the repository`s description
:return: the integer value that is relevant repository`s priority
less number means greater priority
"""

View File

@ -39,11 +39,11 @@ from packetary.objects import Repository
_OPERATORS_MAPPING = {
'>>': 'gt',
'<<': 'lt',
'=': 'eq',
'>=': 'ge',
'<=': 'le',
'>>': '>',
'<<': '<',
'=': '=',
'>=': '>=',
'<=': '<=',
}
_ARCHITECTURES = {
@ -77,46 +77,49 @@ _CHECKSUM_METHODS = (
"SHA256"
)
_DEFAULT_PRIORITY = 500
_checksum_collector = checksum_composite('md5', 'sha1', 'sha256')
class DebRepositoryDriver(RepositoryDriverBase):
def parse_urls(self, urls):
"""Overrides method of superclass."""
for url in urls:
try:
tokens = iter(x for x in url.split(" ") if x)
base, suite = next(tokens), next(tokens)
components = list(tokens)
except StopIteration:
raise ValueError("Invalid url: {0}".format(url))
def priority_sort(self, repo_data):
# DEB repository expects general values from 0 to 1000. 0
# to have lowest priority and 1000 -- the highest. Note that a
# priority above 1000 will allow even downgrades no matter the version
# of the prioritary package
priority = repo_data.get('priority')
if priority is None:
priority = _DEFAULT_PRIORITY
return -priority
base = base.rstrip("/")
if base.endswith("/dists"):
base = base[:-6]
def get_repository(self, connection, repository_data, arch, consumer):
url = utils.normalize_repository_url(repository_data['url'])
suite = repository_data['suite']
components = repository_data.get('section')
path = repository_data.get('path')
name = repository_data.get('name')
# TODO(Flat Repository Format[1])
# [1] https://wiki.debian.org/RepositoryFormat
for component in components:
yield (base, suite, component)
# TODO(bgaifullin) implement support for flat repisotory format [1]
# [1] https://wiki.debian.org/RepositoryFormat#Flat_Repository_Format
if components is None:
raise ValueError("The flat format does not supported.")
def get_repository(self, connection, url, arch, consumer):
"""Overrides method of superclass."""
base, suite, component = url
release = self._get_url_of_metafile(
(base, suite, component, arch), "Release"
)
deb_release = deb822.Release(connection.open_stream(release))
consumer(Repository(
name=(deb_release["Archive"], deb_release["Component"]),
architecture=arch,
origin=deb_release["origin"],
url=base + "/"
))
for component in components:
release = self._get_url_of_metafile(
(url, suite, component, arch), "Release"
)
deb_release = deb822.Release(connection.open_stream(release))
consumer(Repository(
name=name,
architecture=arch,
origin=deb_release["origin"],
url=url,
section=(suite, component),
path=path
))
def get_packages(self, connection, repository, consumer):
"""Overrides method of superclass."""
index = self._get_url_of_metafile(repository, "Packages.gz")
stream = GzipDecompress(connection.open_stream(index))
self.logger.info("loading packages from %s ...", repository)
@ -140,7 +143,8 @@ class DebRepositoryDriver(RepositoryDriverBase):
requires=self._get_relations(
dpkg, "depends", "pre-depends", "recommends"
),
obsoletes=self._get_relations(dpkg, "replaces"),
# The deb does not have obsoletes section
obsoletes=[],
provides=self._get_relations(dpkg, "provides"),
))
except KeyError as e:
@ -153,8 +157,7 @@ class DebRepositoryDriver(RepositoryDriverBase):
self.logger.info("loaded: %d packages from %s.", counter, repository)
def rebuild_repository(self, repository, packages):
"""Overrides method of superclass."""
def add_packages(self, connection, repository, packages):
basedir = utils.get_path_from_url(repository.url)
index_file = utils.get_path_from_url(
self._get_url_of_metafile(repository, "Packages")
@ -162,6 +165,8 @@ class DebRepositoryDriver(RepositoryDriverBase):
utils.ensure_dir_exist(os.path.dirname(index_file))
index_gz = index_file + ".gz"
count = 0
# load existing packages
self.get_packages(connection, repository, packages.add)
with open(index_file, "wb") as fd1:
with closing(gzip.open(index_gz, "wb")) as fd2:
writer = utils.composite_writer(fd1, fd2)
@ -185,7 +190,7 @@ class DebRepositoryDriver(RepositoryDriverBase):
# TODO(download gpk)
# TODO(sources and locales)
new_repo = copy.copy(repository)
new_repo.url = utils.localize_repo_url(destination, repository.url)
new_repo.url = utils.normalize_repository_url(destination)
packages_file = utils.get_path_from_url(
self._get_url_of_metafile(new_repo, "Packages")
)
@ -200,8 +205,8 @@ class DebRepositoryDriver(RepositoryDriverBase):
release = deb822.Release()
release["Origin"] = repository.origin
release["Label"] = repository.origin
release["Archive"] = repository.name[0]
release["Component"] = repository.name[1]
release["Archive"] = repository.section[0]
release["Component"] = repository.section[1]
release["Architecture"] = _ARCHITECTURES[repository.architecture]
with open(release_file, "wb") as fd:
release.dump(fd)
@ -214,7 +219,7 @@ class DebRepositoryDriver(RepositoryDriverBase):
"""Updates the Release file in the suite."""
path = os.path.join(
utils.get_path_from_url(repository.url),
"dists", repository.name[0]
"dists", repository.section[0]
)
release_path = os.path.join(path, "Release")
self.logger.info(
@ -304,7 +309,7 @@ class DebRepositoryDriver(RepositoryDriverBase):
"""
if isinstance(repo_or_comps, Repository):
baseurl = repo_or_comps.url
suite, component = repo_or_comps.name
suite, component = repo_or_comps.section
arch = repo_or_comps.architecture
else:
baseurl, suite, component, arch = repo_or_comps
@ -329,12 +334,12 @@ class DebRepositoryDriver(RepositoryDriverBase):
)
release.setdefault("Origin", repository.origin)
release.setdefault("Label", repository.origin)
release.setdefault("Suite", repository.name[0])
release.setdefault("Codename", repository.name[0].split("-", 1)[0])
release.setdefault("Suite", repository.section[0])
release.setdefault("Codename", repository.section[0].split("-", 1)[0])
release.setdefault("Description", "The packages repository.")
keys = ("Architectures", "Components")
values = (repository.architecture, repository.name[1])
values = (repository.architecture, repository.section[1])
for key, value in six.moves.zip(keys, values):
if key in release:
release[key] = utils.append_token_to_string(

View File

@ -49,6 +49,17 @@ _NAMESPACES = {
"rpm": "http://linux.duke.edu/metadata/rpm"
}
_OPERATORS_MAPPING = {
'GT': '>',
'LT': '<',
'EQ': '=',
'GE': '>=',
'LE': '<=',
}
_DEFAULT_PRIORITY = 10
class CreaterepoCallBack(object):
"""Callback object for createrepo"""
@ -69,21 +80,25 @@ class CreaterepoCallBack(object):
class RpmRepositoryDriver(RepositoryDriverBase):
def parse_urls(self, urls):
"""Overrides method of superclass."""
return (url.rstrip("/") for url in urls)
def priority_sort(self, repo_data):
# DEB repository expects general values from 0 to 1000. 0
# to have lowest priority and 1000 -- the highest. Note that a
# priority above 1000 will allow even downgrades no matter the version
# of the prioritary package
priority = repo_data.get('priority')
if priority is None:
priority = _DEFAULT_PRIORITY
return priority
def get_repository(self, connection, url, arch, consumer):
name = utils.get_path_from_url(url, False)
def get_repository(self, connection, repository_data, arch, consumer):
consumer(Repository(
name=name,
url=url + "/",
name=repository_data['name'],
url=utils.normalize_repository_url(repository_data["url"]),
architecture=arch,
origin=""
))
def get_packages(self, connection, repository, consumer):
"""Overrides method of superclass."""
baseurl = repository.url
repomd = urljoin(baseurl, "repodata/repomd.xml")
self.logger.debug("repomd: %s", repomd)
@ -130,8 +145,7 @@ class RpmRepositoryDriver(RepositoryDriverBase):
counter += 1
self.logger.info("loaded: %d packages from %s.", counter, repository)
def rebuild_repository(self, repository, packages):
"""Overrides method of superclass."""
def add_packages(self, connection, repository, packages):
basepath = utils.get_path_from_url(repository.url)
self.logger.info("rebuild repository in %s", basepath)
md_config = createrepo.MetaDataConfig()
@ -165,12 +179,12 @@ class RpmRepositoryDriver(RepositoryDriverBase):
# TODO(download gpk)
# TODO(sources and locales)
new_repo = copy.copy(repository)
new_repo.url = utils.localize_repo_url(destination, repository.url)
new_repo.url = utils.normalize_repository_url(destination)
self.logger.info(
"clone repository %s to %s", repository, new_repo.url
)
utils.ensure_dir_exist(new_repo.url)
self.rebuild_repository(new_repo, set())
utils.ensure_dir_exist(destination)
self.add_packages(connection, new_repo, set())
return new_repo
def _load_db(self, connection, baseurl, repomd, *aliases):
@ -264,7 +278,7 @@ class RpmRepositoryDriver(RepositoryDriverBase):
return (
attrs['name'],
attrs["flags"].lower(),
_OPERATORS_MAPPING[attrs["flags"]],
self._unparse_version_attrs(attrs)
)

View File

@ -79,7 +79,7 @@ def get_path_from_url(url, ensure_file=True):
:param url: the URL
:param ensure_file: If True, ensure that scheme is "file"
:return: the path component from URL
:raises ValueError
:raise ValueError: if expected local path and schema of URL is not file
"""
comps = urlparse(url, scheme="file")
@ -92,14 +92,27 @@ def get_path_from_url(url, ensure_file=True):
return comps.path
def localize_repo_url(localurl, repo_url):
"""Gets local repository url.
def get_url_from_path(path):
"""Get the URL from local path.
:param localurl: the base local URL
:param repo_url: the origin URL of repository
:return: localurl + get_path_from_url(repo_url)
:param path: the local path
:return: the URL
"""
return localurl.rstrip("/") + urlparse(repo_url).path
path = os.path.abspath(path)
if os.sep != "/":
path = path.replace(os.sep, "/")
return "file://" + path
def normalize_repository_url(url):
"""Convert URL of repository to normal form.
:param url: the origin URL
:return: normalized URL
"""
if url and url[0] in ("/", "."):
url = get_url_from_path(url)
return url.rstrip("/") + "/"
def ensure_dir_exist(path):

View File

@ -22,6 +22,7 @@ from packetary.objects.package import Package
from packetary.objects.package_relation import PackageRelation
from packetary.objects.package_relation import VersionRange
from packetary.objects.package_version import PackageVersion
from packetary.objects.packages_forest import PackagesForest
from packetary.objects.packages_tree import PackagesTree
from packetary.objects.repository import Repository
@ -31,6 +32,7 @@ __all__ = [
"Index",
"Package",
"PackageRelation",
"PackagesForest",
"PackagesTree",
"PackageVersion",
"Repository",

View File

@ -66,16 +66,22 @@ def _lowerbound_end(versions, version, condition):
return result
def _equal(tree, version):
"""Gets the package with specified version."""
if version in tree:
return [tree[version]]
return []
def _equal(versions, version):
"""Gets the package with specified version.
:param versions: the tree of versions.
:param version: the required version
"""
value = versions.get(version, None)
return [] if value is None else [value]
def _any(tree, _):
"""Gets the package with max version."""
return list(tree.values())
def _any(versions, _):
"""Gets the package with max version.
:param versions: the tree of versions.
"""
return list(versions.values())
class Index(object):
@ -91,17 +97,15 @@ class Index(object):
operators = {
None: _any,
"lt": _make_operator(_start_upperbound, operator.lt),
"le": _make_operator(_start_upperbound, operator.le),
"gt": _make_operator(_lowerbound_end, operator.gt),
"ge": _make_operator(_lowerbound_end, operator.ge),
"eq": _equal,
"<": _make_operator(_start_upperbound, operator.lt),
"<=": _make_operator(_start_upperbound, operator.le),
">": _make_operator(_lowerbound_end, operator.gt),
">=": _make_operator(_lowerbound_end, operator.ge),
"=": _equal,
}
def __init__(self):
self.packages = defaultdict(FastRBTree)
self.obsoletes = defaultdict(FastRBTree)
self.provides = defaultdict(FastRBTree)
def __iter__(self):
"""Iterates over all packages including versions."""
@ -115,6 +119,10 @@ class Index(object):
0
)
def __contains__(self, name):
"""Checks that index contains any package with such name."""
return name in self.packages
def get_all(self):
"""Gets sequence from all of packages including versions."""
@ -122,42 +130,15 @@ class Index(object):
for version in versions.values():
yield version
def find(self, name, version):
"""Finds the package by name and range of versions.
:param name: the package`s name.
:param version: the range of versions.
:return: the package if it is found, otherwise None
"""
candidates = self.find_all(name, version)
if len(candidates) > 0:
return candidates[-1]
return None
def find_all(self, name, version):
def find_all(self, name, version_range):
"""Finds the packages by name and range of versions.
:param name: the package`s name.
:param version: the range of versions.
:param version_range: the range of versions.
:return: the list of suitable packages
"""
if name in self.packages:
candidates = self._find_versions(
self.packages[name], version
)
if len(candidates) > 0:
return candidates
if name in self.obsoletes:
return self._resolve_relation(
self.obsoletes[name], version
)
if name in self.provides:
return self._resolve_relation(
self.provides[name], version
)
return self._find_versions(self.packages[name], version_range)
return []
def add(self, package):
@ -166,43 +147,24 @@ class Index(object):
:param package: the package object.
"""
self.packages[package.name][package.version] = package
key = package.name, package.version
for obsolete in package.obsoletes:
self.obsoletes[obsolete.name][key] = obsolete
for provide in package.provides:
self.provides[provide.name][key] = provide
def _resolve_relation(self, relations, version):
"""Resolve relation according to relations index.
:param relations: the index of relations
:param version: the range of versions
:return: package if found, otherwise None
"""
for key, candidate in relations.iter_items(reverse=True):
if candidate.version.has_intersection(version):
return [self.packages[key[0]][key[1]]]
return []
@staticmethod
def _find_versions(versions, version):
def _find_versions(versions, version_range):
"""Searches accurate version.
Search for the highest version out of intersection
of existing and required range of versions.
:param versions: the existing versions
:param version: the required range of versions
:param version_range: the required range of versions
:return: package if found, otherwise None
"""
try:
op = Index.operators[version.op]
op = Index.operators[version_range.op]
except KeyError:
raise ValueError(
"Unsupported operation: {0}"
.format(version.op)
.format(version_range.op)
)
return op(versions, version.edge)
return op(versions, version_range.edge)

View File

@ -59,10 +59,10 @@ class Package(ComparableObject):
return Package(**self.__dict__)
def __str__(self):
return "{0} {1}".format(self.name, self.version)
return "{0} ({1})".format(self.name, self.version)
def __unicode__(self):
return u"{0} {1}".format(self.name, self.version)
return u"{0} ({1})".format(self.name, self.version)
def __hash__(self):
return hash((self.name, self.version))

View File

@ -19,6 +19,16 @@
import operator
_OPERATORS = {
None: lambda x: True,
'=': operator.eq,
'>': operator.gt,
'<': operator.lt,
'>=': operator.ge,
'<=': operator.le,
}
class VersionRange(object):
"""Describes the range of versions.
@ -27,17 +37,24 @@ class VersionRange(object):
equal, greater, less, greater or equal, less or equal.
"""
__slots__ = ["op", "edge"]
__slots__ = ("op", "edge")
def __init__(self, op=None, edge=None):
"""Initialises.
:param op: the name of operator to compare.
:param edge: the edge of versions.
:raise ValueError: if comparison operator is invalid
"""
if op not in _OPERATORS:
raise ValueError("Invalid comparison operator: '{0}'".format(op))
self.op = op
self.edge = edge
def __contains__(self, point):
return _OPERATORS[self.op](point, self.edge)
def __hash__(self):
return hash((self.op, self.edge))
@ -59,7 +76,11 @@ class VersionRange(object):
return u"any"
def has_intersection(self, other):
"""Checks that 2 ranges has intersection."""
"""Checks that 2 ranges has intersection.
:param other: the candidate to check
:return: True if intersection exists, otherwise False
"""
if not isinstance(other, VersionRange):
raise TypeError(
@ -70,28 +91,16 @@ class VersionRange(object):
if self.op is None or other.op is None:
return True
my_op = getattr(operator, self.op)
other_op = getattr(operator, other.op)
if self.op[0] == other.op[0]:
if self.op[0] == 'l':
if self.edge < other.edge:
return my_op(self.edge, other.edge)
return other_op(other.edge, self.edge)
elif self.op[0] == 'g':
if self.edge > other.edge:
return my_op(self.edge, other.edge)
return other_op(other.edge, self.edge)
if self.op == 'eq':
return other_op(self.edge, other.edge)
if other.op == 'eq':
return my_op(other.edge, self.edge)
return (
my_op(other.edge, self.edge) and
other_op(self.edge, other.edge)
)
if self.op == '=':
return self.edge == other.edge
# the intersection is -inf or +inf
return True
if self.edge == other.edge:
# need to cover case < a and >= a
return self.edge in other and other.edge in self
# all other cases
return self.edge in other or other.edge in self
class PackageRelation(object):
@ -101,7 +110,7 @@ class PackageRelation(object):
and range of versions that satisfies requirement.
"""
__slots__ = ["name", "version", "alternative"]
__slots__ = ("name", "version", "alternative")
def __init__(self, name, version=None, alternative=None):
"""Initialises.

View File

@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
# Copyright 2016 Mirantis, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
import logging
from packetary.objects.packages_tree import PackagesTree
logger = logging.getLogger(__package__)
class PackagesForest(object):
"""Helper class to deal with dependency graph."""
def __init__(self):
self.trees = []
def add_tree(self):
"""Add new tree to end of forest.
:return: The added tree
"""
tree = PackagesTree()
self.trees.append(tree)
return tree
def get_packages(self, requirements, include_mandatory=False):
"""Get the packages according requirements.
:param requirements: the list of requirements
:param include_mandatory: if true, the mandatory packages will be
included to result
:return list of packages to copy
"""
# TODO(bgaifullin): use versions intersection instead of union
# now the all versions that fit requirements are selected
# need to select only one version that fits all requirements
resolved = set()
unresolved = set()
stack = [requirements]
if include_mandatory:
for tree in self.trees:
for mandatory in tree.mandatory_packages:
resolved.add(mandatory)
stack.append(mandatory.requires)
while stack:
requirements = stack.pop()
for required in requirements:
for rel in required:
if rel not in unresolved:
candidate = self.find(rel)
if candidate is not None:
if candidate not in resolved:
stack.append(candidate.requires)
resolved.add(candidate)
break
else:
unresolved.add(required)
logger.warning("Unresolved relation: %s", required)
return resolved
def find(self, relation):
"""Finds package in forest.
:param relation: the package relation
:return: the packages from first tree if found otherwise empty list
"""
for tree in self.trees:
candidate = tree.find(relation.name, relation.version)
if candidate is not None:
return candidate

View File

@ -16,119 +16,98 @@
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
import warnings
from collections import defaultdict
import six
from packetary.objects.index import Index
from packetary.objects.package_relation import VersionRange
class UnresolvedWarning(UserWarning):
"""Warning about unresolved depends."""
pass
class PackagesTree(Index):
class PackagesTree(object):
"""Helper class to deal with dependency graph."""
def __init__(self):
super(PackagesTree, self).__init__()
self.mandatory_packages = []
self.packages = Index()
self.provides = defaultdict(dict)
self.obsoletes = defaultdict(dict)
def add(self, package):
super(PackagesTree, self).add(package)
# store all mandatory packages in separated list for quick access
if package.mandatory:
self.mandatory_packages.append(package)
def get_unresolved_dependencies(self, base=None):
self.packages.add(package)
key = package.name, package.version
for obsolete in package.obsoletes:
self.obsoletes[obsolete.name][key] = obsolete
for provide in package.provides:
self.provides[provide.name][key] = provide
def find(self, name, version_range):
"""Finds the package by name and range of versions.
:param name: the package`s name.
:param version_range: the range of versions.
:return: the package if it is found, otherwise None
"""
candidates = self.find_all(name, version_range)
if len(candidates) > 0:
return candidates[-1]
return None
def find_all(self, name, version_range):
"""Finds the packages by name and range of versions.
:param name: the package`s name.
:param version_range: the range of versions.
:return: the list of suitable packages
"""
if name in self.packages:
candidates = self.packages.find_all(name, version_range)
if len(candidates) > 0:
return candidates
if name in self.obsoletes:
return self._resolve_relation(self.obsoletes[name], version_range)
if name in self.provides:
return self._resolve_relation(self.provides[name], version_range)
return []
def get_unresolved_dependencies(self):
"""Gets the set of unresolved dependencies.
:param base: the base index to resolve dependencies
:return: the set of unresolved depends.
"""
external = self.__get_unresolved_dependencies(self)
if base is None:
return external
unresolved = set()
for relation in external:
for rel in relation:
if base.find(rel.name, rel.version) is not None:
break
else:
unresolved.add(relation)
return unresolved
def get_minimal_subset(self, main, requirements):
"""Gets the minimal work subset.
:param main: the main index, to complete requirements.
:param requirements: additional requirements.
:return: The set of resolved depends.
"""
unresolved = set()
resolved = set()
if main is None:
def pkg_filter(*_):
pass
else:
pkg_filter = main.find
self.__get_unresolved_dependencies(main, requirements)
stack = list()
stack.append((None, requirements))
# add all mandatory packages
for pkg in self.mandatory_packages:
stack.append((pkg, pkg.requires))
while len(stack) > 0:
pkg, required = stack.pop()
resolved.add(pkg)
for require in required:
for rel in require:
for pkg in self.packages:
for required in pkg.requires:
for rel in required:
if rel not in unresolved:
if pkg_filter(rel.name, rel.version) is not None:
break
# use all packages that meets depends
candidates = self.find_all(rel.name, rel.version)
found = False
for cand in candidates:
if cand == pkg:
continue
found = True
if cand not in resolved:
stack.append((cand, cand.requires))
if found:
if self.find(rel.name, rel.version) is not None:
break
else:
unresolved.add(require)
msg = "Unresolved depends: {0}".format(require)
warnings.warn(UnresolvedWarning(msg))
resolved.remove(None)
return resolved
@staticmethod
def __get_unresolved_dependencies(index, unresolved=None):
"""Gets the set of unresolved dependencies.
:param index: the search index.
:param unresolved: the known list of unresolved packages.
:return: the set of unresolved depends.
"""
if unresolved is None:
unresolved = set()
for pkg in index:
for require in pkg.requires:
for rel in require:
if rel not in unresolved:
candidate = index.find(rel.name, rel.version)
if candidate is not None and candidate != pkg:
break
else:
unresolved.add(require)
unresolved.add(required)
return unresolved
def _resolve_relation(self, relations, version_range):
"""Resolve relation according to relations index.
:param relations: the index of relations
:param version_range: the range of versions
:return: package if found, otherwise None
"""
result = []
for key, candidate in six.iteritems(relations):
if version_range.has_intersection(candidate.version):
result.extend(
self.packages.find_all(key[0], VersionRange('=', key[1]))
)
result.sort(key=lambda x: x.version)
return result

View File

@ -20,29 +20,37 @@
class Repository(object):
"""Structure to describe repository object."""
def __init__(self, name, url, architecture, origin):
def __init__(self, name, url, architecture, origin=None,
path=None, section=None):
"""Initialises.
:param name: the repository`s name, may be tuple of strings
:param url: the repository`s URL
:param architecture: the repository`s architecture
:param origin: the repository`s origin
:param origin: optional, the repository`s origin
:param path: the repository relative path, used for mirroring
:param section: the repository section
"""
self.name = name
self.url = url
self.architecture = architecture
self.origin = origin
self.name = name
self.origin = origin or ""
self.url = url
self.section = section
self.path = path
def __str__(self):
if isinstance(self.name, tuple):
return ".".join(self.name)
return self.name or self.url
if not self.section:
return self.url
def __unicode__(self):
if isinstance(self.name, tuple):
return u".".join(self.name)
return self.name or self.url
if isinstance(self.section, tuple):
section_str = " ".join(self.section)
else:
section_str = self.section
return " ".join((self.url, section_str))
def __copy__(self):
"""Creates shallow copy of package."""
return Repository(**self.__dict__)
def __hash__(self):
return hash((self.url, self.section))

View File

@ -20,9 +20,9 @@ from packetary import objects
def gen_repository(name="test", url="file:///test",
architecture="x86_64", origin="Test"):
architecture="x86_64", origin="Test", **kwargs):
"""Helper to create Repository object with default attributes."""
return objects.Repository(name, url, architecture, origin)
return objects.Repository(name, url, architecture, origin, **kwargs)
def gen_relation(name="test", version=None, alternative=None):
@ -46,7 +46,7 @@ def gen_package(idx=1, **kwargs):
for relation in ("requires", "provides", "obsoletes"):
if relation not in kwargs:
kwargs[relation] = [gen_relation(
"{0}{1}".format(relation, idx), ["le", idx + 1]
"{0}{1}".format(relation, idx), ["<=", idx + 1]
)]
return objects.Package(**kwargs)

View File

@ -24,25 +24,19 @@ import subprocess
# that was removed in 3.5
subprocess.mswindows = False
from packetary.api import RepositoryApi
from packetary.cli.commands import clone
from packetary.cli.commands import packages
from packetary.cli.commands import unresolved
from packetary.objects.statistics import CopyStatistics
from packetary.tests import base
from packetary.tests.stubs.generator import gen_package
from packetary.tests.stubs.generator import gen_relation
from packetary.tests.stubs.generator import gen_repository
from packetary.tests.stubs.helpers import CallbacksAdapter
@mock.patch.multiple(
"packetary.api",
RepositoryController=mock.DEFAULT,
ConnectionsManager=mock.DEFAULT,
AsynchronousSection=mock.MagicMock()
)
@mock.patch(
"packetary.cli.commands.base.BaseRepoCommand.stdout"
)
@mock.patch("packetary.cli.commands.base.BaseRepoCommand.stdout")
@mock.patch("packetary.cli.commands.base.read_from_file")
@mock.patch("packetary.cli.commands.base.RepositoryApi")
class TestCliCommands(base.TestCase):
common_argv = [
"--ignore-errors-num=3",
@ -53,23 +47,24 @@ class TestCliCommands(base.TestCase):
]
clone_argv = [
"-o", "http://localhost/origin",
"-d", ".",
"-r", "http://localhost/requires",
"-b", "test-package",
"-r", "repositories.yaml",
"-p", "packages.yaml",
"-d", "/root",
"-t", "deb",
"-a", "x86_64",
"--clean",
"--skip-mandatory"
]
packages_argv = [
"-o", "http://localhost/origin",
"-r", "repositories.yaml",
"-t", "deb",
"-a", "x86_64"
"-a", "x86_64",
"-c", "name", "filename"
]
unresolved_argv = [
"-o", "http://localhost/origin",
"-r", "repositories.yaml",
"-t", "deb",
"-a", "x86_64"
]
@ -77,76 +72,76 @@ class TestCliCommands(base.TestCase):
def start_cmd(self, cmd, argv):
cmd.debug(argv + self.common_argv)
def check_context(self, context, ConnectionsManager):
self.assertEqual(3, context._ignore_errors_num)
self.assertEqual(8, context._threads_num)
self.assertIs(context._connection, ConnectionsManager.return_value)
ConnectionsManager.assert_called_once_with(
proxy="http://proxy",
secure_proxy="https://proxy",
retries_num=10
)
def check_common_config(self, config):
self.assertEqual("http://proxy", config.http_proxy)
self.assertEqual("https://proxy", config.https_proxy)
self.assertEqual(3, config.ignore_errors_num)
self.assertEqual(8, config.threads_num)
self.assertEqual(10, config.retries_num)
def test_clone_cmd(self, stdout, RepositoryController, **kwargs):
ctrl = RepositoryController.load()
ctrl.copy_packages = CallbacksAdapter()
ctrl.load_repositories = CallbacksAdapter()
ctrl.load_packages = CallbacksAdapter()
ctrl.copy_packages.return_value = [1, 0]
repo = gen_repository()
ctrl.load_repositories.side_effect = [repo, gen_repository()]
ctrl.load_packages.side_effect = [
gen_package(repository=repo),
gen_package()
def test_clone_cmd(self, api_mock, read_file_mock, stdout_mock):
read_file_mock.side_effect = [
[{"name": "repo"}],
[{"name": "package"}],
]
api_instance = mock.MagicMock(spec=RepositoryApi)
api_mock.create.return_value = api_instance
api_instance.clone_repositories.return_value = CopyStatistics()
self.start_cmd(clone, self.clone_argv)
RepositoryController.load.assert_called_with(
api_mock.create.assert_called_once_with(
mock.ANY, "deb", "x86_64"
)
self.check_context(
RepositoryController.load.call_args[0][0], **kwargs
self.check_common_config(api_mock.create.call_args[0][0])
read_file_mock.assert_any_call("repositories.yaml")
read_file_mock.assert_any_call("packages.yaml")
api_instance.clone_repositories.assert_called_once_with(
[{"name": "repo"}], [{"name": "package"}], "/root",
False, False, False, False
)
stdout.write.assert_called_once_with(
"Packages copied: 1/2.\n"
stdout_mock.write.assert_called_once_with(
"Packages copied: 0/0.\n"
)
def test_get_packages_cmd(self, stdout, RepositoryController, **kwargs):
ctrl = RepositoryController.load()
ctrl.load_packages = CallbacksAdapter()
ctrl.load_packages.return_value = gen_package(
name="test1",
filesize=1,
requires=None,
obsoletes=None,
provides=None
)
def test_get_packages_cmd(self, api_mock, read_file_mock, stdout_mock):
read_file_mock.return_value = [{"name": "repo"}]
api_instance = mock.MagicMock(spec=RepositoryApi)
api_mock.create.return_value = api_instance
api_instance.get_packages.return_value = [
gen_package(name="test1", filesize=1, requires=None,
obsoletes=None, provides=None)
]
self.start_cmd(packages, self.packages_argv)
RepositoryController.load.assert_called_with(
read_file_mock.assert_called_with("repositories.yaml")
api_mock.create.assert_called_once_with(
mock.ANY, "deb", "x86_64"
)
self.check_context(
RepositoryController.load.call_args[0][0], **kwargs
self.check_common_config(api_mock.create.call_args[0][0])
api_instance.get_packages.assert_called_once_with(
[{"name": "repo"}], None, True
)
self.assertIn(
"test1; test; 1; test1.pkg; 1;",
stdout.write.call_args_list[3][0][0]
"test1; test1.pkg",
stdout_mock.write.call_args_list[3][0][0]
)
def test_get_unresolved_cmd(self, stdout, RepositoryController, **kwargs):
ctrl = RepositoryController.load()
ctrl.load_packages = CallbacksAdapter()
ctrl.load_packages.return_value = gen_package(
name="test1",
requires=[gen_relation("test2")]
)
def test_get_unresolved_cmd(self, api_mock, read_file_mock, stdout_mock):
read_file_mock.return_value = [{"name": "repo"}]
api_instance = mock.MagicMock(spec=RepositoryApi)
api_mock.create.return_value = api_instance
api_instance.get_unresolved_dependencies.return_value = [
gen_relation(name="test")
]
self.start_cmd(unresolved, self.unresolved_argv)
RepositoryController.load.assert_called_with(
api_mock.create.assert_called_once_with(
mock.ANY, "deb", "x86_64"
)
self.check_context(
RepositoryController.load.call_args[0][0], **kwargs
self.check_common_config(api_mock.create.call_args[0][0])
api_instance.get_unresolved_dependencies.assert_called_once_with(
[{"name": "repo"}]
)
self.assertIn(
"test2; any; -",
stdout.write.call_args_list[3][0][0]
"test; any; -",
stdout_mock.write.call_args_list[3][0][0]
)

View File

@ -28,18 +28,28 @@ class Dummy(object):
class TestCommandUtils(base.TestCase):
@mock.patch("packetary.cli.commands.utils.open")
def test_read_lines_from_file(self, open_mock):
open_mock().__enter__.return_value = [
"line1\n",
" # comment\n",
"line2 \n"
]
def test_read_from_json_file(self, open_mock):
mock.mock_open(open_mock, read_data='{"key": "value"}')
self.assertEqual(
["line1", "line2"],
utils.read_lines_from_file("test.txt")
{"key": "value"},
utils.read_from_file("test.json")
)
@mock.patch("packetary.cli.commands.utils.open")
def test_read_from_yaml_file(self, open_mock):
mock.mock_open(open_mock, read_data='key: value')
self.assertEqual(
{"key": "value"},
utils.read_from_file("test.YAML")
)
def test_read_from_from_file_if_none(self):
self.assertIsNone(utils.read_from_file(None))
def test_read_from_from_file_fails_if_unknown_extension(self):
with self.assertRaisesRegexp(ValueError, "txt"):
utils.read_from_file("test.txt")
def test_get_object_attrs(self):
obj = Dummy()
obj.attr_int = 0

View File

@ -20,9 +20,7 @@ import mock
import os.path as path
import six
from packetary.drivers import deb_driver
from packetary.library.utils import localize_repo_url
from packetary.tests import base
from packetary.tests.stubs.generator import gen_package
from packetary.tests.stubs.generator import gen_repository
@ -30,7 +28,6 @@ from packetary.tests.stubs.helpers import get_compressed
PACKAGES = path.join(path.dirname(__file__), "data", "Packages")
RELEASE = path.join(path.dirname(__file__), "data", "Release")
class TestDebDriver(base.TestCase):
@ -42,75 +39,79 @@ class TestDebDriver(base.TestCase):
def setUp(self):
self.connection = mock.MagicMock()
self.repo = gen_repository(
name="trusty", section=("trusty", "main"), url="file:///repo"
)
def test_parse_urls(self):
self.assertItemsEqual(
[
("http://host", "trusty", "main"),
("http://host", "trusty", "restricted"),
],
self.driver.parse_urls(
["http://host/dists/ trusty main restricted"]
)
)
self.assertItemsEqual(
[("http://host", "trusty", "main")],
self.driver.parse_urls(
["http://host/dists trusty main"]
)
)
self.assertItemsEqual(
[("http://host", "trusty", "main")],
self.driver.parse_urls(
["http://host/ trusty main"]
)
)
self.assertItemsEqual(
[
("http://host", "trusty", "main"),
("http://host2", "trusty", "main"),
],
self.driver.parse_urls(
[
"http://host/ trusty main",
"http://host2/dists/ trusty main",
]
)
def test_priority_sort(self):
repos = [
{"name": "repo0"},
{"name": "repo1", "priority": 0},
{"name": "repo2", "priority": 1000},
{"name": "repo3", "priority": None}
]
repos.sort(key=self.driver.priority_sort)
self.assertEqual(
["repo2", "repo0", "repo3", "repo1"],
[x['name'] for x in repos]
)
def test_get_repository(self):
repos = []
with open(RELEASE, "rb") as stream:
self.connection.open_stream.return_value = stream
self.driver.get_repository(
self.connection,
("http://host", "trusty", "main"),
"x86_64",
repos.append
)
self.connection.open_stream.assert_called_once_with(
repo_data = {
"name": "repo1", "url": "http://host", "suite": "trusty",
"section": ["main", "universe"], "path": "my_path"
}
self.connection.open_stream.return_value = {"Origin": "Ubuntu"}
self.driver.get_repository(
self.connection,
repo_data,
"x86_64",
repos.append
)
self.connection.open_stream.assert_any_call(
"http://host/dists/trusty/main/binary-amd64/Release"
)
self.assertEqual(1, len(repos))
self.connection.open_stream.assert_any_call(
"http://host/dists/trusty/universe/binary-amd64/Release"
)
self.assertEqual(2, len(repos))
repo = repos[0]
self.assertEqual(("trusty", "main"), repo.name)
self.assertEqual("repo1", repo.name)
self.assertEqual(("trusty", "main"), repo.section)
self.assertEqual("Ubuntu", repo.origin)
self.assertEqual("x86_64", repo.architecture)
self.assertEqual("http://host/", repo.url)
self.assertEqual("my_path", repo.path)
repo = repos[1]
self.assertEqual("repo1", repo.name)
self.assertEqual(("trusty", "universe"), repo.section)
self.assertEqual("Ubuntu", repo.origin)
self.assertEqual("x86_64", repo.architecture)
self.assertEqual("http://host/", repo.url)
def test_get_flat_repository(self):
with self.assertRaisesRegexp(ValueError, "does not supported"):
self.driver.get_repository(
self.connection,
{"url": "http://host", "suite": "trusty"},
"x86_64",
lambda x: None
)
def test_get_packages(self):
packages = []
repo = gen_repository(name=("trusty", "main"), url="http://host/")
with open(PACKAGES, "rb") as s:
self.connection.open_stream.return_value = get_compressed(s)
self.driver.get_packages(
self.connection,
repo,
self.repo,
packages.append
)
self.connection.open_stream.assert_called_once_with(
"http://host/dists/trusty/main/binary-amd64/Packages.gz",
"file:///repo/dists/trusty/main/binary-amd64/Packages.gz",
)
self.assertEqual(1, len(packages))
package = packages[0]
@ -132,7 +133,7 @@ class TestDebDriver(base.TestCase):
self.assertItemsEqual(
[
'test-main (any)',
'test2 (ge 0.8.16~exp9) | tes2-old (any)',
'test2 (>= 0.8.16~exp9) | tes2-old (any)',
'test3 (any)'
],
(str(x) for x in package.requires)
@ -142,7 +143,7 @@ class TestDebDriver(base.TestCase):
(str(x) for x in package.provides)
)
self.assertItemsEqual(
["test-old (any)"],
[],
(str(x) for x in package.obsoletes)
)
@ -156,10 +157,8 @@ class TestDebDriver(base.TestCase):
os=mock.DEFAULT,
open=mock.DEFAULT
)
def test_rebuild_repository(self, os, debfile, deb822, fcntl,
gzip, utils, open):
repo = gen_repository(name=("trusty", "main"), url="file:///repo")
package = gen_package(name="test", repository=repo)
def test_add_packages(self, os, debfile, deb822, fcntl, gzip, utils, open):
package = gen_package(name="test", repository=self.repo)
os.path.join = lambda *x: "/".join(x)
utils.get_path_from_url = lambda x: x[7:]
@ -171,7 +170,7 @@ class TestDebDriver(base.TestCase):
mock.MagicMock() # Packages.gz, rb
]
open.side_effect = files
self.driver.rebuild_repository(repo, [package])
self.driver.add_packages(self.connection, self.repo, {package})
open.assert_any_call(
"/repo/dists/trusty/main/binary-amd64/Packages", "wb"
)
@ -186,27 +185,24 @@ class TestDebDriver(base.TestCase):
gzip=mock.DEFAULT,
open=mock.DEFAULT,
os=mock.DEFAULT,
utils=mock.DEFAULT
)
def test_fork_repository(self, deb822, gzip, open, os, utils):
@mock.patch("packetary.drivers.deb_driver.utils.ensure_dir_exist")
def test_fork_repository(self, mkdir_mock, deb822, gzip, open, os):
os.path.sep = "/"
os.path.join = lambda *x: "/".join(x)
utils.get_path_from_url = lambda x: x
utils.localize_repo_url = localize_repo_url
repo = gen_repository(
name=("trusty", "main"), url="http://localhost/test/"
)
files = [
mock.MagicMock(),
mock.MagicMock()
]
open.side_effect = files
new_repo = self.driver.fork_repository(self.connection, repo, "/root")
self.assertEqual(repo.name, new_repo.name)
self.assertEqual(repo.architecture, new_repo.architecture)
self.assertEqual(repo.origin, new_repo.origin)
self.assertEqual("/root/test/", new_repo.url)
utils.ensure_dir_exist.assert_called_once_with(os.path.dirname())
new_repo = self.driver.fork_repository(
self.connection, self.repo, "/root/test"
)
self.assertEqual(self.repo.name, new_repo.name)
self.assertEqual(self.repo.architecture, new_repo.architecture)
self.assertEqual(self.repo.origin, new_repo.origin)
self.assertEqual("file:///root/test/", new_repo.url)
mkdir_mock.assert_called_once_with(os.path.dirname())
open.assert_any_call(
"/root/test/dists/trusty/main/binary-amd64/Release", "wb"
)
@ -225,9 +221,7 @@ class TestDebDriver(base.TestCase):
os=mock.DEFAULT,
utils=mock.DEFAULT
)
def test_update_suite_index(
self, os, fcntl, gzip, open, utils):
repo = gen_repository(name=("trusty", "main"), url="/repo")
def test_update_suite_index(self, os, fcntl, gzip, open, utils):
files = [
mock.MagicMock(), # Release, a+b
mock.MagicMock(), # Packages, rb
@ -254,7 +248,7 @@ class TestDebDriver(base.TestCase):
)
for name in deb_driver._REPOSITORY_FILES
)
self.driver._update_suite_index(repo)
self.driver._update_suite_index(self.repo)
open.assert_any_call("/root/dists/trusty/Release", "a+b")
files[0].seek.assert_called_once_with(0)
files[0].truncate.assert_called_once_with(0)
@ -269,6 +263,5 @@ class TestDebDriver(base.TestCase):
.format(k, k + "_value")
))
open.assert_any_call("/root/dists/trusty/Release", "a+b")
print([x.fileno() for x in files])
fcntl.flock.assert_any_call(files[0].fileno(), fcntl.LOCK_EX)
fcntl.flock.assert_any_call(files[0].fileno(), fcntl.LOCK_UN)

View File

@ -23,79 +23,25 @@ from packetary.objects.index import Index
from packetary import objects
from packetary.tests import base
from packetary.tests.stubs.generator import gen_package
from packetary.tests.stubs.generator import gen_relation
class TestIndex(base.TestCase):
def test_add(self):
index = Index()
index.add(gen_package(version=1))
self.assertIn("package1", index.packages)
self.assertIn(1, index.packages["package1"])
self.assertIn("obsoletes1", index.obsoletes)
self.assertIn("provides1", index.provides)
package1 = gen_package(version=1)
index.add(package1)
self.assertIn(package1.name, index.packages)
self.assertEqual(
[(1, package1)],
list(index.packages[package1.name].items())
)
index.add(gen_package(version=2))
package2 = gen_package(version=2)
index.add(package2)
self.assertEqual(1, len(index.packages))
self.assertIn(1, index.packages["package1"])
self.assertIn(2, index.packages["package1"])
self.assertEqual(1, len(index.obsoletes))
self.assertEqual(1, len(index.provides))
def test_find(self):
index = Index()
p1 = gen_package(version=1)
p2 = gen_package(version=2)
index.add(p1)
index.add(p2)
self.assertIs(
p1,
index.find("package1", objects.VersionRange("eq", 1))
)
self.assertIs(
p2,
index.find("package1", objects.VersionRange())
)
self.assertIsNone(
index.find("package1", objects.VersionRange("gt", 2))
)
def test_find_all(self):
index = Index()
p11 = gen_package(idx=1, version=1)
p12 = gen_package(idx=1, version=2)
p21 = gen_package(idx=2, version=1)
p22 = gen_package(idx=2, version=2)
index.add(p11)
index.add(p12)
index.add(p21)
index.add(p22)
self.assertItemsEqual(
[p11, p12],
index.find_all("package1", objects.VersionRange())
)
self.assertItemsEqual(
[p21, p22],
index.find_all("package2", objects.VersionRange("le", 2))
)
def test_find_newest_package(self):
index = Index()
p1 = gen_package(idx=1, version=2)
p2 = gen_package(idx=2, version=2)
p2.obsoletes.append(
gen_relation(p1.name, ["lt", p1.version])
)
index.add(p1)
index.add(p2)
self.assertIs(
p1, index.find(p1.name, objects.VersionRange("eq", p1.version))
)
self.assertIs(
p2, index.find(p1.name, objects.VersionRange("eq", 1))
self.assertEqual(
[(1, package1), (2, package2)],
list(index.packages[package1.name].items())
)
def test_find_top_down(self):
@ -104,16 +50,17 @@ class TestIndex(base.TestCase):
p2 = gen_package(version=2)
index.add(p1)
index.add(p2)
self.assertIs(
p2,
index.find("package1", objects.VersionRange("le", 2))
self.assertEqual(
[p1, p2],
index.find_all(p1.name, objects.VersionRange("<=", 2))
)
self.assertIs(
p1,
index.find("package1", objects.VersionRange("lt", 2))
self.assertEqual(
[p1],
index.find_all(p1.name, objects.VersionRange("<", 2))
)
self.assertIsNone(
index.find("package1", objects.VersionRange("lt", 1))
self.assertEqual(
[],
index.find_all(p1.name, objects.VersionRange("<", 1))
)
def test_find_down_up(self):
@ -122,56 +69,33 @@ class TestIndex(base.TestCase):
p2 = gen_package(version=2)
index.add(p1)
index.add(p2)
self.assertIs(
p2,
index.find("package1", objects.VersionRange("ge", 2))
self.assertEqual(
[p2],
index.find_all(p1.name, objects.VersionRange(">=", 2))
)
self.assertIs(
p2,
index.find("package1", objects.VersionRange("gt", 1))
self.assertEqual(
[p2],
index.find_all(p1.name, objects.VersionRange(">", 1))
)
self.assertIsNone(
index.find("package1", objects.VersionRange("gt", 2))
self.assertEqual(
[],
index.find_all(p1.name, objects.VersionRange(">", 2))
)
def test_find_accurate(self):
def test_find_with_specified_version(self):
index = Index()
p1 = gen_package(version=1)
p2 = gen_package(version=2)
index.add(p1)
index.add(p2)
self.assertIs(
p1,
index.find("package1", objects.VersionRange("eq", 1))
)
self.assertIsNone(
index.find("package1", objects.VersionRange("eq", 3))
)
def test_find_obsolete(self):
index = Index()
p1 = gen_package(version=1)
index.add(p1)
self.assertIs(
p1, index.find("obsoletes1", objects.VersionRange("le", 2))
)
self.assertIsNone(
index.find("obsoletes1", objects.VersionRange("gt", 2))
)
def test_find_provides(self):
index = Index()
p1 = gen_package(version=1)
p2 = gen_package(version=2)
p1 = gen_package(idx=1, version=1)
p2 = gen_package(idx=1, version=2)
index.add(p1)
index.add(p2)
self.assertIs(
p2, index.find("provides1", objects.VersionRange("ge", 2))
self.assertItemsEqual(
[p1],
index.find_all(p1.name, objects.VersionRange("=", p1.version))
)
self.assertIsNone(
index.find("provides1", objects.VersionRange("lt", 2))
self.assertItemsEqual(
[p2],
index.find_all(p2.name, objects.VersionRange("=", p2.version))
)
def test_len(self):

View File

@ -94,6 +94,29 @@ class TestLibraryUtils(base.TestCase):
utils.get_path_from_url("http://host/f.txt", False)
)
@mock.patch("packetary.library.utils.os")
def test_normalize_repository_url(self, os_mock):
def abs_patch_mock(p):
if p.startswith("/"):
return p
return "/root/" + p[2:]
os_mock.sep = "/"
os_mock.path.abspath.side_effect = abs_patch_mock
cases = [
("file:///repo/", "/repo"),
("file:///root/repo/", "./repo"),
("http://localhost/repo/", "http://localhost/repo"),
("http://localhost/repo/", "http://localhost/repo/"),
]
for expected, url in cases:
self.assertEqual(
expected, utils.normalize_repository_url(url),
"URL: {0}".format(url)
)
@mock.patch("packetary.library.utils.os")
def test_ensure_dir_exist(self, os):
os.makedirs.side_effect = [

View File

@ -32,7 +32,7 @@ class TestObjectBase(base.TestCase):
def check_copy(self, origin):
clone = copy.copy(origin)
self.assertIsNot(origin, clone)
self.assertEqual(origin, clone)
self.assertEqual(origin.name, clone.name)
origin_name = origin.name
origin.name += "1"
self.assertEqual(
@ -91,25 +91,30 @@ class TestPackageObject(TestObjectBase):
)
class TestRepositoryObject(base.TestCase):
class TestRepositoryObject(TestObjectBase):
def test_copy(self):
origin = generator.gen_repository()
clone = copy.copy(origin)
self.assertEqual(clone.name, origin.name)
self.assertEqual(clone.architecture, origin.architecture)
self.check_copy(generator.gen_repository())
def test_hashable(self):
self.check_hashable(
generator.gen_repository(name="test1", url="file:///repo"),
generator.gen_repository(name="test1", url="file:///repo",
section=("a", "b")),
)
def test_str(self):
self.assertEqual(
"a.b",
str(generator.gen_repository(name=("a", "b")))
)
self.assertEqual(
"/a/b/",
str(generator.gen_repository(name="", url="/a/b/"))
str(generator.gen_repository(name="a", url="/a/b/"))
)
self.assertEqual(
"a",
str(generator.gen_repository(name="a", url="/a/b/"))
"/a/b/ c",
str(generator.gen_repository(name="a", url="/a/b/", section="c"))
)
self.assertEqual(
"/a/b/ c d",
str(generator.gen_repository(
name="a", url="/a/b/", section=("c", "d")))
)
@ -124,15 +129,15 @@ class TestRelationObject(TestObjectBase):
def test_hashable(self):
self.check_hashable(
generator.gen_relation(name="test1"),
generator.gen_relation(name="test1", version=["le", 1])
generator.gen_relation(name="test1", version=["<=", 1])
)
def test_from_args(self):
r = PackageRelation.from_args(
("test", "le", 2), ("test2",), ("test3",)
("test", "<=", 2), ("test2",), ("test3",)
)
self.assertEqual("test", r.name)
self.assertEqual("le", r.version.op)
self.assertEqual("<=", r.version.op)
self.assertEqual(2, r.version.edge)
self.assertEqual("test2", r.alternative.name)
self.assertEqual(VersionRange(), r.alternative.version)
@ -142,7 +147,7 @@ class TestRelationObject(TestObjectBase):
def test_iter(self):
it = iter(PackageRelation.from_args(
("test", "le", 2), ("test2", "ge", 3))
("test", "<=", 2), ("test2", ">=", 3))
)
self.assertEqual("test", next(it).name)
self.assertEqual("test2", next(it).name)
@ -153,15 +158,15 @@ class TestRelationObject(TestObjectBase):
class TestVersionRange(TestObjectBase):
def test_equal(self):
self.check_equal(
VersionRange("eq", 1),
VersionRange("eq", 1),
VersionRange("le", 1)
VersionRange("=", 1),
VersionRange("=", 1),
VersionRange("<=", 1)
)
def test_hashable(self):
self.check_hashable(
VersionRange(op="le"),
VersionRange(op="le", edge=3)
VersionRange(op="<="),
VersionRange(op="<=", edge=3)
)
def __check_intersection(self, assertion, cases):
@ -177,28 +182,39 @@ class TestVersionRange(TestObjectBase):
def test_have_intersection(self):
cases = [
(("lt", 2), ("gt", 1)),
(("lt", 3), ("lt", 4)),
(("gt", 3), ("gt", 4)),
(("eq", 1), ("eq", 1)),
(("ge", 1), ("le", 1)),
(("eq", 1), ("lt", 2)),
((None, None), ("le", 10)),
(("=", 2), ("=", 2)),
(("=", 2), ("<", 3)),
(("=", 2), (">", 1)),
(("<", 2), (">", 1)),
(("<", 2), ("<", 3)),
(("<", 2), ("<", 2)),
(("<", 2), ("<=", 2)),
((">", 2), (">", 1)),
((">", 2), ("<", 3)),
((">", 2), (">=", 2)),
((">", 2), (">", 2)),
((">=", 2), ("<=", 2)),
((None, None), ("=", 2)),
]
self.__check_intersection(self.assertTrue, cases)
def test_does_not_have_intersection(self):
cases = [
(("lt", 2), ("gt", 2)),
(("ge", 2), ("lt", 2)),
(("gt", 2), ("le", 2)),
(("gt", 1), ("lt", 1)),
(("=", 2), ("=", 1)),
(("=", 2), ("<", 2)),
(("=", 2), (">", 2)),
(("=", 2), (">", 3)),
(("=", 2), ("<", 1)),
(("<", 2), (">=", 2)),
(("<", 2), (">", 3)),
((">", 2), ("<=", 2)),
((">", 2), ("<", 1)),
]
self.__check_intersection(self.assertFalse, cases)
def test_intersection_is_typesafe(self):
with self.assertRaises(TypeError):
VersionRange("eq", 1).has_intersection(("eq", 1))
VersionRange("=", 1).has_intersection(("=", 1))
class TestPackageVersion(base.TestCase):

View File

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Copyright 2016 Mirantis, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
from packetary.objects import PackagesForest
from packetary.tests import base
from packetary.tests.stubs import generator
class TestPackagesForest(base.TestCase):
def setUp(self):
super(TestPackagesForest, self).setUp()
def _add_packages(self, tree, packages):
for pkg in packages:
tree.add(pkg)
def _generate_packages(self, forest):
packages1 = [
generator.gen_package(
name="package1", version=1, mandatory=True,
requires=None
),
generator.gen_package(
name="package2", version=1,
requires=None
),
generator.gen_package(
name="package3", version=1,
requires=[generator.gen_relation("package5")]
)
]
packages2 = [
generator.gen_package(
name="package4", version=1, mandatory=True,
requires=None
),
generator.gen_package(
name="package5", version=1,
requires=[generator.gen_relation("package2")]
),
]
self._add_packages(forest.add_tree(), packages1)
self._add_packages(forest.add_tree(), packages2)
def test_add_tree(self):
forest = PackagesForest()
tree = forest.add_tree()
self.assertIs(tree, forest.trees[-1])
def test_find(self):
forest = PackagesForest()
p11 = generator.gen_package(name="package1", version=1)
p12 = generator.gen_package(name="package1", version=2)
p21 = generator.gen_package(name="package2", version=1)
p22 = generator.gen_package(name="package2", version=2)
self._add_packages(forest.add_tree(), [p11, p22])
self._add_packages(forest.add_tree(), [p12, p21])
self.assertEqual(
p11, forest.find(generator.gen_relation("package1", [">=", 1]))
)
self.assertEqual(
p12, forest.find(generator.gen_relation("package1", [">", 1]))
)
self.assertEqual(p22, forest.find(generator.gen_relation("package2")))
self.assertEqual(
p21, forest.find(generator.gen_relation("package2", ["<", 2]))
)
def test_get_packages_with_mandatory(self):
forest = PackagesForest()
self._generate_packages(forest)
packages = forest.get_packages(
[generator.gen_relation("package3")], True
)
self.assertItemsEqual(
["package1", "package2", "package3", "package4", "package5"],
(x.name for x in packages)
)
def test_get_packages_without_mandatory(self):
forest = PackagesForest()
self._generate_packages(forest)
packages = forest.get_packages(
[generator.gen_relation("package3")], False
)
self.assertItemsEqual(
["package2", "package3", "package5"],
(x.name for x in packages)
)

View File

@ -16,120 +16,82 @@
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
import warnings
from packetary.objects import Index
from packetary.objects import PackagesTree
from packetary.objects import VersionRange
from packetary.tests import base
from packetary.tests.stubs import generator
class TestPackagesTree(base.TestCase):
def setUp(self):
super(TestPackagesTree, self).setUp()
def test_add(self):
tree = PackagesTree()
pkg = generator.gen_package(version=1, mandatory=True)
tree.add(pkg)
self.assertIs(pkg, tree.find(pkg.name, VersionRange('=', pkg.version)))
self.assertIs(
pkg.obsoletes[0],
tree.obsoletes[pkg.obsoletes[0].name][(pkg.name, pkg.version)]
)
self.assertIs(
pkg.provides[0],
tree.provides[pkg.provides[0].name][(pkg.name, pkg.version)]
)
tree.add(generator.gen_package(version=1, mandatory=False))
self.assertItemsEqual([pkg], tree.mandatory_packages)
def test_get_unresolved_dependencies(self):
ptree = PackagesTree()
ptree.add(generator.gen_package(
1, requires=[generator.gen_relation("unresolved")]))
ptree.add(generator.gen_package(2, requires=None))
ptree.add(generator.gen_package(
3, requires=[generator.gen_relation("package1")]
))
ptree.add(generator.gen_package(
4,
requires=[generator.gen_relation("loop")],
obsoletes=[generator.gen_relation("loop", ["le", 1])]
))
def test_find_package(self):
tree = PackagesTree()
p1 = generator.gen_package(idx=1, version=1)
p2 = generator.gen_package(idx=1, version=2)
tree.add(p1)
tree.add(p2)
unresolved = ptree.get_unresolved_dependencies()
self.assertItemsEqual(
["loop", "unresolved"],
(x.name for x in unresolved)
self.assertIs(p1, tree.find(p1.name, VersionRange("<", p2.version)))
self.assertIs(p2, tree.find(p1.name, VersionRange(">=", p1.version)))
self.assertIsNone(tree.find(p1.name, VersionRange(">", p2.version)))
def test_find_obsolete(self):
tree = PackagesTree()
p1 = generator.gen_package(
version=1, obsoletes=[generator.gen_relation('obsolete', ('<', 2))]
)
p2 = generator.gen_package(
version=2, obsoletes=[generator.gen_relation('obsolete', ('<', 2))]
)
tree.add(p1)
tree.add(p2)
self.assertEqual(
[p1, p2], tree.find_all("obsolete", VersionRange("<=", 2))
)
self.assertIsNone(
tree.find("obsolete", VersionRange(">", 2))
)
def test_get_unresolved_dependencies_with_main(self):
ptree = PackagesTree()
ptree.add(generator.gen_package(
def test_find_provides(self):
tree = PackagesTree()
p1 = generator.gen_package(
version=1, obsoletes=[generator.gen_relation('provide', ('<', 2))]
)
tree.add(p1)
self.assertIs(
p1, tree.find("provide", VersionRange("<=", 2))
)
self.assertIsNone(
tree.find("provide", VersionRange(">", 2))
)
def test_get_unresolved_dependencies(self):
tree = PackagesTree()
tree.add(generator.gen_package(
1, requires=[generator.gen_relation("unresolved")]))
ptree.add(generator.gen_package(2, requires=None))
ptree.add(generator.gen_package(
tree.add(generator.gen_package(2, requires=None))
tree.add(generator.gen_package(
3, requires=[generator.gen_relation("package1")]
))
ptree.add(generator.gen_package(
4,
requires=[generator.gen_relation("package5")]
))
main = Index()
main.add(generator.gen_package(5, requires=[
generator.gen_relation("package6")
]))
unresolved = ptree.get_unresolved_dependencies(main)
unresolved = tree.get_unresolved_dependencies()
self.assertItemsEqual(
["unresolved"],
(x.name for x in unresolved)
)
def test_get_minimal_subset_with_master(self):
ptree = PackagesTree()
ptree.add(generator.gen_package(1, requires=None))
ptree.add(generator.gen_package(2, requires=None))
ptree.add(generator.gen_package(3, requires=None))
ptree.add(generator.gen_package(
4, requires=[generator.gen_relation("package1")]
))
master = Index()
master.add(generator.gen_package(1, requires=None))
master.add(generator.gen_package(
5,
requires=[generator.gen_relation(
"package10",
alternative=generator.gen_relation("package4")
)]
))
unresolved = set([generator.gen_relation("package3")])
resolved = ptree.get_minimal_subset(master, unresolved)
self.assertItemsEqual(
["package3", "package4"],
(x.name for x in resolved)
)
def test_get_minimal_subset_without_master(self):
ptree = PackagesTree()
ptree.add(generator.gen_package(1, requires=None))
ptree.add(generator.gen_package(2, requires=None))
ptree.add(generator.gen_package(
3, requires=[generator.gen_relation("package1")]
))
unresolved = set([generator.gen_relation("package3")])
resolved = ptree.get_minimal_subset(None, unresolved)
self.assertItemsEqual(
["package3", "package1"],
(x.name for x in resolved)
)
def test_mandatory_packages_always_included(self):
ptree = PackagesTree()
ptree.add(generator.gen_package(1, requires=None, mandatory=True))
ptree.add(generator.gen_package(2, requires=None))
ptree.add(generator.gen_package(3, requires=None))
unresolved = set([generator.gen_relation("package3")])
resolved = ptree.get_minimal_subset(None, unresolved)
self.assertItemsEqual(
["package3", "package1"],
(x.name for x in resolved)
)
def test_warning_if_unresolved(self):
ptree = PackagesTree()
ptree.add(generator.gen_package(
1, requires=None))
with warnings.catch_warnings(record=True) as log:
ptree.get_minimal_subset(
None, [generator.gen_relation("package2")]
)
self.assertIn("package2", str(log[0]))

View File

@ -16,6 +16,7 @@
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
import copy
import mock
from packetary.api import Configuration
@ -27,197 +28,174 @@ from packetary.tests.stubs.helpers import CallbacksAdapter
class TestRepositoryApi(base.TestCase):
def test_get_packages_as_is(self):
controller = CallbacksAdapter()
pkg = generator.gen_package(name="test")
controller.load_packages.side_effect = [
pkg
]
api = RepositoryApi(controller)
packages = api.get_packages("file:///repo1")
self.assertEqual(1, len(packages))
package = packages.pop()
self.assertIs(pkg, package)
def setUp(self):
self.controller = CallbacksAdapter()
self.api = RepositoryApi(self.controller)
self.repo_data = {"name": "repo1", "url": "file:///repo1"}
self.repo = generator.gen_repository(**self.repo_data)
self.controller.load_repositories.return_value = [self.repo]
self._generate_packages()
def test_get_packages_with_depends_resolving(self):
controller = CallbacksAdapter()
controller.load_packages.side_effect = [
[
generator.gen_package(idx=1, requires=None),
generator.gen_package(
idx=2, requires=[generator.gen_relation("package1")]
),
generator.gen_package(
idx=3, requires=[generator.gen_relation("package1")]
),
generator.gen_package(idx=4, requires=None),
generator.gen_package(idx=5, requires=None),
],
def _generate_packages(self):
self.packages = [
generator.gen_package(idx=1, repository=self.repo, requires=None),
generator.gen_package(idx=2, repository=self.repo, requires=None),
generator.gen_package(
idx=6, requires=[generator.gen_relation("package2")]
idx=3, repository=self.repo, mandatory=True,
requires=[generator.gen_relation("package2")]
),
generator.gen_package(
idx=4, repository=self.repo, mandatory=False,
requires=[generator.gen_relation("package1")]
),
generator.gen_package(
idx=5, repository=self.repo,
requires=[generator.gen_relation("package6")])
]
self.controller.load_packages.return_value = self.packages
api = RepositoryApi(controller)
packages = api.get_packages([
"file:///repo1", "file:///repo2"
],
"file:///repo3", ["package4"]
@mock.patch("packetary.api.RepositoryController")
@mock.patch("packetary.api.ConnectionsManager")
def test_create_with_config(self, connection_mock, controller_mock):
config = Configuration(
http_proxy="http://localhost", https_proxy="https://localhost",
retries_num=10, threads_num=8, ignore_errors_num=6
)
RepositoryApi.create(config, "deb", "x86_64")
connection_mock.assert_called_once_with(
proxy="http://localhost",
secure_proxy="https://localhost",
retries_num=10
)
controller_mock.load.assert_called_once_with(
mock.ANY, "deb", "x86_64"
)
@mock.patch("packetary.api.RepositoryController")
@mock.patch("packetary.api.ConnectionsManager")
def test_create_with_context(self, connection_mock, controller_mock):
config = Configuration(
http_proxy="http://localhost", https_proxy="https://localhost",
retries_num=10, threads_num=8, ignore_errors_num=6
)
context = Context(config)
RepositoryApi.create(context, "deb", "x86_64")
connection_mock.assert_called_once_with(
proxy="http://localhost",
secure_proxy="https://localhost",
retries_num=10
)
controller_mock.load.assert_called_once_with(
context, "deb", "x86_64"
)
def test_get_packages_as_is(self):
packages = self.api.get_packages([self.repo_data], None)
self.assertEqual(5, len(packages))
self.assertItemsEqual(
self.packages,
packages
)
def test_get_packages_by_requirements_with_mandatory(self):
packages = self.api.get_packages(
[self.repo_data], [{"name": "package1"}], True
)
self.assertEqual(3, len(packages))
self.assertItemsEqual(
["package1", "package4", "package2"],
["package1", "package2", "package3"],
(x.name for x in packages)
)
controller.load_repositories.assert_any_call(
["file:///repo1", "file:///repo2"]
def test_get_packages_by_requirements_without_mandatory(self):
packages = self.api.get_packages(
[self.repo_data], [{"name": "package4"}], False
)
controller.load_repositories.assert_any_call(
"file:///repo3"
self.assertEqual(2, len(packages))
self.assertItemsEqual(
["package1", "package4"],
(x.name for x in packages)
)
def test_clone_repositories_as_is(self):
controller = CallbacksAdapter()
repo = generator.gen_repository(name="repo1")
packages = [
generator.gen_package(name="test1", repository=repo),
generator.gen_package(name="test2", repository=repo)
]
mirror = generator.gen_repository(name="mirror")
controller.load_repositories.return_value = repo
controller.load_packages.return_value = packages
controller.clone_repositories.return_value = {repo: mirror}
controller.copy_packages.return_value = [0, 1]
api = RepositoryApi(controller)
stats = api.clone_repositories(
["file:///repo1"], "/mirror", keep_existing=True
# return value is used as statistics
mirror = copy.copy(self.repo)
mirror.url = "file:///mirror/repo"
self.controller.fork_repository.return_value = mirror
self.controller.assign_packages.return_value = [0, 1, 1, 1, 0, 6]
stats = self.api.clone_repositories([self.repo_data], None, "/mirror")
self.controller.fork_repository.assert_called_once_with(
self.repo, '/mirror', False, False
)
self.controller.assign_packages.assert_called_once_with(
mirror, set(self.packages)
)
self.assertEqual(6, stats.total)
self.assertEqual(4, stats.copied)
def test_clone_by_requirements_with_mandatory(self):
# return value is used as statistics
mirror = copy.copy(self.repo)
mirror.url = "file:///mirror/repo"
self.controller.fork_repository.return_value = mirror
self.controller.assign_packages.return_value = [0, 1, 1]
stats = self.api.clone_repositories(
[self.repo_data], [{"name": "package1"}],
"/mirror", include_mandatory=True
)
packages = {self.packages[0], self.packages[1], self.packages[2]}
self.controller.fork_repository.assert_called_once_with(
self.repo, '/mirror', False, False
)
self.controller.assign_packages.assert_called_once_with(
mirror, packages
)
self.assertEqual(3, stats.total)
self.assertEqual(2, stats.copied)
def test_clone_by_requirements_without_mandatory(self):
# return value is used as statistics
mirror = copy.copy(self.repo)
mirror.url = "file:///mirror/repo"
self.controller.fork_repository.return_value = mirror
self.controller.assign_packages.return_value = [0, 4]
stats = self.api.clone_repositories(
[self.repo_data], [{"name": "package4"}],
"/mirror", include_mandatory=False
)
packages = {self.packages[0], self.packages[3]}
self.controller.fork_repository.assert_called_once_with(
self.repo, '/mirror', False, False
)
self.controller.assign_packages.assert_called_once_with(
mirror, packages
)
self.assertEqual(2, stats.total)
self.assertEqual(1, stats.copied)
controller.copy_packages.assert_called_once_with(
mirror, set(packages), True
)
def test_copy_minimal_subset_of_repository(self):
controller = CallbacksAdapter()
repo1 = generator.gen_repository(name="repo1")
repo2 = generator.gen_repository(name="repo2")
repo3 = generator.gen_repository(name="repo3")
mirror1 = generator.gen_repository(name="mirror1")
mirror2 = generator.gen_repository(name="mirror2")
pkg_group1 = [
generator.gen_package(
idx=1, requires=None, repository=repo1
),
generator.gen_package(
idx=1, version=2, requires=None, repository=repo1
),
generator.gen_package(
idx=2, requires=None, repository=repo1
)
]
pkg_group2 = [
generator.gen_package(
idx=4,
requires=[generator.gen_relation("package1")],
repository=repo2,
mandatory=True,
)
]
pkg_group3 = [
generator.gen_package(
idx=3, requires=None, repository=repo1
)
]
controller.load_repositories.side_effect = [[repo1, repo2], repo3]
controller.load_packages.side_effect = [
pkg_group1 + pkg_group2 + pkg_group3,
generator.gen_package(
idx=6,
repository=repo3,
requires=[generator.gen_relation("package2")]
)
]
controller.clone_repositories.return_value = {
repo1: mirror1, repo2: mirror2
}
controller.copy_packages.return_value = 1
api = RepositoryApi(controller)
api.clone_repositories(
["file:///repo1", "file:///repo2"], "/mirror",
["file:///repo3"],
keep_existing=True
)
controller.copy_packages.assert_any_call(
mirror1, set(pkg_group1), True
)
controller.copy_packages.assert_any_call(
mirror2, set(pkg_group2), True
)
self.assertEqual(2, controller.copy_packages.call_count)
def test_get_unresolved(self):
controller = CallbacksAdapter()
pkg = generator.gen_package(
name="test", requires=[generator.gen_relation("test2")]
)
controller.load_packages.side_effect = [
pkg
]
api = RepositoryApi(controller)
r = api.get_unresolved_dependencies("file:///repo1")
controller.load_repositories.assert_called_once_with("file:///repo1")
self.assertItemsEqual(
["test2"],
(x.name for x in r)
)
unresolved = self.api.get_unresolved_dependencies([self.repo_data])
self.assertItemsEqual(["package6"], (x.name for x in unresolved))
def test_get_unresolved_with_main(self):
controller = CallbacksAdapter()
pkg1 = generator.gen_package(
name="test1", requires=[
generator.gen_relation("test2"),
generator.gen_relation("test3")
]
)
pkg2 = generator.gen_package(
name="test2", requires=[generator.gen_relation("test4")]
)
controller.load_packages.side_effect = [
pkg1, pkg2
]
api = RepositoryApi(controller)
r = api.get_unresolved_dependencies("file:///repo1", "file:///repo2")
controller.load_repositories.assert_any_call("file:///repo1")
controller.load_repositories.assert_any_call("file:///repo2")
self.assertItemsEqual(
["test3"],
(x.name for x in r)
)
def test_load_requirements(self):
expected = {
generator.gen_relation("test1"),
generator.gen_relation("test2", ["<", "3"]),
generator.gen_relation("test2", [">", "1"]),
}
actual = set(self.api._load_requirements(
[{"name": "test1"}, {"name": "test2", "versions": ["< 3", "> 1"]}]
))
self.assertEqual(expected, actual)
self.assertIsNone(self.api._load_requirements(None))
def test_parse_requirements(self):
requirements = RepositoryApi._parse_requirements(
["p1 le 2 | p2 | p3 ge 2"]
)
def test_validate_repos_data(self):
# TODO(bgaifullin) implement me
pass
expected = generator.gen_relation(
"p1",
["le", '2'],
generator.gen_relation(
"p2",
None,
generator.gen_relation(
"p3",
["ge", '2']
)
)
)
self.assertEqual(1, len(requirements))
self.assertEqual(
list(expected),
list(requirements.pop())
)
def test_validate_requirements_data(self):
# TODO(bgaifullin) implement me
pass
class TestContext(base.TestCase):

View File

@ -18,9 +18,9 @@
import copy
import mock
import six
from packetary.controllers import RepositoryController
from packetary.drivers.base import RepositoryDriverBase
from packetary.tests import base
from packetary.tests.stubs.executor import Executor
from packetary.tests.stubs.generator import gen_package
@ -30,7 +30,7 @@ from packetary.tests.stubs.helpers import CallbacksAdapter
class TestRepositoryController(base.TestCase):
def setUp(self):
self.driver = mock.MagicMock()
self.driver = mock.MagicMock(spec=RepositoryDriverBase)
self.context = mock.MagicMock()
self.context.async_section.return_value = Executor()
self.ctrl = RepositoryController(self.context, self.driver, "x86_64")
@ -53,24 +53,21 @@ class TestRepositoryController(base.TestCase):
self.assertIs(self.driver, controller.driver)
def test_load_repositories(self):
self.driver.parse_urls.return_value = ["test1"]
consumer = mock.MagicMock()
self.ctrl.load_repositories("file:///test1", consumer)
self.driver.parse_urls.assert_called_once_with(["file:///test1"])
repo_data = {"name": "test", "url": "file:///test1"}
repo = gen_repository(**repo_data)
self.driver.get_repository = CallbacksAdapter()
self.driver.get_repository.side_effect = [repo]
repos = self.ctrl.load_repositories([repo_data])
self.driver.get_repository.assert_called_once_with(
self.context.connection, "test1", "x86_64", consumer
self.context.connection, repo_data, self.ctrl.arch
)
for url in [six.u("file:///test1"), ["file:///test1"]]:
self.driver.reset_mock()
self.ctrl.load_repositories(url, consumer)
if not isinstance(url, list):
url = [url]
self.driver.parse_urls.assert_called_once_with(url)
self.assertEqual([repo], repos)
def test_load_packages(self):
repo = mock.MagicMock()
consumer = mock.MagicMock()
self.ctrl.load_packages([repo], consumer)
self.ctrl.load_packages(repo, consumer)
self.driver.get_packages.assert_called_once_with(
self.context.connection, repo, consumer
)
@ -78,30 +75,33 @@ class TestRepositoryController(base.TestCase):
@mock.patch("packetary.controllers.repository.os")
def test_assign_packages(self, os):
repo = gen_repository(url="/test/repo")
packages = [
packages = {
gen_package(name="test1", repository=repo),
gen_package(name="test2", repository=repo)
]
existed_packages = [
gen_package(name="test3", repository=repo),
gen_package(name="test2", repository=repo)
]
}
os.path.join = lambda *x: "/".join(x)
self.driver.get_packages = CallbacksAdapter()
self.driver.get_packages.return_value = existed_packages
self.ctrl.assign_packages(repo, packages, True)
os.remove.assert_not_called()
all_packages = set(packages + existed_packages)
self.driver.rebuild_repository.assert_called_once_with(
repo, all_packages
self.ctrl.assign_packages(repo, packages)
self.driver.add_packages.assert_called_once_with(
self.ctrl.context.connection, repo, packages
)
self.driver.rebuild_repository.reset_mock()
self.ctrl.assign_packages(repo, packages, False)
self.driver.rebuild_repository.assert_called_once_with(
repo, set(packages)
@mock.patch("packetary.controllers.repository.os")
def test_fork_repository(self, os):
os.path.join.side_effect = lambda *args: "".join(args)
repo = gen_repository(name="test1", url="file:///test")
clone = copy.copy(repo)
clone.url = "/root/repo"
self.driver.fork_repository.return_value = clone
self.context.connection.retrieve.side_effect = [0, 10]
self.ctrl.fork_repository(repo, "./repo", False, False)
self.driver.fork_repository.assert_called_once_with(
self.context.connection, repo, "./repo/test", False, False
)
repo.path = "os"
self.ctrl.fork_repository(repo, "./repo/", False, False)
self.driver.fork_repository.assert_called_with(
self.context.connection, repo, "./repo/os", False, False
)
os.remove.assert_called_once_with("/test/repo/test3.pkg")
def test_copy_packages(self):
repo = gen_repository(url="file:///repo/")
@ -112,8 +112,9 @@ class TestRepositoryController(base.TestCase):
target = gen_repository(url="/test/repo")
self.context.connection.retrieve.side_effect = [0, 10]
observer = mock.MagicMock()
self.ctrl.copy_packages(target, packages, True, observer)
observer.assert_has_calls([mock.call(0), mock.call(10)])
self.ctrl._copy_packages(target, packages, observer)
observer.assert_any_call(0)
observer.assert_any_call(10)
self.context.connection.retrieve.assert_any_call(
"file:///repo/test1.pkg",
"/test/repo/test1.pkg",
@ -124,22 +125,13 @@ class TestRepositoryController(base.TestCase):
"/test/repo/test2.pkg",
size=-1
)
self.driver.rebuild_repository.assert_called_once_with(
target, set(packages)
)
@mock.patch("packetary.controllers.repository.os")
def test_clone_repository(self, os):
os.path.abspath.return_value = "/root/repo"
repos = [
gen_repository(name="test1"),
gen_repository(name="test2")
def test_copy_packages_does_not_affect_packages_in_same_repo(self):
repo = gen_repository(url="file:///repo/")
packages = [
gen_package(name="test1", repository=repo, filesize=10),
gen_package(name="test2", repository=repo, filesize=-1)
]
clones = [copy.copy(x) for x in repos]
self.driver.fork_repository.side_effect = clones
mirrors = self.ctrl.clone_repositories(repos, "./repo")
for r in repos:
self.driver.fork_repository.assert_any_call(
self.context.connection, r, "/root/repo", False, False
)
self.assertEqual(mirrors, dict(zip(repos, clones)))
observer = mock.MagicMock()
self.ctrl._copy_packages(repo, packages, observer)
self.assertFalse(self.context.connection.retrieve.called)

View File

@ -22,7 +22,6 @@ import sys
import six
from packetary.library.utils import localize_repo_url
from packetary.objects import FileChecksum
from packetary.tests import base
from packetary.tests.stubs.generator import gen_repository
@ -53,31 +52,33 @@ class TestRpmDriver(base.TestCase):
self.createrepo.reset_mock()
self.connection = mock.MagicMock()
def test_parse_urls(self):
self.assertItemsEqual(
[
"http://host/centos/os",
"http://host/centos/updates"
],
self.driver.parse_urls([
"http://host/centos/os",
"http://host/centos/updates/",
])
def test_priority_sort(self):
repos = [
{"name": "repo0"},
{"name": "repo1", "priority": 1},
{"name": "repo2", "priority": 99},
{"name": "repo3", "priority": None}
]
repos.sort(key=self.driver.priority_sort)
self.assertEqual(
["repo1", "repo0", "repo3", "repo2"],
[x['name'] for x in repos]
)
def test_get_repository(self):
repos = []
repo_data = {"name": "os", "url": "http://host/centos/os/x86_64/"}
self.driver.get_repository(
self.connection,
"http://host/centos/os/x86_64",
repo_data,
"x86_64",
repos.append
)
self.assertEqual(1, len(repos))
repo = repos[0]
self.assertEqual("/centos/os/x86_64", repo.name)
self.assertEqual("os", repo.name)
self.assertEqual("", repo.origin)
self.assertEqual("x86_64", repo.architecture)
self.assertEqual("http://host/centos/os/x86_64/", repo.url)
@ -125,7 +126,7 @@ class TestRpmDriver(base.TestCase):
"Packages/test1.rpm", package.filename
)
self.assertItemsEqual(
['test2 (eq 0-1.1.1.1-1.el7)'],
['test2 (= 0-1.1.1.1-1.el7)'],
(str(x) for x in package.requires)
)
self.assertItemsEqual(
@ -165,7 +166,7 @@ class TestRpmDriver(base.TestCase):
self.assertTrue(package.mandatory)
@mock.patch("packetary.drivers.rpm_driver.shutil")
def test_rebuild_repository(self, shutil):
def test_add_packages(self, shutil):
self.createrepo.MDError = ValueError
self.createrepo.MetaDataGenerator().doFinalMove.side_effect = [
None, self.createrepo.MDError()
@ -174,7 +175,7 @@ class TestRpmDriver(base.TestCase):
self.createrepo.MetaDataConfig().outputdir = "/repo/os/x86_64"
self.createrepo.MetaDataConfig().tempdir = "tmp"
self.driver.rebuild_repository(repo, set())
self.driver.add_packages(self.connection, repo, set())
self.assertEqual(
"/repo/os/x86_64",
@ -189,24 +190,23 @@ class TestRpmDriver(base.TestCase):
.doFinalMove.assert_called_once_with()
with self.assertRaises(RuntimeError):
self.driver.rebuild_repository(repo, set())
self.driver.add_packages(self.connection, repo, set())
shutil.rmtree.assert_called_once_with(
"/repo/os/x86_64/tmp", ignore_errors=True
)
@mock.patch("packetary.drivers.rpm_driver.utils")
def test_fork_repository(self, utils):
@mock.patch("packetary.drivers.rpm_driver.utils.ensure_dir_exist")
def test_fork_repository(self, ensure_dir_exists_mock):
repo = gen_repository("os", url="http://localhost/os/x86_64/")
utils.localize_repo_url = localize_repo_url
self.createrepo.MetaDataGenerator().doFinalMove.side_effect = [None]
new_repo = self.driver.fork_repository(
self.connection,
repo,
"/repo"
"/repo/os/x86_64"
)
utils.ensure_dir_exist.assert_called_once_with("/repo/os/x86_64/")
ensure_dir_exists_mock.assert_called_once_with("/repo/os/x86_64")
self.assertEqual(repo.name, new_repo.name)
self.assertEqual(repo.architecture, new_repo.architecture)
self.assertEqual("/repo/os/x86_64/", new_repo.url)
self.assertEqual("file:///repo/os/x86_64/", new_repo.url)
self.createrepo.MetaDataGenerator()\
.doFinalMove.assert_called_once_with()

View File

@ -1,6 +1,6 @@
[tox]
minversion = 1.6
envlist = py34,py27,py26,pep8
envlist = py34,py27,pep8
skipsdist = True
[testenv]