Giter Club home page Giter Club logo

Comments (8)

csyonghe avatar csyonghe commented on July 20, 2024

This is a type system bug related to constructors in generic interfaces. PR 3799 should fix this issue.

from slang.

csyonghe avatar csyonghe commented on July 20, 2024

It seems that you are trying to build something interesting with generics. Would you mind sharing more details on what you are trying to achieve here? Specifically, I wonder if the builtin IFloat interface will work for you, and if you need anything that is missing from the builtin interfaces.

from slang.

kevinboulain avatar kevinboulain commented on July 20, 2024

@csyonghe Ah, gladly. Don't be too harsh though, I've been learning GPU programming for a few months and just started with HLSL/Slang (GLSL was a bit painful for this) but comments are definitely welcome :)
Here's an excerpt where I'd like to abstract over 'regular' and interval arithmetic (so I don't have to maintain two implementations):

interface From<T> {
  __init(const T a);
}

interface Swizzle3<T> { // TODO: Can't write interface Swizzle3<X>: IArray<X> yet?
  property T x;
  property T y;
  property T z;
}

interface Arithmetic { // TODO: I guess there's no way to 'use Arithmetic' so that I could implement and scope operator resolution.
  associatedtype Float: From<float>;
  static Float abs(const Float a);
  static Float add(const Float a, const Float b);
  static Float divide(const Float a, const Float b);
  static Float divide(const Float a, const float b);
  static bool less_than_equal(const Float a, const Float b);
  static Float max(const Float a, const Float b);
  static Float min(const Float a, const Float b);
  static Float multiply(const Float a, const Float b);
  static Float negate(const Float a);
  static Float solve2(const Float a, const Float b, const Float c);
  static Float subtract(const Float a, const Float b);
  static Float subtract(const Float a, const float b);

  associatedtype Vector3: IArray<Float>, Swizzle3<Float>, From<Float[3]>; // vector and matrix are concrete types implementing IArray so can't inherit vector<Float, 3>.
  static Vector3 abs(const Vector3 a);
  static Float addv(const Vector3 a);
  static Vector3 divide(const Vector3 a, const float3 b);
  static Float maxv(const Vector3 a);
  static Vector3 multiply(const Vector3 a, const Vector3 b);
  static Vector3 subtract(const Vector3 a, const float3 b);
}

extension float: From<float> {}
extension float3: Swizzle3<float> {
  property float x {
    get { return this.x; }
    set { this.x = x; }
  }
  property float y {
    get { return this.y; }
    set { this.y = y; }
  }
  property float z {
    get { return this.z; }
    set { this.z = z; }
  }
}
extension float3: From<float[3]> {
  __init(const float[3] a) {
    return { a[0], a[1], a[2] };
  }
}

struct Regular: Arithmetic {
  typedef float Float;
  static Float abs(const Float a) {
    return ::abs(a);
  }
  static Float add(const Float a, const Float b) {
    return a + b;
  }
  static Float divide(const Float a, const Float b) {
    return a / b;
  }
  static bool less_than_equal(const Float a, const Float b) {
    return a <= b;
  }
  static Float max(const Float a, const Float b) {
    return ::max(a, b);
  }
  static Float min(const Float a, const Float b) {
    return ::min(a, b);
  }
  static Float multiply(const Float a, const Float b) {
    return a * b;
  }
  static Float negate(const Float a) {
    return -a;
  }
  static Float solve2(const Float a, const Float b, const Float c) {
    return ::solve2(a, b, c);
  }
  static Float subtract(const Float a, const float b) {
    return a - b;
  }

  typedef float3 Vector3;
  static float3 abs(const float3 a) {
    return ::abs(a);
  }
  static Float addv(const float3 a) {
    return ::addv(a);
  }
  static float3 divide(const float3 a, const float3 b) {
    return a / b;
  }
  static Float maxv(const float3 a) {
    return ::maxv(a);
  }
  static float3 multiply(const float3 a, const float3 b) {
    return a * b;
  }
  static float3 subtract(const float3 a, const float3 b) {
    return a - b;
  }
}

extension float2: From<float> {}
extension float3x2: Swizzle3<float2> {
  property float2 x {
    get { return this[0]; }
    set { this[0] = x; }
  }
  property float2 y {
    get { return this[1]; }
    set { this[1] = y; }
  }
  property float2 z {
    get { return this[2]; }
    set { this[2] = z; }
  }
}
extension float3x2: From<float2[3]> {
  __init(const float2[3] a) {
    return { a[0], a[1], a[2] };
  }
}

struct Interval: Arithmetic {
  typedef float2 Float;
  static Float abs(const Float a) {
    if (a.x >= 0.) {
      return a;
    } else if (a.y <= 0.) {
      return negate(a);
    } else {
      return Float(0., max(-a.x, a.y));
    }
  }
  static Float add(const Float a, const Float b) {
    return a + b;
  }
  static Float divide(const Float a, const Float b) {
    const float4 c = float4(a.x / b, a.y / b);
    return Float(minv(c), maxv(c));
  }
  static Float divide(const Float a, const float b) {
    return divide(a, Float(b));
  }
  static bool less_than_equal(const Float a, const Float b) {
    return all(a <= b);
  }
  static Float max(const Float a, const Float b) {
    return Float(max(a.x, b.x), max(a.y, b.y));
  }
  static Float min(const Float a, const Float b) {
    return Float(min(a.x, b.x), min(a.y, b.y));
  }
  static Float multiply(const Float a, const Float b) {
    const float4 c = a.xxyy * b.xyxy;
    return Float(minv(c), maxv(c));
  }
  static Float negate(const Float a) {
    return -a.yx;
  }
  static Float solve2(const Float a, const Float b, const Float c) {
    return Float(solve2(a.x, b.x, c.x), solve2(a.y, b.y, c.y));
  }
  static Float subtract(const Float a, const Float b) {
    return a - b.yx;
  }
  static Float subtract(const Float a, const float b) {
    return subtract(a, Float(b));
  }

  typedef float3x2 Vector3; // TODO: It doesn't look like Swizzle3 is resolved below so I have to use subscript notation.
  static Vector3 abs(const Vector3 a) {
    return { abs(a[0]), abs(a[1]), abs(a[2]) };
  }
  static Float addv(const Vector3 a) {
    return add(add(a[0], a[1]), a[2]);
  }
  static Vector3 divide(const Vector3 a, const float3 b) {
    return { divide(a[0], b[0]), divide(a[1], b[1]), divide(a[2], b[2]) };
  }
  static Float maxv(const Vector3 a) {
    return max(max(a[0], a[1]), a[2]);
  }
  static Vector3 multiply(const Vector3 a, const Vector3 b) {
    return { multiply(a[0], b[0]), multiply(a[1], b[1]), multiply(a[2], b[2]) };
  }
  static Vector3 subtract(const Vector3 a, const float3 b) {
    return { subtract(a[0], Float(b[0])), subtract(a[1], Float(b[1])), subtract(a[2], Float(b[2])) };
  }
}

// Max-norm.

A::Float box_distance<A: Arithmetic>(const A::Vector3 point, const float3 half_sides) {
  return A::maxv(A::subtract(A::abs(point), half_sides));
}

A::Float ellipsoid_distance<A: Arithmetic>(const A::Vector3 point, const float3 radii) {
  const A::Vector3 center = A::abs(point);
  const A::Vector3 center_squared = A::multiply(center, center);
  const float3 radii_squared = radii * radii;
  const float3 a = 1. / radii_squared;
  const A::Vector3 b = A::divide(center, radii_squared), c = A::divide(center_squared, radii_squared);
  A::Float distance;
  {
    const A::Float c = A::subtract(A::addv(c), 1.);
    distance = A::solve2(A::Float(addv(a)), A::addv(b), c);
    if (A::less_than_equal(c, A::Float(0.))) {
      return distance;
    }
  }
  distance = A::min(distance, A::max(center.x, A::solve2(A::Float(a.y + a.z), A::add(b.y, b.z), A::subtract(A::add(c.y, c.z), 1.))));
  distance = A::min(distance, A::max(center.y, A::solve2(A::Float(a.x + a.z), A::add(b.x, b.z), A::subtract(A::add(c.x, c.z), 1.))));
  distance = A::min(distance, A::max(center.z, A::solve2(A::Float(a.x + a.y), A::add(b.x, b.y), A::subtract(A::add(c.x, c.y), 1.))));
  distance = A::min(distance, A::maxv(A::Vector3({ A::abs(A::subtract(center.x, radii.x)), center.y, center.z })));
  distance = A::min(distance, A::maxv(A::Vector3({ center.x, A::abs(A::subtract(center.y, radii.y)), center.z })));
  distance = A::min(distance, A::maxv(A::Vector3({ center.x, center.y, A::abs(A::subtract(center.z, radii.z)) })));
  return distance;
}

// Simple examples:
// box_distance<Regular>(point, float3(.5));
// ellipsoid_distance<Interval>({ Interval::Float(point.x), Interval::Float(point.y), Interval::Float(point.z) }, float3(.5, 2., 1.));

And thanks for another fix, I've integrated it above :)

from slang.

csyonghe avatar csyonghe commented on July 20, 2024

One thing I don't quite get from the code is why we need to package Float and Float3 in an Arithmetic type. It seems that you can have:

struct IntervalFloat : IFloat
{
     // implement interval arithmetic for float
}

struct Vector<T:IFloat, let n : int> : IFloat
{
     T values[n];
     // implement vector arithmetic here.
}

// general logic should work.
T myArithmeticFunc<T:IFloat>(T v1, T v2)
{
       return v1+v2;
}

The downside here is that you wan't be able to use the builtin float3 float4 types. It might be a little inconvenient, but there shouldn't be any performance consequences.

from slang.

kevinboulain avatar kevinboulain commented on July 20, 2024

One thing I don't quite get from the code is why we need to package Float and Float3

Agree, I don't think the vector types should be part of the interface, there's no need for a hierarchy if I pass the generic everywhere (like in your last example).

struct IntervalFloat : IFloat

Now that I've tried to implement that, I can see the benefits (like free operators implementation). I was postponing newtype wrappers until I got a simple solution like the above working (so I had a comparison point). There's only one wrinkle to IFloat though: toFloat doesn't make sense for an interval. I should probably simply implement IArithmetic instead.

struct Vector<T:IFloat, let n : int> : IFloat

I'm not sure why the vector should inherit IFloat.

Thanks again for the input, I'll give it a proper shot.

from slang.

csyonghe avatar csyonghe commented on July 20, 2024

By making Vector conforming to IFloat, you can write generic functions that work on both scalar and vectors. If that is not required, then you can skip that part.

from slang.

kevinboulain avatar kevinboulain commented on July 20, 2024

Circling back on this: IArithmetic does make the implementation of the two example functions nicer (especially with a few top-level overloads, though not all of them can be made unambiguous) in roughly the same amount of code. And when compiling with -O3 the SPIR-V is basically the same.

from slang.

csyonghe avatar csyonghe commented on July 20, 2024

Thank you for sharing your use case with us and I am really glad that this works well for you!

from slang.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.