This commit is contained in:
Michael Witten 2017-08-25 05:04:52 +00:00 committed by GitHub
commit 790409024f
4 changed files with 148 additions and 101 deletions

View file

@ -7,7 +7,7 @@
--- ---
Scour is an SVG optimizer/cleaner that reduces the size of scalable vector graphics by optimizing structure and removing unnecessary data written in Python. Scour is an SVG optimizer/cleaner that reduces the size of scalable vector graphics by optimizing structure and removing unnecessary data; scour is written in Python.
It can be used to create streamlined vector graphics suitable for web deployment, publishing/sharing or further processing. It can be used to create streamlined vector graphics suitable for web deployment, publishing/sharing or further processing.
@ -16,11 +16,11 @@ The goal of Scour is to output a file that renderes identically at a fraction of
Scour is open-source and licensed under [Apache License 2.0](https://github.com/codedread/scour/blob/master/LICENSE). Scour is open-source and licensed under [Apache License 2.0](https://github.com/codedread/scour/blob/master/LICENSE).
Scour was originally developed by Jeff "codedread" Schiller and Louis Simard in in 2010. Scour was originally developed by Jeff "codedread" Schiller and Louis Simard in in 2010.
The project moved to GitLab in 2013 an is now maintained by Tobias "oberstet" Oberstein and Eduard "Ede_123" Braun. The project moved to GitLab in 2013, and then later to GitHub; it is now maintained by Tobias Oberstein ("oberstet") and Eduard Braun ("Ede123").
## Installation ## Installation
Scour requires [Python](https://www.python.org) 2.7 or 3.3+. Further, for installation, [pip](https://pip.pypa.io) should be used. Scour requires [Python](https://www.python.org) 2.7 or 3.3+; for installation, [pip](https://pip.pypa.io) should be used.
To install the [latest release](https://pypi.python.org/pypi/scour) of Scour from PyPI: To install the [latest release](https://pypi.python.org/pypi/scour) of Scour from PyPI:
@ -28,7 +28,7 @@ To install the [latest release](https://pypi.python.org/pypi/scour) of Scour fro
pip install scour pip install scour
``` ```
To install the [latest trunk](https://github.com/codedread/scour) version (which might be broken!) from GitHub: To install the [latest version](https://github.com/codedread/scour) (*which might be broken!*):
```console ```console
pip install https://github.com/codedread/scour/archive/master.zip pip install https://github.com/codedread/scour/archive/master.zip

View file

@ -56,6 +56,7 @@ import re
import sys import sys
import time import time
import xml.dom.minidom import xml.dom.minidom
from xml.dom import Node
from collections import namedtuple from collections import namedtuple
from decimal import Context, Decimal, InvalidOperation, getcontext from decimal import Context, Decimal, InvalidOperation, getcontext
@ -540,7 +541,7 @@ def findElementsWithId(node, elems=None):
for child in node.childNodes: for child in node.childNodes:
# from http://www.w3.org/TR/DOM-Level-2-Core/idl-definitions.html # from http://www.w3.org/TR/DOM-Level-2-Core/idl-definitions.html
# we are only really interested in nodes of type Element (1) # we are only really interested in nodes of type Element (1)
if child.nodeType == 1: if child.nodeType == Node.ELEMENT_NODE:
findElementsWithId(child, elems) findElementsWithId(child, elems)
return elems return elems
@ -604,7 +605,7 @@ def findReferencedElements(node, ids=None):
if node.hasChildNodes(): if node.hasChildNodes():
for child in node.childNodes: for child in node.childNodes:
if child.nodeType == 1: if child.nodeType == Node.ELEMENT_NODE:
findReferencedElements(child, ids) findReferencedElements(child, ids)
return ids return ids
@ -645,8 +646,8 @@ def removeUnusedDefs(doc, defElem, elemsToRemove=None):
keepTags = ['font', 'style', 'metadata', 'script', 'title', 'desc'] keepTags = ['font', 'style', 'metadata', 'script', 'title', 'desc']
for elem in defElem.childNodes: for elem in defElem.childNodes:
# only look at it if an element and not referenced anywhere else # only look at it if an element and not referenced anywhere else
if elem.nodeType == 1 and (elem.getAttribute('id') == '' or if elem.nodeType == Node.ELEMENT_NODE and (elem.getAttribute('id') == '' or
elem.getAttribute('id') not in referencedIDs): elem.getAttribute('id') not in referencedIDs):
# we only inspect the children of a group in a defs if the group # we only inspect the children of a group in a defs if the group
# is not referenced anywhere else # is not referenced anywhere else
if elem.nodeName == 'g' and elem.namespaceURI == NS['SVG']: if elem.nodeName == 'g' and elem.namespaceURI == NS['SVG']:
@ -879,7 +880,7 @@ def removeUnreferencedIDs(referencedIDs, identifiedElements):
def removeNamespacedAttributes(node, namespaces): def removeNamespacedAttributes(node, namespaces):
global _num_attributes_removed global _num_attributes_removed
num = 0 num = 0
if node.nodeType == 1: if node.nodeType == Node.ELEMENT_NODE:
# remove all namespace'd attributes from this element # remove all namespace'd attributes from this element
attrList = node.attributes attrList = node.attributes
attrsToRemove = [] attrsToRemove = []
@ -901,7 +902,7 @@ def removeNamespacedAttributes(node, namespaces):
def removeNamespacedElements(node, namespaces): def removeNamespacedElements(node, namespaces):
global _num_elements_removed global _num_elements_removed
num = 0 num = 0
if node.nodeType == 1: if node.nodeType == Node.ELEMENT_NODE:
# remove all namespace'd child nodes from this element # remove all namespace'd child nodes from this element
childList = node.childNodes childList = node.childNodes
childrenToRemove = [] childrenToRemove = []
@ -959,12 +960,12 @@ def removeNestedGroups(node):
groupsToRemove = [] groupsToRemove = []
# Only consider <g> elements for promotion if this element isn't a <switch>. # Only consider <g> elements for promotion if this element isn't a <switch>.
# (partial fix for bug 594930, required by the SVG spec however) # (partial fix for bug 594930, required by the SVG spec however)
if not (node.nodeType == 1 and node.nodeName == 'switch'): if not (node.nodeType == Node.ELEMENT_NODE and node.nodeName == 'switch'):
for child in node.childNodes: for child in node.childNodes:
if child.nodeName == 'g' and child.namespaceURI == NS['SVG'] and len(child.attributes) == 0: if child.nodeName == 'g' and child.namespaceURI == NS['SVG'] and len(child.attributes) == 0:
# only collapse group if it does not have a title or desc as a direct descendant, # only collapse group if it does not have a title or desc as a direct descendant,
for grandchild in child.childNodes: for grandchild in child.childNodes:
if grandchild.nodeType == 1 and grandchild.namespaceURI == NS['SVG'] and \ if grandchild.nodeType == Node.ELEMENT_NODE and grandchild.namespaceURI == NS['SVG'] and \
grandchild.nodeName in ['title', 'desc']: grandchild.nodeName in ['title', 'desc']:
break break
else: else:
@ -979,7 +980,7 @@ def removeNestedGroups(node):
# now recurse for children # now recurse for children
for child in node.childNodes: for child in node.childNodes:
if child.nodeType == 1: if child.nodeType == Node.ELEMENT_NODE:
num += removeNestedGroups(child) num += removeNestedGroups(child)
return num return num
@ -997,14 +998,14 @@ def moveCommonAttributesToParentGroup(elem, referencedElements):
childElements = [] childElements = []
# recurse first into the children (depth-first) # recurse first into the children (depth-first)
for child in elem.childNodes: for child in elem.childNodes:
if child.nodeType == 1: if child.nodeType == Node.ELEMENT_NODE:
# only add and recurse if the child is not referenced elsewhere # only add and recurse if the child is not referenced elsewhere
if not child.getAttribute('id') in referencedElements: if not child.getAttribute('id') in referencedElements:
childElements.append(child) childElements.append(child)
num += moveCommonAttributesToParentGroup(child, referencedElements) num += moveCommonAttributesToParentGroup(child, referencedElements)
# else if the parent has non-whitespace text children, do not # else if the parent has non-whitespace text children, do not
# try to move common attributes # try to move common attributes
elif child.nodeType == 3 and child.nodeValue.strip(): elif child.nodeType == Node.TEXT_NODE and child.nodeValue.strip():
return num return num
# only process the children if there are more than one element # only process the children if there are more than one element
@ -1102,23 +1103,27 @@ def createGroupsForCommonAttributes(elem):
while curChild >= 0: while curChild >= 0:
childNode = elem.childNodes.item(curChild) childNode = elem.childNodes.item(curChild)
if childNode.nodeType == 1 and childNode.getAttribute(curAttr) != '' and childNode.nodeName in [ if (
# only attempt to group elements that the content model allows to be children of a <g> childNode.nodeType == Node.ELEMENT_NODE and
childNode.getAttribute(curAttr) != '' and
childNode.nodeName in [
# only attempt to group elements that the content model allows to be children of a <g>
# SVG 1.1 (see https://www.w3.org/TR/SVG/struct.html#GElement) # SVG 1.1 (see https://www.w3.org/TR/SVG/struct.html#GElement)
'animate', 'animateColor', 'animateMotion', 'animateTransform', 'set', # animation elements 'animate', 'animateColor', 'animateMotion', 'animateTransform', 'set', # animation elements
'desc', 'metadata', 'title', # descriptive elements 'desc', 'metadata', 'title', # descriptive elements
'circle', 'ellipse', 'line', 'path', 'polygon', 'polyline', 'rect', # shape elements 'circle', 'ellipse', 'line', 'path', 'polygon', 'polyline', 'rect', # shape elements
'defs', 'g', 'svg', 'symbol', 'use', # structural elements 'defs', 'g', 'svg', 'symbol', 'use', # structural elements
'linearGradient', 'radialGradient', # gradient elements 'linearGradient', 'radialGradient', # gradient elements
'a', 'altGlyphDef', 'clipPath', 'color-profile', 'cursor', 'filter', 'a', 'altGlyphDef', 'clipPath', 'color-profile', 'cursor', 'filter',
'font', 'font-face', 'foreignObject', 'image', 'marker', 'mask', 'font', 'font-face', 'foreignObject', 'image', 'marker', 'mask',
'pattern', 'script', 'style', 'switch', 'text', 'view', 'pattern', 'script', 'style', 'switch', 'text', 'view',
# SVG 1.2 (see https://www.w3.org/TR/SVGTiny12/elementTable.html) # SVG 1.2 (see https://www.w3.org/TR/SVGTiny12/elementTable.html)
'animation', 'audio', 'discard', 'handler', 'listener', 'animation', 'audio', 'discard', 'handler', 'listener',
'prefetch', 'solidColor', 'textArea', 'video' 'prefetch', 'solidColor', 'textArea', 'video'
]: ]
):
# We're in a possible run! Track the value and run length. # We're in a possible run! Track the value and run length.
value = childNode.getAttribute(curAttr) value = childNode.getAttribute(curAttr)
runStart, runEnd = curChild, curChild runStart, runEnd = curChild, curChild
@ -1130,7 +1135,7 @@ def createGroupsForCommonAttributes(elem):
# attribute value, preserving any nodes in-between. # attribute value, preserving any nodes in-between.
while runStart > 0: while runStart > 0:
nextNode = elem.childNodes.item(runStart - 1) nextNode = elem.childNodes.item(runStart - 1)
if nextNode.nodeType == 1: if nextNode.nodeType == Node.ELEMENT_NODE:
if nextNode.getAttribute(curAttr) != value: if nextNode.getAttribute(curAttr) != value:
break break
else: else:
@ -1142,7 +1147,7 @@ def createGroupsForCommonAttributes(elem):
if runElements >= 3: if runElements >= 3:
# Include whitespace/comment/etc. nodes in the run. # Include whitespace/comment/etc. nodes in the run.
while runEnd < elem.childNodes.length - 1: while runEnd < elem.childNodes.length - 1:
if elem.childNodes.item(runEnd + 1).nodeType == 1: if elem.childNodes.item(runEnd + 1).nodeType == Node.ELEMENT_NODE:
break break
else: else:
runEnd += 1 runEnd += 1
@ -1186,7 +1191,7 @@ def createGroupsForCommonAttributes(elem):
# each child gets the same treatment, recursively # each child gets the same treatment, recursively
for childNode in elem.childNodes: for childNode in elem.childNodes:
if childNode.nodeType == 1: if childNode.nodeType == Node.ELEMENT_NODE:
num += createGroupsForCommonAttributes(childNode) num += createGroupsForCommonAttributes(childNode)
return num return num
@ -1202,7 +1207,7 @@ def removeUnusedAttributesOnParent(elem):
childElements = [] childElements = []
# recurse first into the children (depth-first) # recurse first into the children (depth-first)
for child in elem.childNodes: for child in elem.childNodes:
if child.nodeType == 1: if child.nodeType == Node.ELEMENT_NODE:
childElements.append(child) childElements.append(child)
num += removeUnusedAttributesOnParent(child) num += removeUnusedAttributesOnParent(child)
@ -1302,11 +1307,15 @@ def collapseSinglyReferencedGradients(doc):
# (Cyn: I've seen documents with #id references but no element with that ID!) # (Cyn: I've seen documents with #id references but no element with that ID!)
if count == 1 and rid in identifiedElements: if count == 1 and rid in identifiedElements:
elem = identifiedElements[rid] elem = identifiedElements[rid]
if elem is not None and elem.nodeType == 1 and elem.nodeName in ['linearGradient', 'radialGradient'] \ if (
and elem.namespaceURI == NS['SVG']: elem is not None and
elem.nodeType == Node.ELEMENT_NODE and
elem.nodeName in ['linearGradient', 'radialGradient'] and
elem.namespaceURI == NS['SVG']
):
# found a gradient that is referenced by only 1 other element # found a gradient that is referenced by only 1 other element
refElem = nodes[0] refElem = nodes[0]
if refElem.nodeType == 1 and refElem.nodeName in ['linearGradient', 'radialGradient'] \ if refElem.nodeType == Node.ELEMENT_NODE and refElem.nodeName in ['linearGradient', 'radialGradient'] \
and refElem.namespaceURI == NS['SVG']: and refElem.namespaceURI == NS['SVG']:
# elem is a gradient referenced by only one other gradient (refElem) # elem is a gradient referenced by only one other gradient (refElem)
@ -1448,7 +1457,7 @@ def removeDuplicateGradients(doc):
def _getStyle(node): def _getStyle(node):
u"""Returns the style attribute of a node as a dictionary.""" u"""Returns the style attribute of a node as a dictionary."""
if node.nodeType == 1 and len(node.getAttribute('style')) > 0: if node.nodeType == Node.ELEMENT_NODE and len(node.getAttribute('style')) > 0:
styleMap = {} styleMap = {}
rawStyles = node.getAttribute('style').split(';') rawStyles = node.getAttribute('style').split(';')
for style in rawStyles: for style in rawStyles:
@ -1614,7 +1623,7 @@ def styleInheritedFromParent(node, style):
parentNode = node.parentNode parentNode = node.parentNode
# return None if we reached the Document element # return None if we reached the Document element
if parentNode.nodeType == 9: if parentNode.nodeType == Node.DOCUMENT_NODE:
return None return None
# check styles first (they take precedence over presentation attributes) # check styles first (they take precedence over presentation attributes)
@ -1647,7 +1656,7 @@ def styleInheritedByChild(node, style, nodeIsChild=False):
any style sheets are ignored! any style sheets are ignored!
""" """
# Comment, text and CDATA nodes don't have attributes and aren't containers so they can't inherit attributes # Comment, text and CDATA nodes don't have attributes and aren't containers so they can't inherit attributes
if node.nodeType != 1: if node.nodeType != Node.ELEMENT_NODE:
return False return False
if nodeIsChild: if nodeIsChild:
@ -1702,7 +1711,7 @@ def mayContainTextNodes(node):
result = True # Default value result = True # Default value
# Comment, text and CDATA nodes don't have attributes and aren't containers # Comment, text and CDATA nodes don't have attributes and aren't containers
if node.nodeType != 1: if node.nodeType != Node.ELEMENT_NODE:
result = False result = False
# Non-SVG elements? Unknown elements! # Non-SVG elements? Unknown elements!
elif node.namespaceURI != NS['SVG']: elif node.namespaceURI != NS['SVG']:
@ -1920,7 +1929,7 @@ def removeDefaultAttributeValues(node, options, tainted=set()):
For such attributes, we don't delete attributes with default values.""" For such attributes, we don't delete attributes with default values."""
num = 0 num = 0
if node.nodeType != 1: if node.nodeType != Node.ELEMENT_NODE:
return 0 return 0
# Conditionally remove all default attributes defined in 'default_attributes' (a list of 'DefaultAttribute's) # Conditionally remove all default attributes defined in 'default_attributes' (a list of 'DefaultAttribute's)
@ -1997,7 +2006,7 @@ def convertColors(element):
""" """
numBytes = 0 numBytes = 0
if element.nodeType != 1: if element.nodeType != Node.ELEMENT_NODE:
return 0 return 0
# set up list of color attributes for each element type # set up list of color attributes for each element type
@ -2772,7 +2781,7 @@ def reducePrecision(element):
_setStyle(element, styles) _setStyle(element, styles)
for child in element.childNodes: for child in element.childNodes:
if child.nodeType == 1: if child.nodeType == Node.ELEMENT_NODE:
num += reducePrecision(child) num += reducePrecision(child)
return num return num
@ -2989,7 +2998,7 @@ def optimizeTransforms(element, options):
num += len(val) - len(newVal) num += len(val) - len(newVal)
for child in element.childNodes: for child in element.childNodes:
if child.nodeType == 1: if child.nodeType == Node.ELEMENT_NODE:
num += optimizeTransforms(child, options) num += optimizeTransforms(child, options)
return num return num
@ -3133,7 +3142,7 @@ def properlySizeDoc(docElement, options):
def remapNamespacePrefix(node, oldprefix, newprefix): def remapNamespacePrefix(node, oldprefix, newprefix):
if node is None or node.nodeType != 1: if node is None or node.nodeType != Node.ELEMENT_NODE:
return return
if node.prefix == oldprefix: if node.prefix == oldprefix:
@ -3169,20 +3178,10 @@ def remapNamespacePrefix(node, oldprefix, newprefix):
remapNamespacePrefix(child, oldprefix, newprefix) remapNamespacePrefix(child, oldprefix, newprefix)
def makeWellFormed(str): def makeWellFormed(str, quote=''):
# Don't escape quotation marks for now as they are fine in text nodes
# as well as in attributes if used reciprocally
# xml_ents = { '<':'&lt;', '>':'&gt;', '&':'&amp;', "'":'&apos;', '"':'&quot;'}
xml_ents = {'<': '&lt;', '>': '&gt;', '&': '&amp;'} xml_ents = {'<': '&lt;', '>': '&gt;', '&': '&amp;'}
if quote:
# starr = [] xml_ents[quote] = '&apos;' if (quote == "'") else "&quot;"
# for c in str:
# if c in xml_ents:
# starr.append(xml_ents[c])
# else:
# starr.append(c)
# this list comprehension is short-form for the above for-loop:
return ''.join([xml_ents[c] if c in xml_ents else c for c in str]) return ''.join([xml_ents[c] if c in xml_ents else c for c in str])
@ -3206,25 +3205,11 @@ def serializeXML(element, options, ind=0, preserveWhitespace=False):
outParts.extend([(I * ind), '<', element.nodeName]) outParts.extend([(I * ind), '<', element.nodeName])
# always serialize the id or xml:id attributes first
if element.getAttribute('id') != '':
id = element.getAttribute('id')
quot = '"'
if id.find('"') != -1:
quot = "'"
outParts.extend([' id=', quot, id, quot])
if element.getAttribute('xml:id') != '':
id = element.getAttribute('xml:id')
quot = '"'
if id.find('"') != -1:
quot = "'"
outParts.extend([' xml:id=', quot, id, quot])
# now serialize the other attributes # now serialize the other attributes
known_attr = [ known_attr = [
# TODO: Maybe update with full list from https://www.w3.org/TR/SVG/attindex.html # TODO: Maybe update with full list from https://www.w3.org/TR/SVG/attindex.html
# (but should be kept inuitively ordered) # (but should be kept inuitively ordered)
'id', 'class', 'id', 'xml:id', 'class',
'transform', 'transform',
'x', 'y', 'z', 'width', 'height', 'x1', 'x2', 'y1', 'y2', 'x', 'y', 'z', 'width', 'height', 'x1', 'x2', 'y1', 'y2',
'dx', 'dy', 'rotate', 'startOffset', 'method', 'spacing', 'dx', 'dy', 'rotate', 'startOffset', 'method', 'spacing',
@ -3244,14 +3229,24 @@ def serializeXML(element, options, ind=0, preserveWhitespace=False):
attrIndices += [attrName2Index[name] for name in sorted(attrName2Index.keys())] attrIndices += [attrName2Index[name] for name in sorted(attrName2Index.keys())]
for index in attrIndices: for index in attrIndices:
attr = attrList.item(index) attr = attrList.item(index)
if attr.nodeName == 'id' or attr.nodeName == 'xml:id':
continue
# if the attribute value contains a double-quote, use single-quotes
quot = '"'
if attr.nodeValue.find('"') != -1:
quot = "'"
attrValue = makeWellFormed(attr.nodeValue) attrValue = attr.nodeValue
quot_count = 0
apos_count = 0
for c in attrValue:
if c == '"':
quot_count += 1
elif c == "'":
apos_count += 1
if quot_count > apos_count:
quote = "'"
else:
quote = '"'
attrValue = makeWellFormed(attrValue, quote if (quot_count or apos_count) else '')
if attr.nodeName == 'style': if attr.nodeName == 'style':
# sort declarations # sort declarations
attrValue = ';'.join([p for p in sorted(attrValue.split(';'))]) attrValue = ';'.join([p for p in sorted(attrValue.split(';'))])
@ -3265,7 +3260,7 @@ def serializeXML(element, options, ind=0, preserveWhitespace=False):
outParts.append('xmlns:') outParts.append('xmlns:')
elif attr.namespaceURI == 'http://www.w3.org/1999/xlink': elif attr.namespaceURI == 'http://www.w3.org/1999/xlink':
outParts.append('xlink:') outParts.append('xlink:')
outParts.extend([attr.localName, '=', quot, attrValue, quot]) outParts.extend([attr.localName, '=', quote, attrValue, quote])
if attr.nodeName == 'xml:space': if attr.nodeName == 'xml:space':
if attrValue == 'preserve': if attrValue == 'preserve':
@ -3273,22 +3268,25 @@ def serializeXML(element, options, ind=0, preserveWhitespace=False):
elif attrValue == 'default': elif attrValue == 'default':
preserveWhitespace = False preserveWhitespace = False
# if no children, self-close
children = element.childNodes children = element.childNodes
if children.length > 0: if children.length == 0:
outParts.append('/>')
if indent > 0:
outParts.append(newline)
else:
outParts.append('>') outParts.append('>')
onNewLine = False onNewLine = False
for child in element.childNodes: for child in element.childNodes:
# element node # element node
if child.nodeType == 1: if child.nodeType == Node.ELEMENT_NODE:
if preserveWhitespace: if preserveWhitespace:
outParts.append(serializeXML(child, options, 0, preserveWhitespace)) outParts.append(serializeXML(child, options, 0, preserveWhitespace))
else: else:
outParts.extend([newline, serializeXML(child, options, indent + 1, preserveWhitespace)]) outParts.extend([newline, serializeXML(child, options, indent + 1, preserveWhitespace)])
onNewLine = True onNewLine = True
# text node # text node
elif child.nodeType == 3: elif child.nodeType == Node.TEXT_NODE:
# trim it only in the case of not being a child of an element # trim it only in the case of not being a child of an element
# where whitespace might be important # where whitespace might be important
if preserveWhitespace: if preserveWhitespace:
@ -3296,10 +3294,10 @@ def serializeXML(element, options, ind=0, preserveWhitespace=False):
else: else:
outParts.append(makeWellFormed(child.nodeValue.strip())) outParts.append(makeWellFormed(child.nodeValue.strip()))
# CDATA node # CDATA node
elif child.nodeType == 4: elif child.nodeType == Node.CDATA_SECTION_NODE:
outParts.extend(['<![CDATA[', child.nodeValue, ']]>']) outParts.extend(['<![CDATA[', child.nodeValue, ']]>'])
# Comment node # Comment node
elif child.nodeType == 8: elif child.nodeType == Node.COMMENT_NODE:
outParts.extend(['<!--', child.nodeValue, '-->']) outParts.extend(['<!--', child.nodeValue, '-->'])
# TODO: entities, processing instructions, what else? # TODO: entities, processing instructions, what else?
else: # ignore the rest else: # ignore the rest
@ -3310,10 +3308,6 @@ def serializeXML(element, options, ind=0, preserveWhitespace=False):
outParts.extend(['</', element.nodeName, '>']) outParts.extend(['</', element.nodeName, '>'])
if indent > 0: if indent > 0:
outParts.append(newline) outParts.append(newline)
else:
outParts.append('/>')
if indent > 0:
outParts.append(newline)
return "".join(outParts) return "".join(outParts)
@ -3468,9 +3462,9 @@ def scourString(in_string, options=None):
removeElem = not elem.hasChildNodes() removeElem = not elem.hasChildNodes()
if removeElem is False: if removeElem is False:
for child in elem.childNodes: for child in elem.childNodes:
if child.nodeType in [1, 4, 8]: if child.nodeType in [Node.ELEMENT_NODE, Node.CDATA_SECTION_NODE, Node.COMMENT_NODE]:
break break
elif child.nodeType == 3 and not child.nodeValue.isspace(): elif child.nodeType == Node.TEXT_NODE and not child.nodeValue.isspace():
break break
else: else:
removeElem = True removeElem = True
@ -3594,7 +3588,7 @@ def scourString(in_string, options=None):
total_output = "" total_output = ""
for child in doc.childNodes: for child in doc.childNodes:
if child.nodeType == 1: if child.nodeType == Node.ELEMENT_NODE:
total_output += "".join(lines) total_output += "".join(lines)
else: # doctypes, entities, comments else: # doctypes, entities, comments
total_output += child.toxml() + '\n' total_output += child.toxml() + '\n'
@ -3603,9 +3597,7 @@ def scourString(in_string, options=None):
# used mostly by unit tests # used mostly by unit tests
# input is a filename def scourXmlFileAndReturnString(filename, options=None):
# returns the minidom doc representation of the SVG
def scourXmlFile(filename, options=None):
# sanitize options (take missing attributes from defaults, discard unknown attributes) # sanitize options (take missing attributes from defaults, discard unknown attributes)
options = sanitizeOptions(options) options = sanitizeOptions(options)
# we need to make sure infilename is set correctly (otherwise relative references in the SVG won't work) # we need to make sure infilename is set correctly (otherwise relative references in the SVG won't work)
@ -3614,7 +3606,14 @@ def scourXmlFile(filename, options=None):
# open the file and scour it # open the file and scour it
with open(filename, "rb") as f: with open(filename, "rb") as f:
in_string = f.read() in_string = f.read()
out_string = scourString(in_string, options)
return scourString(in_string, options)
# used mostly by unit tests
# returns the minidom doc representation of the SVG
def scourXmlFile(filename, options=None):
out_string = scourXmlFileAndReturnString(filename, options)
# prepare the output xml.dom.minidom object # prepare the output xml.dom.minidom object
doc = xml.dom.minidom.parseString(out_string.encode('utf-8')) doc = xml.dom.minidom.parseString(out_string.encode('utf-8'))

View file

@ -30,7 +30,12 @@ import unittest
import six import six
from six.moves import map, range from six.moves import map, range
from scour.scour import makeWellFormed, parse_args, scourString, scourXmlFile, start, run from scour.scour import (
makeWellFormed, parse_args, scourString,
scourXmlFileAndReturnString, scourXmlFile,
start, run
)
from scour.svg_regex import svg_parser from scour.svg_regex import svg_parser
from scour import __version__ from scour import __version__
@ -1779,7 +1784,42 @@ class XmlEntities(unittest.TestCase):
def runTest(self): def runTest(self):
self.assertEqual(makeWellFormed('<>&'), '&lt;&gt;&amp;', self.assertEqual(makeWellFormed('<>&'), '&lt;&gt;&amp;',
'Incorrectly translated XML entities') 'Incorrectly translated unquoted XML entities')
self.assertEqual(makeWellFormed('<>&', "'"), '&lt;&gt;&amp;',
'Incorrectly translated single-quoted XML entities')
self.assertEqual(makeWellFormed('<>&', '"'), '&lt;&gt;&amp;',
'Incorrectly translated double-quoted XML entities')
self.assertEqual(makeWellFormed("'"), "'",
'Incorrectly translated unquoted single quote')
self.assertEqual(makeWellFormed('"'), '"',
'Incorrectly translated unquoted double quote')
self.assertEqual(makeWellFormed("'", '"'), "'",
'Incorrectly translated double-quoted single quote')
self.assertEqual(makeWellFormed('"', "'"), '"',
'Incorrectly translated single-quoted double quote')
self.assertEqual(makeWellFormed("'", "'"), '&apos;',
'Incorrectly translated single-quoted single quote')
self.assertEqual(makeWellFormed('"', '"'), '&quot;',
'Incorrectly translated double-quoted double quote')
class HandleQuotesInAttributes(unittest.TestCase):
def runTest(self):
output = scourXmlFileAndReturnString('unittests/entities.svg')
self.assertTrue('a="\'"' in output,
'Failed on attribute value with non-double quote')
self.assertTrue("b='\"'" in output,
'Failed on attribute value with non-single quote')
self.assertTrue("c=\"''&quot;\"" in output,
'Failed on attribute value with more single quotes than double quotes')
self.assertTrue('d=\'""&apos;\'' in output,
'Failed on attribute value with more double quotes than single quotes')
self.assertTrue("e=\"''&quot;&quot;\"" in output,
'Failed on attribute value with the same number of double quotes as single quotes')
class DoNotStripCommentsOutsideOfRoot(unittest.TestCase): class DoNotStripCommentsOutsideOfRoot(unittest.TestCase):

8
unittests/entities.svg Normal file
View file

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg"
a="'"
b='"'
c="''&quot;"
d='""&apos;'
e='&apos;&apos;""'
/>

After

Width:  |  Height:  |  Size: 144 B