Skip to content

Commit

Permalink
Add Dataset & Element equality methods, greatly speeding up tests tha…
Browse files Browse the repository at this point in the history
…t rely on dataset comparisons. (#280)

This change introduces well-defined equality methods for Dataset, Element, and some other data structures. This greatly speeds up tests that rely on checking equality of datasets (that previously needed reflection). For example, it reduces the total test suite from 1m24s to 10s on GitHub actions (mostly due to one test). These methods may also be of use to library users.

However, this does mean that if new fields are added to any of these structs it is important for the Equals method to be updated as well. For now this will be enforced during code review (helped by the fact most of these structs should not fail often), but we should investigate lint rules or some auto-generated reflection based tests that can help catch when this doesn't happen (see #281).

This change also makes a change to rely on pointers for []*frame.Frame in the PixelDataInfo.
  • Loading branch information
suyashkumar authored Aug 26, 2023
1 parent 5933371 commit 8f1f151
Show file tree
Hide file tree
Showing 11 changed files with 524 additions and 51 deletions.
2 changes: 1 addition & 1 deletion cmd/dicomutil/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,6 @@ func generateImage(fr *frame.Frame, frameIndex int, frameSuffix string, wg *sync
func writePixelDataElement(e *dicom.Element, suffix string) {
imageInfo := e.Value.GetValue().(dicom.PixelDataInfo)
for idx, f := range imageInfo.Frames {
generateImage(&f, idx, suffix, nil)
generateImage(f, idx, suffix, nil)
}
}
14 changes: 14 additions & 0 deletions dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,20 @@ func (d *Dataset) String() string {
return b.String()
}

// Equals returns true if this Dataset equals the provided target Dataset,
// otherwise false.
func (d *Dataset) Equals(target *Dataset) bool {
if target == nil || d == nil {
return d == target
}
for idx, e := range d.Elements {
if !e.Equals(target.Elements[idx]) {
return false
}
}
return true
}

type elementWithLevel struct {
e *Element
// l represents the nesting level of the Element
Expand Down
186 changes: 159 additions & 27 deletions element.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dicom

import (
"bytes"
"encoding/json"
"errors"
"fmt"
Expand All @@ -24,6 +25,24 @@ type Element struct {
Value Value `json:"value"`
}

// Equals returns true if this Element equals the provided target Element,
// otherwise false.
func (e *Element) Equals(target *Element) bool {
if target == nil || e == nil {
return e == target
}
if !e.Tag.Equals(target.Tag) ||
e.RawValueRepresentation != target.RawValueRepresentation ||
e.ValueLength != target.ValueLength ||
e.ValueRepresentation != target.ValueRepresentation {
return false
}
if !e.Value.Equals(target.Value) {
return false
}
return true
}

func (e *Element) String() string {
var tagName string
if tagInfo, err := tag.Find(e.Tag); err == nil {
Expand Down Expand Up @@ -75,6 +94,8 @@ type Value interface {
GetValue() interface{} // TODO: rename to Get to read cleaner
String() string
MarshalJSON() ([]byte, error)
// Equals returns true if this value equals the input Value.
Equals(Value) bool
}

// NewValue creates a new DICOM value for the supplied data. Likely most useful
Expand Down Expand Up @@ -204,6 +225,16 @@ func (b *bytesValue) MarshalJSON() ([]byte, error) {
return json.Marshal(b.value)
}

func (b *bytesValue) Equals(target Value) bool {
if target.ValueType() != Bytes {
return false
}
if !bytes.Equal(b.value, target.GetValue().([]byte)) {
return false
}
return true
}

// stringsValue represents a value of []string.
type stringsValue struct {
value []string
Expand All @@ -219,34 +250,81 @@ func (s *stringsValue) MarshalJSON() ([]byte, error) {
return json.Marshal(s.value)
}

func (s *stringsValue) Equals(target Value) bool {
if target.ValueType() != Strings {
return false
}
targetVal := target.GetValue().([]string)
if len(s.value) != len(targetVal) {
return false
}
for idx, val := range s.value {
if val != targetVal[idx] {
return false
}
}
return true
}

// intsValue represents a value of []int.
type intsValue struct {
value []int
}

func (s *intsValue) isElementValue() {}
func (s *intsValue) ValueType() ValueType { return Ints }
func (s *intsValue) GetValue() interface{} { return s.value }
func (s *intsValue) String() string {
return fmt.Sprintf("%v", s.value)
func (i *intsValue) isElementValue() {}
func (i *intsValue) ValueType() ValueType { return Ints }
func (i *intsValue) GetValue() interface{} { return i.value }
func (i *intsValue) String() string {
return fmt.Sprintf("%v", i.value)
}
func (s *intsValue) MarshalJSON() ([]byte, error) {
return json.Marshal(s.value)
func (i *intsValue) MarshalJSON() ([]byte, error) {
return json.Marshal(i.value)
}

func (i *intsValue) Equals(target Value) bool {
if target.ValueType() != Ints {
return false
}
targetVal := target.GetValue().([]int)
if len(i.value) != len(targetVal) {
return false
}
for idx, val := range i.value {
if val != targetVal[idx] {
return false
}
}
return true
}

// floatsValue represents a value of []float64.
type floatsValue struct {
value []float64
}

func (s *floatsValue) isElementValue() {}
func (s *floatsValue) ValueType() ValueType { return Floats }
func (s *floatsValue) GetValue() interface{} { return s.value }
func (s *floatsValue) String() string {
return fmt.Sprintf("%v", s.value)
func (f *floatsValue) isElementValue() {}
func (f *floatsValue) ValueType() ValueType { return Floats }
func (f *floatsValue) GetValue() interface{} { return f.value }
func (f *floatsValue) String() string {
return fmt.Sprintf("%v", f.value)
}
func (s *floatsValue) MarshalJSON() ([]byte, error) {
return json.Marshal(s.value)
func (f *floatsValue) MarshalJSON() ([]byte, error) {
return json.Marshal(f.value)
}
func (f *floatsValue) Equals(target Value) bool {
if target.ValueType() != Floats {
return false
}
targetVal := target.GetValue().([]float64)
if len(f.value) != len(targetVal) {
return false
}
for idx, val := range f.value {
if val != targetVal[idx] {
return false
}
}
return true
}

// SequenceItemValue is a Value that represents a single Sequence Item. Learn
Expand Down Expand Up @@ -278,6 +356,22 @@ func (s *SequenceItemValue) MarshalJSON() ([]byte, error) {
return json.Marshal(s.elements)
}

func (s *SequenceItemValue) Equals(target Value) bool {
if target.ValueType() != SequenceItem {
return false
}
targetVal := target.GetValue().([]*Element)
if len(s.elements) != len(targetVal) {
return false
}
for idx, val := range s.elements {
if !val.Equals(targetVal[idx]) {
return false
}
}
return true
}

// sequencesValue represents a set of items in a DICOM sequence.
type sequencesValue struct {
value []*SequenceItemValue
Expand All @@ -293,6 +387,21 @@ func (s *sequencesValue) String() string {
func (s *sequencesValue) MarshalJSON() ([]byte, error) {
return json.Marshal(s.value)
}
func (s *sequencesValue) Equals(target Value) bool {
if target.ValueType() != Sequences {
return false
}
targetVal := target.GetValue().([]*SequenceItemValue)
if len(s.value) != len(targetVal) {
return false
}
for idx, val := range s.value {
if !val.Equals(targetVal[idx]) {
return false
}
}
return true
}

// PixelDataInfo is a representation of DICOM PixelData.
type PixelDataInfo struct {
Expand All @@ -304,7 +413,7 @@ type PixelDataInfo struct {

// Frames hold the processed PixelData frames (either Native or Encapsulated
// PixelData).
Frames []frame.Frame
Frames []*frame.Frame

// ParseErr indicates if there was an error when reading this Frame from the DICOM.
// If this is set, this means fallback behavior was triggered to blindly write the PixelData bytes to an encapsulated frame.
Expand All @@ -329,24 +438,47 @@ type pixelDataValue struct {
PixelDataInfo
}

func (e *pixelDataValue) isElementValue() {}
func (e *pixelDataValue) ValueType() ValueType { return PixelData }
func (e *pixelDataValue) GetValue() interface{} { return e.PixelDataInfo }
func (e *pixelDataValue) String() string {
if len(e.Frames) == 0 {
func (p *pixelDataValue) isElementValue() {}
func (p *pixelDataValue) ValueType() ValueType { return PixelData }
func (p *pixelDataValue) GetValue() interface{} { return p.PixelDataInfo }
func (p *pixelDataValue) String() string {
if len(p.Frames) == 0 {
return "empty pixel data"
}
if e.IsEncapsulated {
return fmt.Sprintf("encapsulated FramesLength=%d Frame[0] size=%d", len(e.Frames), len(e.Frames[0].EncapsulatedData.Data))
if p.IsEncapsulated {
return fmt.Sprintf("encapsulated FramesLength=%d Frame[0] size=%d", len(p.Frames), len(p.Frames[0].EncapsulatedData.Data))
}
if e.ParseErr != nil {
return fmt.Sprintf("parseErr err=%s FramesLength=%d Frame[0] size=%d", e.ParseErr.Error(), len(e.Frames), len(e.Frames[0].EncapsulatedData.Data))
if p.ParseErr != nil {
return fmt.Sprintf("parseErr err=%s FramesLength=%d Frame[0] size=%d", p.ParseErr.Error(), len(p.Frames), len(p.Frames[0].EncapsulatedData.Data))
}
return fmt.Sprintf("FramesLength=%d FrameSize rows=%d cols=%d", len(e.Frames), e.Frames[0].NativeData.Rows, e.Frames[0].NativeData.Cols)
return fmt.Sprintf("FramesLength=%d FrameSize rows=%d cols=%d", len(p.Frames), p.Frames[0].NativeData.Rows, p.Frames[0].NativeData.Cols)
}

func (e *pixelDataValue) MarshalJSON() ([]byte, error) {
return json.Marshal(e.PixelDataInfo)
func (p *pixelDataValue) MarshalJSON() ([]byte, error) {
return json.Marshal(p.PixelDataInfo)
}
func (p *pixelDataValue) Equals(target Value) bool {
if target.ValueType() != PixelData {
return false
}
targetVal := target.GetValue().(PixelDataInfo)
if p.IntentionallySkipped != targetVal.IntentionallySkipped ||
p.IntentionallyUnprocessed != targetVal.IntentionallyUnprocessed ||
p.ParseErr != targetVal.ParseErr ||
p.IsEncapsulated != targetVal.IsEncapsulated ||
!bytes.Equal(p.UnprocessedValueData, targetVal.UnprocessedValueData) {
return false
}
targetFrameVal := target.GetValue().(PixelDataInfo).Frames
if len(p.Frames) != len(targetFrameVal) {
return false
}
for idx, val := range p.Frames {
if !val.Equals(targetFrameVal[idx]) {
return false
}
}
return true
}

// MustGetInts attempts to get an Ints value out of the provided value, and will
Expand Down
Loading

0 comments on commit 8f1f151

Please sign in to comment.