CNK's Blog

Import files into Wagtail

I am building a site that is replacing an older site and I want to preserve a substantial number of PDF files. So I wrote a manage.py command to import all the files in a nested set of directories into corresponding nested collections in Wagtail. For example, given the following local directory:

  archive
      - some-file.pdf
      - 2020
        - file1.pdf
        - file2.pdf
      - 2021
        - file3.pdf

my script will create collections for 2020 and 2021 and the import 4 PDF files into the correct collections and subcollections.

  # core/management/commands/import_documents_from_directory.py

  from django.core.exceptions import ObjectDoesNotExist
  from django.core.management import BaseCommand, CommandError
  from wagtail.models import Collection, get_root_collection_id

  from core.jobs.document_importer import DocumentImporter

  class Command(BaseCommand):
      help = "Imports all files nested under `pdf-directory` into
      corresponding collection under the given base collection."

      def add_arguments(self, parser):
          parser.add_argument(
              '--pdf-directory',
              dest='pdf_directory',
              default='/tmp/documents',
              help="Path to the local directory where the PDFs are located"
          )

          parser.add_argument(
              '--base-collection',
              dest='base_collection',
              required=False,
              help="Which collection should get these files? Will use the base collection if this is missing."
          )

          parser.add_argument(
              '--dry-run',
              action='store_true',
              dest='dry_run',
              default=False,
              help='Try not to change the database; just show what would have been done.',
          )

      def handle(self, **options):
          if options['base_collection']:
              try:
                  base_collection = Collection.objects.get(name=options['base_collection'])
              except ObjectDoesNotExist:
                  raise CommandError(f"Base collection \"{options['base_collection']}\" does not exist")
          else:
              base_collection = Collection.objects.get(pk=get_root_collection_id())

          importer = DocumentImporter()
          importer.import_all(options['pdf_directory'], base_collection, options['dry_run'])
  # core/jobs/document_importer.py

  import hashlib
  import os
  from django.core.files import File

  from wagtail.documents import get_document_model
  from wagtail.models import Collection

  from core.logging import logger


  class DocumentImporter(object):
      """
      Given a nested directory of files, import them into Wagtails documents model - preserving the
      folder structure as nested collections.
      """

      def import_all(self, pdf_directory, base_collection, dry_run=False):
          for path, file in self._get_files(pdf_directory):
              collection = self._get_collection(path, pdf_directory, base_collection, dry_run)
              self._create_document(file, path, collection, dry_run)

      def _get_files(self, root):
          """Recursively iterate all the .py files in the root directory and below"""
          for path, dirs, files in os.walk(root):
              yield from ((path, file) for file in files)

      def _get_collection(self, path, pdf_directory, base_collection, dry_run):
          """
          Construct a nested set of collections corresponding to the nested directories.
          """
          current_parent = base_collection
          rel_path = os.path.relpath(path, pdf_directory)
          for part in rel_path.split('/'):
              collection = current_parent.get_descendants().filter(name=part).first()
              if collection:
                  current_parent = collection
                  logger.info(
                      'document_importer.collection.found',
                      dry_run=dry_run,
                      name=part,
                  )
              else:
                  # create this collection
                  if not dry_run:
                      collection = Collection(name=part)
                      current_parent.add_child(instance=collection)
                      # Set this as the parent for the next node in our list
                      current_parent = collection
                  logger.info(
                      'document_importer.collection.create',
                      dry_run=dry_run,
                      name=part,
                  )
          return current_parent

      def _create_document(self, file, path, collection, dry_run):
          doc = get_document_model().objects.filter(file__endswith=file).first()
          if doc:
              op = "update"
              if dry_run:
                  self.__log_document_changes(op, file, collection, dry_run)
              else:
                  with open(f'{path}/{file}', "rb") as fd:
                      new_hash = hashlib.sha1(fd.read()).hexdigest()
                      if not new_hash == doc.file_hash:
                          doc.file = File(fd, name=file)
                          doc.file_size = len(doc.file)
                          doc.file_hash = new_hash
                          doc.save()
                          self.__log_document_changes(op, file, collection, dry_run)
                      if not collection == doc.collection:
                          doc.collection = collection
                          doc.save()
                          self.__log_document_changes(op, file, collection, dry_run)
          else:
              op = "create"
              if dry_run:
                  self.__log_document_changes(op, file, collection, dry_run)
              else:
                  with open(f'{path}/{file}', "rb") as fd:
                      doc = get_document_model()(title=file, collection=collection)
                      doc.file = File(fd, name=file)
                      doc.file_size = len(doc.file)
                      doc.file_hash = hashlib.sha1(fd.read()).hexdigest()
                      doc.save()
                      self.__log_document_changes(op, file, collection, dry_run)

      def __log_document_changes(self, op, file, collection, dry_run):
          logger.info(
              "document_importer.document.{}".format(op),
              dry_run=dry_run,
              file=file,
              collection=collection,
          )